diff --git a/piccolo/testing/model_builder.py b/piccolo/testing/model_builder.py index 87e1c87fb..7aa3381ff 100644 --- a/piccolo/testing/model_builder.py +++ b/piccolo/testing/model_builder.py @@ -1,10 +1,12 @@ from __future__ import annotations import datetime +import inspect import json import typing as t from decimal import Decimal -from uuid import UUID +from functools import partial +from types import MappingProxyType from piccolo.columns import JSON, JSONB, Array, Column, ForeignKey from piccolo.custom_types import TableInstance @@ -13,18 +15,8 @@ class ModelBuilder: - __DEFAULT_MAPPER: t.Dict[t.Type, t.Callable] = { - bool: RandomBuilder.next_bool, - bytes: RandomBuilder.next_bytes, - datetime.date: RandomBuilder.next_date, - datetime.datetime: RandomBuilder.next_datetime, - float: RandomBuilder.next_float, - int: RandomBuilder.next_int, - str: RandomBuilder.next_str, - datetime.time: RandomBuilder.next_time, - datetime.timedelta: RandomBuilder.next_timedelta, - UUID: RandomBuilder.next_uuid, - } + __DEFAULT_MAPPER: t.Dict[t.Type, t.Callable] = {} + __OTHER_MAPPER: t.Dict[t.Type, t.Callable] = {} @classmethod async def build( @@ -106,7 +98,7 @@ async def _build( persist: bool = True, ) -> TableInstance: model = table_class(_ignore_missing=True) - defaults = {} if not defaults else defaults + defaults = defaults or {} for column, value in defaults.items(): if isinstance(column, str): @@ -159,29 +151,110 @@ def _randomize_attribute(cls, column: Column) -> t.Any: Column class to randomize. """ - random_value: t.Any - if column.value_type == Decimal: - precision, scale = column._meta.params["digits"] or (4, 2) - random_value = RandomBuilder.next_float( - maximum=10 ** (precision - scale), scale=scale - ) - elif column.value_type == datetime.datetime: - tz_aware = getattr(column, "tz_aware", False) - random_value = RandomBuilder.next_datetime(tz_aware=tz_aware) - elif column.value_type == list: - length = RandomBuilder.next_int(maximum=10) - base_type = t.cast(Array, column).base_column.value_type - random_value = [ - cls.__DEFAULT_MAPPER[base_type]() for _ in range(length) - ] - elif column._meta.choices: + reg = cls.get_registry(column) + if column._meta.choices: random_value = RandomBuilder.next_enum(column._meta.choices) else: - random_value = cls.__DEFAULT_MAPPER[column.value_type]() + random_value = reg[column.value_type]() - if "length" in column._meta.params and isinstance(random_value, str): - return random_value[: column._meta.params["length"]] - elif isinstance(column, (JSON, JSONB)): + if isinstance(column, (JSON, JSONB)): return json.dumps({"value": random_value}) - return random_value + + @classmethod + def get_registry( + cls, column: Column + ) -> MappingProxyType[t.Type, t.Callable]: + """ + This serves as the public API allowing users to **view** + the complete registry for the specified column. + + :param column: + Column class to randomize. + + """ + default_mapper = cls.__DEFAULT_MAPPER + if not default_mapper: # execute once only + for typ, callable_ in RandomBuilder.get_mapper().items(): + default_mapper[typ] = callable_ + + # order matters + reg = { + **default_mapper, + **cls._get_local_mapper(column), + **cls._get_other_mapper(column), + } + + if column.value_type == list: + reg[list] = partial( + RandomBuilder.next_list, + reg[t.cast(Array, column).base_column.value_type], + ) + return MappingProxyType(reg) + + @classmethod + def _get_local_mapper(cls, column: Column) -> t.Dict[t.Type, t.Callable]: + """ + This classmethod encapsulates the desired logic, utilizing information + from the column. + + :param column: + Column class to randomize. + """ + local_mapper: t.Dict[t.Type, t.Callable] = {} + + precision, scale = column._meta.params.get("digits") or (4, 2) + local_mapper[Decimal] = partial( + RandomBuilder.next_decimal, precision, scale + ) + + tz_aware = getattr(column, "tz_aware", False) + local_mapper[datetime.datetime] = partial( + RandomBuilder.next_datetime, tz_aware + ) + + if _length := column._meta.params.get("length"): + local_mapper[str] = partial(RandomBuilder.next_str, _length) + + return local_mapper + + @classmethod + def _get_other_mapper(cls, column: Column) -> t.Dict[t.Type, t.Callable]: + """ + This is a hook that allows users to register their own random type + callable. If the callable has a parameter named `column`, we assist + by injecting `column` using `partial`. + + :param column: + Column class to randomize. + + Examples:: + + # a callable not utilizing column information + ModelBuilder.register_random_type(str, lambda: "piccolo") + + # a callable utilizing the column information + def next_str(column: Column) -> str: + length = column._meta.params.get("length", 5) + return "".join("a" for _ in range(length)) + ) + ModelBuilder.register_random_type(str, next_str) + + """ + other_mapper: t.Dict[t.Type, t.Callable] = {} + for typ, callable_ in cls.__OTHER_MAPPER.items(): + sig = inspect.signature(callable_) + if sig.parameters.get("column"): + other_mapper[typ] = partial(callable_, column) + else: + other_mapper[typ] = callable_ + return other_mapper + + @classmethod + def register_type(cls, typ: t.Type, callable_: t.Callable) -> None: + cls.__OTHER_MAPPER[typ] = callable_ + + @classmethod + def unregister_type(cls, typ: t.Type) -> None: + if typ in cls.__OTHER_MAPPER: + del cls.__OTHER_MAPPER[typ] diff --git a/piccolo/testing/random_builder.py b/piccolo/testing/random_builder.py index bca29a7f2..f8d18cb5a 100644 --- a/piccolo/testing/random_builder.py +++ b/piccolo/testing/random_builder.py @@ -1,4 +1,5 @@ import datetime +import decimal import enum import random import string @@ -7,6 +8,27 @@ class RandomBuilder: + @classmethod + def get_mapper(cls) -> t.Dict[t.Type, t.Callable]: + """ + This is the public API for users to get the + provided random mapper. + + """ + return { + bool: cls.next_bool, + bytes: cls.next_bytes, + datetime.date: cls.next_date, + datetime.datetime: cls.next_datetime, + float: cls.next_float, + decimal.Decimal: cls.next_decimal, + int: cls.next_int, + str: cls.next_str, + datetime.time: cls.next_time, + datetime.timedelta: cls.next_timedelta, + uuid.UUID: cls.next_uuid, + } + @classmethod def next_bool(cls) -> bool: return random.choice([True, False]) @@ -43,12 +65,21 @@ def next_enum(cls, e: t.Type[enum.Enum]) -> t.Any: def next_float(cls, minimum=0, maximum=2147483647, scale=5) -> float: return round(random.uniform(minimum, maximum), scale) + @classmethod + def next_decimal( + cls, precision: int = 4, scale: int = 2 + ) -> decimal.Decimal: + float_number = cls.next_float( + maximum=10 ** (precision - scale), scale=scale + ) + return decimal.Decimal(str(float_number)) + @classmethod def next_int(cls, minimum=0, maximum=2147483647) -> int: return random.randint(minimum, maximum) @classmethod - def next_str(cls, length=16) -> str: + def next_str(cls, length: int = 16) -> str: return "".join( random.choice(string.ascii_letters) for _ in range(length) ) @@ -72,3 +103,8 @@ def next_timedelta(cls) -> datetime.timedelta: @classmethod def next_uuid(cls) -> uuid.UUID: return uuid.uuid4() + + @classmethod + def next_list(cls, callable_: t.Callable) -> t.List[t.Any]: + length = cls.next_int(maximum=10) + return [callable_() for _ in range(length)] diff --git a/tests/testing/test_model_builder.py b/tests/testing/test_model_builder.py index 242bac188..aa313129f 100644 --- a/tests/testing/test_model_builder.py +++ b/tests/testing/test_model_builder.py @@ -1,10 +1,13 @@ import asyncio +import builtins import json +import random import typing as t import unittest from piccolo.columns import ( Array, + Column, Decimal, ForeignKey, Integer, @@ -223,3 +226,66 @@ def test_json(self): .run_sync() ): self.assertIsInstance(facilities, dict) + + +@engines_skip("cockroach") +class TestModelBuilder2(unittest.TestCase): + @classmethod + def setUpClass(cls): + create_db_tables_sync(*TABLES) + + @classmethod + def tearDownClass(cls) -> None: + drop_db_tables_sync(*TABLES) + + def setUp(self) -> None: + ModelBuilder.__OTHER_MAPPER = {} + + def tearDown(self) -> None: + ModelBuilder.__OTHER_MAPPER = {} + + def test_register(self): + ModelBuilder.register_type(str, lambda: "piccolo") + manager = ModelBuilder.build_sync(Manager) + self.assertEqual(manager.name, "piccolo") + + def test_register_with_column_info(self): + def next_str(column: Column) -> str: + length = column._meta.params.get("length", 5) + return "".join("a" for _ in range(length)) + + ModelBuilder.register_type(str, next_str) + manager = ModelBuilder.build_sync(Manager) + self.assertEqual(len(manager.name), 50) + + post = ModelBuilder.build_sync(Poster) + self.assertEqual(len(post.content), 5) + + def test_register_same_type(self): + ModelBuilder.register_type(str, lambda: "piccolo") + ModelBuilder.register_type(str, lambda: "PICCOLO") + manager = ModelBuilder.build_sync(Manager) + self.assertEqual(manager.name, "PICCOLO") + + def test_unregister(self): + ModelBuilder.register_type(str, lambda: "piccolo") + ModelBuilder.unregister_type(str) + manager = ModelBuilder.build_sync(Manager) + self.assertNotEqual(manager.name, "piccolo") + + def test_unregister_same_type(self): + ModelBuilder.unregister_type(str) + ModelBuilder.unregister_type(str) + + def test_unregister_any_type(self): + ModelBuilder.unregister_type(random.choice(dir(builtins))) + + def test_get_registry(self): + def next_str() -> str: + return "piccolo" + + ModelBuilder.register_type(str, next_str) + reg = ModelBuilder.get_registry(Varchar()) + + self.assertIn(str, reg) + self.assertIn(next_str, reg.values()) diff --git a/tests/testing/test_random_builder.py b/tests/testing/test_random_builder.py index 1f078cb9e..ed5d2cad5 100644 --- a/tests/testing/test_random_builder.py +++ b/tests/testing/test_random_builder.py @@ -35,6 +35,10 @@ def test_next_float(self): random_float = RandomBuilder.next_float(maximum=1000) self.assertLessEqual(random_float, 1000) + def test_next_decimal(self): + random_decimal = RandomBuilder.next_decimal(5, 2) + self.assertLessEqual(random_decimal, 1000) + def test_next_int(self): random_int = RandomBuilder.next_int() self.assertLessEqual(random_int, 2147483647) @@ -52,3 +56,9 @@ def test_next_timedelta(self): def test_next_uuid(self): RandomBuilder.next_uuid() + + def test_next_list(self): + for typ, callable_ in RandomBuilder.get_mapper().items(): + random_list = RandomBuilder.next_list(callable_) + self.assertIsInstance(random_list, list) + self.assertTrue(all(isinstance(elem, typ) for elem in random_list))