-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
96 lines (86 loc) · 2.21 KB
/
train.py
File metadata and controls
96 lines (86 loc) · 2.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from unsloth import FastLanguageModel
from unsloth.chat_templates import get_chat_template
from datasets import load_dataset
import torch
# -----------------------
# Model
# -----------------------
model_name = "Qwen/Qwen2.5-Coder-7B"
model, tokenizer = FastLanguageModel.from_pretrained(
model_name = model_name,
max_seq_length = 4096,
dtype = torch.bfloat16,
load_in_4bit = True,
)
tokenizer = get_chat_template(
tokenizer,
chat_template = "qwen25",
)
# -----------------------
# LoRA
# -----------------------
model = FastLanguageModel.get_peft_model(
model,
r = 16,
lora_alpha = 32,
lora_dropout = 0.0,
bias = "none",
target_modules = [
"q_proj","k_proj","v_proj","o_proj",
"gate_proj","up_proj","down_proj",
],
use_gradient_checkpointing = True,
)
# -----------------------
# Dataset
# -----------------------
dataset = load_dataset(
"json",
data_files = "/data01/RAG4Coding_datasets/Semantics4Coding/sample_6000.jsonl",
split = "train",
)
def format_example(example):
messages = [
{"role": "user", "content": example["prompt"]},
{"role": "assistant", "content": example["output"]},
]
example["text"] = tokenizer.apply_chat_template(
messages,
tokenize = False,
add_generation_prompt = False,
)
return example
dataset = dataset.map(
format_example,
remove_columns = dataset.column_names,
num_proc = 8,
)
# -----------------------
# Training
# -----------------------
from trl import SFTTrainer
from transformers import TrainingArguments
trainer = SFTTrainer(
model = model,
tokenizer = tokenizer,
train_dataset = dataset,
dataset_text_field = "text",
max_seq_length = 4096,
args = TrainingArguments(
output_dir = "qwen25-coder-unsloth",
per_device_train_batch_size = 2,
gradient_accumulation_steps = 16,
num_train_epochs = 2,
learning_rate = 2e-5,
warmup_ratio = 0.1,
bf16 = True,
logging_steps = 10,
save_steps = 500,
save_total_limit = 2,
optim = "adamw_8bit",
weight_decay = 0.0,
lr_scheduler_type = "cosine",
report_to = "none",
),
)
trainer.train()