Skip to content
This repository was archived by the owner on Apr 2, 2024. It is now read-only.
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,9 @@ dmypy.json
.pyre/
.idea
.lightning

# weights
sd_weights.tar.gz
v1-5-pruned-emaonly.ckpt
768-v-ema.ckpt
512-base-ema.ckpt
5 changes: 5 additions & 0 deletions .lightningignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,8 @@ venv/
.git/
.idea/

# weights
sd_weights.tar.gz
v1-5-pruned-emaonly.ckpt
768-v-ema.ckpt
512-base-ema.ckpt
17 changes: 9 additions & 8 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# !pip install lightning_api_access
# !pip install lightning_api_access streamlit pandas
# !pip install 'git+https://github.com/Lightning-AI/stablediffusion.git@lit'
# !curl https://raw.githubusercontent.com/Lightning-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml -o v2-inference-v.yaml
import time
Expand Down Expand Up @@ -43,6 +43,7 @@ def setup(self):
def predict(self, requests):
start = time.time()
batch_size = len(requests.inputs)
print(batch_size)
texts = [request.text for request in requests.inputs]

images = self._model.predict_step(
Expand All @@ -57,7 +58,7 @@ def predict(self, requests):
image_str = base64.b64encode(buffer.getvalue()).decode("utf-8")
results.append(image_str)

print(f"finish predicting with batch size {batch_size} in {time.time() - start} seconds")
print(f"Finish predicting with batch size {batch_size} in {time.time() - start} seconds")
return BatchResponse(outputs=[{"image": image_str} for image_str in results])


Expand All @@ -67,15 +68,15 @@ def predict(self, requests):

# autoscaler args
min_replicas=1,
max_replicas=3,
max_replicas=1,
endpoint="/predict",
autoscale_up_interval=0,
autoscale_down_interval=1800, # 30 minutes
max_batch_size=8,
timeout_batching=2,
scale_out_interval=0,
scale_in_interval=30,
max_batch_size=7,
timeout_batching=5,
input_type=Text,
output_type=Image,
cold_start_proxy=CustomColdStartProxy(),
# cold_start_proxy=CustomColdStartProxy(),
)

app = L.LightningApp(component)
63 changes: 63 additions & 0 deletions app_emulator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# !pip install lightning_api_access
# !pip install 'git+https://github.com/Lightning-AI/stablediffusion.git@lit'
# !curl https://raw.githubusercontent.com/Lightning-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml -o v2-inference-v.yaml
import lightning as L
import os
from time import sleep

from autoscaler import AutoScaler
from datatypes import BatchText, BatchResponse, Text, Image

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"


class EmulatorDiffusionServer(L.app.components.PythonServer):
def __init__(self, *args, **kwargs):
super().__init__(
input_type=BatchText,
output_type=BatchResponse,
*args,
**kwargs,
)

def setup(self):
pass
# sleep(0)

def predict(self, requests):
batch_size = len(requests.inputs)

print(batch_size)

sleep_times = {
1: 10,
2: 16,
3: 24,
4: 30,
5: 36,
6: 46,
7: 53,
8: 59,
}
sleep(sleep_times[batch_size])
return BatchResponse(outputs=[{"image": "image_str"} for _ in range(batch_size)])


component = AutoScaler(
EmulatorDiffusionServer, # The component to scale
cloud_compute=L.CloudCompute("gpu-rtx", disk_size=80),

# autoscaler args
min_replicas=1,
max_replicas=1,
endpoint="/predict",
scale_out_interval=0,
scale_in_interval=30,
max_batch_size=6,
timeout_batching=4,
input_type=Text,
output_type=Image,
# cold_start_proxy=CustomColdStartProxy(),
)

app = L.LightningApp(component)
Loading