Skip to content

Commit ba8ca02

Browse files
authored
Update evaluation and inference code to handle other precisions and models (#179)
1 parent b0a094f commit ba8ca02

File tree

7 files changed

+88
-74
lines changed

7 files changed

+88
-74
lines changed

diffusion/datasets/image.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@ class StreamingImageDataset(StreamingDataset):
2626
Args:
2727
streams (Sequence[Stream], optional): One or more Streams to stream/cache samples from.
2828
``StreamingImageCaptionDataset`` uses either ``streams`` or ``remote``/``local``. Default:``None``.
29-
remote (str, optional): Remote directory (S3 or local filesystem) where dataset is stored. Default: ``None``.
30-
local (str, optional): Local filesystem directory where dataset is cached during operation. Default: ``None``.
29+
remote (Union[str, Sequence[str]], optional): Remote directory (S3 or local filesystem) where dataset is
30+
stored. Default: ``None``.
31+
local (Union[str, Sequence[str]], optional): Local filesystem directory where dataset is cached during
32+
operation. Default: ``None``.
3133
transform (Callable, optional): The transforms to apply to the image. Default: ``None``.
3234
image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``.
3335
image_output_key (optional, str): Optional output key for the image. If none, the value of `image_key` will
@@ -41,8 +43,8 @@ class StreamingImageDataset(StreamingDataset):
4143
def __init__(
4244
self,
4345
streams: Optional[Sequence[Stream]] = None,
44-
remote: Optional[str] = None,
45-
local: Optional[str] = None,
46+
remote: Optional[Union[str, Sequence[str]]] = None,
47+
local: Optional[Union[str, Sequence[str]]] = None,
4648
transform: Optional[Callable] = None,
4749
image_key: str = 'image',
4850
image_output_key: Optional[str] = None,
@@ -54,10 +56,11 @@ def __init__(
5456
streaming_kwargs.setdefault('shuffle_block_size', 1 << 18)
5557
streaming_kwargs.setdefault('shuffle_algo', 'py1s')
5658

59+
# Make the streams if necessary
60+
streams = make_streams(remote, local=local) if streams is None else streams
61+
5762
super().__init__(
5863
streams=streams,
59-
remote=remote,
60-
local=local,
6164
**streaming_kwargs,
6265
)
6366

diffusion/evaluate.py

+4-10
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from composer.loggers import LoggerDestination
1515
from composer.utils import reproducibility
1616
from omegaconf import DictConfig, OmegaConf
17-
from torch.utils.data import DataLoader
17+
from torch.utils.data import Dataset
1818
from torchmetrics.multimodal import CLIPScore
1919

2020
from diffusion.evaluation.clean_fid_eval import CleanFIDEvaluator
@@ -31,14 +31,8 @@ def evaluate(config: DictConfig) -> None:
3131
# The model to evaluate
3232
model: ComposerModel = hydra.utils.instantiate(config.model)
3333

34-
tokenizer = model.tokenizer if hasattr(model, 'tokenizer') else None
35-
36-
# The dataloader to use for evaluation
37-
if tokenizer:
38-
eval_dataloader = hydra.utils.instantiate(config.eval_dataloader, tokenizer=tokenizer)
39-
40-
else:
41-
eval_dataloader: DataLoader = hydra.utils.instantiate(config.eval_dataloader)
34+
# The dataset
35+
dataset: Dataset = hydra.utils.instantiate(config.dataset)
4236

4337
# The CLIPScores metric to use for evaluation
4438
clip_metric: CLIPScore = hydra.utils.instantiate(config.clip_metric)
@@ -88,7 +82,7 @@ def evaluate(config: DictConfig) -> None:
8882
evaluator: CleanFIDEvaluator = hydra.utils.instantiate(
8983
config.evaluator,
9084
model=model,
91-
eval_dataloader=eval_dataloader,
85+
dataset=dataset,
9286
clip_metric=clip_metric,
9387
loggers=logger,
9488
)

diffusion/evaluation/clean_fid_eval.py

+36-31
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,10 @@
1414
from composer.core import get_precision_context
1515
from composer.loggers import LoggerDestination
1616
from composer.utils import dist
17-
from torch.utils.data import DataLoader
17+
from torch.utils.data import Dataset
1818
from torchmetrics.multimodal import CLIPScore
19-
from torchvision.transforms.functional import to_pil_image
19+
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
2020
from tqdm.auto import tqdm
21-
from transformers import PreTrainedTokenizerBase
2221

2322
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
2423

@@ -32,7 +31,7 @@ class CleanFIDEvaluator:
3231
3332
Args:
3433
model (ComposerModel): The model to evaluate.
35-
eval_dataloader (DataLoader): The dataloader to use for evaluation.
34+
dataset (Dataset): The dataset to use the prompts from.
3635
clip_metric (CLIPScore): The CLIPScore metric to use for evaluation.
3736
load_path (str, optional): The path to load the model from. Default: ``None``.
3837
guidance_scales (List[float]): The guidance scales to use for evaluation.
@@ -52,13 +51,14 @@ class CleanFIDEvaluator:
5251
default_prompt (Optional[str]): An optional default prompt to add before each eval prompt. Default: ``None``.
5352
default_negative_prompt (Optional[str]): An optional default negative prompt to add before each
5453
negative prompt. Default: ``None``.
54+
sdxl_conditioning (bool): Whether or not to include SDXL conditioning in the evaluation. Default: ``False``.
5555
additional_generate_kwargs (Dict, optional): Additional keyword arguments to pass to the model.generate method.
5656
5757
"""
5858

5959
def __init__(self,
6060
model: ComposerModel,
61-
eval_dataloader: DataLoader,
61+
dataset: Dataset,
6262
clip_metric: CLIPScore,
6363
load_path: Optional[str] = None,
6464
guidance_scales: Optional[List[float]] = None,
@@ -75,10 +75,10 @@ def __init__(self,
7575
prompts: Optional[List[str]] = None,
7676
default_prompt: Optional[str] = None,
7777
default_negative_prompt: Optional[str] = None,
78+
sdxl_conditioning: bool = False,
7879
additional_generate_kwargs: Optional[Dict] = None):
7980
self.model = model
80-
self.tokenizer: PreTrainedTokenizerBase = model.tokenizer
81-
self.eval_dataloader = eval_dataloader
81+
self.dataset = dataset
8282
self.clip_metric = clip_metric
8383
self.load_path = load_path
8484
self.guidance_scales = guidance_scales if guidance_scales is not None else [1.0]
@@ -89,20 +89,19 @@ def __init__(self,
8989
self.loggers = loggers
9090
self.seed = seed
9191
self.output_dir = output_dir
92-
self.num_samples = num_samples if num_samples is not None else float('inf')
92+
self.num_samples = num_samples
9393
self.precision = precision
9494
self.prompts = prompts if prompts is not None else ['A shiba inu wearing a blue sweater']
9595
self.default_prompt = default_prompt
9696
self.default_negative_prompt = default_negative_prompt
97+
self.sdxl_conditioning = sdxl_conditioning
9798
self.additional_generate_kwargs = additional_generate_kwargs if additional_generate_kwargs is not None else {}
98-
self.sdxl = model.sdxl
9999

100100
# Load the model
101101
trainer = Trainer(model=self.model,
102102
load_path=self.load_path,
103103
load_weights_only=True,
104104
load_strict_model_weights=load_strict_model_weights,
105-
eval_dataloader=self.eval_dataloader,
106105
seed=self.seed,
107106
loggers=self.loggers)
108107
self.trainer = trainer
@@ -139,18 +138,27 @@ def _generate_images(self, guidance_scale: float):
139138

140139
# Storage for prompts
141140
prompts = {}
142-
# Iterate over the eval dataloader
143-
num_batches = len(self.eval_dataloader)
144-
starting_seed = self.seed + num_batches * dist.get_local_rank()
145-
for batch_id, batch in tqdm(enumerate(self.eval_dataloader)):
146-
# Break if enough samples have been generated
147-
if batch_id * self.batch_size * dist.get_world_size() >= self.num_samples:
148-
break
149-
150-
real_images = batch[self.image_key]
151-
tokenized_captions = batch[self.caption_key]
152-
# Get the prompts from the tokens
153-
text_captions = self.tokenizer.batch_decode(tokenized_captions, skip_special_tokens=True)
141+
# Partition the dataset across the ranks
142+
dataset_len = self.dataset.num_samples # type: ignore
143+
# Truncate the dataset if num_samples is specified
144+
if self.num_samples is not None and self.num_samples <= dataset_len:
145+
dataset_len = self.num_samples
146+
elif self.num_samples is not None and self.num_samples > dataset_len:
147+
raise ValueError(f'num_samples {self.num_samples} is greater than the dataset length {dataset_len}.')
148+
samples_per_rank, remainder = divmod(dataset_len, dist.get_world_size())
149+
start_idx = dist.get_global_rank() * samples_per_rank + min(remainder, dist.get_global_rank())
150+
end_idx = start_idx + samples_per_rank
151+
if dist.get_global_rank() < remainder:
152+
end_idx += 1
153+
print(f'Rank {dist.get_global_rank()} processing samples {start_idx} to {end_idx} of {dataset_len} total.')
154+
# Iterate over the dataset
155+
for sample_id in tqdm(range(start_idx, end_idx)):
156+
# Set a unique seed for this sample to ensure reproducible but different randomness
157+
seed = self.seed + sample_id
158+
# Image and caption come from the dataset. Note the caption is untokenized
159+
sample = self.dataset[sample_id]
160+
real_images = pil_to_tensor(sample[self.image_key]).unsqueeze(0) / 255.0
161+
text_captions = sample[self.caption_key]
154162
# Add default prompts if specified
155163
augmented_captions = text_captions
156164
augmented_negative_prompt = None
@@ -159,15 +167,12 @@ def _generate_images(self, guidance_scale: float):
159167
if self.default_negative_prompt:
160168
augmented_negative_prompt = [f'{self.default_negative_prompt}' for _ in text_captions]
161169

162-
if self.sdxl:
163-
crop_params = batch['cond_crops_coords_top_left']
164-
input_size_params = batch['cond_original_size']
170+
if self.sdxl_conditioning:
171+
crop_params = torch.tensor([0, 0]).unsqueeze(0)
172+
input_size_params = torch.tensor([self.size, self.size]).unsqueeze(0)
165173
else:
166174
crop_params = None
167175
input_size_params = None
168-
169-
# Ensure a new seed for each batch, as randomness in model.generate is fixed.
170-
seed = starting_seed + batch_id
171176
# Generate images from the captions
172177
with get_precision_context(self.precision):
173178
generated_images = self.model.generate(prompt=augmented_captions,
@@ -188,11 +193,11 @@ def _generate_images(self, guidance_scale: float):
188193
f'Images are expected to be in the range [0, 1]. Got max {real_images.max()} and min {real_images.min()}'
189194
)
190195
for i, img in enumerate(real_images):
191-
to_pil_image(img).save(f'{real_image_path}/{batch_id}_{i}_rank_{dist.get_local_rank()}.png')
192-
prompts[f'{batch_id}_{i}_rank_{dist.get_local_rank()}'] = text_captions[i]
196+
to_pil_image(img).save(f'{real_image_path}/{sample_id}_rank_{dist.get_local_rank()}.png')
197+
prompts[f'{sample_id}_rank_{dist.get_local_rank()}'] = text_captions[i]
193198
# Save the generated images
194199
for i, img in enumerate(generated_images):
195-
to_pil_image(img).save(f'{gen_image_path}/{batch_id}_{i}_rank_{dist.get_local_rank()}.png')
200+
to_pil_image(img).save(f'{gen_image_path}/{sample_id}_rank_{dist.get_local_rank()}.png')
196201

197202
# Save the prompts as json
198203
json.dump(prompts, open(f'{real_image_path}/prompts_rank_{dist.get_local_rank()}.json', 'w'))

diffusion/evaluation/generate_geneval_images.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class GenevalImageGenerator:
2727
load_path (str, optional): The path to load the model from. Default: ``None``.
2828
local_checkpoint_path (str, optional): The local path to save the model checkpoint. Default: ``'/tmp/model.pt'``.
2929
load_strict_model_weights (bool): Whether or not to strict load model weights. Default: ``True``.
30+
precision (str): The precision to use for evaluation. Default: ``'amp_fp16'``.
3031
guidance_scale (float): The guidance scale to use for evaluation. Default: ``7.0``.
3132
height (int): The height of the generated images. Default: ``1024``.
3233
width (int): The width of the generated images. Default: ``1024``.
@@ -46,6 +47,7 @@ def __init__(self,
4647
load_path: Optional[str] = None,
4748
local_checkpoint_path: str = '/tmp/model.pt',
4849
load_strict_model_weights: bool = True,
50+
precision: str = 'amp_fp16',
4951
guidance_scale: float = 7.0,
5052
height: int = 1024,
5153
width: int = 1024,
@@ -77,6 +79,7 @@ def __init__(self,
7779
self.load_path = load_path
7880
self.local_checkpoint_path = local_checkpoint_path
7981
self.load_strict_model_weights = load_strict_model_weights
82+
self.precision = precision
8083
self.guidance_scale = guidance_scale
8184
self.height = height
8285
self.width = width
@@ -148,7 +151,7 @@ def generate(self):
148151
**self.additional_generate_kwargs).images[0]
149152
img = generated_image
150153
else:
151-
with get_precision_context('amp_fp16'):
154+
with get_precision_context(self.precision):
152155
generated_image = self.model.generate(prompt=caption,
153156
height=self.height,
154157
width=self.width,

diffusion/inference/inference_model.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -235,13 +235,23 @@ class ModelInference():
235235
model_name (str): Name of the model from `diffusion.models` to load. Ex: for stable diffusion xl, use 'stable_diffusion_xl'.
236236
local_checkpoint_path (str): Path to the local checkpoint. Default: '/tmp/model.pt'.
237237
strict (bool): Whether to load the model weights strictly. Default: False.
238+
dtype: The data type to use. One of [`float32`, `float16`, `bfloat16`]. Default: `bfloat16`.
238239
**model_kwargs: Keyword arguments to pass to the model initialization.
239240
"""
240241

241-
def __init__(self, model_name, local_checkpoint_path: str = LOCAL_CHECKPOINT_PATH, strict=False, **model_kwargs):
242+
def __init__(self,
243+
model_name,
244+
local_checkpoint_path: str = LOCAL_CHECKPOINT_PATH,
245+
strict=False,
246+
dtype='bfloat16',
247+
**model_kwargs):
242248
self.device = torch.cuda.current_device()
243249
model_factory = getattr(diffusion.models, model_name)
244250
model = model_factory(**model_kwargs)
251+
dtype_map = {'float32': torch.float32, 'float16': torch.float16, 'bfloat16': torch.bfloat16}
252+
if dtype not in dtype_map:
253+
raise ValueError(f'Invalid dtype: {dtype}. Must be one of {list(dtype_map.keys())}')
254+
self.dtype = dtype_map[dtype]
245255

246256
if 'pretrained' in model_kwargs and model_kwargs['pretrained']:
247257
pass
@@ -290,7 +300,7 @@ def predict(self, model_requests: List[Dict[str, Any]]):
290300
raise RuntimeError('There must be the same number of negative prompts as prompts.')
291301

292302
# Generate images
293-
with torch.cuda.amp.autocast(True):
303+
with torch.cuda.amp.autocast(True, dtype=self.dtype):
294304
imgs = self.model.generate(prompt=prompts, negative_prompt=negative_prompts, **generate_kwargs).cpu()
295305

296306
# Send as bytes

diffusion/models/models.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
import torch
1111
from composer.devices import DeviceGPU
12-
from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, EulerDiscreteScheduler, UNet2DConditionModel
12+
from diffusers import (AutoencoderKL, DDIMScheduler, DDPMScheduler, DPMSolverMultistepScheduler, EulerDiscreteScheduler,
13+
UNet2DConditionModel)
1314
from peft import LoraConfig
1415
from torchmetrics import MeanSquaredError
1516
from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer, PretrainedConfig
@@ -770,16 +771,13 @@ def precomputed_text_latent_diffusion(
770771
'beta_schedule': 'scaled_linear',
771772
'trained_betas': None,
772773
'prediction_type': prediction_type,
773-
'interpolation_type': 'linear',
774-
'use_karras_sigmas': False,
775774
'timestep_spacing': 'leading',
776-
'steps_offset': 1,
777775
'rescale_betas_zero_snr': False,
778776
}
779777

780778
if inference_noise_scheduler_params is not None:
781779
inference_scheduler_params.update(inference_noise_scheduler_params)
782-
inference_noise_scheduler = EulerDiscreteScheduler(**inference_scheduler_params)
780+
inference_noise_scheduler = DPMSolverMultistepScheduler(**inference_scheduler_params)
783781

784782
# Shift noise scheduler to correct for resolution changes
785783
noise_scheduler = shift_noise_schedule(noise_scheduler,

diffusion/models/precomputed_text_latent_diffusion.py

+20-19
Original file line numberDiff line numberDiff line change
@@ -203,28 +203,29 @@ def decode_latents(self, latents):
203203
def encode_text(self, text, device):
204204
assert self.t5_tokenizer is not None and self.t5_encoder is not None
205205
assert self.clip_tokenizer is not None and self.clip_encoder is not None
206-
# Encode with T5
207-
t5_tokenizer_out = self.t5_tokenizer(text,
208-
padding='max_length',
209-
max_length=self.t5_tokenizer.model_max_length,
210-
truncation=True,
211-
return_tensors='pt')
212-
t5_tokenized_captions = t5_tokenizer_out['input_ids'].to(device)
213-
t5_attn_mask = t5_tokenizer_out['attention_mask'].to(torch.bool).to(device)
214-
t5_embed = self.t5_encoder(input_ids=t5_tokenized_captions, attention_mask=t5_attn_mask)[0]
215-
# Encode with CLIP
216-
clip_tokenizer_out = self.clip_tokenizer(text,
206+
with torch.autocast(device_type='cuda', enabled=False):
207+
# Encode with T5
208+
t5_tokenizer_out = self.t5_tokenizer(text,
217209
padding='max_length',
218-
max_length=self.clip_tokenizer.model_max_length,
210+
max_length=self.t5_tokenizer.model_max_length,
219211
truncation=True,
220212
return_tensors='pt')
221-
clip_tokenized_captions = clip_tokenizer_out['input_ids'].to(device)
222-
clip_attn_mask = clip_tokenizer_out['attention_mask'].to(torch.bool).to(device)
223-
clip_out = self.clip_encoder(input_ids=clip_tokenized_captions,
224-
attention_mask=clip_attn_mask,
225-
output_hidden_states=True)
226-
clip_embed = clip_out.hidden_states[-2]
227-
pooled_embeddings = clip_out[1]
213+
t5_tokenized_captions = t5_tokenizer_out['input_ids'].to(device)
214+
t5_attn_mask = t5_tokenizer_out['attention_mask'].to(torch.bool).to(device)
215+
t5_embed = self.t5_encoder(input_ids=t5_tokenized_captions, attention_mask=t5_attn_mask)[0]
216+
# Encode with CLIP
217+
clip_tokenizer_out = self.clip_tokenizer(text,
218+
padding='max_length',
219+
max_length=self.clip_tokenizer.model_max_length,
220+
truncation=True,
221+
return_tensors='pt')
222+
clip_tokenized_captions = clip_tokenizer_out['input_ids'].to(device)
223+
clip_attn_mask = clip_tokenizer_out['attention_mask'].to(torch.bool).to(device)
224+
clip_out = self.clip_encoder(input_ids=clip_tokenized_captions,
225+
attention_mask=clip_attn_mask,
226+
output_hidden_states=True)
227+
clip_embed = clip_out.hidden_states[-2]
228+
pooled_embeddings = clip_out[1]
228229
return t5_embed, clip_embed, t5_attn_mask, clip_attn_mask, pooled_embeddings
229230

230231
def prepare_text_embeddings(self, t5_embed: torch.Tensor, clip_embed: torch.Tensor, t5_mask: torch.Tensor,

0 commit comments

Comments
 (0)