Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,17 +187,24 @@ def add_recording_segment(self, recording_segment: "BaseRecordingSegment") -> No
"""
super().add_segment(recording_segment)

def get_sample_size_in_bytes(self):
def get_sample_size_in_bytes(self, dtype=None):
"""
Returns the size of a single sample across all channels in bytes.

Parameters
----------
dtype : data-type, optional
The data type to use for calculating the sample size. If None,
the recording's dtype is used.

Returns
-------
int
The size of a single sample in bytes
"""
num_channels = self.get_num_channels()
dtype_size_bytes = self.get_dtype().itemsize
dtype = self.get_dtype() if dtype is None else np.dtype(dtype)
dtype_size_bytes = dtype.itemsize
sample_size = num_channels * dtype_size_bytes
return sample_size

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_num_samples(self, segment_index: int | None = None) -> int:
raise NotImplementedError

@abstractmethod
def get_sample_size_in_bytes(self) -> int:
def get_sample_size_in_bytes(self, dtype=None) -> int:
raise NotImplementedError

@abstractmethod
Expand Down
6 changes: 2 additions & 4 deletions src/spikeinterface/core/time_series_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@ def write_binary(
if add_file_extension:
file_path_list = [add_suffix(file_path, ["raw", "bin", "dat"]) for file_path in file_path_list]

dtype = dtype if dtype is not None else time_series.get_dtype()

sample_size_bytes = time_series.get_sample_size_in_bytes()
sample_size_bytes = time_series.get_sample_size_in_bytes(dtype=dtype)

file_path_dict = {segment_index: file_path for segment_index, file_path in enumerate(file_path_list)}
if file_timestamps_paths is not None:
Expand Down Expand Up @@ -125,7 +123,7 @@ def _write_binary_chunk(segment_index, start_frame, end_frame, worker_ctx):
byte_offset = worker_ctx["byte_offset"]
file = worker_ctx["file_dict"][segment_index]
file_timestamps_dict = worker_ctx["file_timestamps_dict"]
sample_size_bytes = time_series.get_sample_size_in_bytes()
sample_size_bytes = time_series.get_sample_size_in_bytes(dtype=dtype)

# Calculate byte offsets for the start frames relative to the entire recording
start_byte = byte_offset + start_frame * sample_size_bytes
Expand Down
Loading