From dad6e8f71635f0e3b2954eeebfad94fff0f4e875 Mon Sep 17 00:00:00 2001 From: Bingyi Cao Date: Wed, 22 Apr 2026 08:19:05 +0000 Subject: [PATCH 1/2] Fix DPT decoder bugs: add ReLU, fix DepthDecoder head, add key remapping - Add F.relu() after DPTHead project conv to match Scenic's output_activation=True default - Fix DepthDecoder to route through parent's nn.Linear head instead of bypassing it - Register bin_centers as a buffer with configurable num_depth_bins - Add weight key remapping for all decoder types --- pytorch/decoders.py | 51 +++++++++++++++++++++++++++++++-------------- 1 file changed, 35 insertions(+), 16 deletions(-) diff --git a/pytorch/decoders.py b/pytorch/decoders.py index e1dd93f..824e522 100644 --- a/pytorch/decoders.py +++ b/pytorch/decoders.py @@ -196,6 +196,7 @@ def forward( out = self.fusion_blocks[i](out, residual=x[-(i + 1)]) out = self.project(out) + out = F.relu(out) return out @@ -264,16 +265,26 @@ def __init__(self, num_classes: int = 150, **kwargs) -> None: class DepthDecoder(Decoder): - """Decoder for monocular depth prediction using classification bins.""" + """Decoder for monocular depth prediction using classification bins. - def __init__(self, min_depth: float = 0.001, max_depth: float = 10.0, **kwargs) -> None: - # Decoder requires out_channels, we pass 256 as we use channels as bins, - # although we bypass the head in forward(). - super().__init__(out_channels=256, **kwargs) + Predicts depth by classifying each pixel into uniformly-spaced depth bins + and computing the expected depth value. + """ + + def __init__( + self, + num_depth_bins: int = 256, + min_depth: float = 0.001, + max_depth: float = 10.0, + **kwargs, + ) -> None: + super().__init__(out_channels=num_depth_bins, **kwargs) self.min_depth = min_depth self.max_depth = max_depth + self.num_depth_bins = num_depth_bins self.register_buffer( - "bin_centers", torch.linspace(min_depth, max_depth, 256) + "bin_centers", + torch.linspace(min_depth, max_depth, num_depth_bins), ) def forward( @@ -281,20 +292,23 @@ def forward( intermediate_features: List[Tuple[torch.Tensor, torch.Tensor]], image_size: Optional[Tuple[int, int]] = None, ) -> torch.Tensor: - # Bypass super().forward() to avoid the linear head applied there, - # and use raw DPT features as logits. - logits = self.dpt(intermediate_features) # (B, C, H', W') - # Apply ReLU and shift + # 1. Get DPT features + task head (nn.Linear) via parent class. + # Output shape: (B, num_depth_bins, H', W') + logits = super().forward(intermediate_features) + + # 2. Classification-based depth prediction (following Scenic/AdaBins): + # relu + shift -> linear normalisation -> expectation over bins. logits = torch.relu(logits) + self.min_depth - # Normalize to probabilities along the channel dimension probs = logits / torch.sum(logits, dim=1, keepdim=True) - # Compute expectation: sum(prob * bin_center) - depth_map = torch.einsum( - "bchw,c->bhw", probs, self.bin_centers.to(logits.device) - ) + depth_map = torch.einsum("bchw,c->bhw", probs, self.bin_centers.to(logits.device)) + + # 3. Upsample to target resolution. if image_size is not None: depth_map = F.interpolate( - depth_map.unsqueeze(1), size=image_size, mode="bilinear", align_corners=False + depth_map.unsqueeze(1), + size=image_size, + mode="bilinear", + align_corners=False, ).squeeze(1) return depth_map.unsqueeze(1) @@ -315,7 +329,12 @@ def __init__(self, **kwargs) -> None: "convs.": "dpt.convs.", "fusion_blocks.": "dpt.fusion_blocks.", "project.": "dpt.project.", + # Task-specific head keys (Scenic Dense -> PyTorch head.*) "segmentation_head.": "head.", + "pixel_segmentation.": "head.", + "pixel_depth_classif.": "head.", + "pixel_depth_regress.": "head.", + "pixel_normals.": "head.", } From 1a0aee1589656a1d6d7cd058adbaca9ec5936517 Mon Sep 17 00:00:00 2001 From: Bingyi Cao Date: Thu, 30 Apr 2026 23:30:41 +0000 Subject: [PATCH 2/2] Apply verified Scenic<->PyTorch parity fixes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three fixes verified by numerical parity tests (max abs diff < 1e-4): 1. GELU approximation: JAX defaults to tanh approximation, PyTorch uses exact. Use F.gelu(x, approximate='tanh') for numerical parity. 2. Remove spurious ReLU: Scenic DPT defaults to output_activation=False, so no ReLU should be applied after the project conv. 3. ConvTranspose kernel flip: Flax ConvTranspose uses transpose_kernel=False (no kernel flip), while PyTorch ConvTranspose2d always flips. Pre-flip weights 180 degrees during loading to compensate. These fixes bring the max absolute difference between Scenic and PyTorch decoder outputs below 1e-4 across all three heads (depth, normals, segmentation). No checkpoint re-export needed — all fixes are in the inference code/weight loading path. --- pytorch/decoders.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/pytorch/decoders.py b/pytorch/decoders.py index 824e522..4ee0539 100644 --- a/pytorch/decoders.py +++ b/pytorch/decoders.py @@ -136,7 +136,7 @@ def forward( x_flat = x.flatten(2).transpose(1, 2) readout = cls_token.unsqueeze(1).expand(-1, x_flat.shape[1], -1) x_cat = torch.cat([x_flat, readout], dim=-1) - x_proj = F.gelu(self.readout_projects[i](x_cat)) + x_proj = F.gelu(self.readout_projects[i](x_cat), approximate="tanh") x = x_proj.transpose(1, 2).reshape(b, d, h, w) x = self.out_projections[i](x) x = self.resize_layers[i](x) @@ -196,7 +196,6 @@ def forward( out = self.fusion_blocks[i](out, residual=x[-(i + 1)]) out = self.project(out) - out = F.relu(out) return out @@ -356,6 +355,14 @@ def load_decoder_weights( """ weights = dict(np.load(checkpoint_path, allow_pickle=False)) + # Identify ConvTranspose layers that need kernel flipping. + # Flax ConvTranspose uses transpose_kernel=False (no kernel flip), + # while PyTorch ConvTranspose2d always flips. We pre-flip to compensate. + conv_transpose_weight_keys = set() + for name, module in model.named_modules(): + if isinstance(module, nn.ConvTranspose2d): + conv_transpose_weight_keys.add(name + ".weight") + sd = {} for key, value in weights.items(): new_key = key @@ -364,7 +371,11 @@ def load_decoder_weights( if key.startswith(old_prefix): new_key = new_prefix + key[len(old_prefix):] break - sd[new_key] = torch.from_numpy(value) + tensor = torch.from_numpy(value) + # Flip ConvTranspose kernels 180 degrees spatially to match Scenic/Flax. + if new_key in conv_transpose_weight_keys and tensor.ndim == 4: + tensor = tensor.flip([2, 3]) + sd[new_key] = tensor model.load_state_dict(sd, strict=True) print(f"Loaded decoder weights from {checkpoint_path} ({len(sd)} tensors)")