diff --git a/unsloth_zoo/temporary_patches/misc.py b/unsloth_zoo/temporary_patches/misc.py index c4ff97629..0e1b9a90d 100644 --- a/unsloth_zoo/temporary_patches/misc.py +++ b/unsloth_zoo/temporary_patches/misc.py @@ -179,7 +179,18 @@ def forward( logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: KWARGS_TYPE, ) -> Union[Tuple, CausalLMOutputWithPast]: - kwargs = process_output_options(self, locals(), kwargs) + output_attentions = kwargs.pop("output_attentions", None) + output_hidden_states = kwargs.pop("output_hidden_states", None) + output_attentions = ( + output_attentions if output_attentions is not None else getattr(self.config, "output_attentions", False) + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else getattr(self.config, "output_hidden_states", False) + ) + kwargs["output_attentions"] = output_attentions + kwargs["output_hidden_states"] = output_hidden_states + if input_ids is not None: + input_ids = input_ids.clamp(min=0, max=self.config.vocab_size - 1) # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) outputs = self.model( @@ -232,7 +243,6 @@ def forward( if success: return # New transformers removes output_attentions and output_hidden_states - old_forward = forward def forward( self, input_ids: torch.LongTensor = None, @@ -249,16 +259,141 @@ def forward( logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: KWARGS_TYPE, ) -> Union[Tuple, CausalLMOutputWithPast]: - new_kwargs = locals().copy() - new_kwargs.pop('old_forward', None) - kwargs = new_kwargs.pop('kwargs', dict()) - new_kwargs.update(kwargs) - return old_forward(**new_kwargs) - patch_function(transformers.models.csm.modeling_csm.CsmDepthDecoderForCausalLM, "forward", forward) + output_attentions = kwargs.pop("output_attentions", None) + output_hidden_states = kwargs.pop("output_hidden_states", None) + output_attentions = ( + output_attentions if output_attentions is not None else getattr(self.config, "output_attentions", False) + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else getattr(self.config, "output_hidden_states", False) + ) + kwargs["output_attentions"] = output_attentions + kwargs["output_hidden_states"] = output_hidden_states + if input_ids is not None: + input_ids = input_ids.clamp(min=0, max=self.config.vocab_size - 1) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids = input_ids, + backbone_last_hidden_state = backbone_last_hidden_state, + attention_mask = attention_mask, + position_ids = position_ids, + past_key_values = past_key_values, + inputs_embeds = inputs_embeds, + use_cache = use_cache, + cache_position = cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + if isinstance(logits_to_keep, int): + if logits_to_keep == 0: + # skip idx 0 logits since it's for the concatenated backbone last hidden state + slice_indices = slice(1, None) + else: + slice_indices = slice(-logits_to_keep, None) + else: + slice_indices = logits_to_keep + + logits = self.codebooks_head( + hidden_states[:, slice_indices, :], cache_position[slice_indices] if cache_position is not None else None + ) + logits = logits.contiguous() + + loss = None + if labels is not None: + shift_labels = labels[..., 1:].contiguous() + loss = ForCausalLMLoss( + logits=logits, labels=None, vocab_size=self.config.vocab_size, shift_labels=shift_labels, **kwargs + ) + + return process_return(CausalLMOutputWithPast, { + "loss" : loss, + "logits" : logits, + "past_key_values" : outputs.past_key_values, + "hidden_states" : outputs.hidden_states, + "attentions" : outputs.attentions, + }) + patch_function( + transformers.models.csm.modeling_csm.CsmDepthDecoderForCausalLM, + "forward", + forward, + match_level="relaxed", + ) pass TEMPORARY_PATCHES.append(patch_CsmDepthDecoderForCausalLM_forward) +def patch_CsmDepthDecoderModel_forward(): + """Fix in-place write on autograd leaf views in newer torch/transformers. + + transformers CSM depth decoder writes `inputs_embeds[:, 0] = ...` after + embedding lookup. On some version combos this can trip: + "a view of a leaf Variable that requires grad is being used in an in-place operation." + Cloning the embedding output before the write avoids this while preserving behavior. + """ + try: + import transformers.models.csm.modeling_csm as modeling_csm + except Exception as e: + return raise_error("CsmDepthDecoderModel.forward", e) + + # Avoid double-patching in the same process. + unique_name = _get_unique_storage_name(modeling_csm.CsmDepthDecoderModel, "forward") + if hasattr(modeling_csm.CsmDepthDecoderModel, unique_name): + return + + try: + old_forward = modeling_csm.CsmDepthDecoderModel.forward + except Exception as e: + return raise_error("CsmDepthDecoderModel.forward", e) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + backbone_last_hidden_state: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: KWARGS_TYPE, + ): + if inputs_embeds is None and input_ids is not None and backbone_last_hidden_state is not None: + input_ids = input_ids.clamp(min=0, max=self.config.vocab_size - 1) + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + inputs_seq_length = input_ids.shape[1] + cache_position = torch.arange( + past_seen_tokens, + past_seen_tokens + inputs_seq_length, + device=input_ids.device, + ) + codebook_idxs = torch.clamp(cache_position - 1, min=0) + offset = codebook_idxs * self.vocab_size + inputs_embeds = self.embed_tokens(input_ids + offset).clone() + inputs_embeds[:, 0] = backbone_last_hidden_state + input_ids = None + + return old_forward( + self, + input_ids=input_ids, + backbone_last_hidden_state=backbone_last_hidden_state, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + cache_position=cache_position, + **kwargs, + ) + pass + patch_function(modeling_csm.CsmDepthDecoderModel, "forward", forward, match_level="relaxed") +pass +TEMPORARY_PATCHES.append(patch_CsmDepthDecoderModel_forward) + + def patch_CsmForConditionalGeneration_forward(): try: import transformers.models.csm.modeling_csm @@ -284,7 +419,16 @@ def forward( logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: KWARGS_TYPE, ) -> Union[Tuple, CsmOutputWithPast]: - kwargs = process_output_options(self, locals(), kwargs) + output_attentions = kwargs.pop("output_attentions", None) + output_hidden_states = kwargs.pop("output_hidden_states", None) + output_attentions = ( + output_attentions if output_attentions is not None else getattr(self.config, "output_attentions", False) + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else getattr(self.config, "output_hidden_states", False) + ) + kwargs["output_attentions"] = output_attentions + kwargs["output_hidden_states"] = output_hidden_states if input_ids is not None and input_ids.ndim == 2: merged_inputs = self._merge_input_ids_with_input_values( @@ -361,8 +505,13 @@ def forward( **depth_decoder_kwargs, ) - depth_decoder_loss = depth_decoder_outputs.loss - loss = backbone_loss + depth_decoder_loss + depth_decoder_loss = depth_decoder_outputs.loss if depth_decoder_outputs is not None else None + if backbone_loss is None: + loss = depth_decoder_loss + elif depth_decoder_loss is None: + loss = backbone_loss + else: + loss = backbone_loss + depth_decoder_loss return process_return(CsmOutputWithPast, { "loss" : loss, @@ -386,7 +535,6 @@ def forward( if success: return # New transformers removes output_attentions and output_hidden_states - old_forward = forward def forward( self, input_ids: torch.LongTensor = None, @@ -404,12 +552,118 @@ def forward( logits_to_keep: Union[int, torch.Tensor] = 0, **kwargs: KWARGS_TYPE, ) -> Union[Tuple, CsmOutputWithPast]: - new_kwargs = locals().copy() - new_kwargs.pop('old_forward', None) - kwargs = new_kwargs.pop('kwargs', dict()) - new_kwargs.update(kwargs) - return old_forward(**new_kwargs) - patch_function(transformers.models.csm.modeling_csm.CsmForConditionalGeneration, "forward", forward) + output_attentions = kwargs.pop("output_attentions", None) + output_hidden_states = kwargs.pop("output_hidden_states", None) + output_attentions = ( + output_attentions if output_attentions is not None else getattr(self.config, "output_attentions", False) + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else getattr(self.config, "output_hidden_states", False) + ) + kwargs["output_attentions"] = output_attentions + kwargs["output_hidden_states"] = output_hidden_states + + if input_ids is not None and input_ids.ndim == 2: + merged_inputs = self._merge_input_ids_with_input_values( + input_ids, input_values, input_values_cutoffs, labels + ) + inputs_embeds = merged_inputs["inputs_embeds"] + labels = merged_inputs["labels"] + input_ids = None + + backbone_outputs = self.backbone_model( + input_ids = input_ids, + attention_mask = attention_mask, + position_ids = position_ids, + past_key_values = past_key_values, + inputs_embeds = inputs_embeds, + use_cache = use_cache, + cache_position = cache_position, + **kwargs, + ) + + backbone_hidden_states = backbone_outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + backbone_logits = self.lm_head(backbone_hidden_states[:, slice_indices, :]) + + loss = None + backbone_loss = None + depth_decoder_loss = None + depth_decoder_outputs = None + if labels is not None: + # select first codebook as labels for the backbone model + backbone_labels = labels[:, :, 0] + backbone_loss = self.loss_function( + logits=backbone_logits, labels=backbone_labels, vocab_size=self.config.vocab_size, **kwargs + ) + + # for the depth decoder, we need to select the frames to train on + # those are frames where the label is not uniformly `ignore_index` along the codebook dimension + train_mask = ~(labels[:, :, 1:] == -100).all(dim=-1) + depth_decoder_input_ids = labels[train_mask][..., : self.config.num_codebooks - 1] + # add place holder in position 0 that will be replaced by the backbone_last_hidden_state + depth_decoder_input_ids = torch.nn.functional.pad(depth_decoder_input_ids, (1, 0), value=0) + + train_idxs = train_mask.nonzero(as_tuple=True) + backbone_last_hidden_states = backbone_hidden_states[train_idxs[0], train_idxs[1] - 1, :] + depth_decoder_labels = labels[train_mask] + + # Fix: explicitly pass kwargs to depth decoder to get access to num_items_in_batch + depth_decoder_kwargs = kwargs.copy() + # backbone loss num_items is based on the 0th codebooks index + # while depth loss num_items is based on the the remaining 31 codebooks + # therefore num_items_in_batch should be multiplied by 31 + if 'num_items_in_batch' in depth_decoder_kwargs: + depth_decoder_kwargs['num_items_in_batch'] = depth_decoder_kwargs['num_items_in_batch'] * 31 + + # make sure return_dict is set to True + depth_decoder_kwargs.pop('return_dict', None) + # Move output_attentions and output_hidden_states since transformers 4.54 deletes them + depth_decoder_kwargs["output_attentions"] = kwargs.get("output_attentions") + depth_decoder_kwargs["output_hidden_states"] = kwargs.get("output_hidden_states") + + depth_decoder_outputs = self.depth_decoder( + input_ids = depth_decoder_input_ids, + backbone_last_hidden_state = backbone_last_hidden_states, + use_cache = use_cache, + return_dict = True, + labels = depth_decoder_labels, + # Fix: explicitly pass kwargs to depth decoder to get access to num_items_in_batch + **depth_decoder_kwargs, + ) + + depth_decoder_loss = depth_decoder_outputs.loss if depth_decoder_outputs is not None else None + if backbone_loss is None: + loss = depth_decoder_loss + elif depth_decoder_loss is None: + loss = backbone_loss + else: + loss = backbone_loss + depth_decoder_loss + + return process_return(CsmOutputWithPast, { + "loss" : loss, + "backbone_loss" : backbone_loss, + "depth_decoder_loss" : depth_decoder_loss, + "logits" : backbone_logits, + "past_key_values" : backbone_outputs.past_key_values, + "hidden_states" : backbone_outputs.hidden_states, + "attentions" : backbone_outputs.attentions, + "depth_decoder_logits" : depth_decoder_outputs.logits if depth_decoder_outputs is not None else None, + "depth_decoder_past_key_values" : depth_decoder_outputs.past_key_values + if depth_decoder_outputs is not None + else None, + "depth_decoder_hidden_states" : depth_decoder_outputs.hidden_states + if depth_decoder_outputs is not None + else None, + "depth_decoder_attentions" : depth_decoder_outputs.attentions if depth_decoder_outputs is not None else None, + }) + patch_function( + transformers.models.csm.modeling_csm.CsmForConditionalGeneration, + "forward", + forward, + match_level="relaxed", + ) pass TEMPORARY_PATCHES.append(patch_CsmForConditionalGeneration_forward)