-
Notifications
You must be signed in to change notification settings - Fork 1.4k
8627 perceptual loss errors out after hitting the maximum number of downloads #8652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
b1e4a50
0aeb4d9
fa0639b
685aee2
915de5f
5594bfe
b065de7
c99e16e
717b99b
b276f3c
2156b84
e2b982e
e3be8de
b02053b
6dfc209
d258390
2520920
d2ab308
081a673
cff03d6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,11 +19,18 @@ | |
| from monai.utils import optional_import | ||
| from monai.utils.enums import StrEnum | ||
|
|
||
| # Valid model name to download from the repository | ||
| HF_MONAI_MODELS = frozenset( | ||
| ("medicalnet_resnet10_23datasets", "medicalnet_resnet50_23datasets", "radimagenet_resnet50") | ||
| ) | ||
|
Comment on lines
+22
to
+25
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Restrict model validation per family and guard 3D path to MedicalNet only. As written,
Recommend:
Example patch: @@
-# Valid model name to download from the repository
-HF_MONAI_MODELS = frozenset(
- ("medicalnet_resnet10_23datasets", "medicalnet_resnet50_23datasets", "radimagenet_resnet50")
-)
+# Valid model names to download from the repository
+HF_MONAI_MODELS = frozenset(
+ ("medicalnet_resnet10_23datasets", "medicalnet_resnet50_23datasets", "radimagenet_resnet50")
+)
+HF_MONAI_MEDICALNET_MODELS = frozenset(
+ ("medicalnet_resnet10_23datasets", "medicalnet_resnet50_23datasets")
+)
+HF_MONAI_RADIMAGENET_MODELS = frozenset(("radimagenet_resnet50",))
@@
- # If spatial_dims is 3, only MedicalNet supports 3D models, otherwise, spatial_dims=2 and fake_3D must be used.
- if spatial_dims == 3 and is_fake_3d is False:
- self.perceptual_function = MedicalNetPerceptualSimilarity(
- net=network_type, verbose=False, channel_wise=channel_wise, cache_dir=cache_dir
- )
+ # If spatial_dims is 3, only MedicalNet supports 3D models; other networks must use the fake 3D path.
+ if spatial_dims == 3 and is_fake_3d is False:
+ if "medicalnet_" not in network_type:
+ raise ValueError(
+ "Only MedicalNet networks support 3D perceptual loss with is_fake_3d=False; "
+ f"got network_type={network_type!r}."
+ )
+ self.perceptual_function = MedicalNetPerceptualSimilarity(
+ net=network_type, verbose=False, channel_wise=channel_wise, cache_dir=cache_dir
+ )
@@
- if net not in HF_MONAI_MODELS:
- raise ValueError(f"Invalid download model name '{net}'. Must be one of: {', '.join(HF_MONAI_MODELS)}.")
+ if net not in HF_MONAI_MEDICALNET_MODELS:
+ raise ValueError(
+ f"Invalid MedicalNet model name '{net}'. Must be one of: {', '.join(HF_MONAI_MEDICALNET_MODELS)}."
+ )
@@
- if net not in HF_MONAI_MODELS:
- raise ValueError(f"Invalid download model name '{net}'. Must be one of: {', '.join(HF_MONAI_MODELS)}.")
+ if net not in HF_MONAI_RADIMAGENET_MODELS:
+ raise ValueError(
+ f"Invalid RadImageNet model name '{net}'. Must be one of: {', '.join(HF_MONAI_RADIMAGENET_MODELS)}."
+ )Also applies to: 94-107, 125-133, 234-239, 325-331 |
||
|
|
||
| LPIPS, _ = optional_import("lpips", name="LPIPS") | ||
| torchvision, _ = optional_import("torchvision") | ||
|
|
||
|
|
||
| class PercetualNetworkType(StrEnum): | ||
| class PerceptualNetworkType(StrEnum): | ||
| """Types of neural networks that are supported by perceptual loss.""" | ||
|
|
||
| alex = "alex" | ||
| vgg = "vgg" | ||
| squeeze = "squeeze" | ||
|
|
@@ -70,7 +77,7 @@ class PerceptualLoss(nn.Module): | |
| def __init__( | ||
| self, | ||
| spatial_dims: int, | ||
| network_type: str = PercetualNetworkType.alex, | ||
| network_type: str = PerceptualNetworkType.alex, | ||
| is_fake_3d: bool = True, | ||
| fake_3d_ratio: float = 0.5, | ||
| cache_dir: str | None = None, | ||
|
|
@@ -84,19 +91,25 @@ def __init__( | |
| if spatial_dims not in [2, 3]: | ||
| raise NotImplementedError("Perceptual loss is implemented only in 2D and 3D.") | ||
|
|
||
| if (spatial_dims == 2 or is_fake_3d) and "medicalnet_" in network_type: | ||
| raise ValueError( | ||
| "MedicalNet networks are only compatible with ``spatial_dims=3``." | ||
| "Argument is_fake_3d must be set to False." | ||
| ) | ||
|
|
||
| if channel_wise and "medicalnet_" not in network_type: | ||
| # Strict validation for MedicalNet | ||
| if "medicalnet_" in network_type: | ||
| if spatial_dims == 2 or is_fake_3d: | ||
| raise ValueError( | ||
| "MedicalNet networks are only compatible with ``spatial_dims=3``. Argument is_fake_3d must be set to False." | ||
| ) | ||
| if not channel_wise: | ||
| warnings.warn( | ||
| "MedicalNet networks supp, ort channel-wise loss. Consider setting channel_wise=True.", stacklevel=2 | ||
| ) | ||
|
|
||
| # Channel-wise only for MedicalNet | ||
| elif channel_wise: | ||
| raise ValueError("Channel-wise loss is only compatible with MedicalNet networks.") | ||
|
|
||
| if network_type.lower() not in list(PercetualNetworkType): | ||
| if network_type.lower() not in list(PerceptualNetworkType): | ||
| raise ValueError( | ||
| "Unrecognised criterion entered for Adversarial Loss. Must be one in: %s" | ||
| % ", ".join(PercetualNetworkType) | ||
| "Unrecognised criterion entered for Perceptual Loss. Must be one in: %s" | ||
| % ", ".join(PerceptualNetworkType) | ||
| ) | ||
|
|
||
| if cache_dir: | ||
|
|
@@ -108,12 +121,16 @@ def __init__( | |
|
|
||
| self.spatial_dims = spatial_dims | ||
| self.perceptual_function: nn.Module | ||
|
|
||
| # If spatial_dims is 3, only MedicalNet supports 3D models, otherwise, spatial_dims=2 and fake_3D must be used. | ||
| if spatial_dims == 3 and is_fake_3d is False: | ||
| self.perceptual_function = MedicalNetPerceptualSimilarity( | ||
| net=network_type, verbose=False, channel_wise=channel_wise | ||
| net=network_type, verbose=False, channel_wise=channel_wise, cache_dir=cache_dir | ||
| ) | ||
| elif "radimagenet_" in network_type: | ||
| self.perceptual_function = RadImageNetPerceptualSimilarity(net=network_type, verbose=False) | ||
| self.perceptual_function = RadImageNetPerceptualSimilarity( | ||
| net=network_type, verbose=False, cache_dir=cache_dir | ||
| ) | ||
| elif network_type == "resnet50": | ||
| self.perceptual_function = TorchvisionModelPerceptualSimilarity( | ||
| net=network_type, | ||
|
|
@@ -122,7 +139,9 @@ def __init__( | |
| pretrained_state_dict_key=pretrained_state_dict_key, | ||
| ) | ||
| else: | ||
| # VGG, AlexNet and SqueezeNet are independently handled by LPIPS. | ||
| self.perceptual_function = LPIPS(pretrained=pretrained, net=network_type, verbose=False) | ||
|
|
||
| self.is_fake_3d = is_fake_3d | ||
| self.fake_3d_ratio = fake_3d_ratio | ||
| self.channel_wise = channel_wise | ||
|
|
@@ -194,7 +213,7 @@ class MedicalNetPerceptualSimilarity(nn.Module): | |
| """ | ||
| Component to perform the perceptual evaluation with the networks pretrained by Chen, et al. "Med3D: Transfer | ||
| Learning for 3D Medical Image Analysis". This class uses torch Hub to download the networks from | ||
| "Warvito/MedicalNet-models". | ||
| "Project-MONAI/perceptual-models". | ||
|
|
||
| Args: | ||
| net: {``"medicalnet_resnet10_23datasets"``, ``"medicalnet_resnet50_23datasets"``} | ||
|
|
@@ -205,11 +224,19 @@ class MedicalNetPerceptualSimilarity(nn.Module): | |
| """ | ||
|
|
||
| def __init__( | ||
virginiafdez marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self, net: str = "medicalnet_resnet10_23datasets", verbose: bool = False, channel_wise: bool = False | ||
| self, | ||
| net: str = "medicalnet_resnet10_23datasets", | ||
| verbose: bool = False, | ||
| channel_wise: bool = False, | ||
| cache_dir: str | None = None, | ||
| ) -> None: | ||
| super().__init__() | ||
| torch.hub._validate_not_a_forked_repo = lambda a, b, c: True | ||
| self.model = torch.hub.load("warvito/MedicalNet-models", model=net, verbose=verbose, trust_repo=True) | ||
| if net not in HF_MONAI_MODELS: | ||
| raise ValueError(f"Invalid download model name '{net}'. Must be one of: {', '.join(HF_MONAI_MODELS)}.") | ||
|
|
||
| self.model = torch.hub.load( | ||
virginiafdez marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| "Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir, trust_repo=True | ||
| ) | ||
| self.eval() | ||
|
|
||
| self.channel_wise = channel_wise | ||
|
|
@@ -258,7 +285,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: | |
| for i in range(input.shape[1]): | ||
| l_idx = i * feats_per_ch | ||
| r_idx = (i + 1) * feats_per_ch | ||
| results[:, i, ...] = feats_diff[:, l_idx : i + r_idx, ...].sum(dim=1) | ||
| results[:, i, ...] = feats_diff[:, l_idx:r_idx, ...].sum(dim=1) | ||
| else: | ||
| results = feats_diff.sum(dim=1, keepdim=True) | ||
|
|
||
|
|
@@ -287,17 +314,21 @@ class RadImageNetPerceptualSimilarity(nn.Module): | |
| """ | ||
| Component to perform the perceptual evaluation with the networks pretrained on RadImagenet (pretrained by Mei, et | ||
| al. "RadImageNet: An Open Radiologic Deep Learning Research Dataset for Effective Transfer Learning"). This class | ||
| uses torch Hub to download the networks from "Warvito/radimagenet-models". | ||
| uses torch Hub to download the networks from "Project-MONAI/perceptual-models". | ||
|
|
||
| Args: | ||
| net: {``"radimagenet_resnet50"``} | ||
| Specifies the network architecture to use. Defaults to ``"radimagenet_resnet50"``. | ||
| verbose: if false, mute messages from torch Hub load function. | ||
| """ | ||
|
|
||
| def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False) -> None: | ||
| def __init__(self, net: str = "radimagenet_resnet50", verbose: bool = False, cache_dir: str | None = None) -> None: | ||
| super().__init__() | ||
| self.model = torch.hub.load("Warvito/radimagenet-models", model=net, verbose=verbose, trust_repo=True) | ||
| if net not in HF_MONAI_MODELS: | ||
| raise ValueError(f"Invalid download model name '{net}'. Must be one of: {', '.join(HF_MONAI_MODELS)}.") | ||
| self.model = torch.hub.load( | ||
| "Project-MONAI/perceptual-models:main", model=net, verbose=verbose, cache_dir=cache_dir, trust_repo=True | ||
| ) | ||
| self.eval() | ||
|
|
||
| for param in self.parameters(): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cache path is pip cache dir.
Cache key uses hashFiles('requirements*.txt', 'setup.cfg', 'pyproject.toml') so it updates when dependencies change.
Kept torch installs --no-cache-dir to reduce disk pressure, while allowing smaller deps (e.g., requirements-min.txt) to benefit from caching.
cc @ericspod