Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

setup(
name="tensor-shape-assert",
version="0.4.2",
version="0.4.3",
description="A simple runtime assert library for tensor-based frameworks.",
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down
2 changes: 1 addition & 1 deletion speedtest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from src.tensor_shape_assert import ShapedTensor, check_tensor_shapes
from test_utils import get_library_by_name, NAME_LIBRARY_MAP
from src.test_utils import get_library_by_name, NAME_LIBRARY_MAP
from time import time
from tqdm import tqdm
from tabulate import tabulate
Expand Down
5 changes: 5 additions & 0 deletions src/tensor_shape_assert/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,11 @@ def label_tensor(
LabelAnnotationError
If the given label is not registered or if the tensor already has labels.
"""

# skip if checks are disabled
if _global_check_mode == "never":
return tensor

if isinstance(label, str):
label = [label]

Expand Down
16 changes: 15 additions & 1 deletion src/tensor_shape_assert_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,9 +1420,23 @@ def test(x: ShapedTensor["n m 2"]) -> ShapedTensor["2"]:
compiled_test(xp.zeros((4, 3, 1)))
warnings.filterwarnings("default", category=CheckDisabledWarning)


self.assertTrue(True)

def test_labeling_compatible_with_torch_compile(self):
import torch
set_global_check_mode('never')
register_label("test_label")

@check_tensor_shapes()
def test(x: ShapedTensor["n m 2"]) -> ShapedTensor["2"]:
label_tensor(x, "test_label")
return x.sum(axis=(0, 1))

compiled_test = torch.compile(test, fullgraph=True)

# if TSA causes graph breaks, this will raise a RuntimeError about tracing
compiled_test(xp.zeros((4, 3, 1)))


class TestNonTensorTupleAnnotations(unittest.TestCase):
def test_non_tensor_tuple_annotations(self):
Expand Down
Loading