diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 0a9b26931b..6328199475 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -187,17 +187,24 @@ def add_recording_segment(self, recording_segment: "BaseRecordingSegment") -> No """ super().add_segment(recording_segment) - def get_sample_size_in_bytes(self): + def get_sample_size_in_bytes(self, dtype=None): """ Returns the size of a single sample across all channels in bytes. + Parameters + ---------- + dtype : data-type, optional + The data type to use for calculating the sample size. If None, + the recording's dtype is used. + Returns ------- int The size of a single sample in bytes """ num_channels = self.get_num_channels() - dtype_size_bytes = self.get_dtype().itemsize + dtype = self.get_dtype() if dtype is None else np.dtype(dtype) + dtype_size_bytes = dtype.itemsize sample_size = num_channels * dtype_size_bytes return sample_size diff --git a/src/spikeinterface/core/time_series.py b/src/spikeinterface/core/time_series.py index d4d4717dff..62113dafd2 100644 --- a/src/spikeinterface/core/time_series.py +++ b/src/spikeinterface/core/time_series.py @@ -34,7 +34,7 @@ def get_num_samples(self, segment_index: int | None = None) -> int: raise NotImplementedError @abstractmethod - def get_sample_size_in_bytes(self) -> int: + def get_sample_size_in_bytes(self, dtype=None) -> int: raise NotImplementedError @abstractmethod diff --git a/src/spikeinterface/core/time_series_tools.py b/src/spikeinterface/core/time_series_tools.py index fad697f94e..efdedc8282 100644 --- a/src/spikeinterface/core/time_series_tools.py +++ b/src/spikeinterface/core/time_series_tools.py @@ -62,9 +62,7 @@ def write_binary( if add_file_extension: file_path_list = [add_suffix(file_path, ["raw", "bin", "dat"]) for file_path in file_path_list] - dtype = dtype if dtype is not None else time_series.get_dtype() - - sample_size_bytes = time_series.get_sample_size_in_bytes() + sample_size_bytes = time_series.get_sample_size_in_bytes(dtype=dtype) file_path_dict = {segment_index: file_path for segment_index, file_path in enumerate(file_path_list)} if file_timestamps_paths is not None: @@ -125,7 +123,7 @@ def _write_binary_chunk(segment_index, start_frame, end_frame, worker_ctx): byte_offset = worker_ctx["byte_offset"] file = worker_ctx["file_dict"][segment_index] file_timestamps_dict = worker_ctx["file_timestamps_dict"] - sample_size_bytes = time_series.get_sample_size_in_bytes() + sample_size_bytes = time_series.get_sample_size_in_bytes(dtype=dtype) # Calculate byte offsets for the start frames relative to the entire recording start_byte = byte_offset + start_frame * sample_size_bytes