diff --git a/sdks/python/apache_beam/coders/coder_impl.py b/sdks/python/apache_beam/coders/coder_impl.py index 0bded25e05d2..177c5e2b5a57 100644 --- a/sdks/python/apache_beam/coders/coder_impl.py +++ b/sdks/python/apache_beam/coders/coder_impl.py @@ -28,10 +28,12 @@ For internal use only; no backwards-compatibility guarantees. """ + # pytype: skip-file # ruff: noqa: UP006 import dataclasses +import datetime import decimal import enum import itertools @@ -74,6 +76,11 @@ except ImportError: dill = None +try: + import zoneinfo +except ImportError: + zoneinfo = None + if TYPE_CHECKING: import proto @@ -98,9 +105,10 @@ try: import cython + is_compiled = cython.compiled except ImportError: - globals()['cython'] = type('fake_cython', (), {'cast': lambda typ, x: x}) + globals()["cython"] = type("fake_cython", (), {"cast": lambda typ, x: x}) else: # pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports @@ -111,9 +119,9 @@ # Make it possible to import create_InputStream and other cdef-classes # from apache_beam.coders.coder_impl when Cython codepath is used. - globals()['create_InputStream'] = create_InputStream - globals()['create_OutputStream'] = create_OutputStream - globals()['ByteCountingOutputStream'] = ByteCountingOutputStream + globals()["create_InputStream"] = create_InputStream + globals()["create_OutputStream"] = create_OutputStream + globals()["ByteCountingOutputStream"] = ByteCountingOutputStream # pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports is_compiled = True @@ -123,9 +131,9 @@ MIN_TIMESTAMP_micros = MIN_TIMESTAMP.micros MAX_TIMESTAMP_micros = MAX_TIMESTAMP.micros -IterableStateReader = Callable[[bytes, 'CoderImpl'], Iterable] -IterableStateWriter = Callable[[Iterable, 'CoderImpl'], bytes] -Observables = List[Tuple[observable.ObservableMixin, 'CoderImpl']] +IterableStateReader = Callable[[bytes, "CoderImpl"], Iterable] +IterableStateWriter = Callable[[Iterable, "CoderImpl"], bytes] +Observables = List[Tuple[observable.ObservableMixin, "CoderImpl"]] class CoderImpl(object): @@ -196,26 +204,26 @@ def get_estimated_size_and_observables(self, value, nested=False): """Returns estimated size of value along with any nested observables. - The list of nested observables is returned as a list of 2-tuples of - (obj, coder_impl), where obj is an instance of observable.ObservableMixin, - and coder_impl is the CoderImpl that can be used to encode elements sent by - obj to its observers. + The list of nested observables is returned as a list of 2-tuples of + (obj, coder_impl), where obj is an instance of observable.ObservableMixin, + and coder_impl is the CoderImpl that can be used to encode elements sent by + obj to its observers. - Arguments: - value: the value whose encoded size is to be estimated. - nested: whether the value is nested. + Arguments: + value: the value whose encoded size is to be estimated. + nested: whether the value is nested. - Returns: - The estimated encoded size of the given value and a list of observables - whose elements are 2-tuples of (obj, coder_impl) as described above. - """ + Returns: + The estimated encoded size of the given value and a list of observables + whose elements are 2-tuples of (obj, coder_impl) as described above. + """ return self.estimate_size(value, nested), [] class SimpleCoderImpl(CoderImpl): """For internal use only; no backwards-compatibility guarantees. - Subclass of CoderImpl implementing stream methods using encode/decode.""" + Subclass of CoderImpl implementing stream methods using encode/decode.""" def encode_to_stream(self, value, stream, nested): # type: (Any, create_OutputStream, bool) -> None @@ -232,7 +240,7 @@ def decode_from_stream(self, stream, nested): class StreamCoderImpl(CoderImpl): """For internal use only; no backwards-compatibility guarantees. - Subclass of CoderImpl implementing encode/decode using stream methods.""" + Subclass of CoderImpl implementing encode/decode using stream methods.""" def encode(self, value): # type: (Any) -> bytes out = create_OutputStream() @@ -255,11 +263,11 @@ def estimate_size(self, value, nested=False): class CallbackCoderImpl(CoderImpl): """For internal use only; no backwards-compatibility guarantees. - A CoderImpl that calls back to the _impl methods on the Coder itself. + A CoderImpl that calls back to the _impl methods on the Coder itself. - This is the default implementation used if Coder._get_impl() - is not overwritten. - """ + This is the default implementation used if Coder._get_impl() + is not overwritten. + """ def __init__(self, encoder, decoder, size_estimator=None): self._encoder = encoder self._decoder = decoder @@ -296,7 +304,7 @@ def get_estimated_size_and_observables(self, value, nested=False): return self.estimate_size(value, nested), [] def __repr__(self): - return 'CallbackCoderImpl[encoder=%s, decoder=%s]' % ( + return "CallbackCoderImpl[encoder=%s, decoder=%s]" % ( self._encoder, self._decoder) @@ -340,10 +348,13 @@ def decode(self, value): BYTES_TYPE = 3 UNICODE_TYPE = 4 BOOL_TYPE = 9 +DATETIME_TYPE = 11 +DATE_TYPE = 12 LIST_TYPE = 5 TUPLE_TYPE = 6 DICT_TYPE = 7 SET_TYPE = 8 +FROZENSET_TYPE = 13 ITERABLE_LIKE_TYPE = 10 PROTO_TYPE = 100 @@ -379,7 +390,8 @@ def __init__( fallback_coder_impl, requires_deterministic_step_label=None, force_use_dill=False, - use_relative_filepaths=True): + use_relative_filepaths=True, + ): self.fallback_coder_impl = fallback_coder_impl self.iterable_coder_impl = IterableCoderImpl(self) self.requires_deterministic_step_label = requires_deterministic_step_label @@ -432,7 +444,7 @@ def encode_to_stream(self, value, stream, nested): elif t is str: unicode_value = value # for typing stream.write_byte(UNICODE_TYPE) - stream.write(unicode_value.encode('utf-8'), nested) + stream.write(unicode_value.encode("utf-8"), nested) elif t is list or t is tuple: stream.write_byte(LIST_TYPE if t is list else TUPLE_TYPE) stream.write_var_int64(len(value)) @@ -441,6 +453,22 @@ def encode_to_stream(self, value, stream, nested): elif t is bool: stream.write_byte(BOOL_TYPE) stream.write_byte(value) + elif t is datetime.datetime: + # We use RFC 9557 for lossless encoding of timezone info. + stream.write_byte(DATETIME_TYPE) + stream.write(value.isoformat().encode("utf-8")) + if (zoneinfo is not None and value.tzinfo is not None and + type(value.tzinfo) is not datetime.timezone): + stream.write(f"[{value.tzinfo}]".encode("utf-8")) + if type( + value.tzinfo) is datetime.timezone and (tzname := + value.tzname()) is not None: + stream.write(f"[tzn={tzname}]".encode("utf-8")) + if value.fold != 0: + stream.write(f"[f={value.fold}]".encode("utf-8")) + elif t is datetime.date: + stream.write_byte(DATE_TYPE) + stream.write(value.isoformat().encode("utf-8")) elif t in _ITERABLE_LIKE_TYPES: stream.write_byte(ITERABLE_LIKE_TYPE) self.iterable_coder_impl.encode_to_stream(value, stream, nested) @@ -463,8 +491,11 @@ def encode_to_stream(self, value, stream, nested): for k, v in dict_value.items(): self.encode_to_stream(k, stream, True) self.encode_to_stream(v, stream, True) - elif t is set: - stream.write_byte(SET_TYPE) + elif t is set or t is frozenset: + if t is set: + stream.write_byte(SET_TYPE) + else: + stream.write_byte(FROZENSET_TYPE) stream.write_var_int64(len(value)) if self.requires_deterministic_step_label is not None: try: @@ -488,7 +519,8 @@ def encode_special_deterministic(self, value, stream): _LOGGER.warning( "Using fallback deterministic coder for type '%s' in '%s'. ", type(value), - self.requires_deterministic_step_label) + self.requires_deterministic_step_label, + ) self.warn_deterministic_fallback = False if isinstance(value, proto_utils.message_types): stream.write_byte(PROTO_TYPE) @@ -516,7 +548,7 @@ def encode_special_deterministic(self, value, stream): self.iterable_coder_impl.encode_to_stream(values, stream, True) except Exception as e: raise TypeError(self._deterministic_encoding_error_msg(value)) from e - elif isinstance(value, tuple) and hasattr(type(value), '_fields'): + elif isinstance(value, tuple) and hasattr(type(value), "_fields"): stream.write_byte(NAMED_TUPLE_TYPE) self.encode_type(type(value), stream) try: @@ -558,8 +590,8 @@ def _deterministic_encoding_error_msg(self, value): def encode_type_2_67_0(self, t, stream): """ - Encode special type with <=2.67.0 compatibility. - """ + Encode special type with <=2.67.0 compatibility. + """ if t not in _pickled_types: _verify_dill_compat() _pickled_types[t] = dill.dumps(t) @@ -573,7 +605,8 @@ def encode_type(self, t, stream): config = cloudpickle.CloudPickleConfig( id_generator=None, skip_reset_dynamic_type_state=True, - filepath_interceptor=cloudpickle.get_relative_path) + filepath_interceptor=cloudpickle.get_relative_path, + ) if not self.use_relative_filepaths: config.filepath_interceptor = None _pickled_types[t] = cloudpickle_pickler.dumps(t, config=config) @@ -596,14 +629,16 @@ def decode_from_stream(self, stream, nested): elif t == BYTES_TYPE: return stream.read_all(nested) elif t == UNICODE_TYPE: - return stream.read_all(nested).decode('utf-8') - elif t == LIST_TYPE or t == TUPLE_TYPE or t == SET_TYPE: + return stream.read_all(nested).decode("utf-8") + elif t == LIST_TYPE or t == TUPLE_TYPE or t == SET_TYPE or t == FROZENSET_TYPE: vlen = stream.read_var_int64() vlist = [self.decode_from_stream(stream, True) for _ in range(vlen)] if t == LIST_TYPE: return vlist elif t == TUPLE_TYPE: return tuple(vlist) + elif t == FROZENSET_TYPE: + return frozenset(vlist) return set(vlist) elif t == DICT_TYPE: vlen = stream.read_var_int64() @@ -614,6 +649,46 @@ def decode_from_stream(self, stream, nested): return v elif t == BOOL_TYPE: return not not stream.read_byte() + elif t == DATETIME_TYPE: + rfc_9557_str = stream.read_all(nested).decode("utf-8") + first_tag_idx = rfc_9557_str.find("[") + if first_tag_idx == -1: + return datetime.datetime.fromisoformat(rfc_9557_str) + + base_iso = rfc_9557_str[:first_tag_idx] + tags_str = rfc_9557_str[first_tag_idx:] + dt = datetime.datetime.fromisoformat(base_iso) + + fold = 0 + zone_name = None + tz_name = None + + tags = tags_str.replace("]", "").split("[") + for tag in tags: + if not tag: + continue + if tag.startswith("f="): + fold = int(tag[2:]) + elif tag.startswith("tzn="): + tz_name = tag[4:] + elif "=" in tag: + # Skip unknown tags like [knort=blorgel] + continue + else: + zone_name = tag + + if tz_name and (offset := dt.utcoffset()) is not None: + dt = dt.replace(tzinfo=datetime.timezone(offset=offset, name=tz_name)) + elif zoneinfo is not None and zone_name: + dt = dt.replace(tzinfo=zoneinfo.ZoneInfo(zone_name)) + + if fold != dt.fold: + dt = dt.replace(fold=fold) + + return dt + elif t == DATE_TYPE: + return datetime.date.fromisoformat( + stream.read_all(nested).decode("utf-8")) elif t == ITERABLE_LIKE_TYPE: return self.iterable_coder_impl.decode_from_stream(stream, nested) elif t == PROTO_TYPE: @@ -626,7 +701,7 @@ def decode_from_stream(self, stream, nested): vlen = stream.read_var_int64() fields = {} for _ in range(vlen): - field_name = stream.read_all(True).decode('utf-8') + field_name = stream.read_all(True).decode("utf-8") fields[field_name] = self.decode_from_stream(stream, True) return cls(**fields) elif t == DATACLASS_TYPE or t == NAMED_TUPLE_TYPE: @@ -644,7 +719,7 @@ def decode_from_stream(self, stream, nested): elif t == UNKNOWN_TYPE: return self.fallback_coder_impl.decode_from_stream(stream, nested) else: - raise ValueError('Unknown type tag %x' % t) + raise ValueError("Unknown type tag %x" % t) _pickled_types = {} # type: Dict[type, bytes] @@ -653,14 +728,14 @@ def decode_from_stream(self, stream, nested): def _unpickle_type_2_67_0(bs): """ - Decode special type with <=2.67.0 compatibility. - """ + Decode special type with <=2.67.0 compatibility. + """ t = _unpickled_types.get(bs, None) if t is None: _verify_dill_compat() t = _unpickled_types[bs] = dill.loads(bs) # Fix unpicklable anonymous named tuples for Python 3.6. - if t.__base__ is tuple and hasattr(t, '_fields'): + if t.__base__ is tuple and hasattr(t, "_fields"): try: pickle.loads(pickle.dumps(t)) except pickle.PicklingError: @@ -687,7 +762,7 @@ def _unpickle_named_tuple(bs, items): class BytesCoderImpl(CoderImpl): """For internal use only; no backwards-compatibility guarantees. - A coder for bytes/str objects.""" + A coder for bytes/str objects.""" def encode_to_stream(self, value, out, nested): # type: (bytes, create_OutputStream, bool) -> None @@ -713,7 +788,7 @@ def decode(self, encoded): class BooleanCoderImpl(CoderImpl): """For internal use only; no backwards-compatibility guarantees. - A coder for bool objects.""" + A coder for bool objects.""" def encode_to_stream(self, value, out, nested): out.write_byte(1 if value else 0) @@ -726,7 +801,7 @@ def decode_from_stream(self, in_stream, nested): raise ValueError("Expected 0 or 1, got %s" % value) def encode(self, value): - return b'\x01' if value else b'\x00' + return b"\x01" if value else b"\x00" def decode(self, encoded): value = ord(encoded) @@ -744,20 +819,21 @@ def estimate_size(self, unused_value, nested=False): class MapCoderImpl(StreamCoderImpl): """For internal use only; no backwards-compatibility guarantees. - Note this implementation always uses nested context when encoding keys - and values. This differs from Java's MapCoder, which uses - nested=False if possible for the last value encoded. + Note this implementation always uses nested context when encoding keys + and values. This differs from Java's MapCoder, which uses + nested=False if possible for the last value encoded. - This difference is acceptable because MapCoder is not standard. It is only - used in a standard context by RowCoder which always uses nested context for - attribute values. + This difference is acceptable because MapCoder is not standard. It is only + used in a standard context by RowCoder which always uses nested context for + attribute values. - A coder for typing.Mapping objects.""" + A coder for typing.Mapping objects.""" def __init__( self, key_coder, # type: CoderImpl value_coder, # type: CoderImpl - is_deterministic=False): + is_deterministic=False, + ): self._key_coder = key_coder self._value_coder = value_coder self._is_deterministic = is_deterministic @@ -800,7 +876,7 @@ def estimate_size(self, unused_value, nested=False): class NullableCoderImpl(StreamCoderImpl): """For internal use only; no backwards-compatibility guarantees. - A coder for typing.Optional objects.""" + A coder for typing.Optional objects.""" ENCODE_NULL = 0 ENCODE_PRESENT = 1 @@ -918,8 +994,7 @@ def _from_normal_time(self, value): def encode_to_stream(self, value, out, nested): # type: (IntervalWindow, create_OutputStream, bool) -> None typed_value = value - span_millis = ( - typed_value._end_micros // 1000 - typed_value._start_micros // 1000) + span_millis = typed_value._end_micros // 1000 - typed_value._start_micros // 1000 out.write_bigendian_uint64( self._from_normal_time(typed_value._end_micros // 1000)) out.write_var_int64(span_millis) @@ -933,10 +1008,10 @@ def decode_from_stream(self, in_, nested): # instantiating with None is not part of the public interface # pylint: disable=too-many-function-args typed_value = IntervalWindow(None, None) # type: ignore[arg-type] - typed_value._end_micros = ( - 1000 * self._to_normal_time(in_.read_bigendian_uint64())) - typed_value._start_micros = ( - typed_value._end_micros - 1000 * in_.read_var_int64()) + typed_value._end_micros = 1000 * self._to_normal_time( + in_.read_bigendian_uint64()) + typed_value._start_micros = typed_value._end_micros - 1000 * in_.read_var_int64( + ) return typed_value def estimate_size(self, value, nested=False): @@ -944,19 +1019,18 @@ def estimate_size(self, value, nested=False): # An IntervalWindow is context-insensitive, with a timestamp (8 bytes) # and a varint timespam. typed_value = value - span_millis = ( - typed_value._end_micros // 1000 - typed_value._start_micros // 1000) + span_millis = typed_value._end_micros // 1000 - typed_value._start_micros // 1000 return 8 + get_varint_size(span_millis) class TimestampCoderImpl(StreamCoderImpl): """For internal use only; no backwards-compatibility guarantees. - TODO: SDK agnostic encoding - For interoperability with Java SDK, encoding needs to match - that of the Java SDK InstantCoder. - https://github.com/apache/beam/blob/f5029b4f0dfff404310b2ef55e2632bbacc7b04f/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/InstantCoder.java#L79 - """ + TODO: SDK agnostic encoding + For interoperability with Java SDK, encoding needs to match + that of the Java SDK InstantCoder. + https://github.com/apache/beam/blob/f5029b4f0dfff404310b2ef55e2632bbacc7b04f/sdks/java/core/src/main/java/org/apache/beam/sdk/coders/InstantCoder.java#L79 + """ def encode_to_stream(self, value, out, nested): # type: (Timestamp, create_OutputStream, bool) -> None millis = value.micros // 1000 @@ -990,6 +1064,7 @@ def __init__(self, key_coder_impl, window_coder_impl): self._key_coder_impl = key_coder_impl self._windows_coder_impl = TupleSequenceCoderImpl(window_coder_impl) from apache_beam.coders.coders import StrUtf8Coder + self._tag_coder_impl = StrUtf8Coder().get_impl() def encode_to_stream(self, value, out, nested): @@ -1008,6 +1083,7 @@ def encode_to_stream(self, value, out, nested): def decode_from_stream(self, in_stream, nested): # type: (create_InputStream, bool) -> userstate.Timer from apache_beam.transforms import userstate + user_key = self._key_coder_impl.decode_from_stream(in_stream, True) dynamic_timer_tag = self._tag_coder_impl.decode_from_stream(in_stream, True) windows = self._windows_coder_impl.decode_from_stream(in_stream, True) @@ -1020,7 +1096,8 @@ def decode_from_stream(self, in_stream, nested): clear_bit=clear_bit, fire_timestamp=None, hold_timestamp=None, - paneinfo=None) + paneinfo=None, + ) return userstate.Timer( user_key=user_key, @@ -1031,16 +1108,17 @@ def decode_from_stream(self, in_stream, nested): in_stream, True), hold_timestamp=self._timestamp_coder_impl.decode_from_stream( in_stream, True), - paneinfo=self._pane_info_coder_impl.decode_from_stream(in_stream, True)) + paneinfo=self._pane_info_coder_impl.decode_from_stream(in_stream, True), + ) -small_ints = [chr(_).encode('latin-1') for _ in range(128)] +small_ints = [chr(_).encode("latin-1") for _ in range(128)] class VarIntCoderImpl(StreamCoderImpl): """For internal use only; no backwards-compatibility guarantees. - A coder for int objects.""" + A coder for int objects.""" def encode_to_stream(self, value, out, nested): # type: (int, create_OutputStream, bool) -> None try: @@ -1084,7 +1162,7 @@ def estimate_size(self, value, nested=False): class VarInt32CoderImpl(StreamCoderImpl): """For internal use only; no backwards-compatibility guarantees. - A coder for int32 objects.""" + A coder for int32 objects.""" def encode_to_stream(self, value, out, nested): # type: (int, create_OutputStream, bool) -> None out.write_var_int32(value) @@ -1115,7 +1193,7 @@ def estimate_size(self, value, nested=False): class SingletonCoderImpl(CoderImpl): """For internal use only; no backwards-compatibility guarantees. - A coder that always encodes exactly one value.""" + A coder that always encodes exactly one value.""" def __init__(self, value): self._value = value @@ -1128,7 +1206,7 @@ def decode_from_stream(self, stream, nested): return self._value def encode(self, value): - b = b'' # avoid byte vs str vs unicode error + b = b"" # avoid byte vs str vs unicode error return b def decode(self, encoded): @@ -1142,7 +1220,7 @@ def estimate_size(self, value, nested=False): class AbstractComponentCoderImpl(StreamCoderImpl): """For internal use only; no backwards-compatibility guarantees. - CoderImpl for coders that are comprised of several component coders.""" + CoderImpl for coders that are comprised of several component coders.""" def __init__(self, coder_impls): for c in coder_impls: assert isinstance(c, CoderImpl), c @@ -1158,7 +1236,7 @@ def encode_to_stream(self, value, out, nested): # type: (Any, create_OutputStream, bool) -> None values = self._extract_components(value) if len(self._coder_impls) != len(values): - raise ValueError('Number of components does not match number of coders.') + raise ValueError("Number of components does not match number of coders.") for i in range(0, len(self._coder_impls)): c = self._coder_impls[i] # type cast c.encode_to_stream( @@ -1177,7 +1255,7 @@ def estimate_size(self, value, nested=False): """Estimates the encoded size of the given value, in bytes.""" # TODO(ccy): This ignores sizes of observable components. - estimated_size, _ = (self.get_estimated_size_and_observables(value)) + estimated_size, _ = self.get_estimated_size_and_observables(value) return estimated_size def get_estimated_size_and_observables(self, value, nested=False): @@ -1189,9 +1267,9 @@ def get_estimated_size_and_observables(self, value, nested=False): observables = [] # type: Observables for i in range(0, len(self._coder_impls)): c = self._coder_impls[i] # type cast - child_size, child_observables = ( - c.get_estimated_size_and_observables( - values[i], nested=nested or i + 1 < len(self._coder_impls))) + child_size, child_observables = c.get_estimated_size_and_observables( + values[i], nested=nested or i + 1 < len(self._coder_impls) + ) estimated_size += child_size observables += child_observables return estimated_size, observables @@ -1251,35 +1329,35 @@ def __reduce__(self): class SequenceCoderImpl(StreamCoderImpl): """For internal use only; no backwards-compatibility guarantees. - A coder for sequences. + A coder for sequences. - If the length of the sequence in known we encode the length as a 32 bit - ``int`` followed by the encoded bytes. + If the length of the sequence in known we encode the length as a 32 bit + ``int`` followed by the encoded bytes. - If the length of the sequence is unknown, we encode the length as ``-1`` - followed by the encoding of elements buffered up to 64K bytes before prefixing - the count of number of elements. A ``0`` is encoded at the end to indicate the - end of stream. + If the length of the sequence is unknown, we encode the length as ``-1`` + followed by the encoding of elements buffered up to 64K bytes before prefixing + the count of number of elements. A ``0`` is encoded at the end to indicate the + end of stream. - The resulting encoding would look like this:: + The resulting encoding would look like this:: - -1 - countA element(0) element(1) ... element(countA - 1) - countB element(0) element(1) ... element(countB - 1) - ... - countX element(0) element(1) ... element(countX - 1) - 0 + -1 + countA element(0) element(1) ... element(countA - 1) + countB element(0) element(1) ... element(countB - 1) + ... + countX element(0) element(1) ... element(countX - 1) + 0 - If writing to state is enabled, the final terminating 0 will instead be - repaced with:: + If writing to state is enabled, the final terminating 0 will instead be + repaced with:: - varInt64(-1) - len(state_token) - state_token + varInt64(-1) + len(state_token) + state_token - where state_token is a bytes object used to retrieve the remainder of the - iterable via the state API. - """ + where state_token is a bytes object used to retrieve the remainder of the + iterable via the state API. + """ # Default buffer size of 64kB of handling iterables of unknown length. _DEFAULT_BUFFER_SIZE = 64 * 1024 @@ -1289,7 +1367,7 @@ def __init__( elem_coder, # type: CoderImpl read_state=None, # type: Optional[IterableStateReader] write_state=None, # type: Optional[IterableStateWriter] - write_state_threshold=0 # type: int + write_state_threshold=0, # type: int ): self._elem_coder = elem_coder self._read_state = read_state @@ -1302,7 +1380,7 @@ def _construct_from_sequence(self, values): def encode_to_stream(self, value, out, nested): # type: (Sequence, create_OutputStream, bool) -> None # Compatible with Java's IterableLikeCoder. - if hasattr(value, '__len__') and self._write_state is None: + if hasattr(value, "__len__") and self._write_state is None: out.write_bigendian_int32(len(value)) for elem in value: self._elem_coder.encode_to_stream(elem, out, True) @@ -1367,7 +1445,7 @@ def decode_from_stream(self, in_stream, nested): if count == -1: if self._read_state is None: raise ValueError( - 'Cannot read state-written iterable without state reader.') + "Cannot read state-written iterable without state reader.") state_token = in_stream.read_all(True) elements = _ConcatSequence( @@ -1380,7 +1458,7 @@ def estimate_size(self, value, nested=False): """Estimates the encoded size of the given value, in bytes.""" # TODO(ccy): This ignores element sizes. - estimated_size, _ = (self.get_estimated_size_and_observables(value)) + estimated_size, _ = self.get_estimated_size_and_observables(value) return estimated_size def get_estimated_size_and_observables(self, value, nested=False): @@ -1395,9 +1473,9 @@ def get_estimated_size_and_observables(self, value, nested=False): observables = [] # type: Observables for elem in value: - child_size, child_observables = ( - self._elem_coder.get_estimated_size_and_observables( - elem, nested=True)) + child_size, child_observables = self._elem_coder.get_estimated_size_and_observables( + elem, nested=True + ) estimated_size += child_size observables += child_observables # TODO: (https://github.com/apache/beam/issues/18169) Update to use an @@ -1414,7 +1492,7 @@ def get_estimated_size_and_observables(self, value, nested=False): class TupleSequenceCoderImpl(SequenceCoderImpl): """For internal use only; no backwards-compatibility guarantees. - A coder for homogeneous tuple objects.""" + A coder for homogeneous tuple objects.""" def _construct_from_sequence(self, components): return tuple(components) @@ -1430,8 +1508,8 @@ def __iter__(self): def __repr__(self): head = [repr(e) for e in itertools.islice(self, 4)] if len(head) == 4: - head[-1] = '...' - return '_AbstractIterable([%s])' % ', '.join(head) + head[-1] = "..." + return "_AbstractIterable([%s])" % ", ".join(head) # Mostly useful for tests. def __eq__(left, right): @@ -1452,7 +1530,7 @@ def __eq__(left, right): class IterableCoderImpl(SequenceCoderImpl): """For internal use only; no backwards-compatibility guarantees. - A coder for homogeneous iterable objects.""" + A coder for homogeneous iterable objects.""" def __init__(self, *args, use_abstract_iterable=None, **kwargs): super().__init__(*args, **kwargs) if use_abstract_iterable is None: @@ -1469,7 +1547,7 @@ def _construct_from_sequence(self, components): class ListCoderImpl(SequenceCoderImpl): """For internal use only; no backwards-compatibility guarantees. - A coder for homogeneous list objects.""" + A coder for homogeneous list objects.""" def _construct_from_sequence(self, components): return components if isinstance(components, list) else list(components) @@ -1477,12 +1555,12 @@ def _construct_from_sequence(self, components): class PaneInfoEncoding(object): """For internal use only; no backwards-compatibility guarantees. - Encoding used to describe a PaneInfo descriptor. A PaneInfo descriptor - can be encoded in three different ways: with a single byte (FIRST), with a - single byte followed by a varint describing a single index (ONE_INDEX) or - with a single byte followed by two varints describing two separate indices: - the index and nonspeculative index. - """ + Encoding used to describe a PaneInfo descriptor. A PaneInfo descriptor + can be encoded in three different ways: with a single byte (FIRST), with a + single byte followed by a varint describing a single index (ONE_INDEX) or + with a single byte followed by two varints describing two separate indices: + the index and nonspeculative index. + """ FIRST = 0 ONE_INDEX = 1 @@ -1497,10 +1575,10 @@ class PaneInfoEncoding(object): class PaneInfoCoderImpl(StreamCoderImpl): """For internal use only; no backwards-compatibility guarantees. - Coder for a PaneInfo descriptor.""" + Coder for a PaneInfo descriptor.""" def _choose_encoding(self, value): - if ((value._index == 0 and value._nonspeculative_index == 0) or - value._timing == PaneInfoTiming_UNKNOWN): + if (value._index == 0 and value._nonspeculative_index + == 0) or value._timing == PaneInfoTiming_UNKNOWN: return PaneInfoEncoding_FIRST elif (value._index == value._nonspeculative_index or value._timing == windowed_value.PaneInfoTiming.EARLY): @@ -1521,7 +1599,7 @@ def encode_to_stream(self, value, out, nested): out.write_var_int64(value.index) out.write_var_int64(value.nonspeculative_index) else: - raise NotImplementedError('Invalid PaneInfoEncoding: %s' % encoding_type) + raise NotImplementedError("Invalid PaneInfoEncoding: %s" % encoding_type) def decode_from_stream(self, in_stream, nested): # type: (create_InputStream, bool) -> windowed_value.PaneInfo @@ -1541,7 +1619,7 @@ def decode_from_stream(self, in_stream, nested): index = in_stream.read_var_int64() nonspeculative_index = in_stream.read_var_int64() else: - raise NotImplementedError('Invalid PaneInfoEncoding: %s' % encoding_type) + raise NotImplementedError("Invalid PaneInfoEncoding: %s" % encoding_type) return windowed_value.PaneInfo( base.is_first, base.is_last, base.timing, index, nonspeculative_index) @@ -1567,7 +1645,7 @@ def __init__(self, coder_impl_types, fallback_coder_impl): def encode_to_stream(self, value, out, nested): value_t = type(value) - for (ix, t) in enumerate(self._types): + for ix, t in enumerate(self._types): if value_t is t: out.write_byte(ix) c = self._coder_impls[ix] # for typing @@ -1593,7 +1671,7 @@ def decode_from_stream(self, in_stream, nested): class WindowedValueCoderImpl(StreamCoderImpl): """For internal use only; no backwards-compatibility guarantees. - A coder for windowed values.""" + A coder for windowed values.""" # Ensure that lexicographic ordering of the bytes corresponds to # chronological order of timestamps. @@ -1657,9 +1735,10 @@ def decode_from_stream(self, in_stream, nested): value = self._value_coder.decode_from_stream(in_stream, nested) return windowed_value.create( value, - timestamp, # Avoid creation of Timestamp object. + timestamp, windows, - pane_info) + pane_info # Avoid creation of Timestamp object. + ) def get_estimated_size_and_observables(self, value, nested=False): # type: (Any, bool) -> Tuple[int, Observables] @@ -1672,31 +1751,30 @@ def get_estimated_size_and_observables(self, value, nested=False): estimated_size = 0 observables = [] # type: Observables value_estimated_size, value_observables = ( - self._value_coder.get_estimated_size_and_observables( - value.value, nested=nested)) + self._value_coder.get_estimated_size_and_observables(value.value, nested=nested) + ) estimated_size += value_estimated_size observables += value_observables - estimated_size += ( - self._timestamp_coder.estimate_size(value.timestamp, nested=True)) - estimated_size += ( - self._windows_coder.estimate_size(value.windows, nested=True)) - estimated_size += ( - self._pane_info_coder.estimate_size(value.pane_info, nested=True)) + estimated_size += self._timestamp_coder.estimate_size( + value.timestamp, nested=True) + estimated_size += self._windows_coder.estimate_size( + value.windows, nested=True) + estimated_size += self._pane_info_coder.estimate_size( + value.pane_info, nested=True) return estimated_size, observables class ParamWindowedValueCoderImpl(WindowedValueCoderImpl): """For internal use only; no backwards-compatibility guarantees. - A coder for windowed values with constant timestamp, windows and - pane info. The coder drops timestamp, windows and pane info during - encoding, and uses the supplied parameterized timestamp, windows - and pane info values during decoding when reconstructing the windowed - value.""" + A coder for windowed values with constant timestamp, windows and + pane info. The coder drops timestamp, windows and pane info during + encoding, and uses the supplied parameterized timestamp, windows + and pane info values during decoding when reconstructing the windowed + value.""" def __init__(self, value_coder, window_coder, payload): super().__init__(value_coder, TimestampCoderImpl(), window_coder) - self._timestamp, self._windows, self._pane_info = self._from_proto( - payload, window_coder) + self._timestamp, self._windows, self._pane_info = self._from_proto(payload, window_coder) def _from_proto(self, payload, window_coder): windowed_value_coder = WindowedValueCoderImpl( @@ -1722,8 +1800,8 @@ def get_estimated_size_and_observables(self, value, nested=False): estimated_size = 0 observables = [] value_estimated_size, value_observables = ( - self._value_coder.get_estimated_size_and_observables( - value.value, nested=nested)) + self._value_coder.get_estimated_size_and_observables(value.value, nested=nested) + ) estimated_size += value_estimated_size observables += value_observables return estimated_size, observables @@ -1732,7 +1810,7 @@ def get_estimated_size_and_observables(self, value, nested=False): class LengthPrefixCoderImpl(StreamCoderImpl): """For internal use only; no backwards-compatibility guarantees. - Coder which prefixes the length of the encoded object in the stream.""" + Coder which prefixes the length of the encoded object in the stream.""" def __init__(self, value_coder): # type: (CoderImpl) -> None self._value_coder = value_coder @@ -1757,12 +1835,12 @@ def estimate_size(self, value, nested=False): class ShardedKeyCoderImpl(StreamCoderImpl): """For internal use only; no backwards-compatibility guarantees. - A coder for sharded user keys. + A coder for sharded user keys. - The encoding and decoding should follow the order: - shard id byte string - encoded user key - """ + The encoding and decoding should follow the order: + shard id byte string + encoded user key + """ def __init__(self, key_coder_impl): self._shard_id_coder_impl = BytesCoderImpl() self._key_coder_impl = key_coder_impl @@ -1781,23 +1859,22 @@ def decode_from_stream(self, in_stream, nested): def estimate_size(self, value, nested=False): # type: (Any, bool) -> int estimated_size = 0 - estimated_size += ( - self._shard_id_coder_impl.estimate_size(value._shard_id, nested=True)) - estimated_size += ( - self._key_coder_impl.estimate_size(value.key, nested=True)) + estimated_size += self._shard_id_coder_impl.estimate_size( + value._shard_id, nested=True) + estimated_size += self._key_coder_impl.estimate_size(value.key, nested=True) return estimated_size class TimestampPrefixingWindowCoderImpl(StreamCoderImpl): """For internal use only; no backwards-compatibility guarantees. - A coder for custom window types, which prefix required max_timestamp to - encoded original window. + A coder for custom window types, which prefix required max_timestamp to + encoded original window. - The coder encodes and decodes custom window types with following format: - window's max_timestamp() - encoded window using it's own coder. - """ + The coder encodes and decodes custom window types with following format: + window's max_timestamp() + encoded window using it's own coder. + """ def __init__(self, window_coder_impl: CoderImpl) -> None: self._window_coder_impl = window_coder_impl @@ -1810,9 +1887,9 @@ def decode_from_stream(self, stream, nested): return self._window_coder_impl.decode_from_stream(stream, nested) def estimate_size(self, value: Any, nested: bool = False) -> int: - return ( - TimestampCoderImpl().estimate_size(value.max_timestamp()) + - self._window_coder_impl.estimate_size(value, nested)) + return TimestampCoderImpl().estimate_size( + value.max_timestamp()) + self._window_coder_impl.estimate_size( + value, nested) _OpaqueWindow = None @@ -1830,7 +1907,7 @@ def __init__(self, end, encoded_window): self.encoded_window = encoded_window def __repr__(self): - return 'OpaqueWindow(%s, %s)' % (self.end, self.encoded_window) + return "OpaqueWindow(%s, %s)" % (self.end, self.encoded_window) def __hash__(self): return hash(self.encoded_window) @@ -1846,13 +1923,13 @@ def __eq__(self, other): class TimestampPrefixingOpaqueWindowCoderImpl(StreamCoderImpl): """For internal use only; no backwards-compatibility guarantees. - A coder for unknown window types, which prefix required max_timestamp to - encoded original window. + A coder for unknown window types, which prefix required max_timestamp to + encoded original window. - The coder encodes and decodes custom window types with following format: - window's max_timestamp() - length prefixed encoded window - """ + The coder encodes and decodes custom window types with following format: + window's max_timestamp() + length prefixed encoded window + """ def __init__(self) -> None: pass @@ -1866,9 +1943,8 @@ def decode_from_stream(self, stream, nested): max_timestamp.successor(), stream.read_all(True)) def estimate_size(self, value: Any, nested: bool = False) -> int: - return ( - TimestampCoderImpl().estimate_size(value.max_timestamp()) + - len(value.encoded_window)) + return TimestampCoderImpl().estimate_size(value.max_timestamp()) + len( + value.encoded_window) row_coders_registered = False @@ -1943,8 +2019,8 @@ def __init__(self, schema, components): names_no_pos = ", ".join( [f.name for f in self.schema.fields if f.encoding_position is None]) raise ValueError( - f'''Schema with id {schema.id} has encoding_positions_set=True, - but found fields without encoding_position set: {names_no_pos}''') + f"""Schema with id {schema.id} has encoding_positions_set=True, + but found fields without encoding_position set: {names_no_pos}""") self.encoding_positions = list( field.encoding_position for field in self.schema.fields) self.encoding_positions_argsort = list(np.argsort(self.encoding_positions)) @@ -1988,7 +2064,7 @@ def encode_to_stream(self, value, out, nested): if attr is None: if not self.field_nullable[i]: raise ValueError( - "Attempted to encode null for non-nullable field \"{}\".".format( + 'Attempted to encode null for non-nullable field "{}".'.format( self.schema.fields[i].name)) continue component_coder = self.components[i] # for typing @@ -2034,8 +2110,8 @@ def encode_batch_to_stream(self, columns: Dict[str, np.ndarray], out): if has_null_bits[k] and null_flags[k, i]: if not self.field_nullable[i]: raise ValueError( - "Attempted to encode null for non-nullable field \"{}\".". - format(self.schema.fields[i].name)) + 'Attempted to encode null for non-nullable field "{}".'.format( + self.schema.fields[i].name)) else: cython.cast(RowColumnEncoder, attrs[i]).encode_to_stream(k, out) @@ -2119,8 +2195,8 @@ def decode_from_stream(self, in_stream, nested): class BigIntegerCoderImpl(StreamCoderImpl): """For internal use only; no backwards-compatibility guarantees. - For interoperability with Java SDK, encoding needs to match that of the Java - SDK BigIntegerCoder.""" + For interoperability with Java SDK, encoding needs to match that of the Java + SDK BigIntegerCoder.""" def encode_to_stream(self, value, out, nested): # type: (int, create_OutputStream, bool) -> None if value < 0: @@ -2128,20 +2204,20 @@ def encode_to_stream(self, value, out, nested): else: byte_length = (value.bit_length() + 8) // 8 encoded_value = value.to_bytes( - length=byte_length, byteorder='big', signed=True) + length=byte_length, byteorder="big", signed=True) out.write(encoded_value, nested) def decode_from_stream(self, in_stream, nested): # type: (create_InputStream, bool) -> int encoded_value = in_stream.read_all(nested) - return int.from_bytes(encoded_value, byteorder='big', signed=True) + return int.from_bytes(encoded_value, byteorder="big", signed=True) class DecimalCoderImpl(StreamCoderImpl): """For internal use only; no backwards-compatibility guarantees. - For interoperability with Java SDK, encoding needs to match that of the Java - SDK BigDecimalCoder.""" + For interoperability with Java SDK, encoding needs to match that of the Java + SDK BigDecimalCoder.""" BIG_INT_CODER_IMPL = BigIntegerCoderImpl() diff --git a/sdks/python/apache_beam/coders/coders_test_common.py b/sdks/python/apache_beam/coders/coders_test_common.py index 422d494b61c7..b4ba8f92a37a 100644 --- a/sdks/python/apache_beam/coders/coders_test_common.py +++ b/sdks/python/apache_beam/coders/coders_test_common.py @@ -16,10 +16,12 @@ # """Tests common to all coder implementations.""" + # pytype: skip-file import base64 import collections +import datetime import enum import logging import math @@ -66,20 +68,25 @@ except ImportError: dill = None -MyNamedTuple = collections.namedtuple('A', ['x', 'y']) # type: ignore[name-match] -AnotherNamedTuple = collections.namedtuple('AnotherNamedTuple', ['x', 'y']) -MyTypedNamedTuple = NamedTuple('MyTypedNamedTuple', [('f1', int), ('f2', str)]) +try: + import zoneinfo +except ImportError: + zoneinfo = None + +MyNamedTuple = collections.namedtuple("A", ["x", "y"]) # type: ignore[name-match] +AnotherNamedTuple = collections.namedtuple("AnotherNamedTuple", ["x", "y"]) +MyTypedNamedTuple = NamedTuple("MyTypedNamedTuple", [("f1", int), ("f2", str)]) class MyEnum(enum.Enum): E1 = 5 E2 = enum.auto() - E3 = 'abc' + E3 = "abc" -MyIntEnum = enum.IntEnum('MyIntEnum', 'I1 I2 I3') -MyIntFlag = enum.IntFlag('MyIntFlag', 'F1 F2 F3') -MyFlag = enum.Flag('MyFlag', 'F1 F2 F3') # pylint: disable=too-many-function-args +MyIntEnum = enum.IntEnum("MyIntEnum", "I1 I2 I3") +MyIntFlag = enum.IntFlag("MyIntFlag", "F1 F2 F3") +MyFlag = enum.Flag("MyFlag", "F1 F2 F3") # pylint: disable=too-many-function-args class DefinesGetState: @@ -101,7 +108,7 @@ def __setstate__(self, value): # Defined out of line for picklability. class CustomCoder(coders.Coder): def encode(self, x): - return str(x + 1).encode('utf-8') + return str(x + 1).encode("utf-8") def decode(self, encoded): return int(encoded) - 1 @@ -131,7 +138,7 @@ class FrozenUnInitKwOnlyDataClass: def __post_init__(self): # Hack to update an attribute in a frozen dataclass. - object.__setattr__(self, 'area', self.side**2) + object.__setattr__(self, "area", self.side**2) # These tests need to all be run in the same process due to the asserts @@ -149,8 +156,8 @@ class CodersTest(unittest.TestCase): 1, -1, 1.5, - b'str\0str', - 'unicode\0\u0101', + b"str\0str", + "unicode\0\u0101", (), (1, 2, 3), [], @@ -161,13 +168,13 @@ class CodersTest(unittest.TestCase): test_values = test_values_deterministic + [ {}, { - 'a': 'b' + "a": "b" }, { 0: {}, 1: len }, set(), - {'a', 'b'}, + {"a", "b"}, len, ] @@ -180,7 +187,7 @@ def setUpClass(cls): def tearDownClass(cls): standard = set( c for c in coders.__dict__.values() if isinstance(c, type) and - issubclass(c, coders.Coder) and 'Base' not in c.__name__) + issubclass(c, coders.Coder) and "Base" not in c.__name__) standard -= set([ coders.Coder, coders.AvroGenericCoder, @@ -217,8 +224,8 @@ def _observe_nested(cls, coder): cls._observe_nested(c) def check_coder(self, coder, *values, **kwargs): - context = kwargs.pop('context', pipeline_context.PipelineContext()) - test_size_estimation = kwargs.pop('test_size_estimation', True) + context = kwargs.pop("context", pipeline_context.PipelineContext()) + test_size_estimation = kwargs.pop("test_size_estimation", True) assert not kwargs self._observe(coder) for v in values: @@ -229,7 +236,8 @@ def check_coder(self, coder, *values, **kwargs): coder.estimate_size(v), coder.get_impl().estimate_size(v)) self.assertEqual( coder.get_impl().get_estimated_size_and_observables(v), - (coder.get_impl().estimate_size(v), [])) + (coder.get_impl().estimate_size(v), []), + ) copy1 = pickler.loads(pickler.dumps(coder)) copy2 = coders.Coder.from_runner_api(coder.to_runner_api(context), context) for v in values: @@ -241,8 +249,11 @@ def test_custom_coder(self): self.check_coder(CustomCoder(), 1, -10, 5) self.check_coder( - coders.TupleCoder((CustomCoder(), coders.BytesCoder())), (1, b'a'), - (-10, b'b'), (5, b'c')) + coders.TupleCoder((CustomCoder(), coders.BytesCoder())), + (1, b"a"), + (-10, b"b"), + (5, b"c"), + ) def test_pickle_coder(self): coder = coders.PickleCoder() @@ -250,7 +261,7 @@ def test_pickle_coder(self): def test_cloudpickle_pickle_coder(self): cell_value = (lambda x: lambda: x)(0).__closure__[0] - self.check_coder(coders.CloudpickleCoder(), 'a', 1, cell_value) + self.check_coder(coders.CloudpickleCoder(), "a", 1, cell_value) self.check_coder( coders.TupleCoder((coders.VarIntCoder(), coders.CloudpickleCoder())), (1, cell_value)) @@ -265,21 +276,21 @@ def test_memoizing_pickle_coder(self): param(compat_version="2.68.0"), ]) def test_deterministic_coder(self, compat_version): - """ Test in process determinism for all special deterministic types - - - In SDK version <= 2.67.0 dill is used to encode "special types" - - In SDK version 2.68.0 cloudpickle is used to encode "special types" with - absolute filepaths in code objects and dynamic functions. - - In SDK version >=2.69.0 cloudpickle is used to encode "special types" - with relative filepaths in code objects and dynamic functions. - """ + """Test in process determinism for all special deterministic types + + - In SDK version <= 2.67.0 dill is used to encode "special types" + - In SDK version 2.68.0 cloudpickle is used to encode "special types" with + absolute filepaths in code objects and dynamic functions. + - In SDK version >=2.69.0 cloudpickle is used to encode "special types" + with relative filepaths in code objects and dynamic functions. + """ with scoped_pipeline_options( PipelineOptions(update_compatibility_version=compat_version)): coder = coders.FastPrimitivesCoder() if not dill and compat_version == "2.67.0": with self.assertRaises(RuntimeError): coder.as_deterministic_coder(step_label="step") - self.skipTest('Dill not installed') + self.skipTest("Dill not installed") deterministic_coder = coder.as_deterministic_coder(step_label="step") self.check_coder(deterministic_coder, *self.test_values_deterministic) @@ -288,21 +299,22 @@ def test_deterministic_coder(self, compat_version): self.check_coder( coders.TupleCoder( (deterministic_coder, ) * len(self.test_values_deterministic)), - tuple(self.test_values_deterministic)) + tuple(self.test_values_deterministic), + ) self.check_coder(deterministic_coder, {}) - self.check_coder(deterministic_coder, {2: 'x', 1: 'y'}) + self.check_coder(deterministic_coder, {2: "x", 1: "y"}) with self.assertRaises(TypeError): - self.check_coder(deterministic_coder, {1: 'x', 'y': 2}) + self.check_coder(deterministic_coder, {1: "x", "y": 2}) self.check_coder(deterministic_coder, [1, {}]) with self.assertRaises(TypeError): - self.check_coder(deterministic_coder, [1, {1: 'x', 'y': 2}]) + self.check_coder(deterministic_coder, [1, {1: "x", "y": 2}]) self.check_coder( - coders.TupleCoder((deterministic_coder, coder)), (1, {}), ('a', [{}])) + coders.TupleCoder((deterministic_coder, coder)), (1, {}), ("a", [{}])) self.check_coder( - deterministic_coder, test_message.MessageA(field1='value')) + deterministic_coder, test_message.MessageA(field1="value")) # Skip this test during cloudpickle. Dill monkey patches the __reduce__ # method for anonymous named tuples (MyNamedTuple) which is not @@ -310,11 +322,11 @@ def test_deterministic_coder(self, compat_version): if compat_version == "2.67.0": self.check_coder( deterministic_coder, - [MyNamedTuple(1, 2), MyTypedNamedTuple(1, 'a')]) + [MyNamedTuple(1, 2), MyTypedNamedTuple(1, "a")]) self.check_coder( deterministic_coder, - [AnotherNamedTuple(1, 2), MyTypedNamedTuple(1, 'a')]) + [AnotherNamedTuple(1, 2), MyTypedNamedTuple(1, "a")]) if dataclasses is not None: self.check_coder(deterministic_coder, FrozenDataClass(1, 2)) @@ -347,7 +359,7 @@ def test_deterministic_coder(self, compat_version): with self.assertRaises(TypeError): self.check_coder( deterministic_coder, DefinesGetAndSetState({ - 1: 'x', 'y': 2 + 1: "x", "y": 2 })) @parameterized.expand([ @@ -356,19 +368,19 @@ def test_deterministic_coder(self, compat_version): param(compat_version="2.68.0"), ]) def test_deterministic_map_coder_is_update_compatible(self, compat_version): - """ Test in process determinism for map coder including when a component - coder uses DeterministicFastPrimitivesCoder for "special types". - - - In SDK version <= 2.67.0 dill is used to encode "special types" - - In SDK version 2.68.0 cloudpickle is used to encode "special types" with - absolute filepaths in code objects and dynamic functions. - - In SDK version >=2.69.0 cloudpickle is used to encode "special types" - with relative file. - """ + """Test in process determinism for map coder including when a component + coder uses DeterministicFastPrimitivesCoder for "special types". + + - In SDK version <= 2.67.0 dill is used to encode "special types" + - In SDK version 2.68.0 cloudpickle is used to encode "special types" with + absolute filepaths in code objects and dynamic functions. + - In SDK version >=2.69.0 cloudpickle is used to encode "special types" + with relative file. + """ with scoped_pipeline_options( PipelineOptions(update_compatibility_version=compat_version)): values = [{ - MyTypedNamedTuple(i, 'a'): MyTypedNamedTuple('a', i) + MyTypedNamedTuple(i, "a"): MyTypedNamedTuple("a", i) for i in range(10) }] @@ -378,14 +390,16 @@ def test_deterministic_map_coder_is_update_compatible(self, compat_version): if not dill and compat_version == "2.67.0": with self.assertRaises(RuntimeError): coder.as_deterministic_coder(step_label="step") - self.skipTest('Dill not installed') + self.skipTest("Dill not installed") deterministic_coder = coder.as_deterministic_coder(step_label="step") assert isinstance( deterministic_coder._key_coder, - coders.DeterministicFastPrimitivesCoderV2 if compat_version - in (None, "2.68.0") else coders.DeterministicFastPrimitivesCoder) + ( + coders.DeterministicFastPrimitivesCoderV2 if compat_version + in (None, "2.68.0") else coders.DeterministicFastPrimitivesCoder), + ) self.check_coder(deterministic_coder, *values) @@ -393,10 +407,10 @@ def test_dill_coder(self): if not dill: with self.assertRaises(RuntimeError): coders.DillCoder() - self.skipTest('Dill not installed') + self.skipTest("Dill not installed") cell_value = (lambda x: lambda: x)(0).__closure__[0] - self.check_coder(coders.DillCoder(), 'a', 1, cell_value) + self.check_coder(coders.DillCoder(), "a", 1, cell_value) self.check_coder( coders.TupleCoder((coders.VarIntCoder(), coders.DillCoder())), (1, cell_value)) @@ -418,24 +432,60 @@ def test_fake_deterministic_fast_primitives_coder(self): self.check_coder(coders.TupleCoder((coder, )), (v, )) def test_bytes_coder(self): - self.check_coder(coders.BytesCoder(), b'a', b'\0', b'z' * 1000) + self.check_coder(coders.BytesCoder(), b"a", b"\0", b"z" * 1000) def test_bool_coder(self): self.check_coder(coders.BooleanCoder(), True, False) + def test_fast_primitives_coder_datetime(self): + self.check_coder( + coders.FastPrimitivesCoder(), + datetime.datetime(2026, 1, 1), + datetime.datetime( + 2025, + 2, + 3, + tzinfo=datetime.timezone(datetime.timedelta(hours=3, minutes=30))), + datetime.datetime( + 2025, + 2, + 3, + tzinfo=datetime.timezone(datetime.timedelta(hours=3), name="Foo")), + # Nonsense tznaive fold is still preserved. + datetime.datetime(2026, 11, 1, 1, 30, fold=1), + ) + if zoneinfo is not None: + tz = zoneinfo.ZoneInfo("America/New_York") + self.check_coder( + coders.FastPrimitivesCoder(), + datetime.datetime(2026, 11, 1, 1, 30, tzinfo=tz, fold=0), + datetime.datetime(2026, 11, 1, 1, 30, tzinfo=tz, fold=1), + ) + + def test_fast_primitives_coder_date(self): + self.check_coder( + coders.FastPrimitivesCoder(), + datetime.date(2026, 1, 1), + ) + + def test_fast_primitives_coder_frozenset(self): + self.check_coder( + coders.FastPrimitivesCoder(), frozenset(), frozenset(["a", "b", "c"])) + def test_varint_coder(self): # Small ints. self.check_coder(coders.VarIntCoder(), *range(-10, 10)) # Multi-byte encoding starts at 128 self.check_coder(coders.VarIntCoder(), *range(120, 140)) # Large values - MAX_64_BIT_INT = 0x7fffffffffffffff + MAX_64_BIT_INT = 0x7FFFFFFFFFFFFFFF self.check_coder( coders.VarIntCoder(), *[ int(math.pow(-1, k) * math.exp(k)) for k in range(0, int(math.log(MAX_64_BIT_INT))) - ]) + ], + ) def test_varint32_coder(self): # Small ints. @@ -443,27 +493,31 @@ def test_varint32_coder(self): # Multi-byte encoding starts at 128 self.check_coder(coders.VarInt32Coder(), *range(120, 140)) # Large values - MAX_32_BIT_INT = 0x7fffffff + MAX_32_BIT_INT = 0x7FFFFFFF self.check_coder( coders.VarIntCoder(), *[ int(math.pow(-1, k) * math.exp(k)) for k in range(0, int(math.log(MAX_32_BIT_INT))) - ]) + ], + ) def test_float_coder(self): self.check_coder( coders.FloatCoder(), *[float(0.1 * x) for x in range(-100, 100)]) self.check_coder( coders.FloatCoder(), *[float(2**(0.1 * x)) for x in range(-100, 100)]) - self.check_coder(coders.FloatCoder(), float('-Inf'), float('Inf')) + self.check_coder(coders.FloatCoder(), float("-Inf"), float("Inf")) self.check_coder( - coders.TupleCoder((coders.FloatCoder(), coders.FloatCoder())), (0, 1), - (-100, 100), (0.5, 0.25)) + coders.TupleCoder((coders.FloatCoder(), coders.FloatCoder())), + (0, 1), + (-100, 100), + (0.5, 0.25), + ) def test_singleton_coder(self): - a = 'anything' - b = 'something else' + a = "anything" + b = "something else" self.check_coder(coders.SingletonCoder(a), a) self.check_coder(coders.SingletonCoder(b), b) self.check_coder( @@ -474,9 +528,10 @@ def test_interval_window_coder(self): self.check_coder( coders.IntervalWindowCoder(), *[ - window.IntervalWindow(x, y) for x in [-2**52, 0, 2**52] + window.IntervalWindow(x, y) for x in [-(2**52), 0, 2**52] for y in range(-100, 100) - ]) + ], + ) self.check_coder( coders.TupleCoder((coders.IntervalWindowCoder(), )), (window.IntervalWindow(0, 10), )) @@ -490,8 +545,10 @@ def test_paneinfo_window_coder(self): is_last=y == 9, timing=windowed_value.PaneInfoTiming.EARLY, index=y, - nonspeculative_index=-1) for y in range(0, 10) - ]) + nonspeculative_index=-1, + ) for y in range(0, 10) + ], + ) def test_timestamp_coder(self): self.check_coder( @@ -500,14 +557,17 @@ def test_timestamp_coder(self): self.check_coder( coders.TimestampCoder(), timestamp.Timestamp(micros=-1234567000), - timestamp.Timestamp(micros=1234567000)) + timestamp.Timestamp(micros=1234567000), + ) self.check_coder( coders.TimestampCoder(), timestamp.Timestamp(micros=-1234567890123456000), - timestamp.Timestamp(micros=1234567890123456000)) + timestamp.Timestamp(micros=1234567890123456000), + ) self.check_coder( coders.TupleCoder((coders.TimestampCoder(), coders.BytesCoder())), - (timestamp.Timestamp.of(27), b'abc')) + (timestamp.Timestamp.of(27), b"abc"), + ) def test_timer_coder(self): self.check_coder( @@ -520,7 +580,8 @@ def test_timer_coder(self): clear_bit=True, fire_timestamp=None, hold_timestamp=None, - paneinfo=None), + paneinfo=None, + ), userstate.Timer( user_key="key", dynamic_timer_tag="tag", @@ -528,22 +589,28 @@ def test_timer_coder(self): clear_bit=False, fire_timestamp=timestamp.Timestamp.of(123), hold_timestamp=timestamp.Timestamp.of(456), - paneinfo=windowed_value.PANE_INFO_UNKNOWN) - ]) + paneinfo=windowed_value.PANE_INFO_UNKNOWN, + ), + ], + ) def test_tuple_coder(self): kv_coder = coders.TupleCoder((coders.VarIntCoder(), coders.BytesCoder())) # Test binary representation - self.assertEqual(b'\x04abc', kv_coder.encode((4, b'abc'))) + self.assertEqual(b"\x04abc", kv_coder.encode((4, b"abc"))) # Test unnested - self.check_coder(kv_coder, (1, b'a'), (-2, b'a' * 100), (300, b'abc\0' * 5)) + self.check_coder(kv_coder, (1, b"a"), (-2, b"a" * 100), (300, b"abc\0" * 5)) # Test nested self.check_coder( coders.TupleCoder(( coders.TupleCoder((coders.PickleCoder(), coders.VarIntCoder())), coders.StrUtf8Coder(), - coders.BooleanCoder())), ((1, 2), 'a', True), - ((-2, 5), 'a\u0101' * 100, False), ((300, 1), 'abc\0' * 5, True)) + coders.BooleanCoder(), + )), + ((1, 2), "a", True), + ((-2, 5), "a\u0101" * 100, False), + ((300, 1), "abc\0" * 5, True), + ) def test_tuple_sequence_coder(self): int_tuple_coder = coders.TupleSequenceCoder(coders.VarIntCoder()) @@ -553,10 +620,10 @@ def test_tuple_sequence_coder(self): (1, (1, 2, 3))) def test_base64_pickle_coder(self): - self.check_coder(coders.Base64PickleCoder(), 'a', 1, 1.5, (1, 2, 3)) + self.check_coder(coders.Base64PickleCoder(), "a", 1, 1.5, (1, 2, 3)) def test_utf8_coder(self): - self.check_coder(coders.StrUtf8Coder(), 'a', 'ab\u00FF', '\u0101\0') + self.check_coder(coders.StrUtf8Coder(), "a", "ab\u00ff", "\u0101\0") def test_iterable_coder(self): iterable_coder = coders.IterableCoder(coders.VarIntCoder()) @@ -566,7 +633,8 @@ def test_iterable_coder(self): self.check_coder( coders.TupleCoder( (coders.VarIntCoder(), coders.IterableCoder(coders.VarIntCoder()))), - (1, [1, 2, 3])) + (1, [1, 2, 3]), + ) def test_iterable_coder_unknown_length(self): # Empty @@ -586,7 +654,8 @@ def iter_generator(count): iterable_coder = coders.IterableCoder(coders.VarIntCoder()) self.assertCountEqual( list(iter_generator(count)), - iterable_coder.decode(iterable_coder.encode(iter_generator(count)))) + iterable_coder.decode(iterable_coder.encode(iter_generator(count))), + ) def test_list_coder(self): list_coder = coders.ListCoder(coders.VarIntCoder()) @@ -624,7 +693,8 @@ def test_windowedvalue_coder_paneinfo(self): self.check_coder( coder, windowed_value.WindowedValue( - 123, 234, (GlobalWindow(), ), windowed_value.PANE_INFO_UNKNOWN)) + 123, 234, (GlobalWindow(), ), windowed_value.PANE_INFO_UNKNOWN), + ) for value in test_values: self.check_coder(coder, value) @@ -638,34 +708,41 @@ def test_windowed_value_coder(self): coders.VarIntCoder(), coders.GlobalWindowCoder()) # Test binary representation self.assertEqual( - b'\x7f\xdf;dZ\x1c\xac\t\x00\x00\x00\x01\x0f\x01', - coder.encode(window.GlobalWindows.windowed_value(1))) + b"\x7f\xdf;dZ\x1c\xac\t\x00\x00\x00\x01\x0f\x01", + coder.encode(window.GlobalWindows.windowed_value(1)), + ) # Test decoding large timestamp self.assertEqual( - coder.decode(b'\x7f\xdf;dZ\x1c\xac\x08\x00\x00\x00\x01\x0f\x00'), - windowed_value.create(0, MIN_TIMESTAMP.micros, (GlobalWindow(), ))) + coder.decode(b"\x7f\xdf;dZ\x1c\xac\x08\x00\x00\x00\x01\x0f\x00"), + windowed_value.create(0, MIN_TIMESTAMP.micros, (GlobalWindow(), )), + ) # Test unnested self.check_coder( coders.WindowedValueCoder(coders.VarIntCoder()), windowed_value.WindowedValue(3, -100, ()), - windowed_value.WindowedValue(-1, 100, (1, 2, 3))) + windowed_value.WindowedValue(-1, 100, (1, 2, 3)), + ) # Test Global Window self.check_coder( coders.WindowedValueCoder( coders.VarIntCoder(), coders.GlobalWindowCoder()), - window.GlobalWindows.windowed_value(1)) + window.GlobalWindows.windowed_value(1), + ) # Test nested self.check_coder( coders.TupleCoder(( coders.WindowedValueCoder(coders.FloatCoder()), - coders.WindowedValueCoder(coders.StrUtf8Coder()))), + coders.WindowedValueCoder(coders.StrUtf8Coder()), + )), ( windowed_value.WindowedValue(1.5, 0, ()), - windowed_value.WindowedValue("abc", 10, ('window', )))) + windowed_value.WindowedValue("abc", 10, ("window", )), + ), + ) def test_param_windowed_value_coder(self): from apache_beam.transforms.window import IntervalWindow @@ -673,11 +750,12 @@ def test_param_windowed_value_coder(self): # pylint: disable=too-many-function-args wv = windowed_value.create( - b'', + b"", # Milliseconds to microseconds 1000 * 1000, (IntervalWindow(11, 21), ), - PaneInfo(True, False, 1, 2, 3)) + PaneInfo(True, False, 1, 2, 3), + ) windowed_value_coder = coders.WindowedValueCoder( coders.BytesCoder(), coders.IntervalWindowCoder()) payload = windowed_value_coder.encode(wv) @@ -686,7 +764,7 @@ def test_param_windowed_value_coder(self): # Test binary representation self.assertEqual( - b'\x01', coder.encode(window.GlobalWindows.windowed_value(1))) + b"\x01", coder.encode(window.GlobalWindows.windowed_value(1))) # Test unnested self.check_coder( @@ -699,7 +777,8 @@ def test_param_windowed_value_coder(self): windowed_value.WindowedValue( 1, 1, (window.IntervalWindow(11, 21), ), - PaneInfo(True, False, 1, 2, 3))) + PaneInfo(True, False, 1, 2, 3)), + ) # Test nested self.check_coder( @@ -707,8 +786,8 @@ def test_param_windowed_value_coder(self): coders.ParamWindowedValueCoder( payload, [coders.FloatCoder(), coders.IntervalWindowCoder()]), coders.ParamWindowedValueCoder( - payload, - [coders.StrUtf8Coder(), coders.IntervalWindowCoder()]))), + payload, [coders.StrUtf8Coder(), coders.IntervalWindowCoder()]), + )), ( windowed_value.WindowedValue( 1.5, @@ -717,7 +796,9 @@ def test_param_windowed_value_coder(self): windowed_value.WindowedValue( "abc", 1, (window.IntervalWindow(11, 21), ), - PaneInfo(True, False, 1, 2, 3)))) + PaneInfo(True, False, 1, 2, 3)), + ), + ) @parameterized.expand([ param(compat_version=None), @@ -728,22 +809,22 @@ def test_cross_process_encoding_of_special_types_is_deterministic( self, compat_version): """Test cross-process determinism for all special deterministic types - - In SDK version <= 2.67.0 dill is used to encode "special types" - - In SDK version 2.68.0 cloudpickle is used to encode "special types" with - absolute filepaths in code objects and dynamic functions. - - In SDK version 2.69.0 cloudpickle is used to encode "special types" with - relative filepaths in code objects and dynamic functions. - """ + - In SDK version <= 2.67.0 dill is used to encode "special types" + - In SDK version 2.68.0 cloudpickle is used to encode "special types" with + absolute filepaths in code objects and dynamic functions. + - In SDK version 2.69.0 cloudpickle is used to encode "special types" with + relative filepaths in code objects and dynamic functions. + """ is_using_dill = compat_version == "2.67.0" if is_using_dill: pytest.importorskip("dill") if sys.executable is None: - self.skipTest('No Python interpreter found') + self.skipTest("No Python interpreter found") # pylint: disable=line-too-long script = textwrap.dedent( - f'''\ + f"""\ import pickle import sys import collections @@ -820,10 +901,10 @@ def test_cross_process_encoding_of_special_types_is_deterministic( sys.stdout.buffer.write(pickle.dumps(results)) - ''') + """) def run_subprocess(): - result = subprocess.run([sys.executable, '-c', script], + result = subprocess.run([sys.executable, "-c", script], capture_output=True, timeout=30, check=False) @@ -861,11 +942,13 @@ def run_subprocess(): named_tuple_type = type(decoded1) self.assertEqual( os.path.isabs(named_tuple_type._make.__code__.co_filename), - not should_have_relative_path) + not should_have_relative_path, + ) self.assertEqual( os.path.isabs( - named_tuple_type.__getnewargs__.__globals__['__file__']), - not should_have_relative_path) + named_tuple_type.__getnewargs__.__globals__["__file__"]), + not should_have_relative_path, + ) self.assertEqual( decoded1, decoded2, f"Cross-process decoding differs for {test_name}") @@ -880,23 +963,23 @@ def test_proto_coder(self): ma = test_message.MessageA() mab = ma.field2.add() mab.field1 = True - ma.field1 = 'hello world' + ma.field1 = "hello world" mb = test_message.MessageA() - mb.field1 = 'beam' + mb.field1 = "beam" proto_coder = coders.ProtoCoder(ma.__class__) self.check_coder(proto_coder, ma) self.check_coder( - coders.TupleCoder((proto_coder, coders.BytesCoder())), (ma, b'a'), - (mb, b'b')) + coders.TupleCoder((proto_coder, coders.BytesCoder())), (ma, b"a"), + (mb, b"b")) def test_global_window_coder(self): coder = coders.GlobalWindowCoder() value = window.GlobalWindow() # Test binary representation - self.assertEqual(b'', coder.encode(value)) - self.assertEqual(value, coder.decode(b'')) + self.assertEqual(b"", coder.encode(value)) + self.assertEqual(value, coder.decode(b"")) # Test unnested self.check_coder(coder, value) # Test nested @@ -905,15 +988,15 @@ def test_global_window_coder(self): def test_length_prefix_coder(self): coder = coders.LengthPrefixCoder(coders.BytesCoder()) # Test binary representation - self.assertEqual(b'\x00', coder.encode(b'')) - self.assertEqual(b'\x01a', coder.encode(b'a')) - self.assertEqual(b'\x02bc', coder.encode(b'bc')) - self.assertEqual(b'\xff\x7f' + b'z' * 16383, coder.encode(b'z' * 16383)) + self.assertEqual(b"\x00", coder.encode(b"")) + self.assertEqual(b"\x01a", coder.encode(b"a")) + self.assertEqual(b"\x02bc", coder.encode(b"bc")) + self.assertEqual(b"\xff\x7f" + b"z" * 16383, coder.encode(b"z" * 16383)) # Test unnested - self.check_coder(coder, b'', b'a', b'bc', b'def') + self.check_coder(coder, b"", b"a", b"bc", b"def") # Test nested self.check_coder( - coders.TupleCoder((coder, coder)), (b'', b'a'), (b'bc', b'def')) + coders.TupleCoder((coder, coder)), (b"", b"a"), (b"bc", b"def")) def test_nested_observables(self): class FakeObservableIterator(observable.ObservableMixin): @@ -930,14 +1013,16 @@ def __iter__(self): value = windowed_value.WindowedValue(observ, 0, ()) self.assertEqual( coder.get_impl().get_estimated_size_and_observables(value)[1], - [(observ, elem_coder.get_impl())]) + [(observ, elem_coder.get_impl())], + ) # Test nested tuple observable. coder = coders.TupleCoder((coders.StrUtf8Coder(), iter_coder)) - value = ('123', observ) + value = ("123", observ) self.assertEqual( coder.get_impl().get_estimated_size_and_observables(value)[1], - [(observ, elem_coder.get_impl())]) + [(observ, elem_coder.get_impl())], + ) def test_state_backed_iterable_coder(self): # pylint: disable=global-variable-undefined @@ -946,7 +1031,7 @@ def test_state_backed_iterable_coder(self): state = {} def iterable_state_write(values, element_coder_impl): - token = b'state_token_%d' % len(state) + token = b"state_token_%d" % len(state) state[token] = [element_coder_impl.encode(e) for e in values] return token @@ -957,7 +1042,8 @@ def iterable_state_read(token, element_coder_impl): coders.VarIntCoder(), read_state=iterable_state_read, write_state=iterable_state_write, - write_state_threshold=1) + write_state_threshold=1, + ) # Note: do not use check_coder # see https://github.com/cloudpipe/cloudpickle/issues/452 self._observe(coder) @@ -988,27 +1074,30 @@ def test_map_coder(self): self.check_coder(map_coder.as_deterministic_coder("label"), *values) def test_sharded_key_coder(self): - key_and_coders = [(b'', b'\x00', coders.BytesCoder()), - (b'key', b'\x03key', coders.BytesCoder()), - ('key', b'\03\x6b\x65\x79', coders.StrUtf8Coder()), - (('k', 1), - b'\x01\x6b\x01', - coders.TupleCoder( - (coders.StrUtf8Coder(), coders.VarIntCoder())))] + key_and_coders = [ + (b"", b"\x00", coders.BytesCoder()), + (b"key", b"\x03key", coders.BytesCoder()), + ("key", b"\03\x6b\x65\x79", coders.StrUtf8Coder()), + ( + ("k", 1), + b"\x01\x6b\x01", + coders.TupleCoder((coders.StrUtf8Coder(), coders.VarIntCoder())), + ), + ] for key, bytes_repr, key_coder in key_and_coders: coder = coders.ShardedKeyCoder(key_coder) # Test str repr - self.assertEqual('%s' % coder, 'ShardedKeyCoder[%s]' % key_coder) + self.assertEqual("%s" % coder, "ShardedKeyCoder[%s]" % key_coder) - self.assertEqual(b'\x00' + bytes_repr, coder.encode(ShardedKey(key, b''))) + self.assertEqual(b"\x00" + bytes_repr, coder.encode(ShardedKey(key, b""))) self.assertEqual( - b'\x03123' + bytes_repr, coder.encode(ShardedKey(key, b'123'))) + b"\x03123" + bytes_repr, coder.encode(ShardedKey(key, b"123"))) # Test unnested - self.check_coder(coder, ShardedKey(key, b'')) - self.check_coder(coder, ShardedKey(key, b'123')) + self.check_coder(coder, ShardedKey(key, b"")) + self.check_coder(coder, ShardedKey(key, b"123")) # Test type hints self.assertTrue( @@ -1022,30 +1111,35 @@ def test_sharded_key_coder(self): self.assertEqual( coders.ShardedKeyCoder.from_type_hint( coder.to_type_hint(), typecoders.CoderRegistry()), - coder) + coder, + ) for other_key, _, other_key_coder in key_and_coders: other_coder = coders.ShardedKeyCoder(other_key_coder) # Test nested self.check_coder( coders.TupleCoder((coder, other_coder)), - (ShardedKey(key, b''), ShardedKey(other_key, b''))) + (ShardedKey(key, b""), ShardedKey(other_key, b"")), + ) self.check_coder( coders.TupleCoder((coder, other_coder)), - (ShardedKey(key, b'123'), ShardedKey(other_key, b''))) + (ShardedKey(key, b"123"), ShardedKey(other_key, b"")), + ) def test_timestamp_prefixing_window_coder(self): self.check_coder( coders.TimestampPrefixingWindowCoder(coders.IntervalWindowCoder()), *[ - window.IntervalWindow(x, y) for x in [-2**52, 0, 2**52] + window.IntervalWindow(x, y) for x in [-(2**52), 0, 2**52] for y in range(-100, 100) - ]) + ], + ) self.check_coder( coders.TupleCoder(( coders.TimestampPrefixingWindowCoder( coders.IntervalWindowCoder()), )), - (window.IntervalWindow(0, 10), )) + (window.IntervalWindow(0, 10), ), + ) def test_timestamp_prefixing_opaque_window_coder(self): sdk_coder = coders.TimestampPrefixingWindowCoder( @@ -1087,10 +1181,12 @@ def test_byte_coder(self): base64.b64encode(test_coder.encode(value)).decode().rstrip("=")) def test_OrderedUnionCoder(self): - test_coder = coders._OrderedUnionCoder((str, coders.StrUtf8Coder()), - (int, coders.VarIntCoder()), - fallback_coder=coders.FloatCoder()) - self.check_coder(test_coder, 's') + test_coder = coders._OrderedUnionCoder( + (str, coders.StrUtf8Coder()), + (int, coders.VarIntCoder()), + fallback_coder=coders.FloatCoder(), + ) + self.check_coder(test_coder, "s") self.check_coder(test_coder, 123) self.check_coder(test_coder, 1.5) @@ -1103,12 +1199,14 @@ def test_OrderedUnionCoderDeterministic(self): self.assertFalse(test_coder.is_deterministic()) - test_coder = coders._OrderedUnionCoder((str, coders.StrUtf8Coder()), - (int, coders.VarIntCoder()), - fallback_coder=coders.FloatCoder()) + test_coder = coders._OrderedUnionCoder( + (str, coders.StrUtf8Coder()), + (int, coders.VarIntCoder()), + fallback_coder=coders.FloatCoder(), + ) self.assertTrue(test_coder.is_deterministic()) -if __name__ == '__main__': +if __name__ == "__main__": logging.getLogger().setLevel(logging.INFO) unittest.main()