Skip to content
This repository was archived by the owner on Nov 26, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 26 additions & 9 deletions torchFastText/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def __init__(
self,
categorical_variables: List[List[int]],
texts: List[str],
outputs: List[int],
tokenizer: NGramTokenizer,
outputs: List[int] = None,
**kwargs,
):
"""
Expand All @@ -43,6 +43,13 @@ def __init__(
y (List[int]): List of outcomes.
tokenizer (Tokenizer): Tokenizer.
"""

if categorical_variables is not None and len(categorical_variables) != len(texts):
raise ValueError("Categorical variables and texts must have the same length.")

if outputs is not None and len(outputs) != len(texts):
raise ValueError("Outputs and texts must have the same length.")

self.categorical_variables = categorical_variables
self.texts = texts
self.outputs = outputs
Expand All @@ -55,7 +62,7 @@ def __len__(self) -> int:
Returns:
int: Number of observations.
"""
return len(self.outputs)
return len(self.texts)

def __str__(self) -> str:
"""
Expand All @@ -80,8 +87,12 @@ def __getitem__(self, index: int) -> List:
self.categorical_variables[index] if self.categorical_variables is not None else None
)
text = self.texts[index]
y = self.outputs[index]
return text, categorical_variables, y

if self.outputs is not None:
y = self.outputs[index]
return text, categorical_variables, y
else:
return text, categorical_variables

def collate_fn(self, batch):
"""
Expand All @@ -93,8 +104,12 @@ def collate_fn(self, batch):
Returns:
Tuple[torch.LongTensor]: Observation with given index.
"""

# Unzip the batch in one go using zip(*batch)
text, *categorical_vars, y = zip(*batch)
if self.outputs is not None:
text, *categorical_vars, y = zip(*batch)
else:
text, *categorical_vars = zip(*batch)

# Convert text to indices in parallel using map
indices_batch = list(map(lambda x: self.tokenizer.indices_matrix(x)[0], text))
Expand Down Expand Up @@ -124,10 +139,12 @@ def collate_fn(self, batch):
padded_batch.shape[0], 1, dtype=torch.float32, device=padded_batch.device
)

# Convert labels to tensor in one go
y = torch.tensor(y, dtype=torch.long)

return (padded_batch, categorical_tensors, y)
if self.outputs is not None:
# Convert labels to tensor in one go
y = torch.tensor(y, dtype=torch.long)
return (padded_batch, categorical_tensors, y)
else:
return (padded_batch, categorical_tensors)

def create_dataloader(
self,
Expand Down
10 changes: 6 additions & 4 deletions torchFastText/datasets/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def get_word_index(self, word: str) -> int:
def get_subwords(self, word: str) -> Tuple[List[str], List[int]]:
"""
Return all subwords tokens and indices for a given word.
Also adds the whole word token and indice if the word is in word_id_mapping
Also adds the whole word token and indice if the word is in word_id_mapping
(==> the word is in initial vocabulary + seen at least MIN_COUNT times).
Adds tags "<" and ">" to the word.

Expand All @@ -198,11 +198,13 @@ def get_subwords(self, word: str) -> Tuple[List[str], List[int]]:
# Get subwords and associated indices WITHOUT the whole word
for n in range(self.min_n, self.max_n + 1):
ngrams = self.get_ngram_list(word_with_tags, n)
tokens += [ngram for ngram in ngrams if ngram != word_with_tags and ngram != word] # Exclude the full word
tokens += [
ngram for ngram in ngrams if ngram != word_with_tags and ngram != word
] # Exclude the full word

indices = [self.get_subword_index(token) for token in tokens]
assert word not in tokens

# Add word token and indice only if the word is in word_id_mapping
if word in self.word_id_mapping.keys():
self.get_word_index(word)
Expand Down Expand Up @@ -313,4 +315,4 @@ def from_json(cls: Type["NGramTokenizer"], filepath: str, training_text) -> "NGr
"""
with open(filepath, "r") as f:
data = json.load(f)
return cls(**data, training_text=training_text)
return cls(**data, training_text=training_text)