Skip to content

Commit 339a554

Browse files
committed
vision tests passing
1 parent 7220d1f commit 339a554

File tree

1 file changed

+3
-9
lines changed

1 file changed

+3
-9
lines changed

models/tt_transformers/tt/generator.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1228,18 +1228,12 @@ def decode_forward_llama_vision(
12281228
else:
12291229
tt_logits = self._decode_forward_no_trace(**decode_kwargs)
12301230

1231-
output = None
12321231
if read_from_device:
12331232
to_host = self.read_decode_output(tt_logits)
1234-
output = self.process_decode_output_host(to_host)
1233+
# skip log_probs
1234+
return self.process_decode_output_host(to_host)[0]
12351235
else:
1236-
output = tt_logits
1237-
1238-
# skip returning log-probs
1239-
if isinstance(output, tuple):
1240-
return output[0]
1241-
else:
1242-
return output
1236+
return tt_logits
12431237

12441238
# Note: This function is called by vLLM
12451239
def read_decode_output(self, tt_out, async_read=False):

0 commit comments

Comments
 (0)