From 358a7d74af3e21fa2203dcbdddbc44cf7333d405 Mon Sep 17 00:00:00 2001 From: jrycw Date: Fri, 22 Mar 2024 12:30:36 +0800 Subject: [PATCH 1/3] refactor ModelBuilder and RandomBuilder --- piccolo/testing/model_builder.py | 44 ++---------------------- piccolo/testing/random_builder.py | 51 +++++++++++++++++++++++++++- tests/testing/test_random_builder.py | 27 +++++++++++++++ 3 files changed, 80 insertions(+), 42 deletions(-) diff --git a/piccolo/testing/model_builder.py b/piccolo/testing/model_builder.py index 87e1c87fb..3146d2f62 100644 --- a/piccolo/testing/model_builder.py +++ b/piccolo/testing/model_builder.py @@ -1,31 +1,15 @@ from __future__ import annotations -import datetime import json import typing as t -from decimal import Decimal -from uuid import UUID -from piccolo.columns import JSON, JSONB, Array, Column, ForeignKey +from piccolo.columns import JSON, JSONB, Column, ForeignKey from piccolo.custom_types import TableInstance from piccolo.testing.random_builder import RandomBuilder from piccolo.utils.sync import run_sync class ModelBuilder: - __DEFAULT_MAPPER: t.Dict[t.Type, t.Callable] = { - bool: RandomBuilder.next_bool, - bytes: RandomBuilder.next_bytes, - datetime.date: RandomBuilder.next_date, - datetime.datetime: RandomBuilder.next_datetime, - float: RandomBuilder.next_float, - int: RandomBuilder.next_int, - str: RandomBuilder.next_str, - datetime.time: RandomBuilder.next_time, - datetime.timedelta: RandomBuilder.next_timedelta, - UUID: RandomBuilder.next_uuid, - } - @classmethod async def build( cls, @@ -159,29 +143,7 @@ def _randomize_attribute(cls, column: Column) -> t.Any: Column class to randomize. """ - random_value: t.Any - if column.value_type == Decimal: - precision, scale = column._meta.params["digits"] or (4, 2) - random_value = RandomBuilder.next_float( - maximum=10 ** (precision - scale), scale=scale - ) - elif column.value_type == datetime.datetime: - tz_aware = getattr(column, "tz_aware", False) - random_value = RandomBuilder.next_datetime(tz_aware=tz_aware) - elif column.value_type == list: - length = RandomBuilder.next_int(maximum=10) - base_type = t.cast(Array, column).base_column.value_type - random_value = [ - cls.__DEFAULT_MAPPER[base_type]() for _ in range(length) - ] - elif column._meta.choices: - random_value = RandomBuilder.next_enum(column._meta.choices) - else: - random_value = cls.__DEFAULT_MAPPER[column.value_type]() - - if "length" in column._meta.params and isinstance(random_value, str): - return random_value[: column._meta.params["length"]] - elif isinstance(column, (JSON, JSONB)): + random_value: t.Any = RandomBuilder._build(column) + if isinstance(column, (JSON, JSONB)): return json.dumps({"value": random_value}) - return random_value diff --git a/piccolo/testing/random_builder.py b/piccolo/testing/random_builder.py index bca29a7f2..13a693e53 100644 --- a/piccolo/testing/random_builder.py +++ b/piccolo/testing/random_builder.py @@ -1,12 +1,50 @@ +from __future__ import annotations + import datetime import enum import random import string import typing as t import uuid +from decimal import Decimal +from functools import partial +from uuid import UUID + +from piccolo.columns import Array, Column class RandomBuilder: + @classmethod + def _build(cls, column: Column) -> t.Any: + if e := column._meta.choices: + return cls.next_enum(e) + + mapper: t.Dict[t.Type, t.Callable] = { + bool: cls.next_bool, + bytes: cls.next_bytes, + datetime.date: cls.next_date, + datetime.datetime: partial( + cls.next_datetime, getattr(column, "tz_aware", False) + ), + float: cls.next_float, + Decimal: partial( + cls.next_decimal, column._meta.params.get("digits") + ), + int: cls.next_int, + str: partial(cls.next_str, column._meta.params.get("length")), + datetime.time: cls.next_time, + datetime.timedelta: cls.next_timedelta, + UUID: cls.next_uuid, + } + + random_value_callable = mapper.get(column.value_type) + if random_value_callable is None: + random_value_callable = partial( + cls.next_list, + mapper[t.cast(Array, column).base_column.value_type], + ) + return random_value_callable() + @classmethod def next_bool(cls) -> bool: return random.choice([True, False]) @@ -43,12 +81,18 @@ def next_enum(cls, e: t.Type[enum.Enum]) -> t.Any: def next_float(cls, minimum=0, maximum=2147483647, scale=5) -> float: return round(random.uniform(minimum, maximum), scale) + @classmethod + def next_decimal(cls, digits: t.Tuple[int, int] | None = (4, 2)) -> float: + precision, scale = digits or (4, 2) + return cls.next_float(maximum=10 ** (precision - scale), scale=scale) + @classmethod def next_int(cls, minimum=0, maximum=2147483647) -> int: return random.randint(minimum, maximum) @classmethod - def next_str(cls, length=16) -> str: + def next_str(cls, length: int | None = 16) -> str: + length = length or 16 return "".join( random.choice(string.ascii_letters) for _ in range(length) ) @@ -72,3 +116,8 @@ def next_timedelta(cls) -> datetime.timedelta: @classmethod def next_uuid(cls) -> uuid.UUID: return uuid.uuid4() + + @classmethod + def next_list(cls, callable_: t.Callable) -> t.List[t.Any]: + length = cls.next_int(maximum=10) + return [callable_() for _ in range(length)] diff --git a/tests/testing/test_random_builder.py b/tests/testing/test_random_builder.py index 1f078cb9e..b3e8347a9 100644 --- a/tests/testing/test_random_builder.py +++ b/tests/testing/test_random_builder.py @@ -1,5 +1,7 @@ +import datetime import unittest from enum import Enum +from uuid import UUID from piccolo.testing.random_builder import RandomBuilder @@ -31,6 +33,10 @@ class Color(Enum): random_enum = RandomBuilder.next_enum(Color) self.assertIsInstance(random_enum, int) + def test_next_decimal(self): + random_decimal = RandomBuilder.next_decimal((5, 2)) + self.assertLessEqual(random_decimal, 1000) + def test_next_float(self): random_float = RandomBuilder.next_float(maximum=1000) self.assertLessEqual(random_float, 1000) @@ -52,3 +58,24 @@ def test_next_timedelta(self): def test_next_uuid(self): RandomBuilder.next_uuid() + + def test_next_list(self): + # `RandomBuilder.next_decimal` will return `float` + reversed_mapper = { + RandomBuilder.next_bool: bool, + RandomBuilder.next_bytes: bytes, + RandomBuilder.next_date: datetime.date, + RandomBuilder.next_datetime: datetime.datetime, + RandomBuilder.next_float: float, + RandomBuilder.next_decimal: float, + RandomBuilder.next_int: int, + RandomBuilder.next_str: str, + RandomBuilder.next_time: datetime.time, + RandomBuilder.next_timedelta: datetime.timedelta, + RandomBuilder.next_uuid: UUID, + } + + for callable_, typ in reversed_mapper.items(): + random_list = RandomBuilder.next_list(callable_) + self.assertIsInstance(random_list, list) + self.assertTrue(all(isinstance(elem, typ) for elem in random_list)) From d6df24744f8988d1e6b14b804efbdcd7c92ca1d8 Mon Sep 17 00:00:00 2001 From: jrycw Date: Sat, 23 Mar 2024 03:05:32 +0800 Subject: [PATCH 2/3] Another iteration of refactor --- piccolo/testing/model_builder.py | 105 ++++++++++++++++++++++++++- piccolo/testing/random_builder.py | 62 ++++------------ tests/testing/test_model_builder.py | 56 ++++++++++++++ tests/testing/test_random_builder.py | 31 +------- 4 files changed, 179 insertions(+), 75 deletions(-) diff --git a/piccolo/testing/model_builder.py b/piccolo/testing/model_builder.py index 3146d2f62..0cb5b009d 100644 --- a/piccolo/testing/model_builder.py +++ b/piccolo/testing/model_builder.py @@ -1,15 +1,22 @@ from __future__ import annotations +import datetime +import inspect import json import typing as t +from decimal import Decimal +from functools import partial -from piccolo.columns import JSON, JSONB, Column, ForeignKey +from piccolo.columns import JSON, JSONB, Array, Column, ForeignKey from piccolo.custom_types import TableInstance from piccolo.testing.random_builder import RandomBuilder from piccolo.utils.sync import run_sync class ModelBuilder: + __DEFAULT_MAPPER: t.Dict[t.Type, t.Callable] = {} + __OTHER_MAPPER: t.Dict[t.Type, t.Callable] = {} + @classmethod async def build( cls, @@ -143,7 +150,101 @@ def _randomize_attribute(cls, column: Column) -> t.Any: Column class to randomize. """ - random_value: t.Any = RandomBuilder._build(column) + random_value: t.Any + default_mapper = cls._get_default_mapper() + local_mapper = cls._get_local_mapper(column) + other_mapper = cls._get_other_mapper(column) + # order matters + mapper = {**default_mapper, **local_mapper, **other_mapper} + + if column._meta.choices: + random_value = RandomBuilder.next_enum(column._meta.choices) + elif column.value_type == list: + length = RandomBuilder.next_int(maximum=10) + base_type = t.cast(Array, column).base_column.value_type + random_value = [mapper[base_type]() for _ in range(length)] + else: + random_value = mapper[column.value_type]() + if isinstance(column, (JSON, JSONB)): return json.dumps({"value": random_value}) return random_value + + @classmethod + def _get_default_mapper(cls) -> t.Dict[t.Type, t.Callable]: + """ + This classmethod encapsulates the desired logic. + """ + mapper = cls.__DEFAULT_MAPPER + if not mapper: # execute once only + for typ, callable_name in RandomBuilder.DEFAULT_MAPPER.items(): + # a simpler approach available? + func = RandomBuilder.__dict__[callable_name].__func__ + mapper[typ] = partial(func, RandomBuilder) + return mapper + + @classmethod + def _get_local_mapper(cls, column: Column) -> t.Dict[t.Type, t.Callable]: + """ + This classmethod encapsulates the desired logic, utilizing information + from the column. + + :param column: + Column class to randomize. + """ + local_mapper: t.Dict[t.Type, t.Callable] = {} + if column.value_type == Decimal: + precision, scale = column._meta.params["digits"] or (4, 2) + local_mapper[Decimal] = partial( + RandomBuilder.next_decimal, precision, scale + ) + elif column.value_type == datetime.datetime: + tz_aware = getattr(column, "tz_aware", False) + local_mapper[datetime.datetime] = partial( + RandomBuilder.next_datetime, tz_aware + ) + elif column.value_type == str: + if _length := column._meta.params.get("length"): + local_mapper[str] = partial(RandomBuilder.next_str, _length) + return local_mapper + + @classmethod + def _get_other_mapper(cls, column: Column) -> t.Dict[t.Type, t.Callable]: + """ + This is a hook that allows users to register their own random type + callable. If the callable has a parameter named `column`, we assist + by injecting `column` using `partial`. + + :param column: + Column class to randomize. + + Examples:: + + # a callable not utilizing column information + ModelBuilder.register_random_type(str, lambda: "piccolo") + + # a callable utilizing the column information + def next_str(column: Column) -> str: + length = column._meta.params.get("length", 5) + return "".join("a" for _ in range(length)) + ) + ModelBuilder.register_random_type(str, next_str) + + """ + other_mapper: t.Dict[t.Type, t.Callable] = {} + for typ, callable_ in cls.__OTHER_MAPPER.items(): + sig = inspect.signature(callable_) + if sig.parameters.get("column"): + other_mapper[typ] = partial(callable_, column) + else: + other_mapper[typ] = callable_ + return other_mapper + + @classmethod + def register_random_type(cls, typ: t.Type, callable_: t.Callable) -> None: + cls.__OTHER_MAPPER[typ] = callable_ + + @classmethod + def unregister_random_type(cls, typ: t.Type) -> None: + if typ in cls.__OTHER_MAPPER: + del cls.__OTHER_MAPPER[typ] diff --git a/piccolo/testing/random_builder.py b/piccolo/testing/random_builder.py index 13a693e53..16fd82dc3 100644 --- a/piccolo/testing/random_builder.py +++ b/piccolo/testing/random_builder.py @@ -1,49 +1,26 @@ -from __future__ import annotations - import datetime +import decimal import enum import random import string import typing as t import uuid -from decimal import Decimal -from functools import partial -from uuid import UUID - -from piccolo.columns import Array, Column class RandomBuilder: - @classmethod - def _build(cls, column: Column) -> t.Any: - if e := column._meta.choices: - return cls.next_enum(e) - - mapper: t.Dict[t.Type, t.Callable] = { - bool: cls.next_bool, - bytes: cls.next_bytes, - datetime.date: cls.next_date, - datetime.datetime: partial( - cls.next_datetime, getattr(column, "tz_aware", False) - ), - float: cls.next_float, - Decimal: partial( - cls.next_decimal, column._meta.params.get("digits") - ), - int: cls.next_int, - str: partial(cls.next_str, column._meta.params.get("length")), - datetime.time: cls.next_time, - datetime.timedelta: cls.next_timedelta, - UUID: cls.next_uuid, - } - - random_value_callable = mapper.get(column.value_type) - if random_value_callable is None: - random_value_callable = partial( - cls.next_list, - mapper[t.cast(Array, column).base_column.value_type], - ) - return random_value_callable() + DEFAULT_MAPPER: t.Dict[t.Type, str] = { + bool: "next_bool", + bytes: "next_bytes", + datetime.date: "next_date", + datetime.datetime: "next_datetime", + float: "next_float", + decimal.Decimal: "next_decimal", + int: "next_int", + str: "next_str", + datetime.time: "next_time", + datetime.timedelta: "next_timedelta", + uuid.UUID: "next_uuid", + } @classmethod def next_bool(cls) -> bool: @@ -82,8 +59,7 @@ def next_float(cls, minimum=0, maximum=2147483647, scale=5) -> float: return round(random.uniform(minimum, maximum), scale) @classmethod - def next_decimal(cls, digits: t.Tuple[int, int] | None = (4, 2)) -> float: - precision, scale = digits or (4, 2) + def next_decimal(cls, precision: int, scale: int) -> float: return cls.next_float(maximum=10 ** (precision - scale), scale=scale) @classmethod @@ -91,8 +67,7 @@ def next_int(cls, minimum=0, maximum=2147483647) -> int: return random.randint(minimum, maximum) @classmethod - def next_str(cls, length: int | None = 16) -> str: - length = length or 16 + def next_str(cls, length: int = 16) -> str: return "".join( random.choice(string.ascii_letters) for _ in range(length) ) @@ -116,8 +91,3 @@ def next_timedelta(cls) -> datetime.timedelta: @classmethod def next_uuid(cls) -> uuid.UUID: return uuid.uuid4() - - @classmethod - def next_list(cls, callable_: t.Callable) -> t.List[t.Any]: - length = cls.next_int(maximum=10) - return [callable_() for _ in range(length)] diff --git a/tests/testing/test_model_builder.py b/tests/testing/test_model_builder.py index 242bac188..1531757b3 100644 --- a/tests/testing/test_model_builder.py +++ b/tests/testing/test_model_builder.py @@ -1,10 +1,13 @@ import asyncio +import builtins import json +import random import typing as t import unittest from piccolo.columns import ( Array, + Column, Decimal, ForeignKey, Integer, @@ -223,3 +226,56 @@ def test_json(self): .run_sync() ): self.assertIsInstance(facilities, dict) + + +@engines_skip("cockroach") +class TestModelBuilder2(unittest.TestCase): + @classmethod + def setUpClass(cls): + create_db_tables_sync(*TABLES) + + @classmethod + def tearDownClass(cls) -> None: + drop_db_tables_sync(*TABLES) + + def setUp(self) -> None: + ModelBuilder.__OTHER_MAPPER = {} + + def tearDown(self) -> None: + ModelBuilder.__OTHER_MAPPER = {} + + def test_register(self): + ModelBuilder.register_random_type(str, lambda: "piccolo") + manager = ModelBuilder.build_sync(Manager) + self.assertEqual(manager.name, "piccolo") + + def test_register_with_column_info(self): + def next_str(column: Column) -> str: + length = column._meta.params.get("length", 5) + return "".join("a" for _ in range(length)) + + ModelBuilder.register_random_type(str, next_str) + manager = ModelBuilder.build_sync(Manager) + self.assertEqual(len(manager.name), 50) + + post = ModelBuilder.build_sync(Poster) + self.assertEqual(len(post.content), 5) + + def test_register_same_type(self): + ModelBuilder.register_random_type(str, lambda: "piccolo") + ModelBuilder.register_random_type(str, lambda: "PICCOLO") + manager = ModelBuilder.build_sync(Manager) + self.assertEqual(manager.name, "PICCOLO") + + def test_unregister(self): + ModelBuilder.register_random_type(str, lambda: "piccolo") + ModelBuilder.unregister_random_type(str) + manager = ModelBuilder.build_sync(Manager) + self.assertNotEqual(manager.name, "piccolo") + + def test_unregister_same_type(self): + ModelBuilder.unregister_random_type(str) + ModelBuilder.unregister_random_type(str) + + def test_unregister_any_type(self): + ModelBuilder.unregister_random_type(random.choice(dir(builtins))) diff --git a/tests/testing/test_random_builder.py b/tests/testing/test_random_builder.py index b3e8347a9..4d94837f0 100644 --- a/tests/testing/test_random_builder.py +++ b/tests/testing/test_random_builder.py @@ -1,7 +1,5 @@ -import datetime import unittest from enum import Enum -from uuid import UUID from piccolo.testing.random_builder import RandomBuilder @@ -33,14 +31,14 @@ class Color(Enum): random_enum = RandomBuilder.next_enum(Color) self.assertIsInstance(random_enum, int) - def test_next_decimal(self): - random_decimal = RandomBuilder.next_decimal((5, 2)) - self.assertLessEqual(random_decimal, 1000) - def test_next_float(self): random_float = RandomBuilder.next_float(maximum=1000) self.assertLessEqual(random_float, 1000) + def test_next_decimal(self): + random_decimal = RandomBuilder.next_decimal(5, 2) + self.assertLessEqual(random_decimal, 1000) + def test_next_int(self): random_int = RandomBuilder.next_int() self.assertLessEqual(random_int, 2147483647) @@ -58,24 +56,3 @@ def test_next_timedelta(self): def test_next_uuid(self): RandomBuilder.next_uuid() - - def test_next_list(self): - # `RandomBuilder.next_decimal` will return `float` - reversed_mapper = { - RandomBuilder.next_bool: bool, - RandomBuilder.next_bytes: bytes, - RandomBuilder.next_date: datetime.date, - RandomBuilder.next_datetime: datetime.datetime, - RandomBuilder.next_float: float, - RandomBuilder.next_decimal: float, - RandomBuilder.next_int: int, - RandomBuilder.next_str: str, - RandomBuilder.next_time: datetime.time, - RandomBuilder.next_timedelta: datetime.timedelta, - RandomBuilder.next_uuid: UUID, - } - - for callable_, typ in reversed_mapper.items(): - random_list = RandomBuilder.next_list(callable_) - self.assertIsInstance(random_list, list) - self.assertTrue(all(isinstance(elem, typ) for elem in random_list)) From ad763c30ba81b6286c425bddbcc6907a3bacc3da Mon Sep 17 00:00:00 2001 From: jrycw Date: Sat, 23 Mar 2024 14:22:11 +0800 Subject: [PATCH 3/3] third refactor --- piccolo/testing/model_builder.py | 84 ++++++++++++++++------------ piccolo/testing/random_builder.py | 47 +++++++++++----- tests/testing/test_model_builder.py | 28 +++++++--- tests/testing/test_random_builder.py | 6 ++ 4 files changed, 104 insertions(+), 61 deletions(-) diff --git a/piccolo/testing/model_builder.py b/piccolo/testing/model_builder.py index 0cb5b009d..7aa3381ff 100644 --- a/piccolo/testing/model_builder.py +++ b/piccolo/testing/model_builder.py @@ -6,6 +6,7 @@ import typing as t from decimal import Decimal from functools import partial +from types import MappingProxyType from piccolo.columns import JSON, JSONB, Array, Column, ForeignKey from piccolo.custom_types import TableInstance @@ -97,7 +98,7 @@ async def _build( persist: bool = True, ) -> TableInstance: model = table_class(_ignore_missing=True) - defaults = {} if not defaults else defaults + defaults = defaults or {} for column, value in defaults.items(): if isinstance(column, str): @@ -150,38 +151,46 @@ def _randomize_attribute(cls, column: Column) -> t.Any: Column class to randomize. """ - random_value: t.Any - default_mapper = cls._get_default_mapper() - local_mapper = cls._get_local_mapper(column) - other_mapper = cls._get_other_mapper(column) - # order matters - mapper = {**default_mapper, **local_mapper, **other_mapper} - + reg = cls.get_registry(column) if column._meta.choices: random_value = RandomBuilder.next_enum(column._meta.choices) - elif column.value_type == list: - length = RandomBuilder.next_int(maximum=10) - base_type = t.cast(Array, column).base_column.value_type - random_value = [mapper[base_type]() for _ in range(length)] else: - random_value = mapper[column.value_type]() + random_value = reg[column.value_type]() if isinstance(column, (JSON, JSONB)): return json.dumps({"value": random_value}) return random_value @classmethod - def _get_default_mapper(cls) -> t.Dict[t.Type, t.Callable]: + def get_registry( + cls, column: Column + ) -> MappingProxyType[t.Type, t.Callable]: """ - This classmethod encapsulates the desired logic. + This serves as the public API allowing users to **view** + the complete registry for the specified column. + + :param column: + Column class to randomize. + """ - mapper = cls.__DEFAULT_MAPPER - if not mapper: # execute once only - for typ, callable_name in RandomBuilder.DEFAULT_MAPPER.items(): - # a simpler approach available? - func = RandomBuilder.__dict__[callable_name].__func__ - mapper[typ] = partial(func, RandomBuilder) - return mapper + default_mapper = cls.__DEFAULT_MAPPER + if not default_mapper: # execute once only + for typ, callable_ in RandomBuilder.get_mapper().items(): + default_mapper[typ] = callable_ + + # order matters + reg = { + **default_mapper, + **cls._get_local_mapper(column), + **cls._get_other_mapper(column), + } + + if column.value_type == list: + reg[list] = partial( + RandomBuilder.next_list, + reg[t.cast(Array, column).base_column.value_type], + ) + return MappingProxyType(reg) @classmethod def _get_local_mapper(cls, column: Column) -> t.Dict[t.Type, t.Callable]: @@ -193,19 +202,20 @@ def _get_local_mapper(cls, column: Column) -> t.Dict[t.Type, t.Callable]: Column class to randomize. """ local_mapper: t.Dict[t.Type, t.Callable] = {} - if column.value_type == Decimal: - precision, scale = column._meta.params["digits"] or (4, 2) - local_mapper[Decimal] = partial( - RandomBuilder.next_decimal, precision, scale - ) - elif column.value_type == datetime.datetime: - tz_aware = getattr(column, "tz_aware", False) - local_mapper[datetime.datetime] = partial( - RandomBuilder.next_datetime, tz_aware - ) - elif column.value_type == str: - if _length := column._meta.params.get("length"): - local_mapper[str] = partial(RandomBuilder.next_str, _length) + + precision, scale = column._meta.params.get("digits") or (4, 2) + local_mapper[Decimal] = partial( + RandomBuilder.next_decimal, precision, scale + ) + + tz_aware = getattr(column, "tz_aware", False) + local_mapper[datetime.datetime] = partial( + RandomBuilder.next_datetime, tz_aware + ) + + if _length := column._meta.params.get("length"): + local_mapper[str] = partial(RandomBuilder.next_str, _length) + return local_mapper @classmethod @@ -241,10 +251,10 @@ def next_str(column: Column) -> str: return other_mapper @classmethod - def register_random_type(cls, typ: t.Type, callable_: t.Callable) -> None: + def register_type(cls, typ: t.Type, callable_: t.Callable) -> None: cls.__OTHER_MAPPER[typ] = callable_ @classmethod - def unregister_random_type(cls, typ: t.Type) -> None: + def unregister_type(cls, typ: t.Type) -> None: if typ in cls.__OTHER_MAPPER: del cls.__OTHER_MAPPER[typ] diff --git a/piccolo/testing/random_builder.py b/piccolo/testing/random_builder.py index 16fd82dc3..f8d18cb5a 100644 --- a/piccolo/testing/random_builder.py +++ b/piccolo/testing/random_builder.py @@ -8,19 +8,26 @@ class RandomBuilder: - DEFAULT_MAPPER: t.Dict[t.Type, str] = { - bool: "next_bool", - bytes: "next_bytes", - datetime.date: "next_date", - datetime.datetime: "next_datetime", - float: "next_float", - decimal.Decimal: "next_decimal", - int: "next_int", - str: "next_str", - datetime.time: "next_time", - datetime.timedelta: "next_timedelta", - uuid.UUID: "next_uuid", - } + @classmethod + def get_mapper(cls) -> t.Dict[t.Type, t.Callable]: + """ + This is the public API for users to get the + provided random mapper. + + """ + return { + bool: cls.next_bool, + bytes: cls.next_bytes, + datetime.date: cls.next_date, + datetime.datetime: cls.next_datetime, + float: cls.next_float, + decimal.Decimal: cls.next_decimal, + int: cls.next_int, + str: cls.next_str, + datetime.time: cls.next_time, + datetime.timedelta: cls.next_timedelta, + uuid.UUID: cls.next_uuid, + } @classmethod def next_bool(cls) -> bool: @@ -59,8 +66,13 @@ def next_float(cls, minimum=0, maximum=2147483647, scale=5) -> float: return round(random.uniform(minimum, maximum), scale) @classmethod - def next_decimal(cls, precision: int, scale: int) -> float: - return cls.next_float(maximum=10 ** (precision - scale), scale=scale) + def next_decimal( + cls, precision: int = 4, scale: int = 2 + ) -> decimal.Decimal: + float_number = cls.next_float( + maximum=10 ** (precision - scale), scale=scale + ) + return decimal.Decimal(str(float_number)) @classmethod def next_int(cls, minimum=0, maximum=2147483647) -> int: @@ -91,3 +103,8 @@ def next_timedelta(cls) -> datetime.timedelta: @classmethod def next_uuid(cls) -> uuid.UUID: return uuid.uuid4() + + @classmethod + def next_list(cls, callable_: t.Callable) -> t.List[t.Any]: + length = cls.next_int(maximum=10) + return [callable_() for _ in range(length)] diff --git a/tests/testing/test_model_builder.py b/tests/testing/test_model_builder.py index 1531757b3..aa313129f 100644 --- a/tests/testing/test_model_builder.py +++ b/tests/testing/test_model_builder.py @@ -245,7 +245,7 @@ def tearDown(self) -> None: ModelBuilder.__OTHER_MAPPER = {} def test_register(self): - ModelBuilder.register_random_type(str, lambda: "piccolo") + ModelBuilder.register_type(str, lambda: "piccolo") manager = ModelBuilder.build_sync(Manager) self.assertEqual(manager.name, "piccolo") @@ -254,7 +254,7 @@ def next_str(column: Column) -> str: length = column._meta.params.get("length", 5) return "".join("a" for _ in range(length)) - ModelBuilder.register_random_type(str, next_str) + ModelBuilder.register_type(str, next_str) manager = ModelBuilder.build_sync(Manager) self.assertEqual(len(manager.name), 50) @@ -262,20 +262,30 @@ def next_str(column: Column) -> str: self.assertEqual(len(post.content), 5) def test_register_same_type(self): - ModelBuilder.register_random_type(str, lambda: "piccolo") - ModelBuilder.register_random_type(str, lambda: "PICCOLO") + ModelBuilder.register_type(str, lambda: "piccolo") + ModelBuilder.register_type(str, lambda: "PICCOLO") manager = ModelBuilder.build_sync(Manager) self.assertEqual(manager.name, "PICCOLO") def test_unregister(self): - ModelBuilder.register_random_type(str, lambda: "piccolo") - ModelBuilder.unregister_random_type(str) + ModelBuilder.register_type(str, lambda: "piccolo") + ModelBuilder.unregister_type(str) manager = ModelBuilder.build_sync(Manager) self.assertNotEqual(manager.name, "piccolo") def test_unregister_same_type(self): - ModelBuilder.unregister_random_type(str) - ModelBuilder.unregister_random_type(str) + ModelBuilder.unregister_type(str) + ModelBuilder.unregister_type(str) def test_unregister_any_type(self): - ModelBuilder.unregister_random_type(random.choice(dir(builtins))) + ModelBuilder.unregister_type(random.choice(dir(builtins))) + + def test_get_registry(self): + def next_str() -> str: + return "piccolo" + + ModelBuilder.register_type(str, next_str) + reg = ModelBuilder.get_registry(Varchar()) + + self.assertIn(str, reg) + self.assertIn(next_str, reg.values()) diff --git a/tests/testing/test_random_builder.py b/tests/testing/test_random_builder.py index 4d94837f0..ed5d2cad5 100644 --- a/tests/testing/test_random_builder.py +++ b/tests/testing/test_random_builder.py @@ -56,3 +56,9 @@ def test_next_timedelta(self): def test_next_uuid(self): RandomBuilder.next_uuid() + + def test_next_list(self): + for typ, callable_ in RandomBuilder.get_mapper().items(): + random_list = RandomBuilder.next_list(callable_) + self.assertIsInstance(random_list, list) + self.assertTrue(all(isinstance(elem, typ) for elem in random_list))