Skip to content
4 changes: 1 addition & 3 deletions ci/cscs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,9 @@ build_job:

test_job:
stage: test
extends: .container-runner-clariden-gh200
extends: .container-runner-santis-gh200
image: $PERSIST_IMAGE_NAME
script:
# - echo 'hello world'
- ls /opt
- pytest /opt/hirad-gen/tests -v
variables:
USE_MPI: NO
Expand Down
3 changes: 2 additions & 1 deletion ci/docker/Dockerfile.ci
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ FROM jfrog.svc.cscs.ch/docker-group-csstaff/alps-images/ngc-physicsnemo:25.11-al
RUN pip install --upgrade pip

# install dependencies
RUN pip install mlflow
RUN pip install mlflow \
anemoi-datasets

COPY . /opt/hirad-gen

Expand Down
2 changes: 1 addition & 1 deletion src/hirad/conf/generation/era_real.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ perf:
force_fp16: False
# Whether to force fp16 precision for the model. If false, it'll use the precision
# specified upon training.
use_torch_compile: False
use_torch_compile: True
# whether to use torch.compile on the diffusion model
# this will make the first time stamp generation very slow due to compilation overheads
# but will significantly speed up subsequent inference runs
Expand Down
12 changes: 6 additions & 6 deletions src/hirad/datasets/anemoi_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,11 @@ def __getitem__(self, idx):
# next two steps only if target is cosmo, real has to be regridded first (done in training loop on gpu-s for efficiency)
# reshape to image_shape
# flip so that it starts in top-left corner (by default it is bottom left)
if not self.real_target:
target_shape = self.image_shape()
target_data = np.flip(target_data \
.reshape(-1,*target_shape),
1)
# if not self.real_target:
# target_shape = self.image_shape()
# target_data = np.flip(target_data \
# .reshape(-1,*target_shape),
# 1)

