-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhandler.py
More file actions
103 lines (84 loc) · 3.02 KB
/
handler.py
File metadata and controls
103 lines (84 loc) · 3.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import runpod
import os
import torch
# Pre-import torchvision to ensure it's initialized correctly
import torchvision
import uuid
import sys
import requests
from PIL import Image
from io import BytesIO
from utils.s3 import S3Client
# Add TRELLIS to path
sys.path.append(os.path.join(os.getcwd(), 'TRELLIS_REPO'))
# Delay import until handler to allow lazy loading/env setup if needed
pipeline = None
def get_pipeline():
global pipeline
if pipeline is None:
# Pre-import required modules
import utils3d
# Load the specific large image model
# Using gqk fork since Microsoft deleted original repos (see issue https://github.com/quanticsoul4772/trellis-runpod-worker/commit/63020acd3bdad52cafd0d675f4ca9d56e3984159)
from trellis.pipelines import TrellisImageTo3DPipeline
model_repo = "gqk/TRELLIS-image-large-fork"
print(f"Loading {model_repo}...")
pipeline = TrellisImageTo3DPipeline.from_pretrained(model_repo)
pipeline.cuda()
return pipeline
# Initialize S3 Client
s3_client = S3Client()
def handler(job):
"""
RunPod Serverless Handler for TRELLIS 3D Generation.
"""
job_input = job['input']
# Extract image_url
image_url = job_input.get('image_url') or (job_input.get('images', [None])[0])
if not image_url:
return {"error": "No image_url or images array provided."}
task_id = job.get('id', str(uuid.uuid4()))
try:
# 1. Get/Initialize Pipeline (Downloads weights if first run)
pipe = get_pipeline()
from trellis.utils import postprocessing_utils
# 2. Download Image (with ngrok skip header)
headers = {"ngrok-skip-browser-warning": "true"}
response = requests.get(image_url, headers=headers, timeout=30)
response.raise_for_status()
input_image = Image.open(BytesIO(response.content))
# 3. Run Inference
outputs = pipe.run(input_image, seed=1)
# 4. Post-process to GLB
glb = postprocessing_utils.to_glb(
outputs['gaussian'][0],
outputs['mesh'][0],
simplify=0.95,
texture_size=1024,
)
model_path = f"/tmp/{task_id}.glb"
glb.export(model_path)
# 5. Upload to S3
s3_path = f"models/{task_id}.glb"
public_url = s3_client.upload_file(model_path, s3_path)
# Cleanup local file
if os.path.exists(model_path):
os.remove(model_path)
if not public_url:
return {"error": "S3 Upload failed."}
return {
"status": "COMPLETED",
"model_url": public_url,
"task_id": task_id
}
except Exception as e:
import traceback
error_msg = str(e)
trace = traceback.format_exc()
print(f"Error during execution: {error_msg}")
return {
"error": error_msg,
"trace": trace,
"status": "FAILED"
}
runpod.serverless.start({"handler": handler})