Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
290 changes: 272 additions & 18 deletions unsloth_zoo/temporary_patches/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,18 @@ def forward(
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: KWARGS_TYPE,
) -> Union[Tuple, CausalLMOutputWithPast]:
kwargs = process_output_options(self, locals(), kwargs)
output_attentions = kwargs.pop("output_attentions", None)
output_hidden_states = kwargs.pop("output_hidden_states", None)
output_attentions = (
output_attentions if output_attentions is not None else getattr(self.config, "output_attentions", False)
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else getattr(self.config, "output_hidden_states", False)
)
kwargs["output_attentions"] = output_attentions
kwargs["output_hidden_states"] = output_hidden_states
if input_ids is not None:
input_ids = input_ids.clamp(min=0, max=self.config.vocab_size - 1)

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
Expand Down Expand Up @@ -232,7 +243,6 @@ def forward(
if success: return

# New transformers removes output_attentions and output_hidden_states
old_forward = forward
def forward(
self,
input_ids: torch.LongTensor = None,
Expand All @@ -249,16 +259,141 @@ def forward(
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: KWARGS_TYPE,
) -> Union[Tuple, CausalLMOutputWithPast]:
new_kwargs = locals().copy()
new_kwargs.pop('old_forward', None)
kwargs = new_kwargs.pop('kwargs', dict())
new_kwargs.update(kwargs)
return old_forward(**new_kwargs)
patch_function(transformers.models.csm.modeling_csm.CsmDepthDecoderForCausalLM, "forward", forward)
output_attentions = kwargs.pop("output_attentions", None)
output_hidden_states = kwargs.pop("output_hidden_states", None)
output_attentions = (
output_attentions if output_attentions is not None else getattr(self.config, "output_attentions", False)
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else getattr(self.config, "output_hidden_states", False)
)
kwargs["output_attentions"] = output_attentions
kwargs["output_hidden_states"] = output_hidden_states
if input_ids is not None:
input_ids = input_ids.clamp(min=0, max=self.config.vocab_size - 1)

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids = input_ids,
backbone_last_hidden_state = backbone_last_hidden_state,
attention_mask = attention_mask,
position_ids = position_ids,
past_key_values = past_key_values,
inputs_embeds = inputs_embeds,
use_cache = use_cache,
cache_position = cache_position,
**kwargs,
)

hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
if isinstance(logits_to_keep, int):
if logits_to_keep == 0:
# skip idx 0 logits since it's for the concatenated backbone last hidden state
slice_indices = slice(1, None)
else:
slice_indices = slice(-logits_to_keep, None)
else:
slice_indices = logits_to_keep

logits = self.codebooks_head(
hidden_states[:, slice_indices, :], cache_position[slice_indices] if cache_position is not None else None
)
logits = logits.contiguous()

loss = None
if labels is not None:
shift_labels = labels[..., 1:].contiguous()
loss = ForCausalLMLoss(
logits=logits, labels=None, vocab_size=self.config.vocab_size, shift_labels=shift_labels, **kwargs
)

return process_return(CausalLMOutputWithPast, {
"loss" : loss,
"logits" : logits,
"past_key_values" : outputs.past_key_values,
"hidden_states" : outputs.hidden_states,
"attentions" : outputs.attentions,
})
patch_function(
transformers.models.csm.modeling_csm.CsmDepthDecoderForCausalLM,
"forward",
forward,
match_level="relaxed",
)
pass
TEMPORARY_PATCHES.append(patch_CsmDepthDecoderForCausalLM_forward)


def patch_CsmDepthDecoderModel_forward():
"""Fix in-place write on autograd leaf views in newer torch/transformers.

transformers CSM depth decoder writes `inputs_embeds[:, 0] = ...` after
embedding lookup. On some version combos this can trip:
"a view of a leaf Variable that requires grad is being used in an in-place operation."
Cloning the embedding output before the write avoids this while preserving behavior.
"""
try:
import transformers.models.csm.modeling_csm as modeling_csm
except Exception as e:
return raise_error("CsmDepthDecoderModel.forward", e)

# Avoid double-patching in the same process.
unique_name = _get_unique_storage_name(modeling_csm.CsmDepthDecoderModel, "forward")
if hasattr(modeling_csm.CsmDepthDecoderModel, unique_name):
return

try:
old_forward = modeling_csm.CsmDepthDecoderModel.forward
except Exception as e:
return raise_error("CsmDepthDecoderModel.forward", e)

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
backbone_last_hidden_state: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs: KWARGS_TYPE,
):
if inputs_embeds is None and input_ids is not None and backbone_last_hidden_state is not None:
input_ids = input_ids.clamp(min=0, max=self.config.vocab_size - 1)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
inputs_seq_length = input_ids.shape[1]
cache_position = torch.arange(
past_seen_tokens,
past_seen_tokens + inputs_seq_length,
device=input_ids.device,
)
codebook_idxs = torch.clamp(cache_position - 1, min=0)
offset = codebook_idxs * self.vocab_size
inputs_embeds = self.embed_tokens(input_ids + offset).clone()
inputs_embeds[:, 0] = backbone_last_hidden_state
input_ids = None

