From b2ef31f381462edbe1caf1ddee728c9a6f1417a3 Mon Sep 17 00:00:00 2001 From: meilame-tayebjee Date: Tue, 13 May 2025 13:47:55 +0000 Subject: [PATCH] feat: Enforce right values in y_train to avoid out of index error np.max(y_train) == len(np.unique(y_train))-1 should be True Solves #53 and #54 --- torchFastText/torchFastText.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/torchFastText/torchFastText.py b/torchFastText/torchFastText.py index a424493..aa1af37 100644 --- a/torchFastText/torchFastText.py +++ b/torchFastText/torchFastText.py @@ -311,6 +311,12 @@ def build( self.num_classes = len( np.unique(y_train) ) # Be sure that y_train contains all the classes ! + + if np.max(y_train) >= self.num_classes: + raise ValueError( + f"y_train must contain values between 0 and {self.num_classes - 1}. Make sure that np.max(y_train) == len(np.unique(y_train))-1." + ) + else: if self.num_classes is None: raise ValueError(