diff --git a/src/zeroband/models/llama/model.py b/src/zeroband/models/llama/model.py index e73c3fbf..78e2cce0 100644 --- a/src/zeroband/models/llama/model.py +++ b/src/zeroband/models/llama/model.py @@ -143,8 +143,8 @@ def apply_rotary_emb( xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - flop_counter.track_binary(xq_, freqs_cis) - flop_counter.track_binary(xk_, freqs_cis) + # flop_counter.track_binary(xq_, freqs_cis) + # flop_counter.track_binary(xk_, freqs_cis) return xq_out.type_as(xq), xk_out.type_as(xk)