From 56457af6e374f222d8906bbf199aaa338b28e320 Mon Sep 17 00:00:00 2001 From: meilame-tayebjee Date: Wed, 2 Apr 2025 12:39:26 +0000 Subject: [PATCH 1/5] [chore] format --- torchFastText/datasets/tokenizer.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/torchFastText/datasets/tokenizer.py b/torchFastText/datasets/tokenizer.py index 8bcd8f8..61595ef 100644 --- a/torchFastText/datasets/tokenizer.py +++ b/torchFastText/datasets/tokenizer.py @@ -182,7 +182,7 @@ def get_word_index(self, word: str) -> int: def get_subwords(self, word: str) -> Tuple[List[str], List[int]]: """ Return all subwords tokens and indices for a given word. - Also adds the whole word token and indice if the word is in word_id_mapping + Also adds the whole word token and indice if the word is in word_id_mapping (==> the word is in initial vocabulary + seen at least MIN_COUNT times). Adds tags "<" and ">" to the word. @@ -198,11 +198,13 @@ def get_subwords(self, word: str) -> Tuple[List[str], List[int]]: # Get subwords and associated indices WITHOUT the whole word for n in range(self.min_n, self.max_n + 1): ngrams = self.get_ngram_list(word_with_tags, n) - tokens += [ngram for ngram in ngrams if ngram != word_with_tags and ngram != word] # Exclude the full word + tokens += [ + ngram for ngram in ngrams if ngram != word_with_tags and ngram != word + ] # Exclude the full word indices = [self.get_subword_index(token) for token in tokens] assert word not in tokens - + # Add word token and indice only if the word is in word_id_mapping if word in self.word_id_mapping.keys(): self.get_word_index(word) @@ -313,4 +315,4 @@ def from_json(cls: Type["NGramTokenizer"], filepath: str, training_text) -> "NGr """ with open(filepath, "r") as f: data = json.load(f) - return cls(**data, training_text=training_text) \ No newline at end of file + return cls(**data, training_text=training_text) From 0827c8b999e30221f906d1ff785a37a9ea5d5de7 Mon Sep 17 00:00:00 2001 From: meilame-tayebjee Date: Wed, 2 Apr 2025 12:40:15 +0000 Subject: [PATCH 2/5] [feat] enable outputs=None in Dataset Especially useful when no ground truth (on the fly inference) --- torchFastText/datasets/dataset.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/torchFastText/datasets/dataset.py b/torchFastText/datasets/dataset.py index 2bacdaa..d18eef4 100644 --- a/torchFastText/datasets/dataset.py +++ b/torchFastText/datasets/dataset.py @@ -29,8 +29,8 @@ def __init__( self, categorical_variables: List[List[int]], texts: List[str], - outputs: List[int], tokenizer: NGramTokenizer, + outputs: List[int] = None, **kwargs, ): """ @@ -80,8 +80,12 @@ def __getitem__(self, index: int) -> List: self.categorical_variables[index] if self.categorical_variables is not None else None ) text = self.texts[index] - y = self.outputs[index] - return text, categorical_variables, y + + if self.outputs is not None: + y = self.outputs[index] + return text, categorical_variables, y + else: + return text, categorical_variables def collate_fn(self, batch): """ @@ -124,10 +128,12 @@ def collate_fn(self, batch): padded_batch.shape[0], 1, dtype=torch.float32, device=padded_batch.device ) - # Convert labels to tensor in one go - y = torch.tensor(y, dtype=torch.long) - - return (padded_batch, categorical_tensors, y) + if self.outputs is not None: + # Convert labels to tensor in one go + y = torch.tensor(y, dtype=torch.long) + return (padded_batch, categorical_tensors, y) + else: + return (padded_batch, categorical_tensors) def create_dataloader( self, From 6764d6b0d2ac8d4bb158724525b924f637a7ca7a Mon Sep 17 00:00:00 2001 From: meilame-tayebjee Date: Wed, 2 Apr 2025 13:09:14 +0000 Subject: [PATCH 3/5] [fix] __len__ method use texts instead of outputs. Added input validation. --- torchFastText/datasets/dataset.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/torchFastText/datasets/dataset.py b/torchFastText/datasets/dataset.py index d18eef4..209337c 100644 --- a/torchFastText/datasets/dataset.py +++ b/torchFastText/datasets/dataset.py @@ -43,6 +43,13 @@ def __init__( y (List[int]): List of outcomes. tokenizer (Tokenizer): Tokenizer. """ + + if len(categorical_variables) != len(texts): + raise ValueError("Categorical variables and texts must have the same length.") + + if outputs is not None and len(outputs) != len(texts): + raise ValueError("Outputs and texts must have the same length.") + self.categorical_variables = categorical_variables self.texts = texts self.outputs = outputs @@ -55,7 +62,7 @@ def __len__(self) -> int: Returns: int: Number of observations. """ - return len(self.outputs) + return len(self.texts) def __str__(self) -> str: """ From 4eed284b111e0e3c7c0ba4455e286aef29151f22 Mon Sep 17 00:00:00 2001 From: meilame-tayebjee Date: Wed, 2 Apr 2025 13:21:17 +0000 Subject: [PATCH 4/5] [fix] Handling when y is None in collate_fn --- torchFastText/datasets/dataset.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchFastText/datasets/dataset.py b/torchFastText/datasets/dataset.py index 209337c..cd431e7 100644 --- a/torchFastText/datasets/dataset.py +++ b/torchFastText/datasets/dataset.py @@ -104,8 +104,12 @@ def collate_fn(self, batch): Returns: Tuple[torch.LongTensor]: Observation with given index. """ + # Unzip the batch in one go using zip(*batch) - text, *categorical_vars, y = zip(*batch) + if self.outputs is not None: + text, *categorical_vars, y = zip(*batch) + else: + text, *categorical_vars = zip(*batch) # Convert text to indices in parallel using map indices_batch = list(map(lambda x: self.tokenizer.indices_matrix(x)[0], text)) From 2a4f8784bbdfe7aad6c39b8b3b7a838bceb0ea5c Mon Sep 17 00:00:00 2001 From: meilame-tayebjee Date: Wed, 2 Apr 2025 13:35:44 +0000 Subject: [PATCH 5/5] [fix] Fix the input validation for null categorical variables --- torchFastText/datasets/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchFastText/datasets/dataset.py b/torchFastText/datasets/dataset.py index cd431e7..219789f 100644 --- a/torchFastText/datasets/dataset.py +++ b/torchFastText/datasets/dataset.py @@ -44,7 +44,7 @@ def __init__( tokenizer (Tokenizer): Tokenizer. """ - if len(categorical_variables) != len(texts): + if categorical_variables is not None and len(categorical_variables) != len(texts): raise ValueError("Categorical variables and texts must have the same length.") if outputs is not None and len(outputs) != len(texts):