|
| 1 | +import torch |
| 2 | +import logging |
| 3 | +from pathlib import Path |
| 4 | +from typing import Tuple, Optional |
| 5 | + |
| 6 | +from transformers import PreTrainedModel |
| 7 | +from diffusionLM.model.transformers_model import DiffusionConfig, DiffusionLLM |
| 8 | + |
| 9 | +logger = logging.getLogger(__name__) |
| 10 | + |
| 11 | +class ModelSaveError(Exception): |
| 12 | + """Custom exception for model saving/loading errors""" |
| 13 | + pass |
| 14 | + |
| 15 | +def save_model( |
| 16 | + model: DiffusionLLM, |
| 17 | + optimizer: torch.optim.Optimizer, |
| 18 | + save_path: str, |
| 19 | + final: bool = False, |
| 20 | +) -> None: |
| 21 | + """Save model and optimizer state.""" |
| 22 | + try: |
| 23 | + save_dir = Path(save_path) |
| 24 | + save_dir.mkdir(parents=True, exist_ok=True) |
| 25 | + |
| 26 | + step = getattr(model, 'current_step', 1) |
| 27 | + prefix = "final" if final else f"step_{step}" |
| 28 | + save_name = save_dir / f"{prefix}_model.pt" |
| 29 | + |
| 30 | + # Save the model |
| 31 | + torch.save( |
| 32 | + { |
| 33 | + "model_state_dict": model.state_dict(), |
| 34 | + "optimizer_state_dict": optimizer.state_dict(), |
| 35 | + "step": step, |
| 36 | + "config": model.config.__dict__, |
| 37 | + }, |
| 38 | + save_name, |
| 39 | + ) |
| 40 | + logger.info(f"Model saved to {save_name}") |
| 41 | + |
| 42 | + except Exception as e: |
| 43 | + logger.error(f"Failed to save model: {str(e)}") |
| 44 | + raise ModelSaveError(f"Failed to save model: {str(e)}") |
| 45 | + |
| 46 | +def load_model( |
| 47 | + load_path: str, |
| 48 | + device: Optional[torch.device] = None, |
| 49 | +) -> Tuple[DiffusionLLM, torch.optim.Optimizer]: |
| 50 | + """Load saved model.""" |
| 51 | + try: |
| 52 | + if device is None: |
| 53 | + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 54 | + |
| 55 | + # Load checkpoint |
| 56 | + if not Path(load_path).exists(): |
| 57 | + raise ModelSaveError(f"Checkpoint not found at {load_path}") |
| 58 | + |
| 59 | + checkpoint = torch.load(load_path, map_location=device) |
| 60 | + |
| 61 | + # Create config and model |
| 62 | + config_dict = checkpoint.get("config", {}) |
| 63 | + if not config_dict: |
| 64 | + raise ModelSaveError("No config found in checkpoint") |
| 65 | + |
| 66 | + # Filter out unexpected keyword arguments |
| 67 | + expected_keys = DiffusionConfig.__init__.__code__.co_varnames |
| 68 | + filtered_config_dict = {k: v for k, v in config_dict.items() if k in expected_keys} |
| 69 | + |
| 70 | + config = DiffusionConfig(**filtered_config_dict) |
| 71 | + |
| 72 | + # Create model |
| 73 | + model = DiffusionLLM(config) |
| 74 | + model.load_state_dict(checkpoint["model_state_dict"]) |
| 75 | + model.to(device) |
| 76 | + |
| 77 | + # Create optimizer |
| 78 | + optimizer = torch.optim.AdamW(model.parameters()) |
| 79 | + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) |
| 80 | + |
| 81 | + return model, optimizer |
| 82 | + |
| 83 | + except Exception as e: |
| 84 | + logger.error(f"Failed to load model: {str(e)}") |
| 85 | + raise ModelSaveError(f"Failed to load model: {str(e)}") |
0 commit comments