-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathstyle_aligned_sd1.py
More file actions
75 lines (64 loc) · 2.72 KB
/
style_aligned_sd1.py
File metadata and controls
75 lines (64 loc) · 2.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import argparse
import json
from diffusers import DDIMScheduler, StableDiffusionPipeline
import torch
import numpy as np
import random
from models.stylealigned import sa_handler
import os
def setup_seed(seed=1234):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Set up argument parser
def parse_args():
parser = argparse.ArgumentParser(description="Generate images using Stable Diffusion")
parser.add_argument('--prompts_file', type=str, required=True,
help="Path to the JSON file containing a list of prompts")
parser.add_argument('--output_dir', type=str, required=True,
help="Directory to save the generated images")
return parser.parse_args()
# Parse the arguments
args = parse_args()
# Load prompts from the JSON file
with open(args.prompts_file, 'r') as f:
sets_of_prompts = json.load(f)
# Ensure the prompts are in a list format
if not isinstance(sets_of_prompts, list):
raise ValueError("The JSON file must contain a list of prompts.")
# Define output directory
output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False,
set_alpha_to_one=False)
pipeline = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="fp16",
scheduler=scheduler
)
pipeline = pipeline.to("cuda")
setup_seed()
for prompt_set in sets_of_prompts:
prompts = [f"{obj} {prompt_set['prompt']}" for obj in prompt_set['objects']]
for i in range(1, len(prompts)):
handler = sa_handler.Handler(pipeline)
sa_args = sa_handler.StyleAlignedArgs(share_group_norm=True,
share_layer_norm=True,
share_attention=True,
adain_queries=True,
adain_keys=True,
adain_values=False,)
handler.register(sa_args)
# Run StyleAligned
images = pipeline([prompts[0],prompts[i]], generator=None).images
# Save images
output_folder = os.path.join(output_dir, prompt_set['prompt'])
os.makedirs(output_folder, exist_ok=True)
if not os.path.exists(os.path.join(output_folder, "generated_image_0.png")):
images[0].save(os.path.join(output_folder, "generated_image_0.png"))
img_path = os.path.join(output_folder, f"generated_image_{i}.png")
images[1].save(img_path)