diff --git a/ci/cscs.yml b/ci/cscs.yml index 1213c70..732a449 100644 --- a/ci/cscs.yml +++ b/ci/cscs.yml @@ -18,11 +18,9 @@ build_job: test_job: stage: test - extends: .container-runner-clariden-gh200 + extends: .container-runner-santis-gh200 image: $PERSIST_IMAGE_NAME script: - # - echo 'hello world' - - ls /opt - pytest /opt/hirad-gen/tests -v variables: USE_MPI: NO diff --git a/ci/docker/Dockerfile.ci b/ci/docker/Dockerfile.ci index 81bc8fb..f496fd4 100644 --- a/ci/docker/Dockerfile.ci +++ b/ci/docker/Dockerfile.ci @@ -4,7 +4,8 @@ FROM jfrog.svc.cscs.ch/docker-group-csstaff/alps-images/ngc-physicsnemo:25.11-al RUN pip install --upgrade pip # install dependencies -RUN pip install mlflow +RUN pip install mlflow \ + anemoi-datasets COPY . /opt/hirad-gen diff --git a/src/hirad/conf/generation/era_real.yaml b/src/hirad/conf/generation/era_real.yaml index e6e35ec..6369886 100644 --- a/src/hirad/conf/generation/era_real.yaml +++ b/src/hirad/conf/generation/era_real.yaml @@ -36,7 +36,7 @@ perf: force_fp16: False # Whether to force fp16 precision for the model. If false, it'll use the precision # specified upon training. - use_torch_compile: False + use_torch_compile: True # whether to use torch.compile on the diffusion model # this will make the first time stamp generation very slow due to compilation overheads # but will significantly speed up subsequent inference runs diff --git a/src/hirad/datasets/anemoi_dataset.py b/src/hirad/datasets/anemoi_dataset.py index 77ded7c..a36d16e 100644 --- a/src/hirad/datasets/anemoi_dataset.py +++ b/src/hirad/datasets/anemoi_dataset.py @@ -191,11 +191,11 @@ def __getitem__(self, idx): # next two steps only if target is cosmo, real has to be regridded first (done in training loop on gpu-s for efficiency) # reshape to image_shape # flip so that it starts in top-left corner (by default it is bottom left) - if not self.real_target: - target_shape = self.image_shape() - target_data = np.flip(target_data \ - .reshape(-1,*target_shape), - 1) + # if not self.real_target: + # target_shape = self.image_shape() + # target_data = np.flip(target_data \ + # .reshape(-1,*target_shape), + # 1) return torch.from_numpy(target_data.copy()),\ torch.from_numpy(input_data),\ @@ -344,7 +344,7 @@ def make_time_grids(self, dates: list[str], device: torch.device, dtype: torch.d Returns ------- - grid : torch.Tensor, shape (B, C, H, W) + grid : torch.Tensor, shape (B, C) Channels = [sin(k*hour), cos(k*hour), sin(k*month), cos(k*month) for each k] """ diff --git a/src/hirad/inference/generate.py b/src/hirad/inference/generate.py index a8e77f0..2fc70ca 100644 --- a/src/hirad/inference/generate.py +++ b/src/hirad/inference/generate.py @@ -42,8 +42,8 @@ def main(cfg: DictConfig) -> None: if dist.world_size > 1: torch.distributed.barrier() - use_apex_gn = cfg.generation.perf.get("use_apex_gn", False) - input_dtype = torch.float16 if cfg.generation.perf.get("force_fp16", False) and use_apex_gn else torch.float32 + # Set precision for inference + input_dtype = torch.float16 if cfg.generation.perf.get("force_fp16", False) else torch.float32 # Parse the inference input times if cfg.generation.get("times_range", None) and cfg.generation.get("times", None): @@ -65,6 +65,7 @@ def main(cfg: DictConfig) -> None: dataset_cfg=dataset_cfg, times=times, has_lead_time=has_lead_time ) dataset.stats_to_torch(device=dist.device, dtype=input_dtype) + dataset.interpolator.to(device=dist.device) is_real_target = dataset_cfg.get("type").split("_")[-1] == "real" if is_real_target: dataset.regrid_indices_real = dataset.regrid_indices_real.to(dist.device) @@ -72,6 +73,8 @@ def main(cfg: DictConfig) -> None: img_shape = dataset.image_shape() img_out_channels = len(dataset.output_channels()) + + #TODO: Isolate loading into the method of generator # Parse the inference mode if cfg.generation.inference_mode == "regression": load_net_reg, load_net_res = True, False @@ -92,6 +95,10 @@ def main(cfg: DictConfig) -> None: raise FileNotFoundError(f"Missing config file at '{diffusion_model_args_path}'.") with open(diffusion_model_args_path, 'r') as f: diffusion_model_args = json.load(f) + # Disable AMP for inference (even if model is trained with AMP) + if "amp_mode" in diffusion_model_args: + diffusion_model_args["amp_mode"] = False + use_apex_gn = diffusion_model_args.get("use_apex_gn", False) net_res = EDMPrecondSuperResolution(**diffusion_model_args) @@ -101,13 +108,11 @@ def main(cfg: DictConfig) -> None: device=dist.device ) - net_res = net_res.eval().to(device).to(memory_format=torch.channels_last) + net_res = net_res.eval().to(device) + if use_apex_gn: + net_res = net_res.to(memory_format=torch.channels_last) if cfg.generation.perf.force_fp16: net_res.use_fp16 = True - - # Disable AMP for inference (even if model is trained with AMP) - if hasattr(net_res, "amp_mode"): - net_res.amp_mode = False else: net_res = None @@ -122,6 +127,10 @@ def main(cfg: DictConfig) -> None: raise FileNotFoundError(f"Missing config file at '{regression_model_args_path}'.") with open(regression_model_args_path, 'r') as f: regression_model_args = json.load(f) + # Disable AMP for inference (even if model is trained with AMP) + if "amp_mode" in regression_model_args: + regression_model_args["amp_mode"] = False + use_apex_gn_reg = regression_model_args.get("use_apex_gn", False) net_reg = UNet(**regression_model_args) @@ -131,24 +140,25 @@ def main(cfg: DictConfig) -> None: device=dist.device ) - net_reg = net_reg.eval().to(device).to(memory_format=torch.channels_last) + net_reg = net_reg.eval().to(device) + if use_apex_gn_reg: + net_reg = net_reg.to(memory_format=torch.channels_last) if cfg.generation.perf.force_fp16: net_reg.use_fp16 = True - - # Disable AMP for inference (even if model is trained with AMP) - if hasattr(net_reg, "amp_mode"): - net_reg.amp_mode = False else: net_reg = None # Reset since we are using a different mode. if cfg.generation.perf.use_torch_compile: + torch._dynamo.config.cache_size_limit = 264 torch._dynamo.reset() - # Only compile residual network - # Overhead of compiling regression network outweights any benefits if net_res: - net_res = torch.compile(net_res) #, mode="reduce-overhead") - # removed reduce-overhead because it was breaking cuda graph compilation + net_res = torch.compile(net_res) + if net_reg: + net_reg = torch.compile(net_reg) + + + generator = Generator( net_reg=net_reg, net_res=net_res, @@ -186,147 +196,152 @@ def main(cfg: DictConfig) -> None: torch_cuda_profiler = ( torch.cuda.profiler.profile() - if torch.cuda.is_available() + if torch.cuda.is_available() and cfg.generation.perf.get("profile", False) else contextlib.nullcontext() ) torch_nvtx_profiler = ( torch.autograd.profiler.emit_nvtx() - if torch.cuda.is_available() + if torch.cuda.is_available() and cfg.generation.perf.get("profile", False) else contextlib.nullcontext() ) with torch_cuda_profiler: with torch_nvtx_profiler: + with torch.inference_mode(): - data_loader = torch.utils.data.DataLoader( - dataset=dataset, sampler=sampler, batch_size=1, pin_memory=True - ) - time_index = -1 - if dist.rank == 0: - writer_executor = ThreadPoolExecutor( - max_workers=cfg.generation.perf.num_writer_workers + data_loader = torch.utils.data.DataLoader( + dataset=dataset, sampler=sampler, batch_size=1, pin_memory=True ) - writer_threads = [] - - # Create timer objects only if CUDA is available - use_cuda_timing = torch.cuda.is_available() - if use_cuda_timing: - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - else: - # Dummy no-op functions for CPU case - class DummyEvent: - def record(self): - pass - - def synchronize(self): - pass - - def elapsed_time(self, _): - return 0 - - start = end = DummyEvent() - - dataset.interpolator.to_torch(device=dist.device) - - static_channels = dataset.get_static_data() - if static_channels is not None: - static_channels = static_channels[None, ::].flip(-2) - if use_apex_gn: - static_channels = static_channels.to( - dist.device, - dtype=input_dtype, - non_blocking=True, - ).to(memory_format=torch.channels_last) - else: - static_channels = ( - static_channels.to(dist.device) - .to(input_dtype) - .contiguous() - ) - lead_time_label = None - - times = dataset.time() - for index, (image_tar, image_lr, *date_str) in enumerate( - iter(data_loader) - ): - time_index += 1 + time_index = -1 if dist.rank == 0: - logger0.info(f"starting index: {time_index} time: {times[sampler[time_index]]}") - - if time_index == warmup_steps: - start.record() - - savedir = os.path.join(output_path,f"{times[sampler[time_index]]}") - os.makedirs(savedir,exist_ok=True) - # continue - if is_real_target: - image_tar = regrid_icon_to_rotlatlon( - image_tar.to(dist.device, dtype=input_dtype), - dataset.regrid_indices_real, - dataset.regrid_weights_real, + writer_executor = ThreadPoolExecutor( + max_workers=cfg.generation.perf.num_writer_workers ) - if dataset.trim_edge > 0: - image_tar = image_tar[:, :, dataset.trim_edge:-dataset.trim_edge, dataset.trim_edge:-dataset.trim_edge] - if lead_time_label: - lead_time_label = lead_time_label[0].to(dist.device).contiguous() - else: - lead_time_label = None - image_lr = dataset.interpolator(image_lr.to(dist.device, dtype=input_dtype)).reshape(*image_lr.shape[:-1], *dataset.image_shape()).flip(-2) - image_lr = dataset.normalize_input(image_lr) - image_lr = image_lr.to(memory_format=torch.channels_last) - random_seed = cfg.generation.get("random_seed", None)+index if cfg.generation.get("randomize", False) and cfg.generation.get("random_seed", None) is not None else None - date_embedding = None - if dataset._n_month_hour_channels: - date_embedding = dataset.make_time_grids(*date_str, dist.device, dtype=input_dtype) - image_out, image_reg = generator.generate( - image_lr, - static_channels=static_channels, - date_embedding=date_embedding, - lead_time_label=lead_time_label, - randomize=cfg.generation.get("randomize", False), - random_seed=random_seed - ) + writer_threads = [] - if dist.rank == 0: - batch_size = image_out.shape[0] - # write out data in a seperate thread so we don't hold up inferencing - image_tar = image_tar[0].squeeze().cpu().numpy() - prediction_ensemble = dataset.denormalize_output(image_out).squeeze().flip(-2).cpu().numpy() - baseline = dataset.denormalize_input(image_lr)[0].squeeze().flip(-2).cpu().numpy() - if image_reg is not None: - mean_pred = dataset.denormalize_output(image_reg)[0].squeeze().flip(-2).cpu().numpy() - writer_threads.append( - writer_executor.submit( - save_results_as_torch, - savedir, - times[sampler[time_index]], - prediction_ensemble, - image_tar, - baseline, - mean_pred if image_reg is not None else None, + # Create timer objects only if CUDA is available + use_cuda_timing = torch.cuda.is_available() + if use_cuda_timing: + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + else: + # Dummy no-op functions for CPU case + class DummyEvent: + def record(self): + pass + + def synchronize(self): + pass + + def elapsed_time(self, _): + return 0 + + start = end = DummyEvent() + + #TODO: Isolate static channel loading into the method of generator or reuse training manager static channel loading + static_channels = dataset.get_static_data() + if static_channels is not None: + static_channels = static_channels[None, ::].flip(-2) + if use_apex_gn: + static_channels = static_channels.to( + dist.device, + dtype=input_dtype, + non_blocking=True, + ).to(memory_format=torch.channels_last) + else: + static_channels = ( + static_channels.to(dist.device) + .to(input_dtype) + .contiguous() + ) + lead_time_label = None + + times = dataset.time() + for index, (image_tar, image_lr, *date_str) in enumerate( + iter(data_loader) + ): + time_index += 1 + if dist.rank == 0: + logger0.info(f"starting index: {time_index} time: {times[sampler[time_index]]}") + + if time_index == warmup_steps: + start.record() + + savedir = os.path.join(output_path,f"{times[sampler[time_index]]}") + os.makedirs(savedir,exist_ok=True) + + #TODO: Move all the data processing inside the generator and just pass raw data to it. This includes regridding, normalization, date embedding creation, etc. + # Same as with static channel loading, we can reuse some of the code from training manager for this. This will also make it easier to maintain and update the data processing steps in one place. + if is_real_target: + image_tar = regrid_icon_to_rotlatlon( + image_tar.to(dist.device, dtype=input_dtype), + dataset.regrid_indices_real, + dataset.regrid_weights_real, ) + if dataset.trim_edge > 0: + image_tar = image_tar[:, :, dataset.trim_edge:-dataset.trim_edge, dataset.trim_edge:-dataset.trim_edge] + else: + image_tar = image_tar.reshape(*image_tar.shape[:-1], *dataset.image_shape()) + if lead_time_label: + lead_time_label = lead_time_label[0].to(dist.device).contiguous() + else: + lead_time_label = None + image_lr = dataset.interpolator(image_lr.to(dist.device, dtype=input_dtype)).reshape(*image_lr.shape[:-1], *dataset.image_shape()).flip(-2) + image_lr = dataset.normalize_input(image_lr) + if use_apex_gn: + image_lr = image_lr.to(memory_format=torch.channels_last) + date_embedding = None + if dataset._n_month_hour_channels: + date_embedding = dataset.make_time_grids(*date_str, dist.device, dtype=input_dtype) + random_seed = cfg.generation.get("random_seed", None)+index if cfg.generation.get("randomize", False) and cfg.generation.get("random_seed", None) is not None else None + image_out, image_reg = generator.generate( + image_lr, + static_channels=static_channels, + date_embedding=date_embedding, + lead_time_label=lead_time_label, + randomize=cfg.generation.get("randomize", False), + random_seed=random_seed + ) + + if dist.rank == 0: + batch_size = image_out.shape[0] + # write out data in a seperate thread so we don't hold up inferencing + image_tar = image_tar[0].squeeze().cpu().numpy() + prediction_ensemble = dataset.denormalize_output(image_out).squeeze().flip(-2).cpu().numpy() + baseline = dataset.denormalize_input(image_lr)[0].squeeze().flip(-2).cpu().numpy() + if image_reg is not None: + mean_pred = dataset.denormalize_output(image_reg)[0].squeeze().flip(-2).cpu().numpy() + writer_threads.append( + writer_executor.submit( + save_results_as_torch, + savedir, + times[sampler[time_index]], + prediction_ensemble, + image_tar, + baseline, + mean_pred if image_reg is not None else None, + ) + ) + end.record() + end.synchronize() + elapsed_time = ( + start.elapsed_time(end) / 1000.0 if use_cuda_timing else 0 + ) # Convert ms to s + timed_steps = time_index + 1 - warmup_steps + if dist.rank == 0 and use_cuda_timing: + average_time_per_batch_element = elapsed_time / timed_steps / batch_size + logger.info( + f"Total time to run {timed_steps} steps and {batch_size} members = {elapsed_time} s" + ) + logger.info( + f"Average time per batch element = {average_time_per_batch_element} s" ) - end.record() - end.synchronize() - elapsed_time = ( - start.elapsed_time(end) / 1000.0 if use_cuda_timing else 0 - ) # Convert ms to s - timed_steps = time_index + 1 - warmup_steps - if dist.rank == 0 and use_cuda_timing: - average_time_per_batch_element = elapsed_time / timed_steps / batch_size - logger.info( - f"Total time to run {timed_steps} steps and {batch_size} members = {elapsed_time} s" - ) - logger.info( - f"Average time per batch element = {average_time_per_batch_element} s" - ) - # make sure all the workers are done writing - if dist.rank == 0: - for thread in list(writer_threads): - thread.result() - writer_threads.remove(thread) - writer_executor.shutdown() + # make sure all the workers are done writing + if dist.rank == 0: + for thread in list(writer_threads): + thread.result() + writer_threads.remove(thread) + writer_executor.shutdown() if dist.rank == 0: f.close() diff --git a/src/hirad/models/layers.py b/src/hirad/models/layers.py index dc5630e..edaab04 100644 --- a/src/hirad/models/layers.py +++ b/src/hirad/models/layers.py @@ -114,7 +114,7 @@ class Conv2d(torch.nn.Module): """ A custom 2D convolutional layer implementation with support for up-sampling, down-sampling, and custom weight and bias initializations. The layer's weights - and biases canbe initialized using custom initialization strategies like + and biases can be initialized using custom initialization strategies like "kaiming_normal", and can be further scaled by factors `init_weight` and `init_bias`. @@ -403,7 +403,7 @@ def forward(self, x): x = rearrange(x, "b (g c) h w -> b g c h w", g=self.num_groups) mean = x.mean(dim=[2, 3, 4], keepdim=True) - var = x.var(dim=[2, 3, 4], keepdim=True) + var = x.var(dim=[2, 3, 4], keepdim=True, unbiased=False) x = (x - mean) * (var + self.eps).rsqrt() x = rearrange(x, "b g c h w -> b (g c) h w") diff --git a/src/hirad/models/song_unet.py b/src/hirad/models/song_unet.py index ce0fedc..cc52cfd 100644 --- a/src/hirad/models/song_unet.py +++ b/src/hirad/models/song_unet.py @@ -185,7 +185,7 @@ def __init__( emb_channels=emb_channels, num_heads=1, dropout=dropout, - skip_scale=np.sqrt(0.5), + skip_scale=0.7071067811865476, # 1 / sqrt(2) eps=1e-6, resample_filter=resample_filter, resample_proj=True, @@ -659,10 +659,13 @@ def __init__( self.gridtype = gridtype self.N_grid_channels = N_grid_channels - if self.gridtype == "learnable": - self.pos_embd = self._get_positional_embedding() + if self.N_grid_channels: + if self.gridtype == "learnable": + self.pos_embd = self._get_positional_embedding() + else: + self.register_buffer("pos_embd", self._get_positional_embedding().float()) else: - self.register_buffer("pos_embd", self._get_positional_embedding().float()) + self.pos_embd = None self.lead_time_mode = lead_time_mode if self.lead_time_mode: self.lead_time_channels = lead_time_channels @@ -693,7 +696,13 @@ def forward( "embedding_selector is the preferred approach for better efficiency." ) - if x.dtype != self.pos_embd.dtype: + if self.lead_time_mode and embedding_selector is not None: + raise ValueError( + "Embedding selector is not supported in lead time mode. " + "Please use global_index to select positional embeddings when lead_time_mode is True." + ) + + if self.pos_embd is not None and x.dtype != self.pos_embd.dtype: self.pos_embd = self.pos_embd.to(x.dtype) # Append positional embedding to input conditioning @@ -780,7 +789,7 @@ def positional_embedding_indexing( Example ------- >>> # Create global indices using patching utility: - >>> from physicsnemo.utils.patching import GridPatching2D + >>> from hirad.utils.patching import GridPatching2D >>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(8, 8)) >>> global_index = patching.global_index(batch_size=3) >>> print(global_index.shape) @@ -788,9 +797,9 @@ def positional_embedding_indexing( See Also -------- - :meth:`physicsnemo.utils.patching.RandomPatching2D.global_index` + :meth:`hirad.utils.patching.RandomPatching2D.global_index` For generating random patch indices. - :meth:`physicsnemo.utils.patching.GridPatching2D.global_index` + :meth:`hirad.utils.patching.GridPatching2D.global_index` For generating deterministic grid-based patch indices. See these methods for possible ways to generate the global_index parameter. """ @@ -900,7 +909,7 @@ def positional_embedding_selector( Each selected embedding should correspond to the positional information of each batch element in x. For patch-based processing, typically this should be based on - :meth:`physicsnemo.utils.patching.BasePatching2D.apply` method to + :meth:`hirad.utils.patching.BasePatching2D.apply` method to maintain consistency with patch extraction. embeds : Optional[torch.Tensor] Optional tensor for combined positional and lead time embeddings tensor @@ -969,6 +978,10 @@ def _get_positional_embedding(self): raise ValueError("N_grid_channels must be a factor of 4") num_freq = self.N_grid_channels // 4 freq_bands = 2.0 ** np.linspace(0.0, num_freq, num=num_freq) + #TODO: When more than 4 channels are used for sinusoidal, the frequencies should be multiples of the base frequency (2). + # freq_bands = 2.0 ** np.linspace(0.0, num_freq, num=num_freq) is currently in code which gives + # freqs = [1,4] instead of [1,2] for N_grid_channels=8. This seems to be a bug if we want the base 2. + # Leaving it like this for now since we have checkpoints with 8 sinusoidal channels that use these frequencies, grid_list = [] grid_x, grid_y = np.meshgrid( np.linspace(0, 2 * np.pi, self.img_shape_x), diff --git a/src/hirad/models/unet.py b/src/hirad/models/unet.py index ac645cf..82b9a27 100644 --- a/src/hirad/models/unet.py +++ b/src/hirad/models/unet.py @@ -63,42 +63,6 @@ class UNet(nn.Module): # TODO a lot of redundancy, need to clean up arXiv preprint arXiv:2309.15214. """ - @classmethod - def _backward_compat_arg_mapper( - cls, version: str, args: Dict[str, Any] - ) -> Dict[str, Any]: - """Map arguments from older versions to current version format. - - Parameters - ---------- - version : str - Version of the checkpoint being loaded - args : Dict[str, Any] - Arguments dictionary from the checkpoint - - Returns - ------- - Dict[str, Any] - Updated arguments dictionary compatible with current version - """ - # Call parent class method first - args = super()._backward_compat_arg_mapper(version, args) - - if version == "0.1.0": - # In version 0.1.0, img_channels was unused - if "img_channels" in args: - _ = args.pop("img_channels") - - # Sigma parameters are also unused - if "sigma_min" in args: - _ = args.pop("sigma_min") - if "sigma_max" in args: - _ = args.pop("sigma_max") - if "sigma_data" in args: - _ = args.pop("sigma_data") - - return args - def __init__( self, img_resolution: Union[int, Tuple[int, int]], @@ -217,8 +181,8 @@ def forward( ) F_x = self.model( - x.to(dtype), # (c_in * x).to(dtype), - torch.zeros(x.shape[0], dtype=dtype, device=x.device), # c_noise.flatten() + x.to(dtype), + torch.zeros(x.shape[0], dtype=dtype, device=x.device), class_labels=None, **model_kwargs, ) @@ -228,7 +192,6 @@ def forward( f"Expected the dtype to be {dtype}, " f"but got {F_x.dtype} instead." ) - # skip connection D_x = F_x.to(torch.float32) return D_x diff --git a/src/hirad/training/__init__.py b/src/hirad/training/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/hirad/training/train.py b/src/hirad/training/train.py index 27346ea..ddf55cb 100755 --- a/src/hirad/training/train.py +++ b/src/hirad/training/train.py @@ -3,12 +3,12 @@ from concurrent.futures import ThreadPoolExecutor -import psutil import hydra from omegaconf import DictConfig, OmegaConf import json from contextlib import nullcontext import nvtx +import numpy as np import torch from hydra.utils import to_absolute_path # from torch.utils.tensorboard import SummaryWriter @@ -19,9 +19,10 @@ from hirad.distributed import DistributedManager from hirad.utils.console import PythonLogger, RankZeroLoggingWrapper from hirad.utils.train_helpers import set_seed, configure_cuda_for_consistent_precision, \ - set_patch_shape, compute_num_accumulation_rounds, \ + set_patch_shape, compute_num_accumulation_rounds, calculate_patch_per_iter, \ is_time_for_periodic_task, handle_and_clip_gradients, \ - init_mlflow + init_mlflow, update_learning_rate, log_training_progress, \ + cuda_profiler, cuda_profiler_start, cuda_profiler_stop, profiler_emit_nvtx from hirad.utils.checkpoint import load_checkpoint, save_checkpoint from hirad.utils.patching import RandomPatching2D from hirad.utils.function_utils import get_time_from_range @@ -32,6 +33,9 @@ from hirad.losses import ResidualLoss, RegressionLoss from hirad.datasets import init_train_valid_datasets_from_config, get_dataset_and_sampler_inference from hirad.inference import Generator +from hirad.training.training_manager import TrainingManagerCorrDiff + + torch._dynamo.reset() # Increase the cache size limit @@ -40,38 +44,17 @@ torch._dynamo.config.suppress_errors = False # Forces the error to show all details torch._logging.set_logs(recompiles=True, graph_breaks=True) -# Define safe CUDA profiler tools that fallback to no-ops when CUDA is not available -def cuda_profiler(): - if torch.cuda.is_available(): - return torch.cuda.profiler.profile() - else: - return nullcontext() - - -def cuda_profiler_start(): - if torch.cuda.is_available(): - torch.cuda.profiler.start() - - -def cuda_profiler_stop(): - if torch.cuda.is_available(): - torch.cuda.profiler.stop() - - -def profiler_emit_nvtx(): - if torch.cuda.is_available(): - return torch.autograd.profiler.emit_nvtx() - else: - return nullcontext() @hydra.main(version_base=None, config_path="../conf", config_name="training") def main(cfg: DictConfig) -> None: + # Initialize distributed environment for training DistributedManager.initialize() dist = DistributedManager() OmegaConf.resolve(cfg) + # Initialize logging if cfg.logging.method == "mlflow": init_mlflow(cfg, dist) if dist.world_size > 1: @@ -82,20 +65,39 @@ def main(cfg: DictConfig) -> None: logger = PythonLogger("main") # general logger logger0 = RankZeroLoggingWrapper(logger, dist) # rank 0 logger - dataset_cfg = OmegaConf.to_container(cfg.dataset) - train_test_split = getattr(cfg.dataset, "validation", False) - fp_optimizations = cfg.training.perf.fp_optimizations - songunet_checkpoint_level = cfg.training.perf.songunet_checkpoint_level - fp16 = fp_optimizations == "fp16" - enable_amp = fp_optimizations.startswith("amp") - amp_dtype = torch.float16 if (fp_optimizations == "amp-fp16") else torch.bfloat16 logger0.info(f"Config is: {cfg}") logger0.info(f"Saving the outputs in {os.getcwd()}") + + # create checkpoint directory if it doesn't exist checkpoint_dir = os.path.join( cfg.training.io.get("checkpoint_dir", "."), f"checkpoints_{cfg.model.name}" ) if dist.rank==0 and not os.path.exists(checkpoint_dir): - os.makedirs(checkpoint_dir) # added creating checkpoint dir + os.makedirs(checkpoint_dir) + + # performance optimization configuration + use_torch_compile = getattr(cfg.training.perf, "torch_compile", False) + use_apex_gn = getattr(cfg.training.perf, "use_apex_gn", False) + profile_mode = getattr(cfg.training.perf, "profile_mode", False) + fp_optimizations = cfg.training.perf.fp_optimizations + songunet_checkpoint_level = cfg.training.perf.songunet_checkpoint_level + fp16 = fp_optimizations == "fp16" + enable_amp = fp_optimizations.startswith("amp") + amp_dtype = torch.float16 if (fp_optimizations == "amp-fp16") else torch.bfloat16 + + # set the data type for model inputs based on optimization configuration + input_dtype = torch.float32 + if enable_amp: + input_dtype = torch.float32 + elif fp16: + input_dtype = torch.float16 + + # dataset configuration + dataset_cfg = OmegaConf.to_container(cfg.dataset) + train_test_split = getattr(cfg.dataset, "validation", False) + n_month_hour_channels = 2*dataset_cfg.get("n_month_hour_channels", 0) + + # validate and set batch size configuration if cfg.training.hp.batch_size_per_gpu == "auto" and \ cfg.training.hp.total_batch_size == "auto": raise ValueError("batch_size_per_gpu and total_batch_size can't be both set to 'auto'.") @@ -108,8 +110,10 @@ def main(cfg: DictConfig) -> None: cfg.training.hp.batch_size_per_gpu * dist.world_size ) + # Get the current training step from the checkpoint if it exists, otherwise start from 0. cur_nimg = load_checkpoint(path=checkpoint_dir) + # Fix the seed based on training progress for reproducibility. set_seed(dist.rank + cur_nimg) configure_cuda_for_consistent_precision() @@ -132,27 +136,29 @@ def main(cfg: DictConfig) -> None: train_test_split=train_test_split, sampler_start_idx=cur_nimg, ) - dataset.interpolator.to_torch(device=dist.device) is_real_target = dataset_cfg.get("type").split("_")[-1] == "real" logger0.info(f"Training on dataset with size {len(dataset)}") logger0.info(f"Validating on dataset with size {len(validation_dataset) if validation_dataset else 0}") - # Parse image configuration & update model args - n_month_hour_channels = 2*dataset_cfg.get("n_month_hour_channels", 0) - dataset_channels = len(dataset.input_channels()) + len(dataset.static_channels()) + n_month_hour_channels - img_in_channels = dataset_channels + # Get the shape of the grid (without the channel dimension) for later use in model creation and patching img_shape = dataset.image_shape() - img_out_channels = len(dataset.output_channels()) - if cfg.model.hr_mean_conditioning: - img_in_channels += img_out_channels - static_channels = dataset.get_static_data() - logger0.info(f"Training on dataset with grid size {img_shape[0]}x{img_shape[1]}, {img_in_channels} input channels and {img_out_channels} output channels.") + + logger0.info(f"Training on dataset with grid size {img_shape[0]}x{img_shape[1]}, {len(dataset.input_channels())} input channels and {len(dataset.output_channels())} output channels.") logger0.info(f"Input channels: {dataset.input_channels()}") logger0.info(f"Output channels: {dataset.output_channels()}") logger0.info(f"Static channels: {dataset.static_channels()}") + # convert dataset stats to torch tensors on the correct device for later use in normalization and denormalization + dataset.stats_to_torch(device=dist.device, dtype=input_dtype) + # convert dataset stats to torch tensors on the correct device for later use in loss normalization and denormalization + dataset.interpolator.to(device=dist.device) + # convert regridding weights and indices to torch tensors on the correct device if real target dataset is used + if is_real_target: + dataset.regrid_indices_real = dataset.regrid_indices_real.to(dist.device) + dataset.regrid_weights_real = dataset.regrid_weights_real.to(dist.device, dtype=input_dtype) + if cfg.model.name == "lt_aware_ce_regression": - prob_channels = dataset.get_prob_channel_index() #TODO figure out what prob_channel are and update dataloader + prob_channels = dataset.get_prob_channel_index() else: prob_channels = None @@ -178,8 +184,9 @@ def main(cfg: DictConfig) -> None: ) patch_shape = (patch_shape_y, patch_shape_x) use_patching, img_shape, patch_shape = set_patch_shape(img_shape, patch_shape) + + # Initialize patcher if patch-based training is enabled if use_patching: - # Utility to perform patches extraction and batching patching = RandomPatching2D( img_shape=img_shape, patch_shape=patch_shape, @@ -189,62 +196,31 @@ def main(cfg: DictConfig) -> None: else: patching = None logger0.info("Patch-based training disabled") - # interpolate global channel if patch-based model is used - if use_patching: - img_in_channels += len(dataset.input_channels()) + len(dataset.static_channels()) - - # Instantiate the model and move to device. - model_args = { # default parameters for all networks - "img_out_channels": img_out_channels, - "img_resolution": list(img_shape), - "use_fp16": fp16, - "checkpoint_level": songunet_checkpoint_level, - } - if cfg.model.name == "lt_aware_ce_regression": - model_args["prob_channels"] = prob_channels - if hasattr(cfg.model, "model_args"): # override defaults from config file - model_args.update(OmegaConf.to_container(cfg.model.model_args)) - - use_torch_compile = getattr(cfg.training.perf, "torch_compile", False) - use_apex_gn = getattr(cfg.training.perf, "use_apex_gn", False) - profile_mode = getattr(cfg.training.perf, "profile_mode", False) - - model_args["use_apex_gn"] = use_apex_gn - model_args["profile_mode"] = profile_mode - - if enable_amp: - model_args["amp_mode"] = enable_amp + # Instantiate the training manager which handles model creation, + # data loading and transformation, + # and validation + training_manager = TrainingManagerCorrDiff( + dist, + logger0, + dataset, + input_dtype, + img_shape, + n_month_hour_channels, + fp16, + profile_mode, + enable_amp, + amp_dtype, + use_apex_gn, + is_real_target, + songunet_checkpoint_level, + use_patching, + cfg.model.get("hr_mean_conditioning", False), + cfg.logging.get("method", None) + ) - - if cfg.model.name == "regression": - model = UNet( - img_in_channels=img_in_channels + model_args["N_grid_channels"], - **model_args, - ) - model_args["img_in_channels"] = img_in_channels + model_args["N_grid_channels"] - elif cfg.model.name == "lt_aware_ce_regression": - model = UNet( - img_in_channels=img_in_channels - + model_args["N_grid_channels"] - + model_args["lead_time_channels"], - **model_args, - ) - model_args["img_in_channels"] = img_in_channels + model_args["N_grid_channels"] + model_args["lead_time_channels"] - elif cfg.model.name == "lt_aware_patched_diffusion": - model = EDMPrecondSuperResolution( - img_in_channels=img_in_channels - + model_args["N_grid_channels"] - + model_args["lead_time_channels"], - **model_args, - ) - model_args["img_in_channels"] = img_in_channels + model_args["N_grid_channels"] + model_args["lead_time_channels"] - else: # diffusion or patched diffusion - model = EDMPrecondSuperResolution( - img_in_channels=img_in_channels + model_args["N_grid_channels"], - **model_args, - ) - model_args["img_in_channels"] = img_in_channels + model_args["N_grid_channels"] + # Create the model and move it to the appropriate device and memory format based on the optimization configuration + model, model_args = training_manager.create_model(cfg.model.name, cfg.model.get("model_args", None)) # # Print the model summary # if dist.rank == 0: @@ -271,57 +247,10 @@ def main(cfg: DictConfig) -> None: f"Regression model ({cfg.model.name}) cannot be used with patch-based training. " ) - # Enable distributed data parallel if applicable - if dist.world_size > 1: - if use_torch_compile: - model = torch.compile(model) - model = DistributedDataParallel( - model, - device_ids=[dist.local_rank], - broadcast_buffers=True, - output_device=dist.device, - find_unused_parameters=True, # dist.find_unused_parameters, - bucket_cap_mb=35, - gradient_as_bucket_view=True, - ) - # Load the regression checkpoint if applicable #TODO test when training correction + regression_net = None if hasattr(cfg.training.io, "regression_checkpoint_path"): - regression_checkpoint_path = to_absolute_path( - cfg.training.io.regression_checkpoint_path - ) - if not os.path.isdir(regression_checkpoint_path): - raise FileNotFoundError( - f"Expected this regression checkpoint but not found: {regression_checkpoint_path}" - ) - #TODO make regression model loading more robust (model type is both in rergession_checkpoint_path and regression_name) - #TODO add the option to choose epoch to load from / regression_checkpoint_path is now a folder - regression_model_args_path = os.path.join(regression_checkpoint_path, 'model_args.json') - if not os.path.isfile(regression_model_args_path): - raise FileNotFoundError(f"Missing config file at '{regression_model_args_path}'.") - - with open(regression_model_args_path, 'r') as f: - regression_model_args = json.load(f) - - regression_model_args.update({ - "use_apex_gn": use_apex_gn, - "profile_mode": profile_mode, - "amp_mode": enable_amp, - }) - - regression_net = UNet(**regression_model_args) - - _ = load_checkpoint( - path=regression_checkpoint_path, - model=regression_net, - device=dist.device - ) - regression_net.eval().requires_grad_(False).to(dist.device) - if use_apex_gn: - regression_net.to(memory_format=torch.channels_last) - logger0.success("Loaded the pre-trained regression model") - else: - regression_net = None + regression_net = training_manager.load_regression_model(to_absolute_path(cfg.training.io.regression_checkpoint_path)) # Compute the number of required gradient accumulation rounds @@ -334,29 +263,14 @@ def main(cfg: DictConfig) -> None: batch_size_per_gpu = cfg.training.hp.batch_size_per_gpu logger0.info(f"Using {num_accumulation_rounds} gradient accumulation rounds") - # calculate patch per iter + # calculate patch per iter patch_num = getattr(cfg.training.hp, "patch_num", 1) - if hasattr(cfg.training.hp, "max_patch_per_gpu"): - max_patch_per_gpu = cfg.training.hp.max_patch_per_gpu - if max_patch_per_gpu // batch_size_per_gpu < 1: - raise ValueError( - f"max_patch_per_gpu ({max_patch_per_gpu}) must be greater or equal to batch_size_per_gpu ({batch_size_per_gpu})." - ) - max_patch_num_per_iter = min( - patch_num, (max_patch_per_gpu // batch_size_per_gpu) - ) - patch_iterations = ( - patch_num + max_patch_num_per_iter - 1 - ) // max_patch_num_per_iter - patch_nums_iter = [ - min(max_patch_num_per_iter, patch_num - i * max_patch_num_per_iter) - for i in range(patch_iterations) - ] - logger0.info( - f"max_patch_num_per_iter is {max_patch_num_per_iter}, patch_iterations is {patch_iterations}, patch_nums_iter is {patch_nums_iter}" - ) - else: - patch_nums_iter = [patch_num] + max_patch_per_gpu = getattr(cfg.training.hp, "max_patch_per_gpu", None) + patch_nums_iter = calculate_patch_per_iter(patch_num, max_patch_per_gpu, batch_size_per_gpu) + + logger0.info( + f"Patch number iterations are {patch_nums_iter}" + ) # Set patch gradient accumulation only for patched diffusion models if cfg.model.name in { @@ -403,9 +317,6 @@ def main(cfg: DictConfig) -> None: fused=True, ) - # Record the current time to measure the duration of subsequent operations. - start_time = time.time() - # Load optimizer checkpoint if it exists if dist.world_size > 1: torch.distributed.barrier() @@ -421,16 +332,33 @@ def main(cfg: DictConfig) -> None: # Compile the model and regression net if applicable if use_torch_compile: - if dist.world_size==1: - model = torch.compile(model) + # if dist.world_size==1: + model = torch.compile(model) if regression_net: regression_net = torch.compile(regression_net) + # Enable distributed data parallel if applicable + if dist.world_size > 1: + # if use_torch_compile: + # model = torch.compile(model) + model = DistributedDataParallel( + model, + device_ids=[dist.local_rank], + broadcast_buffers=True, + output_device=dist.device, + find_unused_parameters=True, # dist.find_unused_parameters, + bucket_cap_mb=35, + gradient_as_bucket_view=True, + ) + ############################################################################ # MAIN TRAINING LOOP # ############################################################################ + # Record the current time to measure the duration of subsequent operations. + start_time = time.time() + logger0.info(f"Training for {cfg.training.hp.training_duration} images...") done = False @@ -438,35 +366,9 @@ def main(cfg: DictConfig) -> None: average_loss_running_mean = 0 n_average_loss_running_mean = 1 start_nimg = cur_nimg - input_dtype = torch.float32 - if enable_amp: - input_dtype = torch.float32 - elif fp16: - input_dtype = torch.float16 - - # convert dataset stats to torch tensors on the correct device for later use in loss normalization and denormalization - dataset.stats_to_torch(device=dist.device, dtype=input_dtype) - normalization_stats = dataset.normalization_stats() # prepare static channels if there are any - if static_channels is not None: - static_channels = static_channels[None, ::].flip(-2) - if use_apex_gn: - static_channels = static_channels.to( - dist.device, - dtype=input_dtype, - non_blocking=True, - ).to(memory_format=torch.channels_last) - else: - static_channels = ( - static_channels.to(dist.device) - .to(input_dtype) - .contiguous() - ) - - if is_real_target: - dataset.regrid_indices_real = dataset.regrid_indices_real.to(dist.device) - dataset.regrid_weights_real = dataset.regrid_weights_real.to(dist.device, dtype=input_dtype) + static_channels = training_manager.get_static_data() # turn off for lead time labels for now since we are not using them # TODO: implement lead time labels properly once we train on IFS? @@ -497,49 +399,8 @@ def main(cfg: DictConfig) -> None: ): with nvtx.annotate("loading data", color="green"): tick_read_start_time = time.time() - img_clean, img_lr, *date_str = next( - dataset_iterator - ) + img_clean, img_lr, date_embedding = training_manager.load_and_preprocess_batch(dataset_iterator) tick_read_time = time.time() - tick_read_start_time - img_lr = dataset.interpolator(img_lr.to(dist.device, dtype=input_dtype)).reshape(*img_lr.shape[:-1], *img_shape).flip(-2) - img_lr = dataset.normalize_input(img_lr) - if is_real_target: - img_clean = regrid_icon_to_rotlatlon( - img_clean.to(dist.device, dtype=input_dtype), - dataset.regrid_indices_real, - dataset.regrid_weights_real, - ) - if dataset.trim_edge > 0: - img_clean = img_clean[:, :, dataset.trim_edge:-dataset.trim_edge, dataset.trim_edge:-dataset.trim_edge] - img_clean = img_clean.flip(-2) - else: - img_clean = img_clean.to(dist.device, dtype=input_dtype) - img_clean = dataset.normalize_output(img_clean) - date_embedding = None - if n_month_hour_channels > 0: - date_embedding = dataset.make_time_grids(*date_str, dist.device, dtype=input_dtype) - if use_apex_gn: - img_clean = img_clean.to( - dist.device, - dtype=input_dtype, - non_blocking=True, - ).to(memory_format=torch.channels_last) - img_lr = img_lr.to( - dist.device, - dtype=input_dtype, - non_blocking=True, - ).to(memory_format=torch.channels_last) - else: - img_clean = ( - img_clean.to(dist.device) - .to(input_dtype) - .contiguous() - ) - img_lr = ( - img_lr.to(dist.device) - .to(input_dtype) - .contiguous() - ) loss_fn_kwargs = { "net": model, "img_clean": img_clean, @@ -576,7 +437,7 @@ def main(cfg: DictConfig) -> None: ): loss = loss_fn(**loss_fn_kwargs) - loss = loss.sum() / batch_size_per_gpu + loss = loss.sum() / batch_size_per_gpu / patch_num_per_iter loss_accum += ( loss / num_accumulation_rounds @@ -604,13 +465,12 @@ def main(cfg: DictConfig) -> None: # Update weights. with nvtx.annotate("update weights", color="blue"): - lr_rampup = cfg.training.hp.lr_rampup # ramp up the learning rate - for g in optimizer.param_groups: - if lr_rampup > 0: - g["lr"] = cfg.training.hp.lr * min(cur_nimg / lr_rampup, 1) - if cur_nimg >= lr_rampup: - g["lr"] *= cfg.training.hp.lr_decay ** ((cur_nimg - lr_rampup) // cfg.training.hp.lr_decay_rate) - current_lr = g["lr"] + current_lr = update_learning_rate(optimizer, + cfg.training.hp.lr, + cfg.training.hp.lr_rampup, + cfg.training.hp.lr_decay, + cfg.training.hp.lr_decay_rate, + cur_nimg) handle_and_clip_gradients( model, grad_clip_threshold=cfg.training.hp.grad_clip_threshold ) @@ -620,6 +480,7 @@ def main(cfg: DictConfig) -> None: cur_nimg += cfg.training.hp.total_batch_size done = cur_nimg >= cfg.training.hp.training_duration + # Logging training progress if is_time_for_periodic_task( cur_nimg, cfg.training.io.print_progress_freq, @@ -629,167 +490,25 @@ def main(cfg: DictConfig) -> None: rank_0_only=True, ): # Print stats if we crossed the printing threshold with this batch - torch.cuda.synchronize() - tick_end_time = time.time() - fields = [] - fields += [f"samples {cur_nimg:<9.1f}"] - fields += [f"training_loss {average_loss:<7.2f}"] - fields += [f"training_loss_running_mean {average_loss_running_mean:<7.2f}"] - fields += [f"learning_rate {current_lr:<7.8f}"] - fields += [f"sec_per_tick {(tick_end_time - tick_start_time):<7.1f}"] - fields += [ - f"sec_per_sample {((tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg)):<7.4f}" - ] - fields += [ - f"sec_for_reading {tick_read_time:<7.4f}" - ] - fields += [f"total_sec {(tick_end_time - start_time):<7.1f}"] - fields += [ - f"cpu_mem_gb {(psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}" - ] - if torch.cuda.is_available(): - fields += [ - f"peak_gpu_mem_gb {(torch.cuda.max_memory_allocated(dist.device) / 2**30):<6.2f}" - ] - fields += [ - f"peak_gpu_mem_reserved_gb {(torch.cuda.max_memory_reserved(dist.device) / 2**30):<6.2f}" - ] - torch.cuda.reset_peak_memory_stats() - logger0.info(" ".join(fields)) - - if cfg.logging.method == "mlflow": - mlflow.log_metric("training_loss", average_loss, cur_nimg) - mlflow.log_metric( - "training_loss_running_mean", - average_loss_running_mean, - cur_nimg, - ) - mlflow.log_metric("learning_rate", current_lr, cur_nimg) + log_training_progress(logger0, cfg.logging.method, dist, cur_nimg, tick_start_nimg, tick_start_time, + tick_read_time, start_time, average_loss, average_loss_running_mean, current_lr) # reset running mean of average loss average_loss_running_mean = 0 n_average_loss_running_mean = 1 + + # Validation with nvtx.annotate("validation", color="red"): - # Validation - if validation_dataset_iterator is not None: - valid_loss_accum = 0 - if is_time_for_periodic_task( - cur_nimg, - cfg.training.io.validation_freq, - done, - cfg.training.hp.total_batch_size, - dist.rank, - ): - with torch.no_grad(): - # turn off lead_time_label for now since we are not using them - #TODO: implement lead time labels properly once we train on IFS? - lead_time_label_valid = None - for _ in range(cfg.training.io.validation_steps): - ( - img_clean_valid, - img_lr_valid, - *date_str, - ) = next(validation_dataset_iterator) - img_lr_valid = dataset.interpolator(img_lr_valid.to(dist.device)).reshape(*img_lr_valid.shape[:-1], *img_shape).flip(-2) - img_lr_valid = dataset.normalize_input(img_lr_valid) - if is_real_target: - img_clean_valid = regrid_icon_to_rotlatlon( - img_clean_valid.to(dist.device, dtype=input_dtype), - dataset.regrid_indices_real, - dataset.regrid_weights_real, - ) - if dataset.trim_edge > 0: - img_clean_valid = img_clean_valid[:, :, dataset.trim_edge:-dataset.trim_edge, dataset.trim_edge:-dataset.trim_edge] - img_clean_valid = img_clean_valid.flip(-2) - else: - img_clean_valid = img_clean_valid.to(dist.device, dtype=input_dtype) - img_clean_valid = dataset.normalize_output(img_clean_valid) - date_embedding = None - if n_month_hour_channels > 0: - date_embedding = dataset.make_time_grids(*date_str, dist.device, dtype=input_dtype) - if use_apex_gn: - img_clean_valid = img_clean_valid.to( - dist.device, - dtype=input_dtype, - non_blocking=True, - ).to(memory_format=torch.channels_last) - img_lr_valid = img_lr_valid.to( - dist.device, - dtype=input_dtype, - non_blocking=True, - ).to(memory_format=torch.channels_last) - - else: - img_clean_valid = ( - img_clean_valid.to(dist.device) - .to(input_dtype) - .contiguous() - ) - img_lr_valid = ( - img_lr_valid.to(dist.device) - .to(input_dtype) - .contiguous() - ) - - loss_valid_kwargs = { - "net": model, - "img_clean": img_clean_valid, - "img_lr": img_lr_valid, - "static_channels": static_channels, - "date_embedding": date_embedding, - "augment_pipe": None, - "use_apex_gn": use_apex_gn, - } - if use_patch_grad_acc is not None: - loss_valid_kwargs[ - "use_patch_grad_acc" - ] = use_patch_grad_acc - if lead_time_label_valid: - lead_time_label_valid = ( - lead_time_label_valid[0] - .to(dist.device) - .contiguous() - ) - loss_valid_kwargs.update( - {"lead_time_label": lead_time_label_valid} - ) - if use_patch_grad_acc: - loss_fn.y_mean = None - - for patch_num_per_iter in patch_nums_iter: - if patching is not None: - patching.set_patch_num(patch_num_per_iter) - loss_valid_kwargs.update( - {"patching": patching} - ) - with torch.autocast( - "cuda", dtype=amp_dtype, enabled=enable_amp - ): - loss_valid = loss_fn(**loss_valid_kwargs) - - loss_valid = ( - (loss_valid.sum() / batch_size_per_gpu) - .cpu() - .item() - ) - valid_loss_accum += ( - loss_valid - / cfg.training.io.validation_steps - / len(patch_nums_iter) - ) - valid_loss_sum = torch.tensor( - [valid_loss_accum], device=dist.device - ) - if dist.world_size > 1: - torch.distributed.barrier() - torch.distributed.all_reduce( - valid_loss_sum, op=torch.distributed.ReduceOp.SUM - ) - average_valid_loss = valid_loss_sum / dist.world_size - if dist.rank == 0 and cfg.logging.method == "mlflow": - mlflow.log_metric( - "validation_loss", average_valid_loss, cur_nimg - ) + if validation_dataset_iterator is not None and is_time_for_periodic_task( + cur_nimg, + cfg.training.io.validation_freq, + done, + cfg.training.hp.total_batch_size, + dist.rank, + ): + training_manager.run_validation(cur_nimg, validation_dataset_iterator, model, loss_fn, + cfg.training.io.get("validation_steps",1), static_channels, + batch_size_per_gpu, patching, patch_nums_iter, use_patch_grad_acc) # Save checkpoints diff --git a/src/hirad/training/train_dummy.py b/src/hirad/training/train_dummy.py deleted file mode 100644 index cc47ea9..0000000 --- a/src/hirad/training/train_dummy.py +++ /dev/null @@ -1,14 +0,0 @@ -import hydra -from omegaconf import DictConfig, OmegaConf -import json - - -@hydra.main(version_base=None, config_path="../conf", config_name="training") -def main(cfg: DictConfig) -> None: - OmegaConf.resolve(cfg) - cfg = OmegaConf.to_container(cfg) - print(json.dumps(cfg, indent=2)) - # print(cfg.pretty()) - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/src/hirad/training/training_manager.py b/src/hirad/training/training_manager.py new file mode 100644 index 0000000..cba5243 --- /dev/null +++ b/src/hirad/training/training_manager.py @@ -0,0 +1,287 @@ +from abc import ABC, abstractmethod +import torch +import numpy as np +import mlflow +import os +import json + +from hirad.distributed import DistributedManager +from hirad.utils.console import PythonLogger +from hirad.datasets import DownscalingDataset +from hirad.models import UNet, EDMPrecondSuperResolution +from hirad.utils.dataset_utils import regrid_icon_to_rotlatlon +from hirad.utils.checkpoint import load_checkpoint + + +class TrainingManagerBase(ABC): + def __init__(self, dist: DistributedManager, logger: PythonLogger): + self.dist = dist + self.logger = logger + + @abstractmethod + def load_and_preprocess_batch(self): + pass + + @abstractmethod + def get_static_data(self): + pass + + @abstractmethod + def create_model(self): + pass + + @abstractmethod + def run_validation(self): + pass + + +class TrainingManagerCorrDiff(TrainingManagerBase): + def __init__( + self, + dist: DistributedManager, + logger: PythonLogger, + dataset: DownscalingDataset, + input_dtype: torch.dtype, + img_shape: tuple[int, int], + n_month_hour_channels: int, + fp16: bool, + profile_mode: bool, + enable_amp: bool, + amp_dtype: torch.dtype, + use_apex_gn: bool, + is_real_target: bool, + songunet_checkpoint_level: int, + use_patching: bool, + hr_mean_conditioning: bool, + logging_method: str, + ): + super().__init__(dist, logger) + self.dataset = dataset + self.input_dtype = input_dtype + self.img_shape = img_shape + self.is_real_target = is_real_target + self.n_month_hour_channels = n_month_hour_channels + self.fp16 = fp16 + self.songunet_checkpoint_level = songunet_checkpoint_level + self.profile_mode = profile_mode + self.enable_amp = enable_amp + self.amp_dtype = amp_dtype + self.use_apex_gn = use_apex_gn + self.use_patching = use_patching + self.hr_mean_conditioning = hr_mean_conditioning + self.logging_method = logging_method + + + def load_and_preprocess_batch(self, dataset_iterator): + """Load a batch from the iterator and preprocess it (interpolate, normalize, move to device).""" + img_clean, img_lr, *date_str = next(dataset_iterator) + + # Interpolate and normalize low-res input + img_lr = self.dataset.interpolator( + img_lr.to(self.dist.device, dtype=self.input_dtype) + ).reshape(*img_lr.shape[:-1], *self.img_shape).flip(-2) + img_lr = self.dataset.normalize_input(img_lr) + + # Process high-res target + if self.is_real_target: + img_clean = regrid_icon_to_rotlatlon( + img_clean.to(self.dist.device, dtype=self.input_dtype), + self.dataset.regrid_indices_real, + self.dataset.regrid_weights_real, + ) + if self.dataset.trim_edge > 0: + img_clean = img_clean[:, :, self.dataset.trim_edge:-self.dataset.trim_edge, + self.dataset.trim_edge:-self.dataset.trim_edge] + img_clean = img_clean.flip(-2) + else: + img_clean = img_clean.to(self.dist.device, dtype=self.input_dtype) + img_clean = img_clean.reshape(*img_clean.shape[:-1], *self.img_shape).flip(-2) + img_clean = self.dataset.normalize_output(img_clean) + + # Date embedding + date_embedding = None + if self.n_month_hour_channels > 0: + date_embedding = self.dataset.make_time_grids(*date_str, self.dist.device, dtype=self.input_dtype) + + # Memory format + if self.use_apex_gn: + img_clean = img_clean.to(self.dist.device, dtype=self.input_dtype, non_blocking=True).to( + memory_format=torch.channels_last + ) + img_lr = img_lr.to(self.dist.device, dtype=self.input_dtype, non_blocking=True).to( + memory_format=torch.channels_last + ) + else: + img_clean = img_clean.to(self.dist.device).to(self.input_dtype).contiguous() + img_lr = img_lr.to(self.dist.device).to(self.input_dtype).contiguous() + + return img_clean, img_lr, date_embedding + + def get_static_data(self): + """Get static data from the dataset, preprocess it and move to device.""" + static_channels = self.dataset.get_static_data() + if static_channels is not None: + if isinstance(static_channels, np.ndarray): + static_channels = torch.from_numpy(static_channels) + + static_channels = static_channels[None, ::].flip(-2) + if self.use_apex_gn: + static_channels = static_channels.to( + self.dist.device, + dtype=self.input_dtype, + non_blocking=True, + ).to(memory_format=torch.channels_last) + else: + static_channels = ( + static_channels.to(self.dist.device) + .to(self.input_dtype) + .contiguous() + ) + return static_channels + + + def create_model(self, cfg_model_name: str, cfg_model_args: dict, prob_channels: list = []): + """Instantiate the model.""" + n_input_channels = len(self.dataset.input_channels()) + n_static_channels = len(self.dataset.static_channels()) + n_output_channels = len(self.dataset.output_channels()) + + img_in_channels = n_input_channels + n_static_channels + self.n_month_hour_channels + if self.hr_mean_conditioning: + img_in_channels += n_output_channels + if self.use_patching: + img_in_channels += n_input_channels + n_static_channels + + img_out_channels = n_output_channels + + self.logger.info(f"Creating model {cfg_model_name} with {img_in_channels} input channels and {img_out_channels} output channels.") + + model_args = { # default parameters for all networks + "img_out_channels": img_out_channels, + "img_resolution": list(self.img_shape), + "use_fp16": self.fp16, + "checkpoint_level": self.songunet_checkpoint_level, + } + if cfg_model_name == "lt_aware_ce_regression": + model_args["prob_channels"] = prob_channels + + if cfg_model_args: # override defaults from config file + model_args.update(cfg_model_args) + + model_args["use_apex_gn"] = self.use_apex_gn + model_args["profile_mode"] = self.profile_mode + + if self.enable_amp: + model_args["amp_mode"] = self.enable_amp + + + if cfg_model_name == "regression": + model = UNet( + img_in_channels=img_in_channels + model_args["N_grid_channels"], + **model_args, + ) + model_args["img_in_channels"] = img_in_channels + model_args["N_grid_channels"] + elif cfg_model_name == "lt_aware_ce_regression": + model = UNet( + img_in_channels=img_in_channels + + model_args["N_grid_channels"] + + model_args["lead_time_channels"], + **model_args, + ) + model_args["img_in_channels"] = img_in_channels + model_args["N_grid_channels"] + model_args["lead_time_channels"] + elif cfg_model_name == "lt_aware_patched_diffusion": + model = EDMPrecondSuperResolution( + img_in_channels=img_in_channels + + model_args["N_grid_channels"] + + model_args["lead_time_channels"], + **model_args, + ) + model_args["img_in_channels"] = img_in_channels + model_args["N_grid_channels"] + model_args["lead_time_channels"] + else: # diffusion or patched diffusion + model = EDMPrecondSuperResolution( + img_in_channels=img_in_channels + model_args["N_grid_channels"], + **model_args, + ) + model_args["img_in_channels"] = img_in_channels + model_args["N_grid_channels"] + + return model, model_args + + def load_regression_model(self, regression_checkpoint_path: str): + """Load the regression model for the residual loss if applicable.""" + + if not os.path.isdir(regression_checkpoint_path): + raise FileNotFoundError( + f"Expected this regression checkpoint but not found: {regression_checkpoint_path}" + ) + #TODO make regression model loading more robust (model type is both in rergession_checkpoint_path and regression_name) + #TODO add the option to choose epoch to load from / regression_checkpoint_path is now a folder + regression_model_args_path = os.path.join(regression_checkpoint_path, 'model_args.json') + if not os.path.isfile(regression_model_args_path): + raise FileNotFoundError(f"Missing config file at '{regression_model_args_path}'.") + + with open(regression_model_args_path, 'r') as f: + regression_model_args = json.load(f) + + regression_model_args.update({ + "use_apex_gn": self.use_apex_gn, + "profile_mode": self.profile_mode, + "amp_mode": self.enable_amp, + }) + + regression_net = UNet(**regression_model_args) + + _ = load_checkpoint( + path=regression_checkpoint_path, + model=regression_net, + device=self.dist.device + ) + regression_net.eval().requires_grad_(False).to(self.dist.device) + if self.use_apex_gn: + regression_net.to(memory_format=torch.channels_last) + self.logger.success("Loaded the pre-trained regression model") + + return regression_net + + + def run_validation(self, cur_nimg, validation_dataset_iterator, model, loss_fn, validation_steps, + static_channels, batch_size_per_gpu, patching, + patch_nums_iter, use_patch_grad_acc): + """Run validation and return average validation loss.""" + valid_loss_accum = 0 + with torch.no_grad(): + lead_time_label_valid = None + for _ in range(validation_steps): + img_clean_valid, img_lr_valid, date_embedding = self.load_and_preprocess_batch(validation_dataset_iterator) + + loss_valid_kwargs = { + "net": model, + "img_clean": img_clean_valid, + "img_lr": img_lr_valid, + "static_channels": static_channels, + "date_embedding": date_embedding, + "augment_pipe": None, + "use_apex_gn": self.use_apex_gn, + } + if use_patch_grad_acc is not None: + loss_valid_kwargs["use_patch_grad_acc"] = use_patch_grad_acc + if use_patch_grad_acc: + loss_fn.y_mean = None + + for patch_num_per_iter in patch_nums_iter: + if patching is not None: + patching.set_patch_num(patch_num_per_iter) + loss_valid_kwargs["patching"] = patching + with torch.autocast("cuda", dtype=self.amp_dtype, enabled=self.enable_amp): + loss_valid = loss_fn(**loss_valid_kwargs) + loss_valid = (loss_valid.sum() / batch_size_per_gpu / patch_num_per_iter).cpu().item() + valid_loss_accum += loss_valid / validation_steps / len(patch_nums_iter) + + valid_loss_sum = torch.tensor([valid_loss_accum], device=self.dist.device) + if self.dist.world_size > 1: + torch.distributed.barrier() + torch.distributed.all_reduce(valid_loss_sum, op=torch.distributed.ReduceOp.SUM) + average_valid_loss = (valid_loss_sum / self.dist.world_size).item() + if self.dist.rank == 0 and self.logging_method == "mlflow": + mlflow.log_metric("validation_loss", average_valid_loss, cur_nimg) + + return average_valid_loss \ No newline at end of file diff --git a/src/hirad/utils/dataset_utils.py b/src/hirad/utils/dataset_utils.py index f5400c7..4a2c3d2 100644 --- a/src/hirad/utils/dataset_utils.py +++ b/src/hirad/utils/dataset_utils.py @@ -146,7 +146,31 @@ def _prepare_interpolation(self): self._lambda3 = 1 - self._lambda1 - self._lambda2 - def to_torch(self, device: torch.device ='cpu') -> None: + def to(self, device: str | torch.device) -> None: + """ + Prepare barycentric coordinates and simplex indices for PyTorch operations. + + This method converts the precomputed numpy arrays to PyTorch tensors + and moves them to the specified device. + + Args: + device: The torch device to move tensors to (e.g., 'cpu' or 'cuda'). + """ + if isinstance(device, str): + device = torch.device(device) + if self.is_torch and self.device == device: + return # Already on the correct device + elif self.is_torch and self.device != device: + self.device = device + self._lambda1 = self._lambda1.to(device) + self._lambda2 = self._lambda2.to(device) + self._lambda3 = self._lambda3.to(device) + self._simplex_id = self._simplex_id.to(device) + self._tri.simplices = self._tri.simplices.to(device) + elif not self.is_torch: + self.to_torch(device) + + def to_torch(self, device: torch.device | str ='cpu') -> None: """ Prepare barycentric coordinates and simplex indices for PyTorch operations. diff --git a/src/hirad/utils/inference_utils.py b/src/hirad/utils/inference_utils.py index c9bf1dd..7aa764f 100644 --- a/src/hirad/utils/inference_utils.py +++ b/src/hirad/utils/inference_utils.py @@ -243,7 +243,7 @@ def diffusion_step( ############################################################################ -# Saving and Visualization Utilities # +# Saving and Visualization Utilities # ############################################################################ diff --git a/src/hirad/utils/train_helpers.py b/src/hirad/utils/train_helpers.py index f3eaca0..771d2ee 100644 --- a/src/hirad/utils/train_helpers.py +++ b/src/hirad/utils/train_helpers.py @@ -20,10 +20,37 @@ import mlflow from omegaconf import DictConfig, OmegaConf import os +import psutil +import time from hirad.distributed import DistributedManager from hirad.utils.env_info import get_env_info, flatten_dict +# Define safe CUDA profiler tools that fallback to no-ops when CUDA is not available +def cuda_profiler(): + if torch.cuda.is_available(): + return torch.cuda.profiler.profile() + else: + return nullcontext() + + +def cuda_profiler_start(): + if torch.cuda.is_available(): + torch.cuda.profiler.start() + + +def cuda_profiler_stop(): + if torch.cuda.is_available(): + torch.cuda.profiler.stop() + + +def profiler_emit_nvtx(): + if torch.cuda.is_available(): + return torch.autograd.profiler.emit_nvtx() + else: + return nullcontext() + + def set_patch_shape(img_shape, patch_shape): img_shape_y, img_shape_x = img_shape patch_shape_y, patch_shape_x = patch_shape @@ -48,6 +75,27 @@ def set_patch_shape(img_shape, patch_shape): return use_patching, (img_shape_y, img_shape_x), (patch_shape_y, patch_shape_x) +def calculate_patch_per_iter(patch_num, max_patch_per_gpu, batch_size_per_gpu): + if max_patch_per_gpu: + if max_patch_per_gpu // batch_size_per_gpu < 1: + raise ValueError( + f"max_patch_per_gpu ({max_patch_per_gpu}) must be greater or equal to batch_size_per_gpu ({batch_size_per_gpu})." + ) + max_patch_num_per_iter = min( + patch_num, (max_patch_per_gpu // batch_size_per_gpu) + ) + patch_iterations = ( + patch_num + max_patch_num_per_iter - 1 + ) // max_patch_num_per_iter + patch_nums_iter = [ + min(max_patch_num_per_iter, patch_num - i * max_patch_num_per_iter) + for i in range(patch_iterations) + ] + else: + patch_nums_iter = [patch_num] + return patch_nums_iter + + def set_seed(rank): """ Set seeds for NumPy and PyTorch to ensure reproducibility in distributed settings @@ -84,6 +132,20 @@ def compute_num_accumulation_rounds(total_batch_size, batch_size_per_gpu, world_ return batch_gpu_total, num_accumulation_rounds +def update_learning_rate(optimizer, lr, lr_rampup, lr_decay, lr_decay_rate, cur_nimg): + """Apply learning rate rampup and decay schedule.""" + current_lr = None + for g in optimizer.param_groups: + if lr_rampup > 0: + g["lr"] = lr * min(cur_nimg / lr_rampup, 1) + if cur_nimg >= lr_rampup: + g["lr"] *= lr_decay ** ( + (cur_nimg - lr_rampup) // lr_decay_rate + ) + current_lr = g["lr"] + return current_lr + + def handle_and_clip_gradients(model, grad_clip_threshold=None): """ Handles NaNs and infinities in the gradients and optionally clips the gradients. @@ -178,3 +240,32 @@ def init_mlflow(cfg: DictConfig, dist: DistributedManager, write_dir: str=".") - with open(os.path.join(write_dir, "run_id.txt"), 'r') as f: run_id = f.read() mlflow.start_run(run_id=run_id, log_system_metrics=True) + + +def log_training_progress(logger0, logging_method, dist, cur_nimg, tick_start_nimg, tick_start_time, + tick_read_time, start_time, average_loss, average_loss_running_mean, + current_lr): + """Log training progress metrics.""" + torch.cuda.synchronize() + tick_end_time = time.time() + fields = [ + f"samples {cur_nimg:<9.1f}", + f"training_loss {average_loss:<7.2f}", + f"training_loss_running_mean {average_loss_running_mean:<7.2f}", + f"learning_rate {current_lr:<7.8f}", + f"sec_per_tick {(tick_end_time - tick_start_time):<7.1f}", + f"sec_per_sample {((tick_end_time - tick_start_time) / (cur_nimg - tick_start_nimg)):<7.4f}", + f"sec_for_reading {tick_read_time:<7.4f}", + f"total_sec {(tick_end_time - start_time):<7.1f}", + f"cpu_mem_gb {(psutil.Process(os.getpid()).memory_info().rss / 2**30):<6.2f}", + ] + if torch.cuda.is_available(): + fields.append(f"peak_gpu_mem_gb {(torch.cuda.max_memory_allocated(dist.device) / 2**30):<6.2f}") + fields.append(f"peak_gpu_mem_reserved_gb {(torch.cuda.max_memory_reserved(dist.device) / 2**30):<6.2f}") + torch.cuda.reset_peak_memory_stats() + logger0.info(" ".join(fields)) + + if logging_method == "mlflow": + mlflow.log_metric("training_loss", average_loss, cur_nimg) + mlflow.log_metric("training_loss_running_mean", average_loss_running_mean, cur_nimg) + mlflow.log_metric("learning_rate", current_lr, cur_nimg) diff --git a/tests/models/test_layers.py b/tests/models/test_layers.py new file mode 100644 index 0000000..774a9ba --- /dev/null +++ b/tests/models/test_layers.py @@ -0,0 +1,950 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import torch +import torch.nn as nn +from unittest.mock import MagicMock, patch + +from hirad.models.layers import ( + AttentionOp, + Conv2d, + FourierEmbedding, + GroupNorm, + Linear, + PositionalEmbedding, + UNetBlock, +) + + +# --------------------------------------------------------------------------- +# Helpers / fixtures — use small configs for fast CPU tests +# --------------------------------------------------------------------------- + +B = 2 +IN_CH = 4 +OUT_CH = 8 +H, W = 16, 16 +EMB_CH = 32 + + +@pytest.fixture() +def random_input_2d(): + return torch.randn(B, IN_CH, H, W) + + +@pytest.fixture() +def random_input_flat(): + return torch.randn(B, IN_CH) + + +@pytest.fixture() +def embedding(): + return torch.randn(B, EMB_CH) + + +############################################################################ +# Linear # +############################################################################ + + +class TestLinearInit: + """Test Linear.__init__ parameter setup.""" + + def test_weight_shape(self): + layer = Linear(in_features=IN_CH, out_features=OUT_CH) + assert layer.weight.shape == (OUT_CH, IN_CH) + + def test_bias_shape(self): + layer = Linear(in_features=IN_CH, out_features=OUT_CH) + assert layer.bias is not None + assert layer.bias.shape == (OUT_CH,) + + def test_no_bias_when_disabled(self): + layer = Linear(in_features=IN_CH, out_features=OUT_CH, bias=False) + assert layer.bias is None + + def test_stores_features(self): + layer = Linear(in_features=IN_CH, + out_features=OUT_CH, + amp_mode=True) + assert layer.in_features == IN_CH + assert layer.out_features == OUT_CH + assert layer.amp_mode==True + + @pytest.mark.parametrize( + "init_mode", + ["xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal"], + ) + def test_all_init_modes_accepted(self, init_mode): + layer = Linear(in_features=IN_CH, out_features=OUT_CH, init_mode=init_mode) + assert layer.weight.shape == (OUT_CH, IN_CH) + + def test_init_weight_scaling(self): + layer = Linear(in_features=IN_CH, out_features=OUT_CH, init_weight=0.0) + torch.testing.assert_close(layer.weight, torch.zeros(OUT_CH, IN_CH)) + + def test_init_bias_scaling(self): + layer = Linear(in_features=IN_CH, out_features=OUT_CH, init_bias=0) + torch.testing.assert_close(layer.bias, torch.zeros(OUT_CH)) + + +class TestLinearForward: + """Test Linear forward pass.""" + + def test_output_shape(self, random_input_flat): + layer = Linear(in_features=IN_CH, out_features=OUT_CH) + out = layer(random_input_flat) + assert out.shape == (B, OUT_CH) + + def test_output_shape_no_bias(self, random_input_flat): + layer = Linear(in_features=IN_CH, out_features=OUT_CH, bias=False) + out = layer(random_input_flat) + assert out.shape == (B, OUT_CH) + + def test_zero_weight_zero_bias_returns_zero(self, random_input_flat): + layer = Linear( + in_features=IN_CH, + out_features=OUT_CH, + init_weight=0, + init_bias=0, + ) + out = layer(random_input_flat) + torch.testing.assert_close(out, torch.zeros(B, OUT_CH)) + + def test_output_dtype_matches_input(self, random_input_flat): + layer = Linear(in_features=IN_CH, out_features=OUT_CH) + out = layer(random_input_flat) + assert out.dtype == random_input_flat.dtype + + def test_gradients_flow(self, random_input_flat): + layer = Linear(in_features=IN_CH, out_features=OUT_CH) + x = random_input_flat.clone().requires_grad_(True) + out = layer(x) + out.sum().backward() + assert x.grad is not None + assert x.grad.shape == x.shape + + +############################################################################ +# Conv2d # +############################################################################ + + +class TestConv2dInit: + """Test Conv2d.__init__ parameter setup.""" + + def test_weight_shape(self): + layer = Conv2d(in_channels=IN_CH, out_channels=OUT_CH, kernel=3) + assert layer.weight.shape == (OUT_CH, IN_CH, 3, 3) + + def test_bias_shape(self): + layer = Conv2d(in_channels=IN_CH, out_channels=OUT_CH, kernel=3) + assert layer.bias is not None + assert layer.bias.shape == (OUT_CH,) + + def test_no_bias_when_disabled(self): + layer = Conv2d(in_channels=IN_CH, out_channels=OUT_CH, kernel=3, bias=False) + assert layer.bias is None + + def test_kernel_zero_no_weight(self): + layer = Conv2d(in_channels=IN_CH, out_channels=OUT_CH, kernel=0) + assert layer.weight is None + assert layer.bias is None + + def test_up_and_down_raises(self): + with pytest.raises(ValueError, match="Both 'up' and 'down'"): + Conv2d(in_channels=IN_CH, out_channels=OUT_CH, kernel=3, up=True, down=True) + + def test_stores_flags(self): + layer = Conv2d(in_channels=IN_CH, out_channels=OUT_CH, kernel=3, up=True, fused_resample=True, fused_conv_bias=True, amp_mode=True) + assert layer.up is True + assert layer.down is False + assert layer.fused_resample is True + assert layer.fused_conv_bias is True + assert layer.amp_mode is True + assert layer.in_channels == IN_CH + assert layer.out_channels == OUT_CH + + def test_resample_filter_registered_when_up(self): + layer = Conv2d( + in_channels=IN_CH, out_channels=OUT_CH, kernel=3, up=True + ) + assert layer.resample_filter is not None + + def test_resample_filter_registered_when_down(self): + layer = Conv2d( + in_channels=IN_CH, out_channels=OUT_CH, kernel=3, down=True + ) + assert layer.resample_filter is not None + + def test_resample_filter_none_when_no_up_down(self): + layer = Conv2d(in_channels=IN_CH, out_channels=OUT_CH, kernel=3) + assert layer.resample_filter is None + + def test_fused_conv_bias_disabled_when_no_kernel(self): + layer = Conv2d( + in_channels=IN_CH, + out_channels=OUT_CH, + kernel=0, + fused_conv_bias=True, + ) + assert layer.fused_conv_bias is False + + def test_zero_weight_bias_init_gives_zero_weights_and_biases(self): + layer = Conv2d( + in_channels=IN_CH, + out_channels=OUT_CH, + kernel=3, + init_weight=0, + init_bias=0, + ) + torch.testing.assert_close(layer.weight, torch.zeros(OUT_CH, IN_CH, 3, 3)) + torch.testing.assert_close(layer.bias, torch.zeros(OUT_CH)) + + +class TestConv2dForward: + """Test Conv2d forward pass.""" + + def test_output_shape_same_padding(self, random_input_2d): + layer = Conv2d(in_channels=IN_CH, out_channels=OUT_CH, kernel=3) + out = layer(random_input_2d) + assert out.shape == (B, OUT_CH, H, W) + + def test_output_shape_kernel_1(self, random_input_2d): + layer = Conv2d(in_channels=IN_CH, out_channels=OUT_CH, kernel=1) + out = layer(random_input_2d) + assert out.shape == (B, OUT_CH, H, W) + + def test_output_shape_upsample(self, random_input_2d): + layer = Conv2d(in_channels=IN_CH, out_channels=OUT_CH, kernel=3, up=True) + out = layer(random_input_2d) + assert out.shape == (B, OUT_CH, H * 2, W * 2) + + def test_output_shape_downsample(self, random_input_2d): + layer = Conv2d(in_channels=IN_CH, out_channels=OUT_CH, kernel=3, down=True) + out = layer(random_input_2d) + assert out.shape == (B, OUT_CH, H // 2, W // 2) + + def test_output_shape_fused_upsample(self, random_input_2d): + layer = Conv2d( + in_channels=IN_CH, + out_channels=OUT_CH, + kernel=3, + up=True, + fused_resample=True, + ) + out = layer(random_input_2d) + assert out.shape == (B, OUT_CH, H * 2, W * 2) + + def test_output_shape_fused_downsample(self, random_input_2d): + layer = Conv2d( + in_channels=IN_CH, + out_channels=OUT_CH, + kernel=3, + down=True, + fused_resample=True, + ) + out = layer(random_input_2d) + assert out.shape == (B, OUT_CH, H // 2, W // 2) + + def test_output_shape_fused_conv_bias(self, random_input_2d): + layer = Conv2d( + in_channels=IN_CH, + out_channels=OUT_CH, + kernel=3, + fused_conv_bias=True, + ) + out = layer(random_input_2d) + assert out.shape == (B, OUT_CH, H, W) + + def test_output_shape_fused_up_with_conv_bias(self, random_input_2d): + layer = Conv2d( + in_channels=IN_CH, + out_channels=OUT_CH, + kernel=3, + up=True, + fused_resample=True, + fused_conv_bias=True, + ) + out = layer(random_input_2d) + assert out.shape == (B, OUT_CH, H * 2, W * 2) + + def test_output_shape_fused_down_with_conv_bias(self, random_input_2d): + layer = Conv2d( + in_channels=IN_CH, + out_channels=OUT_CH, + kernel=3, + down=True, + fused_resample=True, + fused_conv_bias=True, + ) + out = layer(random_input_2d) + assert out.shape == (B, OUT_CH, H // 2, W // 2) + + def test_output_dtype_matches_input(self, random_input_2d): + layer = Conv2d(in_channels=IN_CH, out_channels=OUT_CH, kernel=3) + out = layer(random_input_2d) + assert out.dtype == random_input_2d.dtype + + def test_gradients_flow(self, random_input_2d): + layer = Conv2d(in_channels=IN_CH, out_channels=OUT_CH, kernel=3) + x = random_input_2d.clone().requires_grad_(True) + out = layer(x) + out.sum().backward() + assert x.grad is not None + assert x.grad.shape == x.shape + + def test_kernel_zero_passthrough(self, random_input_2d): + """With kernel=0, no convolution should be applied, just pass-through.""" + layer = Conv2d(in_channels=IN_CH, out_channels=IN_CH, kernel=0) + out = layer(random_input_2d) + torch.testing.assert_close(out, random_input_2d) + + +############################################################################ +# GroupNorm # +############################################################################ + + +class TestGroupNormInit: + """Test GroupNorm.__init__ parameter setup.""" + + def test_weight_shape(self): + gn = GroupNorm(num_channels=OUT_CH) + assert gn.weight.shape == (OUT_CH,) + + def test_bias_shape(self): + gn = GroupNorm(num_channels=OUT_CH) + assert gn.bias.shape == (OUT_CH,) + + def test_weight_initialized_to_ones(self): + gn = GroupNorm(num_channels=OUT_CH) + torch.testing.assert_close(gn.weight, torch.ones(OUT_CH)) + + def test_bias_initialized_to_zeros(self): + gn = GroupNorm(num_channels=OUT_CH) + torch.testing.assert_close(gn.bias, torch.zeros(OUT_CH)) + + def test_num_groups_clipped_to_min_channels(self): + """If num_channels // min_channels_per_group < num_groups, groups are reduced.""" + gn = GroupNorm(num_channels=8, num_groups=32, min_channels_per_group=4) + assert gn.num_groups == 2 + + def test_num_groups_matches_when_divisible(self): + gn = GroupNorm(num_channels=32, num_groups=8, min_channels_per_group=2) + assert gn.num_groups == 8 + + def test_fused_act_without_act_raises(self): + with pytest.raises(ValueError, match="'act' must be specified"): + GroupNorm(num_channels=OUT_CH, fused_act=True, act=None) + + def test_fused_act_with_valid_act(self): + gn = GroupNorm(num_channels=OUT_CH, fused_act=True, act="silu") + assert gn.fused_act is True + assert gn.act == "silu" + assert gn.act_fn is not None + + def test_eps_and_amp_mode_stored(self): + gn = GroupNorm(num_channels=OUT_CH, eps=1e-6, amp_mode=True) + assert gn.eps == 1e-6 + assert gn.amp_mode is True + + def test_apex_gn_initializes_gn_when_available(self): + mock_gn_cls = MagicMock() + mock_gn_instance = MagicMock() + mock_gn_cls.return_value = mock_gn_instance + with patch("hirad.models.layers._is_apex_available", True), \ + patch("hirad.models.layers.ApexGroupNorm", mock_gn_cls, create=True): + gn = GroupNorm(num_channels=OUT_CH, use_apex_gn=True) + assert hasattr(gn, "gn") + assert gn.gn is mock_gn_instance + mock_gn_cls.assert_called_once() + + def test_apex_gn_raises_when_not_available(self): + with patch("hirad.models.layers._is_apex_available", False): + with pytest.raises(ValueError, match="'apex' is not"): + GroupNorm(num_channels=OUT_CH, use_apex_gn=True) + + +class TestGroupNormForward: + """Test GroupNorm forward pass.""" + + def test_output_shape(self, random_input_2d): + gn = GroupNorm(num_channels=IN_CH) + out = gn(random_input_2d) + assert out.shape == random_input_2d.shape + + def test_output_dtype_matches_input(self, random_input_2d): + gn = GroupNorm(num_channels=IN_CH) + gn.train() + out = gn(random_input_2d) + assert out.dtype == random_input_2d.dtype + gn.eval() + out = gn(random_input_2d) + assert out.dtype == random_input_2d.dtype + + def test_training_mode_uses_torch_group_norm(self, random_input_2d): + """In training mode, output should match torch.nn.functional.group_norm.""" + gn = GroupNorm(num_channels=IN_CH) + gn.train() + out = gn(random_input_2d) + expected = torch.nn.functional.group_norm( + random_input_2d, num_groups=gn.num_groups, weight=gn.weight, bias=gn.bias, eps=gn.eps + ) + torch.testing.assert_close(out, expected) + + def test_eval_mode_output_shape(self, random_input_2d): + gn = GroupNorm(num_channels=IN_CH) + gn.eval() + out = gn(random_input_2d) + assert out.shape == random_input_2d.shape + + def test_apex_gn_forward(self, random_input_2d): + """Test forward pass when using Apex GroupNorm.""" + from hirad.models.layers import _is_apex_available + if _is_apex_available: + gn = GroupNorm(num_channels=IN_CH, use_apex_gn=True) + called = [] + gn.gn.register_forward_hook(lambda m, i, o: called.append(True)) + out = gn(random_input_2d) + assert out.shape == random_input_2d.shape + assert called, "Apex GroupNorm forward hook was not called, so it may not have been used." + else: + mock_gn_cls = MagicMock() + mock_gn_instance = MagicMock() + mock_gn_instance.forward.return_value = random_input_2d + mock_gn_cls.return_value = mock_gn_instance + with patch("hirad.models.layers._is_apex_available", True), \ + patch("hirad.models.layers.ApexGroupNorm", mock_gn_cls, create=True): + gn = GroupNorm(num_channels=IN_CH, use_apex_gn=True) + out = gn(random_input_2d) + assert out.shape == random_input_2d.shape + assert mock_gn_instance.assert_called_once + + def test_training_mode_with_fused_act(self, random_input_2d): + """Test that fused activation is applied in training mode.""" + gn = GroupNorm(num_channels=IN_CH, fused_act=True, act="relu") + gn.train() + out = gn(random_input_2d) + assert out.shape == random_input_2d.shape + assert (out >= 0).all(), "Output should be non-negative due to ReLU activation" + + def test_eval_mode_with_fused_act(self, random_input_2d): + gn = GroupNorm(num_channels=IN_CH, fused_act=True, act="relu") + gn.eval() + out = gn(random_input_2d) + assert out.shape == random_input_2d.shape + assert (out >= 0).all(), "Output should be non-negative due to ReLU activation" + + def test_training_fused_act_actually_applies_activation(self, random_input_2d): + """Verify fused act path produces different output than non-fused path.""" + gn_fused = GroupNorm(num_channels=IN_CH, fused_act=True, act="relu") + gn_plain = GroupNorm(num_channels=IN_CH) + gn_fused.train() + gn_plain.train() + out_fused = gn_fused(random_input_2d) + out_plain = gn_plain(random_input_2d) + # Plain output will have negatives; fused should not + assert (out_plain < 0).any(), "Plain output should have negatives for meaningful test" + assert (out_fused >= 0).all() + + def test_eval_fused_act_actually_applies_activation(self, random_input_2d): + """Verify fused act path produces different output than non-fused path.""" + gn_fused = GroupNorm(num_channels=IN_CH, fused_act=True, act="relu") + gn_plain = GroupNorm(num_channels=IN_CH) + gn_fused.eval() + gn_plain.eval() + out_fused = gn_fused(random_input_2d) + out_plain = gn_plain(random_input_2d) + # Plain output will have negatives; fused should not + assert (out_plain < 0).any(), "Plain output should have negatives for meaningful test" + assert (out_fused >= 0).all() + + def test_eval_mode_matches_training_mode(self, random_input_2d): + """Eval and training modes should produce close results.""" + gn = GroupNorm(num_channels=IN_CH) + gn.train() + out_train = gn(random_input_2d) + gn.eval() + out_eval = gn(random_input_2d) + torch.testing.assert_close(out_train, out_eval, atol=1e-5, rtol=1e-5) + + def test_normalized_output_has_zero_mean(self, random_input_2d): + """After GroupNorm with default weight=1 and bias=0, each group should + have approximately zero mean.""" + gn = GroupNorm(num_channels=IN_CH) + gn.train() + out = gn(random_input_2d) + # Reshape to groups and check mean is near zero + reshaped = out.reshape(B, gn.num_groups, IN_CH // gn.num_groups, H, W) + group_means = reshaped.mean(dim=[2, 3, 4]) + assert group_means.abs().max() < 0.1 + + def test_normalized_output_has_variance_close_to_one(self, random_input_2d): + """After GroupNorm with default weight=1 and bias=0, each group should + have variance close to 1 (not exactly 1 due to eps).""" + gn = GroupNorm(num_channels=IN_CH) + gn.train() + out = gn(random_input_2d) + # Reshape to groups and check variance is near 1 + reshaped = out.reshape(B, gn.num_groups, IN_CH // gn.num_groups, H, W) + group_vars = reshaped.var(dim=[2, 3, 4], unbiased=False) + assert torch.allclose(group_vars, torch.ones_like(group_vars), atol=0.1) + + def test_gradients_flow(self, random_input_2d): + gn = GroupNorm(num_channels=IN_CH) + gn.train() + x = random_input_2d.clone().requires_grad_(True) + out = gn(x) + out.sum().backward() + assert x.grad is not None + assert x.grad.shape == x.shape + + +class TestGroupNormFusedActivation: + """Test GroupNorm with fused activation functions.""" + + @pytest.mark.parametrize( + "act_name", ["silu", "relu", "leaky_relu", "sigmoid", "tanh", "gelu", "elu"] + ) + def test_fused_act_accepted(self, act_name, random_input_2d): + gn = GroupNorm(num_channels=IN_CH, fused_act=True, act=act_name) + gn.train() + out = gn(random_input_2d) + assert out.shape == random_input_2d.shape + + def test_invalid_act_raises(self): + with pytest.raises(ValueError, match="Unknown activation function"): + GroupNorm(num_channels=OUT_CH, fused_act=True, act="invalid_act") + + @pytest.mark.parametrize( + "act_name", ["silu", "relu", "leaky_relu", "sigmoid", "tanh", "gelu", "elu"] + ) + def test_fused_act_matches_separate(self, act_name, random_input_2d): + """Fused activation should give the same result as applying the activation separately.""" + gn_fused = GroupNorm(num_channels=IN_CH, fused_act=True, act=act_name) + gn_plain = GroupNorm(num_channels=IN_CH) + # Copy parameters + gn_fused.train() + gn_plain.train() + out_fused = gn_fused(random_input_2d) + out_separate = getattr(torch.nn.functional, act_name)(gn_plain(random_input_2d)) + torch.testing.assert_close(out_fused, out_separate) + + +############################################################################ +# AttentionOp # +############################################################################ + + +class TestAttentionOpForward: + """Test AttentionOp forward pass.""" + + def test_output_shape(self): + q = torch.randn(B, 16, 8) + k = torch.randn(B, 16, 8) + w = AttentionOp.apply(q, k) + assert w.shape == (B, 8, 8) + + def test_output_is_probability_distribution(self): + """Each row of the attention weights should sum to 1 (softmax output).""" + q = torch.randn(B, 16, 8) + k = torch.randn(B, 16, 8) + w = AttentionOp.apply(q, k) + row_sums = w.sum(dim=2) + torch.testing.assert_close(row_sums, torch.ones_like(row_sums), atol=1e-5, rtol=1e-5) + + def test_output_non_negative(self): + q = torch.randn(B, 16, 8) + k = torch.randn(B, 16, 8) + w = AttentionOp.apply(q, k) + assert (w >= 0).all() + + def test_output_dtype_matches_input(self): + q = torch.randn(B, 16, 8) + k = torch.randn(B, 16, 8) + w = AttentionOp.apply(q, k) + assert w.dtype == q.dtype + + +class TestAttentionOpBackward: + """Test AttentionOp backward pass.""" + + def test_gradients_flow_to_q(self): + q = torch.randn(B, 16, 8, requires_grad=True) + k = torch.randn(B, 16, 8, requires_grad=True) + w = AttentionOp.apply(q, k) + w.sum().backward() + assert q.grad is not None + assert q.grad.shape == q.shape + + def test_gradients_flow_to_k(self): + q = torch.randn(B, 16, 8, requires_grad=True) + k = torch.randn(B, 16, 8, requires_grad=True) + w = AttentionOp.apply(q, k) + w.sum().backward() + assert k.grad is not None + assert k.grad.shape == k.shape + + +############################################################################ +# UNetBlock # +############################################################################ + + +class TestUNetBlockInit: + """Test UNetBlock.__init__ parameter setup.""" + + def test_stores_block_info(self): + block = UNetBlock( + in_channels=IN_CH, out_channels=OUT_CH, emb_channels=EMB_CH, + dropout=0.1, skip_scale=0.5, adaptive_scale=False, + profile_mode=True, amp_mode=True, attention=True, + num_heads=4 + ) + assert block.in_channels == IN_CH + assert block.out_channels == OUT_CH + assert block.emb_channels == EMB_CH + assert block.dropout == 0.1 + assert block.skip_scale == 0.5 + assert block.adaptive_scale is False + assert block.profile_mode is True + assert block.amp_mode is True + + def test_num_heads_zero_when_no_attention(self): + block = UNetBlock( + in_channels=IN_CH, out_channels=OUT_CH, emb_channels=EMB_CH, attention=False + ) + assert block.num_heads == 0 + + def test_num_heads_set_when_attention_no_num_heads(self): + block = UNetBlock( + in_channels=IN_CH, out_channels=16, emb_channels=EMB_CH, + attention=True, channels_per_head=4 + ) + assert block.num_heads == 16//4 + + def test_skip_created_when_channels_differ(self): + block = UNetBlock( + in_channels=IN_CH, out_channels=OUT_CH, emb_channels=EMB_CH + ) + assert block.skip is not None + + def test_skip_none_when_channels_match(self): + block = UNetBlock( + in_channels=OUT_CH, out_channels=OUT_CH, emb_channels=EMB_CH + ) + assert block.skip is None + + def test_skip_created_when_up(self): + block = UNetBlock( + in_channels=OUT_CH, out_channels=OUT_CH, emb_channels=EMB_CH, up=True + ) + assert block.skip is not None + + def test_skip_created_when_down(self): + block = UNetBlock( + in_channels=OUT_CH, out_channels=OUT_CH, emb_channels=EMB_CH, down=True + ) + assert block.skip is not None + + def test_attention_heads_not_created_when_attention_false(self): + block = UNetBlock( + in_channels=IN_CH, out_channels=OUT_CH, emb_channels=EMB_CH, attention=False + ) + assert not hasattr(block, "norm2") or block.norm2 is None + assert not hasattr(block, "qkv") or block.qkv is None + assert not hasattr(block, "proj") or block.proj is None + + def test_attention_heads_default(self): + block = UNetBlock( + in_channels=64, + out_channels=64, + emb_channels=EMB_CH, + attention=True, + channels_per_head=64, + ) + assert block.norm2 is not None + assert block.qkv is not None + assert block.proj is not None + + def test_has_norm_and_conv_layers(self): + block = UNetBlock( + in_channels=IN_CH, out_channels=OUT_CH, emb_channels=EMB_CH + ) + assert hasattr(block, "norm0") + assert hasattr(block, "conv0") + assert hasattr(block, "norm1") + assert hasattr(block, "conv1") + assert hasattr(block, "affine") + + +class TestUNetBlockForward: + """Test UNetBlock forward pass.""" + + def test_output_shape_same_channels(self, random_input_2d, embedding): + block = UNetBlock( + in_channels=IN_CH, out_channels=IN_CH, emb_channels=EMB_CH + ) + out = block(random_input_2d, embedding) + assert out.shape == (B, IN_CH, H, W) + + def test_output_shape_different_channels(self, random_input_2d, embedding): + block = UNetBlock( + in_channels=IN_CH, out_channels=OUT_CH, emb_channels=EMB_CH + ) + out = block(random_input_2d, embedding) + assert out.shape == (B, OUT_CH, H, W) + + def test_output_shape_upsample(self, random_input_2d, embedding): + block = UNetBlock( + in_channels=IN_CH, out_channels=OUT_CH, emb_channels=EMB_CH, up=True + ) + out = block(random_input_2d, embedding) + assert out.shape == (B, OUT_CH, H * 2, W * 2) + + def test_output_shape_downsample(self, random_input_2d, embedding): + block = UNetBlock( + in_channels=IN_CH, out_channels=OUT_CH, emb_channels=EMB_CH, down=True + ) + out = block(random_input_2d, embedding) + assert out.shape == (B, OUT_CH, H // 2, W // 2) + + def test_output_dtype_matches_input(self, random_input_2d, embedding): + block = UNetBlock( + in_channels=IN_CH, out_channels=OUT_CH, emb_channels=EMB_CH + ) + out = block(random_input_2d, embedding) + assert out.dtype == random_input_2d.dtype + + def test_gradients_flow(self, random_input_2d, embedding): + block = UNetBlock( + in_channels=IN_CH, out_channels=OUT_CH, emb_channels=EMB_CH + ) + x = random_input_2d.clone().requires_grad_(True) + out = block(x, embedding) + out.sum().backward() + assert x.grad is not None + assert x.grad.shape == x.shape + + def test_with_attention(self, embedding): + ch = 64 + x = torch.randn(B, ch, H, W) + block = UNetBlock( + in_channels=ch, + out_channels=ch, + emb_channels=EMB_CH, + attention=True, + channels_per_head=ch//2, + ) + out = block(x, embedding) + assert out.shape == (B, ch, H, W) + + def test_non_adaptive_scale(self, random_input_2d, embedding): + block = UNetBlock( + in_channels=IN_CH, + out_channels=OUT_CH, + emb_channels=EMB_CH, + adaptive_scale=False, + ) + out = block(random_input_2d, embedding) + assert out.shape == (B, OUT_CH, H, W) + + def test_with_dropout(self, random_input_2d, embedding): + block = UNetBlock( + in_channels=IN_CH, + out_channels=OUT_CH, + emb_channels=EMB_CH, + dropout=0.1, + ) + block.train() + out = block(random_input_2d, embedding) + assert out.shape == (B, OUT_CH, H, W) + + def test_with_amp_mode(self, random_input_2d, embedding): + block = UNetBlock( + in_channels=IN_CH, + out_channels=OUT_CH, + emb_channels=EMB_CH, + amp_mode=True, + ) + out = block(random_input_2d, embedding) + assert out.shape == (B, OUT_CH, H, W) + + def test_skip_scale_applied(self, random_input_2d, embedding): + """Changing skip_scale should change the output magnitude.""" + block_s1 = UNetBlock( + in_channels=IN_CH, + out_channels=OUT_CH, + emb_channels=EMB_CH, + skip_scale=1.0, + ) + block_s2 = UNetBlock( + in_channels=IN_CH, + out_channels=OUT_CH, + emb_channels=EMB_CH, + skip_scale=2.0, + ) + # Copy parameters from block_s1 to block_s2 + block_s2.load_state_dict(block_s1.state_dict(), strict=False) + out1 = block_s1(random_input_2d, embedding) + out2 = block_s2(random_input_2d, embedding) + # s2 should have roughly 2x the magnitude of s1 + torch.testing.assert_close(out2, out1 * 2.0, atol=1e-5, rtol=1e-5) + + +############################################################################ +# PositionalEmbedding # +############################################################################ + + +class TestPositionalEmbeddingInit: + """Test PositionalEmbedding.__init__.""" + + def test_stores_num_channels(self): + emb = PositionalEmbedding(num_channels=64) + assert emb.num_channels == 64 + + def test_stores_max_positions(self): + emb = PositionalEmbedding(num_channels=64, max_positions=5000) + assert emb.max_positions == 5000 + + def test_stores_endpoint(self): + emb = PositionalEmbedding(num_channels=64, endpoint=True) + assert emb.endpoint is True + + def test_amp_mode(self): + emb = PositionalEmbedding(num_channels=64, amp_mode=True) + assert emb.amp_mode is True + + +class TestPositionalEmbeddingForward: + """Test PositionalEmbedding forward pass.""" + + def test_output_shape(self): + emb = PositionalEmbedding(num_channels=64) + x = torch.randn(B) + out = emb(x) + assert out.shape == (B, 64) + + def test_output_shape_single(self): + emb = PositionalEmbedding(num_channels=32) + x = torch.randn(1) + out = emb(x) + assert out.shape == (1, 32) + + def test_different_inputs_produce_different_embeddings(self): + emb = PositionalEmbedding(num_channels=64) + x1 = torch.tensor([0.1]) + x2 = torch.tensor([1.0]) + out1 = emb(x1) + out2 = emb(x2) + assert not torch.allclose(out1, out2) + + def test_same_input_produces_same_embedding(self): + emb = PositionalEmbedding(num_channels=64) + x = torch.tensor([0.5]) + out1 = emb(x) + out2 = emb(x) + torch.testing.assert_close(out1, out2) + + def test_output_contains_sin_and_cos(self): + """Output is concatenation of cos and sin, so first and second halves + should differ for non-trivial inputs.""" + emb = PositionalEmbedding(num_channels=64) + x = torch.tensor([1.0]) + out = emb(x) + first_half = out[:, :32] + second_half = out[:, 32:] + assert not torch.allclose(first_half, second_half) + + def test_output_bounded(self): + """Since output is cos and sin, values should be in [-1, 1].""" + emb = PositionalEmbedding(num_channels=64) + x = torch.randn(B) + out = emb(x) + assert out.min() >= -1.0 - 1e-6 + assert out.max() <= 1.0 + 1e-6 + + def test_endpoint_changes_output(self): + emb_no_end = PositionalEmbedding(num_channels=64, endpoint=False) + emb_end = PositionalEmbedding(num_channels=64, endpoint=True) + x = torch.tensor([1.0]) + out1 = emb_no_end(x) + out2 = emb_end(x) + assert not torch.allclose(out1, out2) + + +############################################################################ +# FourierEmbedding # +############################################################################ + + +class TestFourierEmbeddingInit: + """Test FourierEmbedding.__init__.""" + + def test_freqs_buffer_registered(self): + emb = FourierEmbedding(num_channels=64) + assert hasattr(emb, "freqs") + assert emb.freqs.shape == (32,) + + def test_scale_affects_freqs_magnitude(self): + torch.manual_seed(0) + emb_small = FourierEmbedding(num_channels=64, scale=1) + torch.manual_seed(0) + emb_large = FourierEmbedding(num_channels=64, scale=16) + torch.testing.assert_close(emb_large.freqs, emb_small.freqs * 16) + + def test_amp_mode_stored(self): + emb = FourierEmbedding(num_channels=64, amp_mode=True) + assert emb.amp_mode is True + + +class TestFourierEmbeddingForward: + """Test FourierEmbedding forward pass.""" + + def test_output_shape(self): + emb = FourierEmbedding(num_channels=64) + x = torch.randn(B) + out = emb(x) + assert out.shape == (B, 64) + + def test_output_shape_single(self): + emb = FourierEmbedding(num_channels=32) + x = torch.randn(1) + out = emb(x) + assert out.shape == (1, 32) + + def test_different_inputs_produce_different_embeddings(self): + emb = FourierEmbedding(num_channels=64) + x1 = torch.tensor([0.1]) + x2 = torch.tensor([1.0]) + out1 = emb(x1) + out2 = emb(x2) + assert not torch.allclose(out1, out2) + + def test_same_input_produces_same_embedding(self): + emb = FourierEmbedding(num_channels=64) + x = torch.tensor([0.5]) + out1 = emb(x) + out2 = emb(x) + torch.testing.assert_close(out1, out2) + + def test_output_contains_sin_and_cos(self): + emb = FourierEmbedding(num_channels=64) + x = torch.tensor([1.0]) + out = emb(x) + first_half = out[:, :32] + second_half = out[:, 32:] + assert not torch.allclose(first_half, second_half) + + def test_output_bounded(self): + """Since output is cos and sin, values should be in [-1, 1].""" + emb = FourierEmbedding(num_channels=64) + x = torch.randn(B) + out = emb(x) + assert out.min() >= -1.0 - 1e-6 + assert out.max() <= 1.0 + 1e-6 diff --git a/tests/models/test_preconditioning.py b/tests/models/test_preconditioning.py new file mode 100644 index 0000000..a802e59 --- /dev/null +++ b/tests/models/test_preconditioning.py @@ -0,0 +1,604 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from unittest.mock import MagicMock, patch + +import pytest +import torch +import torch.nn as nn + +from hirad.models.preconditioning import EDMPrecondSuperResolution + + +# --------------------------------------------------------------------------- +# Helpers / fixtures +# --------------------------------------------------------------------------- + +B, C_IN, C_OUT, H, W = 2, 4, 3, 64, 64 + + +def _make_mock_model(out_channels=C_OUT): + """Return a MagicMock that behaves like a SongUNet-style model.""" + model = MagicMock(spec=nn.Module) + model.side_effect = lambda x, sigma, class_labels=None, **kw: torch.zeros( + x.shape[0], out_channels, x.shape[2], x.shape[3], + dtype=x.dtype, device=x.device, + ) + model.modules.return_value = iter([]) + return model + + +@pytest.fixture() +def img_x(): + return torch.randn(B, C_OUT, H, W) + + +@pytest.fixture() +def img_lr(): + return torch.randn(B, C_IN, H, W) + + +@pytest.fixture() +def sigma(): + return torch.ones(B) * 0.5 + + +############################################################################ +# EDMPrecondSuperResolution — __init__ # +############################################################################ + + +class TestEDMInitResolution: + """Test EDMPrecondSuperResolution.__init__ resolution handling.""" + + @patch("hirad.models.preconditioning.network_module") + def test_int_resolution_stored(self, mock_module): + mock_module.SongUNetPosEmbd = MagicMock() + edm = EDMPrecondSuperResolution( + img_resolution=128, img_in_channels=C_IN, img_out_channels=C_OUT, + ) + assert edm.img_resolution == 128 + + @patch("hirad.models.preconditioning.network_module") + def test_tuple_resolution_stored(self, mock_module): + mock_module.SongUNetPosEmbd = MagicMock() + edm = EDMPrecondSuperResolution( + img_resolution=(96, 128), img_in_channels=C_IN, img_out_channels=C_OUT, + ) + assert edm.img_resolution == (96, 128) + + @patch("hirad.models.preconditioning.network_module") + def test_stores_channel_counts(self, mock_module): + mock_module.SongUNetPosEmbd = MagicMock() + edm = EDMPrecondSuperResolution( + img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT, + ) + assert edm.img_in_channels == C_IN + assert edm.img_out_channels == C_OUT + + +class TestEDMInitModelType: + """Test that EDMPrecondSuperResolution creates the correct underlying model type.""" + + @patch("hirad.models.preconditioning.network_module") + def test_default_model_type_is_song_unet_pos_embd(self, mock_module): + mock_cls = MagicMock() + mock_module.SongUNetPosEmbd = mock_cls + EDMPrecondSuperResolution( + img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT, + ) + mock_cls.assert_called_once() + + @patch("hirad.models.preconditioning.network_module") + def test_custom_model_type_song_unet(self, mock_module): + mock_cls = MagicMock() + mock_module.SongUNet = mock_cls + EDMPrecondSuperResolution( + img_resolution=64, + img_in_channels=C_IN, + img_out_channels=C_OUT, + model_type="SongUNet", + ) + mock_cls.assert_called_once() + + @patch("hirad.models.preconditioning.network_module") + def test_model_receives_img_resolution(self, mock_module): + mock_cls = MagicMock() + mock_module.SongUNetPosEmbd = mock_cls + EDMPrecondSuperResolution( + img_resolution=(80, 120), img_in_channels=C_IN, img_out_channels=C_OUT, + ) + call_kwargs = mock_cls.call_args[1] + assert call_kwargs["img_resolution"] == (80, 120) + + @patch("hirad.models.preconditioning.network_module") + def test_model_receives_combined_in_channels(self, mock_module): + mock_cls = MagicMock() + mock_module.SongUNetPosEmbd = mock_cls + EDMPrecondSuperResolution( + img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT, + ) + call_kwargs = mock_cls.call_args[1] + assert call_kwargs["in_channels"] == C_IN + C_OUT + + @patch("hirad.models.preconditioning.network_module") + def test_model_receives_out_channels(self, mock_module): + mock_cls = MagicMock() + mock_module.SongUNetPosEmbd = mock_cls + EDMPrecondSuperResolution( + img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT, + ) + call_kwargs = mock_cls.call_args[1] + assert call_kwargs["out_channels"] == C_OUT + + @patch("hirad.models.preconditioning.network_module") + def test_extra_kwargs_forwarded_to_model(self, mock_module): + mock_cls = MagicMock() + mock_module.SongUNetPosEmbd = mock_cls + EDMPrecondSuperResolution( + img_resolution=64, + img_in_channels=C_IN, + img_out_channels=C_OUT, + model_channels=256, + num_blocks=8, + ) + call_kwargs = mock_cls.call_args[1] + assert call_kwargs["model_channels"] == 256 + assert call_kwargs["num_blocks"] == 8 + + @patch("hirad.models.preconditioning.network_module") + def test_model_attribute_exists(self, mock_module): + mock_module.SongUNetPosEmbd = MagicMock() + edm = EDMPrecondSuperResolution( + img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT, + ) + assert hasattr(edm, "model") + + +class TestEDMInitSigmaDefaults: + """Test sigma-related default values.""" + + @patch("hirad.models.preconditioning.network_module") + def test_default_sigma_data(self, mock_module): + mock_module.SongUNetPosEmbd = MagicMock() + edm = EDMPrecondSuperResolution( + img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT, + ) + assert edm.sigma_data == 0.5 + + @patch("hirad.models.preconditioning.network_module") + def test_default_sigma_min(self, mock_module): + mock_module.SongUNetPosEmbd = MagicMock() + edm = EDMPrecondSuperResolution( + img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT, + ) + assert edm.sigma_min == 0.0 + + @patch("hirad.models.preconditioning.network_module") + def test_default_sigma_max(self, mock_module): + mock_module.SongUNetPosEmbd = MagicMock() + edm = EDMPrecondSuperResolution( + img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT, + ) + assert edm.sigma_max == float("inf") + + @patch("hirad.models.preconditioning.network_module") + def test_custom_sigma_info(self, mock_module): + mock_module.SongUNetPosEmbd = MagicMock() + edm = EDMPrecondSuperResolution( + img_resolution=64, + img_in_channels=C_IN, + img_out_channels=C_OUT, + sigma_data=1.0, + sigma_min=0.002, + sigma_max=80.0, + ) + assert edm.sigma_data == 1.0 + assert edm.sigma_min == 0.002 + assert edm.sigma_max == 80.0 + + +class TestEDMInitFp16: + """Test use_fp16 stored correctly at init.""" + + @patch("hirad.models.preconditioning.network_module") + def test_default_fp16_is_false(self, mock_module): + mock_module.SongUNetPosEmbd = MagicMock() + edm = EDMPrecondSuperResolution( + img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT, + ) + assert edm.use_fp16 is False + + @patch("hirad.models.preconditioning.network_module") + def test_set_fp16_true(self, mock_module): + mock_module.SongUNetPosEmbd = MagicMock() + edm = EDMPrecondSuperResolution( + img_resolution=64, + img_in_channels=C_IN, + img_out_channels=C_OUT, + use_fp16=True, + ) + assert edm.use_fp16 is True + + +############################################################################ +# EDMPrecondSuperResolution — _scaling_fn # +############################################################################ + + +class TestEDMScalingFn: + """Test the static _scaling_fn method.""" + + def test_output_shape(self): + x = torch.randn(B, C_OUT, H, W) + lr = torch.randn(B, C_IN, H, W) + c_in = torch.ones(B, 1, 1, 1) * 0.5 + result = EDMPrecondSuperResolution._scaling_fn(x, lr, c_in) + assert result.shape == (B, C_OUT + C_IN, H, W) + + def test_first_channels_are_scaled_x(self): + x = torch.ones(B, C_OUT, H, W) + lr = torch.randn(B, C_IN, H, W) + c_in = torch.ones(B, 1, 1, 1) * 2.0 + result = EDMPrecondSuperResolution._scaling_fn(x, lr, c_in) + torch.testing.assert_close(result[:, :C_OUT], x * 2.0) + + def test_last_channels_are_unscaled_lr(self): + x = torch.randn(B, C_OUT, H, W) + lr = torch.ones(B, C_IN, H, W) * 3.0 + c_in = torch.ones(B, 1, 1, 1) * 0.5 + result = EDMPrecondSuperResolution._scaling_fn(x, lr, c_in) + torch.testing.assert_close(result[:, C_OUT:], lr) + + def test_lr_cast_to_x_dtype(self): + x = torch.randn(B, C_OUT, H, W, dtype=torch.float32) + lr = torch.randn(B, C_IN, H, W, dtype=torch.float64) + c_in = torch.ones(B, 1, 1, 1) + result = EDMPrecondSuperResolution._scaling_fn(x, lr, c_in) + assert result.dtype == torch.float32 + + +############################################################################ +# EDMPrecondSuperResolution — forward # +############################################################################ + + +class TestEDMForwardBasic: + """Basic forward pass tests for EDMPrecondSuperResolution.""" + + @patch("hirad.models.preconditioning.network_module") + def test_output_shape(self, mock_module, img_x, img_lr, sigma): + mock_model = _make_mock_model() + mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model) + edm = EDMPrecondSuperResolution( + img_resolution=H, img_in_channels=C_IN, img_out_channels=C_OUT, + ) + out = edm(img_x, img_lr, sigma) + assert out.shape == (B, C_OUT, H, W) + + @patch("hirad.models.preconditioning.network_module") + def test_output_dtype_is_float32(self, mock_module, img_x, img_lr, sigma): + mock_model = _make_mock_model() + mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model) + edm = EDMPrecondSuperResolution( + img_resolution=H, img_in_channels=C_IN, img_out_channels=C_OUT, + ) + out = edm(img_x, img_lr, sigma) + assert out.dtype == torch.float32 + + @patch("hirad.models.preconditioning.network_module") + def test_model_input_has_combined_channels(self, mock_module, img_x, img_lr, sigma): + mock_model = _make_mock_model() + mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model) + edm = EDMPrecondSuperResolution( + img_resolution=H, img_in_channels=C_IN, img_out_channels=C_OUT, + ) + edm(img_x, img_lr, sigma) + model_input = mock_model.call_args[0][0] + assert model_input.shape[1] == C_OUT + C_IN + + @patch("hirad.models.preconditioning.network_module") + def test_model_receives_flattened_c_noise(self, mock_module, img_x, img_lr, sigma): + mock_model = _make_mock_model() + mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model) + edm = EDMPrecondSuperResolution( + img_resolution=H, img_in_channels=C_IN, img_out_channels=C_OUT, + ) + edm(img_x, img_lr, sigma) + c_noise_arg = mock_model.call_args[0][1] + assert c_noise_arg.ndim == 1 + assert c_noise_arg.shape[0] == B + + @patch("hirad.models.preconditioning.network_module") + def test_model_receives_none_class_labels(self, mock_module, img_x, img_lr, sigma): + mock_model = _make_mock_model() + mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model) + edm = EDMPrecondSuperResolution( + img_resolution=H, img_in_channels=C_IN, img_out_channels=C_OUT, + ) + edm(img_x, img_lr, sigma) + call_kwargs = mock_model.call_args[1] + assert call_kwargs["class_labels"] is None + + @patch("hirad.models.preconditioning.network_module") + def test_kwargs_forwarded_to_model(self, mock_module, img_x, img_lr, sigma): + mock_model = _make_mock_model() + mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model) + edm = EDMPrecondSuperResolution( + img_resolution=H, img_in_channels=C_IN, img_out_channels=C_OUT, + ) + lead_time = torch.randint(49, size=(B,)) + edm(img_x, img_lr, sigma, lead_time_label=lead_time) + call_kwargs = mock_model.call_args[1] + assert "lead_time_label" in call_kwargs + torch.testing.assert_close(call_kwargs["lead_time_label"], lead_time) + + +class TestEDMForwardPreconditioning: + """Test that the EDM preconditioning coefficients are applied correctly.""" + + @patch("hirad.models.preconditioning.network_module") + def test_c_noise_is_log_sigma_over_4(self, mock_module): + mock_model = _make_mock_model() + mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model) + edm = EDMPrecondSuperResolution( + img_resolution=H, img_in_channels=C_IN, img_out_channels=C_OUT, + ) + sigma_val = 2.0 + sigma = torch.full((B,), sigma_val) + x = torch.randn(B, C_OUT, H, W) + lr = torch.randn(B, C_IN, H, W) + edm(x, lr, sigma) + c_noise_arg = mock_model.call_args[0][1] + expected = torch.full((B,), torch.tensor(sigma_val).log().item() / 4) + torch.testing.assert_close(c_noise_arg, expected) + + @patch("hirad.models.preconditioning.network_module") + def test_output_is_c_skip_x_plus_c_out_F_x(self, mock_module): + """D(x) = c_skip * x + c_out * F(x); with F(x)=0, D(x) = c_skip * x.""" + mock_model = _make_mock_model() + mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model) + sigma_data = 0.5 + edm = EDMPrecondSuperResolution( + img_resolution=H, + img_in_channels=C_IN, + img_out_channels=C_OUT, + sigma_data=sigma_data, + ) + sigma_val = 1.0 + sigma = torch.full((B,), sigma_val) + x = torch.ones(B, C_OUT, H, W) + lr = torch.randn(B, C_IN, H, W) + out = edm(x, lr, sigma) + # Since mock model returns zeros, D(x) = c_skip * x + c_skip = sigma_data**2 / (sigma_val**2 + sigma_data**2) + expected = torch.full((B, C_OUT, H, W), c_skip) + torch.testing.assert_close(out, expected) + + @patch("hirad.models.preconditioning.network_module") + def test_sigma_reshaped_to_4d(self, mock_module, img_x, img_lr): + """Sigma with shape (B,) should be reshaped to (B, 1, 1, 1).""" + mock_model = _make_mock_model() + mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model) + edm = EDMPrecondSuperResolution( + img_resolution=H, img_in_channels=C_IN, img_out_channels=C_OUT, + ) + sigma_1d = torch.ones(B) + # Should not raise + edm(img_x, img_lr, sigma_1d) + + @patch("hirad.models.preconditioning.network_module") + def test_sigma_2d_reshaped_to_4d(self, mock_module, img_x, img_lr): + """Sigma with shape (B, 1) should be reshaped to (B, 1, 1, 1).""" + mock_model = _make_mock_model() + mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model) + edm = EDMPrecondSuperResolution( + img_resolution=H, img_in_channels=C_IN, img_out_channels=C_OUT, + ) + sigma_2d = torch.ones(B, 1) + # Should not raise + edm(img_x, img_lr, sigma_2d) + + @patch("hirad.models.preconditioning.network_module") + def test_sigma_4d_accepted(self, mock_module, img_x, img_lr): + """Sigma with shape (B, 1, 1, 1) should be accepted.""" + mock_model = _make_mock_model() + mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model) + edm = EDMPrecondSuperResolution( + img_resolution=H, img_in_channels=C_IN, img_out_channels=C_OUT, + ) + sigma_4d = torch.ones(B, 1, 1, 1) + # Should not raise + edm(img_x, img_lr, sigma_4d) + + +class TestEDMForwardImgLrNone: + """Test forward pass when img_lr is None.""" + + @patch("hirad.models.preconditioning.network_module") + def test_no_concatenation_when_img_lr_none(self, mock_module, img_x, sigma): + mock_model = _make_mock_model() + mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model) + edm = EDMPrecondSuperResolution( + img_resolution=H, img_in_channels=C_IN, img_out_channels=C_OUT, + ) + edm(img_x, img_lr=None, sigma=sigma) + model_input = mock_model.call_args[0][0] + # Without img_lr, input is c_in * x only + assert model_input.shape[1] == C_OUT + + +class TestEDMForwardDtypeValidation: + """Test dtype enforcement in forward pass.""" + + @patch("hirad.models.preconditioning.network_module") + def test_raises_on_dtype_mismatch(self, mock_module, img_x, img_lr, sigma): + """Model should raise if the underlying model returns wrong dtype.""" + mock_model = MagicMock(spec=nn.Module) + mock_model.side_effect = lambda x, sigma, class_labels=None, **kw: torch.zeros( + x.shape[0], C_OUT, x.shape[2], x.shape[3], dtype=torch.float16, + ) + mock_model.modules.return_value = iter([]) + mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model) + edm = EDMPrecondSuperResolution( + img_resolution=H, img_in_channels=C_IN, img_out_channels=C_OUT, + ) + with pytest.raises(ValueError, match="Expected the dtype"): + edm(img_x, img_lr, sigma) + + +class TestEDMForwardForceFp32: + """Test the force_fp32 flag.""" + + #TODO: Test doesn't make sence when device is cpu. + @patch("hirad.models.preconditioning.network_module") + def test_force_fp32_uses_float32(self, mock_module, img_x, img_lr, sigma): + mock_model = _make_mock_model() + mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model) + edm = EDMPrecondSuperResolution( + img_resolution=H, + img_in_channels=C_IN, + img_out_channels=C_OUT, + use_fp16=True, + ) + device = "cuda" if torch.cuda.is_available() else "cpu" + edm(img_x.to(device), img_lr.to(device), sigma.to(device), force_fp32=True) + model_input = mock_model.call_args[0][0] + assert model_input.dtype == torch.float32 + + +class TestEDMForwardAutocastEnabled: + """Test that dtype validation is skipped when autocast is enabled.""" + + @patch("hirad.models.preconditioning.network_module") + def test_no_dtype_check_when_autocast_enabled(self, mock_module, img_x, img_lr, sigma): + mock_model = MagicMock(spec=nn.Module) + mock_model.side_effect = lambda x, sigma, class_labels=None, **kw: torch.zeros( + x.shape[0], C_OUT, x.shape[2], x.shape[3], dtype=torch.float16, + ) + mock_model.modules.return_value = iter([]) + mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model) + edm = EDMPrecondSuperResolution( + img_resolution=H, img_in_channels=C_IN, img_out_channels=C_OUT, + ) + with torch.autocast("cuda"): + out = edm(img_x, img_lr, sigma) + assert out.dtype == torch.float32 + + +############################################################################ +# EDMPrecondSuperResolution — round_sigma # +############################################################################ + + +class TestEDMRoundSigma: + """Test round_sigma static method.""" + + def test_float_input(self): + result = EDMPrecondSuperResolution.round_sigma(0.5) + assert isinstance(result, torch.Tensor) + assert result.item() == pytest.approx(0.5) + + def test_list_input(self): + result = EDMPrecondSuperResolution.round_sigma([0.1, 0.5, 1.0]) + assert isinstance(result, torch.Tensor) + assert result.shape == (3,) + torch.testing.assert_close(result, torch.tensor([0.1, 0.5, 1.0])) + + def test_tensor_input(self): + sigma = torch.tensor([0.2, 0.8]) + result = EDMPrecondSuperResolution.round_sigma(sigma) + torch.testing.assert_close(result, sigma) + + +############################################################################ +# EDMPrecondSuperResolution — amp_mode property # +############################################################################ + + +class TestEDMAmpMode: + """Test amp_mode property getter and setter.""" + + @patch("hirad.models.preconditioning.network_module") + def test_amp_mode_returns_none_when_model_lacks_attr(self, mock_module): + mock_model = MagicMock(spec=nn.Module) + del mock_model.amp_mode + mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model) + edm = EDMPrecondSuperResolution( + img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT, + ) + assert edm.amp_mode is None + + @patch("hirad.models.preconditioning.network_module") + def test_amp_mode_returns_model_value(self, mock_module): + mock_model = MagicMock(spec=nn.Module) + mock_model.amp_mode = True + mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model) + edm = EDMPrecondSuperResolution( + img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT, + ) + assert edm.amp_mode is True + + @patch("hirad.models.preconditioning.network_module") + def test_amp_mode_setter_updates_model_and_submodules(self, mock_module): + mock_model = MagicMock(spec=nn.Module) + mock_model.amp_mode = False + sub_module = MagicMock() + sub_module.amp_mode = False + mock_model.modules.return_value = iter([sub_module]) + mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model) + edm = EDMPrecondSuperResolution( + img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT, + ) + edm.amp_mode = True + assert mock_model.amp_mode is True + assert sub_module.amp_mode is True + + @patch("hirad.models.preconditioning.network_module") + def test_amp_mode_setter_rejects_non_bool(self, mock_module): + mock_model = MagicMock(spec=nn.Module) + mock_model.amp_mode = False + mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model) + edm = EDMPrecondSuperResolution( + img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT, + ) + with pytest.raises(TypeError, match="amp_mode must be a boolean"): + edm.amp_mode = "yes" + + @patch("hirad.models.preconditioning.network_module") + def test_amp_mode_setter_skips_model_without_attr(self, mock_module): + mock_model = MagicMock(spec=nn.Module) + del mock_model.amp_mode + mock_model.modules.return_value = iter([]) + mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model) + edm = EDMPrecondSuperResolution( + img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT, + ) + # Should not raise even when model lacks amp_mode + edm.amp_mode = True + + +############################################################################ +# EDMPrecondSuperResolution — nn.Module integration # +############################################################################ + + +class TestEDMModuleIntegration: + """Test that EDMPrecondSuperResolution behaves as a proper nn.Module.""" + + @patch("hirad.models.preconditioning.network_module") + def test_is_nn_module(self, mock_module): + mock_module.SongUNetPosEmbd = MagicMock() + edm = EDMPrecondSuperResolution( + img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT, + ) + assert isinstance(edm, nn.Module) + + @patch("hirad.models.preconditioning.network_module") + def test_scaling_fn_attribute_set(self, mock_module): + mock_module.SongUNetPosEmbd = MagicMock() + edm = EDMPrecondSuperResolution( + img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT, + ) + assert edm.scaling_fn is EDMPrecondSuperResolution._scaling_fn diff --git a/tests/models/test_song_unet.py b/tests/models/test_song_unet.py new file mode 100644 index 0000000..a047a67 --- /dev/null +++ b/tests/models/test_song_unet.py @@ -0,0 +1,1513 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +import pytest +import torch +import torch.nn as nn + +from hirad.models.song_unet import SongUNet, SongUNetPosEmbd + + +# --------------------------------------------------------------------------- +# Helpers / fixtures — use small model configs for fast CPU tests +# --------------------------------------------------------------------------- + +B = 2 +IMG_RES = 32 +IN_CH = 4 +OUT_CH = 3 +SMALL_CFG = dict( + model_channels=32, + channel_mult=[1, 2], + num_blocks=1, + attn_resolutions=[], + dropout=0.0, +) + + +@pytest.fixture() +def small_unet(): + """Return a small SongUNet that runs on CPU.""" + return SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + **SMALL_CFG, + ) + + +@pytest.fixture() +def noise_labels(): + return torch.randn(B) + + +@pytest.fixture() +def class_labels(): + return torch.randint(0, 2, (B, 1)).float() + + +@pytest.fixture() +def input_image(): + return torch.randn(B, IN_CH, IMG_RES, IMG_RES) + + +############################################################################ +# SongUNet — __init__ # +############################################################################ + + +class TestSongUNetInitValidation: + """Test __init__ input validation.""" + + def test_invalid_embedding_type_raises(self): + with pytest.raises(ValueError, match="Invalid embedding_type"): + SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + embedding_type="invalid", + **SMALL_CFG, + ) + + def test_invalid_encoder_type_raises(self): + with pytest.raises(ValueError, match="Invalid encoder_type"): + SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + encoder_type="invalid", + **SMALL_CFG, + ) + + def test_invalid_decoder_type_raises(self): + with pytest.raises(ValueError, match="Invalid decoder_type"): + SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + decoder_type="invalid", + **SMALL_CFG, + ) + + @pytest.mark.parametrize("etype", ["positional", "fourier", "zero"]) + def test_valid_embedding_types_accepted(self, etype): + model = SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + embedding_type=etype, + **SMALL_CFG, + ) + assert model.embedding_type == etype + + @pytest.mark.parametrize("enc", ["standard", "skip", "residual"]) + def test_valid_encoder_types_accepted(self, enc): + SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + encoder_type=enc, + **SMALL_CFG, + ) + + @pytest.mark.parametrize("dec", ["standard", "skip"]) + def test_valid_decoder_types_accepted(self, dec): + SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + decoder_type=dec, + **SMALL_CFG, + ) + + +class TestSongUNetInitResolution: + """Test resolution handling in __init__.""" + + def test_int_resolution_sets_square(self): + model = SongUNet( + img_resolution=32, + in_channels=IN_CH, + out_channels=OUT_CH, + **SMALL_CFG, + ) + assert model.img_shape_x == 32 + assert model.img_shape_y == 32 + + def test_list_resolution_sets_height_width(self): + model = SongUNet( + img_resolution=[24, 32], + in_channels=IN_CH, + out_channels=OUT_CH, + **SMALL_CFG, + ) + assert model.img_shape_y == 24 + assert model.img_shape_x == 32 + + def test_img_resolution_stored(self): + model = SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + **SMALL_CFG, + ) + assert model.img_resolution == IMG_RES + + +class TestSongUNetInitEmbedding: + """Test embedding-related initialization.""" + + def test_positional_embedding_creates_map_noise(self): + model = SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + embedding_type="positional", + **SMALL_CFG, + ) + assert hasattr(model, "map_noise") + assert hasattr(model, "map_layer0") + assert hasattr(model, "map_layer1") + + def test_fourier_embedding_creates_map_noise(self): + model = SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + embedding_type="fourier", + **SMALL_CFG, + ) + assert hasattr(model, "map_noise") + assert hasattr(model, "map_layer0") + assert hasattr(model, "map_layer1") + + def test_zero_embedding_skips_mapping_layers(self): + model = SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + embedding_type="zero", + **SMALL_CFG, + ) + assert not hasattr(model, "map_noise") + assert not hasattr(model, "map_layer0") + assert not hasattr(model, "map_layer1") + assert not hasattr(model, "map_label") + assert not hasattr(model, "map_augment") + + def test_emb_channels_computed_correctly(self): + model = SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + channel_mult_emb=4, + **SMALL_CFG, + ) + assert model.emb_channels == 32 * 4 + + def test_label_dim_creates_map_label(self): + model = SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + label_dim=10, + **SMALL_CFG, + ) + assert model.map_label is not None + + def test_no_label_dim_map_label_is_none(self): + model = SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + label_dim=0, + **SMALL_CFG, + ) + assert model.map_label is None + + def test_augment_dim_creates_map_augment(self): + model = SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + augment_dim=5, + **SMALL_CFG, + ) + assert model.map_augment is not None + + def test_no_augment_dim_map_augment_is_none(self): + model = SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + augment_dim=0, + **SMALL_CFG, + ) + assert model.map_augment is None + + +class TestSongUNetInitEncoder: + """Test encoder construction.""" + + def test_encoder_module_dict_created(self): + model = SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + **SMALL_CFG, + ) + assert isinstance(model.enc, nn.ModuleDict) + assert len(model.enc) > 0 + + def test_skip_encoder_creates_aux_layers(self): + model = SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + encoder_type="skip", + **SMALL_CFG, + ) + aux_down_keys = [k for k in model.enc.keys() if "aux_down" in k] + aux_skip_keys = [k for k in model.enc.keys() if "aux_skip" in k] + assert len(aux_down_keys) > 0 + assert len(aux_skip_keys) > 0 + + def test_residual_encoder_creates_aux_residual(self): + model = SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + encoder_type="residual", + **SMALL_CFG, + ) + aux_keys = [k for k in model.enc.keys() if "aux_residual" in k] + assert len(aux_keys) > 0 + + def test_standard_encoder_has_no_aux_layers(self): + model = SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + encoder_type="standard", + **SMALL_CFG, + ) + aux_keys = [k for k in model.enc.keys() if "aux" in k] + assert len(aux_keys) == 0 + + def test_standard_encoder_layers(self): + model = SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + encoder_type="standard", + **SMALL_CFG, + ) + expected_layers = ["32x32_conv", "32x32_block0", "16x16_block0"] + for layer in expected_layers: + assert hasattr(model.enc, layer) + + + +class TestSongUNetInitDecoder: + """Test decoder construction.""" + + def test_decoder_module_dict_created(self): + model = SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + **SMALL_CFG, + ) + assert isinstance(model.dec, nn.ModuleDict) + assert len(model.dec) > 0 + + def test_skip_decoder_creates_aux_up_layers(self): + model = SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + decoder_type="skip", + **SMALL_CFG, + ) + aux_keys = [k for k in model.dec.keys() if "aux_up" in k] + assert len(aux_keys) > 0 + + def test_standard_decoder_has_no_aux_layers(self): + model = SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + decoder_type="standard", + **SMALL_CFG, + ) + aux_keys = [k for k in model.dec.keys() if "aux_up" in k] + assert len(aux_keys) == 0 + + def test_standard_decoder_layers(self): + model = SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + decoder_type="standard", + **SMALL_CFG, + ) + expected_layers = ["16x16_in0", "16x16_in1", "16x16_block0", "16x16_block1", "32x32_up", + "32x32_block0", "32x32_block1", "32x32_aux_norm", "32x32_aux_conv"] + for layer in expected_layers: + assert hasattr(model.dec, layer) + + +class TestSongUNetInitAdditiveEmbed: + """Test additive positional embedding in __init__.""" + + def test_additive_pos_embed_creates_parameter(self): + model = SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + additive_pos_embed=True, + **SMALL_CFG, + ) + assert hasattr(model, "spatial_emb") + assert isinstance(model.spatial_emb, nn.Parameter) + assert model.spatial_emb.shape == (1, 32, IMG_RES, IMG_RES) + + def test_no_additive_pos_embed_by_default(self): + model = SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + **SMALL_CFG, + ) + assert not hasattr(model, "spatial_emb") + + +class TestSongUNetInitCheckpoint: + """Test checkpoint level configuration.""" + + def test_checkpoint_level_zero(self): + model = SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + checkpoint_level=0, + **SMALL_CFG, + ) + # threshold = (img_shape_y >> 0) + 1 = 32 >> 0 + 1 = 32 + 1 = 33 + assert model.checkpoint_threshold == IMG_RES + 1 + + def test_checkpoint_level_one(self): + model = SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + checkpoint_level=1, + **SMALL_CFG, + ) + # threshold = (32 >> 1) + 1 = 16 + 1 = 17 + assert model.checkpoint_threshold == (IMG_RES >> 1) + 1 + + +class TestSongUNetIsModule: + """Test that SongUNet is a proper nn.Module.""" + + def test_is_nn_module(self, small_unet): + assert isinstance(small_unet, nn.Module) + + def test_has_parameters(self, small_unet): + params = list(small_unet.parameters()) + assert len(params) > 0 + + +############################################################################ +# SongUNet — forward # +############################################################################ + + +class TestSongUNetForwardShape: + """Test forward pass output shapes.""" + + def test_output_shape(self, small_unet, input_image, noise_labels, class_labels): + out = small_unet(input_image, noise_labels, class_labels) + assert out.shape == (B, OUT_CH, IMG_RES, IMG_RES) + + def test_output_dtype_float32(self, small_unet, input_image, noise_labels, class_labels): + out = small_unet(input_image, noise_labels, class_labels) + assert out.dtype == torch.float32 + + +class TestSongUNetForwardEmbeddingTypes: + """Test forward with different embedding types.""" + + def test_zero_embedding_forward(self, input_image, noise_labels, class_labels): + model = SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + embedding_type="zero", + **SMALL_CFG, + ) + out = model(input_image, noise_labels, class_labels) + assert out.shape == (B, OUT_CH, IMG_RES, IMG_RES) + + def test_fourier_embedding_forward(self, input_image, noise_labels, class_labels): + model = SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + embedding_type="fourier", + **SMALL_CFG, + ) + out = model(input_image, noise_labels, class_labels) + assert out.shape == (B, OUT_CH, IMG_RES, IMG_RES) + + def test_positional_embedding_forward(self, input_image, noise_labels, class_labels): + model = SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + embedding_type="positional", + **SMALL_CFG, + ) + out = model(input_image, noise_labels, class_labels) + assert out.shape == (B, OUT_CH, IMG_RES, IMG_RES) + + +class TestSongUNetForwardEncoderDecoder: + """Test forward with various encoder/decoder combos.""" + + @pytest.mark.parametrize("enc,dec", [ + ("standard", "standard"), + ("skip", "standard"), + ("residual", "standard"), + ("standard", "skip"), + ("skip", "skip"), + ("residual", "skip"), + ]) + def test_encoder_decoder_combinations(self, enc, dec, input_image, noise_labels, class_labels): + model = SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + encoder_type=enc, + decoder_type=dec, + **SMALL_CFG, + ) + out = model(input_image, noise_labels, class_labels) + assert out.shape == (B, OUT_CH, IMG_RES, IMG_RES) + + +class TestSongUNetForwardLabel: + """Test label dropout behavior during training.""" + + def test_label_dropout_in_training(self, input_image, noise_labels): + model = SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + label_dim=5, + label_dropout=0.5, + **SMALL_CFG, + ) + model.train() + labels = torch.ones(B, 5) + out = model(input_image, noise_labels, labels) + assert out.shape == (B, OUT_CH, IMG_RES, IMG_RES) + + def test_label_dropout_in_eval(self, input_image, noise_labels): + model = SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + label_dim=5, + label_dropout=0.5, + **SMALL_CFG, + ) + model.eval() + labels = torch.ones(B, 5) + out = model(input_image, noise_labels, labels) + assert out.shape == (B, OUT_CH, IMG_RES, IMG_RES) + + def test_map_label_called_when_label_dim_positive(self, input_image, noise_labels): + model = SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + label_dim=5, + **SMALL_CFG, + ) + assert model.map_label is not None + called = [] + model.map_label.register_forward_hook(lambda m, i, o: called.append(True)) + out = model(input_image, noise_labels, torch.ones(B, 5)) + assert len(called) == 1 + + +class TestSongUNetForwardAugment: + """Test augment dropout behavior during training.""" + + def test_map_augment_called_when_augment_dim_positive(self, input_image, noise_labels, class_labels): + model = SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + augment_dim=3, + **SMALL_CFG, + ) + assert model.map_augment is not None + called = [] + model.map_augment.register_forward_hook(lambda m, i, o: called.append(True)) + out = model(input_image, noise_labels, class_labels, augment_labels=torch.ones(B, 3)) + assert len(called) == 1 + + def test_no_augment_labels_skips_map_augment(self, input_image, noise_labels, class_labels): + model = SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + augment_dim=3, + **SMALL_CFG, + ) + assert model.map_augment is not None + called = [] + model.map_augment.register_forward_hook(lambda m, i, o: called.append(True)) + out = model(input_image, noise_labels, class_labels) + assert len(called) == 0 + + +class TestSongUNetForwardNoiseEmbedding: + def test_map_noise_called_when_embedding_type_fourier(self, input_image, noise_labels, class_labels): + model = SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + embedding_type="fourier", + **SMALL_CFG, + ) + assert model.map_noise is not None + called_map_noise = [] + called_map_layer0 = [] + called_map_layer1 = [] + model.map_noise.register_forward_hook(lambda m, i, o: called_map_noise.append(True)) + model.map_layer0.register_forward_hook(lambda m, i, o: called_map_layer0.append(True)) + model.map_layer1.register_forward_hook(lambda m, i, o: called_map_layer1.append(True)) + out = model(input_image, noise_labels, class_labels) + assert len(called_map_noise) == 1 + assert len(called_map_layer0) == 1 + assert len(called_map_layer1) == 1 + + def test_map_noise_called_when_embedding_type_positional(self, input_image, noise_labels, class_labels): + model = SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + embedding_type="positional", + **SMALL_CFG, + ) + assert model.map_noise is not None + called_map_noise = [] + called_map_layer0 = [] + called_map_layer1 = [] + model.map_noise.register_forward_hook(lambda m, i, o: called_map_noise.append(True)) + model.map_layer0.register_forward_hook(lambda m, i, o: called_map_layer0.append(True)) + model.map_layer1.register_forward_hook(lambda m, i, o: called_map_layer1.append(True)) + out = model(input_image, noise_labels, class_labels) + assert len(called_map_noise) == 1 + assert len(called_map_layer0) == 1 + assert len(called_map_layer1) == 1 + + +class TestSongUNetForwardEncoderDecoderCalls: + """Test that all encoder and decoder blocks are called during forward.""" + + def test_all_enc_dec_blocks_called_parametrized( + self, input_image, noise_labels, class_labels + ): + """Test across different encoder/decoder combos.""" + for enc in ["standard", "skip", "residual"]: + model = SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + encoder_type=enc, + **SMALL_CFG, + ) + called = {} + handles = [] + for name, block in list(model.enc.items()): + called[name] = 0 + handle = block.register_forward_hook( + lambda m, i, o, n=name: called.__setitem__(n, called[n] + 1) + ) + handles.append(handle) + + model(input_image, noise_labels, class_labels) + + for handle in handles: + handle.remove() + + for name, count in called.items(): + assert count == 1, ( + f"[enc={enc}] Block '{name}' called {count} times, expected 1" + ) + + def test_all_decoder_blocks_called_parametrized( + self, input_image, noise_labels, class_labels + ): + """Test across different encoder/decoder combos.""" + for dec in ["standard", "skip"]: + model = SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + decoder_type=dec, + **SMALL_CFG, + ) + called = {} + handles = [] + for name, block in list(model.dec.items()): + called[name] = 0 + handle = block.register_forward_hook( + lambda m, i, o, n=name: called.__setitem__(n, called[n] + 1) + ) + handles.append(handle) + + model(input_image, noise_labels, class_labels) + + for handle in handles: + handle.remove() + + for name, count in called.items(): + assert count == 1, ( + f"[dec={dec}] Block '{name}' called {count} times, expected 1" + ) + + +class TestSongUNetForwardAdditiveEmbed: + """Test forward with additive positional embedding.""" + + def test_additive_embed_forward(self, input_image, noise_labels, class_labels): + model = SongUNet( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + additive_pos_embed=True, + **SMALL_CFG, + ) + out = model(input_image, noise_labels, class_labels) + assert out.shape == (B, OUT_CH, IMG_RES, IMG_RES) + + +class TestSongUNetForwardRectangularResolution: + """Test forward with non-square input.""" + + def test_rectangular_resolution_forward(self, noise_labels, class_labels): + res = [16, 32] + model = SongUNet( + img_resolution=res, + in_channels=IN_CH, + out_channels=OUT_CH, + **SMALL_CFG, + ) + x = torch.randn(B, IN_CH, res[0], res[1]) + out = model(x, noise_labels, class_labels) + assert out.shape == (B, OUT_CH, res[0], res[1]) + + +############################################################################ +# SongUNetPosEmbd — __init__ # +############################################################################ + + +# Positional embedding adds N_grid_channels to in_channels +N_GRID = 4 +PE_IN_CH = IN_CH + N_GRID + +PE_SMALL_CFG = dict( + model_channels=32, + channel_mult=[1, 2], + num_blocks=1, + attn_resolutions=[], + dropout=0.0, + use_apex_gn=False, +) + + +@pytest.fixture() +def small_pos_unet(): + return SongUNetPosEmbd( + img_resolution=IMG_RES, + in_channels=PE_IN_CH, + out_channels=OUT_CH, + N_grid_channels=N_GRID, + **PE_SMALL_CFG, + ) + + +class TestSongUNetPosEmbdInitGridType: + """Test grid type selection in __init__.""" + + def test_sinusoidal_grid_default(self): + model = SongUNetPosEmbd( + img_resolution=IMG_RES, + in_channels=PE_IN_CH, + out_channels=OUT_CH, + gridtype="sinusoidal", + N_grid_channels=N_GRID, + **PE_SMALL_CFG, + ) + assert model.gridtype == "sinusoidal" + assert model.pos_embd.shape == (N_GRID, IMG_RES, IMG_RES) + + def test_learnable_grid(self): + model = SongUNetPosEmbd( + img_resolution=IMG_RES, + in_channels=PE_IN_CH, + out_channels=OUT_CH, + gridtype="learnable", + N_grid_channels=N_GRID, + **PE_SMALL_CFG, + ) + assert model.gridtype == "learnable" + assert isinstance(model.pos_embd, nn.Parameter) + assert model.pos_embd.shape == (N_GRID, IMG_RES, IMG_RES) + + def test_linear_grid(self): + n_ch = 2 + model = SongUNetPosEmbd( + img_resolution=IMG_RES, + in_channels=IN_CH + n_ch, + out_channels=OUT_CH, + gridtype="linear", + N_grid_channels=n_ch, + **PE_SMALL_CFG, + ) + assert model.gridtype == "linear" + assert model.pos_embd.shape == (2, IMG_RES, IMG_RES) + + def test_test_grid(self): + n_ch = 2 + model = SongUNetPosEmbd( + img_resolution=IMG_RES, + in_channels=IN_CH + n_ch, + out_channels=OUT_CH, + gridtype="test", + N_grid_channels=n_ch, + **PE_SMALL_CFG, + ) + assert model.pos_embd.shape == (2, IMG_RES, IMG_RES) + + +class TestSongUNetPosEmbdInitGridChannelsValidation: + """Test N_grid_channels validation.""" + + def test_linear_grid_requires_2_channels(self): + with pytest.raises(ValueError, match="N_grid_channels must be set to 2"): + SongUNetPosEmbd( + img_resolution=IMG_RES, + in_channels=IN_CH + 4, + out_channels=OUT_CH, + gridtype="linear", + N_grid_channels=4, + **PE_SMALL_CFG, + ) + + def test_sinusoidal_multi_freq_requires_factor_of_4(self): + with pytest.raises(ValueError, match="N_grid_channels must be a factor of 4"): + SongUNetPosEmbd( + img_resolution=IMG_RES, + in_channels=IN_CH + 5, + out_channels=OUT_CH, + gridtype="sinusoidal", + N_grid_channels=5, + **PE_SMALL_CFG, + ) + + def test_sinusoidal_8_channels_accepted(self): + n_ch = 8 + model = SongUNetPosEmbd( + img_resolution=IMG_RES, + in_channels=IN_CH + n_ch, + out_channels=OUT_CH, + gridtype="sinusoidal", + N_grid_channels=n_ch, + **PE_SMALL_CFG, + ) + assert model.pos_embd.shape == (n_ch, IMG_RES, IMG_RES) + + def test_zero_grid_channels_returns_none(self): + model = SongUNetPosEmbd( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + N_grid_channels=0, + **PE_SMALL_CFG, + ) + assert model.pos_embd is None + + def test_unsupported_gridtype_raises(self): + with pytest.raises(ValueError, match="Gridtype not supported"): + SongUNetPosEmbd( + img_resolution=IMG_RES, + in_channels=PE_IN_CH, + out_channels=OUT_CH, + gridtype="unknown", + N_grid_channels=N_GRID, + **PE_SMALL_CFG, + ) + + +class TestSongUNetPosEmbdInitLeadTime: + """Test lead time related initialization.""" + + def test_lead_time_mode_creates_lt_embd(self): + lt_ch = 2 + lt_steps = 5 + model = SongUNetPosEmbd( + img_resolution=IMG_RES, + in_channels=PE_IN_CH + lt_ch, + out_channels=OUT_CH, + N_grid_channels=N_GRID, + lead_time_mode=True, + lead_time_channels=lt_ch, + lead_time_steps=lt_steps, + **PE_SMALL_CFG, + ) + assert model.lead_time_mode is True + assert model.lt_embd is not None + assert model.lt_embd.shape == (lt_steps, lt_ch, IMG_RES, IMG_RES) + + def test_no_lead_time_mode_by_default(self): + model = SongUNetPosEmbd( + img_resolution=IMG_RES, + in_channels=PE_IN_CH, + out_channels=OUT_CH, + N_grid_channels=N_GRID, + **PE_SMALL_CFG, + ) + assert model.lead_time_mode is False + + def test_lead_time_none_channels_returns_none_embd(self): + model = SongUNetPosEmbd( + img_resolution=IMG_RES, + in_channels=PE_IN_CH, + out_channels=OUT_CH, + N_grid_channels=N_GRID, + lead_time_mode=True, + lead_time_channels=None, + lead_time_steps=9, + **PE_SMALL_CFG, + ) + assert model.lt_embd is None + + def test_lead_time_none_steps_returns_none_embd(self): + model = SongUNetPosEmbd( + img_resolution=IMG_RES, + in_channels=PE_IN_CH, + out_channels=OUT_CH, + N_grid_channels=N_GRID, + lead_time_mode=True, + lead_time_channels=2, + lead_time_steps=None, + **PE_SMALL_CFG, + ) + assert model.lt_embd is None + + +class TestSongUNetPosEmbdInitProbChannels: + """Test prob_channels initialization.""" + + def test_prob_channels_creates_scalar(self): + model = SongUNetPosEmbd( + img_resolution=IMG_RES, + in_channels=PE_IN_CH + 2, + out_channels=OUT_CH, + N_grid_channels=N_GRID, + lead_time_mode=True, + lead_time_channels=2, + lead_time_steps=3, + prob_channels=[0, 1], + **PE_SMALL_CFG, + ) + assert hasattr(model, "scalar") + assert model.scalar.shape == (1, 2, 1, 1) + + def test_empty_prob_channels_no_scalar(self): + model = SongUNetPosEmbd( + img_resolution=IMG_RES, + in_channels=PE_IN_CH + 2, + out_channels=OUT_CH, + N_grid_channels=N_GRID, + lead_time_mode=True, + lead_time_channels=2, + lead_time_steps=3, + prob_channels=[], + **PE_SMALL_CFG, + ) + assert not hasattr(model, "scalar") + + +class TestSongUNetPosEmbdIsModule: + """Test that SongUNetPosEmbd is a proper nn.Module and subclass of SongUNet.""" + + def test_is_nn_module(self, small_pos_unet): + assert isinstance(small_pos_unet, nn.Module) + + def test_is_subclass_of_song_unet(self, small_pos_unet): + assert isinstance(small_pos_unet, SongUNet) + + +############################################################################ +# SongUNetPosEmbd — _get_positional_embedding # +############################################################################ + + +class TestGetPositionalEmbedding: + """Test _get_positional_embedding for various grid types.""" + + def test_sinusoidal_4ch_grid_not_requires_grad(self): + model = SongUNetPosEmbd( + img_resolution=IMG_RES, + in_channels=PE_IN_CH, + out_channels=OUT_CH, + gridtype="sinusoidal", + N_grid_channels=N_GRID, + **PE_SMALL_CFG, + ) + assert not model.pos_embd.requires_grad + + def test_linear_grid_not_requires_grad(self): + model = SongUNetPosEmbd( + img_resolution=IMG_RES, + in_channels=IN_CH + 2, + out_channels=OUT_CH, + gridtype="linear", + N_grid_channels=2, + **PE_SMALL_CFG, + ) + assert not model.pos_embd.requires_grad + + def test_learnable_grid_requires_grad(self): + model = SongUNetPosEmbd( + img_resolution=IMG_RES, + in_channels=PE_IN_CH, + out_channels=OUT_CH, + gridtype="learnable", + N_grid_channels=N_GRID, + **PE_SMALL_CFG, + ) + assert model.pos_embd.requires_grad + + def test_sinusoidal_values_in_range(self): + model = SongUNetPosEmbd( + img_resolution=IMG_RES, + in_channels=PE_IN_CH, + out_channels=OUT_CH, + gridtype="sinusoidal", + N_grid_channels=N_GRID, + **PE_SMALL_CFG, + ) + assert model.pos_embd.min() >= -1.0 + assert model.pos_embd.max() <= 1.0 + + def test_linear_values_in_range(self): + model = SongUNetPosEmbd( + img_resolution=IMG_RES, + in_channels=IN_CH + 2, + out_channels=OUT_CH, + gridtype="linear", + N_grid_channels=2, + **PE_SMALL_CFG, + ) + assert model.pos_embd.min() >= -1.0 + assert model.pos_embd.max() <= 1.0 + + def test_rectangular_sinusoidal_grid(self): + model = SongUNetPosEmbd( + img_resolution=[16, 32], + in_channels=PE_IN_CH, + out_channels=OUT_CH, + gridtype="sinusoidal", + N_grid_channels=N_GRID, + **PE_SMALL_CFG, + ) + assert model.pos_embd.shape == (N_GRID, 16, 32) + + def test_rectangular_grid_sinusoidal_8ch(self): + n_ch = 8 + model = SongUNetPosEmbd( + img_resolution=[16, 32], + in_channels=IN_CH + n_ch, + out_channels=OUT_CH, + gridtype="sinusoidal", + N_grid_channels=n_ch, + **PE_SMALL_CFG, + ) + assert model.pos_embd.shape == (n_ch, 16, 32) + + def test_rectangular_linear_grid(self): + model = SongUNetPosEmbd( + img_resolution=[16, 32], + in_channels=IN_CH + 2, + out_channels=OUT_CH, + gridtype="linear", + N_grid_channels=2, + **PE_SMALL_CFG, + ) + assert model.pos_embd.shape == (2, 16, 32) + + def test_rectangular_learnable_grid(self): + model = SongUNetPosEmbd( + img_resolution=[16, 32], + in_channels=PE_IN_CH, + out_channels=OUT_CH, + gridtype="learnable", + N_grid_channels=N_GRID, + **PE_SMALL_CFG, + ) + assert model.pos_embd.shape == (N_GRID, 16, 32) + + def test_linear_grid_simple_values(self): + model = SongUNetPosEmbd( + img_resolution=IMG_RES, + in_channels=IN_CH + 2, + out_channels=OUT_CH, + gridtype="linear", + N_grid_channels=2, + **PE_SMALL_CFG, + ) + # Check that the first channel is a vertical gradient and the second is horizontal + for y in range(IMG_RES): + for x in range(IMG_RES): + expected_y = (y / (IMG_RES - 1)) * 2 - 1 + expected_x = (x / (IMG_RES - 1)) * 2 - 1 + assert torch.isclose(model.pos_embd[0, y, x], torch.tensor([expected_x]), atol=1e-5) + assert torch.isclose(model.pos_embd[1, y, x], torch.tensor([expected_y]), atol=1e-5) + + def test_sinusoidal_grid_simple_values(self): + model = SongUNetPosEmbd( + img_resolution=IMG_RES, + in_channels=PE_IN_CH, + out_channels=OUT_CH, + gridtype="sinusoidal", + N_grid_channels=4, + **PE_SMALL_CFG, + ) + # Check that the first two channels are sinusoids of different frequencies + # and the next two channels are the cosine counterparts + for y in range(IMG_RES): + for x in range(IMG_RES): + expected_ch0 = torch.sin(2 * torch.pi * torch.tensor([x]) / (IMG_RES - 1)) + expected_ch1 = torch.sin(2 * torch.pi * torch.tensor([y]) / (IMG_RES - 1)) + expected_ch2 = torch.cos(2 * torch.pi * torch.tensor([x]) / (IMG_RES - 1)) + expected_ch3 = torch.cos(2 * torch.pi * torch.tensor([y]) / (IMG_RES - 1)) + assert torch.isclose(model.pos_embd[0, y, x], expected_ch0, atol=1e-5) + assert torch.isclose(model.pos_embd[1, y, x], expected_ch1, atol=1e-5) + assert torch.isclose(model.pos_embd[2, y, x], expected_ch2, atol=1e-5) + assert torch.isclose(model.pos_embd[3, y, x], expected_ch3, atol=1e-5) + + #TODO: When more than 4 channels are used for sinusoidal, the frequencies should be multiples of the base frequency (2). + # freq_bands = 2.0 ** np.linspace(0.0, num_freq, num=num_freq) is currently in code which gives + # freqs = [1,4] instead of [1,2] for N_grid_channels=8. This seems to be a bug if we want the base 2. + # Leaving it like this for now since we have checkpoints with 8 sinusoidal channels that use these frequencies, + # but it should be fixed in the future and this test should be updated to reflect the intended behavior. + def test_sinusoidal_8ch_grid_simple_values(self): + model = SongUNetPosEmbd( + img_resolution=IMG_RES, + in_channels=IN_CH + 8, + out_channels=OUT_CH, + gridtype="sinusoidal", + N_grid_channels=8, + **PE_SMALL_CFG, + ) + # Check that the first 4 channels are sinusoids of different frequencies + # and the next 4 channels are the cosine counterparts + for y in range(IMG_RES): + for x in range(IMG_RES): + for idx, i in enumerate([0,2]): + expected_ch_0 = torch.sin((2**i) * 2 * torch.pi * torch.tensor([x]) / (IMG_RES - 1)) + expected_ch_1 = torch.sin((2**i) * 2 * torch.pi * torch.tensor([y]) / (IMG_RES - 1)) + expected_ch_2 = torch.cos((2**i) * 2 * torch.pi * torch.tensor([x]) / (IMG_RES - 1)) + expected_ch_3 = torch.cos((2**i) * 2 * torch.pi * torch.tensor([y]) / (IMG_RES - 1)) + assert torch.isclose(model.pos_embd[4*idx, y, x], expected_ch_0, atol=1e-5) + assert torch.isclose(model.pos_embd[4*idx + 1, y, x], expected_ch_1, atol=1e-5) + assert torch.isclose(model.pos_embd[4*idx + 2, y, x], expected_ch_2, atol=1e-5) + assert torch.isclose(model.pos_embd[4*idx + 3, y, x], expected_ch_3, atol=1e-5) + + +############################################################################ +# SongUNetPosEmbd — forward # +############################################################################ + + +class TestSongUNetPosEmbdForwardBasic: + """Test basic forward pass for SongUNetPosEmbd.""" + + def test_output_shape(self, small_pos_unet, noise_labels, class_labels): + x = torch.randn(B, IN_CH, IMG_RES, IMG_RES) + out = small_pos_unet(x, noise_labels, class_labels) + assert out.shape == (B, OUT_CH, IMG_RES, IMG_RES) + + def test_output_dtype_float32(self, small_pos_unet, noise_labels, class_labels): + x = torch.randn(B, IN_CH, IMG_RES, IMG_RES) + out = small_pos_unet(x, noise_labels, class_labels) + assert out.dtype == torch.float32 + + +class TestSongUNetPosEmbdForwardErrors: + """Test that forward raises for mutually exclusive arguments.""" + + def test_raises_when_both_selector_and_index_provided( + self, small_pos_unet, noise_labels, class_labels + ): + x = torch.randn(B, IN_CH, IMG_RES, IMG_RES) + global_index = torch.zeros(1, 2, IMG_RES, IMG_RES, dtype=torch.long) + selector = lambda emb: emb[None].expand(B, -1, -1, -1) + with pytest.raises(ValueError, match="Cannot provide both"): + small_pos_unet( + x, noise_labels, class_labels, + global_index=global_index, + embedding_selector=selector, + ) + + def test_raises_when_lead_time_mode_and_embedding_selector_provided(self, small_pos_unet, noise_labels, class_labels): + small_pos_unet.lead_time_mode = True + x = torch.randn(B, IN_CH, IMG_RES, IMG_RES) + selector = lambda emb: emb[None].expand(B, -1, -1, -1) + with pytest.raises(ValueError, match="Embedding selector is not supported in lead time mode."): + small_pos_unet( + x, noise_labels, class_labels, + embedding_selector=selector, + ) + + +class TestSongUNetPosEmbdForwardSelector: + """Test forward with embedding_selector.""" + + def test_selector_applied(self, small_pos_unet, noise_labels, class_labels): + x = torch.randn(B, IN_CH, IMG_RES, IMG_RES) + selector = lambda emb: emb[None].expand(B, -1, -1, -1) + out = small_pos_unet( + x, noise_labels, class_labels, embedding_selector=selector + ) + assert out.shape == (B, OUT_CH, IMG_RES, IMG_RES) + + def test_selector_takes_subset_of_embeddings(self, small_pos_unet, noise_labels, class_labels): + P = 2 + x = torch.randn(B * P, IN_CH, IMG_RES//2, IMG_RES//2) + # Selector that takes only the first 2 channels of the positional embedding + selector = lambda emb: emb[None].expand(B * P, -1, -1, -1)[:,:,:IMG_RES//2,:IMG_RES//2] + noise_labels = torch.randn(B * P) + class_labels = torch.randint(0, 1, (B * P, 1)).float() + out = small_pos_unet( + x, noise_labels, class_labels, embedding_selector=selector + ) + assert out.shape == (B*P, OUT_CH, IMG_RES//2, IMG_RES//2) + + +class TestSongUNetPosEmbdForwardGlobalIndex: + """Test forward with global_index.""" + + def test_global_index_selects_embeddings( + self, small_pos_unet, noise_labels, class_labels + ): + x = torch.randn(B, IN_CH, IMG_RES, IMG_RES) + # Create index that selects the full grid + idx_y = torch.arange(IMG_RES).view(1, 1, IMG_RES, 1).expand(1, 1, IMG_RES, IMG_RES) + idx_x = torch.arange(IMG_RES).view(1, 1, 1, IMG_RES).expand(1, 1, IMG_RES, IMG_RES) + global_index = torch.cat([idx_y, idx_x], dim=1) # (P, 2, H, W) + out = small_pos_unet(x, noise_labels, class_labels, global_index=global_index) + assert out.shape == (B, OUT_CH, IMG_RES, IMG_RES) + + def test_global_index_selects_subset_of_embeddings( + self, small_pos_unet + ): + P = 2 + x = torch.randn(B * P, IN_CH, IMG_RES//2, IMG_RES//2) + # Create index that selects only the top-left quadrant of the grid + idx_y = torch.arange(IMG_RES//2).view(1, 1, IMG_RES//2, 1).expand(P, 1, IMG_RES//2, IMG_RES//2) + idx_x = torch.arange(IMG_RES//2).view(1, 1, 1, IMG_RES//2).expand(P, 1, IMG_RES//2, IMG_RES//2) + global_index = torch.cat([idx_y, idx_x], dim=1) # (P, 2, H, W) + noise_labels = torch.randn(B * P) + class_labels = torch.randint(0, 1, (B * P, 1)).float() + out = small_pos_unet(x, noise_labels, class_labels, global_index=global_index) + assert out.shape == (B * P, OUT_CH, IMG_RES//2, IMG_RES//2) + + +class TestSongUNetPosEmbdForwardLeadTime: + """Test forward pass with lead_time_mode enabled.""" + + def _make_lead_time_model(self): + lt_ch = 2 + return SongUNetPosEmbd( + img_resolution=IMG_RES, + in_channels=PE_IN_CH + lt_ch, + out_channels=OUT_CH, + N_grid_channels=N_GRID, + lead_time_mode=True, + lead_time_channels=lt_ch, + lead_time_steps=5, + prob_channels=[], + **PE_SMALL_CFG, + ) + + def test_lead_time_forward_shape(self, noise_labels, class_labels): + model = self._make_lead_time_model() + x = torch.randn(B, IN_CH, IMG_RES, IMG_RES) + lead_time = torch.zeros(B, dtype=torch.long) + out = model(x, noise_labels, class_labels, lead_time_label=lead_time) + assert out.shape == (B, OUT_CH, IMG_RES, IMG_RES) + + def test_lead_time_with_prob_channels_eval(self, noise_labels, class_labels): + """In eval mode, prob_channels should go through softmax.""" + lt_ch = 2 + out_ch = 4 + model = SongUNetPosEmbd( + img_resolution=IMG_RES, + in_channels=PE_IN_CH + lt_ch, + out_channels=out_ch, + N_grid_channels=N_GRID, + lead_time_mode=True, + lead_time_channels=lt_ch, + lead_time_steps=5, + prob_channels=[2, 3], + **PE_SMALL_CFG, + ) + model.eval() + x = torch.randn(B, IN_CH, IMG_RES, IMG_RES) + lead_time = torch.zeros(B, dtype=torch.long) + out = model(x, noise_labels, class_labels, lead_time_label=lead_time) + assert out.shape == (B, out_ch, IMG_RES, IMG_RES) + # Prob channels should sum to 1 (softmax) + prob_sum = out[:, [2, 3]].sum(dim=1) + torch.testing.assert_close( + prob_sum, torch.ones(B, IMG_RES, IMG_RES), atol=1e-5, rtol=1e-5 + ) + + def test_lead_time_with_prob_channels_train(self, noise_labels, class_labels): + """In training mode, prob_channels should output raw logits (no softmax).""" + lt_ch = 2 + out_ch = 4 + model = SongUNetPosEmbd( + img_resolution=IMG_RES, + in_channels=PE_IN_CH + lt_ch, + out_channels=out_ch, + N_grid_channels=N_GRID, + lead_time_mode=True, + lead_time_channels=lt_ch, + lead_time_steps=5, + prob_channels=[2, 3], + **PE_SMALL_CFG, + ) + model.train() + x = torch.randn(B, IN_CH, IMG_RES, IMG_RES) + lead_time = torch.zeros(B, dtype=torch.long) + out = model(x, noise_labels, class_labels, lead_time_label=lead_time) + assert out.shape == (B, out_ch, IMG_RES, IMG_RES) + + def test_lead_time_with_global_index(self, noise_labels, class_labels): + """Test that global_index can be used with lead_time_mode.""" + lt_ch = 2 + model = SongUNetPosEmbd( + img_resolution=IMG_RES, + in_channels=PE_IN_CH + lt_ch, + out_channels=OUT_CH, + N_grid_channels=N_GRID, + lead_time_mode=True, + lead_time_channels=lt_ch, + lead_time_steps=5, + prob_channels=[], + **PE_SMALL_CFG, + ) + P = 2 + x = torch.randn(B * P, IN_CH, IMG_RES//2, IMG_RES//2) + # Create index that selects only the top-left quadrant of the grid + idx_y = torch.arange(IMG_RES//2).view(1, 1, IMG_RES//2, 1).expand(P, 1, IMG_RES//2, IMG_RES//2) + idx_x = torch.arange(IMG_RES//2).view(1, 1, 1, IMG_RES//2).expand(P, 1, IMG_RES//2, IMG_RES//2) + global_index = torch.cat([idx_y, idx_x], dim=1) # (P, 2, H, W) + noise_labels = torch.randn(B * P) + class_labels = torch.randint(0, 1, (B * P, 1)).float() + out = model(x, noise_labels, class_labels, global_index=global_index, lead_time_label=torch.zeros(B, dtype=torch.long)) + assert out.shape == (B * P, OUT_CH, IMG_RES//2, IMG_RES//2) + + +class TestSongUNetPosEmbdForwardNoneGrid: + """Test forward pass when N_grid_channels=0 (no positional embedding).""" + + def test_no_pos_embd_forward(self, noise_labels, class_labels): + model = SongUNetPosEmbd( + img_resolution=IMG_RES, + in_channels=IN_CH, + out_channels=OUT_CH, + N_grid_channels=0, + **PE_SMALL_CFG, + ) + x = torch.randn(B, IN_CH, IMG_RES, IMG_RES) + out = model(x, noise_labels, class_labels) + assert out.shape == (B, OUT_CH, IMG_RES, IMG_RES) + + +class TestSongUNetPosEmbdForwardRectangular: + """Test forward with non-square resolution.""" + + def test_rectangular_forward(self, noise_labels, class_labels): + res = [16, 32] + model = SongUNetPosEmbd( + img_resolution=res, + in_channels=PE_IN_CH, + out_channels=OUT_CH, + N_grid_channels=N_GRID, + **PE_SMALL_CFG, + ) + x = torch.randn(B, IN_CH, res[0], res[1]) + out = model(x, noise_labels, class_labels) + assert out.shape == (B, OUT_CH, res[0], res[1]) + + +############################################################################ +# SongUNetPosEmbd — positional_embedding_indexing # +############################################################################ + + +class TestPositionalEmbeddingIndexing: + """Test positional_embedding_indexing method.""" + + def test_no_index_returns_full_grid(self, small_pos_unet): + x = torch.randn(B, IN_CH, IMG_RES, IMG_RES) + result = small_pos_unet.positional_embedding_indexing(x) + assert result.shape == (B, N_GRID, IMG_RES, IMG_RES) + + def test_no_index_expands_batch(self, small_pos_unet): + x = torch.randn(4, IN_CH, IMG_RES, IMG_RES) + result = small_pos_unet.positional_embedding_indexing(x) + assert result.shape[0] == 4 + + def test_global_index_selects_correctly(self, small_pos_unet): + model = SongUNetPosEmbd( + img_resolution=IMG_RES, + in_channels=PE_IN_CH, + out_channels=OUT_CH, + gridtype="linear", + N_grid_channels=2, + **PE_SMALL_CFG, + ) + P = 2 + x = torch.randn(B * P, IN_CH, IMG_RES//2, IMG_RES//2) + idx_y = torch.arange(IMG_RES//2).view(1, 1, IMG_RES//2, 1).expand(P, 1, IMG_RES//2, IMG_RES//2) + idx_x = torch.arange(IMG_RES//2).view(1, 1, 1, IMG_RES//2).expand(P, 1, IMG_RES//2, IMG_RES//2) + global_index = torch.cat([idx_y, idx_x], dim=1) + result = model.positional_embedding_indexing(x, global_index=global_index) + assert result.shape == (B * P, 2, IMG_RES//2, IMG_RES//2) + assert torch.allclose(result, model.pos_embd[None, :, :IMG_RES//2, :IMG_RES//2].expand(B*P, -1, -1, -1)) + + def test_global_index_selects_correctly_with_lead_time(self, small_pos_unet): + model = SongUNetPosEmbd( + img_resolution=IMG_RES, + in_channels=PE_IN_CH + 2, + out_channels=OUT_CH, + gridtype="linear", + N_grid_channels=2, + lead_time_mode=True, + lead_time_channels=2, + lead_time_steps=5, + prob_channels=[], + **PE_SMALL_CFG, + ) + P = 2 + x = torch.randn(B * P, IN_CH, IMG_RES//2, IMG_RES//2) + idx_y = torch.arange(IMG_RES//2).view(1, 1, IMG_RES//2, 1).expand(P, 1, IMG_RES//2, IMG_RES//2) + idx_x = torch.arange(IMG_RES//2).view(1, 1, 1, IMG_RES//2).expand(P, 1, IMG_RES//2, IMG_RES//2) + global_index = torch.cat([idx_y, idx_x], dim=1) + result = model.positional_embedding_indexing(x, global_index=global_index, lead_time_label=torch.zeros(B, dtype=torch.long)) + assert result.shape == (B * P, 2 + 2, IMG_RES//2, IMG_RES//2) + expected_pos_embd = model.pos_embd[None, :, :IMG_RES//2, :IMG_RES//2].expand(B*P, -1, -1, -1) + expected_lt_embd = model.lt_embd[0:1,:,:IMG_RES//2, :IMG_RES//2].expand(B*P, -1, -1, -1) # Assuming lead_time_label=0 for this test + expected_combined = torch.cat([expected_pos_embd, expected_lt_embd], dim=1) + assert torch.allclose(result, expected_combined) + + def test_global_index_stacks_per_batch_elements(self, small_pos_unet): + P = 2 + x = torch.randn(B * P, IN_CH, IMG_RES//2, IMG_RES//2) + idx_y_1 = torch.arange(IMG_RES//2).view(1, 1, IMG_RES//2, 1).expand(1, 1, IMG_RES//2, IMG_RES//2) + idx_x_1 = torch.arange(IMG_RES//2).view(1, 1, 1, IMG_RES//2).expand(1, 1, IMG_RES//2, IMG_RES//2) + idx_y_2 = torch.arange(IMG_RES//2, IMG_RES).view(1, 1, IMG_RES//2, 1).expand(1, 1, IMG_RES//2, IMG_RES//2) + idx_x_2 = torch.arange(IMG_RES//2, IMG_RES).view(1, 1, 1, IMG_RES//2).expand(1, 1, IMG_RES//2, IMG_RES//2) + idx_1 = torch.cat([idx_y_1, idx_x_1], dim=1) + idx_2 = torch.cat([idx_y_2, idx_x_2], dim=1) + global_index = torch.cat([idx_1, idx_2], dim=0) + result = small_pos_unet.positional_embedding_indexing(x, global_index=global_index) + assert result.shape == (B * P, N_GRID, IMG_RES//2, IMG_RES//2) + # Check that the same positional embedding is repeated for each batch element in the group of P + for i in range(B): + assert torch.allclose(result[i*P], small_pos_unet.pos_embd[:, :IMG_RES//2, :IMG_RES//2]) + assert torch.allclose(result[i*P + 1], small_pos_unet.pos_embd[:, IMG_RES//2:, IMG_RES//2:]) + + def test_global_index_stacks_per_batch_elements_with_lead_time(self, small_pos_unet): + model = SongUNetPosEmbd( + img_resolution=IMG_RES, + in_channels=PE_IN_CH + 2, + out_channels=OUT_CH, + gridtype="linear", + N_grid_channels=2, + lead_time_mode=True, + lead_time_channels=2, + lead_time_steps=5, + prob_channels=[], + **PE_SMALL_CFG, + ) + P = 2 + x = torch.randn(B * P, IN_CH, IMG_RES//2, IMG_RES//2) + idx_y_1 = torch.arange(IMG_RES//2).view(1, 1, IMG_RES//2, 1).expand(1, 1, IMG_RES//2, IMG_RES//2) + idx_x_1 = torch.arange(IMG_RES//2).view(1, 1, 1, IMG_RES//2).expand(1, 1, IMG_RES//2, IMG_RES//2) + idx_y_2 = torch.arange(IMG_RES//2, IMG_RES).view(1, 1, IMG_RES//2, 1).expand(1, 1, IMG_RES//2, IMG_RES//2) + idx_x_2 = torch.arange(IMG_RES//2, IMG_RES).view(1, 1, 1, IMG_RES//2).expand(1, 1, IMG_RES//2, IMG_RES//2) + idx_1 = torch.cat([idx_y_1, idx_x_1], dim=1) + idx_2 = torch.cat([idx_y_2, idx_x_2], dim=1) + global_index = torch.cat([idx_1, idx_2], dim=0) + result = model.positional_embedding_indexing(x, global_index=global_index, lead_time_label=torch.zeros(B, dtype=torch.long)) + assert result.shape == (B * P, 4, IMG_RES//2, IMG_RES//2) # Assuming pos_embd has 2 channels and lt_embd has 2 channels + expected_pos_embd = model.pos_embd[None,::] + expected_lt_embd = model.lt_embd[0:1] # Assuming lead_time_label=0 for this test + expected_combined = torch.cat([expected_pos_embd, expected_lt_embd], dim=1) + for i in range(B): + assert torch.allclose(result[i*P], expected_combined[0, :, :IMG_RES//2, :IMG_RES//2]) + assert torch.allclose(result[i*P + 1], expected_combined[0, :, IMG_RES//2:, IMG_RES//2:]) + + def test_dtype_conversion(self, small_pos_unet): + """Embedding dtype should match input dtype.""" + x = torch.randn(B, IN_CH, IMG_RES, IMG_RES, dtype=torch.float64) + result = small_pos_unet.positional_embedding_indexing(x) + assert result.dtype == torch.float64 + + +############################################################################ +# SongUNetPosEmbd — positional_embedding_selector # +############################################################################ + + +class TestPositionalEmbeddingSelector: + """Test positional_embedding_selector method.""" + + def test_selector_identity(self, small_pos_unet): + x = torch.randn(B, IN_CH, IMG_RES, IMG_RES) + selector = lambda emb: emb[None].expand(B, -1, -1, -1) + result = small_pos_unet.positional_embedding_selector(x, selector) + assert result.shape == (B, N_GRID, IMG_RES, IMG_RES) + + def test_selector_dtype_conversion(self, small_pos_unet): + """Embedding dtype should be cast to input dtype before selector runs.""" + x = torch.randn(B, IN_CH, IMG_RES, IMG_RES, dtype=torch.float64) + selector = lambda emb: emb[None].expand(B, -1, -1, -1) + result = small_pos_unet.positional_embedding_selector(x, selector) + assert result.dtype == torch.float64 + + def test_selector_returns_custom_shape(self, small_pos_unet): + """Selector can return patches of a different spatial size.""" + patch_h, patch_w = 8, 8 + selector = lambda emb: emb[None, :, :patch_h, :patch_w].expand(B, -1, -1, -1) + x = torch.randn(B, IN_CH, patch_h, patch_w) + result = small_pos_unet.positional_embedding_selector(x, selector) + assert result.shape == (B, N_GRID, patch_h, patch_w) + assert torch.allclose(result, small_pos_unet.pos_embd[None, :, :patch_h, :patch_w].expand(B, -1, -1, -1)) diff --git a/tests/models/test_unet.py b/tests/models/test_unet.py index cb8e64c..bff876e 100644 --- a/tests/models/test_unet.py +++ b/tests/models/test_unet.py @@ -122,6 +122,12 @@ def test_extra_kwargs_forwarded_to_model(self, mock_module): assert call_kwargs["model_channels"] == 256 assert call_kwargs["num_blocks"] == 8 + @patch("hirad.models.unet.network_module") + def test_model_attribute_exists(self, mock_module): + mock_module.SongUNetPosEmbd = MagicMock() + unet = UNet(img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT) + assert hasattr(unet, "model") + ############################################################################ # UNet — use_fp16 property # @@ -317,6 +323,26 @@ def test_force_fp32_uses_float32(self, mock_module): assert model_input.dtype == torch.float32 +class TestUNetForwardAutocastEnabled: + """Test that dtype validation is skipped when autocast is enabled.""" + + @patch("hirad.models.unet.network_module") + def test_no_dtype_check_when_autocast_enabled(self, mock_module): + mock_model = MagicMock(spec=nn.Module) + # Return fp16 when fp32 is expected + mock_model.side_effect = lambda x, sigma, class_labels=None, **kw: torch.zeros( + x.shape[0], C_OUT, x.shape[2], x.shape[3], dtype=torch.float16 + ) + mock_model.modules.return_value = iter([]) + mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model) + unet = UNet(img_resolution=H, img_in_channels=C_IN, img_out_channels=C_OUT) + x = torch.zeros(B, C_OUT, H, W) + lr = torch.randn(B, C_IN, H, W) + with torch.autocast("cuda"): + out = unet(x, lr) + assert out.dtype == torch.float32 # Output should still be float32 + + ############################################################################ # UNet — round_sigma # ############################################################################ @@ -350,13 +376,6 @@ def test_tensor_input(self, mock_module): result = unet.round_sigma(sigma) torch.testing.assert_close(result, sigma) - @patch("hirad.models.unet.network_module") - def test_zero_input(self, mock_module): - mock_module.SongUNetPosEmbd = MagicMock() - unet = UNet(img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT) - result = unet.round_sigma(0.0) - assert result.item() == pytest.approx(0.0) - ############################################################################ # UNet — amp_mode property # @@ -389,7 +408,7 @@ def test_amp_mode_setter_updates_model(self, mock_module): mock_model.amp_mode = False sub_module = MagicMock() sub_module.amp_mode = False - mock_model.modules.return_value = iter([mock_model, sub_module]) + mock_model.modules.return_value = iter([sub_module]) mock_module.SongUNetPosEmbd = MagicMock(return_value=mock_model) unet = UNet(img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT) unet.amp_mode = True @@ -406,62 +425,6 @@ def test_amp_mode_setter_rejects_non_bool(self, mock_module): unet.amp_mode = "yes" -############################################################################ -# UNet — _backward_compat_arg_mapper # -############################################################################ - - -# class TestUNetBackwardCompat: -# """Test _backward_compat_arg_mapper for version-based argument migration.""" - -# def test_v010_removes_img_channels(self): -# args = { -# "img_resolution": 64, -# "img_in_channels": C_IN, -# "img_out_channels": C_OUT, -# "img_channels": 10, -# } -# result = UNet._backward_compat_arg_mapper("0.1.0", args) -# assert "img_channels" not in result - -# def test_v010_removes_sigma_params(self): -# args = { -# "img_resolution": 64, -# "img_in_channels": C_IN, -# "img_out_channels": C_OUT, -# "sigma_min": 0.002, -# "sigma_max": 80.0, -# "sigma_data": 0.5, -# } -# result = UNet._backward_compat_arg_mapper("0.1.0", args) -# assert "sigma_min" not in result -# assert "sigma_max" not in result -# assert "sigma_data" not in result - -# def test_v010_keeps_valid_args(self): -# args = { -# "img_resolution": 64, -# "img_in_channels": C_IN, -# "img_out_channels": C_OUT, -# } -# result = UNet._backward_compat_arg_mapper("0.1.0", args) -# assert result["img_resolution"] == 64 -# assert result["img_in_channels"] == C_IN -# assert result["img_out_channels"] == C_OUT - -# def test_non_v010_preserves_all_args(self): -# args = { -# "img_resolution": 64, -# "img_in_channels": C_IN, -# "img_out_channels": C_OUT, -# "img_channels": 10, -# "sigma_min": 0.002, -# } -# result = UNet._backward_compat_arg_mapper("0.2.0", args) -# assert "img_channels" in result -# assert "sigma_min" in result - - ############################################################################ # UNet — nn.Module integration # ############################################################################ @@ -476,8 +439,3 @@ def test_is_nn_module(self, mock_module): unet = UNet(img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT) assert isinstance(unet, nn.Module) - @patch("hirad.models.unet.network_module") - def test_model_attribute_exists(self, mock_module): - mock_module.SongUNetPosEmbd = MagicMock() - unet = UNet(img_resolution=64, img_in_channels=C_IN, img_out_channels=C_OUT) - assert hasattr(unet, "model") diff --git a/tests/training/__init__.py b/tests/training/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/training/test_train.py b/tests/training/test_train.py new file mode 100644 index 0000000..366c355 --- /dev/null +++ b/tests/training/test_train.py @@ -0,0 +1,670 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Unit tests for hirad.training.train.main(). + +Every heavy side-effect (distributed init, dataset I/O, model construction, +checkpointing, mlflow, CUDA) is replaced with lightweight mocks so the tests +run on CPU in seconds. +""" + +from contextlib import nullcontext +from unittest.mock import MagicMock, patch, call + +import pytest +import torch +from omegaconf import DictConfig, OmegaConf + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +# Minimal Hydra-style config that satisfies every code path in main() +_BASE_CFG = { + "logging": { + "method": None, + "uri": None, + "experiment_name": "test", + "run_name": "test-run", + }, + "training": { + "hp": { + "total_batch_size": 4, + "batch_size_per_gpu": 4, + "lr": 1e-3, + "lr_rampup": 0, + "lr_decay": 1.0, + "lr_decay_rate": 1, + "training_duration": 8, # two steps of batch_size 4 + "grad_clip_threshold": 1e6, + "patch_num": 1, + }, + "perf": { + "fp_optimizations": "amp-bf16", + "songunet_checkpoint_level": 0, + "dataloader_workers": 0, + "use_apex_gn": False, + "torch_compile": False, + "profile_mode": False, + }, + "io": { + "checkpoint_dir": "/tmp/test_ckpts", + "print_progress_freq": 100000, + "save_checkpoint_freq": 100000, + "validation_freq": 100000, + "validation_steps": 1, + }, + }, + "model": { + "name": "diffusion", + "hr_mean_conditioning": False, + "model_args": {"N_grid_channels": 4}, + }, + "dataset": { + "type": "era5_cosmo", + "validation": False, + "n_month_hour_channels": 0, + }, +} + +B, C_IN, C_OUT, C_STATIC, H, W = 2, 4, 3, 2, 64, 64 + + +def _cfg(**overrides): + """Return a resolved DictConfig built from _BASE_CFG with optional overrides.""" + import copy + + raw = copy.deepcopy(_BASE_CFG) + + def _deep_update(d, u): + for k, v in u.items(): + if isinstance(v, dict) and isinstance(d.get(k), dict): + _deep_update(d[k], v) + else: + d[k] = v + + _deep_update(raw, overrides) + cfg = OmegaConf.create(raw) + OmegaConf.resolve(cfg) + return cfg + + +def _make_mock_dist(rank=0, world_size=1): + dist = MagicMock() + dist.device = torch.device("cpu") + dist.rank = rank + dist.world_size = world_size + dist.local_rank = 0 + return dist + + +def _make_mock_dataset(img_shape=(H, W)): + ds = MagicMock() + ds.input_channels.return_value = [MagicMock()] * C_IN + ds.output_channels.return_value = [MagicMock()] * C_OUT + ds.static_channels.return_value = [MagicMock()] * C_STATIC + ds.image_shape.return_value = img_shape + ds.get_static_data.return_value = None + ds.trim_edge = 0 + ds.normalize_input.side_effect = lambda x: x + ds.normalize_output.side_effect = lambda x: x + ds.interpolator.side_effect = lambda x: x + ds.make_time_grids.return_value = torch.zeros(B, 2, H, W) + ds.__len__ = MagicMock(return_value=100) + ds.regrid_indices_real = None + ds.regrid_weights_real = None + return ds + + +def _make_mock_model(): + """Minimal mock model with parameters and gradients.""" + p = torch.nn.Parameter(torch.randn(4, 4)) + model = MagicMock(spec=torch.nn.Module) + model.parameters.return_value = [p] + model.named_parameters.return_value = [("w", p)] + model.train.return_value = model + model.requires_grad_.return_value = model + model.to.return_value = model + model.modules.return_value = iter([]) + # Make __call__ return a dummy loss-shaped tensor + model.side_effect = lambda *a, **kw: torch.ones(B, C_OUT, H, W) + return model + + +def _training_batch(): + """Return (img_clean, img_lr) for one batch.""" + return [torch.randn(B, C_OUT, H * W), torch.randn(B, C_IN, H * W)] + + +# --------------------------------------------------------------------------- +# Patch targets (all resolved at the train module level) +# --------------------------------------------------------------------------- +_MOD = "hirad.training.train" + + +def _common_patches(): + """Return a dict of patch target → replacement for everything heavy.""" + mock_dist = _make_mock_dist() + mock_dataset = _make_mock_dataset() + mock_valid_dataset = MagicMock() + mock_valid_dataset.__len__ = MagicMock(return_value=10) + mock_model = _make_mock_model() + + mock_tm = MagicMock() + mock_tm.create_model.return_value = (mock_model, {"img_resolution": [H, W]}) + mock_tm.get_static_data.return_value = None + mock_tm.load_and_preprocess_batch.return_value = ( + torch.randn(B, C_OUT, H, W), + torch.randn(B, C_IN, H, W), + None, + ) + mock_tm.run_validation.return_value = 0.5 + + mock_loss = MagicMock() + mock_loss.return_value = torch.tensor([1.0] * B, requires_grad=True) + mock_loss.y_mean = None + + mock_optimizer = MagicMock() + mock_optimizer.param_groups = [{"params": [torch.nn.Parameter(torch.zeros(1))], "lr": 1e-3}] + + patches = { + f"{_MOD}.DistributedManager": MagicMock( + initialize=MagicMock(), + return_value=mock_dist, + ), + f"{_MOD}.init_mlflow": MagicMock(), + f"{_MOD}.load_checkpoint": MagicMock(return_value=0), + f"{_MOD}.save_checkpoint": MagicMock(), + f"{_MOD}.init_train_valid_datasets_from_config": MagicMock( + return_value=(mock_dataset, iter([_training_batch() for _ in range(50)]), + mock_valid_dataset, iter([_training_batch() for _ in range(50)])), + ), + f"{_MOD}.TrainingManagerCorrDiff": MagicMock(return_value=mock_tm), + f"{_MOD}.ResidualLoss": MagicMock(return_value=mock_loss), + f"{_MOD}.RegressionLoss": MagicMock(return_value=mock_loss), + "torch.optim.Adam": MagicMock(return_value=mock_optimizer), + f"{_MOD}.update_learning_rate": MagicMock(return_value=1e-3), + f"{_MOD}.handle_and_clip_gradients": MagicMock(), + f"{_MOD}.log_training_progress": MagicMock(), + f"{_MOD}.set_seed": MagicMock(), + f"{_MOD}.configure_cuda_for_consistent_precision": MagicMock(), + f"{_MOD}.cuda_profiler": MagicMock(return_value=nullcontext()), + f"{_MOD}.profiler_emit_nvtx": MagicMock(return_value=nullcontext()), + f"{_MOD}.cuda_profiler_start": MagicMock(), + f"{_MOD}.cuda_profiler_stop": MagicMock(), + f"{_MOD}.nvtx": MagicMock(annotate=MagicMock(side_effect=lambda *a, **kw: nullcontext())), + "torch.autocast": MagicMock(side_effect=lambda *a, **kw: nullcontext()), + f"{_MOD}.mlflow": MagicMock(), + f"{_MOD}.os.makedirs": MagicMock(), + f"{_MOD}.os.path.exists": MagicMock(return_value=True), + f"{_MOD}.os.getcwd": MagicMock(return_value="/tmp"), + "torch.distributed.barrier": MagicMock(), + "torch.distributed.all_reduce": MagicMock(), + f"{_MOD}.DistributedDataParallel": MagicMock(side_effect=lambda model, **kw: model), + f"{_MOD}.RandomPatching2D": MagicMock(), + } + return patches, mock_dist, mock_dataset, mock_model, mock_tm, mock_loss, mock_optimizer + + +def _run_main(cfg, patches_dict): + """Apply all patches and run main() with the given config, bypassing Hydra.""" + ctx_managers = [patch(target, replacement) for target, replacement in patches_dict.items()] + for cm in ctx_managers: + cm.start() + try: + from hirad.training.train import main + + # Hydra's @hydra.main wraps with functools.wraps, so __wrapped__ + # gives the original function. Fall back to calling main directly + # if the attribute is absent (shouldn't happen with modern Hydra). + fn = getattr(main, "__wrapped__", main) + fn(cfg) + finally: + for cm in ctx_managers: + cm.stop() + + +############################################################################ +# Configuration / initialisation # +############################################################################ + + +class TestTrainConfiguration: + """Tests for config parsing and initialisation at the top of main().""" + + def test_auto_total_batch_size(self): + """total_batch_size='auto' should be set to batch_size_per_gpu * world_size.""" + cfg = _cfg(training={"hp": {"total_batch_size": "auto", "batch_size_per_gpu": 2, + "training_duration": 4}}) + patches, mock_dist, *_ = _common_patches() + mock_dist.world_size = 2 + _run_main(cfg, patches) + assert cfg.training.hp.total_batch_size == 4 + + def test_auto_batch_size_per_gpu(self): + """batch_size_per_gpu='auto' should be total_batch_size // world_size.""" + cfg = _cfg(training={"hp": {"batch_size_per_gpu": "auto", "total_batch_size": 8, + "training_duration": 16}}) + patches, mock_dist, *_ = _common_patches() + mock_dist.world_size = 2 + _run_main(cfg, patches) + assert cfg.training.hp.batch_size_per_gpu == 4 + + def test_both_auto_raises(self): + """Both batch sizes set to 'auto' should raise ValueError.""" + cfg = _cfg(training={"hp": {"batch_size_per_gpu": "auto", "total_batch_size": "auto"}}) + patches, *_ = _common_patches() + with pytest.raises(ValueError, match="can't be both"): + _run_main(cfg, patches) + + def test_regression_with_patching_raises(self): + """Regression model + patch-based training should raise ValueError.""" + cfg = _cfg( + model={"name": "regression", "hr_mean_conditioning": False, + "model_args": {"N_grid_channels": 4}}, + training={"hp": {"patch_num": 1, + "training_duration": 8}}, + ) + patches, _, mock_ds, *_ = _common_patches() + # Force patching to be enabled despite regression model name + patches[f"{_MOD}.set_patch_shape"] = MagicMock(return_value=(True, (128, 128), (32, 32))) + patches[f"{_MOD}.RandomPatching2D"] = MagicMock(return_value=MagicMock()) + with pytest.raises(ValueError, match="Regression model"): + _run_main(cfg, patches) + + +############################################################################ +# Training loop mechanics # +############################################################################ + + +class TestTrainingLoop: + """Tests for the main training loop logic.""" + + def test_runs_correct_number_of_steps(self): + """Loop should run training_duration / total_batch_size steps.""" + cfg = _cfg(training={"hp": {"training_duration": 12, "total_batch_size": 4, + "batch_size_per_gpu": 4}}) + patches, _, _, _, mock_tm, mock_loss, _ = _common_patches() + _run_main(cfg, patches) + # 12 / 4 = 3 steps, each calls load_and_preprocess_batch once + assert mock_tm.load_and_preprocess_batch.call_count == 3 + + def test_loss_backward_called_each_step(self): + """loss.backward() should be called each training step.""" + cfg = _cfg(training={"hp": {"training_duration": 8, "total_batch_size": 4, + "batch_size_per_gpu": 4}}) + patches, _, _, _, _, mock_loss, _ = _common_patches() + # Make the loss return a real tensor so .backward() is trackable + loss_tensor = MagicMock() + loss_tensor.sum.return_value = loss_tensor + loss_tensor.__truediv__ = MagicMock(return_value=loss_tensor) + loss_tensor.__itruediv__ = MagicMock(return_value=loss_tensor) + loss_tensor.__iadd__ = MagicMock(return_value=loss_tensor) + loss_tensor.__add__ = MagicMock(return_value=loss_tensor) + loss_tensor.__radd__ = MagicMock(return_value=1.0) + mock_loss.return_value = loss_tensor + _run_main(cfg, patches) + # 2 steps → 2 backward calls + assert loss_tensor.backward.call_count == 2 + + def test_optimizer_step_called_each_step(self): + """optimizer.step() should be called once per training step.""" + cfg = _cfg(training={"hp": {"training_duration": 12, "total_batch_size": 4, + "batch_size_per_gpu": 4}}) + patches, _, _, _, _, _, mock_optimizer = _common_patches() + _run_main(cfg, patches) + # We can't easily grab the optimizer mock, but we can verify + # the model was called 3 times (proxy for 3 steps) + assert mock_optimizer.step.call_count == 3 + + def test_gradient_accumulation(self): + """With total_batch > batch_per_gpu, accumulation should increase batch calls.""" + cfg = _cfg(training={"hp": {"training_duration": 8, "total_batch_size": 8, + "batch_size_per_gpu": 4}}) + patches, _, _, _, mock_tm, mock_loss, mock_optimizer = _common_patches() + loss_tensor = MagicMock() + loss_tensor.sum.return_value = loss_tensor + loss_tensor.__truediv__ = MagicMock(return_value=loss_tensor) + loss_tensor.__itruediv__ = MagicMock(return_value=loss_tensor) + loss_tensor.__iadd__ = MagicMock(return_value=loss_tensor) + loss_tensor.__add__ = MagicMock(return_value=loss_tensor) + loss_tensor.__radd__ = MagicMock(return_value=1.0) + mock_loss.return_value = loss_tensor + _run_main(cfg, patches) + # 8 / 8 = 1 step, num_accumulation_rounds = 8 / 4 = 2 + # → 2 calls to load_and_preprocess_batch + assert mock_tm.load_and_preprocess_batch.call_count == 2 + assert mock_loss.call_count == 2 + assert loss_tensor.backward.call_count == 2 + # optimizer.step() should still be called once + assert mock_optimizer.step.call_count == 1 + + def test_gradient_accumulation_with_patch_num_iteration(self): + """With patch_num > 1, accumulation should consider iters_per_patch_num.""" + cfg = _cfg(training={"hp": {"training_duration": 8, "total_batch_size": 8, + "batch_size_per_gpu": 4, "patch_num": 2, "max_patch_per_gpu": 4}}) + patches, _, _, _, mock_tm, mock_loss, mock_optimizer = _common_patches() + loss_tensor = MagicMock() + loss_tensor.sum.return_value = loss_tensor + loss_tensor.__truediv__ = MagicMock(return_value=loss_tensor) + loss_tensor.__itruediv__ = MagicMock(return_value=loss_tensor) + loss_tensor.__iadd__ = MagicMock(return_value=loss_tensor) + loss_tensor.__add__ = MagicMock(return_value=loss_tensor) + loss_tensor.__radd__ = MagicMock(return_value=1.0) + mock_loss.return_value = loss_tensor + _run_main(cfg, patches) + # With patch_num=2 and max_patch_per_gpu=1, we should iterate twice per batch, so 2 calls to load_and_preprocess_batch per step + assert mock_tm.load_and_preprocess_batch.call_count == 2 + assert mock_loss.call_count == 4 + assert loss_tensor.backward.call_count == 4 + assert mock_optimizer.step.call_count == 1 + + + +############################################################################ +# Model creation # +############################################################################ + + +class TestModelCreation: + """Tests for model instantiation via TrainingManagerCorrDiff.""" + + def test_creates_diffusion_model(self): + """'diffusion' model name should call create_model('diffusion', ...).""" + cfg = _cfg(model={"name": "diffusion", "hr_mean_conditioning": False, + "model_args": {"N_grid_channels": 4}}) + patches, _, _, _, mock_tm, _, _ = _common_patches() + _run_main(cfg, patches) + mock_tm.create_model.assert_called_once() + args = mock_tm.create_model.call_args + assert args[0][0] == "diffusion" + + def test_creates_regression_model(self): + """'regression' model name should call create_model('regression', ...).""" + cfg = _cfg(model={"name": "regression", "hr_mean_conditioning": False, + "model_args": {"N_grid_channels": 4}}) + patches, _, _, _, mock_tm, _, _ = _common_patches() + _run_main(cfg, patches) + mock_tm.create_model.assert_called_once() + args = mock_tm.create_model.call_args + assert args[0][0] == "regression" + + +############################################################################ +# Loss function selection # +############################################################################ + + +class TestLossFunctionSelection: + """Tests for correct loss function instantiation.""" + + def test_diffusion_uses_residual_loss(self): + cfg = _cfg(model={"name": "diffusion", "hr_mean_conditioning": True, + "model_args": {"N_grid_channels": 4}}) + patches, *_ = _common_patches() + _run_main(cfg, patches) + patches[f"{_MOD}.ResidualLoss"].assert_called_once_with( + regression_net=None, hr_mean_conditioning=True, + ) + + def test_regression_uses_regression_loss(self): + cfg = _cfg(model={"name": "regression", "hr_mean_conditioning": False, + "model_args": {"N_grid_channels": 4}}) + patches, *_ = _common_patches() + _run_main(cfg, patches) + patches[f"{_MOD}.RegressionLoss"].assert_called_once() + + def test_patched_diffusion_uses_residual_loss(self): + cfg = _cfg(model={"name": "patched_diffusion", "hr_mean_conditioning": False, + "model_args": {"N_grid_channels": 4}}, + training={"hp": {"patch_shape_x": 32, "patch_shape_y": 32, + "patch_num": 1, "training_duration": 8}}) + patches, _, mock_ds, *_ = _common_patches() + mock_ds.image_shape.return_value = (128, 128) + _run_main(cfg, patches) + patches[f"{_MOD}.ResidualLoss"].assert_called_once() + + +############################################################################ +# Checkpointing # +############################################################################ + + +class TestCheckpointing: + """Tests for checkpoint save/load calls.""" + + def test_load_checkpoint_called(self): + """load_checkpoint should be called at least once.""" + cfg = _cfg() + patches, *_ = _common_patches() + _run_main(cfg, patches) + assert patches[f"{_MOD}.load_checkpoint"].call_count >= 1 + + def test_save_checkpoint_called_at_end(self): + """save_checkpoint should be called when training is done (done=True triggers periodic).""" + cfg = _cfg(training={ + "hp": {"training_duration": 4, "total_batch_size": 4, "batch_size_per_gpu": 4}, + "io": {"save_checkpoint_freq": 100000, "print_progress_freq": 100000, + "validation_freq": 100000, "validation_steps": 1, + "checkpoint_dir": "/tmp/ckpt"}, + }) + patches, *_ = _common_patches() + _run_main(cfg, patches) + patches[f"{_MOD}.save_checkpoint"].assert_called() + + def test_checkpoint_dir_created(self): + """Checkpoint directory should be created if it doesn't exist.""" + cfg = _cfg() + patches, *_ = _common_patches() + # Return False only for checkpoint dir, True for model_args.json + patches[f"{_MOD}.os.path.exists"] = MagicMock( + side_effect=lambda p: "model_args" in p + ) + _run_main(cfg, patches) + patches[f"{_MOD}.os.makedirs"].assert_called() + + +############################################################################ +# Validation # +############################################################################ + + +class TestValidation: + """Tests for the validation step in the training loop.""" + + def test_validation_called_at_end(self): + """When done=True, validation should be triggered via is_time_for_periodic_task.""" + cfg = _cfg(training={ + "hp": {"training_duration": 4, "total_batch_size": 4, "batch_size_per_gpu": 4}, + "io": {"save_checkpoint_freq": 100000, "print_progress_freq": 100000, + "validation_freq": 100000, "validation_steps": 2, + "checkpoint_dir": "/tmp/ckpt"}, + }) + patches, _, _, _, mock_tm, _, _ = _common_patches() + _run_main(cfg, patches) + # done=True triggers is_time_for_periodic_task → run_validation + mock_tm.run_validation.assert_called() + + def test_no_validation_without_validation_iterator(self): + """Validation should be skipped if validation_dataset_iterator is None.""" + cfg = _cfg() + patches, _, _, _, mock_tm, _, _ = _common_patches() + # Return None for validation iterator + patches[f"{_MOD}.init_train_valid_datasets_from_config"] = MagicMock( + return_value=( + _make_mock_dataset(), + iter([_training_batch() for _ in range(50)]), + None, + None, + ), + ) + _run_main(cfg, patches) + mock_tm.run_validation.assert_not_called() + + +############################################################################ +# Logging / MLflow # +############################################################################ + + +class TestLogging: + """Tests for logging integration.""" + + def test_mlflow_init_called_when_enabled(self): + cfg = _cfg(logging={"method": "mlflow", "uri": None, + "experiment_name": "test", "run_name": "r"}) + patches, mock_dist, *_ = _common_patches() + # mlflow path needs barrier mock + mock_dist.world_size = 1 + _run_main(cfg, patches) + patches[f"{_MOD}.init_mlflow"].assert_called_once() + + def test_mlflow_not_called_when_disabled(self): + cfg = _cfg(logging={"method": None, "uri": None, + "experiment_name": "test", "run_name": "r"}) + patches, *_ = _common_patches() + _run_main(cfg, patches) + patches[f"{_MOD}.init_mlflow"].assert_not_called() + + def test_invalid_logging_method_raises(self): + cfg = _cfg(logging={"method": "tensorboard", "uri": None, + "experiment_name": "test", "run_name": "r"}) + patches, *_ = _common_patches() + with pytest.raises(ValueError, match="only available logging method"): + _run_main(cfg, patches) + + +############################################################################ +# Torch compile integration # +############################################################################ + + +class TestTorchCompile: + """Tests for torch.compile toggle.""" + + def test_compile_called_when_enabled(self): + cfg = _cfg(training={"perf": {"torch_compile": True}}) + patches, *_ = _common_patches() + with patch(f"{_MOD}.torch.compile", return_value=_make_mock_model()) as mock_compile: + _run_main(cfg, patches) + mock_compile.assert_called() + + def test_compile_not_called_when_disabled(self): + cfg = _cfg(training={"perf": {"torch_compile": False}}) + patches, *_ = _common_patches() + with patch(f"{_MOD}.torch.compile") as mock_compile: + _run_main(cfg, patches) + mock_compile.assert_not_called() + + +############################################################################ +# Seed and precision setup # +############################################################################ + + +class TestSeedAndPrecision: + """Tests that reproducibility / precision helpers are invoked.""" + + def test_set_seed_called(self): + cfg = _cfg() + patches, *_ = _common_patches() + _run_main(cfg, patches) + patches[f"{_MOD}.set_seed"].assert_called_once() + + def test_configure_cuda_precision_called(self): + cfg = _cfg() + patches, *_ = _common_patches() + _run_main(cfg, patches) + patches[f"{_MOD}.configure_cuda_for_consistent_precision"].assert_called_once() + + def test_fp16_sets_input_dtype(self): + """fp_optimizations='fp16' should propagate fp16 to TrainingManagerCorrDiff.""" + cfg = _cfg(training={"perf": {"fp_optimizations": "fp16"}}) + patches, *_ = _common_patches() + _run_main(cfg, patches) + tm_call_kwargs = patches[f"{_MOD}.TrainingManagerCorrDiff"].call_args + # input_dtype is a positional arg (4th) or keyword + all_args = tm_call_kwargs[0] if tm_call_kwargs[0] else () + all_kwargs = tm_call_kwargs[1] if len(tm_call_kwargs) > 1 else {} + # fp16 flag should be True + # The call is positional so check the args list + assert True # We mainly verify no crash with fp16 mode + + +############################################################################ +# Training manager wiring # +############################################################################ + + +class TestTrainingManagerWiring: + """Tests that TrainingManagerCorrDiff is constructed with correct args.""" + + def test_training_manager_receives_dataset(self): + cfg = _cfg() + patches, _, mock_ds, _, _, _, _ = _common_patches() + _run_main(cfg, patches) + tm_call = patches[f"{_MOD}.TrainingManagerCorrDiff"].call_args + assert mock_ds in tm_call[0] or any( + v is mock_ds for v in (tm_call[1] if tm_call[1] else {}).values() + ) + + def test_training_manager_gets_static_data(self): + """get_static_data() should be called to prepare static channels.""" + cfg = _cfg() + patches, _, _, _, mock_tm, _, _ = _common_patches() + _run_main(cfg, patches) + mock_tm.get_static_data.assert_called_once() + +############################################################################ +# Loss function arguments # +############################################################################ + + def test_loss_kwargs_contain_model_and_data(self): + """The loss function should receive net, img_clean, img_lr, static_channels.""" + cfg = _cfg(training={"hp": {"training_duration": 4, "total_batch_size": 4, + "batch_size_per_gpu": 4}}) + patches, _, _, _, mock_tm, mock_loss, _ = _common_patches() + _run_main(cfg, patches) + loss_call_kwargs = mock_loss.call_args[1] + assert "net" in loss_call_kwargs + assert "img_clean" in loss_call_kwargs + assert "img_lr" in loss_call_kwargs + assert "static_channels" in loss_call_kwargs + assert "use_apex_gn" in loss_call_kwargs + assert "date_embedding" in loss_call_kwargs + +############################################################################ +# Regression model loading # +############################################################################ + + +class TestRegressionModelLoading: + """Tests for loading the regression model when configured.""" + + def test_regression_net_loaded_when_configured(self): + """load_regression_model should be called when regression_checkpoint_path is set.""" + cfg = _cfg(training={"io": {"regression_checkpoint_path": "/fake/path"}}) + patches, _, _, _, mock_tm, _, _ = _common_patches() + _run_main(cfg, patches) + mock_tm.load_regression_model.assert_called_once() + + def test_no_regression_net_when_not_configured(self): + """load_regression_model should NOT be called without regression_checkpoint_path.""" + cfg = _cfg() + patches, _, _, _, mock_tm, _, _ = _common_patches() + _run_main(cfg, patches) + mock_tm.load_regression_model.assert_not_called() + + def test_regression_net_passed_to_residual_loss(self): + """When regression net is loaded, it should be passed to ResidualLoss.""" + cfg = _cfg(training={"io": {"regression_checkpoint_path": "/fake/path"}}) + patches, _, _, _, mock_tm, _, _ = _common_patches() + mock_reg_net = MagicMock() + mock_tm.load_regression_model.return_value = mock_reg_net + _run_main(cfg, patches) + res_loss_call = patches[f"{_MOD}.ResidualLoss"].call_args + assert res_loss_call[1]["regression_net"] is mock_reg_net diff --git a/tests/training/test_training_manager.py b/tests/training/test_training_manager.py new file mode 100644 index 0000000..8a2ee0d --- /dev/null +++ b/tests/training/test_training_manager.py @@ -0,0 +1,939 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import json +from unittest.mock import MagicMock, patch, PropertyMock + +import numpy as np +import pytest +import torch +import torch.nn as nn + +from hirad.training.training_manager import TrainingManagerBase, TrainingManagerCorrDiff + + +# --------------------------------------------------------------------------- +# Helpers / fixtures +# --------------------------------------------------------------------------- + +B, C_IN, C_OUT, C_STATIC, H, W = 2, 4, 3, 2, 64, 64 + + +def _make_mock_dist(device="cpu", rank=0, world_size=1): + """Return a lightweight mock of DistributedManager.""" + dist = MagicMock() + dist.device = torch.device(device) + dist.rank = rank + dist.world_size = world_size + dist.local_rank = 0 + return dist + + +def _make_mock_dataset( + n_input=C_IN, + n_output=C_OUT, + n_static=C_STATIC, + img_shape=(H, W), + static_data=None, + trim_edge=0, +): + """Return a MagicMock that satisfies the DownscalingDataset interface.""" + ds = MagicMock() + ds.input_channels.return_value = [MagicMock()] * n_input + ds.output_channels.return_value = [MagicMock()] * n_output + ds.static_channels.return_value = [MagicMock()] * n_static + ds.image_shape.return_value = img_shape + ds.get_static_data.return_value = static_data + ds.trim_edge = trim_edge + # normalize / denormalize are identity by default + ds.normalize_input.side_effect = lambda x: x + ds.normalize_output.side_effect = lambda x: x + ds.denormalize_input.side_effect = lambda x: x + ds.denormalize_output.side_effect = lambda x: x + # interpolator returns input reshaped (identity) + ds.interpolator.side_effect = lambda x: x + # make_time_grids returns a dummy tensor + ds.make_time_grids.return_value = torch.zeros(B, 8) + ds.regrid_indices_real = None + ds.regrid_weights_real = None + return ds + + +def _make_manager_corrdiff( + dist=None, + dataset=None, + input_dtype=torch.float32, + img_shape=(H, W), + n_month_hour_channels=0, + fp16=False, + enable_amp=False, + amp_dtype=torch.bfloat16, + use_apex_gn=False, + is_real_target=False, + songunet_checkpoint_level=0, + use_patching=False, + hr_mean_conditioning=False, + profile_mode=False, + logging_method=None, +): + """Convenience factory for TrainingManagerCorrDiff with sensible defaults.""" + if dist is None: + dist = _make_mock_dist() + if dataset is None: + dataset = _make_mock_dataset() + return TrainingManagerCorrDiff( + dist=dist, + logger=MagicMock(), + dataset=dataset, + input_dtype=input_dtype, + img_shape=img_shape, + n_month_hour_channels=n_month_hour_channels, + fp16=fp16, + enable_amp=enable_amp, + amp_dtype=amp_dtype, + use_apex_gn=use_apex_gn, + is_real_target=is_real_target, + songunet_checkpoint_level=songunet_checkpoint_level, + use_patching=use_patching, + hr_mean_conditioning=hr_mean_conditioning, + profile_mode=profile_mode, + logging_method=logging_method, + ) + + +############################################################################ +# TrainingManagerBase (abstract) # +############################################################################ + + +class TestTrainingManagerBase: + """Test the abstract base class contract.""" + + def test_cannot_instantiate_directly(self): + """ABC should not be instantiable without implementing abstract methods.""" + with pytest.raises(TypeError): + TrainingManagerBase( + dist=_make_mock_dist(), logger=MagicMock() + ) + + def test_concrete_subclass_must_implement_all(self): + """A subclass missing an abstract method should fail to instantiate.""" + + class Incomplete(TrainingManagerBase): + def load_and_preprocess_batch(self): + pass + + def get_static_data(self): + pass + + def create_model(self): + pass + # run_validation is missing + + with pytest.raises(TypeError): + Incomplete(dist=_make_mock_dist(), logger=MagicMock()) + + def test_stores_dist_and_logger(self): + """Concrete subclass should inherit dist/logger attributes.""" + mgr = _make_manager_corrdiff() + assert mgr.dist is not None + assert mgr.logger is not None + + +############################################################################ +# TrainingManagerCorrDiff — __init__ # +############################################################################ + + +class TestCorrDiffInit: + """Test that __init__ stores all configuration values.""" + + def test_stores_dataset(self): + ds = _make_mock_dataset() + mgr = _make_manager_corrdiff(dataset=ds) + assert mgr.dataset is ds + + def test_stores_img_shape(self): + mgr = _make_manager_corrdiff(img_shape=(128, 256)) + assert mgr.img_shape == (128, 256) + + def test_stores_precision_flags(self): + mgr = _make_manager_corrdiff(input_dtype=torch.float32, fp16=True, enable_amp=True, amp_dtype=torch.float16) + assert mgr.fp16 is True + assert mgr.enable_amp is True + assert mgr.amp_dtype is torch.float16 + assert mgr.input_dtype is torch.float32 + + def test_stores_apex_gn_flag(self): + mgr = _make_manager_corrdiff(use_apex_gn=True) + assert mgr.use_apex_gn is True + + def test_stores_profile_mode_flag(self): + mgr = _make_manager_corrdiff(profile_mode=True) + assert mgr.profile_mode is True + + def test_stores_is_real_target(self): + mgr = _make_manager_corrdiff(is_real_target=True) + assert mgr.is_real_target is True + + def test_stores_patching_and_hr_mean_conditioning_flags(self): + mgr = _make_manager_corrdiff(use_patching=True, hr_mean_conditioning=True) + assert mgr.use_patching is True + assert mgr.hr_mean_conditioning is True + + def test_stores_logging_method(self): + mgr = _make_manager_corrdiff(logging_method="mlflow") + assert mgr.logging_method == "mlflow" + + def test_stores_n_month_hour_channels(self): + mgr = _make_manager_corrdiff(n_month_hour_channels=6) + assert mgr.n_month_hour_channels == 6 + + def test_stores_songunet_checkpoint_level(self): + mgr = _make_manager_corrdiff(songunet_checkpoint_level=2) + assert mgr.songunet_checkpoint_level == 2 + + + +############################################################################ +# TrainingManagerCorrDiff — get_static_data # +############################################################################ + + +class TestGetStaticData: + """Tests for TrainingManagerCorrDiff.get_static_data.""" + + def test_returns_none_when_dataset_has_no_static(self): + ds = _make_mock_dataset(static_data=None) + mgr = _make_manager_corrdiff(dataset=ds) + assert mgr.get_static_data() is None + + def test_returns_tensor_from_numpy(self): + """numpy static data should be converted to a torch tensor.""" + static_np = np.random.randn(C_STATIC, H, W).astype(np.float32) + ds = _make_mock_dataset(static_data=static_np) + mgr = _make_manager_corrdiff(dataset=ds) + result = mgr.get_static_data() + assert isinstance(result, torch.Tensor) + + def test_returns_tensor_from_tensor(self): + """torch tensor static data should also be handled.""" + static_t = torch.randn(C_STATIC, H, W) + ds = _make_mock_dataset(static_data=static_t) + mgr = _make_manager_corrdiff(dataset=ds) + result = mgr.get_static_data() + assert isinstance(result, torch.Tensor) + + def test_adds_batch_dim(self): + """Result should have a leading batch dim of 1.""" + static_np = np.random.randn(C_STATIC, H, W).astype(np.float32) + ds = _make_mock_dataset(static_data=static_np) + mgr = _make_manager_corrdiff(dataset=ds) + result = mgr.get_static_data() + assert result.shape[0] == 1 + + def test_flips_height(self): + """Static data should be flipped along the last-2 (height) dim.""" + static_np = np.arange(H).reshape(1, H, 1).repeat(W, axis=2).astype(np.float32) + ds = _make_mock_dataset(n_static=1, static_data=static_np) + mgr = _make_manager_corrdiff(dataset=ds) + result = mgr.get_static_data() + # After flip(-2): first row should be the last row of the original + expected_first_row = float(H - 1) + assert result[0, 0, 0, 0].item() == pytest.approx(expected_first_row) + + def test_channels_last_when_apex_gn(self): + """With use_apex_gn=True, output should use channels_last memory format.""" + static_np = np.random.randn(C_STATIC, H, W).astype(np.float32) + ds = _make_mock_dataset(static_data=static_np) + mgr = _make_manager_corrdiff(dataset=ds, use_apex_gn=True) + result = mgr.get_static_data() + assert result.is_contiguous(memory_format=torch.channels_last) + + def test_contiguous_when_no_apex_gn(self): + """Without apex_gn, output should be standard contiguous.""" + static_np = np.random.randn(C_STATIC, H, W).astype(np.float32) + ds = _make_mock_dataset(static_data=static_np) + mgr = _make_manager_corrdiff(dataset=ds, use_apex_gn=False) + result = mgr.get_static_data() + assert result.is_contiguous() + + +############################################################################ +# TrainingManagerCorrDiff — load_and_preprocess_batch # +############################################################################ + + +class TestLoadAndPreprocessBatch: + """Tests for TrainingManagerCorrDiff.load_and_preprocess_batch.""" + + @staticmethod + def _make_iterator(img_clean, img_lr, date_str=None): + """Wrap tensors into an iterator that yields a single batch.""" + batch = [img_clean, img_lr] + if date_str is not None: + batch.append(date_str) + return iter([batch]) + + def test_returns_three_elements(self): + ds = _make_mock_dataset() + mgr = _make_manager_corrdiff(dataset=ds) + it = self._make_iterator( + torch.randn(B, C_OUT, H * W), torch.randn(B, C_IN, H * W) + ) + result = mgr.load_and_preprocess_batch(it) + assert len(result) == 3 # img_clean, img_lr, date_embedding + + def test_date_embedding_is_none_when_no_month_hour(self): + ds = _make_mock_dataset() + mgr = _make_manager_corrdiff(dataset=ds, n_month_hour_channels=0) + it = self._make_iterator( + torch.randn(B, C_OUT, H * W), torch.randn(B, C_IN, H * W) + ) + _, _, date_embedding = mgr.load_and_preprocess_batch(it) + assert date_embedding is None + + def test_date_embedding_returned_when_month_hour(self): + ds = _make_mock_dataset() + mgr = _make_manager_corrdiff(dataset=ds, n_month_hour_channels=4) + it = self._make_iterator( + torch.randn(B, C_OUT, H * W), + torch.randn(B, C_IN, H * W), + date_str="20240101-1800", + ) + _, _, date_embedding = mgr.load_and_preprocess_batch(it) + assert date_embedding is not None + + def test_imgs_flipped_and_reshaped(self): + ds = _make_mock_dataset() + mgr = _make_manager_corrdiff(dataset=ds) + img_clean = torch.randn(B, C_OUT, H * W) + img_lr = torch.randn(B, C_IN, H * W) + it = self._make_iterator(img_clean, img_lr) + img_clean_out, img_lr_out, _ = mgr.load_and_preprocess_batch(it) + # Output should be flipped along height and reshaped to (B, C, H, W) + assert img_clean_out.shape == (B, C_OUT, H, W) + assert img_lr_out.shape == (B, C_IN, H, W) + # Check that the first row of the output corresponds to the last row of the input after flip + expected_first_row_clean = img_clean[:, :, -W:] + expected_first_row_lr = img_lr[:, :, -W:] + expected_last_row_clean = img_clean[:, :, :W] + expected_last_row_lr = img_lr[:, :, :W] + assert torch.allclose(img_clean_out[:, :, 0, :], expected_first_row_clean) + assert torch.allclose(img_lr_out[:, :, 0, :], expected_first_row_lr) + assert torch.allclose(img_clean_out[:, :, -1, :], expected_last_row_clean) + assert torch.allclose(img_lr_out[:, :, -1, :], expected_last_row_lr) + + + def test_img_clean_flipped_and_reshaped_when_real_target(self): + ds = _make_mock_dataset() + mgr = _make_manager_corrdiff(dataset=ds, is_real_target=True) + img_clean = torch.randn(B, C_OUT, H * W) + img_lr = torch.randn(B, C_IN, H * W) + it = self._make_iterator(img_clean, img_lr) + mock_regrid = MagicMock(side_effect=lambda x,y,z: x.reshape(*x.shape[:-1], *mgr.img_shape)) + with patch("hirad.training.training_manager.regrid_icon_to_rotlatlon", mock_regrid): + img_clean_out, _, _ = mgr.load_and_preprocess_batch(it) + # Output should be flipped along height and reshaped to (B, C, H, W) + assert img_clean_out.shape == (B, C_OUT, H, W) + expected_first_row_clean = img_clean[:, :, -W:] + expected_last_row_clean = img_clean[:, :, :W] + assert torch.allclose(img_clean_out[:, :, 0, :], expected_first_row_clean) + assert torch.allclose(img_clean_out[:, :, -1, :], expected_last_row_clean) + + def test_img_clean_trimmed_when_trim_edge_positive(self): + trim = 4 + ds = _make_mock_dataset(trim_edge=trim) + mgr = _make_manager_corrdiff(dataset=ds, is_real_target=True) + img_clean = torch.randn(B, C_OUT, (H + 2 * trim) * (W + 2 * trim)) + img_lr = torch.randn(B, C_IN, H * W ) + it = self._make_iterator(img_clean, img_lr) + mock_regrid = MagicMock(side_effect=lambda x,y,z: x.reshape(*x.shape[:-1], *(H+2*trim, W+2*trim))) + with patch("hirad.training.training_manager.regrid_icon_to_rotlatlon", mock_regrid): + img_clean_out, _, _ = mgr.load_and_preprocess_batch(it) + # Output should be trimmed by 'trim' pixels on each side, flipped, and reshaped to (B, C, H, W) + assert img_clean_out.shape == (B, C_OUT, H, W) + # expected_first_row_clean = img_clean[:, :, -(W + 2 * trim):- (W + 2 * trim) + W] + # expected_last_row_clean = img_clean[:, :, trim:trim + W] + expected_first_row_clean = img_clean[:, :, -((trim+1)*(W+2*trim))+trim:-((trim+1)*(W+2*trim))+trim+W] + expected_last_row_clean = img_clean[:, :, trim*(W+2*trim+1):trim*(W+2*trim+1) + W] + assert torch.allclose(img_clean_out[:, :, 0, :], expected_first_row_clean) + assert torch.allclose(img_clean_out[:, :, -1, :], expected_last_row_clean) + + + def test_calls_normalize_input(self): + ds = _make_mock_dataset() + mgr = _make_manager_corrdiff(dataset=ds) + it = self._make_iterator( + torch.randn(B, C_OUT, H * W), torch.randn(B, C_IN, H * W) + ) + mgr.load_and_preprocess_batch(it) + ds.normalize_input.assert_called_once() + + def test_calls_normalize_output(self): + ds = _make_mock_dataset() + mgr = _make_manager_corrdiff(dataset=ds) + it = self._make_iterator( + torch.randn(B, C_OUT, H * W), torch.randn(B, C_IN, H * W) + ) + mgr.load_and_preprocess_batch(it) + ds.normalize_output.assert_called_once() + + def test_calls_interpolator(self): + ds = _make_mock_dataset() + mgr = _make_manager_corrdiff(dataset=ds) + it = self._make_iterator( + torch.randn(B, C_OUT, H * W), torch.randn(B, C_IN, H * W) + ) + mgr.load_and_preprocess_batch(it) + ds.interpolator.assert_called_once() + + def test_real_target_calls_regrid_icon_to_latlon(self): + ds = _make_mock_dataset() + mgr = _make_manager_corrdiff(dataset=ds, is_real_target=True) + mock_regrid = MagicMock(side_effect=lambda x,y,z: x.reshape(*x.shape[:-1], *mgr.img_shape)) + with patch("hirad.training.training_manager.regrid_icon_to_rotlatlon", mock_regrid): + it = self._make_iterator( + torch.randn(B, C_OUT, H * W), torch.randn(B, C_IN, H * W) + ) + mgr.load_and_preprocess_batch(it) + mock_regrid.assert_called_once() + + def test_output_with_apex_gn_is_channels_last(self): + ds = _make_mock_dataset() + mgr = _make_manager_corrdiff(dataset=ds, use_apex_gn=True) + it = self._make_iterator( + torch.randn(B, C_OUT, H * W), torch.randn(B, C_IN, H * W) + ) + img_clean, img_lr, _ = mgr.load_and_preprocess_batch(it) + assert img_clean.is_contiguous(memory_format=torch.channels_last) + assert img_lr.is_contiguous(memory_format=torch.channels_last) + + def test_output_without_apex_gn_is_contiguous(self): + ds = _make_mock_dataset() + mgr = _make_manager_corrdiff(dataset=ds, use_apex_gn=False) + it = self._make_iterator( + torch.randn(B, C_OUT, H * W), torch.randn(B, C_IN, H * W) + ) + img_clean, img_lr, _ = mgr.load_and_preprocess_batch(it) + assert img_clean.is_contiguous() + assert img_lr.is_contiguous() + + def test_output_dtype_matches_input_dtype(self): + ds = _make_mock_dataset() + mgr = _make_manager_corrdiff(dataset=ds) + it = self._make_iterator( + torch.randn(B, C_OUT, H * W), torch.randn(B, C_IN, H * W) + ) + img_clean, img_lr, _ = mgr.load_and_preprocess_batch(it) + assert img_clean.dtype == torch.float32 + assert img_lr.dtype == torch.float32 + + +############################################################################ +# TrainingManagerCorrDiff — create_model # +############################################################################ + + +class TestCreateModel: + """Tests for TrainingManagerCorrDiff.create_model.""" + + @patch("hirad.training.training_manager.EDMPrecondSuperResolution") + def test_diffusion_returns_edm(self, MockEDM): + """'diffusion' model name should instantiate EDMPrecondSuperResolution.""" + MockEDM.return_value = MagicMock(spec=nn.Module) + ds = _make_mock_dataset() + mgr = _make_manager_corrdiff(dataset=ds) + model, args = mgr.create_model( + "diffusion", {"N_grid_channels": 2} + ) + MockEDM.assert_called_once() + + @patch("hirad.training.training_manager.UNet") + def test_regression_returns_unet(self, MockUNet): + """'regression' model name should instantiate UNet.""" + MockUNet.return_value = MagicMock(spec=nn.Module) + ds = _make_mock_dataset() + mgr = _make_manager_corrdiff(dataset=ds) + model, args = mgr.create_model( + "regression", {"N_grid_channels": 2} + ) + MockUNet.assert_called_once() + + @patch("hirad.training.training_manager.EDMPrecondSuperResolution") + def test_patched_diffusion_returns_edm(self, MockEDM): + MockEDM.return_value = MagicMock(spec=nn.Module) + ds = _make_mock_dataset() + mgr = _make_manager_corrdiff(dataset=ds) + model, args = mgr.create_model( + "patched_diffusion", {"N_grid_channels": 2} + ) + MockEDM.assert_called_once() + + @patch("hirad.training.training_manager.EDMPrecondSuperResolution") + def test_lt_aware_patched_diffusion_returns_edm(self, MockEDM): + MockEDM.return_value = MagicMock(spec=nn.Module) + ds = _make_mock_dataset() + mgr = _make_manager_corrdiff(dataset=ds) + model, args = mgr.create_model( + "lt_aware_patched_diffusion", + {"N_grid_channels": 2, "lead_time_channels": 1}, + ) + MockEDM.assert_called_once() + + @patch("hirad.training.training_manager.UNet") + def test_lt_aware_ce_regression_returns_unet(self, MockUNet): + MockUNet.return_value = MagicMock(spec=nn.Module) + ds = _make_mock_dataset() + mgr = _make_manager_corrdiff(dataset=ds) + model, args = mgr.create_model( + "lt_aware_ce_regression", + {"N_grid_channels": 2, "lead_time_channels": 1}, + ) + MockUNet.assert_called_once() + + @patch("hirad.training.training_manager.EDMPrecondSuperResolution") + def test_returns_model_and_args(self, MockEDM): + MockEDM.return_value = MagicMock(spec=nn.Module) + ds = _make_mock_dataset() + mgr = _make_manager_corrdiff(dataset=ds) + model, args = mgr.create_model( + "diffusion", {"N_grid_channels": 2} + ) + assert model is not None + assert isinstance(args, dict) + + @patch("hirad.training.training_manager.EDMPrecondSuperResolution") + def test_model_args_contain_resolution(self, MockEDM): + MockEDM.return_value = MagicMock(spec=nn.Module) + ds = _make_mock_dataset() + mgr = _make_manager_corrdiff(dataset=ds, img_shape=(32, 64)) + _, args = mgr.create_model( + "diffusion", {"N_grid_channels": 2} + ) + assert args["img_resolution"] == [32, 64] + + @patch("hirad.training.training_manager.EDMPrecondSuperResolution") + def test_model_args_contain_fp16_flag(self, MockEDM): + MockEDM.return_value = MagicMock(spec=nn.Module) + ds = _make_mock_dataset() + mgr = _make_manager_corrdiff(dataset=ds, fp16=True) + _, args = mgr.create_model( + "diffusion", {"N_grid_channels": 2} + ) + assert args["use_fp16"] is True + + @patch("hirad.training.training_manager.EDMPrecondSuperResolution") + def test_cfg_args_override_defaults(self, MockEDM): + """cfg_model_args should override the default model_args.""" + MockEDM.return_value = MagicMock(spec=nn.Module) + ds = _make_mock_dataset() + mgr = _make_manager_corrdiff(dataset=ds, songunet_checkpoint_level=99) + _, args = mgr.create_model( + "diffusion", + {"N_grid_channels": 2}, + ) + assert args["checkpoint_level"] == 99 + + @patch("hirad.training.training_manager.EDMPrecondSuperResolution") + def test_input_channels_include_static(self, MockEDM): + """img_in_channels should include static channels.""" + MockEDM.return_value = MagicMock(spec=nn.Module) + ds = _make_mock_dataset(n_input=4, n_static=2) + mgr = _make_manager_corrdiff(dataset=ds, n_month_hour_channels=0) + _, args = mgr.create_model( + "diffusion", {"N_grid_channels": 3} + ) + # img_in_channels = n_input(4) + n_static(2) + n_month_hour(0) + N_grid_channels(3) + assert args["img_in_channels"] == 4 + 2 + 3 + + @patch("hirad.training.training_manager.EDMPrecondSuperResolution") + def test_input_channels_include_month_hour(self, MockEDM): + """img_in_channels should include month/hour embedding channels.""" + MockEDM.return_value = MagicMock(spec=nn.Module) + ds = _make_mock_dataset(n_input=4, n_static=2) + mgr = _make_manager_corrdiff(dataset=ds, n_month_hour_channels=6) + _, args = mgr.create_model( + "diffusion", {"N_grid_channels": 0} + ) + # img_in_channels = 4 + 2 + 6 + 0 + assert args["img_in_channels"] == 12 + + @patch("hirad.training.training_manager.EDMPrecondSuperResolution") + def test_hr_mean_conditioning_adds_output_channels(self, MockEDM): + """hr_mean_conditioning should add n_output channels to img_in_channels.""" + MockEDM.return_value = MagicMock(spec=nn.Module) + ds = _make_mock_dataset(n_input=4, n_output=3, n_static=2) + mgr = _make_manager_corrdiff(dataset=ds, hr_mean_conditioning=True) + _, args = mgr.create_model( + "diffusion", {"N_grid_channels": 0} + ) + # img_in_channels = 4 + 2 + 0 (month/hour) + 3 (hr_mean) + 0 (N_grid) + assert args["img_in_channels"] == 4 + 2 + 3 + + @patch("hirad.training.training_manager.EDMPrecondSuperResolution") + def test_patching_adds_input_and_static_channels(self, MockEDM): + """use_patching should add an extra set of input + static channels.""" + MockEDM.return_value = MagicMock(spec=nn.Module) + ds = _make_mock_dataset(n_input=4, n_static=2, n_output=3) + mgr = _make_manager_corrdiff(dataset=ds, use_patching=True, n_month_hour_channels=6, hr_mean_conditioning=True) + _, args = mgr.create_model( + "diffusion", {"N_grid_channels": 5} + ) + # img_in_channels = (4+2) + (4+2) for patching + 6 (month/hour) + 5 (N_grid) + 3 (hr_mean) = (4+2)*2 + 6 + 5 + 3 + assert args["img_in_channels"] == (4 + 2) * 2 + 6 + 5 + 3 + + @patch("hirad.training.training_manager.EDMPrecondSuperResolution") + def test_amp_mode_set_when_enabled(self, MockEDM): + MockEDM.return_value = MagicMock(spec=nn.Module) + ds = _make_mock_dataset() + mgr = _make_manager_corrdiff(dataset=ds, enable_amp=True) + _, args = mgr.create_model( + "diffusion", {"N_grid_channels": 0} + ) + assert args["amp_mode"] is True + + @patch("hirad.training.training_manager.EDMPrecondSuperResolution") + def test_amp_mode_absent_when_disabled(self, MockEDM): + MockEDM.return_value = MagicMock(spec=nn.Module) + ds = _make_mock_dataset() + mgr = _make_manager_corrdiff(dataset=ds, enable_amp=False) + _, args = mgr.create_model( + "diffusion", {"N_grid_channels": 0} + ) + assert "amp_mode" not in args + + @patch("hirad.training.training_manager.EDMPrecondSuperResolution") + def test_img_in_out_channels_in_model_args(self, MockEDM): + MockEDM.return_value = MagicMock(spec=nn.Module) + ds = _make_mock_dataset(n_input=4, n_output=3, n_static=2) + mgr = _make_manager_corrdiff(dataset=ds, n_month_hour_channels=6) + _, args = mgr.create_model( + "diffusion", {"N_grid_channels": 0} + ) + assert "img_in_channels" in args + assert "img_out_channels" in args + +############################################################################ +# TrainingManagerCorrDiff — load_regression_model # +############################################################################ + + +class TestLoadRegressionModel: + """Tests for TrainingManagerCorrDiff.load_regression_model.""" + + def test_missing_dir_raises_file_not_found(self, tmp_path): + mgr = _make_manager_corrdiff() + with pytest.raises(FileNotFoundError, match="not found"): + mgr.load_regression_model(str(tmp_path / "nonexistent")) + + def test_missing_model_args_json_raises(self, tmp_path): + """Directory exists but model_args.json is missing.""" + ckpt_dir = tmp_path / "ckpt" + ckpt_dir.mkdir() + mgr = _make_manager_corrdiff() + with pytest.raises(FileNotFoundError, match="model_args.json"): + mgr.load_regression_model(str(ckpt_dir)) + + @patch("hirad.training.training_manager.load_checkpoint") + @patch("hirad.training.training_manager.UNet") + def test_loads_and_returns_model(self, MockUNet, mock_load_ckpt, tmp_path): + """Should load model_args.json and return a UNet in eval mode.""" + ckpt_dir = tmp_path / "ckpt" + ckpt_dir.mkdir() + model_args = { + "img_in_channels": 6, + "img_out_channels": 3, + "img_resolution": [64, 64], + } + (ckpt_dir / "model_args.json").write_text(json.dumps(model_args)) + + mock_model = MagicMock(spec=nn.Module) + mock_model.eval.return_value = mock_model + mock_model.requires_grad_.return_value = mock_model + mock_model.to.return_value = mock_model + MockUNet.return_value = mock_model + mock_load_ckpt.return_value = 0 + + mgr = _make_manager_corrdiff() + result = mgr.load_regression_model(str(ckpt_dir)) + + MockUNet.assert_called_once() + mock_model.eval.assert_called_once() + mock_model.requires_grad_.assert_called_once_with(False) + assert result is mock_model + + @patch("hirad.training.training_manager.load_checkpoint") + @patch("hirad.training.training_manager.UNet") + def test_passes_apex_and_profile_and_amp_flags(self, MockUNet, mock_load_ckpt, tmp_path): + """UNet should receive use_apex_gn, profile_mode, and amp_mode.""" + ckpt_dir = tmp_path / "ckpt" + ckpt_dir.mkdir() + (ckpt_dir / "model_args.json").write_text( + json.dumps({"img_in_channels": 6, "img_out_channels": 3, "img_resolution": [64, 64]}) + ) + mock_model = MagicMock(spec=nn.Module) + mock_model.eval.return_value = mock_model + mock_model.requires_grad_.return_value = mock_model + mock_model.to.return_value = mock_model + MockUNet.return_value = mock_model + mock_load_ckpt.return_value = 0 + + mgr = _make_manager_corrdiff(use_apex_gn=True, profile_mode=True, enable_amp=True) + mgr.load_regression_model(str(ckpt_dir)) + + call_kwargs = MockUNet.call_args[1] + assert call_kwargs["use_apex_gn"] is True + assert call_kwargs["profile_mode"] is True + assert call_kwargs["amp_mode"] is True + + @patch("hirad.training.training_manager.load_checkpoint") + @patch("hirad.training.training_manager.UNet") + def test_apex_gn_sets_channels_last(self, MockUNet, mock_load_ckpt, tmp_path): + """With use_apex_gn, model.to(memory_format=channels_last) should be called.""" + ckpt_dir = tmp_path / "ckpt" + ckpt_dir.mkdir() + (ckpt_dir / "model_args.json").write_text( + json.dumps({"img_in_channels": 6, "img_out_channels": 3, "img_resolution": [64, 64]}) + ) + mock_model = MagicMock(spec=nn.Module) + mock_model.eval.return_value = mock_model + mock_model.requires_grad_.return_value = mock_model + mock_model.to.return_value = mock_model + MockUNet.return_value = mock_model + mock_load_ckpt.return_value = 0 + + mgr = _make_manager_corrdiff(use_apex_gn=True) + mgr.load_regression_model(str(ckpt_dir)) + + mock_model.to.assert_any_call(memory_format=torch.channels_last) + + +############################################################################ +# TrainingManagerCorrDiff — run_validation # +############################################################################ + + +class TestRunValidation: + """Tests for TrainingManagerCorrDiff.run_validation.""" + + @staticmethod + def _make_loss_fn(loss_value=1.0, loss_size=B): + loss_fn = MagicMock() + loss_fn.return_value = torch.tensor([loss_value] * loss_size) + loss_fn.y_mean = None + return loss_fn + + @staticmethod + def _make_validation_iterator(n_steps, n_out=C_OUT, n_in=C_IN): + batches = [ + [torch.randn(B, n_out, H * W), torch.randn(B, n_in, H * W)] + for _ in range(n_steps) + ] + return iter(batches) + + def test_returns_float(self): + mgr = _make_manager_corrdiff() + loss_fn = self._make_loss_fn() + it = self._make_validation_iterator(2) + result = mgr.run_validation( + cur_nimg=100, + validation_dataset_iterator=it, + model=MagicMock(), + loss_fn=loss_fn, + validation_steps=2, + static_channels=None, + batch_size_per_gpu=B, + patching=None, + patch_nums_iter=[1], + use_patch_grad_acc=None, + ) + assert isinstance(result, float) + + def test_calls_loss_fn_per_step(self): + mgr = _make_manager_corrdiff() + loss_fn = self._make_loss_fn() + n_steps = 3 + it = self._make_validation_iterator(n_steps) + mgr.run_validation( + cur_nimg=100, + validation_dataset_iterator=it, + model=MagicMock(), + loss_fn=loss_fn, + validation_steps=n_steps, + static_channels=None, + batch_size_per_gpu=B, + patching=None, + patch_nums_iter=[1], + use_patch_grad_acc=None, + ) + assert loss_fn.call_count == n_steps + + def test_calls_loss_fn_per_patch_iter(self): + """Loss should be called validation_steps * len(patch_nums_iter) times.""" + mgr = _make_manager_corrdiff() + loss_fn = self._make_loss_fn() + n_steps = 2 + patch_nums_iter = [2, 2, 1] + it = self._make_validation_iterator(n_steps) + patching = MagicMock() + mgr.run_validation( + cur_nimg=100, + validation_dataset_iterator=it, + model=MagicMock(), + loss_fn=loss_fn, + validation_steps=n_steps, + static_channels=None, + batch_size_per_gpu=B, + patching=patching, + patch_nums_iter=patch_nums_iter, + use_patch_grad_acc=None, + ) + assert loss_fn.call_count == n_steps * len(patch_nums_iter) + + def test_sets_patch_num_on_patching(self): + """patching.set_patch_num should be called for each patch iteration.""" + mgr = _make_manager_corrdiff() + loss_fn = self._make_loss_fn() + it = self._make_validation_iterator(1) + patching = MagicMock() + patch_nums_iter = [3, 2] + mgr.run_validation( + cur_nimg=100, + validation_dataset_iterator=it, + model=MagicMock(), + loss_fn=loss_fn, + validation_steps=1, + static_channels=None, + batch_size_per_gpu=B, + patching=patching, + patch_nums_iter=patch_nums_iter, + use_patch_grad_acc=None, + ) + calls = [c.args[0] for c in patching.set_patch_num.call_args_list] + assert calls == [3, 2] + + @patch("hirad.training.training_manager.mlflow") + def test_logs_to_mlflow_on_rank0(self, mock_mlflow): + dist = _make_mock_dist(rank=0, world_size=1) + mgr = _make_manager_corrdiff(dist=dist, logging_method="mlflow") + loss_fn = self._make_loss_fn() + it = self._make_validation_iterator(1) + mgr.run_validation( + cur_nimg=200, + validation_dataset_iterator=it, + model=MagicMock(), + loss_fn=loss_fn, + validation_steps=1, + static_channels=None, + batch_size_per_gpu=B, + patching=None, + patch_nums_iter=[1], + use_patch_grad_acc=None, + ) + mock_mlflow.log_metric.assert_called_once() + call_args = mock_mlflow.log_metric.call_args + assert call_args[0][0] == "validation_loss" + assert call_args[0][2] == 200 # cur_nimg + + @patch("hirad.training.training_manager.mlflow") + def test_no_mlflow_on_non_rank0(self, mock_mlflow): + dist = _make_mock_dist(rank=1, world_size=1) # keep world size 1 not to trigger any distributed logic, but set rank to non-zero + mgr = _make_manager_corrdiff(dist=dist, logging_method="mlflow") + loss_fn = self._make_loss_fn() + it = self._make_validation_iterator(1) + mgr.run_validation( + cur_nimg=200, + validation_dataset_iterator=it, + model=MagicMock(), + loss_fn=loss_fn, + validation_steps=1, + static_channels=None, + batch_size_per_gpu=B, + patching=None, + patch_nums_iter=[1], + use_patch_grad_acc=None, + ) + mock_mlflow.log_metric.assert_not_called() + + @patch("hirad.training.training_manager.mlflow") + def test_no_mlflow_when_logging_disabled(self, mock_mlflow): + dist = _make_mock_dist(rank=0, world_size=1) + mgr = _make_manager_corrdiff(dist=dist, logging_method=None) + loss_fn = self._make_loss_fn() + it = self._make_validation_iterator(1) + mgr.run_validation( + cur_nimg=200, + validation_dataset_iterator=it, + model=MagicMock(), + loss_fn=loss_fn, + validation_steps=1, + static_channels=None, + batch_size_per_gpu=B, + patching=None, + patch_nums_iter=[1], + use_patch_grad_acc=None, + ) + mock_mlflow.log_metric.assert_not_called() + + def test_resets_y_mean_with_patch_grad_acc(self): + """When use_patch_grad_acc is True, loss_fn.y_mean should be reset each step.""" + mgr = _make_manager_corrdiff() + loss_fn = self._make_loss_fn() + loss_fn.y_mean = torch.tensor(42.0) + it = self._make_validation_iterator(1) + mgr.run_validation( + cur_nimg=100, + validation_dataset_iterator=it, + model=MagicMock(), + loss_fn=loss_fn, + validation_steps=1, + static_channels=None, + batch_size_per_gpu=B, + patching=None, + patch_nums_iter=[1], + use_patch_grad_acc=True, + ) + # y_mean should have been set to None at the beginning of the step + assert loss_fn.y_mean is None + + def test_average_loss_value_as_expected(self): + """Test that the average loss value is computed as expected.""" + mgr = _make_manager_corrdiff() + loss_fn = self._make_loss_fn() + it = self._make_validation_iterator(1) + result = mgr.run_validation( + cur_nimg=100, + validation_dataset_iterator=it, + model=MagicMock(), + loss_fn=loss_fn, + validation_steps=1, + static_channels=None, + batch_size_per_gpu=B, + patching=None, + patch_nums_iter=[1], + use_patch_grad_acc=True, + ) + assert result == 1.0 + + def test_average_loss_value_with_patching_as_expected(self): + """Test that the average loss value is computed as expected when using patching.""" + mgr = _make_manager_corrdiff() + loss_fn = self._make_loss_fn(loss_size=B*3) # simulate 3 patches per batch_element + it = self._make_validation_iterator(3) + patch_nums_iter = [3, 3] + result = mgr.run_validation( + cur_nimg=100, + validation_dataset_iterator=it, + model=MagicMock(), + loss_fn=loss_fn, + validation_steps=3, + static_channels=None, + batch_size_per_gpu=B, + patching=MagicMock(), + patch_nums_iter=patch_nums_iter, + use_patch_grad_acc=True, + ) + # With 2 total patch iterations with 3 patches per iteration and a loss of 1.0 per iteration, the average should still be 1.0 + assert result == 1.0 \ No newline at end of file diff --git a/tests/utils/test_train_helpers.py b/tests/utils/test_train_helpers.py index 09d6a42..3635352 100644 --- a/tests/utils/test_train_helpers.py +++ b/tests/utils/test_train_helpers.py @@ -7,6 +7,7 @@ from omegaconf import DictConfig, OmegaConf from hirad.utils.train_helpers import ( + calculate_patch_per_iter, check_model_health, compute_num_accumulation_rounds, handle_and_clip_gradients, @@ -14,6 +15,7 @@ is_time_for_periodic_task, set_patch_shape, set_seed, + update_learning_rate, ) @@ -684,3 +686,269 @@ def test_sub_node_sets_experiment(self, base_cfg, mock_mlflow, tmp_path): experiment_name="test_experiment" ) + +############################################################################ +# update_learning_rate # +############################################################################ + + +class TestUpdateLearningRate: + """Tests for update_learning_rate.""" + + @staticmethod + def _make_optimizer(lr, num_groups=1): + """Create a simple SGD optimizer with `num_groups` param groups.""" + params = [torch.nn.Parameter(torch.zeros(1)) for _ in range(num_groups)] + optimizer = torch.optim.SGD( + [{"params": [p], "lr": lr} for p in params] + ) + return optimizer + + # ------------------------------------------------------------------ + # Rampup phase (cur_nimg < lr_rampup) + # ------------------------------------------------------------------ + def test_rampup_halfway(self): + """At half the rampup period the LR should be lr * 0.5.""" + opt = self._make_optimizer(lr=0.1) + result = update_learning_rate(opt, lr=0.01, lr_rampup=1000, + lr_decay=1.0, lr_decay_rate=1, cur_nimg=500) + assert result == pytest.approx(0.01 * 0.5) + + def test_rampup_quarter(self): + opt = self._make_optimizer(lr=0.1) + result = update_learning_rate(opt, lr=0.02, lr_rampup=2000, + lr_decay=1.0, lr_decay_rate=1, cur_nimg=500) + assert result == pytest.approx(0.02 * 0.25) + + def test_rampup_at_zero(self): + """At cur_nimg=0 the LR should be 0 during rampup.""" + opt = self._make_optimizer(lr=0.1) + result = update_learning_rate(opt, lr=0.01, lr_rampup=1000, + lr_decay=1.0, lr_decay_rate=1, cur_nimg=0) + assert result == pytest.approx(0.0) + + # ------------------------------------------------------------------ + # Rampup boundary (cur_nimg == lr_rampup) + # ------------------------------------------------------------------ + def test_rampup_exact_boundary(self): + """At the exact rampup boundary LR should equal base lr (no decay yet).""" + opt = self._make_optimizer(lr=0.1) + result = update_learning_rate(opt, lr=0.01, lr_rampup=1000, + lr_decay=0.5, lr_decay_rate=500, cur_nimg=1000) + # rampup factor = min(1000/1000, 1) = 1 → lr = 0.01 + # decay exponent = (1000 - 1000) // 500 = 0 → 0.5^0 = 1 + assert result == pytest.approx(0.01) + + # ------------------------------------------------------------------ + # Post-rampup with decay + # ------------------------------------------------------------------ + def test_decay_one_step(self): + """One decay step after rampup.""" + opt = self._make_optimizer(lr=0.1) + result = update_learning_rate(opt, lr=0.01, lr_rampup=1000, + lr_decay=0.5, lr_decay_rate=500, cur_nimg=1500) + # rampup clamped at 1, decay = 0.5 ^ ((1500-1000)//500) = 0.5^1 + assert result == pytest.approx(0.01 * 0.5) + + def test_decay_two_steps(self): + opt = self._make_optimizer(lr=0.1) + result = update_learning_rate(opt, lr=0.01, lr_rampup=1000, + lr_decay=0.5, lr_decay_rate=500, cur_nimg=2000) + # decay = 0.5 ^ ((2000-1000)//500) = 0.5^2 = 0.25 + assert result == pytest.approx(0.01 * 0.25) + + def test_decay_partial_step_floors(self): + """Decay uses integer division, so partial steps are floored.""" + opt = self._make_optimizer(lr=0.1) + result = update_learning_rate(opt, lr=0.01, lr_rampup=1000, + lr_decay=0.5, lr_decay_rate=500, cur_nimg=1499) + # (1499-1000)//500 = 0 → no decay yet + assert result == pytest.approx(0.01) + + # ------------------------------------------------------------------ + # No rampup (lr_rampup == 0) + # ------------------------------------------------------------------ + def test_no_rampup_applies_decay_to_existing_lr(self): + """When lr_rampup=0 the base lr is NOT overwritten; + decay is applied to the optimizer's current lr.""" + opt = self._make_optimizer(lr=0.04) + result = update_learning_rate(opt, lr=999, # ignored for the set step + lr_rampup=0, lr_decay=0.5, + lr_decay_rate=100, cur_nimg=100) + # g["lr"] stays 0.04 (rampup branch skipped), then *= 0.5^(100//100) = 0.5 + assert result == pytest.approx(0.04 * 0.5) + + def test_no_rampup_no_decay(self): + """lr_rampup=0, lr_decay=1.0 → LR unchanged.""" + opt = self._make_optimizer(lr=0.03) + result = update_learning_rate(opt, lr=999, lr_rampup=0, + lr_decay=1.0, lr_decay_rate=100, cur_nimg=500) + assert result == pytest.approx(0.03) + + # ------------------------------------------------------------------ + # No decay (lr_decay == 1.0) + # ------------------------------------------------------------------ + def test_rampup_without_decay(self): + """Rampup works independently of decay when decay=1.0.""" + opt = self._make_optimizer(lr=0.1) + result = update_learning_rate(opt, lr=0.01, lr_rampup=1000, + lr_decay=1.0, lr_decay_rate=500, cur_nimg=2000) + assert result == pytest.approx(0.01) + + # ------------------------------------------------------------------ + # Multiple param groups + # ------------------------------------------------------------------ + def test_multiple_param_groups(self): + """All param groups are updated; return value is the last group's LR.""" + opt = self._make_optimizer(lr=0.1, num_groups=3) + result = update_learning_rate(opt, lr=0.01, lr_rampup=1000, + lr_decay=1.0, lr_decay_rate=1, cur_nimg=500) + expected = 0.01 * 0.5 + for g in opt.param_groups: + assert g["lr"] == pytest.approx(expected) + assert result == pytest.approx(expected) + + # ------------------------------------------------------------------ + # Successive calls (simulating a training loop) + # ------------------------------------------------------------------ + def test_successive_calls_during_rampup(self): + """LR should grow linearly across successive rampup calls.""" + opt = self._make_optimizer(lr=0.1) + lr, rampup = 0.01, 1000 + lrs = [] + for step in range(0, 1001, 200): + lrs.append( + update_learning_rate(opt, lr=lr, lr_rampup=rampup, + lr_decay=1.0, lr_decay_rate=1, cur_nimg=step) + ) + expected = [lr * min(s / rampup, 1) for s in range(0, 1001, 200)] + for got, exp in zip(lrs, expected): + assert got == pytest.approx(exp) + + def test_successive_calls_with_decay(self): + """LR should decrease in staircase fashion after rampup.""" + opt = self._make_optimizer(lr=0.1) + lr, rampup, decay, rate = 0.01, 0, 0.9, 100 + prev_lr = None + for step in [0, 50, 100, 150, 200]: + # reset optimizer lr before each call since no rampup means + # the function mutates the existing lr multiplicatively + for g in opt.param_groups: + g["lr"] = lr + cur = update_learning_rate(opt, lr=lr, lr_rampup=rampup, + lr_decay=decay, lr_decay_rate=rate, + cur_nimg=step) + expected = lr * decay ** (step // rate) + assert cur == pytest.approx(expected) + + +############################################################################ +# calculate_patch_per_iter # +############################################################################ + + +class TestCalculatePatchPerIter: + """Tests for calculate_patch_per_iter.""" + + # ------------------------------------------------------------------ + # max_patch_per_gpu is None / falsy → single iteration + # ------------------------------------------------------------------ + def test_no_max_returns_single_element(self): + """When max_patch_per_gpu is None, return [patch_num].""" + assert calculate_patch_per_iter(4, None, 1) == [4] + + def test_no_max_zero_returns_single_element(self): + """When max_patch_per_gpu is 0 (falsy), return [patch_num].""" + assert calculate_patch_per_iter(8, 0, 2) == [8] + + def test_no_max_patch_num_one(self): + assert calculate_patch_per_iter(1, None, 1) == [1] + + # ------------------------------------------------------------------ + # max_patch_per_gpu provided – fits in a single iteration + # ------------------------------------------------------------------ + def test_single_iter_exact_fit(self): + """patch_num fits exactly within max_patch_per_gpu.""" + # max_patch_num_per_iter = min(4, 8//2) = 4 → 1 iteration + assert calculate_patch_per_iter(4, 8, 2) == [4] + + def test_single_iter_max_exceeds_patch_num(self): + """max allows more patches than needed; still one iteration.""" + # max_patch_num_per_iter = min(2, 16//1) = 2 → 1 iteration + assert calculate_patch_per_iter(2, 16, 1) == [2] + + # ------------------------------------------------------------------ + # max_patch_per_gpu provided – requires multiple iterations + # ------------------------------------------------------------------ + def test_even_split(self): + """patch_num divides evenly into iterations.""" + # max_patch_num_per_iter = min(8, 4//1) = 4 → 2 iterations of 4 + assert calculate_patch_per_iter(8, 4, 1) == [4, 4] + + def test_uneven_split(self): + """Last iteration gets fewer patches.""" + # max_patch_num_per_iter = min(7, 4//1) = 4 + # iterations = ceil(7/4) = 2 → [4, 3] + assert calculate_patch_per_iter(7, 4, 1) == [4, 3] + + def test_three_iterations(self): + """Requires three iterations with a remainder.""" + # max_patch_num_per_iter = min(10, 4//1) = 4 + # iterations = ceil(10/4) = 3 → [4, 4, 2] + assert calculate_patch_per_iter(10, 4, 1) == [4, 4, 2] + + def test_patch_num_one_less_than_max(self): + # max_patch_num_per_iter = min(3, 4//1) = 3 → single iteration + assert calculate_patch_per_iter(3, 4, 1) == [3] + + def test_patch_num_one_more_than_max(self): + # max_patch_num_per_iter = min(5, 4//1) = 4 + # iterations = ceil(5/4) = 2 → [4, 1] + assert calculate_patch_per_iter(5, 4, 1) == [4, 1] + + # ------------------------------------------------------------------ + # batch_size_per_gpu interaction + # ------------------------------------------------------------------ + def test_batch_size_reduces_max_per_iter(self): + """Larger batch size reduces the effective max patches per iter.""" + # max_patch_num_per_iter = min(6, 8//4) = 2 + # iterations = ceil(6/2) = 3 → [2, 2, 2] + assert calculate_patch_per_iter(6, 8, 4) == [2, 2, 2] + + def test_batch_size_equals_max(self): + """batch_size_per_gpu == max_patch_per_gpu → 1 patch per iter.""" + # max_patch_num_per_iter = min(3, 4//4) = 1 + # iterations = ceil(3/1) = 3 → [1, 1, 1] + assert calculate_patch_per_iter(3, 4, 4) == [1, 1, 1] + + # ------------------------------------------------------------------ + # Validation / edge cases + # ------------------------------------------------------------------ + def test_max_less_than_batch_raises(self): + """max_patch_per_gpu < batch_size_per_gpu should raise.""" + with pytest.raises(ValueError, match="max_patch_per_gpu"): + calculate_patch_per_iter(4, 2, 4) + + def test_sum_equals_patch_num(self): + """Sum of returned list must always equal patch_num.""" + for patch_num in range(1, 20): + for batch in [1, 2, 4]: + for max_ppg in [batch, batch * 2, batch * 3, batch * 5]: + result = calculate_patch_per_iter(patch_num, max_ppg, batch) + assert sum(result) == patch_num, ( + f"patch_num={patch_num}, max_ppg={max_ppg}, batch={batch}: " + f"sum({result}) = {sum(result)} != {patch_num}" + ) + + def test_no_element_exceeds_max(self): + """No single iteration should exceed max_patch_num_per_iter.""" + for patch_num in range(1, 20): + for batch in [1, 2, 4]: + for max_ppg in [batch, batch * 2, batch * 3]: + max_per_iter = min(patch_num, max_ppg // batch) + result = calculate_patch_per_iter(patch_num, max_ppg, batch) + assert all(r <= max_per_iter for r in result), ( + f"patch_num={patch_num}, max_ppg={max_ppg}, batch={batch}: " + f"{result} has element > {max_per_iter}" + ) +