From f36291fcc1e3a51d1e9e7822cc03e3fa90709d3d Mon Sep 17 00:00:00 2001 From: CARRARA Igor Date: Mon, 29 Jul 2024 16:03:34 +0200 Subject: [PATCH 01/17] Pseudo Online evaluation and dataset conversion for BNCI2014001, BNCI2014002, BNCI2014004, BNCI2015001 --- examples/plot_pseudoonline.py | 63 ++++++++++++ moabb/datasets/base.py | 2 + moabb/datasets/bnci.py | 171 +++++++++++++++++++++++++++++-- moabb/datasets/preprocessing.py | 80 +++++++++++++-- moabb/evaluations/utils.py | 5 + moabb/paradigms/base.py | 9 +- moabb/paradigms/motor_imagery.py | 22 +++- 7 files changed, 332 insertions(+), 20 deletions(-) create mode 100644 examples/plot_pseudoonline.py diff --git a/examples/plot_pseudoonline.py b/examples/plot_pseudoonline.py new file mode 100644 index 000000000..f636672ca --- /dev/null +++ b/examples/plot_pseudoonline.py @@ -0,0 +1,63 @@ +# Set up the Directory for made it run on a server. +import sys +import os +import moabb +import mne +import resource +from moabb.paradigms import MotorImagery +from moabb.datasets import BNCI2014_001 +from pyriemann.classification import FgMDM +from sklearn.pipeline import Pipeline +from moabb.evaluations import WithinSessionEvaluation +from pyriemann.classification import MDM +from pyriemann.estimation import Covariances +import numpy as np + +sub = 1 + +# Initialize parameter for the Band Pass filter +fmin = 8 +fmax = 30 +tmin = 0 +tmax = 2 + +# Load Dataset and switch to Pseudoonline mode +dataset = BNCI2014_001() +dataset.pseudoonline = True + +#events = ["right_hand", "left_hand"] +events = list(dataset.event_id.keys()) + +paradigm = MotorImagery(events=events, n_classes=len(events), fmin=fmin, fmax=fmax, tmax=tmax, overlap=50) + +X, y, meta = paradigm.get_data(dataset=dataset, subjects=[sub]) +print("Print Events_id:", y) +unique, counts = np.unique(y, return_counts=True) +print("Number of events per class:", dict(zip(unique, counts))) + + +pipelines = {} +pipelines["MDM"] = Pipeline(steps=[ + ("Covariances", Covariances("cov")), + ("MDM", MDM(metric=dict(mean='riemann', distance='riemann'))) +]) + +pipelines["FgMDM"] = Pipeline(steps=[ + ("Covariances", Covariances("cov")), + ("FgMDM", FgMDM()) +]) + +# Select an evaluation Within Session +evaluation_online = WithinSessionEvaluation(paradigm=paradigm, + datasets=dataset, + overwrite=True, + random_state=42, + n_jobs=-1 + ) + +# Print the results +results_ALL = evaluation_online.process(pipelines) +results_pipeline = results_ALL.groupby(['pipeline'], as_index=False)["score"].mean() +results_pipeline_std = results_ALL.groupby(['pipeline'], as_index=False)["score"].std() +results_pipeline['std'] = results_pipeline_std["score"] +print(results_pipeline) \ No newline at end of file diff --git a/moabb/datasets/base.py b/moabb/datasets/base.py index 0d4672482..9a9aa3c4e 100644 --- a/moabb/datasets/base.py +++ b/moabb/datasets/base.py @@ -319,6 +319,7 @@ def __init__( paradigm, doi=None, unit_factor=1e6, + pseudoonline=False ): """Initialize function for the BaseDataset.""" try: @@ -348,6 +349,7 @@ def __init__( self.paradigm = paradigm self.doi = doi self.unit_factor = unit_factor + self.pseudoonline = pseudoonline def _create_process_pipeline(self): return Pipeline( diff --git a/moabb/datasets/bnci.py b/moabb/datasets/bnci.py index 0cef42910..2a058d4d1 100644 --- a/moabb/datasets/bnci.py +++ b/moabb/datasets/bnci.py @@ -1,7 +1,7 @@ """BNCI 2014-001 Motor imagery dataset.""" import numpy as np -from mne import create_info +from mne import create_info, find_events from mne.channels import make_standard_montage from mne.io import RawArray from mne.utils import verbose @@ -33,6 +33,7 @@ def load_data( base_url=BNCI_URL, only_filenames=False, verbose=None, + pseudoonline=False ): # noqa: D301 """Get paths to local copies of a BNCI dataset files. @@ -116,6 +117,7 @@ def load_data( baseurl_list[dataset], only_filenames, verbose, + pseudoonline ) @@ -128,6 +130,7 @@ def _load_data_001_2014( base_url=BNCI_URL, only_filenames=False, verbose=None, + pseudoonline=False ): """Load data for 001-2014 dataset.""" if (subject < 1) or (subject > 9): @@ -144,13 +147,19 @@ def _load_data_001_2014( sessions = {} filenames = [] + time_task = 4 + time_fix = 2 for session_idx, r in enumerate(["T", "E"]): url = "{u}001-2014/A{s:02d}{r}.mat".format(u=base_url, s=subject, r=r) filename = data_path(url, path, force_update, update_path) filenames += filename if only_filenames: continue - runs, ev = _convert_mi(filename[0], ch_names, ch_types) + + if pseudoonline: + runs, ev = _convert_mi_pseudoonline(filename[0], time_task, time_fix, ch_names, ch_types, pseudoonline) + else: + runs, ev = _convert_mi(filename[0], ch_names, ch_types) # FIXME: deal with run with no event (1:3) and name them sessions[f"{session_idx}{_map[r]}"] = { str(ii): run for ii, run in enumerate(runs) @@ -169,12 +178,15 @@ def _load_data_002_2014( base_url=BNCI_URL, only_filenames=False, verbose=None, + pseudoonline=False ): """Load data for 002-2014 dataset.""" if (subject < 1) or (subject > 14): raise ValueError("Subject must be between 1 and 14. Got %d." % subject) runs = [] + time_task = 5 + time_fix = 3 filenames = [] for r in ["T", "E"]: url = "{u}002-2014/S{s:02d}{r}.mat".format(u=base_url, s=subject, r=r) @@ -183,7 +195,10 @@ def _load_data_002_2014( if only_filenames: continue # FIXME: electrode position and name are not provided directly. - raws, _ = _convert_mi(filename, None, ["eeg"] * 15) + if pseudoonline: + raws, _ = _convert_mi_pseudoonline(filename, time_task, time_fix, None, ["eeg"] * 15, pseudoonline) + else: + raws, _ = _convert_mi(filename, None, ["eeg"] * 15) runs.extend(zip([r] * len(raws), raws)) if only_filenames: return filenames @@ -200,6 +215,7 @@ def _load_data_004_2014( base_url=BNCI_URL, only_filenames=False, verbose=None, + pseudoonline=False ): """Load data for 004-2014 dataset.""" if (subject < 1) or (subject > 9): @@ -209,6 +225,8 @@ def _load_data_004_2014( ch_types = ["eeg"] * 3 + ["eog"] * 3 sessions = [] + time_task = 4.5 + time_fix = 3 filenames = [] for r in ["T", "E"]: url = "{u}004-2014/B{s:02d}{r}.mat".format(u=base_url, s=subject, r=r) @@ -216,7 +234,10 @@ def _load_data_004_2014( filenames.append(filename) if only_filenames: continue - raws, _ = _convert_mi(filename, ch_names, ch_types) + if pseudoonline: + raws, _ = _convert_mi_pseudoonline(filename, time_task, time_fix, ch_names, ch_types, pseudoonline) + else: + raws, _ = _convert_mi(filename, ch_names, ch_types) sessions.extend(zip([r] * len(raws), raws)) if only_filenames: @@ -234,7 +255,12 @@ def _load_data_008_2014( base_url=BNCI_URL, only_filenames=False, verbose=None, + pseudoonline=False ): + + if pseudoonline: + raise ValueError("Pseudo Online evaluation not currently implemented for this dataset") + """Load data for 008-2014 dataset.""" if (subject < 1) or (subject > 8): raise ValueError("Subject must be between 1 and 8. Got %d." % subject) @@ -260,7 +286,10 @@ def _load_data_009_2014( base_url=BNCI_URL, only_filenames=False, verbose=None, + pseudoonline=False ): + if pseudoonline: + raise ValueError("Pseudo Online evaluation not currently implemented for this dataset") """Load data for 009-2014 dataset.""" if (subject < 1) or (subject > 10): raise ValueError("Subject must be between 1 and 10. Got %d." % subject) @@ -299,6 +328,7 @@ def _load_data_001_2015( base_url=BNCI_URL, only_filenames=False, verbose=None, + pseudoonline=False ): """Load data for 001-2015 dataset.""" if (subject < 1) or (subject > 12): @@ -318,6 +348,8 @@ def _load_data_001_2015( ch_types = ["eeg"] * 13 sessions = {} + time_task = 5 + time_fix = 0 filenames = [] for session_idx, r in ses: url = "{u}001-2015/S{s:02d}{r}.mat".format(u=base_url, s=subject, r=r) @@ -325,7 +357,10 @@ def _load_data_001_2015( filenames += filename if only_filenames: continue - runs, ev = _convert_mi(filename[0], ch_names, ch_types) + if pseudoonline: + runs, ev = _convert_mi_pseudoonline(filename[0], time_task, time_fix, ch_names, ch_types, pseudoonline) + else: + runs, ev = _convert_mi(filename[0], ch_names, ch_types) sessions[f"{session_idx}{r}"] = {str(ii): run for ii, run in enumerate(runs)} if only_filenames: return filenames @@ -341,7 +376,10 @@ def _load_data_003_2015( base_url=BNCI_URL, only_filenames=False, verbose=None, + pseudoonline=False ): + if pseudoonline: + raise ValueError("Pseudo Online evaluation not currently implemented for this dataset") """Load data for 003-2015 dataset.""" if (subject < 1) or (subject > 10): raise ValueError("Subject must be between 1 and 12. Got %d." % subject) @@ -400,7 +438,10 @@ def _load_data_004_2015( base_url=BNCI_URL, only_filenames=False, verbose=None, + pseudoonline=False ): + if pseudoonline: + raise ValueError("Pseudo Online evaluation not currently implemented for this dataset") """Load data for 004-2015 dataset.""" if (subject < 1) or (subject > 9): raise ValueError("Subject must be between 1 and 9. Got %d." % subject) @@ -434,7 +475,10 @@ def _load_data_009_2015( base_url=BBCI_URL, only_filenames=False, verbose=None, + pseudoonline=False ): + if pseudoonline: + raise ValueError("Pseudo Online evaluation not currently implemented for this dataset") """Load data for 009-2015 dataset.""" if (subject < 1) or (subject > 21): raise ValueError("Subject must be between 1 and 21. Got %d." % subject) @@ -465,7 +509,10 @@ def _load_data_010_2015( base_url=BBCI_URL, only_filenames=False, verbose=None, + pseudoonline=False ): + if pseudoonline: + raise ValueError("Pseudo Online evaluation not currently implemented for this dataset") """Load data for 010-2015 dataset.""" if (subject < 1) or (subject > 12): raise ValueError("Subject must be between 1 and 12. Got %d." % subject) @@ -497,7 +544,10 @@ def _load_data_012_2015( base_url=BBCI_URL, only_filenames=False, verbose=None, + pseudoonline=False ): + if pseudoonline: + raise ValueError("Pseudo Online evaluation not currently implemented for this dataset") """Load data for 012-2015 dataset.""" if (subject < 1) or (subject > 12): raise ValueError("Subject must be between 1 and 12. Got %d." % subject) @@ -524,7 +574,10 @@ def _load_data_013_2015( base_url=BNCI_URL, only_filenames=False, verbose=None, + pseudoonline=False ): + if pseudoonline: + raise ValueError("Pseudo Online evaluation not currently implemented for this dataset") """Load data for 013-2015 dataset.""" if (subject < 1) or (subject > 6): raise ValueError("Subject must be between 1 and 6. Got %d." % subject) @@ -584,6 +637,45 @@ def _convert_mi(filename, ch_names, ch_types): return runs, event_id +def _convert_mi_pseudoonline(filename, time_task, time_fix, ch_names, ch_types, pseudoonline): + """Process (Graz) motor imagery data from MAT files. + + Parameters + ---------- + filename : str + Path to the MAT file. + time_task: float + Actual duration of the task + time_fix: + Duration of Fixation Cross + ch_names : list of str + List of channel names. + ch_types : list of str + List of channel types. + + Returns + ------- + raw : instance of RawArray + returns list of recording runs.""" + runs = [] + event_id = {} + data = loadmat(filename, struct_as_record=False, squeeze_me=True) + + if isinstance(data["data"], np.ndarray): + run_array = data["data"] + else: + run_array = [data["data"]] + + for run in run_array: + raw, evd = _convert_run_pseudoonline(run, time_task, time_fix, ch_names, ch_types, None, pseudoonline) + if raw is None: + continue + runs.append(raw) + event_id.update(evd) + # change labels to match rest + standardize_keys(event_id) + return runs, event_id + def standardize_keys(d): master_list = [ ["both feet", "feet"], @@ -634,6 +726,64 @@ def _convert_run(run, ch_names=None, ch_types=None, verbose=None): return raw, event_id +def _convert_run_pseudoonline(run, time_task, time_fix, ch_names=None, ch_types=None, verbose=None, pseudoonline=False): + """Convert one run to raw.""" + # parse eeg data + event_id = {} + n_chan = run.X.shape[1] + montage = make_standard_montage("standard_1005") + eeg_data = 1e-6 * run.X + sfreq = run.fs + + if not ch_names: + ch_names = ["EEG%d" % ch for ch in range(1, n_chan + 1)] + montage = None # no montage + + if not ch_types: + ch_types = ["eeg"] * n_chan + + trigger = np.zeros((len(eeg_data), 1)) + # some runs does not contains trials i.e baseline runs + if len(run.trial) > 0: + trigger[run.trial - 1, 0] = run.y + else: + return None, None + + eeg_data = np.c_[eeg_data, trigger] + ch_names = ch_names + ["stim"] + ch_types = ch_types + ["stim"] + event_id = {ev: (ii + 1) for ii, ev in enumerate(run.classes)} + info = create_info(ch_names=ch_names, ch_types=ch_types, sfreq=sfreq) + raw = RawArray(data=eeg_data.T, info=info, verbose=verbose) + raw.set_montage(montage) + + if pseudoonline: + # ================================================================================================================= + # Code to add the event Nothing with label 9 + # ================================================================================================================= + # The idea is to replace the old stim channel with a new STIM channel that locate the events at the exact time that + # start and the event also for the nothing phase. + events = find_events(raw, stim_channel="stim") + stim_data = np.zeros((1, len(raw.times))) + + # Time when the task finish + time_nothing = (sfreq * time_task) + 1 + # Time where the task actually begin, because the events of "stim" give us when the fix cross appear, but not when + # the task begin. + time_fixation_cross = (sfreq * time_fix) + beta_rebound = 0.5 * sfreq + for i in np.arange(len(events[:, 0])): + stim_data[0, int(events[i, 0] + time_fixation_cross)] = events[i, 2] + stim_data[0, int(events[i, 0] + time_fixation_cross + time_nothing)] = 9 + + info = create_info(ch_names=['STI'], ch_types=['stim'], sfreq=sfreq) + new_stim = RawArray(data=stim_data, info=info, verbose=verbose) + raw.add_channels([new_stim], force_update_info=True) + raw.drop_channels(['stim']) # Delete old stim channel + event_id["nothing"] = 9 + + return raw, event_id + @verbose def _convert_run_p300_sl(run, verbose=None): """Convert one p300 run from santa lucia file format.""" @@ -735,7 +885,7 @@ class MNEBNCI(BaseDataset): def _get_single_subject_data(self, subject): """Return data for a single subject.""" - sessions = load_data(subject=subject, dataset=self.code, verbose=False) + sessions = load_data(subject=subject, dataset=self.code, verbose=False, pseudoonline=self.pseudoonline) return sessions def data_path( @@ -749,6 +899,7 @@ def data_path( path=path, force_update=force_update, only_filenames=True, + pseudoonline=self.pseudoonline ) @@ -799,7 +950,7 @@ def __init__(self): super().__init__( subjects=list(range(1, 10)), sessions_per_subject=2, - events={"left_hand": 1, "right_hand": 2, "feet": 3, "tongue": 4}, + events={"left_hand": 1, "right_hand": 2, "feet": 3, "tongue": 4, "nothing": 9}, code="BNCI2014-001", interval=[2, 6], paradigm="imagery", @@ -852,7 +1003,7 @@ def __init__(self): super().__init__( subjects=list(range(1, 15)), sessions_per_subject=1, - events={"right_hand": 1, "feet": 2}, + events={"right_hand": 1, "feet": 2, "nothing": 9}, code="BNCI2014-002", interval=[3, 8], paradigm="imagery", @@ -926,7 +1077,7 @@ def __init__(self): super().__init__( subjects=list(range(1, 10)), sessions_per_subject=5, - events={"left_hand": 1, "right_hand": 2}, + events={"left_hand": 1, "right_hand": 2, "nothing": 9}, code="BNCI2014-004", interval=[3, 7.5], paradigm="imagery", @@ -1088,7 +1239,7 @@ def __init__(self): super().__init__( subjects=list(range(1, 13)), sessions_per_subject=2, - events={"right_hand": 1, "feet": 2}, + events={"right_hand": 1, "feet": 2, "nothing": 9}, code="BNCI2015-001", interval=[0, 5], paradigm="imagery", diff --git a/moabb/datasets/preprocessing.py b/moabb/datasets/preprocessing.py index 2bf30a0bf..fbd97b69c 100644 --- a/moabb/datasets/preprocessing.py +++ b/moabb/datasets/preprocessing.py @@ -31,6 +31,55 @@ def _unsafe_pick_events(events, include): raise e +def _events_pseudoonline(events, tmin, tmax, sfreq, overlap): + """ + This function create new events every duration length. + :param events: Real event created during registrations of the dataset + :param tmin: Minimum time where create new events(tmin MUST be 0). Is the starting time of epoch, and we consider as starting time + the initial value of the interval in normal MOABB [2, 6] + :param tmax: Maximum time of the windows. Is the final time of epoch. + :param sfreq: Sfreq of the recorded signal + :param overlap: Percentage of overlapping that we want in the sliding windows + :return: + return the new events, ove every starting point of the sliding windows and with univocal label + """ + # Compute duration of the windows in seconds + duration_s = tmax-tmin + # Convert the duration in time point. + duration = duration_s * sfreq + # The starting point of the new windows in time point + ove = (((tmax - tmin) / 100) * (100 - overlap)) * sfreq + + # Total number of new events that need to be created + total = int((events[-1, 0] - events[0, 0]) / (100 - overlap)) + events_new = np.zeros((total, 3), dtype=int) + # Fill the first event with the same old events + events_new[0, :] = events[0, :] + + j = 0 + i = 1 + # Go on while we are at a time sample less than the last events in the data acquisition + while events_new[i - 1, 0] + duration <= events[-1, -0]: + # Assign the time stamp to the new events, so we add ove + events_new[i, 0] = events_new[i - 1, 0] + ove + # Now we have to check. If the new added events plus the duration is less then the time stamp of the new event + # we assign an univocal label. If is not we check the percentage of time stamp associate with a label is predominant in a windows. + # If we have 50/50 we assign the label as the next event since the subject want to switch in that direction. + if events_new[i, 0] + duration <= events[j + 1, 0]: + events_new[i, 2] = events[j, 2] + else: + First = abs(events[j + 1, 0] - events_new[i, 0]) + Second = abs((events_new[i, 0] + duration) - events[j + 1, 0]) + if First > Second: + events_new[i, 2] = events[j, 2] + else: + events_new[i, 2] = events[j + 1, 2] + j = j + 1 + i = i + 1 + + return events_new + + class ForkPipelines(TransformerMixin, BaseEstimator): def __init__(self, transformers: List[Tuple[str, Union[Pipeline, TransformerMixin]]]): for _, t in transformers: @@ -55,13 +104,16 @@ class SetRawAnnotations(FixedTransformer): Always sets the annotations, even if the events list is empty """ - def __init__(self, event_id, interval: Tuple[float, float]): + def __init__(self, event_id, interval: Tuple[float, float], tmin, tmax, overlap): assert isinstance(event_id, dict) # not None self.event_id = event_id if len(set(event_id.values())) != len(event_id): raise ValueError("Duplicate event code") self.event_desc = dict((code, desc) for desc, code in self.event_id.items()) self.interval = interval + self.tmin = tmin + self.tmax = tmax + self.overlap = overlap def transform(self, raw, y=None): duration = self.interval[1] - self.interval[0] @@ -74,9 +126,15 @@ def transform(self, raw, y=None): "No stim channel nor annotations found, skipping setting annotations." ) return raw - events = mne.find_events(raw, shortest_event=0, verbose=False) - events = _unsafe_pick_events(events, include=list(self.event_id.values())) - events[:, 0] += offset + if self.overlap == None: + events = mne.find_events(raw, shortest_event=0, verbose=False) + events = _unsafe_pick_events(events, include=list(self.event_id.values())) + events[:, 0] += offset + else: + events_ = mne.find_events(raw, shortest_event=0, verbose=False) + events = _events_pseudoonline(events_, tmin=self.tmin, tmax=self.tmax, sfreq=raw.info["sfreq"], overlap=self.overlap) + duration = self.tmax - self.tmin + if len(events) != 0: annotations = mne.annotations_from_events( events, @@ -87,6 +145,8 @@ def transform(self, raw, y=None): ) annotations.set_durations(duration) raw.set_annotations(annotations) + # raw.plot() + # print("OK") else: log.warning("No events found, skipping setting annotations.") return raw @@ -97,16 +157,24 @@ class RawToEvents(FixedTransformer): Always returns an array for shape (n_events, 3), even if no events found """ - def __init__(self, event_id: dict[str, int], interval: Tuple[float, float]): + def __init__(self, event_id: dict[str, int], interval: Tuple[float, float], tmin, tmax, overlap): assert isinstance(event_id, dict) # not None self.event_id = event_id self.interval = interval + self.tmin = tmin + self.tmax = tmax + self.overlap = overlap def _find_events(self, raw): stim_channels = mne.utils._get_stim_channel(None, raw.info, raise_error=False) if len(stim_channels) > 0: # returns empty array if none found - events = mne.find_events(raw, shortest_event=0, verbose=False) + if self.overlap == None: + events = mne.find_events(raw, shortest_event=0, verbose=False) + else: + events_ = mne.find_events(raw, shortest_event=0, verbose=False) + events = _events_pseudoonline(events_, tmin=self.tmin, tmax=self.tmax, sfreq=raw.info["sfreq"], + overlap=self.overlap) else: try: events, _ = mne.events_from_annotations( diff --git a/moabb/evaluations/utils.py b/moabb/evaluations/utils.py index 4a28b8d48..c1f8cd1ce 100644 --- a/moabb/evaluations/utils.py +++ b/moabb/evaluations/utils.py @@ -6,6 +6,7 @@ from numpy import argmax from sklearn.pipeline import Pipeline +from sklearn.metrics import matthews_corrcoef try: @@ -36,6 +37,10 @@ def _check_if_is_keras_model(model): except ImportError: return False +def _normalized_mcc(y_true, y_pred): + mcc = matthews_corrcoef(y_true, y_pred) + return (mcc + 1) / 2 + def _check_if_is_pytorch_model(model): """Check if the model is a Keras model. diff --git a/moabb/paradigms/base.py b/moabb/paradigms/base.py index 527d51c20..955047a23 100644 --- a/moabb/paradigms/base.py +++ b/moabb/paradigms/base.py @@ -68,6 +68,7 @@ def __init__( baseline: Optional[Tuple[float, float]] = None, channels: Optional[List[str]] = None, resample: Optional[float] = None, + overlap: Optional[float] = None, ): if tmax is not None: if tmin >= tmax: @@ -79,6 +80,7 @@ def __init__( self.tmin = tmin self.tmax = tmax self.interpolate_missing_channels = False + self.overlap = overlap @property @abc.abstractmethod @@ -169,6 +171,9 @@ def make_process_pipelines( SetRawAnnotations( dataset.event_id, interval=dataset.interval, + tmin=self.tmin, + tmax=self.tmax, + overlap=self.overlap ), ) ) @@ -513,6 +518,7 @@ def __init__( baseline=None, channels=None, resample=None, + overlap=None ): super().__init__( filters=filters, @@ -521,6 +527,7 @@ def __init__( resample=resample, tmin=tmin, tmax=tmax, + overlap=overlap ) self.events = events @@ -536,4 +543,4 @@ def scoring(self): def _get_events_pipeline(self, dataset): event_id = self.used_events(dataset) - return RawToEvents(event_id=event_id, interval=dataset.interval) + return RawToEvents(event_id=event_id, interval=dataset.interval, tmin=self.tmin, tmax=self.tmax, overlap=self.overlap) diff --git a/moabb/paradigms/motor_imagery.py b/moabb/paradigms/motor_imagery.py index 657a8e814..d6be07114 100644 --- a/moabb/paradigms/motor_imagery.py +++ b/moabb/paradigms/motor_imagery.py @@ -6,7 +6,8 @@ from moabb.datasets import utils from moabb.datasets.fake import FakeDataset from moabb.paradigms.base import BaseParadigm - +from sklearn.metrics import make_scorer +from moabb.evaluations.utils import _normalized_mcc log = logging.getLogger(__name__) @@ -51,6 +52,8 @@ class BaseMotorImagery(BaseParadigm): resample: float | None (default None) If not None, resample the eeg data with the sampling rate provided. + + overlap: Overlap (in percentage) of the sliding windows approach for the pseudoonline evaluation """ def __init__( @@ -62,7 +65,13 @@ def __init__( baseline=None, channels=None, resample=None, + overlap=None ): + + if overlap is not None: + print("Overlap available only for pseudo online evaluation") + tmin = 0.0 + super().__init__( filters=filters, events=events, @@ -71,6 +80,7 @@ def __init__( resample=resample, tmin=tmin, tmax=tmax, + overlap=overlap ) def is_valid(self, dataset): @@ -102,7 +112,10 @@ def datasets(self): @property def scoring(self): - return "accuracy" + if self.overlap == None: + return "accuracy" + else: + return make_scorer(_normalized_mcc) class SinglePass(BaseMotorImagery): @@ -401,7 +414,10 @@ def scoring(self): if self.n_classes == 2: return "roc_auc" else: - return "accuracy" + if self.overlap == None: + return "accuracy" + else: + return make_scorer(_normalized_mcc) class FakeImageryParadigm(LeftRightImagery): From a67b9299241382fca5c0dba54e0be492c1dc9a1e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 29 Jul 2024 14:04:23 +0000 Subject: [PATCH 02/17] [pre-commit.ci] auto fixes from pre-commit.com hooks --- examples/plot_pseudoonline.py | 59 ++++++++------- moabb/datasets/base.py | 2 +- moabb/datasets/bnci.py | 119 ++++++++++++++++++++++--------- moabb/datasets/preprocessing.py | 23 ++++-- moabb/evaluations/utils.py | 3 +- moabb/paradigms/base.py | 14 ++-- moabb/paradigms/motor_imagery.py | 10 +-- 7 files changed, 149 insertions(+), 81 deletions(-) diff --git a/examples/plot_pseudoonline.py b/examples/plot_pseudoonline.py index f636672ca..b7526f5f5 100644 --- a/examples/plot_pseudoonline.py +++ b/examples/plot_pseudoonline.py @@ -1,17 +1,14 @@ # Set up the Directory for made it run on a server. -import sys -import os -import moabb -import mne -import resource -from moabb.paradigms import MotorImagery -from moabb.datasets import BNCI2014_001 -from pyriemann.classification import FgMDM + +import numpy as np +from pyriemann.classification import MDM, FgMDM +from pyriemann.estimation import Covariances from sklearn.pipeline import Pipeline + +from moabb.datasets import BNCI2014_001 from moabb.evaluations import WithinSessionEvaluation -from pyriemann.classification import MDM -from pyriemann.estimation import Covariances -import numpy as np +from moabb.paradigms import MotorImagery + sub = 1 @@ -25,10 +22,12 @@ dataset = BNCI2014_001() dataset.pseudoonline = True -#events = ["right_hand", "left_hand"] +# events = ["right_hand", "left_hand"] events = list(dataset.event_id.keys()) -paradigm = MotorImagery(events=events, n_classes=len(events), fmin=fmin, fmax=fmax, tmax=tmax, overlap=50) +paradigm = MotorImagery( + events=events, n_classes=len(events), fmin=fmin, fmax=fmax, tmax=tmax, overlap=50 +) X, y, meta = paradigm.get_data(dataset=dataset, subjects=[sub]) print("Print Events_id:", y) @@ -37,27 +36,25 @@ pipelines = {} -pipelines["MDM"] = Pipeline(steps=[ - ("Covariances", Covariances("cov")), - ("MDM", MDM(metric=dict(mean='riemann', distance='riemann'))) -]) +pipelines["MDM"] = Pipeline( + steps=[ + ("Covariances", Covariances("cov")), + ("MDM", MDM(metric=dict(mean="riemann", distance="riemann"))), + ] +) -pipelines["FgMDM"] = Pipeline(steps=[ - ("Covariances", Covariances("cov")), - ("FgMDM", FgMDM()) -]) +pipelines["FgMDM"] = Pipeline( + steps=[("Covariances", Covariances("cov")), ("FgMDM", FgMDM())] +) # Select an evaluation Within Session -evaluation_online = WithinSessionEvaluation(paradigm=paradigm, - datasets=dataset, - overwrite=True, - random_state=42, - n_jobs=-1 - ) +evaluation_online = WithinSessionEvaluation( + paradigm=paradigm, datasets=dataset, overwrite=True, random_state=42, n_jobs=-1 +) # Print the results results_ALL = evaluation_online.process(pipelines) -results_pipeline = results_ALL.groupby(['pipeline'], as_index=False)["score"].mean() -results_pipeline_std = results_ALL.groupby(['pipeline'], as_index=False)["score"].std() -results_pipeline['std'] = results_pipeline_std["score"] -print(results_pipeline) \ No newline at end of file +results_pipeline = results_ALL.groupby(["pipeline"], as_index=False)["score"].mean() +results_pipeline_std = results_ALL.groupby(["pipeline"], as_index=False)["score"].std() +results_pipeline["std"] = results_pipeline_std["score"] +print(results_pipeline) diff --git a/moabb/datasets/base.py b/moabb/datasets/base.py index 9a9aa3c4e..cd558a998 100644 --- a/moabb/datasets/base.py +++ b/moabb/datasets/base.py @@ -319,7 +319,7 @@ def __init__( paradigm, doi=None, unit_factor=1e6, - pseudoonline=False + pseudoonline=False, ): """Initialize function for the BaseDataset.""" try: diff --git a/moabb/datasets/bnci.py b/moabb/datasets/bnci.py index 2a058d4d1..4cfd795c1 100644 --- a/moabb/datasets/bnci.py +++ b/moabb/datasets/bnci.py @@ -33,7 +33,7 @@ def load_data( base_url=BNCI_URL, only_filenames=False, verbose=None, - pseudoonline=False + pseudoonline=False, ): # noqa: D301 """Get paths to local copies of a BNCI dataset files. @@ -117,7 +117,7 @@ def load_data( baseurl_list[dataset], only_filenames, verbose, - pseudoonline + pseudoonline, ) @@ -130,7 +130,7 @@ def _load_data_001_2014( base_url=BNCI_URL, only_filenames=False, verbose=None, - pseudoonline=False + pseudoonline=False, ): """Load data for 001-2014 dataset.""" if (subject < 1) or (subject > 9): @@ -157,7 +157,9 @@ def _load_data_001_2014( continue if pseudoonline: - runs, ev = _convert_mi_pseudoonline(filename[0], time_task, time_fix, ch_names, ch_types, pseudoonline) + runs, ev = _convert_mi_pseudoonline( + filename[0], time_task, time_fix, ch_names, ch_types, pseudoonline + ) else: runs, ev = _convert_mi(filename[0], ch_names, ch_types) # FIXME: deal with run with no event (1:3) and name them @@ -178,7 +180,7 @@ def _load_data_002_2014( base_url=BNCI_URL, only_filenames=False, verbose=None, - pseudoonline=False + pseudoonline=False, ): """Load data for 002-2014 dataset.""" if (subject < 1) or (subject > 14): @@ -196,7 +198,9 @@ def _load_data_002_2014( continue # FIXME: electrode position and name are not provided directly. if pseudoonline: - raws, _ = _convert_mi_pseudoonline(filename, time_task, time_fix, None, ["eeg"] * 15, pseudoonline) + raws, _ = _convert_mi_pseudoonline( + filename, time_task, time_fix, None, ["eeg"] * 15, pseudoonline + ) else: raws, _ = _convert_mi(filename, None, ["eeg"] * 15) runs.extend(zip([r] * len(raws), raws)) @@ -215,7 +219,7 @@ def _load_data_004_2014( base_url=BNCI_URL, only_filenames=False, verbose=None, - pseudoonline=False + pseudoonline=False, ): """Load data for 004-2014 dataset.""" if (subject < 1) or (subject > 9): @@ -235,7 +239,9 @@ def _load_data_004_2014( if only_filenames: continue if pseudoonline: - raws, _ = _convert_mi_pseudoonline(filename, time_task, time_fix, ch_names, ch_types, pseudoonline) + raws, _ = _convert_mi_pseudoonline( + filename, time_task, time_fix, ch_names, ch_types, pseudoonline + ) else: raws, _ = _convert_mi(filename, ch_names, ch_types) sessions.extend(zip([r] * len(raws), raws)) @@ -255,11 +261,13 @@ def _load_data_008_2014( base_url=BNCI_URL, only_filenames=False, verbose=None, - pseudoonline=False + pseudoonline=False, ): if pseudoonline: - raise ValueError("Pseudo Online evaluation not currently implemented for this dataset") + raise ValueError( + "Pseudo Online evaluation not currently implemented for this dataset" + ) """Load data for 008-2014 dataset.""" if (subject < 1) or (subject > 8): @@ -286,10 +294,12 @@ def _load_data_009_2014( base_url=BNCI_URL, only_filenames=False, verbose=None, - pseudoonline=False + pseudoonline=False, ): if pseudoonline: - raise ValueError("Pseudo Online evaluation not currently implemented for this dataset") + raise ValueError( + "Pseudo Online evaluation not currently implemented for this dataset" + ) """Load data for 009-2014 dataset.""" if (subject < 1) or (subject > 10): raise ValueError("Subject must be between 1 and 10. Got %d." % subject) @@ -328,7 +338,7 @@ def _load_data_001_2015( base_url=BNCI_URL, only_filenames=False, verbose=None, - pseudoonline=False + pseudoonline=False, ): """Load data for 001-2015 dataset.""" if (subject < 1) or (subject > 12): @@ -358,7 +368,9 @@ def _load_data_001_2015( if only_filenames: continue if pseudoonline: - runs, ev = _convert_mi_pseudoonline(filename[0], time_task, time_fix, ch_names, ch_types, pseudoonline) + runs, ev = _convert_mi_pseudoonline( + filename[0], time_task, time_fix, ch_names, ch_types, pseudoonline + ) else: runs, ev = _convert_mi(filename[0], ch_names, ch_types) sessions[f"{session_idx}{r}"] = {str(ii): run for ii, run in enumerate(runs)} @@ -376,10 +388,12 @@ def _load_data_003_2015( base_url=BNCI_URL, only_filenames=False, verbose=None, - pseudoonline=False + pseudoonline=False, ): if pseudoonline: - raise ValueError("Pseudo Online evaluation not currently implemented for this dataset") + raise ValueError( + "Pseudo Online evaluation not currently implemented for this dataset" + ) """Load data for 003-2015 dataset.""" if (subject < 1) or (subject > 10): raise ValueError("Subject must be between 1 and 12. Got %d." % subject) @@ -438,10 +452,12 @@ def _load_data_004_2015( base_url=BNCI_URL, only_filenames=False, verbose=None, - pseudoonline=False + pseudoonline=False, ): if pseudoonline: - raise ValueError("Pseudo Online evaluation not currently implemented for this dataset") + raise ValueError( + "Pseudo Online evaluation not currently implemented for this dataset" + ) """Load data for 004-2015 dataset.""" if (subject < 1) or (subject > 9): raise ValueError("Subject must be between 1 and 9. Got %d." % subject) @@ -475,10 +491,12 @@ def _load_data_009_2015( base_url=BBCI_URL, only_filenames=False, verbose=None, - pseudoonline=False + pseudoonline=False, ): if pseudoonline: - raise ValueError("Pseudo Online evaluation not currently implemented for this dataset") + raise ValueError( + "Pseudo Online evaluation not currently implemented for this dataset" + ) """Load data for 009-2015 dataset.""" if (subject < 1) or (subject > 21): raise ValueError("Subject must be between 1 and 21. Got %d." % subject) @@ -509,10 +527,12 @@ def _load_data_010_2015( base_url=BBCI_URL, only_filenames=False, verbose=None, - pseudoonline=False + pseudoonline=False, ): if pseudoonline: - raise ValueError("Pseudo Online evaluation not currently implemented for this dataset") + raise ValueError( + "Pseudo Online evaluation not currently implemented for this dataset" + ) """Load data for 010-2015 dataset.""" if (subject < 1) or (subject > 12): raise ValueError("Subject must be between 1 and 12. Got %d." % subject) @@ -544,10 +564,12 @@ def _load_data_012_2015( base_url=BBCI_URL, only_filenames=False, verbose=None, - pseudoonline=False + pseudoonline=False, ): if pseudoonline: - raise ValueError("Pseudo Online evaluation not currently implemented for this dataset") + raise ValueError( + "Pseudo Online evaluation not currently implemented for this dataset" + ) """Load data for 012-2015 dataset.""" if (subject < 1) or (subject > 12): raise ValueError("Subject must be between 1 and 12. Got %d." % subject) @@ -574,10 +596,12 @@ def _load_data_013_2015( base_url=BNCI_URL, only_filenames=False, verbose=None, - pseudoonline=False + pseudoonline=False, ): if pseudoonline: - raise ValueError("Pseudo Online evaluation not currently implemented for this dataset") + raise ValueError( + "Pseudo Online evaluation not currently implemented for this dataset" + ) """Load data for 013-2015 dataset.""" if (subject < 1) or (subject > 6): raise ValueError("Subject must be between 1 and 6. Got %d." % subject) @@ -637,7 +661,9 @@ def _convert_mi(filename, ch_names, ch_types): return runs, event_id -def _convert_mi_pseudoonline(filename, time_task, time_fix, ch_names, ch_types, pseudoonline): +def _convert_mi_pseudoonline( + filename, time_task, time_fix, ch_names, ch_types, pseudoonline +): """Process (Graz) motor imagery data from MAT files. Parameters @@ -667,7 +693,9 @@ def _convert_mi_pseudoonline(filename, time_task, time_fix, ch_names, ch_types, run_array = [data["data"]] for run in run_array: - raw, evd = _convert_run_pseudoonline(run, time_task, time_fix, ch_names, ch_types, None, pseudoonline) + raw, evd = _convert_run_pseudoonline( + run, time_task, time_fix, ch_names, ch_types, None, pseudoonline + ) if raw is None: continue runs.append(raw) @@ -676,6 +704,7 @@ def _convert_mi_pseudoonline(filename, time_task, time_fix, ch_names, ch_types, standardize_keys(event_id) return runs, event_id + def standardize_keys(d): master_list = [ ["both feet", "feet"], @@ -726,7 +755,15 @@ def _convert_run(run, ch_names=None, ch_types=None, verbose=None): return raw, event_id -def _convert_run_pseudoonline(run, time_task, time_fix, ch_names=None, ch_types=None, verbose=None, pseudoonline=False): +def _convert_run_pseudoonline( + run, + time_task, + time_fix, + ch_names=None, + ch_types=None, + verbose=None, + pseudoonline=False, +): """Convert one run to raw.""" # parse eeg data event_id = {} @@ -770,20 +807,21 @@ def _convert_run_pseudoonline(run, time_task, time_fix, ch_names=None, ch_types= time_nothing = (sfreq * time_task) + 1 # Time where the task actually begin, because the events of "stim" give us when the fix cross appear, but not when # the task begin. - time_fixation_cross = (sfreq * time_fix) + time_fixation_cross = sfreq * time_fix beta_rebound = 0.5 * sfreq for i in np.arange(len(events[:, 0])): stim_data[0, int(events[i, 0] + time_fixation_cross)] = events[i, 2] stim_data[0, int(events[i, 0] + time_fixation_cross + time_nothing)] = 9 - info = create_info(ch_names=['STI'], ch_types=['stim'], sfreq=sfreq) + info = create_info(ch_names=["STI"], ch_types=["stim"], sfreq=sfreq) new_stim = RawArray(data=stim_data, info=info, verbose=verbose) raw.add_channels([new_stim], force_update_info=True) - raw.drop_channels(['stim']) # Delete old stim channel + raw.drop_channels(["stim"]) # Delete old stim channel event_id["nothing"] = 9 return raw, event_id + @verbose def _convert_run_p300_sl(run, verbose=None): """Convert one p300 run from santa lucia file format.""" @@ -885,7 +923,12 @@ class MNEBNCI(BaseDataset): def _get_single_subject_data(self, subject): """Return data for a single subject.""" - sessions = load_data(subject=subject, dataset=self.code, verbose=False, pseudoonline=self.pseudoonline) + sessions = load_data( + subject=subject, + dataset=self.code, + verbose=False, + pseudoonline=self.pseudoonline, + ) return sessions def data_path( @@ -899,7 +942,7 @@ def data_path( path=path, force_update=force_update, only_filenames=True, - pseudoonline=self.pseudoonline + pseudoonline=self.pseudoonline, ) @@ -950,7 +993,13 @@ def __init__(self): super().__init__( subjects=list(range(1, 10)), sessions_per_subject=2, - events={"left_hand": 1, "right_hand": 2, "feet": 3, "tongue": 4, "nothing": 9}, + events={ + "left_hand": 1, + "right_hand": 2, + "feet": 3, + "tongue": 4, + "nothing": 9, + }, code="BNCI2014-001", interval=[2, 6], paradigm="imagery", diff --git a/moabb/datasets/preprocessing.py b/moabb/datasets/preprocessing.py index fbd97b69c..8bbf8be00 100644 --- a/moabb/datasets/preprocessing.py +++ b/moabb/datasets/preprocessing.py @@ -44,7 +44,7 @@ def _events_pseudoonline(events, tmin, tmax, sfreq, overlap): return the new events, ove every starting point of the sliding windows and with univocal label """ # Compute duration of the windows in seconds - duration_s = tmax-tmin + duration_s = tmax - tmin # Convert the duration in time point. duration = duration_s * sfreq # The starting point of the new windows in time point @@ -132,7 +132,13 @@ def transform(self, raw, y=None): events[:, 0] += offset else: events_ = mne.find_events(raw, shortest_event=0, verbose=False) - events = _events_pseudoonline(events_, tmin=self.tmin, tmax=self.tmax, sfreq=raw.info["sfreq"], overlap=self.overlap) + events = _events_pseudoonline( + events_, + tmin=self.tmin, + tmax=self.tmax, + sfreq=raw.info["sfreq"], + overlap=self.overlap, + ) duration = self.tmax - self.tmin if len(events) != 0: @@ -157,7 +163,9 @@ class RawToEvents(FixedTransformer): Always returns an array for shape (n_events, 3), even if no events found """ - def __init__(self, event_id: dict[str, int], interval: Tuple[float, float], tmin, tmax, overlap): + def __init__( + self, event_id: dict[str, int], interval: Tuple[float, float], tmin, tmax, overlap + ): assert isinstance(event_id, dict) # not None self.event_id = event_id self.interval = interval @@ -173,8 +181,13 @@ def _find_events(self, raw): events = mne.find_events(raw, shortest_event=0, verbose=False) else: events_ = mne.find_events(raw, shortest_event=0, verbose=False) - events = _events_pseudoonline(events_, tmin=self.tmin, tmax=self.tmax, sfreq=raw.info["sfreq"], - overlap=self.overlap) + events = _events_pseudoonline( + events_, + tmin=self.tmin, + tmax=self.tmax, + sfreq=raw.info["sfreq"], + overlap=self.overlap, + ) else: try: events, _ = mne.events_from_annotations( diff --git a/moabb/evaluations/utils.py b/moabb/evaluations/utils.py index c1f8cd1ce..96d159f97 100644 --- a/moabb/evaluations/utils.py +++ b/moabb/evaluations/utils.py @@ -5,8 +5,8 @@ from typing import Sequence from numpy import argmax -from sklearn.pipeline import Pipeline from sklearn.metrics import matthews_corrcoef +from sklearn.pipeline import Pipeline try: @@ -37,6 +37,7 @@ def _check_if_is_keras_model(model): except ImportError: return False + def _normalized_mcc(y_true, y_pred): mcc = matthews_corrcoef(y_true, y_pred) return (mcc + 1) / 2 diff --git a/moabb/paradigms/base.py b/moabb/paradigms/base.py index 955047a23..68a676ac7 100644 --- a/moabb/paradigms/base.py +++ b/moabb/paradigms/base.py @@ -173,7 +173,7 @@ def make_process_pipelines( interval=dataset.interval, tmin=self.tmin, tmax=self.tmax, - overlap=self.overlap + overlap=self.overlap, ), ) ) @@ -518,7 +518,7 @@ def __init__( baseline=None, channels=None, resample=None, - overlap=None + overlap=None, ): super().__init__( filters=filters, @@ -527,7 +527,7 @@ def __init__( resample=resample, tmin=tmin, tmax=tmax, - overlap=overlap + overlap=overlap, ) self.events = events @@ -543,4 +543,10 @@ def scoring(self): def _get_events_pipeline(self, dataset): event_id = self.used_events(dataset) - return RawToEvents(event_id=event_id, interval=dataset.interval, tmin=self.tmin, tmax=self.tmax, overlap=self.overlap) + return RawToEvents( + event_id=event_id, + interval=dataset.interval, + tmin=self.tmin, + tmax=self.tmax, + overlap=self.overlap, + ) diff --git a/moabb/paradigms/motor_imagery.py b/moabb/paradigms/motor_imagery.py index d6be07114..8a9597ba3 100644 --- a/moabb/paradigms/motor_imagery.py +++ b/moabb/paradigms/motor_imagery.py @@ -3,11 +3,13 @@ import abc import logging +from sklearn.metrics import make_scorer + from moabb.datasets import utils from moabb.datasets.fake import FakeDataset -from moabb.paradigms.base import BaseParadigm -from sklearn.metrics import make_scorer from moabb.evaluations.utils import _normalized_mcc +from moabb.paradigms.base import BaseParadigm + log = logging.getLogger(__name__) @@ -65,7 +67,7 @@ def __init__( baseline=None, channels=None, resample=None, - overlap=None + overlap=None, ): if overlap is not None: @@ -80,7 +82,7 @@ def __init__( resample=resample, tmin=tmin, tmax=tmax, - overlap=overlap + overlap=overlap, ) def is_valid(self, dataset): From 921f59265175142862e40e2ffdbd31972ee5a5af Mon Sep 17 00:00:00 2001 From: CARRARA Igor Date: Mon, 29 Jul 2024 16:09:01 +0200 Subject: [PATCH 03/17] Updating whats new --- docs/source/whats_new.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/whats_new.rst b/docs/source/whats_new.rst index fe246d754..951b2ce78 100644 --- a/docs/source/whats_new.rst +++ b/docs/source/whats_new.rst @@ -22,7 +22,7 @@ Enhancements - Centralize dataset summary tables in CSV files (:gh:`635` by `Pierre Guetschel`_) - Add new dataset :class:`moabb.datasets.Liu2024` dataset (:gh:`619` by `Taha Habib`_) - Increasing the version in the pre-commit config (:gh:`631` by pre-commit bot) - +- Implementation of Pseudo Online framework (:gh:`641` by `Igor Carrara`_) Bugs From 0dcb68c50a0550a81e066cea223ee82081160754 Mon Sep 17 00:00:00 2001 From: CARRARA Igor Date: Mon, 29 Jul 2024 16:41:08 +0200 Subject: [PATCH 04/17] Update example --- examples/plot_pseudoonline.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/plot_pseudoonline.py b/examples/plot_pseudoonline.py index b7526f5f5..444189100 100644 --- a/examples/plot_pseudoonline.py +++ b/examples/plot_pseudoonline.py @@ -47,6 +47,7 @@ steps=[("Covariances", Covariances("cov")), ("FgMDM", FgMDM())] ) +dataset.subject_list = dataset.subject_list[int(sub) - 1:int(sub)] # Select an evaluation Within Session evaluation_online = WithinSessionEvaluation( paradigm=paradigm, datasets=dataset, overwrite=True, random_state=42, n_jobs=-1 From 8f53345e4d46cb4a9991f3c799efb70f0dd758e1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 29 Jul 2024 14:41:55 +0000 Subject: [PATCH 05/17] [pre-commit.ci] auto fixes from pre-commit.com hooks --- examples/plot_pseudoonline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/plot_pseudoonline.py b/examples/plot_pseudoonline.py index 444189100..05ceacc57 100644 --- a/examples/plot_pseudoonline.py +++ b/examples/plot_pseudoonline.py @@ -47,7 +47,7 @@ steps=[("Covariances", Covariances("cov")), ("FgMDM", FgMDM())] ) -dataset.subject_list = dataset.subject_list[int(sub) - 1:int(sub)] +dataset.subject_list = dataset.subject_list[int(sub) - 1 : int(sub)] # Select an evaluation Within Session evaluation_online = WithinSessionEvaluation( paradigm=paradigm, datasets=dataset, overwrite=True, random_state=42, n_jobs=-1 From adbeed78048a651654faa5c01e100d95b63c074a Mon Sep 17 00:00:00 2001 From: CARRARA Igor Date: Mon, 29 Jul 2024 16:46:49 +0200 Subject: [PATCH 06/17] Fix error --- moabb/datasets/base.py | 5 +++-- moabb/datasets/preprocessing.py | 6 ++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/moabb/datasets/base.py b/moabb/datasets/base.py index cd558a998..5ca28d962 100644 --- a/moabb/datasets/base.py +++ b/moabb/datasets/base.py @@ -319,7 +319,7 @@ def __init__( paradigm, doi=None, unit_factor=1e6, - pseudoonline=False, + overlap=False, ): """Initialize function for the BaseDataset.""" try: @@ -349,7 +349,7 @@ def __init__( self.paradigm = paradigm self.doi = doi self.unit_factor = unit_factor - self.pseudoonline = pseudoonline + self.overlap = overlap def _create_process_pipeline(self): return Pipeline( @@ -359,6 +359,7 @@ def _create_process_pipeline(self): SetRawAnnotations( self.event_id, interval=self.interval, + overlap=self.overlap ), ), ] diff --git a/moabb/datasets/preprocessing.py b/moabb/datasets/preprocessing.py index 8bbf8be00..d0ffe65fc 100644 --- a/moabb/datasets/preprocessing.py +++ b/moabb/datasets/preprocessing.py @@ -111,10 +111,12 @@ def __init__(self, event_id, interval: Tuple[float, float], tmin, tmax, overlap) raise ValueError("Duplicate event code") self.event_desc = dict((code, desc) for desc, code in self.event_id.items()) self.interval = interval - self.tmin = tmin - self.tmax = tmax self.overlap = overlap + if self.overlap: + self.tmin = tmin + self.tmax = tmax + def transform(self, raw, y=None): duration = self.interval[1] - self.interval[0] offset = int(self.interval[0] * raw.info["sfreq"]) From 3fb594965a1ce383f7d42e60864585878eab5123 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 29 Jul 2024 14:47:59 +0000 Subject: [PATCH 07/17] [pre-commit.ci] auto fixes from pre-commit.com hooks --- moabb/datasets/base.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/moabb/datasets/base.py b/moabb/datasets/base.py index 5ca28d962..d5b6592fc 100644 --- a/moabb/datasets/base.py +++ b/moabb/datasets/base.py @@ -357,9 +357,7 @@ def _create_process_pipeline(self): ( StepType.RAW, SetRawAnnotations( - self.event_id, - interval=self.interval, - overlap=self.overlap + self.event_id, interval=self.interval, overlap=self.overlap ), ), ] From 406b36e322d8bf0d3c683ba72c07476cba73b1e3 Mon Sep 17 00:00:00 2001 From: CARRARA Igor Date: Mon, 29 Jul 2024 16:58:05 +0200 Subject: [PATCH 08/17] Fix error --- moabb/datasets/preprocessing.py | 106 ++++++++++++++++++++++++++------ moabb/paradigms/base.py | 55 +++++++++++------ 2 files changed, 125 insertions(+), 36 deletions(-) diff --git a/moabb/datasets/preprocessing.py b/moabb/datasets/preprocessing.py index d0ffe65fc..a35789e97 100644 --- a/moabb/datasets/preprocessing.py +++ b/moabb/datasets/preprocessing.py @@ -111,11 +111,6 @@ def __init__(self, event_id, interval: Tuple[float, float], tmin, tmax, overlap) raise ValueError("Duplicate event code") self.event_desc = dict((code, desc) for desc, code in self.event_id.items()) self.interval = interval - self.overlap = overlap - - if self.overlap: - self.tmin = tmin - self.tmax = tmax def transform(self, raw, y=None): duration = self.interval[1] - self.interval[0] @@ -128,20 +123,62 @@ def transform(self, raw, y=None): "No stim channel nor annotations found, skipping setting annotations." ) return raw - if self.overlap == None: - events = mne.find_events(raw, shortest_event=0, verbose=False) - events = _unsafe_pick_events(events, include=list(self.event_id.values())) - events[:, 0] += offset + events = mne.find_events(raw, shortest_event=0, verbose=False) + events = _unsafe_pick_events(events, include=list(self.event_id.values())) + events[:, 0] += offset + + if len(events) != 0: + annotations = mne.annotations_from_events( + events, + raw.info["sfreq"], + self.event_desc, + first_samp=raw.first_samp, + verbose=False, + ) + annotations.set_durations(duration) + raw.set_annotations(annotations) + # raw.plot() + # print("OK") else: - events_ = mne.find_events(raw, shortest_event=0, verbose=False) - events = _events_pseudoonline( - events_, - tmin=self.tmin, - tmax=self.tmax, - sfreq=raw.info["sfreq"], - overlap=self.overlap, + log.warning("No events found, skipping setting annotations.") + return raw + +class SetRawAnnotations_PseudoOnline(FixedTransformer): + """ + Always sets the annotations, even if the events list is empty + """ + + def __init__(self, event_id, interval: Tuple[float, float], tmin, tmax, overlap): + assert isinstance(event_id, dict) # not None + self.event_id = event_id + if len(set(event_id.values())) != len(event_id): + raise ValueError("Duplicate event code") + self.event_desc = dict((code, desc) for desc, code in self.event_id.items()) + self.interval = interval + self.overlap = overlap + self.tmin = tmin + self.tmax = tmax + + def transform(self, raw, y=None): + duration = self.interval[1] - self.interval[0] + offset = int(self.interval[0] * raw.info["sfreq"]) + if raw.annotations: + return raw + stim_channels = mne.utils._get_stim_channel(None, raw.info, raise_error=False) + if len(stim_channels) == 0: + log.warning( + "No stim channel nor annotations found, skipping setting annotations." ) - duration = self.tmax - self.tmin + return raw + events_ = mne.find_events(raw, shortest_event=0, verbose=False) + events = _events_pseudoonline( + events_, + tmin=self.tmin, + tmax=self.tmax, + sfreq=raw.info["sfreq"], + overlap=self.overlap, + ) + duration = self.tmax - self.tmin if len(events) != 0: annotations = mne.annotations_from_events( @@ -165,6 +202,40 @@ class RawToEvents(FixedTransformer): Always returns an array for shape (n_events, 3), even if no events found """ + def __init__( + self, event_id: dict[str, int], interval: Tuple[float, float], tmin, tmax, overlap + ): + assert isinstance(event_id, dict) # not None + self.event_id = event_id + self.interval = interval + + def _find_events(self, raw): + stim_channels = mne.utils._get_stim_channel(None, raw.info, raise_error=False) + if len(stim_channels) > 0: + # returns empty array if none found + events = mne.find_events(raw, shortest_event=0, verbose=False) + else: + try: + events, _ = mne.events_from_annotations( + raw, event_id=self.event_id, verbose=False + ) + offset = int(self.interval[0] * raw.info["sfreq"]) + events[:, 0] -= offset # return the original events onset + except ValueError as e: + if str(e) == "Could not find any of the events you specified.": + return np.zeros((0, 3), dtype="int32") + raise e + return events + + def transform(self, raw, y=None): + events = self._find_events(raw) + return _unsafe_pick_events(events, list(self.event_id.values())) + +class RawToEvents_PseudoOnline(FixedTransformer): + """ + Always returns an array for shape (n_events, 3), even if no events found + """ + def __init__( self, event_id: dict[str, int], interval: Tuple[float, float], tmin, tmax, overlap ): @@ -207,7 +278,6 @@ def transform(self, raw, y=None): events = self._find_events(raw) return _unsafe_pick_events(events, list(self.event_id.values())) - class RawToEventsP300(RawToEvents): def transform(self, raw, y=None): events = self._find_events(raw) diff --git a/moabb/paradigms/base.py b/moabb/paradigms/base.py index 68a676ac7..671fe673f 100644 --- a/moabb/paradigms/base.py +++ b/moabb/paradigms/base.py @@ -17,7 +17,9 @@ ForkPipelines, RawToEpochs, RawToEvents, + RawToEvents_PseudoOnline, SetRawAnnotations, + SetRawAnnotations_PseudoOnline, get_crop_pipeline, get_filter_pipeline, get_resample_pipeline, @@ -165,18 +167,29 @@ def make_process_pipelines( process_pipelines = [] for raw_pipeline in raw_pipelines: steps = [] - steps.append( - ( - StepType.RAW, - SetRawAnnotations( - dataset.event_id, - interval=dataset.interval, - tmin=self.tmin, - tmax=self.tmax, - overlap=self.overlap, - ), + if self.overlap: + steps.append( + ( + StepType.RAW, + SetRawAnnotations_PseudoOnline( + dataset.event_id, + interval=dataset.interval, + tmin=self.tmin, + tmax=self.tmax, + overlap=self.overlap, + ), + ) + ) + else: + steps.append( + ( + StepType.RAW, + SetRawAnnotations( + dataset.event_id, + interval=dataset.interval, + ), + ) ) - ) if raw_pipeline is not None: steps.append((StepType.RAW, raw_pipeline)) if epochs_pipeline is not None: @@ -543,10 +556,16 @@ def scoring(self): def _get_events_pipeline(self, dataset): event_id = self.used_events(dataset) - return RawToEvents( - event_id=event_id, - interval=dataset.interval, - tmin=self.tmin, - tmax=self.tmax, - overlap=self.overlap, - ) + if self.overlap: + return RawToEvents_PseudoOnline( + event_id=event_id, + interval=dataset.interval, + tmin=self.tmin, + tmax=self.tmax, + overlap=self.overlap, + ) + else: + return RawToEvents( + event_id=event_id, + interval=dataset.interval, + ) From 7a60fb156645775cc8f30bcc9024a13fd6f05ddc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 29 Jul 2024 14:58:27 +0000 Subject: [PATCH 09/17] [pre-commit.ci] auto fixes from pre-commit.com hooks --- moabb/datasets/preprocessing.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/moabb/datasets/preprocessing.py b/moabb/datasets/preprocessing.py index a35789e97..d1df9a229 100644 --- a/moabb/datasets/preprocessing.py +++ b/moabb/datasets/preprocessing.py @@ -143,6 +143,7 @@ def transform(self, raw, y=None): log.warning("No events found, skipping setting annotations.") return raw + class SetRawAnnotations_PseudoOnline(FixedTransformer): """ Always sets the annotations, even if the events list is empty @@ -231,6 +232,7 @@ def transform(self, raw, y=None): events = self._find_events(raw) return _unsafe_pick_events(events, list(self.event_id.values())) + class RawToEvents_PseudoOnline(FixedTransformer): """ Always returns an array for shape (n_events, 3), even if no events found @@ -278,6 +280,7 @@ def transform(self, raw, y=None): events = self._find_events(raw) return _unsafe_pick_events(events, list(self.event_id.values())) + class RawToEventsP300(RawToEvents): def transform(self, raw, y=None): events = self._find_events(raw) From 9be950f8fa1a2349e8161ec707daf94b7fbbf21d Mon Sep 17 00:00:00 2001 From: CARRARA Igor Date: Mon, 29 Jul 2024 17:13:00 +0200 Subject: [PATCH 10/17] Erro in Base Dataset --- moabb/datasets/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/moabb/datasets/base.py b/moabb/datasets/base.py index d5b6592fc..60fc7bafc 100644 --- a/moabb/datasets/base.py +++ b/moabb/datasets/base.py @@ -357,7 +357,8 @@ def _create_process_pipeline(self): ( StepType.RAW, SetRawAnnotations( - self.event_id, interval=self.interval, overlap=self.overlap + self.event_id, + interval=self.interval, ), ), ] From 07dd9c817479e86e52841db23b09e4ff421bfbb4 Mon Sep 17 00:00:00 2001 From: CARRARA Igor Date: Tue, 30 Jul 2024 10:17:29 +0200 Subject: [PATCH 11/17] Fix Bug --- moabb/paradigms/base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/moabb/paradigms/base.py b/moabb/paradigms/base.py index 671fe673f..f73a73434 100644 --- a/moabb/paradigms/base.py +++ b/moabb/paradigms/base.py @@ -167,7 +167,7 @@ def make_process_pipelines( process_pipelines = [] for raw_pipeline in raw_pipelines: steps = [] - if self.overlap: + if self.overlap is not None: steps.append( ( StepType.RAW, @@ -556,7 +556,7 @@ def scoring(self): def _get_events_pipeline(self, dataset): event_id = self.used_events(dataset) - if self.overlap: + if self.overlap is not None: return RawToEvents_PseudoOnline( event_id=event_id, interval=dataset.interval, From b2b7dda8ce9ea1fa13f3bc6f5712f1f4358dd2ec Mon Sep 17 00:00:00 2001 From: CARRARA Igor Date: Tue, 30 Jul 2024 10:21:55 +0200 Subject: [PATCH 12/17] Fix Bug --- moabb/datasets/bnci.py | 1 - moabb/datasets/preprocessing.py | 4 +--- moabb/paradigms/motor_imagery.py | 4 ++-- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/moabb/datasets/bnci.py b/moabb/datasets/bnci.py index 4cfd795c1..e0ae31ca7 100644 --- a/moabb/datasets/bnci.py +++ b/moabb/datasets/bnci.py @@ -808,7 +808,6 @@ def _convert_run_pseudoonline( # Time where the task actually begin, because the events of "stim" give us when the fix cross appear, but not when # the task begin. time_fixation_cross = sfreq * time_fix - beta_rebound = 0.5 * sfreq for i in np.arange(len(events[:, 0])): stim_data[0, int(events[i, 0] + time_fixation_cross)] = events[i, 2] stim_data[0, int(events[i, 0] + time_fixation_cross + time_nothing)] = 9 diff --git a/moabb/datasets/preprocessing.py b/moabb/datasets/preprocessing.py index d1df9a229..2df291746 100644 --- a/moabb/datasets/preprocessing.py +++ b/moabb/datasets/preprocessing.py @@ -161,8 +161,6 @@ def __init__(self, event_id, interval: Tuple[float, float], tmin, tmax, overlap) self.tmax = tmax def transform(self, raw, y=None): - duration = self.interval[1] - self.interval[0] - offset = int(self.interval[0] * raw.info["sfreq"]) if raw.annotations: return raw stim_channels = mne.utils._get_stim_channel(None, raw.info, raise_error=False) @@ -252,7 +250,7 @@ def _find_events(self, raw): stim_channels = mne.utils._get_stim_channel(None, raw.info, raise_error=False) if len(stim_channels) > 0: # returns empty array if none found - if self.overlap == None: + if self.overlap is None: events = mne.find_events(raw, shortest_event=0, verbose=False) else: events_ = mne.find_events(raw, shortest_event=0, verbose=False) diff --git a/moabb/paradigms/motor_imagery.py b/moabb/paradigms/motor_imagery.py index 8a9597ba3..f7af413ba 100644 --- a/moabb/paradigms/motor_imagery.py +++ b/moabb/paradigms/motor_imagery.py @@ -114,7 +114,7 @@ def datasets(self): @property def scoring(self): - if self.overlap == None: + if self.overlap is None: return "accuracy" else: return make_scorer(_normalized_mcc) @@ -416,7 +416,7 @@ def scoring(self): if self.n_classes == 2: return "roc_auc" else: - if self.overlap == None: + if self.overlap is None: return "accuracy" else: return make_scorer(_normalized_mcc) From 8e36d1f187c8164431c3c3f097d70488c58d9ebb Mon Sep 17 00:00:00 2001 From: CARRARA Igor Date: Tue, 30 Jul 2024 10:33:19 +0200 Subject: [PATCH 13/17] Fix Bug --- moabb/datasets/preprocessing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/moabb/datasets/preprocessing.py b/moabb/datasets/preprocessing.py index 2df291746..f63c0e4eb 100644 --- a/moabb/datasets/preprocessing.py +++ b/moabb/datasets/preprocessing.py @@ -104,7 +104,7 @@ class SetRawAnnotations(FixedTransformer): Always sets the annotations, even if the events list is empty """ - def __init__(self, event_id, interval: Tuple[float, float], tmin, tmax, overlap): + def __init__(self, event_id, interval: Tuple[float, float]): assert isinstance(event_id, dict) # not None self.event_id = event_id if len(set(event_id.values())) != len(event_id): @@ -202,7 +202,7 @@ class RawToEvents(FixedTransformer): """ def __init__( - self, event_id: dict[str, int], interval: Tuple[float, float], tmin, tmax, overlap + self, event_id: dict[str, int], interval: Tuple[float, float] ): assert isinstance(event_id, dict) # not None self.event_id = event_id From 7e2c34c3ab6006d54ce5186d43ef5e51331c0d9a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Jul 2024 08:37:02 +0000 Subject: [PATCH 14/17] [pre-commit.ci] auto fixes from pre-commit.com hooks --- moabb/datasets/preprocessing.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/moabb/datasets/preprocessing.py b/moabb/datasets/preprocessing.py index f63c0e4eb..6a6ac4f80 100644 --- a/moabb/datasets/preprocessing.py +++ b/moabb/datasets/preprocessing.py @@ -201,9 +201,7 @@ class RawToEvents(FixedTransformer): Always returns an array for shape (n_events, 3), even if no events found """ - def __init__( - self, event_id: dict[str, int], interval: Tuple[float, float] - ): + def __init__(self, event_id: dict[str, int], interval: Tuple[float, float]): assert isinstance(event_id, dict) # not None self.event_id = event_id self.interval = interval From 0641fc68648f011c84d757065709b1c3a6c54f8e Mon Sep 17 00:00:00 2001 From: CARRARA Igor Date: Tue, 30 Jul 2024 10:49:27 +0200 Subject: [PATCH 15/17] Reduce size of windows --- examples/plot_pseudoonline.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/plot_pseudoonline.py b/examples/plot_pseudoonline.py index 05ceacc57..a5b941790 100644 --- a/examples/plot_pseudoonline.py +++ b/examples/plot_pseudoonline.py @@ -9,14 +9,12 @@ from moabb.evaluations import WithinSessionEvaluation from moabb.paradigms import MotorImagery - sub = 1 # Initialize parameter for the Band Pass filter fmin = 8 fmax = 30 -tmin = 0 -tmax = 2 +tmax = 3 # Load Dataset and switch to Pseudoonline mode dataset = BNCI2014_001() From 9eface30094987482ce784206234eac2c8f75d6c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Jul 2024 08:49:43 +0000 Subject: [PATCH 16/17] [pre-commit.ci] auto fixes from pre-commit.com hooks --- examples/plot_pseudoonline.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/plot_pseudoonline.py b/examples/plot_pseudoonline.py index a5b941790..d7b6416c6 100644 --- a/examples/plot_pseudoonline.py +++ b/examples/plot_pseudoonline.py @@ -9,6 +9,7 @@ from moabb.evaluations import WithinSessionEvaluation from moabb.paradigms import MotorImagery + sub = 1 # Initialize parameter for the Band Pass filter From d8f76ef85d84d7f21b43e499649865d1efb039f0 Mon Sep 17 00:00:00 2001 From: CARRARA Igor Date: Tue, 30 Jul 2024 11:26:39 +0200 Subject: [PATCH 17/17] Reduce size parallel --- examples/plot_pseudoonline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/plot_pseudoonline.py b/examples/plot_pseudoonline.py index a5b941790..f10d06b2f 100644 --- a/examples/plot_pseudoonline.py +++ b/examples/plot_pseudoonline.py @@ -48,7 +48,7 @@ dataset.subject_list = dataset.subject_list[int(sub) - 1 : int(sub)] # Select an evaluation Within Session evaluation_online = WithinSessionEvaluation( - paradigm=paradigm, datasets=dataset, overwrite=True, random_state=42, n_jobs=-1 + paradigm=paradigm, datasets=dataset, overwrite=True, random_state=42, n_jobs=1 ) # Print the results