Skip to content

Commit b7e5029

Browse files
authored
Various transformer updates to improve performance (#182)
1 parent cb14024 commit b7e5029

File tree

5 files changed

+522
-221
lines changed

5 files changed

+522
-221
lines changed

diffusion/callbacks/log_diffusion_images.py

+68-46
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class LogDiffusionImages(Callback):
3737
seed (int, optional): Random seed to use for generation. Set a seed for reproducible generation.
3838
Default: ``1138``.
3939
use_table (bool): Whether to make a table of the images or not. Default: ``False``.
40+
use_mask (bool): Whether or not to use the mask for the encoded text. Default: ``True``.
4041
t5_encoder (str, optional): path to the T5 encoder to as a second text encoder.
4142
clip_encoder (str, optional): path to the CLIP encoder as the first text encoder.
4243
t5_latent_key: (str): key to use for the T5 latents in the batch. Default: ``'T5_LATENTS'``.
@@ -56,6 +57,7 @@ def __init__(self,
5657
rescaled_guidance: Optional[float] = None,
5758
seed: Optional[int] = 1138,
5859
use_table: bool = False,
60+
use_mask: bool = True,
5961
t5_encoder: Optional[str] = None,
6062
clip_encoder: Optional[str] = None,
6163
t5_latent_key: str = 'T5_LATENTS',
@@ -71,6 +73,7 @@ def __init__(self,
7173
self.rescaled_guidance = rescaled_guidance
7274
self.seed = seed
7375
self.use_table = use_table
76+
self.use_mask = use_mask
7477
self.t5_latent_key = t5_latent_key
7578
self.t5_mask_key = t5_mask_key
7679
self.clip_latent_key = clip_latent_key
@@ -100,47 +103,47 @@ def __init__(self,
100103
local_files_only=True)
101104

102105
t5_model = AutoModel.from_pretrained(t5_encoder,
103-
torch_dtype=torch.float16,
106+
torch_dtype=torch.bfloat16,
104107
cache_dir=self.cache_dir,
105108
local_files_only=True).encoder.cuda().eval()
106109
clip_model = CLIPTextModel.from_pretrained(clip_encoder,
107110
subfolder='text_encoder',
108-
torch_dtype=torch.float16,
111+
torch_dtype=torch.bfloat16,
109112
cache_dir=self.cache_dir,
110113
local_files_only=True).cuda().eval()
111-
112-
for batch in self.batched_prompts:
113-
latent_batch = {}
114-
tokenized_t5 = t5_tokenizer(batch,
115-
padding='max_length',
116-
max_length=t5_tokenizer.model_max_length,
117-
truncation=True,
118-
return_tensors='pt')
119-
t5_attention_mask = tokenized_t5['attention_mask'].to(torch.bool).cuda()
120-
t5_ids = tokenized_t5['input_ids'].cuda()
121-
t5_latents = t5_model(input_ids=t5_ids, attention_mask=t5_attention_mask)[0].cpu()
122-
t5_attention_mask = t5_attention_mask.cpu().to(torch.long)
123-
124-
tokenized_clip = clip_tokenizer(batch,
114+
with torch.no_grad():
115+
for batch in self.batched_prompts:
116+
latent_batch = {}
117+
tokenized_t5 = t5_tokenizer(batch,
125118
padding='max_length',
126-
max_length=clip_tokenizer.model_max_length,
119+
max_length=t5_tokenizer.model_max_length,
127120
truncation=True,
128121
return_tensors='pt')
129-
clip_attention_mask = tokenized_clip['attention_mask'].cuda()
130-
clip_ids = tokenized_clip['input_ids'].cuda()
131-
clip_outputs = clip_model(input_ids=clip_ids,
132-
attention_mask=clip_attention_mask,
133-
output_hidden_states=True)
134-
clip_latents = clip_outputs.hidden_states[-2].cpu()
135-
clip_pooled = clip_outputs[1].cpu()
136-
clip_attention_mask = clip_attention_mask.cpu().to(torch.long)
137-
138-
latent_batch[self.t5_latent_key] = t5_latents
139-
latent_batch[self.t5_mask_key] = t5_attention_mask
140-
latent_batch[self.clip_latent_key] = clip_latents
141-
latent_batch[self.clip_mask_key] = clip_attention_mask
142-
latent_batch[self.clip_pooled_key] = clip_pooled
143-
self.batched_latents.append(latent_batch)
122+
t5_attention_mask = tokenized_t5['attention_mask'].to(torch.bool).cuda()
123+
t5_ids = tokenized_t5['input_ids'].cuda()
124+
t5_latents = t5_model(input_ids=t5_ids, attention_mask=t5_attention_mask)[0].cpu()
125+
t5_attention_mask = t5_attention_mask.cpu().to(torch.long)
126+
127+
tokenized_clip = clip_tokenizer(batch,
128+
padding='max_length',
129+
max_length=clip_tokenizer.model_max_length,
130+
truncation=True,
131+
return_tensors='pt')
132+
clip_attention_mask = tokenized_clip['attention_mask'].cuda()
133+
clip_ids = tokenized_clip['input_ids'].cuda()
134+
clip_outputs = clip_model(input_ids=clip_ids,
135+
attention_mask=clip_attention_mask,
136+
output_hidden_states=True)
137+
clip_latents = clip_outputs.hidden_states[-2].cpu()
138+
clip_pooled = clip_outputs[1].cpu()
139+
clip_attention_mask = clip_attention_mask.cpu().to(torch.long)
140+
141+
latent_batch[self.t5_latent_key] = t5_latents
142+
latent_batch[self.t5_mask_key] = t5_attention_mask
143+
latent_batch[self.clip_latent_key] = clip_latents
144+
latent_batch[self.clip_mask_key] = clip_attention_mask
145+
latent_batch[self.clip_pooled_key] = clip_pooled
146+
self.batched_latents.append(latent_batch)
144147

145148
del t5_model
146149
del clip_model
@@ -160,21 +163,40 @@ def eval_start(self, state: State, logger: Logger):
160163
if self.precomputed_latents:
161164
for batch in self.batched_latents:
162165
pooled_prompt = batch[self.clip_pooled_key].cuda()
163-
prompt_embeds, prompt_mask = model.prepare_text_embeddings(batch[self.t5_latent_key].cuda(),
164-
batch[self.clip_latent_key].cuda(),
165-
batch[self.t5_mask_key].cuda(),
166-
batch[self.clip_mask_key].cuda())
167-
gen_images = model.generate(prompt_embeds=prompt_embeds,
168-
pooled_prompt=pooled_prompt,
169-
prompt_mask=prompt_mask,
170-
height=self.size[0],
171-
width=self.size[1],
172-
guidance_scale=self.guidance_scale,
173-
rescaled_guidance=self.rescaled_guidance,
174-
progress_bar=False,
175-
num_inference_steps=self.num_inference_steps,
176-
seed=self.seed)
166+
if self.use_mask:
167+
prompt_embeds, prompt_mask = model.prepare_text_embeddings(batch[self.t5_latent_key].cuda(),
168+
batch[self.clip_latent_key].cuda(),
169+
batch[self.t5_mask_key].cuda(),
170+
batch[self.clip_mask_key].cuda())
171+
gen_images = model.generate(prompt_embeds=prompt_embeds,
172+
pooled_prompt=pooled_prompt,
173+
prompt_mask=prompt_mask,
174+
height=self.size[0],
175+
width=self.size[1],
176+
guidance_scale=self.guidance_scale,
177+
rescaled_guidance=self.rescaled_guidance,
178+
progress_bar=False,
179+
num_inference_steps=self.num_inference_steps,
180+
seed=self.seed)
181+
else:
182+
prompt_embeds = model.prepare_text_embeddings(batch[self.t5_latent_key].cuda(),
183+
batch[self.clip_latent_key].cuda())
184+
gen_images = model.generate(prompt_embeds=prompt_embeds,
185+
pooled_prompt=pooled_prompt,
186+
height=self.size[0],
187+
width=self.size[1],
188+
guidance_scale=self.guidance_scale,
189+
rescaled_guidance=self.rescaled_guidance,
190+
progress_bar=False,
191+
num_inference_steps=self.num_inference_steps,
192+
seed=self.seed)
177193
all_gen_images.append(gen_images)
194+
# Clear up GPU tensors
195+
del pooled_prompt
196+
del prompt_embeds
197+
if self.use_mask:
198+
del prompt_mask
199+
torch.cuda.empty_cache()
178200
else:
179201
for batch in self.batched_prompts:
180202
gen_images = model.generate(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright 2022 MosaicML Diffusion authors
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Synthetic Image-Caption dataset."""
5+
6+
from typing import Dict, Optional
7+
8+
import torch
9+
from composer.utils import dist
10+
from torch.utils.data import DataLoader, Dataset
11+
12+
13+
class SyntheticImageCaptionLatentsDataset(Dataset):
14+
"""Synthetic dataset imitating a dataset containing image-caption pairs.
15+
16+
Args:
17+
image_size (int): Size of the synthetic images. Default: ``512``.
18+
clip_length (int): Length of the synthetic clip embeddings. Default: ``77``.
19+
clip_dim (int): Dimension of the synthetic clip embeddings. Default: ``768``.
20+
t5_length (int): Length of the synthetic T5 embeddings. Default: ``512``.
21+
t5_dim (int): Dimension of the synthetic T5 embeddings. Default: ``4096``.
22+
"""
23+
24+
def __init__(self,
25+
image_size: int = 512,
26+
clip_length: int = 77,
27+
clip_dim: int = 768,
28+
t5_length: int = 512,
29+
t5_dim: int = 4096):
30+
31+
super().__init__()
32+
self.image_size = image_size
33+
self.clip_length = clip_length
34+
self.clip_dim = clip_dim
35+
self.t5_length = t5_length
36+
self.t5_dim = t5_dim
37+
38+
def __len__(self):
39+
return 100_000
40+
41+
def __getitem__(self, idx):
42+
out = {}
43+
out['cond_crops_coords_top_left'] = torch.tensor([0, 0], dtype=torch.float)
44+
out['cond_original_size'] = torch.tensor([self.image_size, self.image_size], dtype=torch.float)
45+
out['cond_target_size'] = torch.tensor([self.image_size, self.image_size], dtype=torch.float)
46+
out['image'] = torch.randn(3, self.image_size, self.image_size)
47+
out['CLIP_LATENTS'] = torch.randn(self.clip_length, self.clip_dim, dtype=torch.float)
48+
out['CLIP_POOLED'] = torch.randn(self.clip_dim, dtype=torch.float)
49+
out['CLIP_ATTENTION_MASK'] = torch.ones(self.clip_length)
50+
out['T5_LATENTS'] = torch.randn(self.t5_length, self.t5_dim, dtype=torch.float)
51+
out['T5_ATTENTION_MASK'] = torch.ones(self.t5_length)
52+
return out
53+
54+
55+
def build_synthetic_image_caption_latents_dataloader(
56+
batch_size: int,
57+
image_size: int = 512,
58+
clip_length: int = 77,
59+
clip_dim: int = 768,
60+
t5_length: int = 512,
61+
t5_dim: int = 4096,
62+
dataloader_kwargs: Optional[Dict] = None,
63+
):
64+
"""Builds a dataloader for the synthetic image-caption dataset.
65+
66+
Args:
67+
batch_size (int): Batch size for the dataloader.
68+
image_size (int): Size of the synthetic images. Default: ``512``.
69+
clip_length (int): Length of the synthetic clip embeddings. Default: ``77``.
70+
clip_dim (int): Dimension of the synthetic clip embeddings. Default: ``768``.
71+
t5_length (int): Length of the synthetic T5 embeddings. Default: ``512``.
72+
t5_dim (int): Dimension of the synthetic T5 embeddings. Default: ``4096``.
73+
dataloader_kwargs (optional, dict): Additional arguments to pass to the dataloader. Default ``None``.
74+
"""
75+
if dataloader_kwargs is None:
76+
dataloader_kwargs = {}
77+
78+
dataset = SyntheticImageCaptionLatentsDataset(image_size=image_size,
79+
clip_length=clip_length,
80+
clip_dim=clip_dim,
81+
t5_length=t5_length,
82+
t5_dim=t5_dim)
83+
84+
dataloader = DataLoader(
85+
dataset=dataset,
86+
sampler=dist.get_sampler(dataset),
87+
batch_size=batch_size,
88+
**dataloader_kwargs,
89+
)
90+
91+
return dataloader

0 commit comments

Comments
 (0)