Skip to content

Commit 6e65e35

Browse files
authored
Add with_prior_preservation option for dreambooth (microsoft#2301)
## Describe your changes Add with_prior_preservation option for dreambooth. This option will generate class image automatically for user. ## Checklist before requesting a review - [ ] Add unit tests for this change. - [ ] Make sure all tests can pass. - [ ] Update documents if necessary. - [ ] Lint and apply fixes to your code by running `lintrunner -a` - [ ] Is this a user-facing change? If yes, give a description of this change to be included in the release notes. ## (Optional) Issue link
1 parent cb3e7af commit 6e65e35

File tree

2 files changed

+240
-43
lines changed

2 files changed

+240
-43
lines changed

olive/cli/diffusion_lora.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,36 @@ def register_subcommand(parser: ArgumentParser):
9595
help="Fixed prompt for all images in DreamBooth mode. Required when --dreambooth is set. "
9696
"Example: 'a photo of sks dog'.",
9797
)
98+
db_group.add_argument(
99+
"--with_prior_preservation",
100+
action="store_true",
101+
help="Enable prior preservation to prevent language drift. Requires --class_prompt.",
102+
)
103+
db_group.add_argument(
104+
"--class_prompt",
105+
type=str,
106+
default=None,
107+
help="Prompt for class images in prior preservation. Required when --with_prior_preservation is set. "
108+
"Example: 'a photo of a dog'.",
109+
)
110+
db_group.add_argument(
111+
"--class_data_dir",
112+
type=str,
113+
default=None,
114+
help="Directory containing class images. If not provided or has fewer than --num_class_images, "
115+
"images will be auto-generated.",
116+
)
117+
db_group.add_argument(
118+
"--num_class_images",
119+
type=int,
120+
default=200,
121+
help="Number of class images for prior preservation. Default: 200.",
122+
)
98123
db_group.add_argument(
99124
"--prior_loss_weight",
100125
type=float,
101126
default=1.0,
102-
help="Weight of prior preservation loss (only for DreamBooth). Default: 1.0.",
127+
help="Weight of prior preservation loss. Default: 1.0.",
103128
)
104129

105130
# Data options
@@ -274,6 +299,10 @@ def _get_run_config(self, tempdir: str) -> dict:
274299
((*pass_key, "lora_dropout"), self.args.lora_dropout),
275300
((*pass_key, "dreambooth"), self.args.dreambooth),
276301
((*pass_key, "instance_prompt"), self.args.instance_prompt),
302+
((*pass_key, "with_prior_preservation"), self.args.with_prior_preservation),
303+
((*pass_key, "class_prompt"), self.args.class_prompt),
304+
((*pass_key, "class_data_dir"), self.args.class_data_dir),
305+
((*pass_key, "num_class_images"), self.args.num_class_images),
277306
((*pass_key, "prior_loss_weight"), self.args.prior_loss_weight),
278307
((*pass_key, "merge_lora"), self.args.merge_lora),
279308
(

0 commit comments

Comments
 (0)