Skip to content

When applying DAS to CLIP, there is an issue where the gradient becomes NaN. #11

@yunsangju

Description

@yunsangju

Hello,

I am applying the DAS method to CLIP. When calculating the importance, the text model generates gradients well, but the vision model mostly produces NaN values. The units for calculating importance are placed in the self_attn and mlp of the CLIPEncoderLayer. The CLIPEncoderLayer is used identically for both the text and vision models.

I have declared the masks as follows:

class CLIPEncoder(nn.Module):
    """
    Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
    [`CLIPEncoderLayer`].

    Args:
        config: CLIPConfig
    """

    def __init__(self, config: CLIPConfig):
        super().__init__()
        self.config = config
        self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)])
        self.self_attn_mask = torch.ones(config.num_hidden_layers, config.hidden_size, dtype=torch.float16)
        self.self_attn_mask.requires_grad_(True)
        self.mlp_mask = torch.ones(config.num_hidden_layers, config.hidden_size, dtype=torch.float16)
        self.mlp_mask.requires_grad_(True)
        
        self.gradient_checkpointing = False

I have implemented it to operate in the CLIPEncoderLayer as follows:

residual = hidden_states

        hidden_states = self.layer_norm1(hidden_states)
        hidden_states, attn_weights = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            causal_attention_mask=causal_attention_mask,
            output_attentions=output_attentions,
        )
        self_attn_mask = self_attn_mask.to(hidden_states.device)
        hidden_states = hidden_states * self_attn_mask
        hidden_states = residual + hidden_states
        mlp_mask = mlp_mask.to(hidden_states.device)
        residual = hidden_states
        hidden_states = self.layer_norm2(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states *= mlp_mask
        hidden_states = residual + hidden_states
        outputs = (hidden_states,)

I would like to inquire if you have experienced the same phenomenon or if the implementation is incorrect.

Thank you.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions