diff --git a/torchFastText/torchFastText.py b/torchFastText/torchFastText.py index a424493..aa1af37 100644 --- a/torchFastText/torchFastText.py +++ b/torchFastText/torchFastText.py @@ -311,6 +311,12 @@ def build( self.num_classes = len( np.unique(y_train) ) # Be sure that y_train contains all the classes ! + + if np.max(y_train) >= self.num_classes: + raise ValueError( + f"y_train must contain values between 0 and {self.num_classes - 1}. Make sure that np.max(y_train) == len(np.unique(y_train))-1." + ) + else: if self.num_classes is None: raise ValueError(