Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions stripe/_api_requestor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import stripe.oauth_error as oauth_error
from stripe._multipart_data_generator import MultipartDataGenerator
from urllib.parse import urlencode
from stripe._encode import _api_encode, _json_encode_date_callback
from stripe._encode import _api_encode, _make_suitable_for_json
from stripe._stripe_response import (
StripeResponse,
StripeStreamResponse,
Expand Down Expand Up @@ -642,7 +642,7 @@ def _args_for_request_with_retries(

if api_mode == "V2":
encoded_body = json.dumps(
params or {}, default=_json_encode_date_callback
params or {}, default=_make_suitable_for_json
)
else:
encoded_body = encoded_params
Expand Down
15 changes: 12 additions & 3 deletions stripe/_encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ def _encode_datetime(dttime: datetime.datetime):
return int(utc_timestamp)


def _encode_decimal(dec) -> str:
return str(dec)


def _encode_nested_dict(key, data, fmt="%s[%s]"):
d = OrderedDict()
items = data._data.items() if hasattr(data, "_data") else data.items()
Expand All @@ -23,11 +27,16 @@ def _encode_nested_dict(key, data, fmt="%s[%s]"):
return d


def _json_encode_date_callback(value):
def _make_suitable_for_json(value: Any) -> Any:
"""
Handles taking arbitrary values and making sure they're JSON encodable.

Only cares about types that can appear on StripeObject that but are not serializable by default (like Decimal).
"""
if isinstance(value, datetime.datetime):
return _encode_datetime(value)
if isinstance(value, Decimal):
return str(value)
return _encode_decimal(value)
return value


Expand Down Expand Up @@ -102,7 +111,7 @@ def _coerce_decimal_string(value: Any, *, encode: bool) -> Any:
if isinstance(value, (Decimal, int, float)) and not isinstance(
value, bool
):
return str(value)
return _encode_decimal(value)
return value
else:
if isinstance(value, str):
Expand Down
41 changes: 22 additions & 19 deletions stripe/_stripe_object.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# pyright: strict
import datetime
import json
from copy import deepcopy
from typing_extensions import TYPE_CHECKING, Type, Literal, Self, deprecated
Expand All @@ -26,8 +25,11 @@
StripeStreamResponse,
StripeStreamResponseAsync,
)
from stripe._encode import _encode_datetime # pyright: ignore
from stripe._encode import _coerce_int64_string, _coerce_decimal_string # pyright: ignore
from stripe._encode import (
_coerce_int64_string, # pyright: ignore[reportPrivateUsage]
_coerce_decimal_string, # pyright: ignore[reportPrivateUsage]
_make_suitable_for_json, # pyright: ignore[reportPrivateUsage]
)
from stripe._request_options import (
PERSISTENT_OPTIONS_KEYS,
extract_options_from_dict,
Expand Down Expand Up @@ -82,12 +84,6 @@ def _serialize_list(


class StripeObject:
class _ReprJSONEncoder(json.JSONEncoder):
def default(self, o: Any) -> Any:
if isinstance(o, datetime.datetime):
return _encode_datetime(o)
return super(StripeObject._ReprJSONEncoder, self).default(o)

_retrieve_params: Mapping[str, Any]
_previous: Optional[Mapping[str, Any]]

Expand Down Expand Up @@ -528,20 +524,23 @@ def __str__(self) -> str:
self._to_dict_recursive(),
sort_keys=True,
indent=2,
cls=self._ReprJSONEncoder,
default=_make_suitable_for_json,
)

def to_dict(self, recursive: bool = True) -> Dict[str, Any]:
def to_dict(
self, recursive: bool = True, for_json: bool = False
) -> Dict[str, Any]:
"""
Dump the object's backing data. Recurses by default, but you can opt-out of that behavior by passing `recursive=False`
Dump the object's backing data. Recurses by default, but you can opt-out of that behavior by passing `recursive=False`.
Pass `for_json=True` to convert non-JSON-serializable values (e.g. Decimal -> str)
"""
if recursive:
return self._to_dict_recursive()
return self._to_dict_recursive(for_json=for_json)

# shallow copy, so nested objects will be shared
return self._data.copy()

def _to_dict_recursive(self) -> Dict[str, Any]:
def _to_dict_recursive(self, for_json: bool = False) -> Dict[str, Any]:
"""
used by __str__ to serialize the whole object
"""
Expand All @@ -552,7 +551,9 @@ def maybe_to_dict_recursive(
if value is None:
return None
elif isinstance(value, StripeObject):
return value._to_dict_recursive()
return value._to_dict_recursive(for_json=for_json)
elif for_json:
return _make_suitable_for_json(value)
else:
return value

Expand Down Expand Up @@ -643,12 +644,14 @@ def _get_inner_class_is_beneath_dict(self, field_name: str):

def _coerce_field_value(self, field_name: str, value: Any) -> Any:
"""
Apply field encoding coercion based on _field_encodings metadata.
Convert JSON types to more applicable Python types, if able.

For int64_string fields, converts string values from the API response
to native Python ints. For decimal_string fields, converts string
values to decimal.Decimal.
For example, "int64_string"s become `int`s.
"""

# WARNING: if you edit this function to produce a type that's not json-serializable, you need to update `_make_suitable_for_json` as well.
# By default, Python will only correctly dump a few standard types, so we have to handle the rest

encoding = self._field_encodings.get(field_name)
if encoding is None or value is None:
return value
Expand Down
31 changes: 31 additions & 0 deletions tests/test_stripe_object.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
from decimal import Decimal
import json
import pickle
from copy import copy, deepcopy
Expand Down Expand Up @@ -257,12 +258,14 @@ def test_repr(self):

obj["object"] = "\u4e00boo\u1f00"
obj.date = datetime.datetime.fromtimestamp(1511136000)
obj.dec = Decimal("1.23")

res = repr(obj)

assert "<StripeObject \u4e00boo\u1f00" in res
assert "id=foo" in res
assert '"date": 1511136000' in res
assert '"dec": "1.23"' in res

def test_pickling(self):
obj = StripeObject("foo", "bar", myparam=5)
Expand Down Expand Up @@ -760,6 +763,34 @@ def test_to_dict_with_list_of_nested_objects(self):
assert d == {"id": "x", "items": [{"a": 1}, {"b": 2}]}
assert not isinstance(d["items"][0], StripeObject)

def test_to_dict_json_serializable_converts_decimal(self):
obj = StripeObject.construct_from(
{"amount": Decimal("9.99"), "name": "foo"}, "key"
)
d = obj.to_dict(for_json=True)
assert d == {"amount": "9.99", "name": "foo"}
assert isinstance(d["amount"], str)

def test_to_dict_json_serializable_converts_datetime(self):
dt = datetime.datetime(
2024, 1, 15, 12, 0, 0, tzinfo=datetime.timezone.utc
)
obj = StripeObject.construct_from({"created": dt, "id": "x"}, "key")
d = obj.to_dict(for_json=True)
assert isinstance(d["created"], int)

def test_to_dict_json_serializable_nested(self):
inner = StripeObject.construct_from({"amount": Decimal("1.23")}, "key")
obj = StripeObject.construct_from({"child": inner, "id": "x"}, "key")
d = obj.to_dict(for_json=True)
assert d["child"] == {"amount": "1.23"}
assert isinstance(d["child"]["amount"], str)

def test_to_dict_json_serializable_false_preserves_decimal(self):
obj = StripeObject.construct_from({"amount": Decimal("9.99")}, "key")
d = obj.to_dict()
assert isinstance(d["amount"], Decimal)

def test_update_sets_values(self):
obj = StripeObject.construct_from({"id": "x", "name": "a"}, "key")
obj.update({"name": "b", "email": "b@example.com"})
Expand Down
Loading