diff --git a/stripe/_api_requestor.py b/stripe/_api_requestor.py index 59ba0898f..237a32030 100644 --- a/stripe/_api_requestor.py +++ b/stripe/_api_requestor.py @@ -38,7 +38,7 @@ import stripe.oauth_error as oauth_error from stripe._multipart_data_generator import MultipartDataGenerator from urllib.parse import urlencode -from stripe._encode import _api_encode, _json_encode_date_callback +from stripe._encode import _api_encode, _make_suitable_for_json from stripe._stripe_response import ( StripeResponse, StripeStreamResponse, @@ -642,7 +642,7 @@ def _args_for_request_with_retries( if api_mode == "V2": encoded_body = json.dumps( - params or {}, default=_json_encode_date_callback + params or {}, default=_make_suitable_for_json ) else: encoded_body = encoded_params diff --git a/stripe/_encode.py b/stripe/_encode.py index 85847e87a..23de523da 100644 --- a/stripe/_encode.py +++ b/stripe/_encode.py @@ -15,6 +15,10 @@ def _encode_datetime(dttime: datetime.datetime): return int(utc_timestamp) +def _encode_decimal(dec) -> str: + return str(dec) + + def _encode_nested_dict(key, data, fmt="%s[%s]"): d = OrderedDict() items = data._data.items() if hasattr(data, "_data") else data.items() @@ -23,11 +27,16 @@ def _encode_nested_dict(key, data, fmt="%s[%s]"): return d -def _json_encode_date_callback(value): +def _make_suitable_for_json(value: Any) -> Any: + """ + Handles taking arbitrary values and making sure they're JSON encodable. + + Only cares about types that can appear on StripeObject that but are not serializable by default (like Decimal). + """ if isinstance(value, datetime.datetime): return _encode_datetime(value) if isinstance(value, Decimal): - return str(value) + return _encode_decimal(value) return value @@ -102,7 +111,7 @@ def _coerce_decimal_string(value: Any, *, encode: bool) -> Any: if isinstance(value, (Decimal, int, float)) and not isinstance( value, bool ): - return str(value) + return _encode_decimal(value) return value else: if isinstance(value, str): diff --git a/stripe/_stripe_object.py b/stripe/_stripe_object.py index cd42fb10f..6156586a2 100644 --- a/stripe/_stripe_object.py +++ b/stripe/_stripe_object.py @@ -1,5 +1,4 @@ # pyright: strict -import datetime import json from copy import deepcopy from typing_extensions import TYPE_CHECKING, Type, Literal, Self, deprecated @@ -26,8 +25,11 @@ StripeStreamResponse, StripeStreamResponseAsync, ) -from stripe._encode import _encode_datetime # pyright: ignore -from stripe._encode import _coerce_int64_string, _coerce_decimal_string # pyright: ignore +from stripe._encode import ( + _coerce_int64_string, # pyright: ignore[reportPrivateUsage] + _coerce_decimal_string, # pyright: ignore[reportPrivateUsage] + _make_suitable_for_json, # pyright: ignore[reportPrivateUsage] +) from stripe._request_options import ( PERSISTENT_OPTIONS_KEYS, extract_options_from_dict, @@ -82,12 +84,6 @@ def _serialize_list( class StripeObject: - class _ReprJSONEncoder(json.JSONEncoder): - def default(self, o: Any) -> Any: - if isinstance(o, datetime.datetime): - return _encode_datetime(o) - return super(StripeObject._ReprJSONEncoder, self).default(o) - _retrieve_params: Mapping[str, Any] _previous: Optional[Mapping[str, Any]] @@ -528,20 +524,23 @@ def __str__(self) -> str: self._to_dict_recursive(), sort_keys=True, indent=2, - cls=self._ReprJSONEncoder, + default=_make_suitable_for_json, ) - def to_dict(self, recursive: bool = True) -> Dict[str, Any]: + def to_dict( + self, recursive: bool = True, for_json: bool = False + ) -> Dict[str, Any]: """ - Dump the object's backing data. Recurses by default, but you can opt-out of that behavior by passing `recursive=False` + Dump the object's backing data. Recurses by default, but you can opt-out of that behavior by passing `recursive=False`. + Pass `for_json=True` to convert non-JSON-serializable values (e.g. Decimal -> str) """ if recursive: - return self._to_dict_recursive() + return self._to_dict_recursive(for_json=for_json) # shallow copy, so nested objects will be shared return self._data.copy() - def _to_dict_recursive(self) -> Dict[str, Any]: + def _to_dict_recursive(self, for_json: bool = False) -> Dict[str, Any]: """ used by __str__ to serialize the whole object """ @@ -552,7 +551,9 @@ def maybe_to_dict_recursive( if value is None: return None elif isinstance(value, StripeObject): - return value._to_dict_recursive() + return value._to_dict_recursive(for_json=for_json) + elif for_json: + return _make_suitable_for_json(value) else: return value @@ -643,12 +644,14 @@ def _get_inner_class_is_beneath_dict(self, field_name: str): def _coerce_field_value(self, field_name: str, value: Any) -> Any: """ - Apply field encoding coercion based on _field_encodings metadata. + Convert JSON types to more applicable Python types, if able. - For int64_string fields, converts string values from the API response - to native Python ints. For decimal_string fields, converts string - values to decimal.Decimal. + For example, "int64_string"s become `int`s. """ + + # WARNING: if you edit this function to produce a type that's not json-serializable, you need to update `_make_suitable_for_json` as well. + # By default, Python will only correctly dump a few standard types, so we have to handle the rest + encoding = self._field_encodings.get(field_name) if encoding is None or value is None: return value diff --git a/tests/test_stripe_object.py b/tests/test_stripe_object.py index 430fe81cd..23d0d1c27 100644 --- a/tests/test_stripe_object.py +++ b/tests/test_stripe_object.py @@ -1,4 +1,5 @@ import datetime +from decimal import Decimal import json import pickle from copy import copy, deepcopy @@ -257,12 +258,14 @@ def test_repr(self): obj["object"] = "\u4e00boo\u1f00" obj.date = datetime.datetime.fromtimestamp(1511136000) + obj.dec = Decimal("1.23") res = repr(obj) assert "