diff --git a/axlearn/common/quantized_dot_general/layers.py b/axlearn/common/quantized_dot_general/layers.py index e32fefe3c..2935a5d15 100644 --- a/axlearn/common/quantized_dot_general/layers.py +++ b/axlearn/common/quantized_dot_general/layers.py @@ -27,6 +27,7 @@ import jax from absl import logging from aqt.jax.v2 import aqt_dot_general +from aqt.jax.v2 import utils as aqt_utils from jax import numpy as jnp from jax.lax import DotDimensionNumbers, Precision from jax.typing import DTypeLike @@ -79,7 +80,7 @@ def __call__( dimension_numbers: DotDimensionNumbers, precision: PrecisionLike = None, preferred_element_type: Optional[DTypeLike] = None, - context: aqt_dot_general.Context = aqt_dot_general.Context(key=None, train_step=None), + context: aqt_utils.Context = aqt_utils.Context(key=None, train_step=None), ) -> Tensor: ... diff --git a/axlearn/common/ssm.py b/axlearn/common/ssm.py index 83d97e64a..98798687e 100644 --- a/axlearn/common/ssm.py +++ b/axlearn/common/ssm.py @@ -984,30 +984,6 @@ def __init__(self, cfg: Config, *, parent: Module): self._add_child("b_norm", cfg.b_norm.set(input_dim=cfg.state_dim)) self._add_child("c_norm", cfg.c_norm.set(input_dim=cfg.state_dim)) - def _ssm_parameters(self, inputs: Tensor) -> MambaMixerLayer.SSMParameters: - """Computes layer-normed versions of the input-dependent SSM parameters. - - Args: - inputs: [batch_size, seq_len, inner_dim] - - Returns: - An instance of MambaMixerLayer.SSMParameters. - """ - cfg = self.config - x_dbl = self.x_proj(inputs) # [batch_size, seq_len, dt_rank, state_dim*2] - dt, b, c = jnp.split( - x_dbl, - ( - self.dt_rank, - self.dt_rank + cfg.state_dim, - ), - axis=-1, - ) - dt, b, c = self.dt_norm(dt), self.b_norm(b), self.c_norm(c) - delta = jax.nn.softplus(self.dt_proj(dt)) # [batch_size, seq_len, inner_dim] - a = -jnp.exp(_at_least_float32(self.parameters["log_a"])).astype(inputs.dtype) - return MambaMixerLayer.SSMParameters(a=a, b=b, c=c, delta=delta, d=self.parameters["d"]) - class BaseSSMLayer(BaseLayer): """An abstract class representing SSM layers. @@ -1445,3 +1421,50 @@ def __init__(self, cfg: Config, *, parent: Optional[Module]): for i in range(cfg.num_layers) ] super().__init__(cfg.set(layer=layers), parent=parent) + + +class HybridMambaRecurrence(BaseMambaRecurrence): + """A layer that combines different recurrence methods to leverage their strengths.""" + + @config_class + class Config(BaseMambaRecurrence.Config): + """Configures a HybridMambaRecurrence.""" + + primary_recurrence: BaseMambaRecurrence = LinearScanMambaRecurrence.default_config() + secondary_recurrence: BaseMambaRecurrence = AssociativeScanMambaRecurrence.default_config() + + def __init__(self, cfg: Config, *, parent: Module): + super().__init__(cfg, parent=parent) + self._add_child("primary_recurrence", cfg.primary_recurrence) + self._add_child("secondary_recurrence", cfg.secondary_recurrence) + + def forward( + self, x: Tensor, *, a: Tensor, b: Tensor, c: Tensor, delta: Tensor, d: Tensor + ) -> BaseMambaRecurrence.Output: + primary_output = self.primary_recurrence(x, a=a, b=b, c=c, delta=delta, d=d) + secondary_output = self.secondary_recurrence(x, a=a, b=b, c=c, delta=delta, d=d) + combined_data = (primary_output.data + secondary_output.data) / 2 + combined_states = ( + (primary_output.states + secondary_output.states) / 2 + if primary_output.states is not None and secondary_output.states is not None + else None + ) + return BaseMambaRecurrence.Output(data=combined_data, states=combined_states) + + +class AlternativeMambaRecurrence(BaseMambaRecurrence): + """A layer that implements an alternative recurrence method.""" + + def forward( + self, x: Tensor, *, a: Tensor, b: Tensor, c: Tensor, delta: Tensor, d: Tensor + ) -> BaseMambaRecurrence.Output: + # Implement an alternative recurrence method here. + # For demonstration, let's use a simple weighted sum of inputs and parameters. + weighted_sum = jnp.einsum("btd,sd->btsd", x, a) + jnp.einsum("bts,sd->btsd", b, c) + y = jnp.sum(weighted_sum, axis=-2) + d * x + states = ( + weighted_sum + if self.config.output_mode == MambaRecurrenceOutputMode.OUTPUTS_AND_STATES + else None + ) + return BaseMambaRecurrence.Output(data=y, states=states) diff --git a/axlearn/common/ssm_test.py b/axlearn/common/ssm_test.py index 3d8351285..719f5d094 100644 --- a/axlearn/common/ssm_test.py +++ b/axlearn/common/ssm_test.py @@ -30,8 +30,10 @@ from axlearn.common.config import InstantiableConfig from axlearn.common.module import functional as F from axlearn.common.ssm import ( + AlternativeMambaRecurrence, AssociativeScanMambaRecurrence, BlockResidualMode, + HybridMambaRecurrence, JambaMambaBlock, LinearScanMambaRecurrence, MambaBlock, @@ -509,6 +511,64 @@ def test_prefill_states(self, dtype: jnp.dtype): assert_allclose(decoder_output, forward_outputs.data, atol=1e-6) + @parameterized.parameters(jnp.float32, jnp.bfloat16) + def test_hybrid_recurrence(self, dtype: jnp.dtype): + model_dim = 4 + state_dim = 16 + cfg = MambaMixerLayer.default_config().set( + input_dim=model_dim, + state_dim=state_dim, + cache_dtype=dtype, + dtype=dtype, + ) + layer: MambaMixerLayer = cfg.set(name="test").instantiate(parent=None) + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + layer_params = cast_floats(layer_params, to_dtype=dtype) + batch_size, tgt_len = 2, 6 + query = jax.random.normal( + jax.random.PRNGKey(1), + [batch_size, tgt_len, model_dim], + dtype=dtype, + ) + inputs = dict(query=query) + forward_outputs, _ = F( + layer, + state=layer_params, + is_training=False, + prng_key=jax.random.PRNGKey(2), + inputs=inputs, + ) + assert forward_outputs.data.shape == (batch_size, tgt_len, model_dim) + + @parameterized.parameters(jnp.float32, jnp.bfloat16) + def test_alternative_recurrence(self, dtype: jnp.dtype): + model_dim = 4 + state_dim = 16 + cfg = MambaMixerLayer.default_config().set( + input_dim=model_dim, + state_dim=state_dim, + cache_dtype=dtype, + dtype=dtype, + ) + layer: MambaMixerLayer = cfg.set(name="test").instantiate(parent=None) + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + layer_params = cast_floats(layer_params, to_dtype=dtype) + batch_size, tgt_len = 2, 6 + query = jax.random.normal( + jax.random.PRNGKey(1), + [batch_size, tgt_len, model_dim], + dtype=dtype, + ) + inputs = dict(query=query) + forward_outputs, _ = F( + layer, + state=layer_params, + is_training=False, + prng_key=jax.random.PRNGKey(2), + inputs=inputs, + ) + assert forward_outputs.data.shape == (batch_size, tgt_len, model_dim) + def _test_extend_step(layer_cfg: InstantiableConfig, *, model_dim: int, dtype: jnp.dtype): """Tests extend for composite layers.""" @@ -834,6 +894,58 @@ def test_prefill(self, block_klass: MambaBlock, dtype: jnp.dtype): _test_prefill_states(cfg, model_dim=model_dim, dtype=dtype) + @parameterized.product( + block_klass=(MambaBlock, JambaMambaBlock), + dtype=(jnp.float32, jnp.bfloat16), + ) + def test_hybrid_recurrence_in_block(self, block_klass: MambaBlock, dtype: jnp.dtype): + model_dim = 16 + state_dim = 16 + hidden_dim = 32 + num_layers = 3 + + cfg = StackedSSMLayer.default_config().set( + input_dim=model_dim, + num_layers=num_layers, + layer=block_klass.default_config().set( + state_dim=state_dim, + mamba_layer=MambaMixerLayer.default_config().set( + recurrence=HybridMambaRecurrence.default_config() + ), + ), + ) + cfg.layer.mamba_layer.set(dtype=dtype, cache_dtype=None) + if hasattr(cfg.layer, "feed_forward"): + cfg.layer.feed_forward.hidden_dim = hidden_dim + + _test_extend_step(cfg, model_dim=model_dim, dtype=dtype) + + @parameterized.product( + block_klass=(MambaBlock, JambaMambaBlock), + dtype=(jnp.float32, jnp.bfloat16), + ) + def test_alternative_recurrence_in_block(self, block_klass: MambaBlock, dtype: jnp.dtype): + model_dim = 16 + state_dim = 16 + hidden_dim = 32 + num_layers = 3 + + cfg = StackedSSMLayer.default_config().set( + input_dim=model_dim, + num_layers=num_layers, + layer=block_klass.default_config().set( + state_dim=state_dim, + mamba_layer=MambaMixerLayer.default_config().set( + recurrence=AlternativeMambaRecurrence.default_config() + ), + ), + ) + cfg.layer.mamba_layer.set(dtype=dtype, cache_dtype=None) + if hasattr(cfg.layer, "feed_forward"): + cfg.layer.feed_forward.hidden_dim = hidden_dim + + _test_extend_step(cfg, model_dim=model_dim, dtype=dtype) + class StackedMixedSSMTransformerTest(TestCase): """Tests that mixing SSM layers and transformer layers behaves as expected.""" @@ -927,3 +1039,57 @@ def test_prefill(self, dtype: jnp.dtype): cfg.layer.self_attention.attention.num_heads = num_heads cfg.layer.self_attention.attention.input_linear.set(dtype=dtype, cache_dtype=None) _test_prefill_states(cfg, model_dim=model_dim, dtype=dtype) + + @parameterized.parameters(jnp.float32, jnp.bfloat16) + def test_hybrid_recurrence_in_mixed_layer(self, dtype: jnp.dtype): + model_dim = 16 + state_dim = 16 + num_heads = 4 + hidden_dim = 32 + num_layers = 4 + cfg = StackedMixedSSMTransformerLayer.default_config().set( + input_dim=model_dim, + num_layers=num_layers, + transformer_layer_period=3, + transformer_layer_offset=1, + ssm_layer=JambaMambaBlock.default_config().set( + state_dim=state_dim, + mamba_layer=MambaMixerLayer.default_config().set( + recurrence=HybridMambaRecurrence.default_config() + ), + ), + dtype=dtype, + ) + cfg.ssm_layer.feed_forward.hidden_dim = hidden_dim + cfg.ssm_layer.mamba_layer.set(dtype=dtype, cache_dtype=None) + cfg.layer.feed_forward.hidden_dim = hidden_dim + cfg.layer.self_attention.attention.num_heads = num_heads + cfg.layer.self_attention.attention.input_linear.set(dtype=dtype, cache_dtype=None) + _test_extend_step(cfg, model_dim=model_dim, dtype=dtype) + + @parameterized.parameters(jnp.float32, jnp.bfloat16) + def test_alternative_recurrence_in_mixed_layer(self, dtype: jnp.dtype): + model_dim = 16 + state_dim = 16 + num_heads = 4 + hidden_dim = 32 + num_layers = 4 + cfg = StackedMixedSSMTransformerLayer.default_config().set( + input_dim=model_dim, + num_layers=num_layers, + transformer_layer_period=3, + transformer_layer_offset=1, + ssm_layer=JambaMambaBlock.default_config().set( + state_dim=state_dim, + mamba_layer=MambaMixerLayer.default_config().set( + recurrence=AlternativeMambaRecurrence.default_config() + ), + ), + dtype=dtype, + ) + cfg.ssm_layer.feed_forward.hidden_dim = hidden_dim + cfg.ssm_layer.mamba_layer.set(dtype=dtype, cache_dtype=None) + cfg.layer.feed_forward.hidden_dim = hidden_dim + cfg.layer.self_attention.attention.num_heads = num_heads + cfg.layer.self_attention.attention.input_linear.set(dtype=dtype, cache_dtype=None) + _test_extend_step(cfg, model_dim=model_dim, dtype=dtype)