Skip to content

Commit c84461b

Browse files
committed
Update test.py
1 parent f95b0ba commit c84461b

File tree

2 files changed

+36
-104
lines changed

2 files changed

+36
-104
lines changed

configs/train/Adapter-XL-sketch.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ model:
1414
data:
1515
target: dataset.dataset_laion.WebDataModuleFromConfig_Laion_Lexica
1616
params:
17-
tar_base1: "/group/30042/public_datasets/LAION_6plus"
18-
tar_base2: "/group/30042/public_datasets/RestoreData/Lexica/WebDataset"
17+
tar_base1: "data/LAION_6plus"
18+
tar_base2: "data/WebDataset"
1919
batch_size: 2
2020
num_workers: 8
2121
multinode: True

test.py

Lines changed: 34 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,107 +1,39 @@
1-
from omegaconf import OmegaConf
1+
from diffusers import StableDiffusionXLAdapterPipeline, T2IAdapter, EulerAncestralDiscreteScheduler, AutoencoderKL
2+
from diffusers.utils import load_image, make_image_grid
3+
from controlnet_aux.lineart import LineartDetector
24
import torch
3-
import os
4-
import cv2
5-
import datetime
6-
from huggingface_hub import hf_hub_url
7-
import subprocess
8-
import shlex
9-
import copy
10-
from basicsr.utils import tensor2img
11-
12-
from Adapter.Sampling import diffusion_inference
13-
from configs.utils import instantiate_from_config
14-
from Adapter.inference_base import get_base_argument_parser
15-
from Adapter.extra_condition.api import get_cond_model, ExtraCondition
16-
from Adapter.extra_condition import api
17-
18-
urls = {
19-
'TencentARC/T2I-Adapter':[
20-
'models_XL/adapter-xl-canny.pth', 'models_XL/adapter-xl-sketch.pth',
21-
'models_XL/adapter-xl-openpose.pth', 'third-party-models/body_pose_model.pth',
22-
'third-party-models/table5_pidinet.pth'
23-
]
24-
}
25-
26-
if os.path.exists('checkpoints') == False:
27-
os.mkdir('checkpoints')
28-
for repo in urls:
29-
files = urls[repo]
30-
for file in files:
31-
url = hf_hub_url(repo, file)
32-
name_ckp = url.split('/')[-1]
33-
save_path = os.path.join('checkpoints',name_ckp)
34-
if os.path.exists(save_path) == False:
35-
subprocess.run(shlex.split(f'wget {url} -O {save_path}'))
36-
37-
# config
38-
parser = get_base_argument_parser()
39-
parser.add_argument(
40-
'--model_id',
41-
type=str,
42-
default="stabilityai/stable-diffusion-xl-base-1.0",
43-
help='huggingface url to stable diffusion model',
44-
)
45-
parser.add_argument(
46-
'--config',
47-
type=str,
48-
default='configs/inference/Adapter-XL-sketch.yaml',
49-
help='config path to T2I-Adapter',
50-
)
51-
parser.add_argument(
52-
'--path_source',
53-
type=str,
54-
default='examples/dog.png',
55-
help='config path to the source image',
56-
)
57-
parser.add_argument(
58-
'--in_type',
59-
type=str,
60-
default='image',
61-
help='config path to the source image',
62-
)
63-
global_opt = parser.parse_args()
64-
global_opt.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
655

666
if __name__ == '__main__':
67-
config = OmegaConf.load(global_opt.config)
68-
# Adapter creation
69-
cond_name = config.model.params.adapter_config.name
70-
adapter_config = config.model.params.adapter_config
71-
adapter = instantiate_from_config(adapter_config).cuda()
72-
adapter.load_state_dict(torch.load(config.model.params.adapter_config.pretrained))
73-
cond_model = get_cond_model(global_opt, getattr(ExtraCondition, cond_name))
74-
process_cond_module = getattr(api, f'get_cond_{cond_name}')
75-
76-
# diffusion sampler creation
77-
sampler = diffusion_inference(global_opt.model_id)
7+
# load adapter
8+
adapter = T2IAdapter.from_pretrained(
9+
"TencentARC/t2i-adapter-lineart-sdxl-1.0", torch_dtype=torch.float16, varient="fp16"
10+
).to("cuda")
7811

79-
# diffusion generation
80-
cond = process_cond_module(
81-
global_opt,
82-
global_opt.path_source,
83-
cond_inp_type = global_opt.in_type,
84-
cond_model = cond_model
12+
# load euler_a scheduler
13+
model_id = 'stabilityai/stable-diffusion-xl-base-1.0'
14+
euler_a = EulerAncestralDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler")
15+
vae=AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
16+
pipe = StableDiffusionXLAdapterPipeline.from_pretrained(
17+
model_id, vae=vae, adapter=adapter, scheduler=euler_a, torch_dtype=torch.float16, variant="fp16",
18+
).to("cuda")
19+
pipe.enable_xformers_memory_efficient_attention()
20+
21+
line_detector = LineartDetector.from_pretrained("lllyasviel/Annotators").to("cuda")
22+
23+
url = "https://huggingface.co/Adapter/t2iadapter/resolve/main/figs_SDXLV1.0/org_lin.jpg"
24+
image = load_image(url)
25+
image = line_detector(
26+
image, detect_resolution=384, image_resolution=1024
8527
)
86-
with torch.no_grad():
87-
adapter_features = adapter(cond)
88-
result = sampler.inference(
89-
prompt = global_opt.prompt,
90-
prompt_n = global_opt.neg_prompt,
91-
steps = global_opt.steps,
92-
adapter_features = copy.deepcopy(adapter_features),
93-
guidance_scale = global_opt.scale,
94-
size = (cond.shape[-2], cond.shape[-1]),
95-
seed= global_opt.seed,
96-
)
97-
98-
# save results
99-
root_results = os.path.join('results', cond_name)
100-
if not os.path.exists(root_results):
101-
os.makedirs(root_results)
102-
now = datetime.datetime.now()
103-
formatted_date = now.strftime("%Y-%m-%d")
104-
formatted_time = now.strftime("%H:%M:%S")
105-
im_cond = tensor2img(cond)
106-
cv2.imwrite(os.path.join(root_results, formatted_date+'-'+formatted_time+'_image.png'), result)
107-
cv2.imwrite(os.path.join(root_results, formatted_date+'-'+formatted_time+'_condition.png'), im_cond)
28+
29+
prompt = "Ice dragon roar, 4k photo"
30+
negative_prompt = "anime, cartoon, graphic, text, painting, crayon, graphite, abstract, glitch, deformed, mutated, ugly, disfigured"
31+
gen_images = pipe(
32+
prompt=prompt,
33+
negative_prompt=negative_prompt,
34+
image=image,
35+
num_inference_steps=30,
36+
adapter_conditioning_scale=0.8,
37+
guidance_scale=7.5,
38+
).images[0]
39+
gen_images.save('out_lin.png')

0 commit comments

Comments
 (0)