diff --git a/dinov2/configs/train/vitg14_reg4.yaml b/dinov2/configs/train/vitg14_reg4.yaml index 5df9524..05ceb64 100644 --- a/dinov2/configs/train/vitg14_reg4.yaml +++ b/dinov2/configs/train/vitg14_reg4.yaml @@ -41,6 +41,26 @@ optim: warmup_epochs: 10 layerwise_decay: 1.0 crops: + # -------------------------------------------------------- + # "baseline" : Only RandomResizedCrop + Horizontal Flip (DINOv2 Defaults) + # "vflip" : Baseline + Vertical Flip + # "rotate90" : Baseline + 90-degree Rotations + # "colorjitter_weak" : Baseline + Subtle Color Shifts + # "colorjitter_strong" : Baseline + Aggressive Color Shifts + # "blur_weak" : Baseline + Subtle Out-of-focus + # "blur_strong" : Baseline + Heavy Scanner Blur + # "hed_weak" : Baseline + Subtle HED Stain shift (0.03) + # "hed_medium" : Baseline + Standard HED Stain shift (0.08) + # "hed_strong" : Baseline + Aggressive HED Stain shift (0.15) + # "randstain_weak" : Baseline + RandStainNA Normalization (-0.5) + # "randstain_medium" : Baseline + RandStainNA Standard (-0.3) + # "randstain_strong" : Baseline + RandStainNA High Diversity (0.0) + # "randstain_hed_combo": Baseline + Both HED and RandStain + # "elastic" : Baseline + Tissue Stretching + # "jpeg" : Baseline + JPEG Artifacts + # "combo" : FULL PIPELINE + # ------------------------------ + ablation_mode: "elastic" local_crops_size: 98 evaluation: eval_period_iterations: 5000 # save checkpoint every 10 epochs diff --git a/dinov2/data/augmentations.py b/dinov2/data/augmentations.py index b57e680..1b3cc94 100644 --- a/dinov2/data/augmentations.py +++ b/dinov2/data/augmentations.py @@ -193,6 +193,97 @@ def forward(self, img): return img return self.augment(img) +class ElasticDeformation(torch.nn.Module): + def __init__(self, low_alpha=40.0, high_alpha=200.0, low_sigma=5.0, high_sigma=10.0, probability=0.5): + super().__init__() + self.low_alpha = low_alpha + self.high_alpha = high_alpha + self.low_sigma = low_sigma + self.high_sigma = high_sigma + self.probability = probability + + def forward(self, img): + if random.random() > self.probability: + return img + + + alpha = random.uniform(self.low_alpha, self.high_alpha) + sigma = random.uniform(self.low_sigma, self.high_sigma) + + + if isinstance(img, Image.Image): + img_np = np.array(img) + input_type = 'PIL' + elif isinstance(img, torch.Tensor): + img_np = (img.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) + input_type = 'Tensor' + else: + img_np = np.array(img) + input_type = 'Array' + + h, w = img_np.shape[:2] + + dx = cv2.GaussianBlur((np.random.rand(h, w) * 2 - 1).astype(np.float32), (0, 0), sigma) * alpha + dy = cv2.GaussianBlur((np.random.rand(h, w) * 2 - 1).astype(np.float32), (0, 0), sigma) * alpha + + x, y = np.meshgrid(np.arange(w), np.arange(h)) + map_x = (x + dx).astype(np.float32) + map_y = (y + dy).astype(np.float32) + + + distorted = cv2.remap(img_np, map_x, map_y, cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101) + + + if input_type == 'PIL': + return Image.fromarray(distorted) + elif input_type == 'Tensor': + return torch.from_numpy(distorted).permute(2, 0, 1).float() / 255.0 + return distorted + + +class JpegCompression(torch.nn.Module): + + def __init__(self, quality_lower=20, quality_upper=100, probability=0.5): + super().__init__() + self.quality_lower = quality_lower + self.quality_upper = quality_upper + self.probability = probability + + def forward(self, img): + if random.random() > self.probability: + return img + + quality = random.randint(self.quality_lower, self.quality_upper) + + was_pil = False + was_tensor = False + + if isinstance(img, Image.Image): + img_np = np.array(img) + was_pil = True + elif isinstance(img, torch.Tensor): + img_np = (img.permute(1, 2, 0).numpy() * 255.0).astype(np.uint8) + was_tensor = True + else: + img_np = img + + img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + + encode_param =[int(cv2.IMWRITE_JPEG_QUALITY), quality] + result, encimg = cv2.imencode('.jpg', img_bgr, encode_param) + + if result: + decimg = cv2.imdecode(encimg, 1) + img_rgb = cv2.cvtColor(decimg, cv2.COLOR_BGR2RGB) + else: + img_rgb = img_np + + if was_pil: + return Image.fromarray(img_rgb) + elif was_tensor: + return torch.from_numpy(img_rgb).permute(2, 0, 1).float() / 255.0 + + return img_rgb class hed_mod(torch.nn.Module): """ @@ -276,6 +367,7 @@ def __init__( local_crops_number, global_crops_size=224, local_crops_size=96, + ablation_mode="baseline", # This controls the abalation choices (baseline only applies the RandomResizeCrop + Horizontal Flip) ): self.global_crops_scale = global_crops_scale self.local_crops_scale = local_crops_scale @@ -292,73 +384,161 @@ def __init__( logger.info(f"local_crops_size: {local_crops_size}") logger.info("###################################") - # Geometric augmentations with rotation and both flips - self.geometric_augmentation_global = transforms.Compose( - [ - # RandomRotation90(), - transforms.RandomResizedCrop( - global_crops_size, scale=global_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC - ), - transforms.RandomHorizontalFlip(p=0.5), - # transforms.RandomVerticalFlip(p=0.5), - ] - ) - - self.geometric_augmentation_local = transforms.Compose( - [ - # RandomRotation90(), - transforms.RandomResizedCrop( - local_crops_size, scale=local_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC - ), - transforms.RandomHorizontalFlip(p=0.5), - # transforms.RandomVerticalFlip(p=0.5), - ] - ) - - # Normalization (ImageNet stats used by default) - self.normalize = transforms.Compose( - [ + # Everything is OFF by default. Only RandomResizedCrop and H-Flip are applied. + vflip_p = 0.0 + rotate90_p = 0.0 + cj_params = None # (brightness, contrast, saturation, hue) + gb_params = None # (kernel_size, sigma_min, sigma_max) + hed_range = None + rs_hyper = None + use_elastic = False + use_jpeg = False + # use_ect = False + + if ablation_mode == "baseline": + pass + elif ablation_mode == "vflip": + vflip_p = 0.5 + elif ablation_mode == "rotate90": + rotate90_p = 0.5 + + # Color Jitter Ablations + elif ablation_mode == "colorjitter_weak": + cj_params = (0.1, 0.1, 0.05, 0.02) + elif ablation_mode == "colorjitter_medium": + cj_params = (0.2, 0.2, 0.1, 0.05) + elif ablation_mode == "colorjitter_strong": + cj_params = (0.4, 0.4, 0.2, 0.1) + + # Gaussian Blur Ablations + elif ablation_mode == "blur_weak": + gb_params = (7, 0.3, 0.8) + elif ablation_mode == "blur_medium": + gb_params = (9, 0.5, 1) + elif ablation_mode == "blur_strong": + gb_params = (15, 1.5, 2.5) + + # HED Ablations + elif ablation_mode == "hed_weak": + hed_range = 0.005 + elif ablation_mode == "hed_medium": + hed_range = 0.01 + elif ablation_mode == "hed_strong": + hed_range = 0.03 + + # RandStainNA Ablations + elif ablation_mode == "randstain_weak": + rs_hyper = -0.4 + elif ablation_mode == "randstain_medium": + rs_hyper = -0.1 + elif ablation_mode == "randstain_strong": + rs_hyper = 0.0 + + # Combo Ablations (randstain + hed) + elif ablation_mode == "randstain_hed_combo": + rs_hyper = -0.3 + hed_range = 0.08 + + elif ablation_mode == "elastic": + use_elastic = True + + elif ablation_mode == "jpeg": + use_jpeg = True + + elif ablation_mode == "combo": + # Combo of all above augmentations that showed improvements in isolation + vflip_p = 0.5 + rotate90_p = 0.5 + cj_params = (0.2, 0.2, 0.1, 0.05) + gb_params = (9, 0.1, 2.0) + hed_range = 0.05 + rs_hyper = -0.3 + use_elastic = True + use_jpeg = True + else: + raise ValueError(f"Unknown ablation mode: {ablation_mode}") + + # --- 3. LOGGING THE ACTIVE PIPELINE --- + logger.info("=========================================") + logger.info(f" BUILDING PIPELINE: MODE '{ablation_mode.upper()}' ") + logger.info("=========================================") + logger.info(f" [ON] RandomResizedCrop (Global: {global_crops_size}, Local: {local_crops_size})") + logger.info(f" [ON] HorizontalFlip (p=0.5)") + logger.info(f" [{'ON' if vflip_p > 0 else 'OFF'}] VerticalFlip (p={vflip_p})") + logger.info(f"[{'ON' if rotate90_p > 0 else 'OFF'}] Rotate90 (p={rotate90_p})") + + if cj_params: + logger.info(f" [ON] ColorJitter (B:{cj_params[0]}, C:{cj_params[1]}, S:{cj_params[2]}, H:{cj_params[3]})") + else: + logger.info(f" [OFF] ColorJitter") + + if gb_params: + logger.info(f" [ON] GaussianBlur (Kernel:{gb_params[0]}, Sigma range: {gb_params[1]}-{gb_params[2]})") + else: + logger.info(f" [OFF] GaussianBlur") + + logger.info(f" [{'ON' if hed_range else 'OFF'}] HED Stain Perturbation (Range: +/-{hed_range})") + logger.info(f"[{'ON' if rs_hyper is not None else 'OFF'}] RandStainNA (Std Hyper: {rs_hyper})") + logger.info(f"[{'ON' if use_elastic else 'OFF'}] Elastic Deformation") + logger.info(f"[{'ON' if use_jpeg else 'OFF'}] JPEG Compression") + logger.info("=========================================\n") + + + geom_global =[ + transforms.RandomResizedCrop(global_crops_size, scale=global_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.RandomHorizontalFlip(p=0.5), + ] + geom_local =[ + transforms.RandomResizedCrop(local_crops_size, scale=local_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.RandomHorizontalFlip(p=0.5), + ] + + if vflip_p > 0: + geom_global.append(transforms.RandomVerticalFlip(p=vflip_p)) + geom_local.append(transforms.RandomVerticalFlip(p=vflip_p)) + if rotate90_p > 0: + geom_global.append(RandomRotation90()) + geom_local.append(RandomRotation90()) + + self.geometric_augmentation_global = transforms.Compose(geom_global) + self.geometric_augmentation_local = transforms.Compose(geom_local) + + def build_pixel_pipeline(blur_probability): + pipeline =[] + + if use_elastic: + pipeline.append(ElasticDeformation(probability=0.5)) + + if rs_hyper is not None: + pipeline.append(RandStainNA(color_space='LAB', std_hyper=rs_hyper, distribution='normal', probability=0.5)) + + if hed_range is not None: + pipeline.append(hed_mod(probability=0.5, perturbation_range=hed_range)) + + if cj_params is not None: + pipeline.append(transforms.RandomApply([transforms.ColorJitter(*cj_params)], p=0.8)) + pipeline.append(transforms.RandomGrayscale(p=0.2)) + + if gb_params is not None: + k, s_min, s_max = gb_params + pipeline.append(transforms.RandomApply([transforms.GaussianBlur(kernel_size=k, sigma=(s_min, s_max))], p=blur_probability)) + + if use_jpeg: + pipeline.append(JpegCompression(probability=0.5)) + + pipeline.extend([ transforms.ToTensor(), make_normalize_transform(), - ] - ) - - # Pathology-specific stain augmentations - randstainna = RandStainNA( - color_space='LAB', - std_hyper=-0.3, - distribution='normal', - probability=0.5, - ) - - hed_aug = hed_mod(probability=0.5, perturbation_range=0.05) - - self.global_transfo1 = transforms.Compose([ - # randstainna, - hed_aug, - transforms.RandomApply([transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.05)], p=0.8), - transforms.RandomGrayscale(p=0.2), - GaussianBlur(p=1.0), - self.normalize - ]) - - self.global_transfo2 = transforms.Compose([ - # randstainna, - hed_aug, - transforms.RandomApply([transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.05)], p=0.8), - transforms.RandomGrayscale(p=0.2), - GaussianBlur(p=0.1), - self.normalize - ]) - - self.local_transfo = transforms.Compose([ - # randstainna, - hed_aug, - transforms.RandomApply([transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.05)], p=0.8), - transforms.RandomGrayscale(p=0.2), - GaussianBlur(p=0.5), - self.normalize - ]) + ]) + return transforms.Compose(pipeline) + + # DINOv2 uses asymmetric blur probabilities for the different crops + self.global_transfo1 = build_pixel_pipeline(blur_probability=1.0) + self.global_transfo2 = build_pixel_pipeline(blur_probability=0.1) + self.local_transfo = build_pixel_pipeline(blur_probability=0.5) + + # Normalization (ImageNet stats used by default) + self.normalize = transforms.Compose([transforms.ToTensor(), make_normalize_transform()]) def __call__(self, image): output = {} diff --git a/dinov2/data/curation/clusterer.py b/dinov2/data/curation/clusterer.py new file mode 100644 index 0000000..3f5278f --- /dev/null +++ b/dinov2/data/curation/clusterer.py @@ -0,0 +1,50 @@ +import h5py +import faiss +import argparse +import numpy as np +from tqdm import tqdm + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--embed-file", type=str, default="tcga_embeddings.h5") + parser.add_argument("--output-file", type=str, default="tcga_clusters.h5") + parser.add_argument("--k1", type=int, default=10000) + parser.add_argument("--k2", type=int, default=1000) + return parser.parse_args() + +def run_clustering(args): + print("Loading embeddings...") + with h5py.File(args.embed_file, 'r') as f: + embeddings = f['embeddings'][:] + + N, dim = embeddings.shape + print(f"Loaded {N} embeddings of dimension {dim}.") + + print(f"Training Level 1 K-Means (K1={args.k1})...") + + kmeans1 = faiss.Kmeans(d=dim, k=args.k1, niter=50, verbose=True, gpu=True) + kmeans1.train(embeddings) + + print("Assigning patches to K1 clusters...") + _, labels_k1 = kmeans1.index.search(embeddings, 1) + labels_k1 = labels_k1.squeeze() + + print(f"Training Level 2 K-Means (K2={args.k2}) on K1 centroids...") + kmeans2 = faiss.Kmeans(d=dim, k=args.k2, niter=50, verbose=True, gpu=True) + kmeans2.train(kmeans1.centroids) + + print("Assigning K1 centroids to K2 clusters...") + _, k1_to_k2_map = kmeans2.index.search(kmeans1.centroids, 1) + k1_to_k2_map = k1_to_k2_map.squeeze() + + print("Mapping original patches to final K2 hierarchy...") + labels_k2 = k1_to_k2_map[labels_k1] + + with h5py.File(args.output_file, 'w') as f: + f.create_dataset("cluster_k1", data=labels_k1) + f.create_dataset("cluster_k2", data=labels_k2) + + print(f"Clustering complete. Saved to {args.output_file}.") + +if __name__ == "__main__": + run_clustering(get_args()) diff --git a/dinov2/data/curation/extractor.py b/dinov2/data/curation/extractor.py new file mode 100644 index 0000000..d37a41e --- /dev/null +++ b/dinov2/data/curation/extractor.py @@ -0,0 +1,145 @@ +import os +import torch +import h5py +import argparse +import numpy as np +import cv2 +from tqdm import tqdm +from PIL import Image +from io import BytesIO +from datasets import load_dataset +from torchvision import transforms + +from dinov2.models import build_model_from_cfg +from dinov2.utils.config import setup + +def get_args(): + parser = argparse.ArgumentParser("OpenMidnight Feature Extractor") + parser.add_argument("--config-file", type=str, required=True) + parser.add_argument("--checkpoint", type=str, required=True) + parser.add_argument("--output-file", type=str, default="tcga_embeddings.h5") + parser.add_argument("--data-dir", type=str, required=True) + parser.add_argument("--batch-size", type=int, default=1024) + parser.add_argument("--max-patches", type=int, default=2000000) + parser.add_argument("--num-workers", type=int, default=12) + parser.add_argument("--output-dir", type=str, default=".") + parser.add_argument("opts", default=None, nargs=argparse.REMAINDER) + return parser.parse_args() + +class FastPatchedDataset(torch.utils.data.IterableDataset): + def __init__(self, ds, transform): + self.ds = ds + self.transform = transform + + def __iter__(self): + worker_info = torch.utils.data.get_worker_info() + ds_sharded = self.ds.shard(num_shards=worker_info.num_workers, index=worker_info.id) if worker_info else self.ds + + for item in ds_sharded: + try: + img_array = np.frombuffer(item["image_bytes"], np.uint8) + img = cv2.imdecode(img_array, cv2.IMREAD_COLOR) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img_pil = Image.fromarray(img) + + tensor_img = self.transform(img_pil) + meta = f"{item['slide_path']},{item['x']},{item['y']}" + yield tensor_img, meta + except Exception: + continue + +@torch.no_grad() +def extract_features(args): + device = torch.device("cuda") + + print("Loading and compiling model...") + cfg = setup(args) + model, _ = build_model_from_cfg(cfg, only_teacher=True) + + state_dict = torch.load(args.checkpoint, map_location="cpu")["teacher"] + state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items() if k.startswith("backbone.")} + model.load_state_dict(state_dict, strict=False) + model.to(device).eval() + + model = torch.compile(model) + + transform = transforms.Compose([ + transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + ]) + + class FastPatchedDataset(torch.utils.data.IterableDataset): + def __init__(self, ds, transform, max_patches): + self.ds = ds + self.transform = transform + self.max_patches = max_patches + + def __iter__(self): + worker_info = torch.utils.data.get_worker_info() + ds_sharded = self.ds.shard(num_shards=worker_info.num_workers, index=worker_info.id) if worker_info else self.ds + + count = 0 + for item in ds_sharded: + if count >= (self.max_patches // (worker_info.num_workers if worker_info else 1)): break + + + img_array = np.frombuffer(item["image_bytes"], np.uint8) + img = cv2.imdecode(img_array, cv2.IMREAD_COLOR) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = Image.fromarray(img) + + yield self.transform(img), f"{item['slide_path']},{item['x']},{item['y']}" + count += 1 + + local_files = [os.path.join(args.data_dir, f) for f in os.listdir(args.data_dir) if f.endswith('.parquet')] + raw_ds = load_dataset("parquet", data_files=local_files, split="train", streaming=True) + + loader = torch.utils.data.DataLoader( + FastPatchedDataset(raw_ds, transform, args.max_patches), + batch_size=args.batch_size, + num_workers=args.num_workers, + prefetch_factor=8, + pin_memory=True, + persistent_workers=True + ) + + + WRITE_BUFFER_SIZE = 10000 + embed_buffer = [] + meta_buffer = [] + + embed_dim = model.embed_dim + with h5py.File(args.output_file, 'w') as f: + dset_embeds = f.create_dataset("embeddings", shape=(0, embed_dim), maxshape=(None, embed_dim), dtype='float32', chunks=(args.batch_size, embed_dim)) + dset_meta = f.create_dataset("metadata", shape=(0,), maxshape=(None,), dtype=h5py.string_dtype()) + + pbar = tqdm(total=args.max_patches, desc="⚡ Extracting") + + for imgs, metas in loader: + imgs = imgs.to(device, non_blocking=True) + + with torch.autocast(device_type="cuda", dtype=torch.bfloat16): + features = model(imgs, is_training=False).cpu().numpy() + + embed_buffer.append(features) + meta_buffer.extend(metas) + + if len(meta_buffer) >= WRITE_BUFFER_SIZE: + curr_len = len(dset_embeds) + all_feats = np.concatenate(embed_buffer, axis=0) + actual_write_size = len(all_feats) + + dset_embeds.resize(curr_len + actual_write_size, axis=0) + dset_embeds[curr_len:] = all_feats + + dset_meta.resize(curr_len + actual_write_size, axis=0) + dset_meta[curr_len:] = meta_buffer + + embed_buffer, meta_buffer = [], [] # Clear RAM + + pbar.update(len(features)) + +if __name__ == "__main__": + extract_features(get_args()) diff --git a/dinov2/data/curation/sampler.py b/dinov2/data/curation/sampler.py new file mode 100644 index 0000000..d2e2b97 --- /dev/null +++ b/dinov2/data/curation/sampler.py @@ -0,0 +1,44 @@ +import h5py +import argparse +import numpy as np +import json + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--embed-file", type=str, default="tcga_embeddings.h5") + parser.add_argument("--cluster-file", type=str, default="tcga_clusters.h5") + parser.add_argument("--output-txt", type=str, default="curated_dataset.txt") + parser.add_argument("--samples-per-cluster", type=int, default=500) + return parser.parse_args() + +def sample_dataset(args): + with h5py.File(args.embed_file, 'r') as f_embed, h5py.File(args.cluster_file, 'r') as f_cluster: + metadata = f_embed['metadata'][:] + labels_k2 = f_cluster['cluster_k2'][:] + + cluster_to_indices = {} + for i, cluster_id in enumerate(labels_k2): + if cluster_id not in cluster_to_indices: + cluster_to_indices[cluster_id] = [] + cluster_to_indices[cluster_id].append(i) + + selected_indices =[] + for cluster_id, indices in cluster_to_indices.items(): + if len(indices) > args.samples_per_cluster: + sampled = np.random.choice(indices, size=args.samples_per_cluster, replace=False) + else: + sampled = indices + selected_indices.extend(sampled) + + np.random.shuffle(selected_indices) + + print(f"Selected {len(selected_indices)} patches.") + with open(args.output_txt, 'w') as f: + for idx in selected_indices: + meta_str = metadata[idx].decode('utf-8') + f.write(f"{meta_str}\n") + + print(f"Saved curated list to {args.output_txt}.") + +if __name__ == "__main__": + sample_dataset(get_args()) diff --git a/dinov2/train/train.py b/dinov2/train/train.py index a824e63..b4f9773 100644 --- a/dinov2/train/train.py +++ b/dinov2/train/train.py @@ -1013,6 +1013,7 @@ def do_train(cfg, model, resume=False): cfg.crops.local_crops_number, global_crops_size=cfg.crops.global_crops_size, local_crops_size=cfg.crops.local_crops_size, + ablation_mode=getattr(cfg.crops, "ablation_mode", "baseline"), ) collate_fn = partial(