Skip to content

Commit 77497a6

Browse files
luciaquirkeclaude
andcommitted
Match all-reduce op to query_method (AVG for mean, SUM for sum)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent dd6f29c commit 77497a6

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

bergson/double_backward.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,8 +203,9 @@ def schedule(step: Numeric) -> Numeric:
203203
)
204204

205205
if world_size > 1:
206+
reduce_op = dist.ReduceOp.AVG if run_cfg.query_method == "mean" else dist.ReduceOp.SUM
206207
for v in query_grads.values():
207-
dist.all_reduce(v, op=dist.ReduceOp.AVG)
208+
dist.all_reduce(v, op=reduce_op)
208209

209210
stream.requires_grad = True
210211
opt_grads = [

0 commit comments

Comments
 (0)