Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 39 additions & 3 deletions src/spikeinterface/preprocessing/common_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,17 @@ class CommonReferenceRecording(BasePreprocessor):
ref_channel_ids : list | str | int | None, default: None
If "global" reference, a list of channels to be used as reference.
If "single" reference, a list of one channel or a single channel id is expected.
If "groups" is provided, then a list of channels to be applied to each group is expected.
If "groups" is provided with "single" reference, a list with one reference channel id
per group is expected.
If "groups" is provided with "global" reference, a list with one *list* of reference
channel ids per group is expected: the reference subtracted from each group is the
operator (median/average) over that group's reference set. The reference set may contain
channels outside the group, enabling cross-group referencing (e.g. referencing each
tetrode to the median of all channels on the other tetrodes). If None, each group is
referenced to its own channels.
As a shortcut for that cross-group case, pass the string "complement" (with "global"
reference and "groups"): each group is then referenced to all channels NOT in it,
i.e. ref_channel_ids is auto-built as each group's complement.
local_radius : tuple(int, int), default: (30, 55)
Use in the local CAR implementation as the selecting annulus with the following format:

Expand Down Expand Up @@ -98,9 +108,27 @@ def __init__(
raise ValueError("'operator' must be either 'median', 'average'")

if reference == "global":
if ref_channel_ids == "complement":
# Convenience: reference each group to all channels NOT in it (its complement).
if groups is None:
raise ValueError("ref_channel_ids='complement' requires 'groups' to be set")
all_ids = list(recording.channel_ids)
ref_channel_ids = [[c for c in all_ids if c not in set(group)] for group in groups]
if ref_channel_ids is not None:
if not isinstance(ref_channel_ids, list):
raise ValueError("With 'global' reference, provide 'ref_channel_ids' as a list")
if groups is not None:
# Per-group reference sets: one list of channel ids per group. The reference
# subtracted from each group is the operator over that group's reference set
# (which may be channels outside the group, e.g. for cross-group referencing).
assert len(ref_channel_ids) == len(groups), (
"With 'global' reference and 'groups', 'ref_channel_ids' must be a list "
"with one channel-id list per group"
)
assert all(isinstance(r, (list, np.ndarray)) for r in ref_channel_ids), (
"With 'global' reference and 'groups', each element of 'ref_channel_ids' "
"must itself be a list of channel ids (the reference set for that group)"
)
elif reference == "single":
assert ref_channel_ids is not None, "With 'single' reference, provide 'ref_channel_ids'"
if groups is not None:
Expand Down Expand Up @@ -150,7 +178,11 @@ def __init__(
else:
group_indices = None
if ref_channel_ids is not None:
ref_channel_indices = self.ids_to_indices(ref_channel_ids)
if reference == "global" and groups is not None:
# one reference-channel index array per group
ref_channel_indices = [self.ids_to_indices(r) for r in ref_channel_ids]
else:
ref_channel_indices = self.ids_to_indices(ref_channel_ids)
else:
ref_channel_indices = None

Expand Down Expand Up @@ -246,7 +278,11 @@ def get_traces(self, start_frame, end_frame, channel_indices):
in_group_traces = traces[:, selected_indices_in_group]

if self.reference == "global":
shift = self.operator_func(traces[:, all_group_indices], axis=1, keepdims=True)
if self.ref_channel_indices is None:
ref_indices = all_group_indices # reference each group to its own channels
else:
ref_indices = self.ref_channel_indices[group_index] # per-group reference set
shift = self.operator_func(traces[:, ref_indices], axis=1, keepdims=True)
re_referenced_traces[:, out_indices] = in_group_traces - shift
else:
# single (as local is not allowed for groups)
Expand Down
42 changes: 42 additions & 0 deletions src/spikeinterface/preprocessing/tests/test_common_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,48 @@ def test_common_reference_groups(recording):
assert np.allclose(traces[:, 1], 0)


def test_common_reference_groups_cross(recording):
# "global" reference with groups AND a per-group ref_channel_ids: each group is
# referenced to a (possibly external) set of channels -> enables cross-group referencing.
original_traces = recording.get_traces()
groups = [["a", "c"], ["b", "d"]]
ref_channel_ids = [["b", "d"], ["a", "c"]] # reference each group to the OTHER group's channels

rec_cross = common_reference(
recording, reference="global", operator="median", groups=groups, ref_channel_ids=ref_channel_ids
)
traces = rec_cross.get_traces(channel_ids=["a", "b", "c", "d"])
# a, c (group 0) referenced to median of b, d
ref0 = np.median(original_traces[:, [1, 3]], axis=1)
assert np.allclose(traces[:, 0], original_traces[:, 0] - ref0, atol=0.01)
assert np.allclose(traces[:, 2], original_traces[:, 2] - ref0, atol=0.01)
# b, d (group 1) referenced to median of a, c
ref1 = np.median(original_traces[:, [0, 2]], axis=1)
assert np.allclose(traces[:, 1], original_traces[:, 1] - ref1, atol=0.01)
assert np.allclose(traces[:, 3], original_traces[:, 3] - ref1, atol=0.01)

# mismatched lengths raise
with pytest.raises(AssertionError):
common_reference(recording, reference="global", groups=groups, ref_channel_ids=[["b", "d"]])


def test_common_reference_groups_complement(recording):
# ref_channel_ids="complement" shortcut: reference each group to all channels NOT in it.
groups = [["a", "c"], ["b", "d"]]
# complements of these groups within {a,b,c,d} are exactly [["b","d"], ["a","c"]]
explicit = common_reference(
recording, reference="global", operator="median", groups=groups, ref_channel_ids=[["b", "d"], ["a", "c"]]
)
sugar = common_reference(
recording, reference="global", operator="median", groups=groups, ref_channel_ids="complement"
)
assert np.allclose(sugar.get_traces(), explicit.get_traces(), atol=1e-6)

# "complement" requires groups
with pytest.raises(ValueError):
common_reference(recording, reference="global", ref_channel_ids="complement")


def test_min_local_radius():
# Test that local radius smaller than the number of channels is handled correctly
recording = generate_recording(durations=[1.0], num_channels=32)
Expand Down
Loading