-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathrun_gradio.py
More file actions
92 lines (82 loc) · 2.83 KB
/
run_gradio.py
File metadata and controls
92 lines (82 loc) · 2.83 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import sys
import torch
from stable_audio_3 import StableAudioModel
from stable_audio_3.interface.diffusion_cond import create_diffusion_cond_ui
from stable_audio_3.verbose import set_verbose
# Silence Python warnings (FutureWarning, DeprecationWarning, etc.) unless --verbose.
# Must run before any ML library imports since most warnings fire at import time.
# We keep Gradio chatter, HF/torch progress bars, and generation tqdm intact.
if "--verbose" not in sys.argv:
import os as _os
_os.environ.setdefault("PYTHONWARNINGS", "ignore")
import warnings as _warnings
_warnings.filterwarnings("ignore")
def main(args):
set_verbose(getattr(args, "verbose", False))
torch.manual_seed(42)
model_half = args.model_half
model = StableAudioModel.from_pretrained(args.model, model_half=model_half)
if args.lora_ckpt_path:
model.load_lora(args.lora_ckpt_path)
interface = create_diffusion_cond_ui(
model,
gradio_title=args.title if args.title is not None else "Stable Audio 3",
default_prompt=args.default_prompt,
)
interface.queue()
interface.launch(
share=True,
js=getattr(interface, "_sao_js", None),
theme=getattr(interface, "_sao_theme", None),
)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Run gradio interface")
parser.add_argument(
"--model", type=str, help="Name of pretrained model", required=True
)
parser.add_argument(
"--model-config", type=str, help="Path to model config", required=False
)
parser.add_argument(
"--ckpt-path", type=str, help="Path to model checkpoint", required=False
)
parser.add_argument(
"--pretransform-ckpt-path",
type=str,
help="Optional to model pretransform checkpoint",
required=False,
)
parser.add_argument("--username", type=str, help="Gradio username", required=False)
parser.add_argument("--password", type=str, help="Gradio password", required=False)
parser.add_argument(
"--model-half",
action="store_true",
help="Whether to use half precision",
required=False,
default=True,
)
parser.add_argument(
"--title", type=str, help="Display Title top of Gradio", required=False
)
parser.add_argument(
"--lora-ckpt-path",
type=str,
nargs="*",
help="Path(s) for LoRA(s) to apply. Can specify multiple.",
required=False,
)
parser.add_argument(
"--default-prompt",
type=str,
default=None,
help="Default prompt to pre-fill in the textbox",
)
parser.add_argument(
"--verbose",
action="store_true",
default=False,
help="Print detailed load/generation progress",
)
args = parser.parse_args()
main(args)