diff --git a/torchFastText/datasets/dataset.py b/torchFastText/datasets/dataset.py index 2bacdaa..219789f 100644 --- a/torchFastText/datasets/dataset.py +++ b/torchFastText/datasets/dataset.py @@ -29,8 +29,8 @@ def __init__( self, categorical_variables: List[List[int]], texts: List[str], - outputs: List[int], tokenizer: NGramTokenizer, + outputs: List[int] = None, **kwargs, ): """ @@ -43,6 +43,13 @@ def __init__( y (List[int]): List of outcomes. tokenizer (Tokenizer): Tokenizer. """ + + if categorical_variables is not None and len(categorical_variables) != len(texts): + raise ValueError("Categorical variables and texts must have the same length.") + + if outputs is not None and len(outputs) != len(texts): + raise ValueError("Outputs and texts must have the same length.") + self.categorical_variables = categorical_variables self.texts = texts self.outputs = outputs @@ -55,7 +62,7 @@ def __len__(self) -> int: Returns: int: Number of observations. """ - return len(self.outputs) + return len(self.texts) def __str__(self) -> str: """ @@ -80,8 +87,12 @@ def __getitem__(self, index: int) -> List: self.categorical_variables[index] if self.categorical_variables is not None else None ) text = self.texts[index] - y = self.outputs[index] - return text, categorical_variables, y + + if self.outputs is not None: + y = self.outputs[index] + return text, categorical_variables, y + else: + return text, categorical_variables def collate_fn(self, batch): """ @@ -93,8 +104,12 @@ def collate_fn(self, batch): Returns: Tuple[torch.LongTensor]: Observation with given index. """ + # Unzip the batch in one go using zip(*batch) - text, *categorical_vars, y = zip(*batch) + if self.outputs is not None: + text, *categorical_vars, y = zip(*batch) + else: + text, *categorical_vars = zip(*batch) # Convert text to indices in parallel using map indices_batch = list(map(lambda x: self.tokenizer.indices_matrix(x)[0], text)) @@ -124,10 +139,12 @@ def collate_fn(self, batch): padded_batch.shape[0], 1, dtype=torch.float32, device=padded_batch.device ) - # Convert labels to tensor in one go - y = torch.tensor(y, dtype=torch.long) - - return (padded_batch, categorical_tensors, y) + if self.outputs is not None: + # Convert labels to tensor in one go + y = torch.tensor(y, dtype=torch.long) + return (padded_batch, categorical_tensors, y) + else: + return (padded_batch, categorical_tensors) def create_dataloader( self, diff --git a/torchFastText/datasets/tokenizer.py b/torchFastText/datasets/tokenizer.py index 8bcd8f8..61595ef 100644 --- a/torchFastText/datasets/tokenizer.py +++ b/torchFastText/datasets/tokenizer.py @@ -182,7 +182,7 @@ def get_word_index(self, word: str) -> int: def get_subwords(self, word: str) -> Tuple[List[str], List[int]]: """ Return all subwords tokens and indices for a given word. - Also adds the whole word token and indice if the word is in word_id_mapping + Also adds the whole word token and indice if the word is in word_id_mapping (==> the word is in initial vocabulary + seen at least MIN_COUNT times). Adds tags "<" and ">" to the word. @@ -198,11 +198,13 @@ def get_subwords(self, word: str) -> Tuple[List[str], List[int]]: # Get subwords and associated indices WITHOUT the whole word for n in range(self.min_n, self.max_n + 1): ngrams = self.get_ngram_list(word_with_tags, n) - tokens += [ngram for ngram in ngrams if ngram != word_with_tags and ngram != word] # Exclude the full word + tokens += [ + ngram for ngram in ngrams if ngram != word_with_tags and ngram != word + ] # Exclude the full word indices = [self.get_subword_index(token) for token in tokens] assert word not in tokens - + # Add word token and indice only if the word is in word_id_mapping if word in self.word_id_mapping.keys(): self.get_word_index(word) @@ -313,4 +315,4 @@ def from_json(cls: Type["NGramTokenizer"], filepath: str, training_text) -> "NGr """ with open(filepath, "r") as f: data = json.load(f) - return cls(**data, training_text=training_text) \ No newline at end of file + return cls(**data, training_text=training_text)