Skip to content

Commit a821d9b

Browse files
committed
fix: fix batch size for validation
1 parent b5fa5e3 commit a821d9b

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

packages/train-yolo/main.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,13 @@ def main(
335335
name=run_name,
336336
)
337337

338+
if batch < 0:
339+
# If batch size is set to -1, ultralytics uses the maximum batch size
340+
# that fits in the GPU memory. We need to get the actual batch size
341+
# used, as it will be needed later during validation.
342+
# The actual batch size is stored in model.trainer.args.batch
343+
batch = model.trainer.args.batch
344+
338345
# Export the trained model to ONNX and TensorRT format
339346
model.export(
340347
format="onnx",

0 commit comments

Comments
 (0)