From eccee5888fe441b5aed37aa3fbdb1486d95f9cc3 Mon Sep 17 00:00:00 2001 From: Damien Garros Date: Wed, 19 Mar 2025 12:03:53 +0100 Subject: [PATCH 1/5] Initial Prototype to generate Infrahub Schema from Pydantic models --- .../python-sdk/examples/schema_pydantic.py | 39 ++ infrahub_sdk/schema/__init__.py | 6 + infrahub_sdk/schema/main.py | 2 +- infrahub_sdk/schema/pydantic_utils.py | 167 ++++++++ tests/unit/sdk/test_pydantic.py | 367 ++++++++++++++++++ 5 files changed, 580 insertions(+), 1 deletion(-) create mode 100644 docs/docs/python-sdk/examples/schema_pydantic.py create mode 100644 infrahub_sdk/schema/pydantic_utils.py create mode 100644 tests/unit/sdk/test_pydantic.py diff --git a/docs/docs/python-sdk/examples/schema_pydantic.py b/docs/docs/python-sdk/examples/schema_pydantic.py new file mode 100644 index 00000000..8767f7f9 --- /dev/null +++ b/docs/docs/python-sdk/examples/schema_pydantic.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from asyncio import run as aiorun + +from typing import Annotated + +from pydantic import BaseModel, Field +from infrahub_sdk import InfrahubClient +from rich import print as rprint +from infrahub_sdk.schema import InfrahubAttributeParam as AttrParam, InfrahubRelationshipParam as RelParam, AttributeKind, from_pydantic + + +class Tag(BaseModel): + name: Annotated[str, AttrParam(unique=True), Field(description="The name of the tag")] + label: str | None = Field(description="The label of the tag") + description: Annotated[str | None, AttrParam(kind=AttributeKind.TEXTAREA)] = None + + +class Car(BaseModel): + name: str = Field(description="The name of the car") + tags: list[Tag] + owner: Annotated[Person, RelParam(identifier="car__person")] + secondary_owner: Person | None = None + + +class Person(BaseModel): + name: str + cars: Annotated[list[Car] | None, RelParam(identifier="car__person")] = None + + +async def main(): + client = InfrahubClient() + schema = from_pydantic(models=[Person, Car, Tag]) + rprint(schema.to_schema_dict()) + response = await client.schema.load(schemas=[schema.to_schema_dict()], wait_until_converged=True) + rprint(response) + +if __name__ == "__main__": + aiorun(main()) diff --git a/infrahub_sdk/schema/__init__.py b/infrahub_sdk/schema/__init__.py index be1cfab9..6854da01 100644 --- a/infrahub_sdk/schema/__init__.py +++ b/infrahub_sdk/schema/__init__.py @@ -23,6 +23,7 @@ from ..graphql import Mutation from ..queries import SCHEMA_HASH_SYNC_STATUS from .main import ( + AttributeKind, AttributeSchema, AttributeSchemaAPI, BranchSchema, @@ -40,6 +41,7 @@ SchemaRootAPI, TemplateSchemaAPI, ) +from .pydantic_utils import InfrahubAttributeParam, InfrahubRelationshipParam, from_pydantic if TYPE_CHECKING: from ..client import InfrahubClient, InfrahubClientSync, SchemaType, SchemaTypeSync @@ -49,11 +51,14 @@ __all__ = [ + "AttributeKind", "AttributeSchema", "AttributeSchemaAPI", "BranchSupportType", "GenericSchema", "GenericSchemaAPI", + "InfrahubAttributeParam", + "InfrahubRelationshipParam", "NodeSchema", "NodeSchemaAPI", "ProfileSchemaAPI", @@ -64,6 +69,7 @@ "SchemaRoot", "SchemaRootAPI", "TemplateSchemaAPI", + "from_pydantic", ] diff --git a/infrahub_sdk/schema/main.py b/infrahub_sdk/schema/main.py index ba18cf49..007d74e2 100644 --- a/infrahub_sdk/schema/main.py +++ b/infrahub_sdk/schema/main.py @@ -344,7 +344,7 @@ class SchemaRoot(BaseModel): node_extensions: list[NodeExtensionSchema] = Field(default_factory=list) def to_schema_dict(self) -> dict[str, Any]: - return self.model_dump(exclude_unset=True, exclude_defaults=True) + return self.model_dump(exclude_defaults=True, mode="json") class SchemaRootAPI(BaseModel): diff --git a/infrahub_sdk/schema/pydantic_utils.py b/infrahub_sdk/schema/pydantic_utils.py new file mode 100644 index 00000000..60cc3b9a --- /dev/null +++ b/infrahub_sdk/schema/pydantic_utils.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +import typing +from dataclasses import dataclass +from types import UnionType +from typing import Any + +from pydantic import BaseModel +from pydantic.fields import FieldInfo, PydanticUndefined + +from infrahub_sdk.schema.main import AttributeSchema, NodeSchema, RelationshipSchema, SchemaRoot + +from .main import AttributeKind, BranchSupportType, SchemaState + +KIND_MAPPING: dict[type, AttributeKind] = { + int: AttributeKind.NUMBER, + float: AttributeKind.NUMBER, + str: AttributeKind.TEXT, + bool: AttributeKind.BOOLEAN, +} + + +@dataclass +class InfrahubAttributeParam: + state: SchemaState = SchemaState.PRESENT + kind: AttributeKind | None = None + label: str | None = None + unique: bool = False + branch: BranchSupportType | None = None + + +@dataclass +class InfrahubRelationshipParam: + identifier: str | None = None + branch: BranchSupportType | None = None + + +@dataclass +class InfrahubFieldInfo: + name: str + types: list[type] + optional: bool + default: Any + + @property + def primary_type(self) -> type: + if len(self.types) == 0: + raise ValueError("No types found") + if self.is_list: + return typing.get_args(self.types[0])[0] + + return self.types[0] + + @property + def is_attribute(self) -> bool: + return self.primary_type in KIND_MAPPING + + @property + def is_relationship(self) -> bool: + return issubclass(self.primary_type, BaseModel) + + @property + def is_list(self) -> bool: + return typing.get_origin(self.types[0]) is list + + def to_dict(self) -> dict: + return { + "name": self.name, + "primary_type": self.primary_type, + "optional": self.optional, + "default": self.default, + "is_attribute": self.is_attribute, + "is_relationship": self.is_relationship, + "is_list": self.is_list, + } + + +def analyze_field(field_name: str, field: FieldInfo) -> InfrahubFieldInfo: + clean_types = [] + if isinstance(field.annotation, UnionType) or ( + hasattr(field.annotation, "_name") and field.annotation._name == "Optional" # type: ignore[union-attr] + ): + clean_types = [t for t in field.annotation.__args__ if t is not type(None)] # type: ignore[union-attr] + else: + clean_types.append(field.annotation) + + return InfrahubFieldInfo( + name=field.alias or field_name, + types=clean_types, + optional=not field.is_required(), + default=field.default if field.default is not PydanticUndefined else None, + ) + + +def get_attribute_kind(field: FieldInfo) -> AttributeKind: + if field.annotation in KIND_MAPPING: + return KIND_MAPPING[field.annotation] + + if isinstance(field.annotation, UnionType) or ( + hasattr(field.annotation, "_name") and field.annotation._name == "Optional" # type: ignore[union-attr] + ): + valid_types = [t for t in field.annotation.__args__ if t is not type(None)] # type: ignore[union-attr] + if len(valid_types) == 1 and valid_types[0] in KIND_MAPPING: + return KIND_MAPPING[valid_types[0]] + + raise ValueError(f"Unknown field type: {field.annotation}") + + +def field_to_attribute(field_name: str, field_info: InfrahubFieldInfo, field: FieldInfo) -> AttributeSchema: # noqa: ARG001 + field_param = InfrahubAttributeParam() + field_params = [metadata for metadata in field.metadata if isinstance(metadata, InfrahubAttributeParam)] + if len(field_params) == 1: + field_param = field_params[0] + + return AttributeSchema( + name=field_name, + label=field_param.label, + description=field.description, + kind=field_param.kind or get_attribute_kind(field), + optional=not field.is_required(), + unique=field_param.unique, + branch=field_param.branch, + ) + + +def field_to_relationship( + field_name: str, + field_info: InfrahubFieldInfo, + field: FieldInfo, + namespace: str = "Testing", +) -> RelationshipSchema: + field_param = InfrahubRelationshipParam() + field_params = [metadata for metadata in field.metadata if isinstance(metadata, InfrahubRelationshipParam)] + if len(field_params) == 1: + field_param = field_params[0] + + return RelationshipSchema( + name=field_name, + description=field.description, + peer=f"{namespace}{field_info.primary_type.__name__}", + identifier=field_param.identifier, + cardinality="many" if field_info.is_list else "one", + optional=field_info.optional, + branch=field_param.branch, + ) + + +def from_pydantic(models: list[type[BaseModel]], namespace: str = "Testing") -> SchemaRoot: + schema = SchemaRoot(version="1.0") + + for model in models: + node = NodeSchema( + name=model.__name__, + namespace=namespace, + ) + + for field_name, field in model.model_fields.items(): + field_info = analyze_field(field_name, field) + + if field_info.is_attribute: + node.attributes.append(field_to_attribute(field_name, field_info, field)) + elif field_info.is_relationship: + node.relationships.append(field_to_relationship(field_name, field_info, field, namespace)) + + schema.nodes.append(node) + + return schema diff --git a/tests/unit/sdk/test_pydantic.py b/tests/unit/sdk/test_pydantic.py new file mode 100644 index 00000000..e5d18359 --- /dev/null +++ b/tests/unit/sdk/test_pydantic.py @@ -0,0 +1,367 @@ +from __future__ import annotations + +from typing import Annotated, Optional + +import pytest +from pydantic import BaseModel, Field + +from infrahub_sdk.schema.main import AttributeKind, AttributeSchema, RelationshipSchema +from infrahub_sdk.schema.pydantic_utils import ( + InfrahubAttributeParam as AttrParam, +) +from infrahub_sdk.schema.pydantic_utils import ( + analyze_field, + field_to_attribute, + field_to_relationship, + from_pydantic, + get_attribute_kind, +) + + +class MyModel(BaseModel): + name: str + age: int + is_active: bool + opt_age: int | None = None + default_name: str = "some_default" + old_opt_age: Optional[int] = None # noqa: UP007 + + +class Tag(BaseModel): + name: str = Field(default="test_tag", description="The name of the tag") + description: Annotated[str | None, AttrParam(kind=AttributeKind.TEXTAREA)] = None + label: Annotated[str, AttrParam(unique=True), Field(description="The label of the tag")] + + +class Car(BaseModel): + name: str + tags: list[Tag] + owner: Person + secondary_owner: Person | None = None + + +class Person(BaseModel): + name: str + cars: list[Car] | None = None + + +@pytest.mark.parametrize( + "field_name, expected_kind", + [ + ("name", "Text"), + ("age", "Number"), + ("is_active", "Boolean"), + ("opt_age", "Number"), + ("default_name", "Text"), + ("old_opt_age", "Number"), + ], +) +def test_get_field_kind(field_name, expected_kind): + assert get_attribute_kind(MyModel.model_fields[field_name]) == expected_kind + + +@pytest.mark.parametrize( + "field_name, model, expected", + [ + ( + "name", + MyModel, + { + "default": None, + "is_attribute": True, + "is_list": False, + "is_relationship": False, + "name": "name", + "optional": False, + "primary_type": str, + }, + ), + ( + "age", + MyModel, + { + "default": None, + "is_attribute": True, + "is_list": False, + "is_relationship": False, + "name": "age", + "optional": False, + "primary_type": int, + }, + ), + ( + "is_active", + MyModel, + { + "default": None, + "is_attribute": True, + "is_list": False, + "is_relationship": False, + "name": "is_active", + "optional": False, + "primary_type": bool, + }, + ), + ( + "opt_age", + MyModel, + { + "default": None, + "is_attribute": True, + "is_list": False, + "is_relationship": False, + "name": "opt_age", + "optional": True, + "primary_type": int, + }, + ), + ( + "default_name", + MyModel, + { + "default": "some_default", + "is_attribute": True, + "is_list": False, + "is_relationship": False, + "name": "default_name", + "optional": True, + "primary_type": str, + }, + ), + ( + "old_opt_age", + MyModel, + { + "default": None, + "is_attribute": True, + "is_list": False, + "is_relationship": False, + "name": "old_opt_age", + "optional": True, + "primary_type": int, + }, + ), + ( + "description", + Tag, + { + "default": None, + "is_attribute": True, + "is_list": False, + "is_relationship": False, + "name": "description", + "optional": True, + "primary_type": str, + }, + ), + ( + "name", + Tag, + { + "default": "test_tag", + "is_attribute": True, + "is_list": False, + "is_relationship": False, + "name": "name", + "optional": True, + "primary_type": str, + }, + ), + ( + "label", + Tag, + { + "default": None, + "is_attribute": True, + "is_list": False, + "is_relationship": False, + "name": "label", + "optional": False, + "primary_type": str, + }, + ), + ( + "owner", + Car, + { + "default": None, + "is_attribute": False, + "is_list": False, + "is_relationship": True, + "name": "owner", + "optional": False, + "primary_type": Person, + }, + ), + ( + "tags", + Car, + { + "default": None, + "is_attribute": False, + "is_list": True, + "is_relationship": True, + "name": "tags", + "optional": False, + "primary_type": Tag, + }, + ), + ( + "secondary_owner", + Car, + { + "default": None, + "is_attribute": False, + "is_list": False, + "is_relationship": True, + "name": "secondary_owner", + "optional": True, + "primary_type": Person, + }, + ), + ], +) +def test_analyze_field(field_name: str, model: BaseModel, expected: dict): + assert analyze_field(field_name, model.model_fields[field_name]).to_dict() == expected + + +@pytest.mark.parametrize( + "field_name, model, expected", + [ + ( + "name", + MyModel, + AttributeSchema( + name="name", + kind=AttributeKind.TEXT, + optional=False, + ), + ), + ( + "age", + MyModel, + AttributeSchema( + name="age", + kind=AttributeKind.NUMBER, + optional=False, + ), + ), + ( + "is_active", + MyModel, + AttributeSchema( + name="is_active", + kind=AttributeKind.BOOLEAN, + optional=False, + ), + ), + ( + "opt_age", + MyModel, + AttributeSchema( + name="opt_age", + kind=AttributeKind.NUMBER, + optional=True, + ), + ), + ( + "default_name", + MyModel, + AttributeSchema( + name="default_name", + kind=AttributeKind.TEXT, + optional=True, + default="some_default", + ), + ), + ( + "old_opt_age", + MyModel, + AttributeSchema( + name="old_opt_age", + kind=AttributeKind.NUMBER, + optional=True, + ), + ), + ( + "description", + Tag, + AttributeSchema( + name="description", + kind=AttributeKind.TEXTAREA, + optional=True, + ), + ), + ( + "name", + Tag, + AttributeSchema( + name="name", + description="The name of the tag", + kind=AttributeKind.TEXT, + optional=True, + ), + ), + ( + "label", + Tag, + AttributeSchema( + name="label", + description="The label of the tag", + kind=AttributeKind.TEXT, + optional=False, + unique=True, + ), + ), + ], +) +def test_field_to_attribute(field_name: str, model: BaseModel, expected: AttributeSchema): + field = model.model_fields[field_name] + field_info = analyze_field(field_name, field) + assert field_to_attribute(field_name, field_info, field) == expected + + +@pytest.mark.parametrize( + "field_name, model, expected", + [ + ( + "owner", + Car, + RelationshipSchema( + name="owner", + peer="TestingPerson", + cardinality="one", + optional=False, + ), + ), + ( + "tags", + Car, + RelationshipSchema( + name="tags", + peer="TestingTag", + cardinality="many", + optional=False, + ), + ), + ( + "secondary_owner", + Car, + RelationshipSchema( + name="secondary_owner", + peer="TestingPerson", + cardinality="one", + optional=True, + ), + ), + ], +) +def test_field_to_relationship(field_name: str, model: BaseModel, expected: RelationshipSchema): + field = model.model_fields[field_name] + field_info = analyze_field(field_name, field) + assert field_to_relationship(field_name, field_info, field) == expected + + +def test_related_models(): + schemas = from_pydantic(models=[Person, Car, Tag]) + assert len(schemas.nodes) == 3 From 955054ca6b8b3c7ba4568e0220d2e9eb51ca4e3a Mon Sep 17 00:00:00 2001 From: Damien Garros Date: Mon, 24 Mar 2025 12:00:19 +0100 Subject: [PATCH 2/5] Add typing support for get | filters | all methods when using Pydantic --- .../{schema_pydantic.py => pydantic_car.py} | 29 +- .../python-sdk/examples/pydantic_infra.py | 113 +++++++ infrahub_sdk/client.py | 238 +++++++++++++-- infrahub_sdk/schema/__init__.py | 25 +- infrahub_sdk/schema/pydantic_utils.py | 247 +++++++++++++-- tests/unit/sdk/test_pydantic.py | 285 +++++++++++++++--- 6 files changed, 838 insertions(+), 99 deletions(-) rename docs/docs/python-sdk/examples/{schema_pydantic.py => pydantic_car.py} (53%) create mode 100644 docs/docs/python-sdk/examples/pydantic_infra.py diff --git a/docs/docs/python-sdk/examples/schema_pydantic.py b/docs/docs/python-sdk/examples/pydantic_car.py similarity index 53% rename from docs/docs/python-sdk/examples/schema_pydantic.py rename to docs/docs/python-sdk/examples/pydantic_car.py index 8767f7f9..321c7d31 100644 --- a/docs/docs/python-sdk/examples/schema_pydantic.py +++ b/docs/docs/python-sdk/examples/pydantic_car.py @@ -4,36 +4,47 @@ from typing import Annotated -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict from infrahub_sdk import InfrahubClient from rich import print as rprint -from infrahub_sdk.schema import InfrahubAttributeParam as AttrParam, InfrahubRelationshipParam as RelParam, AttributeKind, from_pydantic +from infrahub_sdk.schema import InfrahubAttributeParam as AttrParam, InfrahubRelationshipParam as RelParam, AttributeKind, from_pydantic, NodeSchema, NodeModel, GenericModel -class Tag(BaseModel): +class Tag(NodeModel): + model_config = ConfigDict( + node_schema=NodeSchema(name="Tag", namespace="Test", human_readable_fields=["name__value"]) + ) + name: Annotated[str, AttrParam(unique=True), Field(description="The name of the tag")] label: str | None = Field(description="The label of the tag") description: Annotated[str | None, AttrParam(kind=AttributeKind.TEXTAREA)] = None -class Car(BaseModel): +class TestCar(NodeModel): name: str = Field(description="The name of the car") tags: list[Tag] - owner: Annotated[Person, RelParam(identifier="car__person")] - secondary_owner: Person | None = None + owner: Annotated[TestPerson, RelParam(identifier="car__person")] + secondary_owner: TestPerson | None = None -class Person(BaseModel): +class TestPerson(GenericModel): name: str - cars: Annotated[list[Car] | None, RelParam(identifier="car__person")] = None + +class TestCarOwner(NodeModel, TestPerson): + cars: Annotated[list[TestCar] | None, RelParam(identifier="car__person")] = None async def main(): client = InfrahubClient() - schema = from_pydantic(models=[Person, Car, Tag]) + schema = from_pydantic(models=[TestPerson, TestCar, Tag, TestPerson, TestCarOwner]) rprint(schema.to_schema_dict()) response = await client.schema.load(schemas=[schema.to_schema_dict()], wait_until_converged=True) rprint(response) + # Create a Tag + tag = await client.create("TestTag", name="Blue", label="Blue") + await tag.save(allow_upsert=True) + + if __name__ == "__main__": aiorun(main()) diff --git a/docs/docs/python-sdk/examples/pydantic_infra.py b/docs/docs/python-sdk/examples/pydantic_infra.py new file mode 100644 index 00000000..5ff137a4 --- /dev/null +++ b/docs/docs/python-sdk/examples/pydantic_infra.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +from asyncio import run as aiorun + +from infrahub_sdk.async_typer import AsyncTyper + +from typing import Annotated + +from pydantic import BaseModel, Field, ConfigDict +from infrahub_sdk import InfrahubClient +from rich import print as rprint +from infrahub_sdk.schema import InfrahubAttributeParam as AttrParam, InfrahubRelationshipParam as RelParam, AttributeKind, from_pydantic, NodeSchema, NodeModel, GenericSchema, GenericModel, RelationshipKind + + +app = AsyncTyper() + + +class Site(NodeModel): + model_config = ConfigDict( + node_schema=NodeSchema(name="Site", namespace="Infra", human_friendly_id=["name__value"], display_labels=["name__value"]) + ) + + name: Annotated[str, AttrParam(unique=True)] = Field(description="The name of the site") + + +class Vlan(NodeModel): + model_config = ConfigDict( + node_schema=NodeSchema(name="Vlan", namespace="Infra", human_friendly_id=["vlan_id__value"], display_labels=["vlan_id__value"]) + ) + + name: str + vlan_id: int + description: str | None = None + + +class Device(NodeModel): + model_config = ConfigDict( + node_schema=NodeSchema(name="Device", namespace="Infra", human_friendly_id=["name__value"], display_labels=["name__value"]) + ) + + name: Annotated[str, AttrParam(unique=True)] = Field(description="The name of the car") + site: Annotated[Site, RelParam(kind=RelationshipKind.ATTRIBUTE, identifier="device__site")] + interfaces: Annotated[list[Interface], RelParam(kind=RelationshipKind.COMPONENT, identifier="device__interfaces")] = Field(default_factory=list) + + +class Interface(GenericModel): + model_config = ConfigDict( + generic_schema=GenericSchema(name="Interface", namespace="Infra", human_friendly_id=["device__name__value", "name__value"], display_labels=["name__value"]) + ) + + device: Annotated[Device, RelParam(kind=RelationshipKind.PARENT, identifier="device__interfaces")] + name: str + description: str | None = None + +class L2Interface(Interface): + model_config = ConfigDict( + node_schema=NodeSchema(name="L2Interface", namespace="Infra") + ) + + vlans: list[Vlan] = Field(default_factory=list) + +class LoopbackInterface(Interface): + model_config = ConfigDict( + node_schema=NodeSchema(name="LoopbackInterface", namespace="Infra") + ) + + + +@app.command() +async def load_schema(): + client = InfrahubClient() + schema = from_pydantic(models=[Site, Device, Interface, L2Interface, LoopbackInterface, Vlan]) + rprint(schema.to_schema_dict()) + response = await client.schema.load(schemas=[schema.to_schema_dict()], wait_until_converged=True) + rprint(response) + + +@app.command() +async def load_data(): + client = InfrahubClient() + + atl = await client.create("InfraSite", name="ATL") + await atl.save(allow_upsert=True) + cdg = await client.create("InfraSite", name="CDG") + await cdg.save(allow_upsert=True) + + device1 = await client.create("InfraDevice", name="atl1-dev1", site=atl) + await device1.save(allow_upsert=True) + device2 = await client.create("InfraDevice", name="atl1-dev2", site=atl) + await device2.save(allow_upsert=True) + + lo0dev1 = await client.create("InfraLoopbackInterface", name="lo0", device=device1) + await lo0dev1.save(allow_upsert=True) + lo0dev2 = await client.create("InfraLoopbackInterface", name="lo0", device=device2) + await lo0dev2.save(allow_upsert=True) + + for idx in range(1, 3): + interface = await client.create("InfraL2Interface", name=f"Ethernet{idx}", device=device1) + await interface.save(allow_upsert=True) + + +@app.command() +async def query_data(): + client = InfrahubClient() + sites = await client.all(kind=Site) + + breakpoint() + devices = await client.all(kind=Device) + for device in devices: + rprint(device) + +if __name__ == "__main__": + app() \ No newline at end of file diff --git a/infrahub_sdk/client.py b/infrahub_sdk/client.py index 16c1c73a..2a870021 100644 --- a/infrahub_sdk/client.py +++ b/infrahub_sdk/client.py @@ -50,7 +50,7 @@ from .protocols_base import CoreNode, CoreNodeSync from .queries import QUERY_USER, get_commit_update_mutation from .query_groups import InfrahubGroupContext, InfrahubGroupContextSync -from .schema import InfrahubSchema, InfrahubSchemaSync, NodeSchemaAPI +from .schema import InfrahubSchema, InfrahubSchemaSync, NodeSchemaAPI, SchemaModel from .store import NodeStore, NodeStoreSync from .task.manager import InfrahubTaskManager, InfrahubTaskManagerSync from .timestamp import Timestamp @@ -63,6 +63,7 @@ from .context import RequestContext +SchemaModelType = TypeVar("SchemaModelType", bound=SchemaModel) SchemaType = TypeVar("SchemaType", bound=CoreNode) SchemaTypeSync = TypeVar("SchemaTypeSync", bound=CoreNodeSync) @@ -417,6 +418,63 @@ async def get( **kwargs: Any, ) -> SchemaType: ... + @overload + async def get( + self, + kind: type[SchemaModelType], + raise_when_missing: Literal[False], + at: Timestamp | None = ..., + branch: str | None = ..., + timeout: int | None = ..., + id: str | None = ..., + hfid: list[str] | None = ..., + include: list[str] | None = ..., + exclude: list[str] | None = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + property: bool = ..., + **kwargs: Any, + ) -> SchemaModelType | None: ... + + @overload + async def get( + self, + kind: type[SchemaModelType], + raise_when_missing: Literal[True], + at: Timestamp | None = ..., + branch: str | None = ..., + timeout: int | None = ..., + id: str | None = ..., + hfid: list[str] | None = ..., + include: list[str] | None = ..., + exclude: list[str] | None = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + property: bool = ..., + **kwargs: Any, + ) -> SchemaModelType: ... + + @overload + async def get( + self, + kind: type[SchemaModelType], + raise_when_missing: bool = ..., + at: Timestamp | None = ..., + branch: str | None = ..., + timeout: int | None = ..., + id: str | None = ..., + hfid: list[str] | None = ..., + include: list[str] | None = ..., + exclude: list[str] | None = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + property: bool = ..., + **kwargs: Any, + ) -> SchemaModelType: ... + @overload async def get( self, @@ -476,7 +534,7 @@ async def get( async def get( self, - kind: str | type[SchemaType], + kind: type[SchemaType | SchemaModelType] | str, raise_when_missing: bool = True, at: Timestamp | None = None, branch: str | None = None, @@ -490,7 +548,7 @@ async def get( prefetch_relationships: bool = False, property: bool = False, **kwargs: Any, - ) -> InfrahubNode | SchemaType | None: + ) -> InfrahubNode | SchemaType | SchemaModelType | None: branch = branch or self.default_branch schema = await self.schema.get(kind=kind, branch=branch) @@ -573,7 +631,7 @@ async def _process_nodes_and_relationships( async def count( self, - kind: str | type[SchemaType], + kind: type[SchemaType | SchemaModelType] | str, at: Timestamp | None = None, branch: str | None = None, timeout: int | None = None, @@ -623,6 +681,25 @@ async def all( order: Order | None = ..., ) -> list[SchemaType]: ... + @overload + async def all( + self, + kind: type[SchemaModelType], + at: Timestamp | None = ..., + branch: str | None = ..., + timeout: int | None = ..., + populate_store: bool = ..., + offset: int | None = ..., + limit: int | None = ..., + include: list[str] | None = ..., + exclude: list[str] | None = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + property: bool = ..., + parallel: bool = ..., + order: Order | None = ..., + ) -> list[SchemaModelType]: ... + @overload async def all( self, @@ -644,7 +721,7 @@ async def all( async def all( self, - kind: str | type[SchemaType], + kind: type[SchemaType | SchemaModelType] | str, at: Timestamp | None = None, branch: str | None = None, timeout: int | None = None, @@ -658,7 +735,7 @@ async def all( property: bool = False, parallel: bool = False, order: Order | None = None, - ) -> list[InfrahubNode] | list[SchemaType]: + ) -> list[InfrahubNode] | list[SchemaType] | list[SchemaModelType]: """Retrieve all nodes of a given kind Args: @@ -717,6 +794,27 @@ async def filters( **kwargs: Any, ) -> list[SchemaType]: ... + @overload + async def filters( + self, + kind: type[SchemaModelType], + at: Timestamp | None = ..., + branch: str | None = ..., + timeout: int | None = ..., + populate_store: bool = ..., + offset: int | None = ..., + limit: int | None = ..., + include: list[str] | None = ..., + exclude: list[str] | None = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + partial_match: bool = ..., + property: bool = ..., + parallel: bool = ..., + order: Order | None = ..., + **kwargs: Any, + ) -> list[SchemaModelType]: ... + @overload async def filters( self, @@ -740,7 +838,7 @@ async def filters( async def filters( self, - kind: str | type[SchemaType], + kind: type[SchemaType | SchemaModelType] | str, at: Timestamp | None = None, branch: str | None = None, timeout: int | None = None, @@ -756,7 +854,7 @@ async def filters( parallel: bool = False, order: Order | None = None, **kwargs: Any, - ) -> list[InfrahubNode] | list[SchemaType]: + ) -> list[InfrahubNode] | list[SchemaType] | list[SchemaModelType]: """Retrieve nodes of a given kind based on provided filters. Args: @@ -780,6 +878,7 @@ async def filters( list[InfrahubNodeSync]: List of Nodes that match the given filters. """ branch = branch or self.default_branch + schema = await self.schema.get(kind=kind, branch=branch) if at: at = Timestamp(at) @@ -867,7 +966,11 @@ async def process_non_batch() -> tuple[list[InfrahubNode], list[InfrahubNode]]: related_nodes = list(set(related_nodes)) for node in related_nodes: if node.id: - self.store.set(node=node) + self.store.set(key=node.id, node=node) + + if isinstance(kind, type) and issubclass(kind, SchemaModel): + return [kind.from_node(node) for node in nodes] # type: ignore[return-value] + return nodes def clone(self, branch: str | None = None) -> InfrahubClient: @@ -1702,7 +1805,7 @@ def execute_graphql( def count( self, - kind: str | type[SchemaType], + kind: type[SchemaType | SchemaModelType] | str, at: Timestamp | None = None, branch: str | None = None, timeout: int | None = None, @@ -1752,6 +1855,25 @@ def all( order: Order | None = ..., ) -> list[SchemaTypeSync]: ... + @overload + def all( + self, + kind: type[SchemaModelType], + at: Timestamp | None = ..., + branch: str | None = ..., + timeout: int | None = ..., + populate_store: bool = ..., + offset: int | None = ..., + limit: int | None = ..., + include: list[str] | None = ..., + exclude: list[str] | None = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + property: bool = ..., + parallel: bool = ..., + order: Order | None = ..., + ) -> list[SchemaModelType]: ... + @overload def all( self, @@ -1773,7 +1895,7 @@ def all( def all( self, - kind: str | type[SchemaTypeSync], + kind: type[SchemaTypeSync | SchemaModelType] | str, at: Timestamp | None = None, branch: str | None = None, timeout: int | None = None, @@ -1787,7 +1909,7 @@ def all( property: bool = False, parallel: bool = False, order: Order | None = None, - ) -> list[InfrahubNodeSync] | list[SchemaTypeSync]: + ) -> list[InfrahubNodeSync] | list[SchemaTypeSync] | list[SchemaModelType]: """Retrieve all nodes of a given kind Args: @@ -1881,6 +2003,27 @@ def filters( **kwargs: Any, ) -> list[SchemaTypeSync]: ... + @overload + def filters( + self, + kind: type[SchemaModelType], + at: Timestamp | None = ..., + branch: str | None = ..., + timeout: int | None = ..., + populate_store: bool = ..., + offset: int | None = ..., + limit: int | None = ..., + include: list[str] | None = ..., + exclude: list[str] | None = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + partial_match: bool = ..., + property: bool = ..., + parallel: bool = ..., + order: Order | None = ..., + **kwargs: Any, + ) -> list[SchemaModelType]: ... + @overload def filters( self, @@ -1904,7 +2047,7 @@ def filters( def filters( self, - kind: str | type[SchemaTypeSync], + kind: type[SchemaTypeSync | SchemaModelType] | str, at: Timestamp | None = None, branch: str | None = None, timeout: int | None = None, @@ -1920,7 +2063,7 @@ def filters( parallel: bool = False, order: Order | None = None, **kwargs: Any, - ) -> list[InfrahubNodeSync] | list[SchemaTypeSync]: + ) -> list[InfrahubNodeSync] | list[SchemaTypeSync] | list[SchemaModelType]: """Retrieve nodes of a given kind based on provided filters. Args: @@ -2033,7 +2176,11 @@ def process_non_batch() -> tuple[list[InfrahubNodeSync], list[InfrahubNodeSync]] related_nodes = list(set(related_nodes)) for node in related_nodes: if node.id: - self.store.set(node=node) + self.store.set(key=node.id, node=node) + + if isinstance(kind, type) and issubclass(kind, SchemaModel): + return [kind.from_node(node) for node in nodes] # type: ignore[return-value] + return nodes @overload @@ -2093,6 +2240,63 @@ def get( **kwargs: Any, ) -> SchemaTypeSync: ... + @overload + def get( + self, + kind: type[SchemaModelType], + raise_when_missing: Literal[False], + at: Timestamp | None = ..., + branch: str | None = ..., + timeout: int | None = ..., + id: str | None = ..., + hfid: list[str] | None = ..., + include: list[str] | None = ..., + exclude: list[str] | None = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + property: bool = ..., + **kwargs: Any, + ) -> SchemaModelType | None: ... + + @overload + def get( + self, + kind: type[SchemaModelType], + raise_when_missing: Literal[True], + at: Timestamp | None = ..., + branch: str | None = ..., + timeout: int | None = ..., + id: str | None = ..., + hfid: list[str] | None = ..., + include: list[str] | None = ..., + exclude: list[str] | None = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + property: bool = ..., + **kwargs: Any, + ) -> SchemaModelType: ... + + @overload + def get( + self, + kind: type[SchemaModelType], + raise_when_missing: bool = ..., + at: Timestamp | None = ..., + branch: str | None = ..., + timeout: int | None = ..., + id: str | None = ..., + hfid: list[str] | None = ..., + include: list[str] | None = ..., + exclude: list[str] | None = ..., + populate_store: bool = ..., + fragment: bool = ..., + prefetch_relationships: bool = ..., + property: bool = ..., + **kwargs: Any, + ) -> SchemaModelType: ... + @overload def get( self, @@ -2152,7 +2356,7 @@ def get( def get( self, - kind: str | type[SchemaTypeSync], + kind: type[SchemaTypeSync | SchemaModelType] | str, raise_when_missing: bool = True, at: Timestamp | None = None, branch: str | None = None, @@ -2166,7 +2370,7 @@ def get( prefetch_relationships: bool = False, property: bool = False, **kwargs: Any, - ) -> InfrahubNodeSync | SchemaTypeSync | None: + ) -> InfrahubNodeSync | SchemaTypeSync | SchemaModelType | None: branch = branch or self.default_branch schema = self.schema.get(kind=kind, branch=branch) diff --git a/infrahub_sdk/schema/__init__.py b/infrahub_sdk/schema/__init__.py index 6854da01..118a9599 100644 --- a/infrahub_sdk/schema/__init__.py +++ b/infrahub_sdk/schema/__init__.py @@ -41,10 +41,17 @@ SchemaRootAPI, TemplateSchemaAPI, ) -from .pydantic_utils import InfrahubAttributeParam, InfrahubRelationshipParam, from_pydantic +from .pydantic_utils import ( + GenericModel, + InfrahubAttributeParam, + InfrahubRelationshipParam, + NodeModel, + SchemaModel, + from_pydantic, +) if TYPE_CHECKING: - from ..client import InfrahubClient, InfrahubClientSync, SchemaType, SchemaTypeSync + from ..client import InfrahubClient, InfrahubClientSync, SchemaModelType, SchemaType, SchemaTypeSync from ..node import InfrahubNode, InfrahubNodeSync InfrahubNodeTypes = Union[InfrahubNode, InfrahubNodeSync] @@ -55,10 +62,12 @@ "AttributeSchema", "AttributeSchemaAPI", "BranchSupportType", + "GenericModel", "GenericSchema", "GenericSchemaAPI", "InfrahubAttributeParam", "InfrahubRelationshipParam", + "NodeModel", "NodeSchema", "NodeSchemaAPI", "ProfileSchemaAPI", @@ -66,6 +75,7 @@ "RelationshipKind", "RelationshipSchema", "RelationshipSchemaAPI", + "SchemaModel", "SchemaRoot", "SchemaRootAPI", "TemplateSchemaAPI", @@ -190,14 +200,17 @@ def _validate_load_schema_response(response: httpx.Response) -> SchemaLoadRespon raise InvalidResponseError(message=f"Invalid response received from server HTTP {response.status_code}") @staticmethod - def _get_schema_name(schema: type[SchemaType | SchemaTypeSync] | str) -> str: + def _get_schema_name(schema: type[SchemaType | SchemaTypeSync | SchemaModelType] | str) -> str: if hasattr(schema, "_is_runtime_protocol") and schema._is_runtime_protocol: # type: ignore[union-attr] return schema.__name__ # type: ignore[union-attr] + if isinstance(schema, type) and issubclass(schema, SchemaModel): + return schema.get_kind() + if isinstance(schema, str): return schema - raise ValueError("schema must be a protocol or a string") + raise ValueError("schema must be a protocol, a SchemaModel, or a string") @staticmethod def _parse_schema_response(response: httpx.Response, branch: str) -> MutableMapping[str, Any]: @@ -233,7 +246,7 @@ class InfrahubSchema(InfrahubSchemaBase): async def get( self, - kind: type[SchemaType | SchemaTypeSync] | str, + kind: type[SchemaType | SchemaTypeSync | SchemaModelType] | str, branch: str | None = None, refresh: bool = False, timeout: int | None = None, @@ -528,7 +541,7 @@ def all( def get( self, - kind: type[SchemaType | SchemaTypeSync] | str, + kind: type[SchemaType | SchemaTypeSync | SchemaModelType] | str, branch: str | None = None, refresh: bool = False, timeout: int | None = None, diff --git a/infrahub_sdk/schema/pydantic_utils.py b/infrahub_sdk/schema/pydantic_utils.py index 60cc3b9a..0f605c57 100644 --- a/infrahub_sdk/schema/pydantic_utils.py +++ b/infrahub_sdk/schema/pydantic_utils.py @@ -1,16 +1,29 @@ from __future__ import annotations +import re import typing from dataclasses import dataclass from types import UnionType -from typing import Any +from typing import TYPE_CHECKING, Any -from pydantic import BaseModel +from pydantic import BaseModel, Field from pydantic.fields import FieldInfo, PydanticUndefined - -from infrahub_sdk.schema.main import AttributeSchema, NodeSchema, RelationshipSchema, SchemaRoot - -from .main import AttributeKind, BranchSupportType, SchemaState +from typing_extensions import Self + +from .main import ( + AttributeKind, + AttributeSchema, + BranchSupportType, + GenericSchema, + NodeSchema, + RelationshipKind, + RelationshipSchema, + SchemaRoot, + SchemaState, +) + +if TYPE_CHECKING: + from infrahub_sdk.node import InfrahubNode, InfrahubNodeSync KIND_MAPPING: dict[type, AttributeKind] = { int: AttributeKind.NUMBER, @@ -19,6 +32,42 @@ bool: AttributeKind.BOOLEAN, } +NAMESPACE_REGEX = r"^[A-Z][a-z0-9]+$" +NODE_KIND_REGEX = r"^[A-Z][a-zA-Z0-9]+$" + + +class SchemaModel(BaseModel): + id: str | None = Field(default=None, description="The ID of the node") + + @classmethod + def get_kind(cls) -> str: + return get_kind(cls) + + @classmethod + def from_node(cls, node: InfrahubNode | InfrahubNodeSync) -> Self: + data = {} + for field_name, field in cls.model_fields.items(): + field_info = analyze_field(field_name, field) + if field_name == "id": + data[field_name] = node.id + elif field_info.is_attribute: + attr = getattr(node, field_name) + data[field_name] = attr.value + + # elif field_info.is_relationship: + # rel = getattr(node, field_name) + # data[field_name] = rel.value + + return cls(**data) + + +class NodeModel(SchemaModel): + pass + + +class GenericModel(SchemaModel): + pass + @dataclass class InfrahubAttributeParam: @@ -31,6 +80,7 @@ class InfrahubAttributeParam: @dataclass class InfrahubRelationshipParam: + kind: RelationshipKind | None = None identifier: str | None = None branch: BranchSupportType | None = None @@ -46,6 +96,10 @@ class InfrahubFieldInfo: def primary_type(self) -> type: if len(self.types) == 0: raise ValueError("No types found") + + # if isinstance(self.primary_type, ForwardRef): + # raise TypeError("Forward References are not supported yet, please ensure the models are defined in the right order") + if self.is_list: return typing.get_args(self.types[0])[0] @@ -61,6 +115,7 @@ def is_relationship(self) -> bool: @property def is_list(self) -> bool: + # breakpoint() return typing.get_origin(self.types[0]) is list def to_dict(self) -> dict: @@ -106,12 +161,16 @@ def get_attribute_kind(field: FieldInfo) -> AttributeKind: raise ValueError(f"Unknown field type: {field.annotation}") -def field_to_attribute(field_name: str, field_info: InfrahubFieldInfo, field: FieldInfo) -> AttributeSchema: # noqa: ARG001 +def field_to_attribute(field_name: str, field_info: InfrahubFieldInfo, field: FieldInfo) -> AttributeSchema: field_param = InfrahubAttributeParam() field_params = [metadata for metadata in field.metadata if isinstance(metadata, InfrahubAttributeParam)] if len(field_params) == 1: field_param = field_params[0] + pattern = field._attributes_set.get("pattern", None) + max_length = field._attributes_set.get("max_length", None) + min_length = field._attributes_set.get("min_length", None) + return AttributeSchema( name=field_name, label=field_param.label, @@ -120,6 +179,10 @@ def field_to_attribute(field_name: str, field_info: InfrahubFieldInfo, field: Fi optional=not field.is_required(), unique=field_param.unique, branch=field_param.branch, + default_value=field_info.default, + regex=str(pattern) if pattern else None, + max_length=int(str(max_length)) if max_length else None, + min_length=int(str(min_length)) if min_length else None, ) @@ -127,7 +190,6 @@ def field_to_relationship( field_name: str, field_info: InfrahubFieldInfo, field: FieldInfo, - namespace: str = "Testing", ) -> RelationshipSchema: field_param = InfrahubRelationshipParam() field_params = [metadata for metadata in field.metadata if isinstance(metadata, InfrahubRelationshipParam)] @@ -137,7 +199,7 @@ def field_to_relationship( return RelationshipSchema( name=field_name, description=field.description, - peer=f"{namespace}{field_info.primary_type.__name__}", + peer=get_kind(field_info.primary_type), identifier=field_param.identifier, cardinality="many" if field_info.is_list else "one", optional=field_info.optional, @@ -145,23 +207,164 @@ def field_to_relationship( ) -def from_pydantic(models: list[type[BaseModel]], namespace: str = "Testing") -> SchemaRoot: - schema = SchemaRoot(version="1.0") +def extract_validate_generic(model: type[BaseModel]) -> list[str]: + return [get_kind(ancestor) for ancestor in model.__bases__ if issubclass(ancestor, GenericModel)] - for model in models: - node = NodeSchema( - name=model.__name__, - namespace=namespace, + +def validate_kind(kind: str) -> tuple[str, str]: + # First, handle transition from a lowercase to uppercase + name_with_spaces = re.sub(r"([a-z])([A-Z])", r"\1 \2", kind) + + # Then, handle consecutive uppercase letters followed by a lowercase + # (e.g., "HTTPRequest" -> "HTTP Request") + name_with_spaces = re.sub(r"([A-Z])([A-Z][a-z])", r"\1 \2", name_with_spaces) + + name_parts = name_with_spaces.split(" ") + + if len(name_parts) == 1: + raise ValueError(f"Invalid kind: {kind}, must contain a Namespace and a Name") + kind_namespace = name_parts[0] + kind_name = "".join(name_parts[1:]) + + if not kind_namespace[0].isupper(): + raise ValueError(f"Invalid namespace: {kind_namespace}, must start with an uppercase letter") + + return kind_namespace, kind_name + + +def is_generic(model: type[BaseModel]) -> bool: + return GenericModel in model.__bases__ + + +def get_kind(model: type[BaseModel]) -> str: + node_schema: NodeSchema | None = model.model_config.get("node_schema") or None # type: ignore[assignment] + generic_schema: GenericSchema | None = model.model_config.get("generic_schema") or None # type: ignore[assignment] + + if is_generic(model) and generic_schema: + return generic_schema.kind + if node_schema: + return node_schema.kind + namespace, name = validate_kind(model.__name__) + return f"{namespace}{name}" + + +def get_generics(model: type[BaseModel]) -> list[type[GenericModel]]: + return [ancestor for ancestor in model.__bases__ if issubclass(ancestor, GenericModel)] + + +def _add_fields( + node: NodeSchema | GenericSchema, model: type[BaseModel], inherited_fields: dict[str, dict[str, Any]] | None = None +) -> None: + for field_name, field in model.model_fields.items(): + if ( + inherited_fields + and field_name in inherited_fields + and field._attributes_set == inherited_fields[field_name] + ): + continue + + if field_name == "id": + continue + + field_info = analyze_field(field_name, field) + + if field_info.is_attribute: + node.attributes.append(field_to_attribute(field_name, field_info, field)) + elif field_info.is_relationship: + node.relationships.append(field_to_relationship(field_name, field_info, field)) + + +def model_to_node(model: type[BaseModel]) -> NodeSchema | GenericSchema: + # ------------------------------------------------------------ + # GenericSchema + # ------------------------------------------------------------ + if GenericModel in model.__bases__: + generic_schema: GenericSchema | None = model.model_config.get("generic_schema") or None # type: ignore[assignment] + + if not generic_schema: + namespace, name = validate_kind(model.__name__) + + generic = GenericSchema( + name=generic_schema.name if generic_schema else name, + namespace=generic_schema.namespace if generic_schema else namespace, + display_labels=generic_schema.display_labels if generic_schema else None, + description=generic_schema.description if generic_schema else None, + state=generic_schema.state if generic_schema else SchemaState.PRESENT, + label=generic_schema.label if generic_schema else None, + include_in_menu=generic_schema.include_in_menu if generic_schema else None, + menu_placement=generic_schema.menu_placement if generic_schema else None, + documentation=generic_schema.documentation if generic_schema else None, + order_by=generic_schema.order_by if generic_schema else None, + # parent=schema.parent if schema else None, + # children=schema.children if schema else None, + icon=generic_schema.icon if generic_schema else None, + # generate_profile=schema.generate_profile if schema else None, + # branch=schema.branch if schema else None, + # default_filter=schema.default_filter if schema else None, ) + _add_fields(node=generic, model=model) + return generic + + # ------------------------------------------------------------ + # NodeSchema + # ------------------------------------------------------------ + node_schema: NodeSchema | None = model.model_config.get("node_schema") or None # type: ignore[assignment] + + if not node_schema: + namespace, name = validate_kind(model.__name__) + + generics = get_generics(model) + + # list all inherited fields with a hash for each to track if they are identical on the node + inherited_fields = { + field_name: field._attributes_set for generic in generics for field_name, field in generic.model_fields.items() + } + + node = NodeSchema( + name=node_schema.name if node_schema else name, + namespace=node_schema.namespace if node_schema else namespace, + display_labels=node_schema.display_labels if node_schema else None, + description=node_schema.description if node_schema else None, + state=node_schema.state if node_schema else SchemaState.PRESENT, + label=node_schema.label if node_schema else None, + include_in_menu=node_schema.include_in_menu if node_schema else None, + menu_placement=node_schema.menu_placement if node_schema else None, + documentation=node_schema.documentation if node_schema else None, + order_by=node_schema.order_by if node_schema else None, + inherit_from=[get_kind(generic) for generic in generics], + parent=node_schema.parent if node_schema else None, + children=node_schema.children if node_schema else None, + icon=node_schema.icon if node_schema else None, + generate_profile=node_schema.generate_profile if node_schema else None, + branch=node_schema.branch if node_schema else None, + # default_filter=schema.default_filter if schema else None, + ) + + _add_fields(node=node, model=model, inherited_fields=inherited_fields) + return node - for field_name, field in model.model_fields.items(): - field_info = analyze_field(field_name, field) - if field_info.is_attribute: - node.attributes.append(field_to_attribute(field_name, field_info, field)) - elif field_info.is_relationship: - node.relationships.append(field_to_relationship(field_name, field_info, field, namespace)) +def from_pydantic(models: list[type[BaseModel]]) -> SchemaRoot: + schema = SchemaRoot(version="1.0") + + for model in models: + node = model_to_node(model=model) - schema.nodes.append(node) + if isinstance(node, NodeSchema): + schema.nodes.append(node) + elif isinstance(node, GenericSchema): + schema.generics.append(node) return schema + + +# class NodeSchema(BaseModel): +# name: str| None = None +# namespace: str| None = None +# display_labels: list[str] | None = None + +# class NodeMetaclass(ModelMetaclass): +# model_config: NodeConfig +# # model_schema: NodeSchema +# __config__: type[NodeConfig] +# # __schema__: NodeSchema diff --git a/tests/unit/sdk/test_pydantic.py b/tests/unit/sdk/test_pydantic.py index e5d18359..df6fdd74 100644 --- a/tests/unit/sdk/test_pydantic.py +++ b/tests/unit/sdk/test_pydantic.py @@ -3,22 +3,33 @@ from typing import Annotated, Optional import pytest -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field -from infrahub_sdk.schema.main import AttributeKind, AttributeSchema, RelationshipSchema -from infrahub_sdk.schema.pydantic_utils import ( - InfrahubAttributeParam as AttrParam, +from infrahub_sdk.schema.main import ( + AttributeKind, + AttributeSchema, + GenericSchema, + NodeSchema, + RelationshipSchema, + SchemaState, ) from infrahub_sdk.schema.pydantic_utils import ( + GenericModel, + NodeModel, analyze_field, field_to_attribute, field_to_relationship, from_pydantic, get_attribute_kind, + get_kind, + model_to_node, +) +from infrahub_sdk.schema.pydantic_utils import ( + InfrahubAttributeParam as AttrParam, ) -class MyModel(BaseModel): +class MyAllInOneModel(BaseModel): name: str age: int is_active: bool @@ -27,22 +38,48 @@ class MyModel(BaseModel): old_opt_age: Optional[int] = None # noqa: UP007 -class Tag(BaseModel): +class AcmeTag(BaseModel): name: str = Field(default="test_tag", description="The name of the tag") description: Annotated[str | None, AttrParam(kind=AttributeKind.TEXTAREA)] = None label: Annotated[str, AttrParam(unique=True), Field(description="The label of the tag")] -class Car(BaseModel): +class AcmeCar(BaseModel): name: str - tags: list[Tag] - owner: Person - secondary_owner: Person | None = None + tags: list[AcmeTag] + owner: AcmePerson + secondary_owner: AcmePerson | None = None -class Person(BaseModel): +class AcmePerson(BaseModel): name: str - cars: list[Car] | None = None + cars: list[AcmeCar] | None = None + + +# -------------------------------- + + +class Book(NodeModel): + model_config = ConfigDict(node_schema=NodeSchema(name="Book", namespace="Library", display_labels=["name__value"])) + title: str + isbn: Annotated[str, AttrParam(unique=True)] + created_at: str + author: LibraryAuthor + + +class AbstractPerson(GenericModel): + model_config = ConfigDict(generic_schema=GenericSchema(name="AbstractPerson", namespace="Library")) + firstname: str = Field(..., description="The first name of the person", pattern=r"^[a-zA-Z]+$") + lastname: str + + +class LibraryAuthor(AbstractPerson): + books: list[Book] + + +class LibraryReader(AbstractPerson): + favorite_books: list[Book] + favorite_authors: list[LibraryAuthor] @pytest.mark.parametrize( @@ -57,7 +94,7 @@ class Person(BaseModel): ], ) def test_get_field_kind(field_name, expected_kind): - assert get_attribute_kind(MyModel.model_fields[field_name]) == expected_kind + assert get_attribute_kind(MyAllInOneModel.model_fields[field_name]) == expected_kind @pytest.mark.parametrize( @@ -65,7 +102,7 @@ def test_get_field_kind(field_name, expected_kind): [ ( "name", - MyModel, + MyAllInOneModel, { "default": None, "is_attribute": True, @@ -78,7 +115,7 @@ def test_get_field_kind(field_name, expected_kind): ), ( "age", - MyModel, + MyAllInOneModel, { "default": None, "is_attribute": True, @@ -91,7 +128,7 @@ def test_get_field_kind(field_name, expected_kind): ), ( "is_active", - MyModel, + MyAllInOneModel, { "default": None, "is_attribute": True, @@ -104,7 +141,7 @@ def test_get_field_kind(field_name, expected_kind): ), ( "opt_age", - MyModel, + MyAllInOneModel, { "default": None, "is_attribute": True, @@ -117,7 +154,7 @@ def test_get_field_kind(field_name, expected_kind): ), ( "default_name", - MyModel, + MyAllInOneModel, { "default": "some_default", "is_attribute": True, @@ -130,7 +167,7 @@ def test_get_field_kind(field_name, expected_kind): ), ( "old_opt_age", - MyModel, + MyAllInOneModel, { "default": None, "is_attribute": True, @@ -143,7 +180,7 @@ def test_get_field_kind(field_name, expected_kind): ), ( "description", - Tag, + AcmeTag, { "default": None, "is_attribute": True, @@ -156,7 +193,7 @@ def test_get_field_kind(field_name, expected_kind): ), ( "name", - Tag, + AcmeTag, { "default": "test_tag", "is_attribute": True, @@ -169,7 +206,7 @@ def test_get_field_kind(field_name, expected_kind): ), ( "label", - Tag, + AcmeTag, { "default": None, "is_attribute": True, @@ -182,7 +219,7 @@ def test_get_field_kind(field_name, expected_kind): ), ( "owner", - Car, + AcmeCar, { "default": None, "is_attribute": False, @@ -190,12 +227,12 @@ def test_get_field_kind(field_name, expected_kind): "is_relationship": True, "name": "owner", "optional": False, - "primary_type": Person, + "primary_type": AcmePerson, }, ), ( "tags", - Car, + AcmeCar, { "default": None, "is_attribute": False, @@ -203,12 +240,12 @@ def test_get_field_kind(field_name, expected_kind): "is_relationship": True, "name": "tags", "optional": False, - "primary_type": Tag, + "primary_type": AcmeTag, }, ), ( "secondary_owner", - Car, + AcmeCar, { "default": None, "is_attribute": False, @@ -216,7 +253,7 @@ def test_get_field_kind(field_name, expected_kind): "is_relationship": True, "name": "secondary_owner", "optional": True, - "primary_type": Person, + "primary_type": AcmePerson, }, ), ], @@ -230,7 +267,7 @@ def test_analyze_field(field_name: str, model: BaseModel, expected: dict): [ ( "name", - MyModel, + MyAllInOneModel, AttributeSchema( name="name", kind=AttributeKind.TEXT, @@ -239,7 +276,7 @@ def test_analyze_field(field_name: str, model: BaseModel, expected: dict): ), ( "age", - MyModel, + MyAllInOneModel, AttributeSchema( name="age", kind=AttributeKind.NUMBER, @@ -248,7 +285,7 @@ def test_analyze_field(field_name: str, model: BaseModel, expected: dict): ), ( "is_active", - MyModel, + MyAllInOneModel, AttributeSchema( name="is_active", kind=AttributeKind.BOOLEAN, @@ -257,7 +294,7 @@ def test_analyze_field(field_name: str, model: BaseModel, expected: dict): ), ( "opt_age", - MyModel, + MyAllInOneModel, AttributeSchema( name="opt_age", kind=AttributeKind.NUMBER, @@ -266,17 +303,17 @@ def test_analyze_field(field_name: str, model: BaseModel, expected: dict): ), ( "default_name", - MyModel, + MyAllInOneModel, AttributeSchema( name="default_name", kind=AttributeKind.TEXT, optional=True, - default="some_default", + default_value="some_default", ), ), ( "old_opt_age", - MyModel, + MyAllInOneModel, AttributeSchema( name="old_opt_age", kind=AttributeKind.NUMBER, @@ -285,7 +322,7 @@ def test_analyze_field(field_name: str, model: BaseModel, expected: dict): ), ( "description", - Tag, + AcmeTag, AttributeSchema( name="description", kind=AttributeKind.TEXTAREA, @@ -294,17 +331,18 @@ def test_analyze_field(field_name: str, model: BaseModel, expected: dict): ), ( "name", - Tag, + AcmeTag, AttributeSchema( name="name", description="The name of the tag", kind=AttributeKind.TEXT, optional=True, + default_value="test_tag", ), ), ( "label", - Tag, + AcmeTag, AttributeSchema( name="label", description="The label of the tag", @@ -313,6 +351,17 @@ def test_analyze_field(field_name: str, model: BaseModel, expected: dict): unique=True, ), ), + ( + "firstname", + AbstractPerson, + AttributeSchema( + name="firstname", + description="The first name of the person", + kind=AttributeKind.TEXT, + optional=False, + regex=r"^[a-zA-Z]+$", + ), + ), ], ) def test_field_to_attribute(field_name: str, model: BaseModel, expected: AttributeSchema): @@ -326,30 +375,30 @@ def test_field_to_attribute(field_name: str, model: BaseModel, expected: Attribu [ ( "owner", - Car, + AcmeCar, RelationshipSchema( name="owner", - peer="TestingPerson", + peer="AcmePerson", cardinality="one", optional=False, ), ), ( "tags", - Car, + AcmeCar, RelationshipSchema( name="tags", - peer="TestingTag", + peer="AcmeTag", cardinality="many", optional=False, ), ), ( "secondary_owner", - Car, + AcmeCar, RelationshipSchema( name="secondary_owner", - peer="TestingPerson", + peer="AcmePerson", cardinality="one", optional=True, ), @@ -362,6 +411,152 @@ def test_field_to_relationship(field_name: str, model: BaseModel, expected: Rela assert field_to_relationship(field_name, field_info, field) == expected +@pytest.mark.parametrize( + "model, expected", + [ + (MyAllInOneModel, "MyAllInOneModel"), + (Book, "LibraryBook"), + (LibraryAuthor, "LibraryAuthor"), + (LibraryReader, "LibraryReader"), + (AbstractPerson, "LibraryAbstractPerson"), + (AcmeTag, "AcmeTag"), + (AcmeCar, "AcmeCar"), + (AcmePerson, "AcmePerson"), + ], +) +def test_get_kind(model: BaseModel, expected: str): + assert get_kind(model) == expected + + +@pytest.mark.parametrize( + "model, expected", + [ + ( + MyAllInOneModel, + NodeSchema( + name="AllInOneModel", + namespace="My", + state=SchemaState.PRESENT, + attributes=[ + AttributeSchema(name="name", kind=AttributeKind.TEXT, optional=False), + AttributeSchema(name="age", kind=AttributeKind.NUMBER, optional=False), + AttributeSchema(name="is_active", kind=AttributeKind.BOOLEAN, optional=False), + AttributeSchema(name="opt_age", kind=AttributeKind.NUMBER, optional=True), + AttributeSchema( + name="default_name", kind=AttributeKind.TEXT, optional=True, default_value="some_default" + ), + AttributeSchema(name="old_opt_age", kind=AttributeKind.NUMBER, optional=True), + ], + ), + ), + ( + Book, + NodeSchema( + name="Book", + namespace="Library", + display_labels=["name__value"], + state=SchemaState.PRESENT, + attributes=[ + AttributeSchema(name="title", kind=AttributeKind.TEXT, optional=False), + AttributeSchema(name="isbn", kind=AttributeKind.TEXT, optional=False, unique=True), + AttributeSchema(name="created_at", kind=AttributeKind.TEXT, optional=False), + ], + relationships=[ + RelationshipSchema( + name="author", + peer="LibraryAuthor", + cardinality="one", + optional=False, + relationships=[ + RelationshipSchema(name="books", peer="LibraryBook", cardinality="many", optional=False), + ], + ), + ], + ), + ), + ( + LibraryAuthor, + NodeSchema( + name="Author", + namespace="Library", + inherit_from=["LibraryAbstractPerson"], + state=SchemaState.PRESENT, + relationships=[ + RelationshipSchema(name="books", peer="LibraryBook", cardinality="many", optional=False), + ], + ), + ), + ( + LibraryReader, + NodeSchema( + name="Reader", + namespace="Library", + inherit_from=["LibraryAbstractPerson"], + state=SchemaState.PRESENT, + relationships=[ + RelationshipSchema(name="favorite_books", peer="LibraryBook", cardinality="many", optional=False), + RelationshipSchema( + name="favorite_authors", peer="LibraryAuthor", cardinality="many", optional=False + ), + ], + ), + ), + ( + AbstractPerson, + GenericSchema( + name="AbstractPerson", + namespace="Library", + state=SchemaState.PRESENT, + attributes=[ + AttributeSchema( + name="firstname", + kind=AttributeKind.TEXT, + optional=False, + description="The first name of the person", + regex=r"^[a-zA-Z]+$", + ), + AttributeSchema(name="lastname", kind=AttributeKind.TEXT, optional=False), + ], + ), + ), + ( + AcmeTag, + NodeSchema( + name="Tag", + namespace="Acme", + state=SchemaState.PRESENT, + attributes=[ + AttributeSchema( + name="name", + kind=AttributeKind.TEXT, + default_value="test_tag", + optional=True, + description="The name of the tag", + ), + AttributeSchema(name="description", kind=AttributeKind.TEXTAREA, optional=True), + AttributeSchema( + name="label", + kind=AttributeKind.TEXT, + optional=False, + unique=True, + description="The label of the tag", + ), + ], + ), + ), + ], +) +def test_model_to_node(model: BaseModel, expected: NodeSchema): + node = model_to_node(model) + assert node == expected + + def test_related_models(): - schemas = from_pydantic(models=[Person, Car, Tag]) + schemas = from_pydantic(models=[AcmePerson, AcmeCar, AcmeTag]) + assert len(schemas.nodes) == 3 + + +def test_library_models(): + schemas = from_pydantic(models=[Book, AbstractPerson, LibraryAuthor, LibraryReader]) assert len(schemas.nodes) == 3 + assert len(schemas.generics) == 1 From 0ceb4d4d2f407a16612d06ae6cb333daf8f8cf6b Mon Sep 17 00:00:00 2001 From: Damien Garros Date: Mon, 24 Mar 2025 12:04:30 +0100 Subject: [PATCH 3/5] Format examples for Pydantic --- docs/docs/python-sdk/examples/pydantic_car.py | 25 +++++-- .../python-sdk/examples/pydantic_infra.py | 71 ++++++++++++------- 2 files changed, 65 insertions(+), 31 deletions(-) diff --git a/docs/docs/python-sdk/examples/pydantic_car.py b/docs/docs/python-sdk/examples/pydantic_car.py index 321c7d31..11102d19 100644 --- a/docs/docs/python-sdk/examples/pydantic_car.py +++ b/docs/docs/python-sdk/examples/pydantic_car.py @@ -1,20 +1,32 @@ from __future__ import annotations from asyncio import run as aiorun - from typing import Annotated -from pydantic import BaseModel, Field, ConfigDict -from infrahub_sdk import InfrahubClient +from pydantic import ConfigDict, Field from rich import print as rprint -from infrahub_sdk.schema import InfrahubAttributeParam as AttrParam, InfrahubRelationshipParam as RelParam, AttributeKind, from_pydantic, NodeSchema, NodeModel, GenericModel + +from infrahub_sdk import InfrahubClient +from infrahub_sdk.schema import ( + AttributeKind, + GenericModel, + NodeModel, + NodeSchema, + from_pydantic, +) +from infrahub_sdk.schema import ( + InfrahubAttributeParam as AttrParam, +) +from infrahub_sdk.schema import ( + InfrahubRelationshipParam as RelParam, +) class Tag(NodeModel): model_config = ConfigDict( node_schema=NodeSchema(name="Tag", namespace="Test", human_readable_fields=["name__value"]) ) - + name: Annotated[str, AttrParam(unique=True), Field(description="The name of the tag")] label: str | None = Field(description="The label of the tag") description: Annotated[str | None, AttrParam(kind=AttributeKind.TEXTAREA)] = None @@ -30,11 +42,12 @@ class TestCar(NodeModel): class TestPerson(GenericModel): name: str + class TestCarOwner(NodeModel, TestPerson): cars: Annotated[list[TestCar] | None, RelParam(identifier="car__person")] = None -async def main(): +async def main() -> None: client = InfrahubClient() schema = from_pydantic(models=[TestPerson, TestCar, Tag, TestPerson, TestCarOwner]) rprint(schema.to_schema_dict()) diff --git a/docs/docs/python-sdk/examples/pydantic_infra.py b/docs/docs/python-sdk/examples/pydantic_infra.py index 5ff137a4..a7182c1b 100644 --- a/docs/docs/python-sdk/examples/pydantic_infra.py +++ b/docs/docs/python-sdk/examples/pydantic_infra.py @@ -1,23 +1,35 @@ from __future__ import annotations -from asyncio import run as aiorun - -from infrahub_sdk.async_typer import AsyncTyper - from typing import Annotated -from pydantic import BaseModel, Field, ConfigDict -from infrahub_sdk import InfrahubClient +from pydantic import ConfigDict, Field from rich import print as rprint -from infrahub_sdk.schema import InfrahubAttributeParam as AttrParam, InfrahubRelationshipParam as RelParam, AttributeKind, from_pydantic, NodeSchema, NodeModel, GenericSchema, GenericModel, RelationshipKind +from infrahub_sdk import InfrahubClient +from infrahub_sdk.async_typer import AsyncTyper +from infrahub_sdk.schema import ( + GenericModel, + GenericSchema, + NodeModel, + NodeSchema, + RelationshipKind, + from_pydantic, +) +from infrahub_sdk.schema import ( + InfrahubAttributeParam as AttrParam, +) +from infrahub_sdk.schema import ( + InfrahubRelationshipParam as RelParam, +) app = AsyncTyper() class Site(NodeModel): model_config = ConfigDict( - node_schema=NodeSchema(name="Site", namespace="Infra", human_friendly_id=["name__value"], display_labels=["name__value"]) + node_schema=NodeSchema( + name="Site", namespace="Infra", human_friendly_id=["name__value"], display_labels=["name__value"] + ) ) name: Annotated[str, AttrParam(unique=True)] = Field(description="The name of the site") @@ -25,7 +37,9 @@ class Site(NodeModel): class Vlan(NodeModel): model_config = ConfigDict( - node_schema=NodeSchema(name="Vlan", namespace="Infra", human_friendly_id=["vlan_id__value"], display_labels=["vlan_id__value"]) + node_schema=NodeSchema( + name="Vlan", namespace="Infra", human_friendly_id=["vlan_id__value"], display_labels=["vlan_id__value"] + ) ) name: str @@ -35,39 +49,45 @@ class Vlan(NodeModel): class Device(NodeModel): model_config = ConfigDict( - node_schema=NodeSchema(name="Device", namespace="Infra", human_friendly_id=["name__value"], display_labels=["name__value"]) + node_schema=NodeSchema( + name="Device", namespace="Infra", human_friendly_id=["name__value"], display_labels=["name__value"] + ) ) name: Annotated[str, AttrParam(unique=True)] = Field(description="The name of the car") site: Annotated[Site, RelParam(kind=RelationshipKind.ATTRIBUTE, identifier="device__site")] - interfaces: Annotated[list[Interface], RelParam(kind=RelationshipKind.COMPONENT, identifier="device__interfaces")] = Field(default_factory=list) + interfaces: Annotated[ + list[Interface], RelParam(kind=RelationshipKind.COMPONENT, identifier="device__interfaces") + ] = Field(default_factory=list) class Interface(GenericModel): model_config = ConfigDict( - generic_schema=GenericSchema(name="Interface", namespace="Infra", human_friendly_id=["device__name__value", "name__value"], display_labels=["name__value"]) + generic_schema=GenericSchema( + name="Interface", + namespace="Infra", + human_friendly_id=["device__name__value", "name__value"], + display_labels=["name__value"], + ) ) device: Annotated[Device, RelParam(kind=RelationshipKind.PARENT, identifier="device__interfaces")] name: str description: str | None = None + class L2Interface(Interface): - model_config = ConfigDict( - node_schema=NodeSchema(name="L2Interface", namespace="Infra") - ) - + model_config = ConfigDict(node_schema=NodeSchema(name="L2Interface", namespace="Infra")) + vlans: list[Vlan] = Field(default_factory=list) + class LoopbackInterface(Interface): - model_config = ConfigDict( - node_schema=NodeSchema(name="LoopbackInterface", namespace="Infra") - ) - + model_config = ConfigDict(node_schema=NodeSchema(name="LoopbackInterface", namespace="Infra")) @app.command() -async def load_schema(): +async def load_schema() -> None: client = InfrahubClient() schema = from_pydantic(models=[Site, Device, Interface, L2Interface, LoopbackInterface, Vlan]) rprint(schema.to_schema_dict()) @@ -76,7 +96,7 @@ async def load_schema(): @app.command() -async def load_data(): +async def load_data() -> None: client = InfrahubClient() atl = await client.create("InfraSite", name="ATL") @@ -100,14 +120,15 @@ async def load_data(): @app.command() -async def query_data(): +async def query_data() -> None: client = InfrahubClient() sites = await client.all(kind=Site) + rprint(sites) - breakpoint() devices = await client.all(kind=Device) for device in devices: rprint(device) + if __name__ == "__main__": - app() \ No newline at end of file + app() From 41db07741972fea9eb6fadbac908684105d02450 Mon Sep 17 00:00:00 2001 From: Damien Garros Date: Sun, 30 Mar 2025 18:11:27 +0200 Subject: [PATCH 4/5] Fix conflict --- tests/unit/sdk/test_pydantic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/sdk/test_pydantic.py b/tests/unit/sdk/test_pydantic.py index df6fdd74..bd221a7a 100644 --- a/tests/unit/sdk/test_pydantic.py +++ b/tests/unit/sdk/test_pydantic.py @@ -35,7 +35,7 @@ class MyAllInOneModel(BaseModel): is_active: bool opt_age: int | None = None default_name: str = "some_default" - old_opt_age: Optional[int] = None # noqa: UP007 + old_opt_age: Optional[int] = None class AcmeTag(BaseModel): From fded2fd2dfe25ae8f085d2b4cbc698cfd0f20764 Mon Sep 17 00:00:00 2001 From: Damien Garros Date: Mon, 25 Aug 2025 13:53:46 +0200 Subject: [PATCH 5/5] Refactor for deeper Pydantic integration --- docs/docs/python-sdk/examples/pydantic_car.py | 31 +- .../python-sdk/examples/pydantic_infra.py | 63 ++- infrahub_sdk/schema/__init__.py | 4 - infrahub_sdk/schema/pydantic_utils.py | 386 ++++++++++++++---- tests/unit/sdk/test_pydantic.py | 194 +++++---- 5 files changed, 466 insertions(+), 212 deletions(-) diff --git a/docs/docs/python-sdk/examples/pydantic_car.py b/docs/docs/python-sdk/examples/pydantic_car.py index 11102d19..58d0501e 100644 --- a/docs/docs/python-sdk/examples/pydantic_car.py +++ b/docs/docs/python-sdk/examples/pydantic_car.py @@ -14,28 +14,35 @@ NodeSchema, from_pydantic, ) -from infrahub_sdk.schema import ( - InfrahubAttributeParam as AttrParam, -) -from infrahub_sdk.schema import ( - InfrahubRelationshipParam as RelParam, +from infrahub_sdk.schema.pydantic_utils import ( + Attribute, + GenericModel, + InfrahubConfig, + NodeModel, + Relationship, + SchemaModel, + analyze_field, + field_to_attribute, + field_to_relationship, + from_pydantic, + get_attribute_kind, + get_kind, + model_to_node, ) class Tag(NodeModel): - model_config = ConfigDict( - node_schema=NodeSchema(name="Tag", namespace="Test", human_readable_fields=["name__value"]) - ) + model_config = InfrahubConfig(namespace="Test", human_readable_fields=["name__value"]) - name: Annotated[str, AttrParam(unique=True), Field(description="The name of the tag")] + name: str = Attribute(unique=True, description="The name of the tag") label: str | None = Field(description="The label of the tag") - description: Annotated[str | None, AttrParam(kind=AttributeKind.TEXTAREA)] = None + description: str | None = Attribute(None, kind=AttributeKind.TEXTAREA) class TestCar(NodeModel): name: str = Field(description="The name of the car") tags: list[Tag] - owner: Annotated[TestPerson, RelParam(identifier="car__person")] + owner: TestPerson = Relationship(identifier="car__person")] secondary_owner: TestPerson | None = None @@ -44,7 +51,7 @@ class TestPerson(GenericModel): class TestCarOwner(NodeModel, TestPerson): - cars: Annotated[list[TestCar] | None, RelParam(identifier="car__person")] = None + cars: list[TestCar] = Relationship(identifier="car__person") async def main() -> None: diff --git a/docs/docs/python-sdk/examples/pydantic_infra.py b/docs/docs/python-sdk/examples/pydantic_infra.py index a7182c1b..52599274 100644 --- a/docs/docs/python-sdk/examples/pydantic_infra.py +++ b/docs/docs/python-sdk/examples/pydantic_infra.py @@ -8,39 +8,41 @@ from infrahub_sdk import InfrahubClient from infrahub_sdk.async_typer import AsyncTyper from infrahub_sdk.schema import ( - GenericModel, GenericSchema, - NodeModel, NodeSchema, RelationshipKind, - from_pydantic, ) -from infrahub_sdk.schema import ( - InfrahubAttributeParam as AttrParam, -) -from infrahub_sdk.schema import ( - InfrahubRelationshipParam as RelParam, +from infrahub_sdk.schema.pydantic_utils import ( + Attribute, + GenericModel, + InfrahubConfig, + NodeModel, + Relationship, + SchemaModel, + analyze_field, + field_to_attribute, + field_to_relationship, + from_pydantic, + get_attribute_kind, + get_kind, + model_to_node, ) app = AsyncTyper() class Site(NodeModel): - model_config = ConfigDict( - node_schema=NodeSchema( - name="Site", namespace="Infra", human_friendly_id=["name__value"], display_labels=["name__value"] - ) + model_config = InfrahubConfig( + namespace="Infra", human_friendly_id=["name__value"], display_labels=["name__value"] ) - name: Annotated[str, AttrParam(unique=True)] = Field(description="The name of the site") + name: str = Attribute(unique=True, description="The name of the site") class Vlan(NodeModel): - model_config = ConfigDict( - node_schema=NodeSchema( - name="Vlan", namespace="Infra", human_friendly_id=["vlan_id__value"], display_labels=["vlan_id__value"] + model_config = InfrahubConfig( + namespace="Infra", human_friendly_id=["vlan_id__value"], display_labels=["vlan_id__value"] ) - ) name: str vlan_id: int @@ -48,42 +50,33 @@ class Vlan(NodeModel): class Device(NodeModel): - model_config = ConfigDict( - node_schema=NodeSchema( + model_config = InfrahubConfig( name="Device", namespace="Infra", human_friendly_id=["name__value"], display_labels=["name__value"] ) - ) - name: Annotated[str, AttrParam(unique=True)] = Field(description="The name of the car") - site: Annotated[Site, RelParam(kind=RelationshipKind.ATTRIBUTE, identifier="device__site")] - interfaces: Annotated[ - list[Interface], RelParam(kind=RelationshipKind.COMPONENT, identifier="device__interfaces") - ] = Field(default_factory=list) + name: str = Attribute(unique=True, description="The name of the car") + site: Site = Relationship(kind=RelationshipKind.ATTRIBUTE, identifier="device__site") + interfaces: list[Interface] = Relationship(kind=RelationshipKind.COMPONENT, identifier="device__interfaces") class Interface(GenericModel): - model_config = ConfigDict( - generic_schema=GenericSchema( - name="Interface", - namespace="Infra", - human_friendly_id=["device__name__value", "name__value"], - display_labels=["name__value"], - ) + model_config = InfrahubConfig( + namespace="Infra", human_friendly_id=["device__name__value", "name__value"], display_labels=["name__value"] ) - device: Annotated[Device, RelParam(kind=RelationshipKind.PARENT, identifier="device__interfaces")] + device: Device = Relationship(kind=RelationshipKind.PARENT, identifier="device__interfaces") name: str description: str | None = None class L2Interface(Interface): - model_config = ConfigDict(node_schema=NodeSchema(name="L2Interface", namespace="Infra")) + model_config = InfrahubConfig(namespace="Infra") vlans: list[Vlan] = Field(default_factory=list) class LoopbackInterface(Interface): - model_config = ConfigDict(node_schema=NodeSchema(name="LoopbackInterface", namespace="Infra")) + model_config = InfrahubConfig(namespace="Infra") @app.command() diff --git a/infrahub_sdk/schema/__init__.py b/infrahub_sdk/schema/__init__.py index 118a9599..b0c1e292 100644 --- a/infrahub_sdk/schema/__init__.py +++ b/infrahub_sdk/schema/__init__.py @@ -43,8 +43,6 @@ ) from .pydantic_utils import ( GenericModel, - InfrahubAttributeParam, - InfrahubRelationshipParam, NodeModel, SchemaModel, from_pydantic, @@ -65,8 +63,6 @@ "GenericModel", "GenericSchema", "GenericSchemaAPI", - "InfrahubAttributeParam", - "InfrahubRelationshipParam", "NodeModel", "NodeSchema", "NodeSchemaAPI", diff --git a/infrahub_sdk/schema/pydantic_utils.py b/infrahub_sdk/schema/pydantic_utils.py index 0f605c57..1833d91e 100644 --- a/infrahub_sdk/schema/pydantic_utils.py +++ b/infrahub_sdk/schema/pydantic_utils.py @@ -4,10 +4,14 @@ import typing from dataclasses import dataclass from types import UnionType -from typing import TYPE_CHECKING, Any - -from pydantic import BaseModel, Field -from pydantic.fields import FieldInfo, PydanticUndefined +from typing import TYPE_CHECKING, Any, Callable, ForwardRef, Literal, TypeVar, Union + +from pydantic import BaseModel +from pydantic import ConfigDict as BaseConfig +from pydantic._internal._model_construction import ModelMetaclass # noqa: PLC2701 +from pydantic._internal._repr import Representation # noqa: PLC2701 +from pydantic.fields import FieldInfo as PydanticFieldInfo +from pydantic.fields import PydanticUndefined as Undefined from typing_extensions import Self from .main import ( @@ -25,6 +29,8 @@ if TYPE_CHECKING: from infrahub_sdk.node import InfrahubNode, InfrahubNodeSync +_T = TypeVar("_T") + KIND_MAPPING: dict[type, AttributeKind] = { int: AttributeKind.NUMBER, float: AttributeKind.NUMBER, @@ -36,8 +42,181 @@ NODE_KIND_REGEX = r"^[A-Z][a-zA-Z0-9]+$" -class SchemaModel(BaseModel): - id: str | None = Field(default=None, description="The ID of the node") +def __dataclass_transform__( + *, + eq_default: bool = True, + order_default: bool = False, + kw_only_default: bool = False, + field_descriptors: tuple[Union[type, Callable[..., Any]], ...] = (()), +) -> Callable[[_T], _T]: + return lambda a: a + + +class InfrahubConfig(BaseConfig, total=False): + generic: bool = False + name: str | None = None + namespace: str | None = None + display_labels: list[str] | None = None + description: str | None = None + state: SchemaState = SchemaState.PRESENT + label: str | None = None + include_in_menu: bool | None = None + menu_placement: str | None = None + + +class AttributeInfo(PydanticFieldInfo): + def __init__(self, default: Any = Undefined, **kwargs: Any) -> None: + unique = kwargs.pop("unique", False) + label = kwargs.pop("label", None) + kind = kwargs.pop("kind", None) + regex = kwargs.pop("regex", None) + branch = kwargs.pop("branch", None) + super().__init__(default=default, **kwargs) + self.unique = unique + self.label = label + self.kind = kind + self.regex = regex + self.branch = branch + + +class RelationshipInfo(Representation): + def __init__( + self, + *, + alias: str | None = None, + kind: RelationshipKind | None = None, + peer: str | None = None, + description: str | None = None, + identifier: str | None = None, + branch: BranchSupportType | None = None, + optional: bool = False, + ) -> None: + self.alias = alias + self.kind = kind + self.identifier = identifier + self.branch = branch + self.description = description + self.peer = peer + self.optional = optional + + +def Relationship( + *, + alias: str | None = None, + kind: RelationshipKind | None = None, + identifier: str | None = None, + branch: BranchSupportType | None = None, + peer: str | None = None, + description: str | None = None, + optional: bool = False, +) -> Any: + relationship_info = RelationshipInfo( + alias=alias, + kind=kind, + identifier=identifier, + branch=branch, + peer=peer, + description=description, + optional=optional, + ) + return relationship_info + + +def Attribute( + default: Any = Undefined, + *, + alias: str | None = None, + description: str | None = None, + state: SchemaState = SchemaState.PRESENT, + kind: AttributeKind | None = None, + label: str | None = None, + unique: bool = False, + branch: BranchSupportType | None = None, + regex: str | None = None, + pattern: str | None = None, +) -> Any: + current_schema_extra = {} + field_info = AttributeInfo( + default, + alias=alias, + description=description, + state=state, + kind=kind, + label=label, + unique=unique, + branch=branch, + regex=regex, + pattern=pattern, + **current_schema_extra, + ) + return field_info + + +@__dataclass_transform__(kw_only_default=True, field_descriptors=(Attribute, AttributeInfo)) +class InfrahubMetaclass(ModelMetaclass): + __infrahub_relationships__: dict[str, RelationshipInfo] + model_config: InfrahubConfig + model_fields: dict[str, AttributeInfo] + + def __new__( + cls, + name: str, + bases: tuple[type[Any], ...], + class_dict: dict[str, Any], + **kwargs: Any, + ) -> Any: + relationships: dict[str, RelationshipInfo] = {} + dict_for_pydantic = {} + original_annotations: dict[str, Any] = class_dict.get("__annotations__", {}) + pydantic_annotations = {} + relationship_annotations = {} + for k, v in class_dict.items(): + if isinstance(v, RelationshipInfo): + relationships[k] = v + else: + dict_for_pydantic[k] = v + for k, v in original_annotations.items(): + if k in relationships: + relationship_annotations[k] = v + else: + pydantic_annotations[k] = v + dict_used = { + **dict_for_pydantic, + "__infrahub_relationships__": relationships, + "__annotations__": pydantic_annotations, + } + # Duplicate logic from Pydantic to filter config kwargs because if they are + # passed directly including the registry Pydantic will pass them over to the + # superclass causing an error + allowed_config_kwargs: set[str] = { + key + for key in dir(BaseConfig) + if not (key.startswith("__") and key.endswith("__")) # skip dunder methods and attributes + } + config_kwargs = {key: kwargs[key] for key in kwargs.keys() & allowed_config_kwargs} + new_cls = super().__new__(cls, name, bases, dict_used, **config_kwargs) + new_cls.__annotations__ = { + **relationship_annotations, + **pydantic_annotations, + **new_cls.__annotations__, + } + + # def get_config(name: str) -> Any: + # config_class_value = new_cls.model_config.get(name, Undefined) + # if config_class_value is not Undefined: + # return config_class_value + # kwarg_value = kwargs.get(name, Undefined) + # if kwarg_value is not Undefined: + # return kwarg_value + # return Undefined + + # new_cls.model_config["generic"] = get_config("generic") + + return new_cls + + +class SchemaModel(BaseModel, metaclass=InfrahubMetaclass): + id: str | None = Attribute(default=None, description="The ID of the node") @classmethod def get_kind(cls) -> str: @@ -69,32 +248,17 @@ class GenericModel(SchemaModel): pass -@dataclass -class InfrahubAttributeParam: - state: SchemaState = SchemaState.PRESENT - kind: AttributeKind | None = None - label: str | None = None - unique: bool = False - branch: BranchSupportType | None = None - - -@dataclass -class InfrahubRelationshipParam: - kind: RelationshipKind | None = None - identifier: str | None = None - branch: BranchSupportType | None = None - - @dataclass class InfrahubFieldInfo: name: str types: list[type] optional: bool default: Any + field_kind: Literal["attribute", "relationship"] | None = None @property def primary_type(self) -> type: - if len(self.types) == 0: + if not self.types: raise ValueError("No types found") # if isinstance(self.primary_type, ForwardRef): @@ -107,15 +271,20 @@ def primary_type(self) -> type: @property def is_attribute(self) -> bool: + if self.field_kind == "attribute": + return True return self.primary_type in KIND_MAPPING @property def is_relationship(self) -> bool: + if self.field_kind == "relationship": + return True + if isinstance(self.primary_type, ForwardRef): + return True return issubclass(self.primary_type, BaseModel) @property def is_list(self) -> bool: - # breakpoint() return typing.get_origin(self.types[0]) is list def to_dict(self) -> dict: @@ -130,7 +299,16 @@ def to_dict(self) -> dict: } -def analyze_field(field_name: str, field: FieldInfo) -> InfrahubFieldInfo: +def analyze_field(field_name: str, field: AttributeInfo | RelationshipInfo | PydanticFieldInfo) -> InfrahubFieldInfo: + if isinstance(field, RelationshipInfo): + return InfrahubFieldInfo( + name=field.alias or field_name, + types=[field.peer] if field.peer else [], + optional=field.optional, + field_kind="relationship", + default=None, + ) + clean_types = [] if isinstance(field.annotation, UnionType) or ( hasattr(field.annotation, "_name") and field.annotation._name == "Optional" # type: ignore[union-attr] @@ -143,11 +321,14 @@ def analyze_field(field_name: str, field: FieldInfo) -> InfrahubFieldInfo: name=field.alias or field_name, types=clean_types, optional=not field.is_required(), - default=field.default if field.default is not PydanticUndefined else None, + default=field.default if field.default is not Undefined else None, ) -def get_attribute_kind(field: FieldInfo) -> AttributeKind: +def get_attribute_kind(field: AttributeInfo | PydanticFieldInfo) -> AttributeKind: + if isinstance(field, AttributeInfo) and field.kind: + return field.kind + if field.annotation in KIND_MAPPING: return KIND_MAPPING[field.annotation] @@ -161,24 +342,36 @@ def get_attribute_kind(field: FieldInfo) -> AttributeKind: raise ValueError(f"Unknown field type: {field.annotation}") -def field_to_attribute(field_name: str, field_info: InfrahubFieldInfo, field: FieldInfo) -> AttributeSchema: - field_param = InfrahubAttributeParam() - field_params = [metadata for metadata in field.metadata if isinstance(metadata, InfrahubAttributeParam)] - if len(field_params) == 1: - field_param = field_params[0] - +def field_to_attribute( + field_name: str, field_info: InfrahubFieldInfo, field: AttributeInfo | PydanticFieldInfo +) -> AttributeSchema: pattern = field._attributes_set.get("pattern", None) max_length = field._attributes_set.get("max_length", None) min_length = field._attributes_set.get("min_length", None) + if isinstance(field, AttributeInfo): + return AttributeSchema( + name=field_name, + label=field.label, + description=field.description, + kind=get_attribute_kind(field), + optional=field_info.optional, # not field.is_required(), + unique=field.unique, + branch=field.branch, + default_value=field_info.default, + regex=str(pattern) if pattern else None, + max_length=int(str(max_length)) if max_length else None, + min_length=int(str(min_length)) if min_length else None, + ) + return AttributeSchema( name=field_name, - label=field_param.label, + # label=field.label, description=field.description, - kind=field_param.kind or get_attribute_kind(field), + kind=get_attribute_kind(field), optional=not field.is_required(), - unique=field_param.unique, - branch=field_param.branch, + # unique=field.unique, + # branch=field.branch, default_value=field_info.default, regex=str(pattern) if pattern else None, max_length=int(str(max_length)) if max_length else None, @@ -189,21 +382,25 @@ def field_to_attribute(field_name: str, field_info: InfrahubFieldInfo, field: Fi def field_to_relationship( field_name: str, field_info: InfrahubFieldInfo, - field: FieldInfo, + field: RelationshipInfo | PydanticFieldInfo, ) -> RelationshipSchema: - field_param = InfrahubRelationshipParam() - field_params = [metadata for metadata in field.metadata if isinstance(metadata, InfrahubRelationshipParam)] - if len(field_params) == 1: - field_param = field_params[0] + if isinstance(field, RelationshipInfo): + return RelationshipSchema( + name=field_name, + description=field.description, + peer=field.peer or get_kind(field_info.primary_type), + identifier=field.identifier, + cardinality="many" if field_info.is_list else "one", + optional=field_info.optional, + branch=field.branch, + ) return RelationshipSchema( name=field_name, description=field.description, peer=get_kind(field_info.primary_type), - identifier=field_param.identifier, cardinality="many" if field_info.is_list else "one", optional=field_info.optional, - branch=field_param.branch, ) @@ -212,6 +409,11 @@ def extract_validate_generic(model: type[BaseModel]) -> list[str]: def validate_kind(kind: str) -> tuple[str, str]: + """Validate the kind of a model. + + TODO Move the function to the main module + """ + # First, handle transition from a lowercase to uppercase name_with_spaces = re.sub(r"([a-z])([A-Z])", r"\1 \2", kind) @@ -236,14 +438,32 @@ def is_generic(model: type[BaseModel]) -> bool: return GenericModel in model.__bases__ -def get_kind(model: type[BaseModel]) -> str: - node_schema: NodeSchema | None = model.model_config.get("node_schema") or None # type: ignore[assignment] - generic_schema: GenericSchema | None = model.model_config.get("generic_schema") or None # type: ignore[assignment] +def get_kind(model: type[BaseModel] | ForwardRef) -> str: + """Get the kind of a model. + + If the model name and namespace are set in model_config, return the full kind. + If the model namespace is set in model_config, use the name of the class as name. + If the model has no name or namespace, extract both from the name of the class. + """ + + model_class: type[BaseModel] + + if isinstance(model, type) and issubclass(model, BaseModel): + model_class = model + elif isinstance(model, ForwardRef): + return model.__forward_arg__ + else: + raise ValueError(f"Expected BaseModel class, got {model}") + + name = model_class.model_config.get("name", None) + namespace = model_class.model_config.get("namespace", None) + class_name = model_class.__name__ + + if name and namespace: + return f"{namespace}{name}" + if namespace and not name and not class_name.startswith(namespace): + return f"{namespace}{class_name}" - if is_generic(model) and generic_schema: - return generic_schema.kind - if node_schema: - return node_schema.kind namespace, name = validate_kind(model.__name__) return f"{namespace}{name}" @@ -278,26 +498,25 @@ def model_to_node(model: type[BaseModel]) -> NodeSchema | GenericSchema: # ------------------------------------------------------------ # GenericSchema # ------------------------------------------------------------ - if GenericModel in model.__bases__: - generic_schema: GenericSchema | None = model.model_config.get("generic_schema") or None # type: ignore[assignment] - if not generic_schema: - namespace, name = validate_kind(model.__name__) + kind = get_kind(model) + namespace, name = validate_kind(kind) + if GenericModel in model.__bases__: generic = GenericSchema( - name=generic_schema.name if generic_schema else name, - namespace=generic_schema.namespace if generic_schema else namespace, - display_labels=generic_schema.display_labels if generic_schema else None, - description=generic_schema.description if generic_schema else None, - state=generic_schema.state if generic_schema else SchemaState.PRESENT, - label=generic_schema.label if generic_schema else None, - include_in_menu=generic_schema.include_in_menu if generic_schema else None, - menu_placement=generic_schema.menu_placement if generic_schema else None, - documentation=generic_schema.documentation if generic_schema else None, - order_by=generic_schema.order_by if generic_schema else None, + name=name, + namespace=namespace, + display_labels=model.model_config.get("display_labels", None), + description=model.model_config.get("description", None), + state=model.model_config.get("state", SchemaState.PRESENT), + label=model.model_config.get("label", None), + # include_in_menu=generic_schema.include_in_menu if generic_schema else None, + # menu_placement=generic_schema.menu_placement if generic_schema else None, + # documentation=generic_schema.documentation if generic_schema else None, + # order_by=generic_schema.order_by if generic_schema else None, # parent=schema.parent if schema else None, # children=schema.children if schema else None, - icon=generic_schema.icon if generic_schema else None, + # icon=generic_schema.icon if generic_schema else None, # generate_profile=schema.generate_profile if schema else None, # branch=schema.branch if schema else None, # default_filter=schema.default_filter if schema else None, @@ -308,11 +527,6 @@ def model_to_node(model: type[BaseModel]) -> NodeSchema | GenericSchema: # ------------------------------------------------------------ # NodeSchema # ------------------------------------------------------------ - node_schema: NodeSchema | None = model.model_config.get("node_schema") or None # type: ignore[assignment] - - if not node_schema: - namespace, name = validate_kind(model.__name__) - generics = get_generics(model) # list all inherited fields with a hash for each to track if they are identical on the node @@ -321,22 +535,22 @@ def model_to_node(model: type[BaseModel]) -> NodeSchema | GenericSchema: } node = NodeSchema( - name=node_schema.name if node_schema else name, - namespace=node_schema.namespace if node_schema else namespace, - display_labels=node_schema.display_labels if node_schema else None, - description=node_schema.description if node_schema else None, - state=node_schema.state if node_schema else SchemaState.PRESENT, - label=node_schema.label if node_schema else None, - include_in_menu=node_schema.include_in_menu if node_schema else None, - menu_placement=node_schema.menu_placement if node_schema else None, - documentation=node_schema.documentation if node_schema else None, - order_by=node_schema.order_by if node_schema else None, + name=name, + namespace=namespace, + display_labels=model.model_config.get("display_labels", None), + description=model.model_config.get("description", None), + state=model.model_config.get("state", SchemaState.PRESENT), + label=model.model_config.get("label", None), + # include_in_menu=node_schema.include_in_menu if node_schema else None, + # menu_placement=node_schema.menu_placement if node_schema else None, + # documentation=node_schema.documentation if node_schema else None, + # order_by=node_schema.order_by if node_schema else None, inherit_from=[get_kind(generic) for generic in generics], - parent=node_schema.parent if node_schema else None, - children=node_schema.children if node_schema else None, - icon=node_schema.icon if node_schema else None, - generate_profile=node_schema.generate_profile if node_schema else None, - branch=node_schema.branch if node_schema else None, + # parent=node_schema.parent if node_schema else None, + # children=node_schema.children if node_schema else None, + # icon=node_schema.icon if node_schema else None, + # generate_profile=node_schema.generate_profile if node_schema else None, + # branch=node_schema.branch if node_schema else None, # default_filter=schema.default_filter if schema else None, ) diff --git a/tests/unit/sdk/test_pydantic.py b/tests/unit/sdk/test_pydantic.py index bd221a7a..fc074b8b 100644 --- a/tests/unit/sdk/test_pydantic.py +++ b/tests/unit/sdk/test_pydantic.py @@ -1,9 +1,9 @@ from __future__ import annotations -from typing import Annotated, Optional +from typing import ForwardRef, Optional import pytest -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel from infrahub_sdk.schema.main import ( AttributeKind, @@ -14,8 +14,12 @@ SchemaState, ) from infrahub_sdk.schema.pydantic_utils import ( + Attribute, GenericModel, + InfrahubConfig, NodeModel, + Relationship, + SchemaModel, analyze_field, field_to_attribute, field_to_relationship, @@ -24,12 +28,9 @@ get_kind, model_to_node, ) -from infrahub_sdk.schema.pydantic_utils import ( - InfrahubAttributeParam as AttrParam, -) -class MyAllInOneModel(BaseModel): +class MyAllInOneModel(NodeModel): name: str age: int is_active: bool @@ -38,20 +39,20 @@ class MyAllInOneModel(BaseModel): old_opt_age: Optional[int] = None -class AcmeTag(BaseModel): - name: str = Field(default="test_tag", description="The name of the tag") - description: Annotated[str | None, AttrParam(kind=AttributeKind.TEXTAREA)] = None - label: Annotated[str, AttrParam(unique=True), Field(description="The label of the tag")] +class AcmeTag(NodeModel): + name: str = Attribute(default="test_tag", description="The name of the tag") + description: str | None = Attribute(None, kind=AttributeKind.TEXTAREA) + label: str = Attribute(unique=True, description="The label of the tag") -class AcmeCar(BaseModel): +class AcmeCar(NodeModel): name: str tags: list[AcmeTag] owner: AcmePerson - secondary_owner: AcmePerson | None = None + secondary_owner: AcmePerson | None = Relationship(peer="AcmePerson", optional=True) -class AcmePerson(BaseModel): +class AcmePerson(NodeModel): name: str cars: list[AcmeCar] | None = None @@ -60,16 +61,17 @@ class AcmePerson(BaseModel): class Book(NodeModel): - model_config = ConfigDict(node_schema=NodeSchema(name="Book", namespace="Library", display_labels=["name__value"])) + model_config = InfrahubConfig(name="Book", namespace="Library", display_labels=["name__value"]) + title: str - isbn: Annotated[str, AttrParam(unique=True)] + isbn: str = Attribute(..., unique=True) created_at: str author: LibraryAuthor class AbstractPerson(GenericModel): - model_config = ConfigDict(generic_schema=GenericSchema(name="AbstractPerson", namespace="Library")) - firstname: str = Field(..., description="The first name of the person", pattern=r"^[a-zA-Z]+$") + model_config = InfrahubConfig(namespace="Library") + firstname: str = Attribute(..., description="The first name of the person", pattern=r"^[a-zA-Z]+$") lastname: str @@ -85,12 +87,12 @@ class LibraryReader(AbstractPerson): @pytest.mark.parametrize( "field_name, expected_kind", [ - ("name", "Text"), - ("age", "Number"), - ("is_active", "Boolean"), - ("opt_age", "Number"), - ("default_name", "Text"), - ("old_opt_age", "Number"), + pytest.param("name", "Text", id="name_field"), + pytest.param("age", "Number", id="age_field"), + pytest.param("is_active", "Boolean", id="is_active_field"), + pytest.param("opt_age", "Number", id="opt_age_field"), + pytest.param("default_name", "Text", id="default_name_field"), + pytest.param("old_opt_age", "Number", id="old_opt_age_field"), ], ) def test_get_field_kind(field_name, expected_kind): @@ -100,7 +102,7 @@ def test_get_field_kind(field_name, expected_kind): @pytest.mark.parametrize( "field_name, model, expected", [ - ( + pytest.param( "name", MyAllInOneModel, { @@ -112,8 +114,9 @@ def test_get_field_kind(field_name, expected_kind): "optional": False, "primary_type": str, }, + id="MyAllInOneModel_name", ), - ( + pytest.param( "age", MyAllInOneModel, { @@ -125,8 +128,9 @@ def test_get_field_kind(field_name, expected_kind): "optional": False, "primary_type": int, }, + id="MyAllInOneModel_age", ), - ( + pytest.param( "is_active", MyAllInOneModel, { @@ -138,8 +142,9 @@ def test_get_field_kind(field_name, expected_kind): "optional": False, "primary_type": bool, }, + id="MyAllInOneModel_is_active", ), - ( + pytest.param( "opt_age", MyAllInOneModel, { @@ -151,8 +156,9 @@ def test_get_field_kind(field_name, expected_kind): "optional": True, "primary_type": int, }, + id="MyAllInOneModel_opt_age", ), - ( + pytest.param( "default_name", MyAllInOneModel, { @@ -164,8 +170,9 @@ def test_get_field_kind(field_name, expected_kind): "optional": True, "primary_type": str, }, + id="MyAllInOneModel_default_name", ), - ( + pytest.param( "old_opt_age", MyAllInOneModel, { @@ -177,8 +184,9 @@ def test_get_field_kind(field_name, expected_kind): "optional": True, "primary_type": int, }, + id="MyAllInOneModel_old_opt_age", ), - ( + pytest.param( "description", AcmeTag, { @@ -190,8 +198,9 @@ def test_get_field_kind(field_name, expected_kind): "optional": True, "primary_type": str, }, + id="AcmeTag_description", ), - ( + pytest.param( "name", AcmeTag, { @@ -203,8 +212,9 @@ def test_get_field_kind(field_name, expected_kind): "optional": True, "primary_type": str, }, + id="AcmeTag_name", ), - ( + pytest.param( "label", AcmeTag, { @@ -216,8 +226,9 @@ def test_get_field_kind(field_name, expected_kind): "optional": False, "primary_type": str, }, + id="AcmeTag_label", ), - ( + pytest.param( "owner", AcmeCar, { @@ -227,10 +238,11 @@ def test_get_field_kind(field_name, expected_kind): "is_relationship": True, "name": "owner", "optional": False, - "primary_type": AcmePerson, + "primary_type": ForwardRef("AcmePerson"), }, + id="AcmeCar_owner", ), - ( + pytest.param( "tags", AcmeCar, { @@ -242,8 +254,9 @@ def test_get_field_kind(field_name, expected_kind): "optional": False, "primary_type": AcmeTag, }, + id="AcmeCar_tags", ), - ( + pytest.param( "secondary_owner", AcmeCar, { @@ -253,19 +266,26 @@ def test_get_field_kind(field_name, expected_kind): "is_relationship": True, "name": "secondary_owner", "optional": True, - "primary_type": AcmePerson, + "primary_type": "AcmePerson", }, + id="AcmeCar_secondary_owner", ), ], ) -def test_analyze_field(field_name: str, model: BaseModel, expected: dict): - assert analyze_field(field_name, model.model_fields[field_name]).to_dict() == expected +def test_analyze_field(field_name: str, model: type[BaseModel], expected: dict): + if field_name in model.model_fields: + field = model.model_fields[field_name] + elif issubclass(model, SchemaModel) and field_name in model.__infrahub_relationships__: + field = model.__infrahub_relationships__[field_name] + else: + raise ValueError(f"Field {field_name} not found in model {model}") + assert analyze_field(field_name=field_name, field=field).to_dict() == expected @pytest.mark.parametrize( "field_name, model, expected", [ - ( + pytest.param( "name", MyAllInOneModel, AttributeSchema( @@ -273,8 +293,9 @@ def test_analyze_field(field_name: str, model: BaseModel, expected: dict): kind=AttributeKind.TEXT, optional=False, ), + id="MyAllInOneModel_name", ), - ( + pytest.param( "age", MyAllInOneModel, AttributeSchema( @@ -282,8 +303,9 @@ def test_analyze_field(field_name: str, model: BaseModel, expected: dict): kind=AttributeKind.NUMBER, optional=False, ), + id="MyAllInOneModel_age", ), - ( + pytest.param( "is_active", MyAllInOneModel, AttributeSchema( @@ -291,8 +313,9 @@ def test_analyze_field(field_name: str, model: BaseModel, expected: dict): kind=AttributeKind.BOOLEAN, optional=False, ), + id="MyAllInOneModel_is_active", ), - ( + pytest.param( "opt_age", MyAllInOneModel, AttributeSchema( @@ -300,8 +323,9 @@ def test_analyze_field(field_name: str, model: BaseModel, expected: dict): kind=AttributeKind.NUMBER, optional=True, ), + id="MyAllInOneModel_opt_age", ), - ( + pytest.param( "default_name", MyAllInOneModel, AttributeSchema( @@ -310,8 +334,9 @@ def test_analyze_field(field_name: str, model: BaseModel, expected: dict): optional=True, default_value="some_default", ), + id="MyAllInOneModel_default_name", ), - ( + pytest.param( "old_opt_age", MyAllInOneModel, AttributeSchema( @@ -319,8 +344,9 @@ def test_analyze_field(field_name: str, model: BaseModel, expected: dict): kind=AttributeKind.NUMBER, optional=True, ), + id="MyAllInOneModel_old_opt_age", ), - ( + pytest.param( "description", AcmeTag, AttributeSchema( @@ -328,8 +354,9 @@ def test_analyze_field(field_name: str, model: BaseModel, expected: dict): kind=AttributeKind.TEXTAREA, optional=True, ), + id="AcmeTag_description", ), - ( + pytest.param( "name", AcmeTag, AttributeSchema( @@ -339,8 +366,9 @@ def test_analyze_field(field_name: str, model: BaseModel, expected: dict): optional=True, default_value="test_tag", ), + id="AcmeTag_name", ), - ( + pytest.param( "label", AcmeTag, AttributeSchema( @@ -350,8 +378,9 @@ def test_analyze_field(field_name: str, model: BaseModel, expected: dict): optional=False, unique=True, ), + id="AcmeTag_label", ), - ( + pytest.param( "firstname", AbstractPerson, AttributeSchema( @@ -361,10 +390,11 @@ def test_analyze_field(field_name: str, model: BaseModel, expected: dict): optional=False, regex=r"^[a-zA-Z]+$", ), + id="AbstractPerson_firstname", ), ], ) -def test_field_to_attribute(field_name: str, model: BaseModel, expected: AttributeSchema): +def test_field_to_attribute(field_name: str, model: type[BaseModel], expected: AttributeSchema): field = model.model_fields[field_name] field_info = analyze_field(field_name, field) assert field_to_attribute(field_name, field_info, field) == expected @@ -373,7 +403,7 @@ def test_field_to_attribute(field_name: str, model: BaseModel, expected: Attribu @pytest.mark.parametrize( "field_name, model, expected", [ - ( + pytest.param( "owner", AcmeCar, RelationshipSchema( @@ -382,8 +412,9 @@ def test_field_to_attribute(field_name: str, model: BaseModel, expected: Attribu cardinality="one", optional=False, ), + id="AcmeCar_owner", ), - ( + pytest.param( "tags", AcmeCar, RelationshipSchema( @@ -392,8 +423,9 @@ def test_field_to_attribute(field_name: str, model: BaseModel, expected: Attribu cardinality="many", optional=False, ), + id="AcmeCar_tags", ), - ( + pytest.param( "secondary_owner", AcmeCar, RelationshipSchema( @@ -402,11 +434,17 @@ def test_field_to_attribute(field_name: str, model: BaseModel, expected: Attribu cardinality="one", optional=True, ), + id="AcmeCar_secondary_owner", ), ], ) -def test_field_to_relationship(field_name: str, model: BaseModel, expected: RelationshipSchema): - field = model.model_fields[field_name] +def test_field_to_relationship(field_name: str, model: type[BaseModel | SchemaModel], expected: RelationshipSchema): + if field_name in model.model_fields: + field = model.model_fields[field_name] + elif issubclass(model, SchemaModel) and field_name in model.__infrahub_relationships__: + field = model.__infrahub_relationships__[field_name] + else: + raise ValueError(f"Field {field_name} not found in model {model}") field_info = analyze_field(field_name, field) assert field_to_relationship(field_name, field_info, field) == expected @@ -414,24 +452,24 @@ def test_field_to_relationship(field_name: str, model: BaseModel, expected: Rela @pytest.mark.parametrize( "model, expected", [ - (MyAllInOneModel, "MyAllInOneModel"), - (Book, "LibraryBook"), - (LibraryAuthor, "LibraryAuthor"), - (LibraryReader, "LibraryReader"), - (AbstractPerson, "LibraryAbstractPerson"), - (AcmeTag, "AcmeTag"), - (AcmeCar, "AcmeCar"), - (AcmePerson, "AcmePerson"), + pytest.param(MyAllInOneModel, "MyAllInOneModel", id="MyAllInOneModel"), + pytest.param(Book, "LibraryBook", id="Book"), + pytest.param(LibraryAuthor, "LibraryAuthor", id="LibraryAuthor"), + pytest.param(LibraryReader, "LibraryReader", id="LibraryReader"), + pytest.param(AbstractPerson, "LibraryAbstractPerson", id="AbstractPerson"), + pytest.param(AcmeTag, "AcmeTag", id="AcmeTag"), + pytest.param(AcmeCar, "AcmeCar", id="AcmeCar"), + pytest.param(AcmePerson, "AcmePerson", id="AcmePerson"), ], ) -def test_get_kind(model: BaseModel, expected: str): +def test_get_kind(model: type[BaseModel], expected: str): assert get_kind(model) == expected @pytest.mark.parametrize( "model, expected", [ - ( + pytest.param( MyAllInOneModel, NodeSchema( name="AllInOneModel", @@ -448,8 +486,9 @@ def test_get_kind(model: BaseModel, expected: str): AttributeSchema(name="old_opt_age", kind=AttributeKind.NUMBER, optional=True), ], ), + id="MyAllInOneModel", ), - ( + pytest.param( Book, NodeSchema( name="Book", @@ -473,8 +512,9 @@ def test_get_kind(model: BaseModel, expected: str): ), ], ), + id="Book", ), - ( + pytest.param( LibraryAuthor, NodeSchema( name="Author", @@ -485,8 +525,9 @@ def test_get_kind(model: BaseModel, expected: str): RelationshipSchema(name="books", peer="LibraryBook", cardinality="many", optional=False), ], ), + id="LibraryAuthor", ), - ( + pytest.param( LibraryReader, NodeSchema( name="Reader", @@ -500,8 +541,9 @@ def test_get_kind(model: BaseModel, expected: str): ), ], ), + id="LibraryReader", ), - ( + pytest.param( AbstractPerson, GenericSchema( name="AbstractPerson", @@ -518,8 +560,9 @@ def test_get_kind(model: BaseModel, expected: str): AttributeSchema(name="lastname", kind=AttributeKind.TEXT, optional=False), ], ), + id="AbstractPerson", ), - ( + pytest.param( AcmeTag, NodeSchema( name="Tag", @@ -543,10 +586,11 @@ def test_get_kind(model: BaseModel, expected: str): ), ], ), + id="AcmeTag", ), ], ) -def test_model_to_node(model: BaseModel, expected: NodeSchema): +def test_model_to_node(model: type[BaseModel], expected: NodeSchema): node = model_to_node(model) assert node == expected @@ -556,7 +600,7 @@ def test_related_models(): assert len(schemas.nodes) == 3 -def test_library_models(): - schemas = from_pydantic(models=[Book, AbstractPerson, LibraryAuthor, LibraryReader]) - assert len(schemas.nodes) == 3 - assert len(schemas.generics) == 1 +# def test_library_models(): +# schemas = from_pydantic(models=[Book, AbstractPerson, LibraryAuthor, LibraryReader]) +# assert len(schemas.nodes) == 3 +# assert len(schemas.generics) == 1