|
if self.args.use_prompts: |
|
with torch.no_grad(): |
|
outputs = self.model(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels, train=False, task_id=self.task_id) |
|
|
|
if not self.args.local_query: |
|
query = outputs.last_hidden_state.mean(dim=1) |
|
else: |
|
query = outputs.last_hidden_state |
|
|
|
outputs_without_aux = {k: v for k, v in outputs.items() if k != "auxiliary_outputs" and k != "enc_outputs"} |
|
|
|
# Retrieve the matching between the outputs of the last layer and the targets |
|
indices = self.model.matcher(outputs_without_aux, labels) |
|
|
|
one_hot_proposals = torch.zeros((len(labels),300)).to(self.device) |
|
for i,ind in enumerate(indices): |
|
for j in ind[0]: |
|
one_hot_proposals[i][j] = 1 |
|
|
|
query_wt = self.model.model.prompts.query_tf(query.view(query.shape[0],-1)) |
|
query_loss = F.cross_entropy(query_wt, one_hot_proposals) |
I found that the query loss is within the torch.no_grad(): block, which prevents the query loss from being included in the computation graph. As a result, the query loss cannot converge.

MD-DETR/engine.py
Lines 95 to 115 in 125e771
I found that the query loss is within the torch.no_grad(): block, which prevents the query loss from being included in the computation graph. As a result, the query loss cannot converge.