Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 48 additions & 18 deletions pytorch/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def forward(
x_flat = x.flatten(2).transpose(1, 2)
readout = cls_token.unsqueeze(1).expand(-1, x_flat.shape[1], -1)
x_cat = torch.cat([x_flat, readout], dim=-1)
x_proj = F.gelu(self.readout_projects[i](x_cat))
x_proj = F.gelu(self.readout_projects[i](x_cat), approximate="tanh")
x = x_proj.transpose(1, 2).reshape(b, d, h, w)
x = self.out_projections[i](x)
x = self.resize_layers[i](x)
Expand Down Expand Up @@ -264,37 +264,50 @@ def __init__(self, num_classes: int = 150, **kwargs) -> None:


class DepthDecoder(Decoder):
"""Decoder for monocular depth prediction using classification bins."""
"""Decoder for monocular depth prediction using classification bins.

def __init__(self, min_depth: float = 0.001, max_depth: float = 10.0, **kwargs) -> None:
# Decoder requires out_channels, we pass 256 as we use channels as bins,
# although we bypass the head in forward().
super().__init__(out_channels=256, **kwargs)
Predicts depth by classifying each pixel into uniformly-spaced depth bins
and computing the expected depth value.
"""

def __init__(
self,
num_depth_bins: int = 256,
min_depth: float = 0.001,
max_depth: float = 10.0,
**kwargs,
) -> None:
super().__init__(out_channels=num_depth_bins, **kwargs)
self.min_depth = min_depth
self.max_depth = max_depth
self.num_depth_bins = num_depth_bins
self.register_buffer(
"bin_centers", torch.linspace(min_depth, max_depth, 256)
"bin_centers",
torch.linspace(min_depth, max_depth, num_depth_bins),
)

def forward(
self,
intermediate_features: List[Tuple[torch.Tensor, torch.Tensor]],
image_size: Optional[Tuple[int, int]] = None,
) -> torch.Tensor:
# Bypass super().forward() to avoid the linear head applied there,
# and use raw DPT features as logits.
logits = self.dpt(intermediate_features) # (B, C, H', W')
# Apply ReLU and shift
# 1. Get DPT features + task head (nn.Linear) via parent class.
# Output shape: (B, num_depth_bins, H', W')
logits = super().forward(intermediate_features)

# 2. Classification-based depth prediction (following Scenic/AdaBins):
# relu + shift -> linear normalisation -> expectation over bins.
logits = torch.relu(logits) + self.min_depth
# Normalize to probabilities along the channel dimension
probs = logits / torch.sum(logits, dim=1, keepdim=True)
# Compute expectation: sum(prob * bin_center)
depth_map = torch.einsum(
"bchw,c->bhw", probs, self.bin_centers.to(logits.device)
)
depth_map = torch.einsum("bchw,c->bhw", probs, self.bin_centers.to(logits.device))

# 3. Upsample to target resolution.
if image_size is not None:
depth_map = F.interpolate(
depth_map.unsqueeze(1), size=image_size, mode="bilinear", align_corners=False
depth_map.unsqueeze(1),
size=image_size,
mode="bilinear",
align_corners=False,
).squeeze(1)
return depth_map.unsqueeze(1)

Expand All @@ -315,7 +328,12 @@ def __init__(self, **kwargs) -> None:
"convs.": "dpt.convs.",
"fusion_blocks.": "dpt.fusion_blocks.",
"project.": "dpt.project.",
# Task-specific head keys (Scenic Dense -> PyTorch head.*)
"segmentation_head.": "head.",
"pixel_segmentation.": "head.",
"pixel_depth_classif.": "head.",
"pixel_depth_regress.": "head.",
"pixel_normals.": "head.",
}


Expand All @@ -337,6 +355,14 @@ def load_decoder_weights(
"""
weights = dict(np.load(checkpoint_path, allow_pickle=False))

# Identify ConvTranspose layers that need kernel flipping.
# Flax ConvTranspose uses transpose_kernel=False (no kernel flip),
# while PyTorch ConvTranspose2d always flips. We pre-flip to compensate.
conv_transpose_weight_keys = set()
for name, module in model.named_modules():
if isinstance(module, nn.ConvTranspose2d):
conv_transpose_weight_keys.add(name + ".weight")

sd = {}
for key, value in weights.items():
new_key = key
Expand All @@ -345,7 +371,11 @@ def load_decoder_weights(
if key.startswith(old_prefix):
new_key = new_prefix + key[len(old_prefix):]
break
sd[new_key] = torch.from_numpy(value)
tensor = torch.from_numpy(value)
# Flip ConvTranspose kernels 180 degrees spatially to match Scenic/Flax.
if new_key in conv_transpose_weight_keys and tensor.ndim == 4:
tensor = tensor.flip([2, 3])
sd[new_key] = tensor

model.load_state_dict(sd, strict=True)
print(f"Loaded decoder weights from {checkpoint_path} ({len(sd)} tensors)")
Expand Down