14
14
from composer .core import get_precision_context
15
15
from composer .loggers import LoggerDestination
16
16
from composer .utils import dist
17
- from torch .utils .data import DataLoader
17
+ from torch .utils .data import Dataset
18
18
from torchmetrics .multimodal import CLIPScore
19
- from torchvision .transforms .functional import to_pil_image
19
+ from torchvision .transforms .functional import pil_to_tensor , to_pil_image
20
20
from tqdm .auto import tqdm
21
- from transformers import PreTrainedTokenizerBase
22
21
23
22
os .environ ['TOKENIZERS_PARALLELISM' ] = 'false'
24
23
@@ -32,7 +31,7 @@ class CleanFIDEvaluator:
32
31
33
32
Args:
34
33
model (ComposerModel): The model to evaluate.
35
- eval_dataloader (DataLoader ): The dataloader to use for evaluation .
34
+ dataset (Dataset ): The dataset to use the prompts from .
36
35
clip_metric (CLIPScore): The CLIPScore metric to use for evaluation.
37
36
load_path (str, optional): The path to load the model from. Default: ``None``.
38
37
guidance_scales (List[float]): The guidance scales to use for evaluation.
@@ -52,13 +51,14 @@ class CleanFIDEvaluator:
52
51
default_prompt (Optional[str]): An optional default prompt to add before each eval prompt. Default: ``None``.
53
52
default_negative_prompt (Optional[str]): An optional default negative prompt to add before each
54
53
negative prompt. Default: ``None``.
54
+ sdxl_conditioning (bool): Whether or not to include SDXL conditioning in the evaluation. Default: ``False``.
55
55
additional_generate_kwargs (Dict, optional): Additional keyword arguments to pass to the model.generate method.
56
56
57
57
"""
58
58
59
59
def __init__ (self ,
60
60
model : ComposerModel ,
61
- eval_dataloader : DataLoader ,
61
+ dataset : Dataset ,
62
62
clip_metric : CLIPScore ,
63
63
load_path : Optional [str ] = None ,
64
64
guidance_scales : Optional [List [float ]] = None ,
@@ -75,10 +75,10 @@ def __init__(self,
75
75
prompts : Optional [List [str ]] = None ,
76
76
default_prompt : Optional [str ] = None ,
77
77
default_negative_prompt : Optional [str ] = None ,
78
+ sdxl_conditioning : bool = False ,
78
79
additional_generate_kwargs : Optional [Dict ] = None ):
79
80
self .model = model
80
- self .tokenizer : PreTrainedTokenizerBase = model .tokenizer
81
- self .eval_dataloader = eval_dataloader
81
+ self .dataset = dataset
82
82
self .clip_metric = clip_metric
83
83
self .load_path = load_path
84
84
self .guidance_scales = guidance_scales if guidance_scales is not None else [1.0 ]
@@ -89,20 +89,19 @@ def __init__(self,
89
89
self .loggers = loggers
90
90
self .seed = seed
91
91
self .output_dir = output_dir
92
- self .num_samples = num_samples if num_samples is not None else float ( 'inf' )
92
+ self .num_samples = num_samples
93
93
self .precision = precision
94
94
self .prompts = prompts if prompts is not None else ['A shiba inu wearing a blue sweater' ]
95
95
self .default_prompt = default_prompt
96
96
self .default_negative_prompt = default_negative_prompt
97
+ self .sdxl_conditioning = sdxl_conditioning
97
98
self .additional_generate_kwargs = additional_generate_kwargs if additional_generate_kwargs is not None else {}
98
- self .sdxl = model .sdxl
99
99
100
100
# Load the model
101
101
trainer = Trainer (model = self .model ,
102
102
load_path = self .load_path ,
103
103
load_weights_only = True ,
104
104
load_strict_model_weights = load_strict_model_weights ,
105
- eval_dataloader = self .eval_dataloader ,
106
105
seed = self .seed ,
107
106
loggers = self .loggers )
108
107
self .trainer = trainer
@@ -139,18 +138,27 @@ def _generate_images(self, guidance_scale: float):
139
138
140
139
# Storage for prompts
141
140
prompts = {}
142
- # Iterate over the eval dataloader
143
- num_batches = len (self .eval_dataloader )
144
- starting_seed = self .seed + num_batches * dist .get_local_rank ()
145
- for batch_id , batch in tqdm (enumerate (self .eval_dataloader )):
146
- # Break if enough samples have been generated
147
- if batch_id * self .batch_size * dist .get_world_size () >= self .num_samples :
148
- break
149
-
150
- real_images = batch [self .image_key ]
151
- tokenized_captions = batch [self .caption_key ]
152
- # Get the prompts from the tokens
153
- text_captions = self .tokenizer .batch_decode (tokenized_captions , skip_special_tokens = True )
141
+ # Partition the dataset across the ranks
142
+ dataset_len = self .dataset .num_samples # type: ignore
143
+ # Truncate the dataset if num_samples is specified
144
+ if self .num_samples is not None and self .num_samples <= dataset_len :
145
+ dataset_len = self .num_samples
146
+ elif self .num_samples is not None and self .num_samples > dataset_len :
147
+ raise ValueError (f'num_samples { self .num_samples } is greater than the dataset length { dataset_len } .' )
148
+ samples_per_rank , remainder = divmod (dataset_len , dist .get_world_size ())
149
+ start_idx = dist .get_global_rank () * samples_per_rank + min (remainder , dist .get_global_rank ())
150
+ end_idx = start_idx + samples_per_rank
151
+ if dist .get_global_rank () < remainder :
152
+ end_idx += 1
153
+ print (f'Rank { dist .get_global_rank ()} processing samples { start_idx } to { end_idx } of { dataset_len } total.' )
154
+ # Iterate over the dataset
155
+ for sample_id in tqdm (range (start_idx , end_idx )):
156
+ # Set a unique seed for this sample to ensure reproducible but different randomness
157
+ seed = self .seed + sample_id
158
+ # Image and caption come from the dataset. Note the caption is untokenized
159
+ sample = self .dataset [sample_id ]
160
+ real_images = pil_to_tensor (sample [self .image_key ]).unsqueeze (0 ) / 255.0
161
+ text_captions = sample [self .caption_key ]
154
162
# Add default prompts if specified
155
163
augmented_captions = text_captions
156
164
augmented_negative_prompt = None
@@ -159,15 +167,12 @@ def _generate_images(self, guidance_scale: float):
159
167
if self .default_negative_prompt :
160
168
augmented_negative_prompt = [f'{ self .default_negative_prompt } ' for _ in text_captions ]
161
169
162
- if self .sdxl :
163
- crop_params = batch [ 'cond_crops_coords_top_left' ]
164
- input_size_params = batch [ 'cond_original_size' ]
170
+ if self .sdxl_conditioning :
171
+ crop_params = torch . tensor ([ 0 , 0 ]). unsqueeze ( 0 )
172
+ input_size_params = torch . tensor ([ self . size , self . size ]). unsqueeze ( 0 )
165
173
else :
166
174
crop_params = None
167
175
input_size_params = None
168
-
169
- # Ensure a new seed for each batch, as randomness in model.generate is fixed.
170
- seed = starting_seed + batch_id
171
176
# Generate images from the captions
172
177
with get_precision_context (self .precision ):
173
178
generated_images = self .model .generate (prompt = augmented_captions ,
@@ -188,11 +193,11 @@ def _generate_images(self, guidance_scale: float):
188
193
f'Images are expected to be in the range [0, 1]. Got max { real_images .max ()} and min { real_images .min ()} '
189
194
)
190
195
for i , img in enumerate (real_images ):
191
- to_pil_image (img ).save (f'{ real_image_path } /{ batch_id } _ { i } _rank_{ dist .get_local_rank ()} .png' )
192
- prompts [f'{ batch_id } _ { i } _rank_{ dist .get_local_rank ()} ' ] = text_captions [i ]
196
+ to_pil_image (img ).save (f'{ real_image_path } /{ sample_id } _rank_{ dist .get_local_rank ()} .png' )
197
+ prompts [f'{ sample_id } _rank_{ dist .get_local_rank ()} ' ] = text_captions [i ]
193
198
# Save the generated images
194
199
for i , img in enumerate (generated_images ):
195
- to_pil_image (img ).save (f'{ gen_image_path } /{ batch_id } _ { i } _rank_{ dist .get_local_rank ()} .png' )
200
+ to_pil_image (img ).save (f'{ gen_image_path } /{ sample_id } _rank_{ dist .get_local_rank ()} .png' )
196
201
197
202
# Save the prompts as json
198
203
json .dump (prompts , open (f'{ real_image_path } /prompts_rank_{ dist .get_local_rank ()} .json' , 'w' ))
0 commit comments