|
| 1 | +# Trainer API Reference |
| 2 | + |
| 3 | +## Trainer Function |
| 4 | + |
| 5 | +Main training function for DiffusionLM models. |
| 6 | + |
| 7 | +```python |
| 8 | +def trainer( |
| 9 | + model: DiffusionLLM, |
| 10 | + train_dataset, |
| 11 | + val_dataset = None, |
| 12 | + batch_size: int = 8, |
| 13 | + num_epochs: int = 5, |
| 14 | + learning_rate: float = 5e-5, |
| 15 | + warmup_steps: int = 1000, |
| 16 | + max_grad_norm: float = 1.0, |
| 17 | + num_timesteps: int = 100, |
| 18 | + save_path: Optional[str] = None, |
| 19 | + device: torch.device = None |
| 20 | +) -> DiffusionLLM |
| 21 | +``` |
| 22 | + |
| 23 | +### Parameters |
| 24 | + |
| 25 | +- `model`: The DiffusionLLM model to train |
| 26 | +- `train_dataset`: Training dataset |
| 27 | +- `val_dataset`: Validation dataset (optional) |
| 28 | +- `batch_size`: Batch size for training |
| 29 | +- `num_epochs`: Number of training epochs |
| 30 | +- `learning_rate`: Learning rate |
| 31 | +- `warmup_steps`: Number of warmup steps |
| 32 | +- `max_grad_norm`: Maximum gradient norm |
| 33 | +- `num_timesteps`: Number of diffusion timesteps |
| 34 | +- `save_path`: Path to save checkpoints |
| 35 | +- `device`: Device to train on |
| 36 | + |
| 37 | +### Returns |
| 38 | + |
| 39 | +- Trained DiffusionLLM model |
| 40 | + |
| 41 | +## Evaluate Function |
| 42 | + |
| 43 | +```python |
| 44 | +def evaluate( |
| 45 | + model: DiffusionLLM, |
| 46 | + dataloader: DataLoader, |
| 47 | + device: torch.device, |
| 48 | + num_timesteps: int = 100, |
| 49 | + num_eval_steps: int = None |
| 50 | +) -> float |
| 51 | +``` |
| 52 | + |
| 53 | +### Parameters |
| 54 | + |
| 55 | +- `model`: Model to evaluate |
| 56 | +- `dataloader`: DataLoader for evaluation data |
| 57 | +- `device`: Device to evaluate on |
| 58 | +- `num_timesteps`: Number of diffusion timesteps |
| 59 | +- `num_eval_steps`: Number of evaluation steps |
| 60 | + |
| 61 | +### Returns |
| 62 | + |
| 63 | +- Average loss on the evaluation set |
0 commit comments