-
Notifications
You must be signed in to change notification settings - Fork 665
[Torchvision API] Input metadata #6364
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
d3f473b
b0c754b
57302e8
a4d6209
693e9af
518c4d1
2c7e9ef
08ebc42
12dddd3
5c32f8f
4fd51eb
c024bac
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,69 @@ | ||
| # Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import nvidia.dali.experimental.dynamic as ndd | ||
| from nvidia.dali._typing import TensorLike | ||
| from nvidia.dali.experimental.dynamic._device import DeviceLike | ||
|
|
||
| from ..operator import adjust_input | ||
| from ..randomcrop import RandomCrop | ||
|
|
||
|
|
||
| def _get_crop_axes(inpt: TensorLike | ndd.Batch) -> list[int]: | ||
| layout = inpt.layout[-3:] | ||
| if layout == "HWC": | ||
| return [-3, -2] | ||
| if layout == "CHW": | ||
| return [-2, -1] | ||
| if inpt.layout[-2:] == "HW": | ||
| return [-2, -1] | ||
| raise ValueError(f"Unsupported layout: {inpt.layout!r}. Expected one of HWC, CHW, HW.") | ||
|
|
||
|
|
||
| def _verify_crop_coordinate(value, name: str) -> None: | ||
| if not isinstance(value, int): | ||
| raise TypeError(f"{name} must be int, got {type(value)}") | ||
|
|
||
|
|
||
| @adjust_input | ||
| def crop( | ||
| inpt: TensorLike | ndd.Batch, | ||
| top: int, | ||
| left: int, | ||
| height: int, | ||
| width: int, | ||
| device: DeviceLike = "cpu", | ||
| ) -> ndd.Tensor | ndd.Batch: | ||
| """ | ||
| Please refer to the ``RandomCrop`` operator for more details. | ||
| """ | ||
| _verify_crop_coordinate(top, "top") | ||
| _verify_crop_coordinate(left, "left") | ||
| RandomCrop.verify_args( | ||
| size=(height, width), | ||
| padding=None, | ||
| pad_if_needed=False, | ||
| padding_mode="constant", | ||
| fill=0, | ||
| ) | ||
|
|
||
| return ndd.slice( | ||
| inpt, | ||
| (top, left), | ||
| (height, width), | ||
| axes=_get_crop_axes(inpt), | ||
| out_of_bounds_policy="pad", | ||
| fill_values=0, | ||
| device=device, | ||
| ) |
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,82 @@ | ||||||||||||||||||||||
| # Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||||||||||||||||||||||
| # | ||||||||||||||||||||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||||||||||||||||
| # you may not use this file except in compliance with the License. | ||||||||||||||||||||||
| # You may obtain a copy of the License at | ||||||||||||||||||||||
| # | ||||||||||||||||||||||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||||||||||||||||||||||
| # | ||||||||||||||||||||||
| # Unless required by applicable law or agreed to in writing, software | ||||||||||||||||||||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||||||||||||||||||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||||||||||||||||
| # See the License for the specific language governing permissions and | ||||||||||||||||||||||
| # limitations under the License. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| from typing import List | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| from PIL import Image | ||||||||||||||||||||||
| import torch | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def get_image_size(inpt: Image.Image | torch.Tensor) -> List[int]: | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| Return the spatial size of an image as ``[width, height]``. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Mirrors ``torchvision.transforms.v2.functional.get_image_size``. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| .. note:: | ||||||||||||||||||||||
| This function is provided for compatibility. The torchvision successor | ||||||||||||||||||||||
| ``get_size`` returns ``[height, width]`` instead. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Parameters | ||||||||||||||||||||||
| ---------- | ||||||||||||||||||||||
| inpt : PIL Image or torch.Tensor | ||||||||||||||||||||||
| Input image. Tensors are expected in ``[…, H, W]`` layout (leading | ||||||||||||||||||||||
| channel / batch dimensions are ignored). | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Returns | ||||||||||||||||||||||
| ------- | ||||||||||||||||||||||
| List[int] | ||||||||||||||||||||||
| ``[width, height]`` | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| if isinstance(inpt, Image.Image): | ||||||||||||||||||||||
| return list(inpt.size) # PIL .size is (W, H) | ||||||||||||||||||||||
| elif isinstance(inpt, torch.Tensor): | ||||||||||||||||||||||
| if inpt.ndim < 2: | ||||||||||||||||||||||
| raise TypeError( | ||||||||||||||||||||||
| f"get_image_size requires a tensor with at least 2 dimensions, got {inpt.ndim}." | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| return [inpt.shape[-1], inpt.shape[-2]] # [W, H] | ||||||||||||||||||||||
| raise TypeError(f"Unsupported input type: {type(inpt)}.") | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
|
|
||||||||||||||||||||||
| def get_dimensions(inpt: Image.Image | torch.Tensor) -> List[int]: | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| Return the number of channels, height, and width of an image as | ||||||||||||||||||||||
| ``[channels, height, width]``. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Mirrors ``torchvision.transforms.v2.functional.get_dimensions``. | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Parameters | ||||||||||||||||||||||
| ---------- | ||||||||||||||||||||||
| inpt : PIL Image or torch.Tensor | ||||||||||||||||||||||
| Input image. Tensors are expected in ``[H, W]`` or ``[…, C, H, W]`` layout | ||||||||||||||||||||||
| (leading batch dimensions are ignored). | ||||||||||||||||||||||
|
|
||||||||||||||||||||||
| Returns | ||||||||||||||||||||||
| ------- | ||||||||||||||||||||||
| List[int] | ||||||||||||||||||||||
| ``[channels, height, width]`` | ||||||||||||||||||||||
| """ | ||||||||||||||||||||||
| if isinstance(inpt, Image.Image): | ||||||||||||||||||||||
| w, h = inpt.size | ||||||||||||||||||||||
| return [len(inpt.getbands()), h, w] | ||||||||||||||||||||||
| elif isinstance(inpt, torch.Tensor): | ||||||||||||||||||||||
| if inpt.ndim < 2: | ||||||||||||||||||||||
| raise TypeError( | ||||||||||||||||||||||
| f"get_dimensions requires a tensor with at least 2 dimensions, got {inpt.ndim}." | ||||||||||||||||||||||
| ) | ||||||||||||||||||||||
| if inpt.ndim == 2: | ||||||||||||||||||||||
|
Comment on lines
+75
to
+79
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Same sentence-ending convention issue in
Suggested change
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed |
||||||||||||||||||||||
| return [1, inpt.shape[-2], inpt.shape[-1]] | ||||||||||||||||||||||
| return [inpt.shape[-3], inpt.shape[-2], inpt.shape[-1]] # [C, H, W] | ||||||||||||||||||||||
| raise TypeError(f"Unsupported input type: {type(inpt)}.") | ||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both error messages are missing a trailing period, violating the project convention that error messages must read as complete sentences.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed