Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions dinov2/configs/train/vitg14_reg4.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,26 @@ optim:
warmup_epochs: 10
layerwise_decay: 1.0
crops:
# --------------------------------------------------------
# "baseline" : Only RandomResizedCrop + Horizontal Flip (DINOv2 Defaults)
# "vflip" : Baseline + Vertical Flip
# "rotate90" : Baseline + 90-degree Rotations
# "colorjitter_weak" : Baseline + Subtle Color Shifts
# "colorjitter_strong" : Baseline + Aggressive Color Shifts
# "blur_weak" : Baseline + Subtle Out-of-focus
# "blur_strong" : Baseline + Heavy Scanner Blur
# "hed_weak" : Baseline + Subtle HED Stain shift (0.03)
# "hed_medium" : Baseline + Standard HED Stain shift (0.08)
# "hed_strong" : Baseline + Aggressive HED Stain shift (0.15)
# "randstain_weak" : Baseline + RandStainNA Normalization (-0.5)
# "randstain_medium" : Baseline + RandStainNA Standard (-0.3)
# "randstain_strong" : Baseline + RandStainNA High Diversity (0.0)
# "randstain_hed_combo": Baseline + Both HED and RandStain
# "elastic" : Baseline + Tissue Stretching
# "jpeg" : Baseline + JPEG Artifacts
# "combo" : FULL PIPELINE
# ------------------------------
ablation_mode: "elastic"
local_crops_size: 98
evaluation:
eval_period_iterations: 5000 # save checkpoint every 10 epochs
Expand Down
310 changes: 245 additions & 65 deletions dinov2/data/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,97 @@ def forward(self, img):
return img
return self.augment(img)

class ElasticDeformation(torch.nn.Module):
def __init__(self, low_alpha=40.0, high_alpha=200.0, low_sigma=5.0, high_sigma=10.0, probability=0.5):
super().__init__()
self.low_alpha = low_alpha
self.high_alpha = high_alpha
self.low_sigma = low_sigma
self.high_sigma = high_sigma
self.probability = probability

def forward(self, img):
if random.random() > self.probability:
return img


alpha = random.uniform(self.low_alpha, self.high_alpha)
sigma = random.uniform(self.low_sigma, self.high_sigma)


if isinstance(img, Image.Image):
img_np = np.array(img)
input_type = 'PIL'
elif isinstance(img, torch.Tensor):
img_np = (img.permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
input_type = 'Tensor'
else:
img_np = np.array(img)
input_type = 'Array'

h, w = img_np.shape[:2]

dx = cv2.GaussianBlur((np.random.rand(h, w) * 2 - 1).astype(np.float32), (0, 0), sigma) * alpha
dy = cv2.GaussianBlur((np.random.rand(h, w) * 2 - 1).astype(np.float32), (0, 0), sigma) * alpha

x, y = np.meshgrid(np.arange(w), np.arange(h))
map_x = (x + dx).astype(np.float32)
map_y = (y + dy).astype(np.float32)


distorted = cv2.remap(img_np, map_x, map_y, cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT_101)


if input_type == 'PIL':
return Image.fromarray(distorted)
elif input_type == 'Tensor':
return torch.from_numpy(distorted).permute(2, 0, 1).float() / 255.0
return distorted


class JpegCompression(torch.nn.Module):

def __init__(self, quality_lower=20, quality_upper=100, probability=0.5):
super().__init__()
self.quality_lower = quality_lower
self.quality_upper = quality_upper
self.probability = probability

def forward(self, img):
if random.random() > self.probability:
return img

quality = random.randint(self.quality_lower, self.quality_upper)

was_pil = False
was_tensor = False

if isinstance(img, Image.Image):
img_np = np.array(img)
was_pil = True
elif isinstance(img, torch.Tensor):
img_np = (img.permute(1, 2, 0).numpy() * 255.0).astype(np.uint8)
was_tensor = True
else:
img_np = img

img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)

encode_param =[int(cv2.IMWRITE_JPEG_QUALITY), quality]
result, encimg = cv2.imencode('.jpg', img_bgr, encode_param)

if result:
decimg = cv2.imdecode(encimg, 1)
img_rgb = cv2.cvtColor(decimg, cv2.COLOR_BGR2RGB)
else:
img_rgb = img_np

