-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcoder.py
More file actions
149 lines (127 loc) · 5.33 KB
/
coder.py
File metadata and controls
149 lines (127 loc) · 5.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
#!/usr/bin/env python3
import os
import sys
import traceback
sys.path.append(os.getcwd())
from contextlib import redirect_stdout
from io import StringIO
from extra.models.llama import Transformer, convert_from_huggingface
from sentencepiece import SentencePieceProcessor
from tinygrad import Device, Tensor, dtypes, nn
from tinygrad.helpers import Timing, colored, fetch, getenv
def create_fixed_tokenizer(output_file):
print("creating fixed tokenizer")
import extra.junk.sentencepiece_model_pb2 as spb2
mp = spb2.ModelProto()
mp.ParseFromString(
fetch(
"https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/tokenizer.model?download=true"
).read_bytes()
)
mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_end|>", score=0))
mp.pieces.append(spb2.ModelProto.SentencePiece(piece="<|im_start|>", score=0))
with open(output_file, "wb") as f:
f.write(mp.SerializeToString())
if __name__ == "__main__":
Tensor.no_grad = True
# https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/config.json
with Timing("create model: "):
model = Transformer(
4096,
14336,
n_heads=32,
n_layers=32,
norm_eps=1e-5,
vocab_size=32002,
n_kv_heads=8,
max_context=4096,
jit=getenv("JIT", 1),
)
with Timing("download weights: "):
part1 = nn.state.torch_load(
fetch(
"https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00001-of-00002.bin?download=true"
)
)
part2 = nn.state.torch_load(
fetch(
"https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/resolve/main/pytorch_model-00002-of-00002.bin?download=true"
)
)
# fix bf16, TODO: check if device supports bf16
def fix_bf16(weights):
return {
k: v.to(Device.DEFAULT).cast(dtypes.float16)
if v.dtype == dtypes.bfloat16
else v
for k, v in weights.items()
}
with Timing("weights -> model: "):
nn.state.load_state_dict(
model, fix_bf16(convert_from_huggingface(part1, model, 32, 8)), strict=False
)
nn.state.load_state_dict(
model, fix_bf16(convert_from_huggingface(part2, model, 32, 8)), strict=False
)
if not os.path.isfile("/tmp/tokenizer.model"):
create_fixed_tokenizer("/tmp/tokenizer.model")
spp = SentencePieceProcessor(model_file="/tmp/tokenizer.model")
# https://huggingface.co/teknium/OpenHermes-2.5-Mistral-7B/blob/main/tokenizer_config.json
# "chat_template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
IM_END = 32000
IM_START = 32001
def encode_prompt(k, v):
return [IM_START] + spp.encode(f"{k}\n{v}") + [IM_END] + spp.encode("\n")
def start_prompt(k):
return [IM_START] + spp.encode(f"{k}\n")
def output(outputted, toks, color):
cur = spp.decode(toks)[len(outputted) :]
sys.stdout.write(colored(cur, color))
sys.stdout.flush()
outputted += cur
return outputted
# *** app below this line ***
toks = [spp.bos_id()] + encode_prompt(
"system",
"You are Quentin. Quentin is a useful assistant who writes Python code to answer questions. He keeps the code as short as possible and doesn't read from user input",
)
PROMPT = getenv("PROMPT", 1)
temperature = getenv("TEMP", 0.7)
start_pos = 0
outputted = output("", toks, "green")
turn = True
while 1:
if PROMPT:
toks += encode_prompt("user", input("Q: ")) + start_prompt("assistant")
else:
toks += start_prompt("user" if turn else "assistant")
turn = not turn
old_output_len = len(outputted)
while 1:
tok = model(Tensor([toks[start_pos:]]), start_pos, temperature).item()
start_pos = len(toks)
toks.append(tok)
outputted = output(outputted, toks, "blue" if not turn else "cyan")
if tok == IM_END:
break
if tok == spp.eos_id():
break
new_output = outputted[old_output_len:]
if new_output.endswith("```") and "```python\n" in new_output:
python_code = new_output.split("```python\n")[1].split("```")[0]
# AI safety. Warning to user. Do not press y if the AI is trying to do unsafe things.
if (
input(colored(f" <-- PYTHON DETECTED, RUN IT? ", "red")).lower()
== "y"
):
my_stdout = StringIO()
try:
with redirect_stdout(my_stdout):
exec(python_code)
result = my_stdout.getvalue()
except Exception as e:
result = "".join(traceback.format_exception_only(e))
toks += spp.encode(f"\nOutput:\n```\n{result}```")
outputted = output(outputted, toks, "yellow")
old_output_len = len(outputted)
print("")