Skip to content

Commit cc81794

Browse files
fix: cross encoder output (#3)
1 parent 62f44c4 commit cc81794

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

vllm_rbln/model_executor/models/optimum/encoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,6 @@ def forward(self, model_input: ModelInputForRBLN) -> torch.Tensor:
138138
assert hidden_states.dim() == 2, (
139139
f"We expected the shape to be dim 2 ([batch, num_labels]), "
140140
f"but the current output is dim {hidden_states.dim()}.")
141-
hidden_states = hidden_states[:request_nums]
141+
hidden_states = hidden_states[:request_nums].squeeze(-1)
142142

143143
return hidden_states

0 commit comments

Comments
 (0)