An intro to feedforward 3DGS — implementing Splatter Image from scratch in PyTorch
In my previous blog, we tackled single-view 3D reconstruction by optimizing a NeRF from a learned initialization. Meta-learning gave us a category-aware starting point, so adaptation only needed a few gradient steps.
But the optimization loop was still there.
This post takes the other route: feedforward 3D Gaussian Splatting. Instead of optimizing 3D parameters per scene, we train a network that maps an image to a full 3DGS representation in a single forward pass. No test-time optimization, no inner loop, no hours of fitting — one image in, a complete set of 3D Gaussians out.
The idea comes from the paper:
Splatter Image: Ultra-Fast Single-View 3D Reconstruction — Szymanowicz, Rupprecht, Vedaldi (CVPR 2024)
The implementation below is a compact PyTorch version of that idea, applied to single-view object reconstruction on the ShapeNet dataset. This is also the first post in a short series on feedforward 3DGS, so we will be revisiting and extending the method in upcoming posts.
Standard 3D Gaussian Splatting fits a set of Gaussians to a scene by optimization: you start from a sparse point cloud, define a photometric loss against multi-view images, and run gradient descent for thousands of iterations. The result is excellent quality — at the cost of minutes to hours per scene.
Feedforward 3DGS asks the opposite question:
Can we predict the Gaussians directly with a neural network, instead of optimizing them?
If a network has seen enough scenes from the same distribution, it can learn the priors that optimization usually has to rediscover from scratch. The network outputs a 3DGS representation in one forward pass, and we render with the standard differentiable rasterizer.
Splatter Image is the simplest, most elegant instance of this idea: one Gaussian per input pixel, predicted by a UNet.
3D Gaussian Splatting is fast at rendering, but slow at fitting. Optimizing a 3DGS scene from images can take minutes to hours. Feedforward 3DGS replaces that loop with a single forward pass.
The Splatter Image trick:
What if a UNet could predict 3DGS parameters directly from an image, the way it predicts a depth map?
Concretely: each pixel of the input image becomes one 3D Gaussian. A UNet that takes a [3, H, W] image and outputs a [15, H, W] feature map encodes all the parameters of H × W Gaussians at once.
That gives us:
Conceptually, this is the opposite of last week's post: predict the representation rather than optimize it.
The pipeline has three stages:
Training is supervised on novel views. For each scene, we sample two views — one source and one target. The model only ever sees the source. Its predicted Gaussians are rendered from the target camera and compared to the target image.
Each of the 15 output channels has a fixed meaning:
The model never directly outputs a 3D position. Instead, it outputs a depth along the ray plus a small offset. The position of each Gaussian is then origin + u · depth + delta, where u is the normalized pixel ray. This conditioning on the source camera makes training significantly more stable — the network only has to learn local geometry instead of arbitrary world-space coordinates.
The decoder takes the raw [1, 15, H, W] UNet output and produces world-space Gaussian parameters. Let's start with the means (the 3D centers).
def decode_gaussians(raw, source_c2w, fx, fy, cx, cy, znear, zfar, opacity_threshold=0.0):
device = raw.device
dtype = raw.dtype
_, _, H, W = raw.shape
opacity_raw = raw[:, 0:1]
delta = raw[:, 1:4]
depth_raw = raw[:, 4:5]
scale_raw = raw[:, 5:8]
quat_raw = raw[:, 8:12]
color_raw = raw[:, 12:15]
depth = (zfar - znear) * torch.sigmoid(depth_raw.clamp(-10, 10)) + znear
u = build_pixel_grid_from_intrinsics(H, W, fx, fy, cx, cy, device, dtype).unsqueeze(0)
Rcw = source_c2w[:3, :3]
tcw = source_c2w[:3, 3]
Rwc = Rcw.t()
origin_cam = (Rwc @ (-tcw)).view(1, 3, 1, 1)
mean_cam = torch.empty((1, 3, H, W), device=device, dtype=dtype)
mean_cam[:, 0:1] = origin_cam[:, 0:1] + u[:, 0:1] * depth + delta[:, 0:1]
mean_cam[:, 1:2] = origin_cam[:, 1:2] + u[:, 1:2] * depth + delta[:, 1:2]
mean_cam[:, 2:3] = origin_cam[:, 2:3] + depth + delta[:, 2:3]
Three things happen here:
[znear, zfar]. This bounds the predicted depths to the scene's expected range.build_pixel_grid_from_intrinsics returns the per-pixel ray direction in camera coordinates: u = (x + 0.5 - cx) / fx, v = (y + 0.5 - cy) / fy, 1. Multiplying by depth gives a 3D point along that ray.delta adds a small XYZ correction. This lets the network move a Gaussian slightly off its ray when geometry doesn't lie exactly on it.I'm computing means in camera coordinates here — the world-space transformation comes next.
Each 3D Gaussian needs a covariance matrix Σ. The standard 3DGS parameterization writes it as Σ = R S² Rᵀ, where R is a rotation matrix built from a quaternion and S is a diagonal scale matrix.
color = torch.sigmoid(color_raw)
quat_xyzw = torch.nn.functional.normalize(quat_raw, dim=1, eps=1e-6)
pos_cam = mean_cam[0].permute(1, 2, 0).reshape(-1, 3)
opacity_raw_flat = opacity_raw[0].reshape(-1)
scale_raw_flat = scale_raw[0].permute(1, 2, 0).reshape(-1, 3)
quat_flat = quat_xyzw[0].permute(1, 2, 0).reshape(-1, 4)
color_flat = color[0].permute(1, 2, 0).reshape(-1, 3)
pos_world = (Rcw @ pos_cam.t()).t() + tcw.unsqueeze(0)
scales = torch.exp(scale_raw_flat.clamp(-10.0, 3.0)).clamp_min(1e-6)
R_local = quat_xyzw_to_rotmat(quat_flat)
S = torch.diag_embed(scales)
sigma_cam = R_local @ S @ S @ R_local.transpose(1, 2)
sigma_world = Rcw.unsqueeze(0) @ sigma_cam @ Rcw.t().unsqueeze(0)
return pos_world, color_flat, opacity_raw_flat, sigma_world
A few things worth pointing out:
Σ_world = R_cw Σ_cam R_cwᵀ. This is consistent with the way positions are computed (in camera coordinates first), and lets the network learn local shape independently of scene-level pose.The output is the canonical 3DGS quintuple: position, color, opacity, covariance — one per pixel.
There is a quietly critical piece in this implementation: how the final output head of the UNet is initialized.
def init_output_head(model, znear=0.8, zfar=1.8, *, opacity_bias=-3.5, depth0=1.2,
scale_bias=0.02, opacity_gain=1.0, xyz_gain=0.1, depth_gain=1.0,
scale_gain=0.1, quat_gain=1.0, rgb_gain=5.0):
def inv_sigmoid_scalar(y, eps=1e-6):
y = max(eps, min(1.0 - eps, float(y)))
return math.log(y / (1.0 - y))
with torch.no_grad():
weight = model.head.weight
bias = model.head.bias
weight.zero_()
bias.zero_()
depth01 = (depth0 - znear) / (zfar - znear)
depth_bias = inv_sigmoid_scalar(depth01)
torch.nn.init.xavier_uniform_(weight[0], gain=opacity_gain)
torch.nn.init.constant_(bias[0], opacity_bias)
torch.nn.init.xavier_uniform_(weight[1:4], gain=xyz_gain)
torch.nn.init.xavier_uniform_(weight[4], gain=depth_gain)
torch.nn.init.constant_(bias[4], depth_bias)
torch.nn.init.xavier_uniform_(weight[5:8], gain=scale_gain)
torch.nn.init.constant_(bias[5:8], math.log(scale_bias))
torch.nn.init.xavier_uniform_(weight[8:12], gain=quat_gain)
torch.nn.init.constant_(bias[11], 1.0)
torch.nn.init.xavier_uniform_(weight[12:15], gain=rgb_gain)
Without this step, the network produces nonsense Gaussians at iteration zero and training is unstable. With it, the very first forward pass already produces something sensible.
What each line buys us:
depth0 = 1.2 (middle of the scene volume) and invert the sigmoid to get the corresponding pre-activation. At init, every Gaussian sits roughly in the middle of the depth range.bias[11] = 1 puts the unit quaternion in canonical position.If you remember one thing from this post, remember this: reasonable parameterization is half the battle in feed-forward 3D models. A clean output head turns a divergent loss into one that descends from step one.
We use the NMR (Neural Mesh Renderer) cars dataset, with multi-view renderings and known camera matrices.
class NMRDataset(Dataset):
def __init__(self, data_path, json_path, train=True, n_views=24, H=64, W=64):
self.data_path = data_path
self.n_views = n_views
self.H = H
self.W = W
with open(json_path, "r") as f:
split = json.load(f)
self.scenes = [
os.path.join(data_path, f) for f in sorted(split["train" if train else "test"])]
gt_pixels, c2ws, intrinsics = [], [], []
for scene_path in tqdm(self.scenes, desc="loading scenes"):
cam_data = np.load(os.path.join(scene_path, "cameras.npz"))
scene_gt_pixels = torch.zeros((n_views, H, W, 3), dtype=torch.float32)
scene_c2ws = torch.zeros((n_views, 4, 4), dtype=torch.float32)
scene_intrinsics = torch.zeros((n_views, 4, 4), dtype=torch.float32)
for view_idx in range(n_views):
img = np.array(Image.open(
os.path.join(scene_path, "image", f"{view_idx:04d}.png")).convert("RGB"))
c2w = cam_data[f"world_mat_inv_{view_idx}"]
K = cam_data[f"camera_mat_{view_idx}"]
scene_gt_pixels[view_idx] = torch.from_numpy(img).float() / 255.0
scene_c2ws[view_idx] = torch.from_numpy(c2w).float()
scene_intrinsics[view_idx] = torch.from_numpy(K).float()
gt_pixels.append(scene_gt_pixels)
c2ws.append(scene_c2ws)
intrinsics.append(scene_intrinsics)
self.gt_pixels = torch.stack(gt_pixels, dim=0) # [B, N, H, W, 3]
self.c2ws = torch.stack(c2ws, dim=0) # [B, N, 4, 4]
self.intrinsics = torch.stack(intrinsics, dim=0) # [B, N, 4, 4]
Each scene has 24 views, each with image + camera-to-world matrix + intrinsics. Everything is pre-loaded into tensors, so training is I/O-free.
The __getitem__ method is where the source/target sampling happens. For each batch element, we pick one source view (the input to the model) and a different target view (the supervision signal):
def __getitem__(self, scene_idx):
scene_imgs = self.gt_pixels[scene_idx]
scene_c2ws = self.c2ws[scene_idx]
scene_intr = self.intrinsics[scene_idx]
src_idx, tgt_idx = sample_source_target_indices(self.n_views)
return {"src_img": scene_imgs[src_idx].permute(2, 0, 1),
"tgt_img": scene_imgs[tgt_idx].permute(2, 0, 1),
"source_c2w": scene_c2ws[src_idx],
"target_c2w": scene_c2ws[tgt_idx],
"source_cam": scene_intr[src_idx],
"target_cam": scene_intr[tgt_idx],
"meta": torch.tensor([scene_idx, src_idx, tgt_idx], dtype=torch.long)}
def sample_source_target_indices(n_views):
src_idx = random.randrange(n_views)
tgt_idx = random.randrange(n_views - 1)
if tgt_idx >= src_idx:
tgt_idx += 1
return src_idx, tgt_idx
The sample_source_target_indices trick guarantees src_idx ≠ tgt_idx without a rejection loop: pick the target from n_views - 1 slots, then shift it past the source if needed.
The training loop is short.
for step in tqdm(range(1, 800_001)):
try:
batch = next(train_iter)
except StopIteration:
train_iter = iter(train_loader)
batch = next(train_iter)
batch_src_imgs = batch["src_img"].to(device, non_blocking=True)
batch_tgt_imgs = batch["tgt_img"].to(device, non_blocking=True)
batch_source_c2w = batch["source_c2w"].to(device, non_blocking=True)
batch_target_c2w = batch["target_c2w"].to(device, non_blocking=True)
batch_source_cam = batch["source_cam"].to(device, non_blocking=True)
batch_target_cam = batch["target_cam"].to(device, non_blocking=True)
optimizer.zero_grad(set_to_none=True)
raw_batch = model(batch_src_imgs) # [B, 15, H, W]
batch_loss = 0.0
for b in range(batch_size):
fx_src, fy_src, cx_src, cy_src = intrinsics_to_fxfycxcy(batch_source_cam[b], H, W)
fx_t, fy_t, cx_t, cy_t = intrinsics_to_fxfycxcy(batch_target_cam[b], H, W)
raw_b = raw_batch[b:b + 1]
pos, color, opacity_raw, sigma = decode_gaussians(
raw=raw_b, source_c2w=batch_source_c2w[b], fx=fx_src, fy=fy_src, cx=cx_src,
cy=cy_src, znear=znear, zfar=zfar, opacity_threshold=0.0)
pred = render(pos=pos, color=color, opacity_raw=opacity_raw, sigma=sigma,
c2w=batch_target_c2w[b], H=H, W=W, fx=fx_t, fy=fy_t, cx=cx_t,
cy=cy_t).permute(2, 0, 1).unsqueeze(0)
tgt = batch_tgt_imgs[b:b + 1]
loss_b = torch.nn.functional.mse_loss(pred, tgt)
batch_loss = batch_loss + loss_b
batch_loss = batch_loss / batch_size
batch_loss.backward()
optimizer.step()
Per step, we:
[B, 15, H, W] predictions.The renderer itself (render, imported from a separate module) is a standard differentiable Gaussian rasterizer. I cover the rasterization details — including how gradients flow back through it, which is what makes feedforward 3DGS trainable end-to-end — in my 3D Gaussian Splatting course. Here we just call it.
One small note: I'm looping over the batch dimension explicitly because the renderer is implemented for a batch size of one. The forward pass through the UNet is fully batched; only the decoding and rendering happen one scene at a time.
After training, we run the model on unseen test scenes. For each one, we pick a single source view, predict Gaussians, and render from a grid of held-out target cameras.
@torch.no_grad()
def render_novel_view_grid(model, gt_pixels, c2ws, intrinsics, H, W, znear, zfar,
device, save_path="splatter_image.png", num_test_scenes=10):
novel_view_indices = [1, 2, 4, 7, 10, 13, 16, 19, 22, 23]
ncols = len(novel_view_indices)
fig, axes = plt.subplots(num_test_scenes, ncols,
figsize=(2.2 * ncols, 2.2 * num_test_scenes), dpi=300)
for scene_idx in range(num_test_scenes):
scene_imgs = gt_pixels[scene_idx]
scene_c2ws = c2ws[scene_idx]
scene_intr = intrinsics[scene_idx]
src_view_idx = novel_view_indices[scene_idx]
tgt_view_indices = [v for v in novel_view_indices if v != src_view_idx]
src_img = scene_imgs[src_view_idx].permute(2, 0, 1).unsqueeze(0).to(device)
source_c2w = scene_c2ws[src_view_idx].to(device)
source_cam = scene_intr[src_view_idx].to(device)
fx_src, fy_src, cx_src, cy_src = intrinsics_to_fxfycxcy(source_cam, H, W)
raw = model(src_img)
pos, color, opacity_raw, sigma = decode_gaussians(
raw=raw, source_c2w=source_c2w, fx=fx_src, fy=fy_src, cx=cx_src, cy=cy_src,
znear=znear, zfar=zfar)
for col_idx, view_idx in enumerate(tgt_view_indices, start=1):
target_c2w = scene_c2ws[view_idx].to(device)
target_cam = scene_intr[view_idx].to(device)
fx_t, fy_t, cx_t, cy_t = intrinsics_to_fxfycxcy(target_cam, H, W)
pred = render(pos=pos, color=color, opacity_raw=opacity_raw, sigma=sigma,
c2w=target_c2w, H=H, W=W, fx=fx_t, fy=fy_t, cx=cx_t, cy=cy_t)
axes[scene_idx, col_idx].imshow(pred.detach().cpu().numpy().clip(0, 1))
axes[scene_idx, col_idx].axis("off")
This is the moment of truth. The model has never seen these objects. It only sees one image of each. And yet the predicted Gaussians can be rendered from arbitrary novel viewpoints in plausible-looking ways — all in a single forward pass, no optimization at test time.
In the next post in this feedforward 3DGS series, we will look at how to predict multiple Gaussians per pixel, so the network can also represent occluded geometry. That single change pushes single-view reconstruction quality dramatically closer to multi-view methods, and it is one of the main directions modern feedforward 3DGS work has taken.
If you want to be notified when that one comes out, subscribe to the newsletter below.
Splatter Image is a small but striking demonstration of how parameterization shapes what neural networks can learn. The whole feedforward 3DGS pipeline works because:
There is no per-scene optimization loop. There is no meta-learning. Just a UNet, a clean parameterization, and a renderer.
Do you want to truly understand 3D Gaussian Splatting—not just run a repo? My 3D Gaussian Splatting Course teaches you the full pipeline from first principles, including the differentiable rasterizer used in this post. Everything is broken down into clear modules with code you can actually read and modify.
Explore the Course →Subscribe to my newsletter for future posts, updates, and practical guides on PyTorch, 3DGS, and differentiable rendering. The next one covers how to lift the one-splat-per-pixel limitation — you don't want to miss it.
Subscribe to the Newsletter →We help teams bridge the gap between research and production. Our work focuses on practical integration of 3D Gaussian Splatting techniques, implementation of recent methods, and custom research or prototyping for advanced splatting pipelines.
For consulting inquiries:
contact@qubitanalytics.be