Skip to content

Commit 3d2db4e

Browse files
authored
Update diffusion model checkpoint to be milabench checkpoint (#379)
* Update diffusion model checkpoint to be milabench checkpoint * Update main.py * Update prepare.py
1 parent 97f5ad2 commit 3d2db4e

File tree

3 files changed

+4
-3
lines changed

3 files changed

+4
-3
lines changed

benchmarks/diffusion/main.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
@dataclass
2323
class Arguments:
24-
model: str = "stabilityai/stable-diffusion-2"
24+
model: str = "Milabench/stable-diffusion-2"
2525
dataset: str = "lambdalabs/naruto-blip-captions"
2626
batch_size: int = 16
2727
num_workers: int = 8
@@ -252,4 +252,4 @@ def main():
252252

253253

254254
if __name__ == "__main__":
255-
main()
255+
main()

benchmarks/diffusion/prepare.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
@dataclass
88
class TrainingConfig:
9-
model: str = "stabilityai/stable-diffusion-2"
9+
model: str = "Milabench/stable-diffusion-2"
1010
dataset: str = "lambdalabs/naruto-blip-captions"
1111
revision: str = None
1212
variant: str = None

config/base.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,7 @@ _diffusion:
420420
method: per_gpu
421421

422422
argv:
423+
--model: "Milabench/stable-diffusion-2"
423424
--num_epochs: 5
424425
--batch_size: "auto_batch(32)"
425426
--num_workers: "auto({n_worker}, 8)"

0 commit comments

Comments
 (0)