if was_pil:
return Image.fromarray(img_rgb)
elif was_tensor:
return torch.from_numpy(img_rgb).permute(2, 0, 1).float() / 255.0

return img_rgb

class hed_mod(torch.nn.Module):
"""
Expand Down Expand Up @@ -276,6 +367,7 @@ def __init__(
local_crops_number,
global_crops_size=224,
local_crops_size=96,
ablation_mode="baseline", # This controls the abalation choices (baseline only applies the RandomResizeCrop + Horizontal Flip)
):
self.global_crops_scale = global_crops_scale
self.local_crops_scale = local_crops_scale
Expand All @@ -292,73 +384,161 @@ def __init__(
logger.info(f"local_crops_size: {local_crops_size}")
logger.info("###################################")

# Geometric augmentations with rotation and both flips
self.geometric_augmentation_global = transforms.Compose(
[
# RandomRotation90(),
transforms.RandomResizedCrop(
global_crops_size, scale=global_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC
),
transforms.RandomHorizontalFlip(p=0.5),
# transforms.RandomVerticalFlip(p=0.5),
]
)

self.geometric_augmentation_local = transforms.Compose(
[
# RandomRotation90(),
transforms.RandomResizedCrop(
local_crops_size, scale=local_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC
),
transforms.RandomHorizontalFlip(p=0.5),
# transforms.RandomVerticalFlip(p=0.5),
]
)

# Normalization (ImageNet stats used by default)
self.normalize = transforms.Compose(
[
# Everything is OFF by default. Only RandomResizedCrop and H-Flip are applied.
vflip_p = 0.0
rotate90_p = 0.0
cj_params = None # (brightness, contrast, saturation, hue)
gb_params = None # (kernel_size, sigma_min, sigma_max)
hed_range = None
rs_hyper = None
use_elastic = False
use_jpeg = False
# use_ect = False

if ablation_mode == "baseline":
pass
elif ablation_mode == "vflip":
vflip_p = 0.5
elif ablation_mode == "rotate90":
rotate90_p = 0.5

# Color Jitter Ablations
elif ablation_mode == "colorjitter_weak":
cj_params = (0.1, 0.1, 0.05, 0.02)
elif ablation_mode == "colorjitter_medium":
cj_params = (0.2, 0.2, 0.1, 0.05)
elif ablation_mode == "colorjitter_strong":
cj_params = (0.4, 0.4, 0.2, 0.1)

# Gaussian Blur Ablations
elif ablation_mode == "blur_weak":
gb_params = (7, 0.3, 0.8)
elif ablation_mode == "blur_medium":
gb_params = (9, 0.5, 1)
elif ablation_mode == "blur_strong":
gb_params = (15, 1.5, 2.5)

# HED Ablations
elif ablation_mode == "hed_weak":
hed_range = 0.005
elif ablation_mode == "hed_medium":
hed_range = 0.01
elif ablation_mode == "hed_strong":
hed_range = 0.03

# RandStainNA Ablations
elif ablation_mode == "randstain_weak":
rs_hyper = -0.4
elif ablation_mode == "randstain_medium":
rs_hyper = -0.1
elif ablation_mode == "randstain_strong":
rs_hyper = 0.0

# Combo Ablations (randstain + hed)
elif ablation_mode == "randstain_hed_combo":
rs_hyper = -0.3
hed_range = 0.08

elif ablation_mode == "elastic":
use_elastic = True

elif ablation_mode == "jpeg":
use_jpeg = True

elif ablation_mode == "combo":
# Combo of all above augmentations that showed improvements in isolation
vflip_p = 0.5
rotate90_p = 0.5
cj_params = (0.2, 0.2, 0.1, 0.05)
gb_params = (9, 0.1, 2.0)
hed_range = 0.05
rs_hyper = -0.3
use_elastic = True
use_jpeg = True
else:
raise ValueError(f"Unknown ablation mode: {ablation_mode}")

# --- 3. LOGGING THE ACTIVE PIPELINE ---
logger.info("=========================================")
logger.info(f" BUILDING PIPELINE: MODE '{ablation_mode.upper()}' ")
logger.info("=========================================")
logger.info(f" [ON] RandomResizedCrop (Global: {global_crops_size}, Local: {local_crops_size})")
logger.info(f" [ON] HorizontalFlip (p=0.5)")
logger.info(f" [{'ON' if vflip_p > 0 else 'OFF'}] VerticalFlip (p={vflip_p})")
logger.info(f"[{'ON' if rotate90_p > 0 else 'OFF'}] Rotate90 (p={rotate90_p})")

if cj_params:
logger.info(f" [ON] ColorJitter (B:{cj_params[0]}, C:{cj_params[1]}, S:{cj_params[2]}, H:{cj_params[3]})")
else:
logger.info(f" [OFF] ColorJitter")

if gb_params:
logger.info(f" [ON] GaussianBlur (Kernel:{gb_params[0]}, Sigma range: {gb_params[1]}-{gb_params[2]})")
else:
logger.info(f" [OFF] GaussianBlur")

logger.info(f" [{'ON' if hed_range else 'OFF'}] HED Stain Perturbation (Range: +/-{hed_range})")
logger.info(f"[{'ON' if rs_hyper is not None else 'OFF'}] RandStainNA (Std Hyper: {rs_hyper})")
logger.info(f"[{'ON' if use_elastic else 'OFF'}] Elastic Deformation")
logger.info(f"[{'ON' if use_jpeg else 'OFF'}] JPEG Compression")
logger.info("=========================================\n")


geom_global =[
transforms.RandomResizedCrop(global_crops_size, scale=global_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(p=0.5),
]
geom_local =[
transforms.RandomResizedCrop(local_crops_size, scale=local_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.RandomHorizontalFlip(p=0.5),
]

if vflip_p > 0:
geom_global.append(transforms.RandomVerticalFlip(p=vflip_p))
geom_local.append(transforms.RandomVerticalFlip(p=vflip_p))
if rotate90_p > 0:
geom_global.append(RandomRotation90())
geom_local.append(RandomRotation90())

self.geometric_augmentation_global = transforms.Compose(geom_global)
self.geometric_augmentation_local = transforms.Compose(geom_local)

def build_pixel_pipeline(blur_probability):
pipeline =[]

if use_elastic:
pipeline.append(ElasticDeformation(probability=0.5))

if rs_hyper is not None:
pipeline.append(RandStainNA(color_space='LAB', std_hyper=rs_hyper, distribution='normal', probability=0.5))

if hed_range is not None:
pipeline.append(hed_mod(probability=0.5, perturbation_range=hed_range))

if cj_params is not None:
pipeline.append(transforms.RandomApply([transforms.ColorJitter(*cj_params)], p=0.8))
pipeline.append(transforms.RandomGrayscale(p=0.2))

if gb_params is not None:
k, s_min, s_max = gb_params
pipeline.append(transforms.RandomApply([transforms.GaussianBlur(kernel_size=k, sigma=(s_min, s_max))], p=blur_probability))

if use_jpeg:
pipeline.append(JpegCompression(probability=0.5))

pipeline.extend([
transforms.ToTensor(),
make_normalize_transform(),
]
)

# Pathology-specific stain augmentations
randstainna = RandStainNA(
color_space='LAB',
std_hyper=-0.3,
distribution='normal',
probability=0.5,
)

hed_aug = hed_mod(probability=0.5, perturbation_range=0.05)

self.global_transfo1 = transforms.Compose([
# randstainna,
hed_aug,
transforms.RandomApply([transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.05)], p=0.8),
transforms.RandomGrayscale(p=0.2),
GaussianBlur(p=1.0),
self.normalize
])

self.global_transfo2 = transforms.Compose([
# randstainna,
hed_aug,
transforms.RandomApply([transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.05)], p=0.8),
transforms.RandomGrayscale(p=0.2),
GaussianBlur(p=0.1),
self.normalize
])

self.local_transfo = transforms.Compose([
# randstainna,
hed_aug,
transforms.RandomApply([transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.05)], p=0.8),
transforms.RandomGrayscale(p=0.2),
GaussianBlur(p=0.5),
self.normalize
])
])
return transforms.Compose(pipeline)

# DINOv2 uses asymmetric blur probabilities for the different crops
self.global_transfo1 = build_pixel_pipeline(blur_probability=1.0)
self.global_transfo2 = build_pixel_pipeline(blur_probability=0.1)
self.local_transfo = build_pixel_pipeline(blur_probability=0.5)

# Normalization (ImageNet stats used by default)
self.normalize = transforms.Compose([transforms.ToTensor(), make_normalize_transform()])

def __call__(self, image):
output = {}
Expand Down
Loading