Skip to content

Commit 31f66e3

Browse files
authored
Replaceinference_mode() with no_grad() to resolve training (#559)
1 parent 904c388 commit 31f66e3

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

rfdetr/detr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ def predict(
310310
"Alternatively, you can recompile the optimized model for a different batch size "
311311
"by calling model.optimize_for_inference(batch_size=<new_batch_size>).")
312312

313-
with torch.inference_mode():
313+
with torch.no_grad():
314314
if self._is_optimized_for_inference:
315315
predictions = self.model.inference_model(batch_tensor.to(dtype=self._optimized_dtype))
316316
else:

rfdetr/engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def train_one_epoch(
113113
scales = compute_multi_scale_scales(args.resolution, args.expanded_scales, args.patch_size, args.num_windows)
114114
random.seed(it)
115115
scale = random.choice(scales)
116-
with torch.inference_mode():
116+
with torch.no_grad():
117117
samples.tensors = F.interpolate(samples.tensors, size=scale, mode='bilinear', align_corners=False)
118118
samples.mask = F.interpolate(samples.mask.unsqueeze(1).float(), size=scale, mode='nearest').squeeze(1).bool()
119119

rfdetr/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,7 @@ def lr_lambda(current_step: int):
368368

369369
utils.save_on_master(weights, checkpoint_path)
370370

371-
with torch.inference_mode():
371+
with torch.no_grad():
372372
test_stats, coco_evaluator = evaluate(
373373
model, criterion, postprocess, data_loader_val, base_ds, device, args=args
374374
)

0 commit comments

Comments
 (0)