Skip to content

Commit 46c0e1f

Browse files
authored
Fix incorrect type annotation in get_auxiliary_logits (#37955)
Correct type annotation from Dict(str, Tensor) to Dict[str, Tensor]
1 parent d80f53f commit 46c0e1f

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/transformers/models/mask2former/modeling_mask2former.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2359,7 +2359,7 @@ def get_loss(self, loss_dict: Dict[str, Tensor]) -> Tensor:
23592359
return sum(loss_dict.values())
23602360

23612361
def get_auxiliary_logits(self, classes: torch.Tensor, output_masks: torch.Tensor):
2362-
auxiliary_logits: List[Dict(str, Tensor)] = []
2362+
auxiliary_logits: List[Dict[str, Tensor]] = []
23632363

23642364
for aux_binary_masks, aux_classes in zip(output_masks[:-1], classes[:-1]):
23652365
auxiliary_logits.append({"masks_queries_logits": aux_binary_masks, "class_queries_logits": aux_classes})

0 commit comments

Comments
 (0)