return torch.from_numpy(target_data.copy()),\
torch.from_numpy(input_data),\
Expand Down Expand Up @@ -344,7 +344,7 @@ def make_time_grids(self, dates: list[str], device: torch.device, dtype: torch.d

Returns
-------
grid : torch.Tensor, shape (B, C, H, W)
grid : torch.Tensor, shape (B, C)
Channels = [sin(k*hour), cos(k*hour), sin(k*month), cos(k*month) for each k]
"""

Expand Down
297 changes: 156 additions & 141 deletions src/hirad/inference/generate.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions src/hirad/models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class Conv2d(torch.nn.Module):
"""
A custom 2D convolutional layer implementation with support for up-sampling,
down-sampling, and custom weight and bias initializations. The layer's weights
and biases canbe initialized using custom initialization strategies like
and biases can be initialized using custom initialization strategies like
"kaiming_normal", and can be further scaled by factors `init_weight` and
`init_bias`.

Expand Down Expand Up @@ -403,7 +403,7 @@ def forward(self, x):
x = rearrange(x, "b (g c) h w -> b g c h w", g=self.num_groups)

mean = x.mean(dim=[2, 3, 4], keepdim=True)
var = x.var(dim=[2, 3, 4], keepdim=True)
var = x.var(dim=[2, 3, 4], keepdim=True, unbiased=False)

x = (x - mean) * (var + self.eps).rsqrt()
x = rearrange(x, "b g c h w -> b (g c) h w")
Expand Down
31 changes: 22 additions & 9 deletions src/hirad/models/song_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def __init__(
emb_channels=emb_channels,
num_heads=1,
dropout=dropout,
skip_scale=np.sqrt(0.5),
skip_scale=0.7071067811865476, # 1 / sqrt(2)
eps=1e-6,
resample_filter=resample_filter,
resample_proj=True,
Expand Down Expand Up @@ -659,10 +659,13 @@ def __init__(

self.gridtype = gridtype
self.N_grid_channels = N_grid_channels
if self.gridtype == "learnable":
self.pos_embd = self._get_positional_embedding()
if self.N_grid_channels:
if self.gridtype == "learnable":
self.pos_embd = self._get_positional_embedding()
else:
self.register_buffer("pos_embd", self._get_positional_embedding().float())
else:
self.register_buffer("pos_embd", self._get_positional_embedding().float())
self.pos_embd = None
self.lead_time_mode = lead_time_mode
if self.lead_time_mode:
self.lead_time_channels = lead_time_channels
Expand Down Expand Up @@ -693,7 +696,13 @@ def forward(
"embedding_selector is the preferred approach for better efficiency."
)

if x.dtype != self.pos_embd.dtype:
if self.lead_time_mode and embedding_selector is not None:
raise ValueError(
"Embedding selector is not supported in lead time mode. "
"Please use global_index to select positional embeddings when lead_time_mode is True."
)

if self.pos_embd is not None and x.dtype != self.pos_embd.dtype:
self.pos_embd = self.pos_embd.to(x.dtype)

# Append positional embedding to input conditioning
Expand Down Expand Up @@ -780,17 +789,17 @@ def positional_embedding_indexing(
Example
-------
>>> # Create global indices using patching utility:
>>> from physicsnemo.utils.patching import GridPatching2D
>>> from hirad.utils.patching import GridPatching2D
>>> patching = GridPatching2D(img_shape=(16, 16), patch_shape=(8, 8))
>>> global_index = patching.global_index(batch_size=3)
>>> print(global_index.shape)
torch.Size([4, 2, 8, 8])

See Also
--------
:meth:`physicsnemo.utils.patching.RandomPatching2D.global_index`
:meth:`hirad.utils.patching.RandomPatching2D.global_index`
For generating random patch indices.
:meth:`physicsnemo.utils.patching.GridPatching2D.global_index`
:meth:`hirad.utils.patching.GridPatching2D.global_index`
For generating deterministic grid-based patch indices.
See these methods for possible ways to generate the global_index parameter.
"""
Expand Down Expand Up @@ -900,7 +909,7 @@ def positional_embedding_selector(
Each selected embedding should correspond to the positional
information of each batch element in x.
For patch-based processing, typically this should be based on
:meth:`physicsnemo.utils.patching.BasePatching2D.apply` method to
:meth:`hirad.utils.patching.BasePatching2D.apply` method to
maintain consistency with patch extraction.
embeds : Optional[torch.Tensor]
Optional tensor for combined positional and lead time embeddings tensor
Expand Down Expand Up @@ -969,6 +978,10 @@ def _get_positional_embedding(self):
raise ValueError("N_grid_channels must be a factor of 4")
num_freq = self.N_grid_channels // 4
freq_bands = 2.0 ** np.linspace(0.0, num_freq, num=num_freq)
#TODO: When more than 4 channels are used for sinusoidal, the frequencies should be multiples of the base frequency (2).
# freq_bands = 2.0 ** np.linspace(0.0, num_freq, num=num_freq) is currently in code which gives
# freqs = [1,4] instead of [1,2] for N_grid_channels=8. This seems to be a bug if we want the base 2.
# Leaving it like this for now since we have checkpoints with 8 sinusoidal channels that use these frequencies,
grid_list = []
grid_x, grid_y = np.meshgrid(
np.linspace(0, 2 * np.pi, self.img_shape_x),
Expand Down
41 changes: 2 additions & 39 deletions src/hirad/models/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,42 +63,6 @@ class UNet(nn.Module): # TODO a lot of redundancy, need to clean up
arXiv preprint arXiv:2309.15214.
"""

@classmethod
def _backward_compat_arg_mapper(
cls, version: str, args: Dict[str, Any]
) -> Dict[str, Any]:
"""Map arguments from older versions to current version format.

Parameters
----------
version : str
Version of the checkpoint being loaded
args : Dict[str, Any]
Arguments dictionary from the checkpoint

Returns
-------
Dict[str, Any]
Updated arguments dictionary compatible with current version
"""
# Call parent class method first
args = super()._backward_compat_arg_mapper(version, args)

if version == "0.1.0":
# In version 0.1.0, img_channels was unused
if "img_channels" in args:
_ = args.pop("img_channels")

# Sigma parameters are also unused
if "sigma_min" in args:
_ = args.pop("sigma_min")
if "sigma_max" in args:
_ = args.pop("sigma_max")
if "sigma_data" in args:
_ = args.pop("sigma_data")

return args

def __init__(
self,
img_resolution: Union[int, Tuple[int, int]],
Expand Down Expand Up @@ -217,8 +181,8 @@ def forward(
)

F_x = self.model(
x.to(dtype), # (c_in * x).to(dtype),
torch.zeros(x.shape[0], dtype=dtype, device=x.device), # c_noise.flatten()
x.to(dtype),
torch.zeros(x.shape[0], dtype=dtype, device=x.device),
class_labels=None,
**model_kwargs,
)
Expand All @@ -228,7 +192,6 @@ def forward(
f"Expected the dtype to be {dtype}, " f"but got {F_x.dtype} instead."
)

# skip connection
D_x = F_x.to(torch.float32)
return D_x

Expand Down
Empty file added src/hirad/training/__init__.py
Empty file.
Loading