Skip to content

Conversation

@MichaelRipa
Copy link
Member

Description

This feature makes it easier to access submodules found via .source programmatically, by implementing __iter__ for EnvoySource, the object type returned from .source.

Demonstration

from nnsight import LanguageModel

model = LanguageModel("openai-community/gpt2", device_map="auto", dispatch=False)

print(list(model.transformer.h[0].attn.source))

Output:

['isinstance_0', 'past_key_values_is_updated_get_0', 'hasattr_0', 'ValueError_0', 'self_q_attn_0', 'self_c_attn_0', 'key_states_view_0', 'value_states_view_0', 'self_c_attn_1', 'key_states_view_1', 'value_states_view_1', 'query_states_view_0', 'curr_past_key_value_update_0', 'self__upcast_and_reordered_attn_0', 'attention_interface_0', 'attn_output_reshape_0', 'self_c_proj_0', 'self_resid_dropout_0', 'deprecate_kwarg_0']

Previously, printing out .source was the only way to take a peep. It is useful, but quite verbose, and requires manual observation:

                                       * @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
                                       0 def forward(
                                       1     self,
                                       2     hidden_states: Optional[tuple[torch.FloatTensor]],
                                       3     past_key_values: Optional[Cache] = None,
                                       4     cache_position: Optional[torch.LongTensor] = None,
                                       5     attention_mask: Optional[torch.FloatTensor] = None,
                                       6     head_mask: Optional[torch.FloatTensor] = None,
                                       7     encoder_hidden_states: Optional[torch.Tensor] = None,
                                       8     encoder_attention_mask: Optional[torch.FloatTensor] = None,
                                       9     output_attentions: Optional[bool] = False,
                                      10     **kwargs,
                                      11 ) -> tuple[Union[torch.Tensor, tuple[torch.Tensor]], ...]:
                                      12     is_cross_attention = encoder_hidden_states is not None
                                      13     if past_key_values is not None:
 isinstance_0                      -> 14         if isinstance(past_key_values, EncoderDecoderCache):
 past_key_values_is_updated_get_0  -> 15             is_updated = past_key_values.is_updated.get(self.layer_idx)
                                      16             if is_cross_attention:
                                      17                 # after the first generated id, we can subsequently re-use all key/value_layer from cache
                                      18                 curr_past_key_value = past_key_values.cross_attention_cache
                                      19             else:
                                      20                 curr_past_key_value = past_key_values.self_attention_cache
                                      21         else:
                                      22             curr_past_key_value = past_key_values
                                      23 
                                      24     if is_cross_attention:
 hasattr_0                         -> 25         if not hasattr(self, "q_attn"):
 ValueError_0                      -> 26             raise ValueError(
                                      27                 "If class is used as cross attention, the weights `q_attn` have to be defined. "
                                      28                 "Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
                                      29             )
 self_q_attn_0                     -> 30         query_states = self.q_attn(hidden_states)
                                      31         attention_mask = encoder_attention_mask
                                      32 
                                      33         # Try to get key/value states from cache if possible
                                      34         if past_key_values is not None and is_updated:
                                      35             key_states = curr_past_key_value.layers[self.layer_idx].keys
                                      36             value_states = curr_past_key_value.layers[self.layer_idx].values
                                      37         else:
 self_c_attn_0                     -> 38             key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
 split_0                           ->  +             ...
                                      39             shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
 key_states_view_0                 -> 40             key_states = key_states.view(shape_kv).transpose(1, 2)
 transpose_0                       ->  +             ...
 value_states_view_0               -> 41             value_states = value_states.view(shape_kv).transpose(1, 2)
 transpose_1                       ->  +             ...
                                      42     else:
 self_c_attn_1                     -> 43         query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
 split_1                           ->  +         ...
                                      44         shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
 key_states_view_1                 -> 45         key_states = key_states.view(shape_kv).transpose(1, 2)
 transpose_2                       ->  +         ...
 value_states_view_1               -> 46         value_states = value_states.view(shape_kv).transpose(1, 2)
 transpose_3                       ->  +         ...
                                      47 
                                      48     shape_q = (*query_states.shape[:-1], -1, self.head_dim)
 query_states_view_0               -> 49     query_states = query_states.view(shape_q).transpose(1, 2)
 transpose_4                       ->  +     ...
                                      50 
                                      51     if (past_key_values is not None and not is_cross_attention) or (
                                      52         past_key_values is not None and is_cross_attention and not is_updated
                                      53     ):
                                      54         # save all key/value_layer to cache to be re-used for fast auto-regressive generation
                                      55         cache_position = cache_position if not is_cross_attention else None
 curr_past_key_value_update_0      -> 56         key_states, value_states = curr_past_key_value.update(
                                      57             key_states, value_states, self.layer_idx, {"cache_position": cache_position}
                                      58         )
                                      59         # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
                                      60         if is_cross_attention:
                                      61             past_key_values.is_updated[self.layer_idx] = True
                                      62 
                                      63     is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
                                      64 
                                      65     using_eager = self.config._attn_implementation == "eager"
                                      66     attention_interface: Callable = eager_attention_forward
                                      67     if self.config._attn_implementation != "eager":
                                      68         attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
                                      69 
                                      70     if using_eager and self.reorder_and_upcast_attn:
 self__upcast_and_reordered_attn_0 -> 71         attn_output, attn_weights = self._upcast_and_reordered_attn(
                                      72             query_states, key_states, value_states, attention_mask, head_mask
                                      73         )
                                      74     else:
 attention_interface_0             -> 75         attn_output, attn_weights = attention_interface(
                                      76             self,
                                      77             query_states,
                                      78             key_states,
                                      79             value_states,
                                      80             attention_mask,
                                      81             head_mask=head_mask,
                                      82             dropout=self.attn_dropout.p if self.training else 0.0,
                                      83             is_causal=is_causal,
                                      84             **kwargs,
                                      85         )
                                      86 
 attn_output_reshape_0             -> 87     attn_output = attn_output.reshape(*attn_output.shape[:-2], -1).contiguous()
 contiguous_0                      ->  +     ...
 self_c_proj_0                     -> 88     attn_output = self.c_proj(attn_output)
 self_resid_dropout_0              -> 89     attn_output = self.resid_dropout(attn_output)
                                      90 
                                      91     return attn_output, attn_weights
                                      92 

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants