diff --git a/pytorch/TIPS_segmentation_depth_norm.ipynb b/pytorch/TIPS_segmentation_depth_norm.ipynb new file mode 100644 index 0000000..a4f76f0 --- /dev/null +++ b/pytorch/TIPS_segmentation_depth_norm.ipynb @@ -0,0 +1,528 @@ +{ + "cells": [ + { + "id": "234d8eb3", + "cell_type": "markdown", + "source": [ + "Copyright 2026 Google LLC.\n", + "\n", + "SPDX-License-Identifier: Apache-2.0" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "c9d2ff46", + "cell_type": "code", + "source": [ + "# @title TIPSv2 Segmentation, Depth, and Normals DPT Notebook\n", + "\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "ee8ab432", + "cell_type": "markdown", + "source": [ + "# Setup" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "d7b52c6e", + "cell_type": "code", + "source": [ + "# @title Install dependencies and clone TIPS repo.\n", + "import os\n", + "import sys\n", + "\n", + "# Root directory for all files (Colab default is /content).\n", + "ROOT_DIR = os.getcwd()\n", + "TIPS_DIR = os.path.join(ROOT_DIR, 'tips')\n", + "\n", + "# Install required packages.\n", + "!pip install -q torch torchvision torchaudio\n", + "!pip install -q tensorflow_text mediapy jax jaxlib scikit-learn\n", + "\n", + "# Clone the TIPS repository.\n", + "if not os.path.exists(TIPS_DIR):\n", + " !git clone --branch add-decoders-module https://github.com/google-deepmind/tips.git {TIPS_DIR}\n", + "\n", + "# Add the root directory to PYTHONPATH so that `tips.*` imports work.\n", + "if ROOT_DIR not in sys.path:\n", + " sys.path.insert(0, ROOT_DIR)\n", + "\n", + "print(f'ROOT_DIR: {ROOT_DIR}')\n", + "print(f'TIPS_DIR: {TIPS_DIR}')\n", + "print('Installation complete!')" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "d4f49cd7", + "cell_type": "code", + "source": [ + "# @title Download the checkpoints and datasets.\n", + "import urllib.request\n", + "import zipfile\n", + "\n", + "variant = 'L' # @param [\"B\", \"L\", \"So\", \"g\"]\n", + "\n", + "CHECKPOINT_BASE_URL = 'https://storage.googleapis.com/tips_data/v2_0/checkpoints/pytorch'\n", + "TOKENIZER_URL = 'https://storage.googleapis.com/tips_data/v1_0/checkpoints/tokenizer.model'\n", + "NYU_URL = 'http://horatio.cs.nyu.edu/mit/silberman/nyu_depth_v2/bedrooms_part6.zip'\n", + "NYU_TMP_PATH = os.path.join(ROOT_DIR, 'bedrooms_part6.zip')\n", + "\n", + "# Directories for checkpoints (under ROOT_DIR).\n", + "CKPT_DIR = os.path.join(ROOT_DIR, 'checkpoints')\n", + "os.makedirs(CKPT_DIR, exist_ok=True)\n", + "\n", + "# Mapping from variant to checkpoint basenames (matches GCS naming).\n", + "V2_CKPT_BASENAME_MAP = {\n", + " 'B': 'tips_v2_oss_b14',\n", + " 'L': 'tips_v2_oss_l14',\n", + " 'So': 'tips_v2_oss_so14',\n", + " 'g': 'tips_v2_oss_g14',\n", + "}\n", + "# DPT checkpoint naming uses a different convention (no \"oss\").\n", + "V2_DPT_BASENAME_MAP = {\n", + " 'B': 'tips_v2_b14',\n", + " 'L': 'tips_v2_l14',\n", + " 'So': 'tips_v2_so400m14',\n", + " 'g': 'tips_v2_g14',\n", + "}\n", + "ckpt_basename = V2_CKPT_BASENAME_MAP[variant]\n", + "dpt_basename = V2_DPT_BASENAME_MAP[variant]\n", + "\n", + "# Download the TIPS vision encoder checkpoint.\n", + "vision_ckpt_name = f'{ckpt_basename}_vision.npz'\n", + "vision_ckpt_url = f'{CHECKPOINT_BASE_URL}/{vision_ckpt_name}'\n", + "image_encoder_checkpoint = os.path.join(CKPT_DIR, vision_ckpt_name)\n", + "if not os.path.exists(image_encoder_checkpoint):\n", + " print(f'Downloading vision encoder: {vision_ckpt_url}...')\n", + " urllib.request.urlretrieve(vision_ckpt_url, image_encoder_checkpoint)\n", + "\n", + "# Download the TIPS text encoder checkpoint.\n", + "text_ckpt_name = f'{ckpt_basename}_text.npz'\n", + "text_ckpt_url = f'{CHECKPOINT_BASE_URL}/{text_ckpt_name}'\n", + "text_encoder_checkpoint = os.path.join(CKPT_DIR, text_ckpt_name)\n", + "if not os.path.exists(text_encoder_checkpoint):\n", + " print(f'Downloading text encoder: {text_ckpt_url}...')\n", + " urllib.request.urlretrieve(text_ckpt_url, text_encoder_checkpoint)\n", + "\n", + "# Download the tokenizer model.\n", + "tokenizer_path = os.path.join(CKPT_DIR, 'tokenizer.model')\n", + "if not os.path.exists(tokenizer_path):\n", + " print(f'Downloading tokenizer: {TOKENIZER_URL}...')\n", + " urllib.request.urlretrieve(TOKENIZER_URL, tokenizer_path)\n", + "\n", + "# Download DPT checkpoints (Segmentation, Depth, Normals).\n", + "dpt_tasks = ['segmentation', 'depth', 'normals']\n", + "dpt_checkpoint_paths = {}\n", + "\n", + "for task in dpt_tasks:\n", + " dpt_zip_name = f'{dpt_basename}_{task}_dpt_pytorch.zip'\n", + " dpt_zip_url = f'{CHECKPOINT_BASE_URL}/{dpt_zip_name}'\n", + " dpt_zip_path = os.path.join(CKPT_DIR, dpt_zip_name)\n", + "\n", + " if not os.path.exists(dpt_zip_path):\n", + " print(f'Downloading DPT {task} checkpoint: {dpt_zip_url}...')\n", + " try:\n", + " urllib.request.urlretrieve(dpt_zip_url, dpt_zip_path)\n", + " # Extract the .npz file(s) from the zip.\n", + " with zipfile.ZipFile(dpt_zip_path, 'r') as zf:\n", + " zf.extractall(CKPT_DIR)\n", + " print(f' Extracted {task} checkpoint to {CKPT_DIR}')\n", + " except Exception as e:\n", + " print(f' Failed to download {dpt_zip_name}: {e}')\n", + " else:\n", + " print(f' DPT {task} checkpoint already exists.')\n", + "\n", + " dpt_checkpoint_paths[task] = dpt_zip_path\n", + "\n", + "# Download and extract NYU dataset for sample images.\n", + "NYU_IMG_DIR = os.path.join(ROOT_DIR, 'nyu_images')\n", + "if not os.path.isdir(NYU_IMG_DIR):\n", + " print('\\nDownloading NYU dataset (bedrooms_part6.zip)...')\n", + " try:\n", + " urllib.request.urlretrieve(NYU_URL, NYU_TMP_PATH)\n", + " print('Extracting NYU images...')\n", + " os.makedirs(NYU_IMG_DIR, exist_ok=True)\n", + " with zipfile.ZipFile(NYU_TMP_PATH, 'r') as z:\n", + " z.extractall(NYU_IMG_DIR)\n", + " os.remove(NYU_TMP_PATH)\n", + " print(f' Extracted to {NYU_IMG_DIR}')\n", + " except Exception as e:\n", + " print(f' Failed to download or extract NYU dataset: {e}')\n", + "else:\n", + " print(' NYU images already extracted.')\n", + "\n", + "IMG_DIR = NYU_IMG_DIR\n", + "\n", + "# Download and extract ADE20K dataset for sample images.\n", + "ADE20K_URL = 'http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip'\n", + "ADE20K_TMP_PATH = os.path.join(ROOT_DIR, 'ADEChallengeData2016.zip')\n", + "ADE20K_DIR = os.path.join(ROOT_DIR, 'ADEChallengeData2016')\n", + "\n", + "if not os.path.isdir(ADE20K_DIR):\n", + " print('\\nDownloading ADE20K dataset...')\n", + " try:\n", + " urllib.request.urlretrieve(ADE20K_URL, ADE20K_TMP_PATH)\n", + " print('Extracting ADE20K images...')\n", + " with zipfile.ZipFile(ADE20K_TMP_PATH, 'r') as z:\n", + " z.extractall(ROOT_DIR)\n", + " os.remove(ADE20K_TMP_PATH)\n", + " print(f' Extracted to {ADE20K_DIR}')\n", + " except Exception as e:\n", + " print(f' Failed to download or extract ADE20K dataset: {e}')\n", + "else:\n", + " print(' ADE20K images already extracted.')\n", + "\n", + "print('\\nAll downloads complete!')" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "f2a58e02", + "cell_type": "code", + "source": [ + "# @title Configure the TIPS model.\n", + "\n", + "# Set the input image shape.\n", + "image_size = 448 # @param {type: \"number\"}\n", + "\n", + "# variant, image_encoder_checkpoint, text_encoder_checkpoint, tokenizer_path, dpt_zip_path\n", + "# are all set in the download cell (Cell [3]) above.\n", + "\n", + "# Use the zip file path directly as the checkpoint, since .npz IS a zip file.\n", + "dpt_checkpoint = dpt_zip_path\n", + "\n", + "print(f'Image encoder checkpoint: {image_encoder_checkpoint}')\n", + "print(f'Text encoder checkpoint: {text_encoder_checkpoint}')\n", + "print(f'Tokenizer path: {tokenizer_path}')\n", + "print(f'DPT checkpoint (PyTorch .zip as .npz): {dpt_checkpoint}')" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "682ab8ab", + "cell_type": "code", + "source": [ + "import numpy as np\n", + "import torch\n", + "from tips.pytorch import image_encoder\n", + "from tips.pytorch.decoders import SegmentationDecoder, DepthDecoder, NormalsDecoder, load_decoder_weights, Decoder\n", + "\n", + "PATCH_SIZE = 14\n", + "weights_image = dict(np.load(image_encoder_checkpoint, allow_pickle=False))\n", + "for key in weights_image:\n", + " weights_image[key] = torch.tensor(weights_image[key])\n", + "ffn_layer = 'swiglu' if variant == 'g' else 'mlp'\n", + "\n", + "MODEL_CONSTRUCTOR_MAP = {'B': 'vit_base', 'L': 'vit_large', 'So': 'vit_so400m', 'g': 'vit_giant2'}\n", + "EMBED_DIM_MAP = {'B': 768, 'L': 1024, 'So': 1152, 'g': 1536}\n", + "INTERMEDIATE_LAYERS_MAP = {\n", + " 'B': [2, 5, 8, 11],\n", + " 'L': [5, 11, 17, 23],\n", + " 'So': [6, 13, 20, 26],\n", + " 'g': [9, 19, 29, 39],\n", + "}\n", + "\n", + "vit_constructor = getattr(image_encoder, MODEL_CONSTRUCTOR_MAP[variant])\n", + "embed_dim = EMBED_DIM_MAP[variant]\n", + "intermediate_layers = INTERMEDIATE_LAYERS_MAP[variant]\n", + "post_process_channels = (embed_dim // 8, embed_dim // 4, embed_dim // 2, embed_dim)\n", + "\n", + "# Load Vision Encoder\n", + "with torch.no_grad():\n", + " model_image = vit_constructor(\n", + " img_size=image_size, patch_size=PATCH_SIZE, ffn_layer=ffn_layer,\n", + " block_chunks=0, init_values=1.0, interpolate_antialias=True, interpolate_offset=0.0,\n", + " )\n", + " model_image.load_state_dict(weights_image)\n", + " model_image.eval()\n", + "\n", + "# Load Segmentation Decoder\n", + "with torch.no_grad():\n", + " seg_model = SegmentationDecoder(\n", + " num_classes=150,\n", + " input_embed_dim=embed_dim,\n", + " post_process_channels=post_process_channels,\n", + " )\n", + " load_decoder_weights(seg_model, dpt_checkpoint_paths['segmentation'])\n", + " seg_model.eval()\n", + "\n", + "# Load Depth Decoder (using fixed library version)\n", + "with torch.no_grad():\n", + " depth_model = DepthDecoder(\n", + " input_embed_dim=embed_dim,\n", + " post_process_channels=post_process_channels,\n", + " )\n", + " load_decoder_weights(depth_model, dpt_checkpoint_paths['depth'])\n", + " depth_model.eval()\n", + "\n", + "# Load Normals Decoder (using fixed library version)\n", + "with torch.no_grad():\n", + " normals_model = NormalsDecoder(\n", + " input_embed_dim=embed_dim,\n", + " post_process_channels=post_process_channels,\n", + " )\n", + " load_decoder_weights(normals_model, dpt_checkpoint_paths['normals'])\n", + " \n", + " # Keep monkey-patch for L2 norm since it's not in decoders.py yet\n", + " normals_model.forward = (\n", + " lambda features, image_size=None: torch.nn.functional.normalize(\n", + " Decoder.forward(normals_model, features, image_size), dim=1\n", + " )\n", + " )\n", + " normals_model.eval()" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "eb3dc08f", + "cell_type": "code", + "source": [ + "# @title Define segmentation classes.\n", + "import colorsys\n", + "\n", + "ADE20K_CLASSES = (\n", + " 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed', 'windowpane', 'grass',\n", + " 'cabinet', 'sidewalk', 'person', 'earth', 'door', 'table', 'mountain', 'plant', 'curtain', 'chair',\n", + " 'car', 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', 'field',\n", + " 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe', 'lamp', 'bathtub', 'railing', 'cushion',\n", + " 'base', 'box', 'column', 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', 'skyscraper', 'fireplace',\n", + " 'refrigerator', 'grandstand', 'path', 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door', 'stairway',\n", + " 'river', 'bridge', 'bookcase', 'blind', 'coffee table', 'toilet', 'flower', 'book', 'hill', 'bench',\n", + " 'countertop', 'stove', 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar', 'arcade machine', 'hovel',\n", + " 'bus', 'towel', 'light', 'truck', 'tower', 'chandelier', 'awning', 'streetlight', 'booth', 'television',\n", + " 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister', 'escalator', 'ottoman', 'bottle', 'buffet',\n", + " 'poster', 'stage', 'van', 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything', 'swimming pool',\n", + " 'stool', 'barrel', 'basket', 'waterfall', 'tent', 'bag', 'minibike', 'cradle', 'oven', 'ball',\n", + " 'food', 'step', 'tank', 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake', 'dishwasher',\n", + " 'screen', 'blanket', 'sculpture', 'hood', 'sconce', 'vase', 'traffic light', 'tray', 'ashcan', 'fan',\n", + " 'pier', 'crt screen', 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', 'clock', 'flag'\n", + ")\n", + "\n", + "NUM_ADE20K_CLASSES = 150\n", + "ADE20K_PALETTE = np.zeros((NUM_ADE20K_CLASSES + 1, 3), dtype=np.uint8)\n", + "for i in range(1, NUM_ADE20K_CLASSES + 1):\n", + " hue = (i * 0.618033988749895) % 1.0\n", + " saturation = 0.65 + 0.35 * ((i * 7) % 5) / 4.0\n", + " value = 0.70 + 0.30 * ((i * 11) % 3) / 2.0\n", + " r, g, b = colorsys.hsv_to_rgb(hue, saturation, value)\n", + " ADE20K_PALETTE[i] = [int(r * 255), int(g * 255), int(b * 255)]\n", + "\n", + "print(f'Defined {len(ADE20K_CLASSES)} classes and palette.')" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "6c4f2cb8", + "cell_type": "markdown", + "source": [ + "# Run Inference" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "93aa2aaf", + "cell_type": "code", + "source": [ + "# @title Run segmentation inference on ADE20K.\n", + "import PIL.Image\n", + "import matplotlib.pyplot as plt\n", + "import torchvision.transforms as TVT\n", + "import os\n", + "import torch\n", + "\n", + "# Find a sample image in the extracted ADE20K dataset.\n", + "ADE20K_DIR = os.path.join(ROOT_DIR, 'ADEChallengeData2016')\n", + "ade_img_dir = os.path.join(ADE20K_DIR, \"images\", \"validation\")\n", + "\n", + "if not os.path.exists(ade_img_dir):\n", + " raise FileNotFoundError(f\"ADE20K validation directory not found: {ade_img_dir}\")\n", + "\n", + "ade_images = sorted([os.path.join(ade_img_dir, f) for f in os.listdir(ade_img_dir) if f.endswith(\".jpg\")])\n", + "\n", + "if not ade_images:\n", + " raise FileNotFoundError(f\"No .jpg images found in {ade_img_dir}\")\n", + "\n", + "# It's a castle.\n", + "image_path = ade_images[1]\n", + "print(f'Using ADE20K image: {image_path}')\n", + "img = PIL.Image.open(image_path).convert(\"RGB\")\n", + "\n", + "transform = TVT.Compose([TVT.Resize((image_size, image_size)), TVT.ToTensor()])\n", + "tensor = transform(img).unsqueeze(0)\n", + "\n", + "device = next(model_image.parameters()).device\n", + "tensor = tensor.to(device)\n", + "\n", + "with torch.no_grad():\n", + " # Get intermediate features from ViT\n", + " features = model_image.get_intermediate_layers(\n", + " tensor, n=intermediate_layers, reshape=True, return_class_token=True, norm=True\n", + " )\n", + " # Swap order for decoder: (feat, cls) -\u003e (cls, feat)\n", + " features = [(cls, feat) for feat, cls in features]\n", + "\n", + " # Run Segmentation\n", + " seg_logits = seg_model(features, image_size=(image_size, image_size))\n", + " seg_map = seg_logits.argmax(dim=1).squeeze(0).cpu().numpy()\n", + "\n", + "# Visualize results\n", + "colored_seg = ADE20K_PALETTE[seg_map]\n", + "\n", + "plt.figure(figsize=(10, 5))\n", + "plt.subplot(1, 2, 1); plt.imshow(img.resize((image_size, image_size))); plt.title(\"Input Image (ADE)\"); plt.axis(\"off\")\n", + "plt.subplot(1, 2, 2); plt.imshow(colored_seg); plt.title(\"Segmentation\"); plt.axis(\"off\")\n", + "plt.show()" + ], + "metadata": {}, + "execution_count": null + }, + { + "id": "07e759dd", + "cell_type": "code", + "source": [ + "import PIL.Image\n", + "import matplotlib.cm as cm\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import os\n", + "import torch\n", + "import torchvision.transforms as TVT\n", + "\n", + "image_size = 448\n", + "\n", + "# Find a sample image in the extracted NYU dataset.\n", + "nyu_images = []\n", + "valid_extensions = (\".ppm\", \".jpg\", \".jpeg\", \".png\")\n", + "\n", + "for root, dirs, files in os.walk(IMG_DIR):\n", + " for file in files:\n", + " if file.lower().endswith(valid_extensions):\n", + " nyu_images.append(os.path.join(root, file))\n", + "\n", + "if not nyu_images:\n", + " raise FileNotFoundError(f\"No valid images found in {IMG_DIR}\")\n", + "\n", + "# Select up to 10 images for visualization\n", + "selected_images = nyu_images[:3]\n", + "\n", + "for i, image_path in enumerate(selected_images):\n", + " print(f\"Processing NYU image {i+1}/{len(selected_images)}: {image_path}\")\n", + " img = PIL.Image.open(image_path)\n", + "\n", + " transform = TVT.Compose(\n", + " [TVT.Resize((image_size, image_size)), TVT.ToTensor()]\n", + " )\n", + " tensor = transform(img).unsqueeze(0)\n", + "\n", + " device = next(model_image.parameters()).device\n", + " tensor = tensor.to(device)\n", + "\n", + " with torch.no_grad():\n", + " # Get intermediate features from ViT\n", + " features = model_image.get_intermediate_layers(\n", + " tensor,\n", + " n=intermediate_layers,\n", + " reshape=True,\n", + " return_class_token=True,\n", + " norm=True,\n", + " )\n", + " # Swap order for decoder: (feat, cls) -\u003e (cls, feat)\n", + " features = [(cls, feat) for feat, cls in features]\n", + "\n", + " # 1. Run Depth\n", + " depth_map = depth_model(features, image_size=(image_size, image_size))\n", + " depth_map = depth_map.squeeze().cpu().numpy()\n", + "\n", + " # Explicitly normalize depth for visualization (HuggingFace style)\n", + " depth_map = (depth_map - depth_map.min()) / (\n", + " depth_map.max() - depth_map.min() + 1e-8\n", + " )\n", + " # 2. Run Normals (Do NOT pass image_size here to get the raw low-res output)\n", + " normals_map = normals_model(features)\n", + "\n", + " # Normalize at low resolution\n", + " normals_map = torch.nn.functional.normalize(normals_map, dim=1)\n", + "\n", + " # Upsample in Colab using BICUBIC for smoother results and less grid artifacts\n", + " normals_map = torch.nn.functional.interpolate(\n", + " normals_map,\n", + " size=(image_size, image_size),\n", + " mode=\"bicubic\",\n", + " align_corners=False,\n", + " )\n", + "\n", + " # Re-normalize after upsampling to ensure they are still unit vectors\n", + " normals_map = torch.nn.functional.normalize(normals_map, dim=1)\n", + "\n", + " normals_map = normals_map.squeeze(0).cpu().numpy() # (3, H, W)\n", + " normals_map = np.transpose(normals_map, (1, 2, 0)) # (H, W, 3)\n", + "\n", + " # Map normals from [-1, 1] to [0, 1]\n", + " normals_map = (normals_map + 1.0) / 2.0\n", + " normals_map = np.clip(normals_map, 0.0, 1.0)\n", + "\n", + " # Visualize results\n", + " plt.figure(figsize=(15, 5))\n", + "\n", + " # Input Image\n", + " plt.subplot(1, 3, 1)\n", + " plt.imshow(img.resize((image_size, image_size)))\n", + " plt.title(f\"Input Image (NYU) {i+1}\")\n", + " plt.axis(\"off\")\n", + "\n", + " # Depth Map (Using 'turbo' colormap to match HF DPT depth)\n", + " plt.subplot(1, 3, 2)\n", + " plt.imshow(depth_map, cmap=\"turbo\")\n", + " plt.title(f\"Depth {i+1}\")\n", + " plt.axis(\"off\")\n", + "\n", + " # Surface Normals\n", + " plt.subplot(1, 3, 3)\n", + " plt.imshow(normals_map)\n", + " plt.title(f\"Surface Normals {i+1}\")\n", + " plt.axis(\"off\")\n", + "\n", + " plt.show() # Added to ensure plots are shown in each iteration" + ], + "metadata": {}, + "execution_count": null + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat_minor": 5, + "nbformat": 4 +} diff --git a/pytorch/decoders.py b/pytorch/decoders.py index e1dd93f..824e522 100644 --- a/pytorch/decoders.py +++ b/pytorch/decoders.py @@ -196,6 +196,7 @@ def forward( out = self.fusion_blocks[i](out, residual=x[-(i + 1)]) out = self.project(out) + out = F.relu(out) return out @@ -264,16 +265,26 @@ def __init__(self, num_classes: int = 150, **kwargs) -> None: class DepthDecoder(Decoder): - """Decoder for monocular depth prediction using classification bins.""" + """Decoder for monocular depth prediction using classification bins. - def __init__(self, min_depth: float = 0.001, max_depth: float = 10.0, **kwargs) -> None: - # Decoder requires out_channels, we pass 256 as we use channels as bins, - # although we bypass the head in forward(). - super().__init__(out_channels=256, **kwargs) + Predicts depth by classifying each pixel into uniformly-spaced depth bins + and computing the expected depth value. + """ + + def __init__( + self, + num_depth_bins: int = 256, + min_depth: float = 0.001, + max_depth: float = 10.0, + **kwargs, + ) -> None: + super().__init__(out_channels=num_depth_bins, **kwargs) self.min_depth = min_depth self.max_depth = max_depth + self.num_depth_bins = num_depth_bins self.register_buffer( - "bin_centers", torch.linspace(min_depth, max_depth, 256) + "bin_centers", + torch.linspace(min_depth, max_depth, num_depth_bins), ) def forward( @@ -281,20 +292,23 @@ def forward( intermediate_features: List[Tuple[torch.Tensor, torch.Tensor]], image_size: Optional[Tuple[int, int]] = None, ) -> torch.Tensor: - # Bypass super().forward() to avoid the linear head applied there, - # and use raw DPT features as logits. - logits = self.dpt(intermediate_features) # (B, C, H', W') - # Apply ReLU and shift + # 1. Get DPT features + task head (nn.Linear) via parent class. + # Output shape: (B, num_depth_bins, H', W') + logits = super().forward(intermediate_features) + + # 2. Classification-based depth prediction (following Scenic/AdaBins): + # relu + shift -> linear normalisation -> expectation over bins. logits = torch.relu(logits) + self.min_depth - # Normalize to probabilities along the channel dimension probs = logits / torch.sum(logits, dim=1, keepdim=True) - # Compute expectation: sum(prob * bin_center) - depth_map = torch.einsum( - "bchw,c->bhw", probs, self.bin_centers.to(logits.device) - ) + depth_map = torch.einsum("bchw,c->bhw", probs, self.bin_centers.to(logits.device)) + + # 3. Upsample to target resolution. if image_size is not None: depth_map = F.interpolate( - depth_map.unsqueeze(1), size=image_size, mode="bilinear", align_corners=False + depth_map.unsqueeze(1), + size=image_size, + mode="bilinear", + align_corners=False, ).squeeze(1) return depth_map.unsqueeze(1) @@ -315,7 +329,12 @@ def __init__(self, **kwargs) -> None: "convs.": "dpt.convs.", "fusion_blocks.": "dpt.fusion_blocks.", "project.": "dpt.project.", + # Task-specific head keys (Scenic Dense -> PyTorch head.*) "segmentation_head.": "head.", + "pixel_segmentation.": "head.", + "pixel_depth_classif.": "head.", + "pixel_depth_regress.": "head.", + "pixel_normals.": "head.", }