Implementing a Custom Backward Pass in PyTorch for 3D Gaussian Splatting

In 3D Gaussian Splatting (3DGS), performance comes from a custom backward pass: the forward stage saves a minimal set of tensors, and inexpensive intermediates are recomputed during backpropagation to reduce memory usage.

In a previous blog post, we implemented the forward pass for 3D Gaussian Splatting (3DGS) from scratch. We covered how to project millions of 3D Gaussians into 2D, apply tile-based sorting, and perform compositing using explicit alpha blending. One of the defining characteristics of 3DGS—and a major reason for its training performance— is its custom backward pass, where we manually implement the gradient computations instead of relying on PyTorch’s automatic differentiation.

In this article, we add that custom backward pass. We'll briefly explain why it's useful, show how to implement it using PyTorch's autograd.Function interface, and walk through a working example that computes the gradient with respect to color only. This is a simplified tutorial; full gradients (including position, opacity, rotation, and scale) are implemented in my full 3DGS course.

Why Implement a Custom Backward Pass in 3DGS?

While PyTorch's autograd can differentiate through most operations, it's not always optimal for performance-heavy rendering pipelines. In our forward pass, we loop over millions of Gaussians, tiles, and pixels, computing values like Gaussian weights g, per-pixel opacities α, and transmittance T. Letting PyTorch trace and store all these operations would be slow and memory-intensive.

By implementing our own backward pass, we:

Note: This approach is widely used in custom operators. For example, in fused layer normalization, a manually defined backward pass avoids materializing large intermediate tensors, leading to substantial memory savings.

Creating a Custom Autograd Function in PyTorch

To define a custom backward, we subclass torch.autograd.Function and implement two static methods: forward(ctx, ...) and backward(ctx, grad_output). PyTorch automatically calls backward during the gradient step.

Inside the forward pass, we store any data we'll need later using ctx.save_for_backward(...), and optionally stash metadata (like image size or tile size) directly on ctx.

Minimal Structure

import torch

class RasterizerFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, pos, color, opacity_raw, sigma, c2w, H, W, fx, fy, cx, cy,
                near, far, pix_guard, T, min_conis, chi_square_clip, alpha_max, alpha_cutoff):
        # Forward pass logic...
        ctx.save_for_backward(...)
        ctx.meta = (...)
        return final_image  # (H, W, 3)

    @staticmethod
    def backward(ctx, grad_output):
        # We'll fill this in next...
        pass

Computing Gradients: The Backward Pass

Let’s now build the backward method. PyTorch gives us grad_output, the gradient of the loss w.r.t. the output image. Our task is to compute the gradient w.r.t. the input color tensor.

The math

In the forward pass, each pixel’s final RGB value is a weighted sum of the colors of contributing Gaussians:

C_pixel = Σ w · color

where w is the weight of Gaussian i on that pixel.

Applying the chain rule:

∂L/∂color = ∂L/∂C_pixel · Σ w

This expression shows that the color gradient depends on the same per-pixel weights used during the forward pass. Although these weights w are not stored explicitly—doing so would be prohibitively expensive—we retain (ctx) all necessary information to reconstruct them, such as tile assignments, screen-space coordinates, opacities, and inverse covariances. During the backward pass, the weights are therefore recomputed tile-by-tile using the same logic as in the forward pass.

Backward Implementation Skeleton

@staticmethod
def backward(ctx, grad_output):
    saved = ctx.saved_tensors
    meta = ctx.meta

    grad_out_flat = grad_output.view(-1, 3)  # Flattened image gradients
    grad_color = torch.zeros((nb_gaussians, 3), device=..., dtype=...)

    for tile in tiles:
        # Recompute pixel positions
        # Recompute Gaussian weights w_i for each pixel (same as forward)
        # Multiply w_i by grad_out[pixel] and accumulate per-Gaussian

        dL_dcolor = (grad_out_flat[pixel_idx] * w.unsqueeze(-1)).sum(dim=1)
        grad_color.scatter_add_(...)

    return (None, grad_color, None, ..., None)

At the end of the backward pass, we return a tuple of gradients, one per input. For inputs that don’t require gradients (e.g. camera intrinsics), we return None. In this example, only color gets a gradient (grad_color).

Important: The order of returned gradients must match the order of forward inputs.

Summary of the Algorithm

Advantages of our implementation

As discussed above, implementing our own backward pass has major benefits:

If you ever write a custom backward pass, it’s good practice to verify gradients using torch.autograd.gradcheck or numerical finite-differences. We trust the math here, but in more complex cases, testing is key.

What's Next

We now have a custom differentiable rasterizer with a working backward pass. In my next blog, I’ll compare this custom backward to using PyTorch autograd end-to-end, focusing on:

Want to be notified when that article goes live? Subscribe to my newsletter for future posts, updates, and practical guides on PyTorch, 3DGS, and differentiable rendering.

📘 Learn 3DGS Step-by-Step (PyTorch Only)

Want to truly understand 3D Gaussian Splatting—not just run a repo? My 3D Gaussian Splatting Course teaches the full pipeline from first principles in PyTorch only (no C++, no CUDA). You’ll learn initialization, rasterization, backward passes, training loops, and how to experiment with recent papers.

Explore the Course →

💼 Research & Engineering Consulting

We help teams integrate 3D Gaussian Splatting techniques, build custom pipelines, and prototype new splatting research. If you need expertise, we can help.

Contact:
contact@qubitanalytics.be