From f5d4fa446eb862b6c32db9bbecba4d69f9812eb2 Mon Sep 17 00:00:00 2001 From: getzze Date: Thu, 1 Feb 2024 16:22:00 +0000 Subject: [PATCH] add options for a signal aliases adapt to SignalRelay --- src/psygnal/__init__.py | 2 + src/psygnal/_dataclass_utils.py | 271 +++++++++++++++++++++++++++++- src/psygnal/_evented_decorator.py | 5 +- src/psygnal/_group_descriptor.py | 44 +++-- tests/test_custom_fields.py | 197 ++++++++++++++++++++++ 5 files changed, 496 insertions(+), 23 deletions(-) create mode 100644 tests/test_custom_fields.py diff --git a/src/psygnal/__init__.py b/src/psygnal/__init__.py index 8bceb9e7..758e6e7d 100644 --- a/src/psygnal/__init__.py +++ b/src/psygnal/__init__.py @@ -29,6 +29,7 @@ "EventedModel", "get_evented_namespace", "is_evented", + "PSYGNAL_METADATA", "Signal", "SignalGroup", "SignalGroupDescriptor", @@ -48,6 +49,7 @@ stacklevel=2, ) +from ._dataclass_utils import PSYGNAL_METADATA from ._evented_decorator import evented from ._exceptions import EmitLoopError from ._group import EmissionInfo, SignalGroup diff --git a/src/psygnal/_dataclass_utils.py b/src/psygnal/_dataclass_utils.py index 5b74be47..00eec4cf 100644 --- a/src/psygnal/_dataclass_utils.py +++ b/src/psygnal/_dataclass_utils.py @@ -4,13 +4,30 @@ import dataclasses import sys import types -from typing import TYPE_CHECKING, Any, Iterator, List, Protocol, cast, overload +from dataclasses import dataclass, fields +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Iterator, + List, + Mapping, + Protocol, + cast, + overload, +) if TYPE_CHECKING: + from dataclasses import Field + import attrs import msgspec from pydantic import BaseModel - from typing_extensions import TypeGuard # py310 + from typing_extensions import TypeAlias, TypeGuard # py310 + + EqOperator: TypeAlias = Callable[[Any, Any], bool] + +PSYGNAL_METADATA = "__psygnal_metadata" class _DataclassParams(Protocol): @@ -29,12 +46,11 @@ class AttrsType: __attrs_attrs__: tuple[attrs.Attribute, ...] -_DATACLASS_PARAMS = "__dataclass_params__" +KW_ONLY = object() with contextlib.suppress(ImportError): - from dataclasses import _DATACLASS_PARAMS # type: ignore + from dataclasses import KW_ONLY # py310 +_DATACLASS_PARAMS = "__dataclass_params__" _DATACLASS_FIELDS = "__dataclass_fields__" -with contextlib.suppress(ImportError): - from dataclasses import _DATACLASS_FIELDS # type: ignore class DataClassType: @@ -171,8 +187,8 @@ def iter_fields( yield field_name, p_field.annotation else: for p_field in cls.__fields__.values(): # type: ignore [attr-defined] - if p_field.field_info.allow_mutation or not exclude_frozen: # type: ignore - yield p_field.name, p_field.outer_type_ # type: ignore + if p_field.field_info.allow_mutation or not exclude_frozen: # type: ignore [attr-defined] + yield p_field.name, p_field.outer_type_ # type: ignore [attr-defined] return if (attrs_fields := getattr(cls, "__attrs_attrs__", None)) is not None: @@ -185,3 +201,242 @@ def iter_fields( type_ = cls.__annotations__.get(m_field, None) yield m_field, type_ return + + +@dataclass +class FieldOptions: + name: str + type_: type | None = None + # set KW_ONLY value for compatibility with python < 3.10 + _: KW_ONLY = KW_ONLY # type: ignore [valid-type] + alias: str | None = None + skip: bool | None = None + eq: EqOperator | None = None + disable_setattr: bool | None = None + + +def is_kw_only(f: Field) -> bool: + if hasattr(f, "kw_only"): + return cast(bool, f.kw_only) + # for python < 3.10 + if f.name not in ["name", "type_"]: + return True + return False + + +def sanitize_field_options_dict(d: Mapping) -> dict[str, Any]: + field_options_kws = [f.name for f in fields(FieldOptions) if is_kw_only(f)] + return {k: v for k, v in d.items() if k in field_options_kws} + + +def get_msgspec_metadata( + cls: type[msgspec.Struct], + m_field: str, +) -> tuple[type | None, dict[str, Any]]: + # Look for type in cls and super classes + type_: type | None = None + for super_cls in cls.__mro__: + if not hasattr(super_cls, "__annotations__"): + continue + type_ = super_cls.__annotations__.get(m_field, None) + if type_ is not None: + break + + msgspec = sys.modules.get("msgspec", None) + if msgspec is None: + return type_, {} + + metadata_list = getattr(type_, "__metadata__", []) + + metadata: dict[str, Any] = {} + for meta in metadata_list: + if not isinstance(meta, msgspec.Meta): + continue + single_meta: dict[str, Any] = getattr(meta, "extra", {}).get( + PSYGNAL_METADATA, {} + ) + metadata.update(single_meta) + + return type_, metadata + + +def iter_fields_with_options( + cls: type, exclude_frozen: bool = True +) -> Iterator[FieldOptions]: + """Iterate over all fields in the class, return a field description. + + This function recognizes dataclasses, attrs classes, msgspec Structs, and pydantic + models. + + Parameters + ---------- + cls : type + The class to iterate over. + exclude_frozen : bool, optional + If True, frozen fields will be excluded. By default True. + + Yields + ------ + FieldOptions + A dataclass instance with the name, type and metadata of each field. + """ + # Add metadata for dataclasses.dataclass + dclass_fields = getattr(cls, "__dataclass_fields__", None) + if dclass_fields is not None: + """ + Example + ------- + from dataclasses import dataclass, field + + + @dataclass + class Foo: + bar: int = field(metadata={"alias": "bar_alias"}) + + assert ( + Foo.__dataclass_fields__["bar"].metadata == + {"__psygnal_metadata": {"alias": "bar_alias"}} + ) + + """ + for d_field in dclass_fields.values(): + if d_field._field_type is dataclasses._FIELD: # type: ignore [attr-defined] + metadata = getattr(d_field, "metadata", {}).get(PSYGNAL_METADATA, {}) + metadata = sanitize_field_options_dict(metadata) + options = FieldOptions(d_field.name, d_field.type, **metadata) + yield options + return + + # Add metadata for pydantic dataclass + if is_pydantic_model(cls): + """ + Example + ------- + from typing import Annotated + + from pydantic import BaseModel, Field + + + # Only works with Pydantic v2 + class Foo(BaseModel): + bar: Annotated[ + str, + {'__psygnal_metadata': {"alias": "bar_alias"}} + ] = Field(...) + + # Working with Pydantic v2 and partially with v1 + # Alternative, using Field `json_schema_extra` keyword argument + class Bar(BaseModel): + bar: str = Field( + json_schema_extra={PSYGNAL_METADATA: {"alias": "bar_alias"}} + ) + + + assert ( + Foo.model_fields["bar"].metadata[0] == + {"__psygnal_metadata": {"alias": "bar_alias"}} + ) + assert ( + Bar.model_fields["bar"].json_schema_extra == + {"__psygnal_metadata": {"alias": "bar_alias"}} + ) + + """ + if hasattr(cls, "model_fields"): + # Pydantic v2 + for field_name, p_field in cls.model_fields.items(): + # skip frozen field + if exclude_frozen and p_field.frozen: + continue + metadata_list = getattr(p_field, "metadata", []) + metadata = {} + for field in metadata_list: + metadata.update(field.get(PSYGNAL_METADATA, {})) + # Compat with using Field `json_schema_extra` keyword argument + if isinstance(getattr(p_field, "json_schema_extra", None), Mapping): + meta_dict = cast(Mapping, p_field.json_schema_extra) + metadata.update(meta_dict.get(PSYGNAL_METADATA, {})) + metadata = sanitize_field_options_dict(metadata) + options = FieldOptions(field_name, p_field.annotation, **metadata) + yield options + return + + else: + # Pydantic v1, metadata is not always working + for pv1_field in cls.__fields__.values(): # type: ignore [attr-defined] + # skip frozen field + if exclude_frozen and not pv1_field.field_info.allow_mutation: + continue + meta_dict = getattr(pv1_field.field_info, "extra", {}).get( + "json_schema_extra", {} + ) + metadata = meta_dict.get(PSYGNAL_METADATA, {}) + + metadata = sanitize_field_options_dict(metadata) + options = FieldOptions( + pv1_field.name, + pv1_field.outer_type_, + **metadata, + ) + yield options + return + + # Add metadata for attrs dataclass + attrs_fields = getattr(cls, "__attrs_attrs__", None) + if attrs_fields is not None: + """ + Example + ------- + from attrs import define, field + + + @define + class Foo: + bar: int = field(metadata={"alias": "bar_alias"}) + + assert ( + Foo.__attrs_attrs__.bar.metadata == + {"__psygnal_metadata": {"alias": "bar_alias"}} + ) + + """ + for a_field in attrs_fields: + metadata = getattr(a_field, "metadata", {}).get(PSYGNAL_METADATA, {}) + metadata = sanitize_field_options_dict(metadata) + options = FieldOptions(a_field.name, a_field.type, **metadata) + yield options + return + + # Add metadata for attrs dataclass + if is_msgspec_struct(cls): + """ + Example + ------- + from typing import Annotated + + from msgspec import Meta, Struct + + + class Foo(Struct): + bar: Annotated[ + str, + Meta(extra={"__psygnal_metadata": {"alias": "bar_alias"})) + ] = "" + + + print(Foo.__annotations__["bar"].__metadata__[0].extra) + # {"__psygnal_metadata": {"alias": "bar_alias"}} + + """ + for m_field in cls.__struct_fields__: + try: + type_, metadata = get_msgspec_metadata(cls, m_field) + metadata = sanitize_field_options_dict(metadata) + except AttributeError: + msg = f"Cannot parse field metadata for {m_field}: {type_}" + # logger.exception(msg) + print(msg) + type_, metadata = None, {} + options = FieldOptions(m_field, type_, **metadata) + yield options + return diff --git a/src/psygnal/_evented_decorator.py b/src/psygnal/_evented_decorator.py index 4b084b9c..eee2b29a 100644 --- a/src/psygnal/_evented_decorator.py +++ b/src/psygnal/_evented_decorator.py @@ -5,7 +5,10 @@ from psygnal._group_descriptor import SignalGroupDescriptor if TYPE_CHECKING: - from psygnal._group_descriptor import EqOperator, FieldAliasFunc + from psygnal._group_descriptor import ( # type: ignore[attr-defined] + EqOperator, + FieldAliasFunc, + ) __all__ = ["evented"] diff --git a/src/psygnal/_group_descriptor.py b/src/psygnal/_group_descriptor.py index f787c7e3..12f0f854 100644 --- a/src/psygnal/_group_descriptor.py +++ b/src/psygnal/_group_descriptor.py @@ -21,7 +21,7 @@ overload, ) -from ._dataclass_utils import iter_fields +from ._dataclass_utils import iter_fields_with_options from ._group import SignalGroup from ._signal import Signal, SignalInstance @@ -31,7 +31,8 @@ from psygnal._weak_callback import RefErrorChoice, WeakCallback - EqOperator: TypeAlias = Callable[[Any, Any], bool] + from ._dataclass_utils import EqOperator + FieldAliasFunc: TypeAlias = Callable[[str], Optional[str]] __all__ = ["is_evented", "get_evented_namespace", "SignalGroupDescriptor"] @@ -196,29 +197,44 @@ def _build_dataclass_signal_group( signals = {} # create a Signal for each field in the dataclass - for name, type_ in iter_fields(cls): - if name in _equality_operators: - if not callable(_equality_operators[name]): # pragma: no cover + for f in iter_fields_with_options(cls): + # skip this field + if f.skip: + continue + + # Equality operator + if f.eq is not None: + if not callable(f.eq): # pragma: no cover + raise TypeError("`eq` field metadata must be callable") + eq_map[f.name] = f.eq + elif f.name in _equality_operators: + if not callable(_equality_operators[f.name]): # pragma: no cover raise TypeError("EqOperator must be callable") - eq_map[name] = _equality_operators[name] + eq_map[f.name] = _equality_operators[f.name] else: - eq_map[name] = _pick_equality_operator(type_) + eq_map[f.name] = _pick_equality_operator(f.type_) # Resolve the signal name for the field sig_name: str | None - if name in _signal_aliases: # an alias has been provided in a mapping - sig_name = _signal_aliases[name] + if f.alias is not None: # an alias has been provided in the field metadata + _signal_aliases[f.name] = sig_name = f.alias + elif f.name in _signal_aliases: # an alias has been provided in a mapping + sig_name = _signal_aliases[f.name] elif callable(transform): # a callable has been provided - _signal_aliases[name] = sig_name = transform(name) - elif name in signal_group_sig_aliases: # an alias has been defined in the class - _signal_aliases[name] = sig_name = signal_group_sig_aliases[name] + _signal_aliases[f.name] = sig_name = transform(f.name) + elif f.name in signal_group_sig_aliases: # an alias was defined in the class + _signal_aliases[f.name] = sig_name = signal_group_sig_aliases[f.name] else: # no alias has been defined, use the field name as the signal name - sig_name = name + sig_name = f.name # An alias mapping or callable returned `None`, skip this field if sig_name is None: continue + # Create the signal, but the alias is set to None + if f.disable_setattr: + _signal_aliases[f.name] = None + # Repeated signal if sig_name in signals: key = next((k for k, v in _signal_aliases.items() if v == sig_name), None) @@ -239,7 +255,7 @@ def _build_dataclass_signal_group( continue # Create the Signal - field_type = object if type_ is None else type_ + field_type = object if f.type_ is None else f.type_ signals[sig_name] = sig = Signal(field_type, field_type) # patch in our custom SignalInstance class with maxargs=1 on connect_setattr sig._signal_instance_class = _DataclassFieldSignalInstance diff --git a/tests/test_custom_fields.py b/tests/test_custom_fields.py new file mode 100644 index 00000000..10d94360 --- /dev/null +++ b/tests/test_custom_fields.py @@ -0,0 +1,197 @@ +# from __future__ import annotations # breaks msgspec Annotated + +import contextlib +import sys +from typing import ClassVar, Dict, Optional +from unittest.mock import Mock + +import pytest + +from psygnal import ( + PSYGNAL_METADATA, + EmissionInfo, + SignalGroupDescriptor, + is_evented, +) + +Annotated = None +with contextlib.suppress(ImportError): + from typing import Annotated # type: ignore + + +min_py_version = pytest.mark.skipif( + sys.version_info < (3, 9), reason="needs typing.Annotated" +) + + +def get_signal_aliases(obj: object) -> Dict[str, Optional[str]]: + if not is_evented(obj): + return {} + return obj.events._psygnal_aliases + + +@pytest.mark.parametrize( + "type_", + [ + "dataclass", + "attrs", + pytest.param("pydantic", marks=min_py_version), + pytest.param("msgspec", marks=min_py_version), + ], +) +def test_field_metadata(type_: str) -> None: + a_metadata = {PSYGNAL_METADATA: {"alias": "a_changed"}} + b_metadata = {PSYGNAL_METADATA: {"eq": lambda s1, s2: s1.lower() == s2.lower()}} + c_metadata = {PSYGNAL_METADATA: {"skip": True}} + d_metadata = {PSYGNAL_METADATA: {"disable_setattr": True}} + + if type_ == "dataclass": + from dataclasses import dataclass, field + + @dataclass + class Base: + a: int = field(metadata=a_metadata) + events: ClassVar = SignalGroupDescriptor() + + @dataclass + class Foo(Base): + b: str = field(metadata=b_metadata) + + @dataclass + class Bar(Foo): + c: float = field(metadata=c_metadata) + + @dataclass + class Baz(Bar): + d: float = field(metadata=d_metadata) + + elif type_ == "attrs": + from attrs import define, field + + @define + class Base: + a: int = field(metadata=a_metadata) + events: ClassVar = SignalGroupDescriptor() + + @define + class Foo(Base): + b: str = field(metadata=b_metadata) + + @define + class Bar(Foo): + c: float = field(metadata=c_metadata) + + @define + class Baz(Bar): + d: float = field(metadata=d_metadata) + + elif type_ == "pydantic": + pytest.importorskip("pydantic", minversion="2") + from pydantic import BaseModel, Field + + class Base(BaseModel): + a: Annotated[int, a_metadata] + events: ClassVar = SignalGroupDescriptor() + + # Alternative, using Field `json_schema_extra` keyword argument + class Foo(Base): + b: str = Field(json_schema_extra=b_metadata) + + class Bar(Foo): + c: Annotated[float, c_metadata] + + class Baz(Bar): + d: Annotated[float, d_metadata] + + elif type_ == "msgspec": + msgspec = pytest.importorskip("msgspec") + + class Base(msgspec.Struct): # type: ignore + a: Annotated[int, msgspec.Meta(extra=a_metadata)] + events: ClassVar = SignalGroupDescriptor() + + class Foo(Base): + b: Annotated[str, msgspec.Meta(extra=b_metadata)] + + class Bar(Foo): + c: Annotated[float, msgspec.Meta(extra=c_metadata)] + + class Baz(Bar): + d: Annotated[float, msgspec.Meta(extra=d_metadata)] + + assert Bar.events is Base.events + + # Instantiate objects + base = Base(a=1) + foo = Foo(a=1, b="b") + bar = Bar(a=1, b="b", c=3.0) + bar2 = Bar(a=1, b="b", c=3.0) + baz = Baz(a=1, b="b", c=3.0, d=4.0) + + # the patching of __setattr__ should only happen once + # and it will happen only on the first access of .events + assert set(base.events) == {"a_changed"} + assert set(foo.events) == {"a_changed", "b"} + assert set(bar.events) == {"a_changed", "b"} + assert set(bar2.events) == {"a_changed", "b"} + assert set(baz.events) == {"a_changed", "b", "d"} + + assert get_signal_aliases(base) == {"a": "a_changed"} + assert get_signal_aliases(foo) == {"a": "a_changed"} + assert get_signal_aliases(bar) == {"a": "a_changed"} + assert get_signal_aliases(bar2) == {"a": "a_changed"} + assert get_signal_aliases(baz) == {"a": "a_changed", "d": None} + + mock = Mock() + assert not hasattr(foo.events, "a") + foo.events.a_changed.connect(mock) + foo.events.b.connect(mock) + baz.events.a_changed.connect(mock) + baz.events.b.connect(mock) + baz.events.d.connect(mock) + + # base doesn't affect subclass + assert not hasattr(base.events, "a") + base.events.a_changed.emit(1) + mock.assert_not_called() + + base.events.a_changed.emit(2) + mock.assert_not_called() + + # `alias` works + assert hasattr(foo.events, "a_changed") + foo.a = 2 + mock.assert_called_once_with(2, 1) + mock.reset_mock() + + baz.a = 2 + mock.assert_called_once_with(2, 1) + mock.reset_mock() + + # `eq` works + foo.b = "B" + mock.assert_not_called() + baz.b = "B" + mock.assert_not_called() + + foo.b = "C" + mock.assert_called_once_with("C", "B") + mock.reset_mock() + + # `skip` works + assert not hasattr(baz.events, "c") + + # `disable_setattr` works + baz.d = 5.0 + mock.assert_not_called() + + # Check all + mock1 = Mock() + baz.events.all.connect(mock1) + baz.c = 4.0 + mock1.assert_not_called() + baz.d = 6.0 + mock1.assert_not_called() + baz.a = 3 + assert hasattr(baz.events, "a_changed") + mock1.assert_called_once_with(EmissionInfo(baz.events.a_changed, (3, 2)))