@@ -39,18 +39,22 @@ def __call__(self, prompt: Dict[str, Any]) -> List[Message]:
39
39
40
40
# Iterate through roles and add content
41
41
for role , content in prompt .items ():
42
- if isinstance (content , str ):
42
+ if content is None :
43
+ continue
44
+ elif isinstance (content , str ):
43
45
new_content = [{"type" : "text" , "content" : content }]
44
- else :
45
- assert (
46
- "image" in content .keys ()
47
- ), "Multiple entries per role expect an image key"
46
+ elif "image" in content .keys ():
48
47
image_loc = content ["image" ]
49
48
image = load_image (image_loc )
50
49
new_content = [
51
50
{"type" : "image" , "content" : image },
52
51
{"type" : "text" , "content" : content ["text" ]},
53
52
]
53
+ else :
54
+ assert (
55
+ "text" in content .keys ()
56
+ ), "Multiple entries per role expect at least a text key"
57
+ new_content = [{"type" : "text" , "content" : content ["text" ]}]
54
58
messages .append (Message (role = role , content = new_content ))
55
59
56
60
# Finally, add an empty assistant message to kick-start generation
@@ -109,12 +113,12 @@ def log_metrics(self, total_time: int, tokens_per_second: float) -> None:
109
113
f"Time for inference: { total_time :.02f} sec total, { tokens_per_second :.02f} tokens/sec"
110
114
)
111
115
self ._logger .info (
112
- f"Bandwidth achieved: { model_size * tokens_per_second / 1e9 :.02f} GB /s"
116
+ f"Bandwidth achieved: { model_size * tokens_per_second / ( 1024 ** 3 ) :.02f} GiB /s"
113
117
)
114
118
if self ._device .type != "cpu" :
115
119
torch_device = utils .get_torch_device_namespace ()
116
120
self ._logger .info (
117
- f"Max memory allocated: { torch_device .max_memory_allocated () / 1e9 :.02f} GB "
121
+ f"Max memory allocated: { torch_device .max_memory_allocated () / ( 1024 ** 3 ) :.02f} GiB "
118
122
)
119
123
120
124
@torch .inference_mode ()
0 commit comments