Skip to content

Commit ab49a48

Browse files
authored
Fix incorrect condition comment in log_target calculation (#633)
## Summary This PR fixes an incorrect comment in the `log_target` condition calculation. The original comments had the conditions reversed of LigerKLDivLossFunction. - When `log_target=False`, the actual computation is `target * (target.log() - input)` - When `log_target=True`, the actual computation is `target.exp() * (target - input)` This change corrects the comments to accurately reflect the implementation (documentation-only change). ## Details The comment correction improves code clarity and prevents potential confusion for developers. No functional changes are made. ## Testing Done **Not required** - This is a documentation-only change that: 1. Doesn't modify any executable code 2. Doesn't affect any existing functionality 3. Only fixes inaccurate comments <!-- Removed testing checklist as it's not applicable for doc changes -->
1 parent 87187b1 commit ab49a48

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

src/liger_kernel/ops/kl_div.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,9 @@ class LigerKLDivLossFunction(torch.autograd.Function):
185185
Class implementing the forward and backward pass for the KL Divergence Loss using Triton, as defined by the following formula:
186186
```python
187187
if log_target:
188-
loss = target * (target.log() - input)
189-
else:
190188
loss = target.exp() * (target - input)
189+
else:
190+
loss = target * (target.log() - input)
191191
```,
192192
then the loss is reduced according to the `reduction` parameter.
193193
as defined in the PyTorch documentation: https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html

0 commit comments

Comments
 (0)