return old_forward(
self,
input_ids=input_ids,
backbone_last_hidden_state=backbone_last_hidden_state,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
cache_position=cache_position,
**kwargs,
)
pass
patch_function(modeling_csm.CsmDepthDecoderModel, "forward", forward, match_level="relaxed")
pass
TEMPORARY_PATCHES.append(patch_CsmDepthDecoderModel_forward)


def patch_CsmForConditionalGeneration_forward():
try:
import transformers.models.csm.modeling_csm
Expand All @@ -284,7 +419,16 @@ def forward(
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: KWARGS_TYPE,
) -> Union[Tuple, CsmOutputWithPast]:
kwargs = process_output_options(self, locals(), kwargs)
output_attentions = kwargs.pop("output_attentions", None)
output_hidden_states = kwargs.pop("output_hidden_states", None)
output_attentions = (
output_attentions if output_attentions is not None else getattr(self.config, "output_attentions", False)
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else getattr(self.config, "output_hidden_states", False)
)
kwargs["output_attentions"] = output_attentions
kwargs["output_hidden_states"] = output_hidden_states

if input_ids is not None and input_ids.ndim == 2:
merged_inputs = self._merge_input_ids_with_input_values(
Expand Down Expand Up @@ -361,8 +505,13 @@ def forward(
**depth_decoder_kwargs,
)

depth_decoder_loss = depth_decoder_outputs.loss
loss = backbone_loss + depth_decoder_loss
depth_decoder_loss = depth_decoder_outputs.loss if depth_decoder_outputs is not None else None
if backbone_loss is None:
loss = depth_decoder_loss
elif depth_decoder_loss is None:
loss = backbone_loss
else:
loss = backbone_loss + depth_decoder_loss
Comment on lines +509 to +514
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

While the current logic is correct, it can be simplified to be more concise and scalable for potentially adding more loss components in the future. You can use a list comprehension to filter out None values and then use sum().

            losses = [l for l in (backbone_loss, depth_decoder_loss) if l is not None]
            loss = sum(losses) if losses else None


return process_return(CsmOutputWithPast, {
"loss" : loss,
Expand All @@ -386,7 +535,6 @@ def forward(
if success: return

# New transformers removes output_attentions and output_hidden_states
old_forward = forward
def forward(
self,
input_ids: torch.LongTensor = None,
Expand All @@ -404,12 +552,118 @@ def forward(
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: KWARGS_TYPE,
) -> Union[Tuple, CsmOutputWithPast]:
new_kwargs = locals().copy()
new_kwargs.pop('old_forward', None)
kwargs = new_kwargs.pop('kwargs', dict())
new_kwargs.update(kwargs)
return old_forward(**new_kwargs)
patch_function(transformers.models.csm.modeling_csm.CsmForConditionalGeneration, "forward", forward)
output_attentions = kwargs.pop("output_attentions", None)
output_hidden_states = kwargs.pop("output_hidden_states", None)
output_attentions = (
output_attentions if output_attentions is not None else getattr(self.config, "output_attentions", False)
)
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else getattr(self.config, "output_hidden_states", False)
)
kwargs["output_attentions"] = output_attentions
kwargs["output_hidden_states"] = output_hidden_states

if input_ids is not None and input_ids.ndim == 2:
merged_inputs = self._merge_input_ids_with_input_values(
input_ids, input_values, input_values_cutoffs, labels
)
inputs_embeds = merged_inputs["inputs_embeds"]
labels = merged_inputs["labels"]
input_ids = None

backbone_outputs = self.backbone_model(
input_ids = input_ids,
attention_mask = attention_mask,
position_ids = position_ids,
past_key_values = past_key_values,
inputs_embeds = inputs_embeds,
use_cache = use_cache,
cache_position = cache_position,
**kwargs,
)

backbone_hidden_states = backbone_outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
backbone_logits = self.lm_head(backbone_hidden_states[:, slice_indices, :])

loss = None
backbone_loss = None
depth_decoder_loss = None
depth_decoder_outputs = None
if labels is not None:
# select first codebook as labels for the backbone model
backbone_labels = labels[:, :, 0]
backbone_loss = self.loss_function(
logits=backbone_logits, labels=backbone_labels, vocab_size=self.config.vocab_size, **kwargs
)

# for the depth decoder, we need to select the frames to train on
# those are frames where the label is not uniformly `ignore_index` along the codebook dimension
train_mask = ~(labels[:, :, 1:] == -100).all(dim=-1)
depth_decoder_input_ids = labels[train_mask][..., : self.config.num_codebooks - 1]
# add place holder in position 0 that will be replaced by the backbone_last_hidden_state
depth_decoder_input_ids = torch.nn.functional.pad(depth_decoder_input_ids, (1, 0), value=0)

train_idxs = train_mask.nonzero(as_tuple=True)
backbone_last_hidden_states = backbone_hidden_states[train_idxs[0], train_idxs[1] - 1, :]
depth_decoder_labels = labels[train_mask]

# Fix: explicitly pass kwargs to depth decoder to get access to num_items_in_batch
depth_decoder_kwargs = kwargs.copy()
# backbone loss num_items is based on the 0th codebooks index
# while depth loss num_items is based on the the remaining 31 codebooks
# therefore num_items_in_batch should be multiplied by 31
if 'num_items_in_batch' in depth_decoder_kwargs:
depth_decoder_kwargs['num_items_in_batch'] = depth_decoder_kwargs['num_items_in_batch'] * 31

# make sure return_dict is set to True
depth_decoder_kwargs.pop('return_dict', None)
# Move output_attentions and output_hidden_states since transformers 4.54 deletes them
depth_decoder_kwargs["output_attentions"] = kwargs.get("output_attentions")
depth_decoder_kwargs["output_hidden_states"] = kwargs.get("output_hidden_states")

depth_decoder_outputs = self.depth_decoder(
input_ids = depth_decoder_input_ids,
backbone_last_hidden_state = backbone_last_hidden_states,
use_cache = use_cache,
return_dict = True,
labels = depth_decoder_labels,
# Fix: explicitly pass kwargs to depth decoder to get access to num_items_in_batch
**depth_decoder_kwargs,
)

depth_decoder_loss = depth_decoder_outputs.loss if depth_decoder_outputs is not None else None
if backbone_loss is None:
loss = depth_decoder_loss
elif depth_decoder_loss is None:
loss = backbone_loss
else:
loss = backbone_loss + depth_decoder_loss

return process_return(CsmOutputWithPast, {
"loss" : loss,
"backbone_loss" : backbone_loss,
"depth_decoder_loss" : depth_decoder_loss,
"logits" : backbone_logits,
"past_key_values" : backbone_outputs.past_key_values,
"hidden_states" : backbone_outputs.hidden_states,
"attentions" : backbone_outputs.attentions,
"depth_decoder_logits" : depth_decoder_outputs.logits if depth_decoder_outputs is not None else None,
"depth_decoder_past_key_values" : depth_decoder_outputs.past_key_values
if depth_decoder_outputs is not None
else None,
"depth_decoder_hidden_states" : depth_decoder_outputs.hidden_states
if depth_decoder_outputs is not None
else None,
"depth_decoder_attentions" : depth_decoder_outputs.attentions if depth_decoder_outputs is not None else None,
})
patch_function(
transformers.models.csm.modeling_csm.CsmForConditionalGeneration,
"forward",
forward,
match_level="relaxed",
)
pass
TEMPORARY_PATCHES.append(patch_CsmForConditionalGeneration_forward)

Expand Down