LML-diffusion-sampler / scripts /StableDiffusion_COCO.py
王方懿康
Initial commit
ab2369a
Raw
History Blame Contribute Delete
4.97 kB
import sys
import torch
import os
import json
import argparse
sys.path.append(os.getcwd())
from diffusers import StableDiffusionPipeline, DPMSolverMultistepLMScheduler, DDIMLMScheduler, PNDMScheduler, UniPCMultistepScheduler
from scheduler.scheduling_dpmsolver_multistep_lm import DPMSolverMultistepLMScheduler
from scheduler.scheduling_ddim_lm import DDIMLMScheduler
def main():
parser = argparse.ArgumentParser(description="sampling script for COCO14.")
parser.add_argument('--test_num', type=int, default=1000)
parser.add_argument('--start_index', type=int, default=0)
parser.add_argument('--seed', type=int, default=1)
parser.add_argument('--num_inference_steps', type=int, default=20)
parser.add_argument('--guidance', type=float, default=7.5)
parser.add_argument('--sampler_type', type = str, default='ddim')
parser.add_argument('--model_id', type=str, default='/xxx/xxx/stable-diffusion-v1-5')
parser.add_argument('--save_dir', type=str, default='/xxx/xxx')
parser.add_argument('--lamb', type=float, default=5.0)
parser.add_argument('--kappa', type=float, default=0.0)
parser.add_argument('--device', type=str, default='cuda')
args = parser.parse_args()
start_index = args.start_index
sampler_type = args.sampler_type
test_num = args.test_num
guidance_scale = args.guidance
num_inference_steps = args.num_inference_steps
lamb = args.lamb
kappa = args.kappa
device = args.device
model_id = args.model_id
# load model
sd_pipe = None
sd_pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float32, safety_checker=None)
sd_pipe = sd_pipe.to(device)
print("sd model loaded")
if sampler_type in ['dpm_lm']:
sd_pipe.scheduler = DPMSolverMultistepLMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe.scheduler.config.solver_order = 3
sd_pipe.scheduler.config.algorithm_type = "dpmsolver"
sd_pipe.scheduler.lamb = lamb
sd_pipe.scheduler.lm = True
elif sampler_type in ['dpm']:
sd_pipe.scheduler = DPMSolverMultistepLMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe.scheduler.config.solver_order = 3
sd_pipe.scheduler.config.algorithm_type = "dpmsolver"
sd_pipe.scheduler.lamb = lamb
sd_pipe.scheduler.lm = False
elif sampler_type in ['dpm++']:
sd_pipe.scheduler = DPMSolverMultistepLMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe.scheduler.config.solver_order = 3
sd_pipe.scheduler.config.algorithm_type = "dpmsolver++"
sd_pipe.scheduler.lamb = lamb
sd_pipe.scheduler.lm = False
elif sampler_type in ['dpm++_lm']:
sd_pipe.scheduler = DPMSolverMultistepLMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe.scheduler.config.solver_order = 3
sd_pipe.scheduler.config.algorithm_type = "dpmsolver++"
sd_pipe.scheduler.lamb = lamb
sd_pipe.scheduler.lm = True
elif sampler_type in ['pndm']:
sd_pipe.scheduler = PNDMScheduler.from_config(sd_pipe.scheduler.config)
elif sampler_type in ['ddim']:
sd_pipe.scheduler = DDIMLMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe.scheduler.lamb = lamb
sd_pipe.scheduler.lm = False
sd_pipe.scheduler.kappa = kappa
elif sampler_type in ['ddim_lm']:
sd_pipe.scheduler = DDIMLMScheduler.from_config(sd_pipe.scheduler.config)
sd_pipe.scheduler.lamb = lamb
sd_pipe.scheduler.lm = True
sd_pipe.scheduler.kappa = kappa
elif sampler_type in ['unipc']:
sd_pipe.scheduler = UniPCMultistepScheduler.from_config(sd_pipe.scheduler.config)
save_dir = args.save_dir
if not os.path.exists(save_dir):
os.makedirs(save_dir, exist_ok=True)
# COCO prompts
with open('/mnt/chongqinggeminiceph1fs/geminicephfs/mm-base-vision/pazelzhang/make_dataset/fid_3W_json.json') as fr:
COCO_prompts_dict = json.load(fr)
image_id = COCO_prompts_dict.keys()
with torch.no_grad():
for pi, key in enumerate(image_id):
if pi >= start_index and pi < start_index + test_num:
print(key)
print(COCO_prompts_dict[key])
prompt = COCO_prompts_dict[key]
negative_prompt = None
for seed in [1]:
generator = torch.Generator(device='cuda')
generator = generator.manual_seed(args.seed)
res = sd_pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale, generator=generator).images[0]
res.save(os.path.join(save_dir, f"{pi:05d}_{key}_guidance{guidance_scale}_inference{num_inference_steps}_seed{seed}_{sampler_type}.jpg"))
print(f"{sampler_type}##{key},done")
if __name__ == '__main__':
main()