Skip to content

Commit 2c596c6

Browse files
committed
Document Sync by Tina
1 parent 7b8a092 commit 2c596c6

File tree

2 files changed

+154
-1
lines changed

2 files changed

+154
-1
lines changed

docs/stable/getting_started/installation.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,12 @@ conda activate sllm-worker
3232
pip install -e ".[worker]"
3333
pip install -i https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ serverless_llm_store==0.0.1.dev3
3434
```
35+
36+
# vLLM Patch
37+
To use vLLM with ServerlessLLM, we need to apply our patch `serverless_llm/store/vllm_patch/sllm_load.patch` to the vLLM repository. Currently, the patch is only tested with vLLM version `0.5.0`.
38+
39+
You may do that by running the following commands:
40+
```bash
41+
VLLM_PATH=$(python -c "import vllm; import os; print(os.path.dirname(os.path.abspath(vllm.__file__)))")
42+
patch -p2 -d $VLLM_PATH < serverless_llm/store/vllm_patch/sllm_load.patch
43+
```

docs/stable/store/quickstart.md

Lines changed: 145 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,4 +105,148 @@ outputs = model.generate(**inputs)
105105
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
106106
```
107107

108-
4. Clean up by "Ctrl+C" the server process.
108+
4. Clean up by "Ctrl+C" the server process.
109+
110+
## Usage with vLLM
111+
112+
To use ServerlessLLM as a load format for vLLM, you need to apply our patch `serverless_llm/store/vllm_patch/sllm_load.patch` to the installed vLLM library. Therefore, please make sure you have read and followed the steps in the `vLLM Patch` section under our [installation guide](../getting_started/installation.md).
113+
114+
Our api aims to be compatible with the `sharded_state` load format in vLLM. Thus, due to the model modifications about the model architecture done by vLLM, the model format for vLLM is **not** the same as we used in transformers. Thus, the `ServerlessLLM format` mentioned in the subsequent sections means the format integrated with vLLM, which is different from the `ServerlessLLM format` used in the previous sections.
115+
116+
Thus, for fist-time users, you have to load the model from other backends and then converted it to the ServerlessLLM format.
117+
118+
1. Download the model from HuggingFace and save it in the ServerlessLLM format:
119+
``` python
120+
import os
121+
import shutil
122+
from typing import Optional
123+
124+
class VllmModelDownloader:
125+
def __init__(self):
126+
pass
127+
128+
def download_vllm_model(
129+
self,
130+
model_name: str,
131+
torch_dtype: str,
132+
tensor_parallel_size: int = 1,
133+
pattern: Optional[str] = None,
134+
max_size: Optional[int] = None,
135+
):
136+
import gc
137+
import shutil
138+
from tempfile import TemporaryDirectory
139+
140+
import torch
141+
from huggingface_hub import snapshot_download
142+
from vllm import LLM
143+
from vllm.config import LoadFormat
144+
145+
def _run_writer(input_dir, output_dir):
146+
# load models from the input directory
147+
llm_writer = LLM(
148+
model=input_dir,
149+
download_dir=input_dir,
150+
dtype=torch_dtype,
151+
tensor_parallel_size=tensor_parallel_size,
152+
num_gpu_blocks_override=1,
153+
enforce_eager=True,
154+
max_model_len=1,
155+
)
156+
model_executer = llm_writer.llm_engine.model_executor
157+
# save the models in the ServerlessLLM format
158+
model_executer.save_serverless_llm_state(
159+
path=output_dir, pattern=pattern, max_size=max_size
160+
)
161+
for file in os.listdir(input_dir):
162+
# Copy the metadata files into the output directory
163+
if os.path.splitext(file)[1] not in (
164+
".bin",
165+
".pt",
166+
".safetensors",
167+
):
168+
src_path = os.path.join(input_dir, file)
169+
dest_path = os.path.join(output_dir, file)
170+
if os.path.isdir(src_path):
171+
shutil.copytree(src_path, dest_path)
172+
else:
173+
shutil.copy(src_path, output_dir)
174+
del model_executer
175+
del llm_writer
176+
gc.collect()
177+
if torch.cuda.is_available():
178+
torch.cuda.empty_cache()
179+
torch.cuda.synchronize()
180+
181+
# set the model storage path
182+
storage_path = os.getenv("STORAGE_PATH", "./models")
183+
model_dir = os.path.join(storage_path, model_name)
184+
185+
# create the output directory
186+
if os.path.exists(model_dir):
187+
print(f"Already exists: {model_dir}")
188+
return
189+
os.makedirs(model_dir, exist_ok=True)
190+
191+
try:
192+
with TemporaryDirectory() as cache_dir:
193+
# download model from huggingface
194+
input_dir = snapshot_download(
195+
model_name,
196+
cache_dir=cache_dir,
197+
allow_patterns=["*.safetensors", "*.bin", "*.json", "*.txt"],
198+
)
199+
_run_writer(input_dir, model_dir)
200+
except Exception as e:
201+
print(f"An error occurred while saving the model: {e}")
202+
# remove the output dir
203+
shutil.rmtree(model_dir)
204+
raise RuntimeError(
205+
f"Failed to save model {model_name} for vllm backend: {e}"
206+
)
207+
208+
downloader = VllmModelDownloader()
209+
downloader.download_vllm_model("facebook/opt-1.3b", "float16", 1)
210+
211+
```
212+
213+
After downloading the model, you can launch the checkpoint store server and load the model in vLLM through `serverless_llm` load format.
214+
215+
2. Launch the checkpoint store server in a separate process:
216+
```bash
217+
# 'mem_pool_size' is the maximum size of the memory pool in GB. It should be larger than the model size.
218+
sllm-store-server --storage_path $PWD/models --mem_pool_size 32
219+
```
220+
221+
3. Load the model in vLLM:
222+
```python
223+
from vllm import LLM, SamplingParams
224+
225+
import os
226+
227+
storage_path = os.getenv("STORAGE_PATH", "./models")
228+
model_name = "facebook/opt-1.3b"
229+
model_path = os.path.join(storage_path, model_name)
230+
231+
llm = LLM(
232+
model=model_path,
233+
load_format="serverless_llm",
234+
dtype="float16"
235+
)
236+
237+
prompts = [
238+
"Hello, my name is",
239+
"The president of the United States is",
240+
"The capital of France is",
241+
"The future of AI is",
242+
]
243+
244+
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
245+
outputs = llm.generate(prompts, sampling_params)
246+
247+
# Print the outputs.
248+
for output in outputs:
249+
prompt = output.prompt
250+
generated_text = output.outputs[0].text
251+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
252+
```

0 commit comments

Comments
 (0)