Skip to content

Comments

Handle missing CSM depth decoder loss during loss aggregation#496

Open
danielhanchen wants to merge 5 commits intomainfrom
fix-csm-depth-loss-none-aggregation
Open

Handle missing CSM depth decoder loss during loss aggregation#496
danielhanchen wants to merge 5 commits intomainfrom
fix-csm-depth-loss-none-aggregation

Conversation

@danielhanchen
Copy link
Contributor

Summary

  • fixes CSM temporary patch loss aggregation when depth decoder loss is absent
  • avoids Tensor + NoneType during training loss computation

Root Cause

In the patched CsmForConditionalGeneration.forward, loss was always computed as:

  • backbone_loss + depth_decoder_loss

When depth_decoder_loss is None, this raises:

  • TypeError: unsupported operand type(s) for +: 'Tensor' and 'NoneType'

Changes

  • unsloth_zoo/temporary_patches/misc.py
    • replace unconditional sum with guarded aggregation:
      • if only one loss exists, use that loss
      • if both exist, sum both

Validation

  • python -m compileall unsloth_zoo/temporary_patches/misc.py
  • static assertion confirms guarded aggregation block is present in patched source

Notes

  • behavior is unchanged when both losses are present
  • this only changes the edge case where one loss is missing

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @danielhanchen, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the robustness of the CSM model's loss computation by introducing conditional logic to handle cases where the depth decoder loss might be absent. This prevents runtime errors during training by ensuring that loss components are only summed when both are valid, thereby improving the stability of the model's forward pass without altering behavior when both losses are present.

Highlights

  • Bug Fix: Resolved a TypeError that occurred when depth_decoder_loss was None during loss aggregation in the CSM model's forward pass.
  • Robustness: Implemented guarded aggregation logic to safely combine backbone_loss and depth_decoder_loss, preventing Tensor + NoneType errors.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • unsloth_zoo/temporary_patches/misc.py
    • Replaced direct summation of backbone_loss and depth_decoder_loss with conditional aggregation.
    • Added checks for None values in both backbone_loss and depth_decoder_loss before summing.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request correctly addresses a TypeError that could occur during loss aggregation when the depth decoder loss is not present. The fix is sound and handles all cases where backbone_loss or depth_decoder_loss might be None. I've added one suggestion to refactor the loss aggregation logic to be more concise and idiomatic, which could also improve scalability if more loss components are added in the future.

Comment on lines +365 to +370
if backbone_loss is None:
loss = depth_decoder_loss
elif depth_decoder_loss is None:
loss = backbone_loss
else:
loss = backbone_loss + depth_decoder_loss
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While the current logic is correct, it can be simplified to be more concise and scalable for potentially adding more loss components in the future. You can use a list comprehension to filter out None values and then use sum().

            losses = [l for l in (backbone_loss, depth_decoder_loss) if l is not None]
            loss = sum(losses) if losses else None

@danielhanchen
Copy link
Contributor Author

Pushed follow-up fixes to this PR branch for transformers 5.x CSM compatibility.

New commits:

  • a2283ac csm: improve transformers 5.x forward patch compatibility
  • 622dda9 csm: restore misc.py file mode

What changed:

  • Added a dedicated CsmDepthDecoderModel forward patch to avoid the leaf-view in-place autograd failure.
  • Updated CSM forward/depth forward fallback patching so 5.x signatures are patched correctly.
  • Replaced 5.x fallback wrappers with full patched implementations so the loss aggregation guard and kwargs handling apply on 5.x.
  • Preserved output_hidden_states/output_attentions in fallback kwargs handling to avoid generation hidden state regressions.
  • Added defensive depth-decoder input id clamping in patched paths.

Validation notes (local, no CSM short-circuit):

  • The previous training-time failures are fixed on transformers 5.1.0 (in-place autograd + Tensor+None loss).
  • CSM training now runs and completes.
  • There is still a remaining CUDA device-side assert in depth-decoder generation (index out of bounds in IndexKernel), which appears to be a deeper generation-path issue.

@danielhanchen
Copy link
Contributor Author

Additional updates pushed after further validation:

  • ea8eb1b csm: preserve generation hidden states in 5.x fallbacks
  • dd26d59 csm: restore misc.py file mode after update

Current local status on transformers 5.1.0 (no CSM short-circuit):

  • CSM training path now completes.
  • Generation proceeds deeper than before, but there is still a remaining CUDA device-side assert in the depth-decoder generation path (index out of bounds in CUDA IndexKernel).

Smoke harness run with standard CSM shortcut remains passing after these updates.

cache_position=cache_position,
**kwargs,
)
pass
)
pass
patch_function(modeling_csm.CsmDepthDecoderModel, "forward", forward, match_level="relaxed")
pass
TEMPORARY_PATCHES.append(patch_CsmDepthDecoderForCausalLM_forward)


def patch_CsmDepthDecoderModel_forward():
@Datta0
Copy link
Collaborator

Datta0 commented Feb 16, 2026

btw did you ask it to verify correctness on both transformers v4 and v5? The two seem to have different args/params

@Datta0
Copy link
Collaborator

Datta0 commented Feb 16, 2026

Ref: #495

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants