From 46e50dbd6b4efb62f2fde4710b15381d5e4bd429 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sat, 6 Sep 2025 14:25:29 +0200 Subject: [PATCH] Add split functionality to synthetic_data --- cebra/datasets/synthetic_data.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/cebra/datasets/synthetic_data.py b/cebra/datasets/synthetic_data.py index 9288a93d..4f675593 100644 --- a/cebra/datasets/synthetic_data.py +++ b/cebra/datasets/synthetic_data.py @@ -112,6 +112,25 @@ def __init__(self, name, root=_DEFAULT_DATADIR, download=True): self.index = self.data['u'] self.lam = self.data['lam'] + + def split(self, split): + tot_len = len(self.neural) + train_idx = np.arange(tot_len)[:int(tot_len*0.8)] + valid_idx = np.arange(tot_len)[int(tot_len*0.8):] + + if split == 'train': + self.neural = self.neural[train_idx] + self.index = self.index[train_idx] + self.idx = train_idx + elif split == 'valid': + self.neural = self.neural[valid_idx] + self.index = self.index[valid_idx] + self.idx = valid_idx + elif split == 'all': + pass + else: + raise ValueError(f"{split} not supported") + @property def input_dimension(self): return self.neural.size(1)