-
Notifications
You must be signed in to change notification settings - Fork 51
Expand file tree
/
Copy pathpano_gen.py
More file actions
104 lines (95 loc) · 3.26 KB
/
pano_gen.py
File metadata and controls
104 lines (95 loc) · 3.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import torch
import tempfile
from pathlib import Path
from huggingface_hub import hf_hub_download
from .models.flux_pano_gen_pipeline import FluxPipeline
from .models.flux_pano_fill_pipeline import FluxFillPipeline
import re
def build_pano_gen_model(lora_path=None, device="cuda"):
if lora_path is None:
lora_path = hf_hub_download(repo_id="LeoXie/WorldGen", filename=f"models--WorldGen-Flux-Lora/worldgen_text2scene.safetensors")
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16, device=device)
print(f"Loading LoRA weights from: {lora_path}")
pipe.load_lora_weights(lora_path)
pipe.enable_model_cpu_offload()
pipe.enable_vae_tiling()
return pipe
def build_pano_fill_model(lora_path=None, device="cuda:0"):
if lora_path is None:
lora_path = hf_hub_download(repo_id="LeoXie/WorldGen", filename=f"models--WorldGen-Flux-Lora/worldgen_img2scene.safetensors")
pipe = FluxFillPipeline.from_pretrained(
"black-forest-labs/FLUX.1-Fill-dev",
dtype=torch.bfloat16,
low_cpu_mem_usage=True,
)
print(f"Loading LoRA weights from: {lora_path}")
pipe.load_lora_weights(lora_path)
match = re.search(r"cuda:(\d+)", str(device))
gpu_id = int(match.group(1)) if match else 0
pipe.enable_model_cpu_offload(gpu_id=gpu_id ) # Save VRAM
pipe.enable_vae_tiling()
return pipe
def gen_pano_image(
model,
prompt="",
output_path=None,
seed=42,
guidance_scale=7.0,
num_inference_steps=50,
height=800,
width=1600,
blend_extend=6,
prefix="A high quality 360 panorama photo of",
suffix="HDR, RAW, 360 consistent, omnidirectional",
):
"""Generates a panorama image using FLUX.1-dev and a LoRA."""
prompt = f"{prefix}, {prompt}, {suffix}"
generator = torch.Generator("cpu").manual_seed(seed)
image = model(
prompt,
height=height,
width=width,
generator=generator,
num_inference_steps=num_inference_steps,
blend_extend=blend_extend,
guidance_scale=guidance_scale
).images[0]
if output_path is not None:
image.save(output_path)
print(f"Panorama image saved to {output_path}")
return image
def gen_pano_fill_image(
model,
image,
mask,
prompt="a scene",
output_path=None,
seed=42,
guidance_scale=30.0,
num_inference_steps=50,
height=800,
width=1600,
blend_extend=0,
prefix="A high quality 360 panorama photo of",
suffix="HDR, RAW, 360 consistent, omnidirectional",
):
image = image.resize((width, height))
mask = mask.resize((width, height))
generator = torch.Generator("cpu").manual_seed(seed)
prompt = f"{prefix} {prompt} {suffix}"
image = model(
prompt,
height=height,
width=width,
image=image,
mask_image=mask,
generator=generator,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
blend_extend=blend_extend
).images[0]
if output_path is not None:
image.save(output_path)
print(f"Panorama image saved to {output_path}")
return image