In 3D Gaussian Splatting (3DGS), performance comes from a custom backward pass: the forward stage saves a minimal set of tensors, and inexpensive intermediates are recomputed during backpropagation to reduce memory usage.
In a previous blog post, we implemented the forward pass for 3D Gaussian Splatting (3DGS) from scratch. We covered how to project millions of 3D Gaussians into 2D, apply tile-based sorting, and perform compositing using explicit alpha blending. One of the defining characteristics of 3DGS—and a major reason for its training performance— is its custom backward pass, where we manually implement the gradient computations instead of relying on PyTorch’s automatic differentiation.
In this article, we add that custom backward pass. We'll briefly explain why it's useful, show how to implement it using
PyTorch's autograd.Function interface, and walk through a working example that computes
the gradient with respect to color only. This is a simplified tutorial; full gradients (including position, opacity,
rotation, and scale) are implemented in my full 3DGS course.
While PyTorch's autograd can differentiate through most operations, it's not always optimal for performance-heavy
rendering pipelines. In our forward pass, we loop over millions of Gaussians, tiles, and pixels, computing values
like Gaussian weights g, per-pixel opacities α, and
transmittance T. Letting PyTorch trace and store all these operations would be
slow and memory-intensive.
By implementing our own backward pass, we:
g, α, and T)
To define a custom backward, we subclass torch.autograd.Function and implement two static
methods: forward(ctx, ...) and backward(ctx, grad_output).
PyTorch automatically calls backward during the gradient step.
Inside the forward pass, we store any data we'll need later using ctx.save_for_backward(...), and
optionally stash metadata (like image size or tile size) directly on ctx.
import torch
class RasterizerFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, pos, color, opacity_raw, sigma, c2w, H, W, fx, fy, cx, cy,
near, far, pix_guard, T, min_conis, chi_square_clip, alpha_max, alpha_cutoff):
# Forward pass logic...
ctx.save_for_backward(...)
ctx.meta = (...)
return final_image # (H, W, 3)
@staticmethod
def backward(ctx, grad_output):
# We'll fill this in next...
pass
Let’s now build the backward method. PyTorch gives us grad_output, the gradient of the loss
w.r.t. the output image. Our task is to compute the gradient w.r.t. the input color tensor.
In the forward pass, each pixel’s final RGB value is a weighted sum of the colors of contributing Gaussians:
C_pixel = Σ w · color
where w is the weight of Gaussian
i on that pixel.
Applying the chain rule:
∂L/∂color = ∂L/∂C_pixel · Σ w
This expression shows that the color gradient depends on the same per-pixel weights used during the forward pass.
Although these weights w are not stored explicitly—doing so would be prohibitively expensive—we retain (ctx)
all necessary information to reconstruct them, such as tile assignments, screen-space coordinates, opacities, and inverse covariances.
During the backward pass, the weights are therefore recomputed tile-by-tile using the same logic as in the forward pass.
@staticmethod
def backward(ctx, grad_output):
saved = ctx.saved_tensors
meta = ctx.meta
grad_out_flat = grad_output.view(-1, 3) # Flattened image gradients
grad_color = torch.zeros((nb_gaussians, 3), device=..., dtype=...)
for tile in tiles:
# Recompute pixel positions
# Recompute Gaussian weights w_i for each pixel (same as forward)
# Multiply w_i by grad_out[pixel] and accumulate per-Gaussian
dL_dcolor = (grad_out_flat[pixel_idx] * w.unsqueeze(-1)).sum(dim=1)
grad_color.scatter_add_(...)
return (None, grad_color, None, ..., None)
At the end of the backward pass, we return a tuple of gradients, one per input. For inputs that don’t require gradients (e.g. camera intrinsics),
we return None. In this example, only color gets a gradient (grad_color).
ctx) to identify which Gaussians influence which pixelsw per pixel using Gaussian opacity, visibility, and spatial kernelΣ w · grad_pixelAs discussed above, implementing our own backward pass has major benefits:
If you ever write a custom backward pass, it’s good practice to verify gradients using
torch.autograd.gradcheck or numerical finite-differences. We trust the math here,
but in more complex cases, testing is key.
We now have a custom differentiable rasterizer with a working backward pass. In my next blog, I’ll compare this custom backward to using PyTorch autograd end-to-end, focusing on:
Want to be notified when that article goes live? Subscribe to my newsletter for future posts, updates, and practical guides on PyTorch, 3DGS, and differentiable rendering.
Want to truly understand 3D Gaussian Splatting—not just run a repo? My 3D Gaussian Splatting Course teaches the full pipeline from first principles in PyTorch only (no C++, no CUDA). You’ll learn initialization, rasterization, backward passes, training loops, and how to experiment with recent papers.
Explore the Course →We help teams integrate 3D Gaussian Splatting techniques, build custom pipelines, and prototype new splatting research. If you need expertise, we can help.
Contact:
contact@qubitanalytics.be