Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 9ae35bd

Browse files
author
Sara Adkins
authored
allow for teacher to be passed in as instantiated model (#2170) (#2172)
1 parent 69a99e1 commit 9ae35bd

File tree

3 files changed

+11
-11
lines changed

3 files changed

+11
-11
lines changed

src/sparseml/transformers/finetune/model_args.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@ class ModelArguments:
3030
)
3131
},
3232
)
33+
distill_teacher: Optional[str] = field(
34+
default=None,
35+
metadata={
36+
"help": "Teacher model (a trained text generation model)",
37+
},
38+
)
3339
config_name: Optional[str] = field(
3440
default=None,
3541
metadata={

src/sparseml/transformers/finetune/text_generation.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,10 @@ def intialize_model_from_path(
155155
)
156156
teacher_config = (
157157
AutoConfig.from_pretrained(
158-
training_args.distill_teacher,
158+
model_args.distill_teacher,
159159
use_auth_token=True if model_args.use_auth_token else None,
160160
)
161-
if training_args.distill_teacher
161+
if model_args.distill_teacher
162162
else None
163163
)
164164

@@ -208,11 +208,11 @@ def intialize_model_from_path(
208208

209209
teacher = (
210210
SparseAutoModel.text_generation_from_pretrained(
211-
model_name_or_path=training_args.distill_teacher,
211+
model_name_or_path=model_args.distill_teacher,
212212
sequence_length=None, # use model default
213213
**teacher_kwargs,
214214
)
215-
if training_args.distill_teacher is not None
215+
if model_args.distill_teacher is not None
216216
else None
217217
)
218218

@@ -289,7 +289,7 @@ def main(
289289

290290
# Detecting last checkpoint.
291291
last_checkpoint = None
292-
teacher = None
292+
teacher = model_args.distill_teacher
293293
model_path = None
294294
model = model_args.model
295295
# Load tokenizer

src/sparseml/transformers/finetune/training_args.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,6 @@ class TrainingArguments(HFTrainingArgs):
3232
arguments
3333
"""
3434

35-
distill_teacher: Optional[str] = field(
36-
default=None,
37-
metadata={
38-
"help": "Teacher model (a trained text generation model)",
39-
},
40-
)
4135
best_model_after_epoch: int = field(
4236
default=None,
4337
metadata={"help": "Epoch after which best model will be saved."},

0 commit comments

Comments
 (0)