diff --git a/piccolo/columns/base.py b/piccolo/columns/base.py index 9d3e2b1cc..b9ed2bd18 100644 --- a/piccolo/columns/base.py +++ b/piccolo/columns/base.py @@ -527,7 +527,9 @@ def _validate_default( elif ( default is None and None in allowed_types - or type(default) in allowed_types + or isinstance( + default, tuple(t for t in allowed_types if isinstance(t, type)) + ) ): self._validated = True return True @@ -539,8 +541,9 @@ def _validate_default( ): self._validated = True return True - elif ( - isinstance(default, Enum) and type(default.value) in allowed_types + elif isinstance(default, Enum) and isinstance( + default.value, + tuple(t for t in allowed_types if isinstance(t, type)), ): self._validated = True return True diff --git a/tests/columns/test_defaults.py b/tests/columns/test_defaults.py index 77df731bf..c6e32243f 100644 --- a/tests/columns/test_defaults.py +++ b/tests/columns/test_defaults.py @@ -25,6 +25,13 @@ from piccolo.table import Table +def get_custom_default(base): + class CustomDefault(base): + pass + + return CustomDefault() + + class TestDefaults(TestCase): """ Columns check the type of the default argument. @@ -66,6 +73,7 @@ def test_uuid(self): UUID(default=None, null=True) UUID(default=UUID4()) UUID(default=uuid.uuid4()) + UUID(default=get_custom_default(UUID4)) with self.assertRaises(ValueError): UUID(default="hello world") @@ -73,6 +81,7 @@ def test_time(self): Time(default=None, null=True) Time(default=TimeNow()) Time(default=datetime.datetime.now().time()) + Time(default=get_custom_default(TimeNow)) with self.assertRaises(ValueError): Time(default="hello world") # type: ignore @@ -80,6 +89,7 @@ def test_date(self): Date(default=None, null=True) Date(default=DateNow()) Date(default=datetime.datetime.now().date()) + Date(default=get_custom_default(DateNow)) with self.assertRaises(ValueError): Date(default="hello world") # type: ignore @@ -87,6 +97,7 @@ def test_timestamp(self): Timestamp(default=None, null=True) Timestamp(default=TimestampNow()) Timestamp(default=datetime.datetime.now()) + Timestamp(default=get_custom_default(TimestampNow)) with self.assertRaises(ValueError): Timestamp(default="hello world") # type: ignore