Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions src/tensor_shape_assert/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,16 +166,16 @@ def check_iterable(
continue

try:
constrained_labels = ShapeDescriptor.filter_for_constrained_labels(descriptor.labels)
unconstrained_labels = ShapeDescriptor.filter_for_unconstrained_labels(descriptor.labels)
constrained_label_annotations = ShapeDescriptor.filter_for_constrained_labels(descriptor.labels)
unconstrained_label_annotations = ShapeDescriptor.filter_for_unconstrained_labels(descriptor.labels)
except KeyError as e:
raise LabelAnnotationError(
f"Label '{e.args[0]}' is not registered. Only registered "
f"labels can be used in shape descriptors. "
)

# check labels with constraint fns
for label in constrained_labels:
for label in constrained_label_annotations:
constraint_fn = ShapeDescriptor.get_label_constraint_fn(label)
if constraint_fn is not None:
if not constraint_fn(obj):
Expand All @@ -184,23 +184,23 @@ def check_iterable(
f"label '{label}'."
)

registry_labels = _global_tensor_label_registry.get(obj, frozenset())
labels_registered_for_tensor = _global_tensor_label_registry.get(obj, frozenset())

if must_have_labels:
if must_have_labels and labels_registered_for_tensor:
# check labels annotations for unconstrained labels
if unconstrained_labels != registry_labels:
if not unconstrained_label_annotations.issubset(labels_registered_for_tensor):
raise LabelAnnotationError(
f"Tensor with labels {registry_labels} does not match "
f"annotation with labels {unconstrained_labels}."
f"Tensor with labels {labels_registered_for_tensor} does not match "
f"annotation with labels {unconstrained_label_annotations}."
)
elif registry_labels and not unconstrained_labels.issubset(registry_labels):
elif labels_registered_for_tensor and not unconstrained_label_annotations.issubset(labels_registered_for_tensor):
raise LabelAnnotationError(
f"Tensor with labels {registry_labels} must at least have "
f"annotation with labels {unconstrained_labels}."
f"Tensor with labels {labels_registered_for_tensor} must at least have "
f"annotation with labels {unconstrained_label_annotations}."
)
elif not registry_labels and unconstrained_labels:
elif not labels_registered_for_tensor and unconstrained_label_annotations:
# add to registry if not in registry and annotation has labels
label_tensor(obj, unconstrained_labels)
label_tensor(obj, unconstrained_label_annotations)

# check shape
variables = descriptor_to_variables(
Expand Down
18 changes: 16 additions & 2 deletions src/tensor_shape_assert_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1778,5 +1778,19 @@ def test_only_overwrite_if_allowed(self):
with self.assertRaises(LabelAnnotationError):
label_tensor(x, "test_label_2", overwrite=False)

# def test_try_to_force_weakref_collision(self):
# ...
def test_no_label_annotation_means_any_label(self):
@check_tensor_shapes()
def fn(x: ShapedTensor["n m 3"]) -> ShapedTensor["3"]:
return xp.sum(x, axis=(0, 1))

fn(xp.zeros((5, 6, 3)))
fn(label_tensor(xp.zeros((5, 6, 3)), "test_label"))
fn(label_tensor(xp.zeros((5, 6, 3)), "test_label_2"))
fn(label_tensor(xp.zeros((5, 6, 3)), ("test_label", "test_label_2")))

def test_more_labels_than_annotated_are_allowed(self):
@check_tensor_shapes()
def fn(x: ShapedTensor["n m 3 test_label"]) -> ShapedTensor["3"]:
return xp.sum(x, axis=(0, 1))

fn(label_tensor(xp.zeros((5, 6, 3)), ("test_label", "test_label_2")))
Loading