Skip to content

Commit db94d8d

Browse files
Add the model Saving and loading.
1 parent cd3c5ce commit db94d8d

File tree

3 files changed

+155
-0
lines changed

3 files changed

+155
-0
lines changed

diffusionLM/save_model/__init__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
"""Model saving and loading utilities"""
2+
3+
from .model_save import save_model, load_model, ModelSaveError
4+
from .register_model import registerANDpush, ModelRegistrationError
5+
6+
__all__ = [
7+
"save_model",
8+
"load_model",
9+
"ModelSaveError",
10+
"registerANDpush",
11+
"ModelRegistrationError",
12+
]
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import torch
2+
import logging
3+
from pathlib import Path
4+
from typing import Tuple, Optional
5+
6+
from transformers import PreTrainedModel
7+
from diffusionLM.model.transformers_model import DiffusionConfig, DiffusionLLM
8+
9+
logger = logging.getLogger(__name__)
10+
11+
class ModelSaveError(Exception):
12+
"""Custom exception for model saving/loading errors"""
13+
pass
14+
15+
def save_model(
16+
model: DiffusionLLM,
17+
optimizer: torch.optim.Optimizer,
18+
save_path: str,
19+
final: bool = False,
20+
) -> None:
21+
"""Save model and optimizer state."""
22+
try:
23+
save_dir = Path(save_path)
24+
save_dir.mkdir(parents=True, exist_ok=True)
25+
26+
step = getattr(model, 'current_step', 1)
27+
prefix = "final" if final else f"step_{step}"
28+
save_name = save_dir / f"{prefix}_model.pt"
29+
30+
# Save the model
31+
torch.save(
32+
{
33+
"model_state_dict": model.state_dict(),
34+
"optimizer_state_dict": optimizer.state_dict(),
35+
"step": step,
36+
"config": model.config.__dict__,
37+
},
38+
save_name,
39+
)
40+
logger.info(f"Model saved to {save_name}")
41+
42+
except Exception as e:
43+
logger.error(f"Failed to save model: {str(e)}")
44+
raise ModelSaveError(f"Failed to save model: {str(e)}")
45+
46+
def load_model(
47+
load_path: str,
48+
device: Optional[torch.device] = None,
49+
) -> Tuple[DiffusionLLM, torch.optim.Optimizer]:
50+
"""Load saved model."""
51+
try:
52+
if device is None:
53+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54+
55+
# Load checkpoint
56+
if not Path(load_path).exists():
57+
raise ModelSaveError(f"Checkpoint not found at {load_path}")
58+
59+
checkpoint = torch.load(load_path, map_location=device)
60+
61+
# Create config and model
62+
config_dict = checkpoint.get("config", {})
63+
if not config_dict:
64+
raise ModelSaveError("No config found in checkpoint")
65+
66+
# Filter out unexpected keyword arguments
67+
expected_keys = DiffusionConfig.__init__.__code__.co_varnames
68+
filtered_config_dict = {k: v for k, v in config_dict.items() if k in expected_keys}
69+
70+
config = DiffusionConfig(**filtered_config_dict)
71+
72+
# Create model
73+
model = DiffusionLLM(config)
74+
model.load_state_dict(checkpoint["model_state_dict"])
75+
model.to(device)
76+
77+
# Create optimizer
78+
optimizer = torch.optim.AdamW(model.parameters())
79+
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
80+
81+
return model, optimizer
82+
83+
except Exception as e:
84+
logger.error(f"Failed to load model: {str(e)}")
85+
raise ModelSaveError(f"Failed to load model: {str(e)}")
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import torch
2+
import os
3+
import logging
4+
from pathlib import Path
5+
from typing import Optional
6+
7+
from huggingface_hub import HfApi, Repository
8+
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
9+
from diffusionLM.model.transformers_model import DiffusionConfig, DiffusionLLM
10+
11+
logger = logging.getLogger(__name__)
12+
13+
class ModelRegistrationError(Exception):
14+
"""Custom exception for model registration errors"""
15+
pass
16+
17+
def registerANDpush(
18+
model: DiffusionLLM,
19+
tokenizer,
20+
model_type: str,
21+
model_name: type[DiffusionLLM],
22+
model_config: type[DiffusionConfig],
23+
repo_id: str = "codewithdark/DiffusionLM",
24+
private: bool = False,
25+
) -> None:
26+
"""Register and push model to Hugging Face Hub."""
27+
try:
28+
# Register model architecture
29+
AutoConfig.register(model_type, model_config)
30+
AutoModel.register(model_config, model_name)
31+
AutoModelForCausalLM.register(model_config, model_name)
32+
33+
api = HfApi()
34+
35+
# Create repo
36+
try:
37+
api.create_repo(repo_id=repo_id, private=private)
38+
logger.info(f"Created new repository: {repo_id}")
39+
except Exception as e:
40+
logger.warning(f"Repository creation failed (may already exist): {e}")
41+
42+
# Setup local repo
43+
repo_local_path = Path("SaveModel/DiffusionLM")
44+
repo_local_path.mkdir(parents=True, exist_ok=True)
45+
46+
repo = Repository(local_dir=str(repo_local_path), clone_from=repo_id)
47+
48+
# Save model and tokenizer
49+
tokenizer.save_pretrained(repo_local_path)
50+
model.save_pretrained(repo_local_path)
51+
52+
# Push to hub
53+
repo.push_to_hub(commit_message="Initial model and tokenizer commit")
54+
logger.info(f"Model and tokenizer pushed to {repo_id}")
55+
56+
except Exception as e:
57+
logger.error(f"Model registration failed: {str(e)}")
58+
raise ModelRegistrationError(f"Failed to register model: {str(e)}")

0 commit comments

Comments
 (0)