11import functools as ft
22import math
33import warnings
4- from typing import Callable , Literal , Optional , Tuple , Union
4+ from typing import Callable , Literal , Optional , Tuple , Union , Dict
55
66import equinox as eqx
77import jax
@@ -153,11 +153,15 @@ class MultiheadAttention(eqx.Module):
153153 value_proj : Linear
154154 output_proj : Linear
155155 dropout : Dropout
156+
156157 autoregressive_index : StateIndex [
157- Tuple [
158- Float [Array , "S H QK" ] | Float [Array , "S QK" ],
159- Float [Array , "S H VO" ] | Float [Array , "S VO" ],
160- Int [Array , "" ],
158+ Dict [
159+ str ,
160+ Tuple [
161+ Float [Array , "S H QK" ] | Float [Array , "S QK" ],
162+ Float [Array , "S H VO" ] | Float [Array , "S VO" ],
163+ Int [Array , "" ],
164+ ]
161165 ]
162166 ]
163167
@@ -241,9 +245,9 @@ def _make_autoregressive_cache(**_):
241245 else :
242246 _int = jnp .int32
243247
244- return jnp .empty (key_shape ), jnp .empty (value_shape ), jnp .zeros ((), _int )
245- # initial_cache = (jnp.empty(key_shape), jnp.empty(value_shape), jnp.zeros((), _int))
246- # return dict(uncond=initial_cache, cond=initial_cache)
248+ # return jnp.empty(key_shape), jnp.empty(value_shape), jnp.zeros((), _int)
249+ initial_cache = (jnp .empty (key_shape ), jnp .empty (value_shape ), jnp .zeros ((), _int ))
250+ return dict (uncond = initial_cache , cond = initial_cache )
247251
248252 query_proj_out_size = qk_size
249253 key_proj_out_size = qk_size
@@ -312,6 +316,8 @@ def __call__(
312316 state : Optional [State ] = None ,
313317 * ,
314318 key : Optional [PRNGKeyArray ] = None ,
319+ temperature : Optional [float ] = 1. ,
320+ which_cache : Literal ["cond" , "uncond" ],
315321 inference : Optional [bool ] = None ,
316322 deterministic : Optional [bool ] = None ,
317323 process_heads : Optional [
@@ -372,7 +378,7 @@ def __call__(
372378 if state is None :
373379 causal_mask_offset = 0
374380 else :
375- key_state , value_state , index = state .get (self .autoregressive_index )
381+ key_state , value_state , index = state .get (self .autoregressive_index )[ which_cache ]
376382
377383 # If the index is larger than state length, it will wrap around and start from zero
378384 key_state = lax .dynamic_update_slice_in_dim (
@@ -385,8 +391,13 @@ def __call__(
385391 causal_mask_offset = index # Offset shifts attention lower-tril
386392 index = index + kv_seq_length # i -> i + 1, nudging autoregression
387393
394+ other_cache = "cond" if which_cache == "uncond" else "uncond"
395+ empty_cache = jax .tree .map (
396+ lambda x : jnp .zeros_like (x ), (key_state , value_state , index )
397+ )
388398 state = state .set (
389- self .autoregressive_index , (key_state , value_state , index )
399+ self .autoregressive_index ,
400+ {which_cache : (key_state , value_state , index ), other_cache : empty_cache }
390401 )
391402
392403 # if sample:
@@ -429,6 +440,7 @@ def __call__(
429440 self .dropout ,
430441 inference ,
431442 attn_bias = self .attn_bias ,
443+ scale_factor = self .scale_factor if temperature is None else temperature ,
432444 keys = keys ,
433445 )
434446
0 commit comments