diff --git a/src/tensor_shape_assert/wrapper.py b/src/tensor_shape_assert/wrapper.py index 24f9e59..fc667c1 100644 --- a/src/tensor_shape_assert/wrapper.py +++ b/src/tensor_shape_assert/wrapper.py @@ -166,8 +166,8 @@ def check_iterable( continue try: - constrained_labels = ShapeDescriptor.filter_for_constrained_labels(descriptor.labels) - unconstrained_labels = ShapeDescriptor.filter_for_unconstrained_labels(descriptor.labels) + constrained_label_annotations = ShapeDescriptor.filter_for_constrained_labels(descriptor.labels) + unconstrained_label_annotations = ShapeDescriptor.filter_for_unconstrained_labels(descriptor.labels) except KeyError as e: raise LabelAnnotationError( f"Label '{e.args[0]}' is not registered. Only registered " @@ -175,7 +175,7 @@ def check_iterable( ) # check labels with constraint fns - for label in constrained_labels: + for label in constrained_label_annotations: constraint_fn = ShapeDescriptor.get_label_constraint_fn(label) if constraint_fn is not None: if not constraint_fn(obj): @@ -184,23 +184,23 @@ def check_iterable( f"label '{label}'." ) - registry_labels = _global_tensor_label_registry.get(obj, frozenset()) + labels_registered_for_tensor = _global_tensor_label_registry.get(obj, frozenset()) - if must_have_labels: + if must_have_labels and labels_registered_for_tensor: # check labels annotations for unconstrained labels - if unconstrained_labels != registry_labels: + if not unconstrained_label_annotations.issubset(labels_registered_for_tensor): raise LabelAnnotationError( - f"Tensor with labels {registry_labels} does not match " - f"annotation with labels {unconstrained_labels}." + f"Tensor with labels {labels_registered_for_tensor} does not match " + f"annotation with labels {unconstrained_label_annotations}." ) - elif registry_labels and not unconstrained_labels.issubset(registry_labels): + elif labels_registered_for_tensor and not unconstrained_label_annotations.issubset(labels_registered_for_tensor): raise LabelAnnotationError( - f"Tensor with labels {registry_labels} must at least have " - f"annotation with labels {unconstrained_labels}." + f"Tensor with labels {labels_registered_for_tensor} must at least have " + f"annotation with labels {unconstrained_label_annotations}." ) - elif not registry_labels and unconstrained_labels: + elif not labels_registered_for_tensor and unconstrained_label_annotations: # add to registry if not in registry and annotation has labels - label_tensor(obj, unconstrained_labels) + label_tensor(obj, unconstrained_label_annotations) # check shape variables = descriptor_to_variables( diff --git a/src/tensor_shape_assert_test.py b/src/tensor_shape_assert_test.py index 74f0da6..2d3e794 100644 --- a/src/tensor_shape_assert_test.py +++ b/src/tensor_shape_assert_test.py @@ -1778,5 +1778,19 @@ def test_only_overwrite_if_allowed(self): with self.assertRaises(LabelAnnotationError): label_tensor(x, "test_label_2", overwrite=False) - # def test_try_to_force_weakref_collision(self): - # ... \ No newline at end of file + def test_no_label_annotation_means_any_label(self): + @check_tensor_shapes() + def fn(x: ShapedTensor["n m 3"]) -> ShapedTensor["3"]: + return xp.sum(x, axis=(0, 1)) + + fn(xp.zeros((5, 6, 3))) + fn(label_tensor(xp.zeros((5, 6, 3)), "test_label")) + fn(label_tensor(xp.zeros((5, 6, 3)), "test_label_2")) + fn(label_tensor(xp.zeros((5, 6, 3)), ("test_label", "test_label_2"))) + + def test_more_labels_than_annotated_are_allowed(self): + @check_tensor_shapes() + def fn(x: ShapedTensor["n m 3 test_label"]) -> ShapedTensor["3"]: + return xp.sum(x, axis=(0, 1)) + + fn(label_tensor(xp.zeros((5, 6, 3)), ("test_label", "test_label_2")))