-
Notifications
You must be signed in to change notification settings - Fork 494
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Image Generation Dataset #2140
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2140
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 4e6b320 with merge base 06a8379 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Saw this RFC and want to pen down some of my thoughts as I'm also building some fine-tuning pipeline for Flux.
|
def __call__(self, caption: str) -> str: ... | ||
|
||
|
||
class ImgTextDataset(torch.utils.data.Dataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should basically look like the text completion dataset but with model_transform as you have here and column_map instead of column. I also don't know if we want to anchor this to vision as diffusion for audio etc would use the same pattern. I think an optional data_transform would work here instead of img/text.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should include an image transform here that way we can separate model-independent image augmentations from model-specific ones in the model transform.
Also, I don't think this should be a generic dataset class for diffusion in general. It shouldn't be tied to diffusion at all, and instead be for any downstream task that uses image-text pairs, e.g. non-diffusion image gen models, image captioning models, image-text joint encoders, etc. There were a lot of papers at NeurIPS this year that was finetuning CLIP. I would expect this to utilize the same ImageTextDataset as finetuning Flux would. If you're doing diffusion for audio, you would use an AudioTextDataset
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with you WRT a dataset class (Hence why everything so far essentially returns an SFT dataset). However, I do think there's tremendous value in aligning our dataset builders with specific tasks. It makes it easier to utilize from configs and find datasets to use on the Hub.
return data_dict | ||
|
||
|
||
class FluxTransform(Transform): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generic diffusion model transform will just take a dict instead of a list of messages but otherwise be the same
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is logic here that is specific to Flux and I think it should exist withing a Flux-specific model transform
... | ||
|
||
|
||
def _build_torchvision_transforms(cfg): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This along with CaptionTransform is all within the abstraction of model transform or data transform as the user needs. Or is this meant to be an example?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should separate the data transform logic from the model transform logic, e.g. the data augmentations like horizontal flip would be in a img transform that's entirely separate from model logic, and the model-specific logic like image normalization would be in the model transform
) | ||
|
||
|
||
def _load_img_text_dataset(path): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We use huggingface load_dataset as well as load_image
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding load_image, thanks I'll switch to this. Question though: when the image path is a URL, should we include the option for saving these images to disk so that they don't need to be re-downloaded during the next epoch?
Regarding load_dataset, I address this in the first bullet of the user experience section. I personally think it's better if we handle simple cases like loading a image-caption TSV ourselves so the user doesn't have to go read huggingface docs, especially since most img gen finetuning will be done on small local datasets, but I'm also ok with just relying on huggingface's load_dataset since that does make our code simpler
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Totally understand your point on not wanting to overcomplicate things, but using load_dataset under the hood makes our lives way easier lol
|
||
```yaml | ||
dataset: | ||
_component_: torchtune.datasets.img_caption_dataset |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe just my naivate, but when I hear image-caption dataset, I assume it's a dataset for taking an image and generating a caption, which is not the case here.
Hugging Face has a label for these datasets called "Text-to-Image", which I think is a more accurate description. This also is inline with our addition of task-centered dataset builders like the vqa_dataset.
Concretely proposing changing the default dataset for diffusion from img_caption_dataset
to text_to_image_dataset
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I figured that this dataset could be used for any downstream task that uses pairs of images+text. Like finetuning CLIP for example. Maybe image_text_pair_dataset
? Or is it more clear for the user if we name the datasets based on a specific use of it?
```yaml | ||
dataset: | ||
_component_: torchtune.datasets.img_caption_dataset | ||
path: ~/my_dataset/data.tsv |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it more common to have ahem private data to finetune diffusion models, or data that might be published on the Hugging Face Hub? That should affect what the first-class citizen is here and what goes in all our examples.
Regardless, if we're using the load_dataset
functionality from Hugging Face (like we do for all our other datasets including image-to-text), why does this not follow the same format where we specify e.g. TSV as the source and data_files=~/my_dataset/data.tsv?
resize: [256, 256] | ||
center_crop: true | ||
horizontal_flip: 0.5 | ||
caption_transform: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See above comment, but would opt for text
not caption
here.
caption_transform: | ||
drop: 0.05 | ||
shuffle_parts: 0.1 | ||
tokenizer: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you show how this would look from code? I know we prefer flattened params for our configs, but if this was build via code I'd imagine we'd instantiate Clip and T5 and then pass that to our FluxTransform - right?
model_transform: Transform, | ||
*, | ||
path: str, | ||
img_transform: Config, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we ever want our builders to see the notion of configs. Configs are just a way to interface with our recipes, but builders should be able to be dropped into place anywhere.
def __call__(self, caption: str) -> str: ... | ||
|
||
|
||
class ImgTextDataset(torch.utils.data.Dataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with you WRT a dataset class (Hence why everything so far essentially returns an SFT dataset). However, I do think there's tremendous value in aligning our dataset builders with specific tasks. It makes it easier to utilize from configs and find datasets to use on the Hub.
# User Experience | ||
|
||
- Regarding loading the TSV/Parquet/whatever data file, should we just rely on huggingface's `load_dataset` like we currently do in `SFTDataset`? It keeps the code simpler, but it makes the user leave torchtune and go read the huggingface docs, which is overkill if they just have some simple JSON file we could easily load ourselves. | ||
- In addition to absolute image paths in the data file, we should probably support image paths relative to the dataset folder, because it would be super annoying if you had to regenerate your data file any time to move the dataset to a new location. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is handled via our current image/text dataset utilities.
- Regarding loading the TSV/Parquet/whatever data file, should we just rely on huggingface's `load_dataset` like we currently do in `SFTDataset`? It keeps the code simpler, but it makes the user leave torchtune and go read the huggingface docs, which is overkill if they just have some simple JSON file we could easily load ourselves. | ||
- In addition to absolute image paths in the data file, we should probably support image paths relative to the dataset folder, because it would be super annoying if you had to regenerate your data file any time to move the dataset to a new location. | ||
- There's currently some potentially unnecessary fields in the config. For example with Flux models, the model determines the image size and the T5 tokenizer sequence length. Is it better to pass this information to the image transform and model transform, respectively? Which complicates the code but lowers the chance of user error. Or is it better to have the user define these values in the dataset config and tokenizer config, respectively? Which puts the burden on the user to match what the model expects. | ||
- Should we add scripts/utilities for inspecting the dataset? It's nice to see a preview of what a batch looks like, especially when you're messing around with color jitter and other hard-to-configure image augmentations. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely a cool feature, but probably a P2 or upon-request-from-users type of thing.
|
||
# Other | ||
- Naming of the image-text dataset builders/classes? Maybe the more verbose `image_caption_dataset_for_image_generation` is better to make it clear that this is NOT for something like finetuning a VLM to do image captioning (although maybe it could be generalized to the point where it can also do lists of Message objects and therefore can be used for whatever purpose). | ||
- Support multiple captions per image? I can imagine people wanting to generate multiple captions for their images, and randomly selecting one at a time during training to prevent overfitting. It's kinda a caption augmentation but it's unique for each caption so it would have to be supported at the data level. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be possible to do easily with torchtune, but definitely not OOTB.
In TorchTune, a simple version would look something like this: | ||
|
||
```yaml | ||
dataset: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay one big question: What direction are we trying to go in?
We landed torchdata support which started a refactor of our datasets into dataset-specific utils rather than an entire builder that essentially just spits back an SFT datasets class. IMO this means less code for the user to worry about and makes hacking easier. In addition, this gives us all the benefits from torchdata.
If we believe torchdata is the right way to go (especially for these more data-intensive use cases), then should this be refactored towards that end?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The goal of this was to follow the pattern of our current SFT dataset solution so it'd be easier to move in parallel to the torchdata solution. By following close to STF then it should be trivial to convert this to the torchdata solution once that's finalized.
Overview
This is an RFC regarding how we should support datasets for finetuning text-conditioned image generation models.
A basic data pipeline for this would be:
At a broad level, this fits well into our current TorchTune data ecosystem (except we wouldn't use the "list of Message objects" abstraction, which would change how we interact with the model's tokenizer).
In TorchTune, a simple version would look something like this:
TODO: Collate
We'll need to generalize our collate functions such that they can handle data outside of the tokens-and-labels format they currently expect. I will update this section after I've looked into this.
Caching/Preprocessing
From what I've seen online, some people finetune image generators on massive datasets, but most people just finetune on very small personal datasets, often 5-100 images. So we should probably add support for various caching/preprocessing options that increase disk/mem usage in order to achieve faster iterations. Some ideas for optional configurations:
but I bet preprocessing the Flux image encoding would save a lot of time and GPU memoryedit: actually the T5 text encoder is the part that would benefit the most from preprocessingBut we should evaluate whether each of these is worth it:
Dataset Creation
Should we include scripts/utilities for creating the captions? Users will probably often have just a folder with a bunch of images that they want to finetune on. So we could help them turn that folder into a dataset by using some model to automatically caption them. We could even provide our own models for this by distilling the image captioning capabilities of Llama3.2V-90B into several smaller Llama3.2V models, and let the user pick the one that fits on their device.
We'll also want to support adding words/phrases to the caption that tell the model to generate in the style of this dataset. For example, if I'm finetuning a model on images of myself, I'll want to include something like "a photo of cpelletier" in the caption so that the model learns to associate "cpelletier" with my face. This could be supported at the dataset creation step (i.e. the identifiers are put into the caption data itself, which is simpler), or at the text transform step (i.e. the identifier is specified in the text transform config like 'add "in the style of cpelletier" to the end of each caption', which is a bit more complex but nice that you don't have to change the dataset if you want to experiment with different identifiers).
User Experience
load_dataset
like we currently do inSFTDataset
? It keeps the code simpler, but it makes the user leave torchtune and go read the huggingface docs, which is overkill if they just have some simple JSON file we could easily load ourselves.Other
image_caption_dataset_for_image_generation
is better to make it clear that this is NOT for something like finetuning a VLM to do image captioning (although maybe it could be generalized to the point where it can also do lists of Message objects and therefore can be used for whatever purpose).