From 985f6ecb96ae689f1d27117f2d8d6fd5c72178e5 Mon Sep 17 00:00:00 2001 From: meilame-tayebjee Date: Mon, 19 May 2025 09:36:41 +0000 Subject: [PATCH] feat: enable flexibility in the dataloader creation Possibility to override default params and also pass any additional params that is accepted by PyTorch DataLoader --- torchFastText/datasets/dataset.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/torchFastText/datasets/dataset.py b/torchFastText/datasets/dataset.py index 219789f..32d0aef 100644 --- a/torchFastText/datasets/dataset.py +++ b/torchFastText/datasets/dataset.py @@ -152,15 +152,22 @@ def create_dataloader( shuffle: bool = False, drop_last: bool = False, num_workers: int = os.cpu_count() - 1, + pin_memory: bool = True, + persistent_workers: bool = True, **kwargs, ) -> torch.utils.data.DataLoader: """ - Creates a Dataloader. + Creates a Dataloader from the FastTextModelDataset. + Use collate_fn() to tokenize and pad the sequences. Args: batch_size (int): Batch size. shuffle (bool, optional): Shuffle option. Defaults to False. drop_last (bool, optional): Drop last option. Defaults to False. + num_workers (int, optional): Number of workers. Defaults to os.cpu_count() - 1. + pin_memory (bool, optional): Set True if working on GPU, False if CPU. Defaults to True. + persistent_workers (bool, optional): Set True for training, False for inference. Defaults to True. + **kwargs: Additional arguments for PyTorch DataLoader. Returns: torch.utils.data.DataLoader: Dataloader. @@ -174,7 +181,8 @@ def create_dataloader( collate_fn=self.collate_fn, shuffle=shuffle, drop_last=drop_last, - pin_memory=True, + pin_memory=pin_memory, num_workers=num_workers, - persistent_workers=True, + persistent_workers=persistent_workers, + **kwargs, )