From 17b85a368c565747f87fc35ec4251c0997caf565 Mon Sep 17 00:00:00 2001 From: Leif Van Holland Date: Wed, 15 Apr 2026 13:05:37 +0000 Subject: [PATCH 1/2] fixed torch compile crash because of label func --- speedtest.py | 2 +- src/tensor_shape_assert/wrapper.py | 5 +++++ src/tensor_shape_assert_test.py | 16 +++++++++++++++- 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/speedtest.py b/speedtest.py index b9588b4..52c8783 100644 --- a/speedtest.py +++ b/speedtest.py @@ -1,5 +1,5 @@ from src.tensor_shape_assert import ShapedTensor, check_tensor_shapes -from test_utils import get_library_by_name, NAME_LIBRARY_MAP +from src.test_utils import get_library_by_name, NAME_LIBRARY_MAP from time import time from tqdm import tqdm from tabulate import tabulate diff --git a/src/tensor_shape_assert/wrapper.py b/src/tensor_shape_assert/wrapper.py index fc667c1..efa00ec 100644 --- a/src/tensor_shape_assert/wrapper.py +++ b/src/tensor_shape_assert/wrapper.py @@ -119,6 +119,11 @@ def label_tensor( LabelAnnotationError If the given label is not registered or if the tensor already has labels. """ + + # skip if checks are disabled + if _global_check_mode == "never": + return tensor + if isinstance(label, str): label = [label] diff --git a/src/tensor_shape_assert_test.py b/src/tensor_shape_assert_test.py index 2d3e794..2b833ab 100644 --- a/src/tensor_shape_assert_test.py +++ b/src/tensor_shape_assert_test.py @@ -1420,9 +1420,23 @@ def test(x: ShapedTensor["n m 2"]) -> ShapedTensor["2"]: compiled_test(xp.zeros((4, 3, 1))) warnings.filterwarnings("default", category=CheckDisabledWarning) - self.assertTrue(True) + def test_labeling_compatible_with_torch_compile(self): + import torch + set_global_check_mode('never') + register_label("test_label") + + @check_tensor_shapes() + def test(x: ShapedTensor["n m 2"]) -> ShapedTensor["2"]: + label_tensor(x, "test_label") + return x.sum(axis=(0, 1)) + + compiled_test = torch.compile(test, fullgraph=True) + + # if TSA causes graph breaks, this will raise a RuntimeError about tracing + compiled_test(xp.zeros((4, 3, 1))) + class TestNonTensorTupleAnnotations(unittest.TestCase): def test_non_tensor_tuple_annotations(self): From 7ce9670b02660038a312d79c489196eabd68969f Mon Sep 17 00:00:00 2001 From: Leif Van Holland Date: Wed, 15 Apr 2026 13:09:42 +0000 Subject: [PATCH 2/2] bump version --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 38975a1..a5dc64f 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ setup( name="tensor-shape-assert", - version="0.4.2", + version="0.4.3", description="A simple runtime assert library for tensor-based frameworks.", long_description=long_description, long_description_content_type="text/markdown",