From 09a940d46bb1d9f64afba6b3b2d3c183f81c155b Mon Sep 17 00:00:00 2001 From: Vishesh <87526302+vishesh9131@users.noreply.github.com> Date: Fri, 6 Sep 2024 18:05:26 +0530 Subject: [PATCH 1/2] ssm_enhancement These enhancements provide additional flexibility and options for implementing and experimenting with different recurrence methods in the Mamba and Jamba models, potentially improving performance and accuracy for various tasks. --- axlearn/common/ssm.py | 90 ++++++++++++++++++++ axlearn/common/ssm_test.py | 168 +++++++++++++++++++++++++++++++++++++ 2 files changed, 258 insertions(+) diff --git a/axlearn/common/ssm.py b/axlearn/common/ssm.py index 83d97e64a..4781df87a 100644 --- a/axlearn/common/ssm.py +++ b/axlearn/common/ssm.py @@ -513,6 +513,7 @@ class MambaMixerLayer(BaseLayer): Can be substituted for a MultiheadAttention layer. """ + @config_class class Config(BaseLayer.Config): """Configures a MambaMixerLayer.""" @@ -882,6 +883,30 @@ def _conv_update( conv_state = conv_state + bias return conv_state, new_cache + + def _full_sequence_forward( + self, inputs: Tensor, *, recurrence: BaseMambaRecurrence + ) -> MambaOutput: + """Computes the Mamba layer output from a full sequence of inputs. + + Args: + inputs: A Tensor of shape [batch_size, seq_len, input_dim]. + recurrence: A BaseMambaRecurrence to use for computing the recurrence. + + Returns: + A MambaOutput. + """ + conv_input, res = self._project_input(inputs) + conv_states = jax.nn.silu(self.conv(conv_input)) + # Get "continuous" ssm parameters. + a, b, c, delta, d = self._ssm_parameters(conv_states) + recurrence_output = recurrence(conv_states, a=a, b=b, c=c, delta=delta, d=d) + output = self._output_from_states(recurrence_output.data, res=res) + return MambaMixerLayer.MambaOutput( + data=output, conv_input=conv_input, states=recurrence_output.states + ) + + def _single_step_ssm_update( self, inputs: Tensor, @@ -971,6 +996,30 @@ def extend_step( class JambaMixerLayer(MambaMixerLayer): """A Jamba-style Mamba layer, which norms the input-dependent SSM parameters.""" + def _ssm_parameters(self, inputs: Tensor) -> MambaMixerLayer.SSMParameters: + """Computes layer-normed versions of the input-dependent SSM parameters. + + Args: + inputs: [batch_size, seq_len, inner_dim] + + Returns: + An instance of MambaMixerLayer.SSMParameters. + """ + cfg = self.config + x_dbl = self.x_proj(inputs) # [batch_size, seq_len, dt_rank, state_dim*2] + dt, b, c = jnp.split( + x_dbl, + ( + self.dt_rank, + self.dt_rank + cfg.state_dim, + ), + axis=-1, + ) + dt, b, c = self.dt_norm(dt), self.b_norm(b), self.c_norm(c) + delta = jax.nn.softplus(self.dt_proj(dt)) # [batch_size, seq_len, inner_dim] + a = -jnp.exp(_at_least_float32(self.parameters["log_a"])).astype(inputs.dtype) + return MambaMixerLayer.SSMParameters(a=a, b=b, c=c, delta=delta, d=self.parameters["d"]) + @config_class class Config(MambaMixerLayer.Config): dt_norm: InstantiableConfig = RMSNorm.default_config() @@ -1445,3 +1494,44 @@ def __init__(self, cfg: Config, *, parent: Optional[Module]): for i in range(cfg.num_layers) ] super().__init__(cfg.set(layer=layers), parent=parent) + + + +class HybridMambaRecurrence(BaseMambaRecurrence): + """A layer that combines different recurrence methods to leverage their strengths.""" + + @config_class + class Config(BaseMambaRecurrence.Config): + """Configures a HybridMambaRecurrence.""" + primary_recurrence: BaseMambaRecurrence = LinearScanMambaRecurrence.default_config() + secondary_recurrence: BaseMambaRecurrence = AssociativeScanMambaRecurrence.default_config() + + def __init__(self, cfg: Config, *, parent: Module): + super().__init__(cfg, parent=parent) + self._add_child("primary_recurrence", cfg.primary_recurrence) + self._add_child("secondary_recurrence", cfg.secondary_recurrence) + + def forward( + self, x: Tensor, *, a: Tensor, b: Tensor, c: Tensor, delta: Tensor, d: Tensor + ) -> BaseMambaRecurrence.Output: + primary_output = self.primary_recurrence(x, a=a, b=b, c=c, delta=delta, d=d) + secondary_output = self.secondary_recurrence(x, a=a, b=b, c=c, delta=delta, d=d) + combined_data = (primary_output.data + secondary_output.data) / 2 + combined_states = (primary_output.states + secondary_output.states) / 2 if primary_output.states is not None and secondary_output.states is not None else None + return BaseMambaRecurrence.Output(data=combined_data, states=combined_states) + + +class AlternativeMambaRecurrence(BaseMambaRecurrence): + """A layer that implements an alternative recurrence method.""" + + def forward( + self, x: Tensor, *, a: Tensor, b: Tensor, c: Tensor, delta: Tensor, d: Tensor + ) -> BaseMambaRecurrence.Output: + # Implement an alternative recurrence method here. + # For demonstration, let's use a simple weighted sum of inputs and parameters. + weighted_sum = jnp.einsum("btd,sd->btsd", x, a) + jnp.einsum("bts,sd->btsd", b, c) + y = jnp.sum(weighted_sum, axis=-2) + d * x + states = weighted_sum if self.config.output_mode == MambaRecurrenceOutputMode.OUTPUTS_AND_STATES else None + return BaseMambaRecurrence.Output(data=y, states=states) + + diff --git a/axlearn/common/ssm_test.py b/axlearn/common/ssm_test.py index 3d8351285..dccd6d8e8 100644 --- a/axlearn/common/ssm_test.py +++ b/axlearn/common/ssm_test.py @@ -39,6 +39,8 @@ RepeatedSSMLayer, StackedMixedSSMTransformerLayer, StackedSSMLayer, + HybridMambaRecurrence, + AlternativeMambaRecurrence, ) from axlearn.common.test_utils import TestCase, assert_allclose from axlearn.common.utils import Nested, Tensor, cast_floats @@ -509,6 +511,66 @@ def test_prefill_states(self, dtype: jnp.dtype): assert_allclose(decoder_output, forward_outputs.data, atol=1e-6) + @parameterized.parameters(jnp.float32, jnp.bfloat16) + def test_hybrid_recurrence(self, dtype: jnp.dtype): + model_dim = 4 + state_dim = 16 + cfg = MambaMixerLayer.default_config().set( + input_dim=model_dim, + state_dim=state_dim, + cache_dtype=dtype, + dtype=dtype, + ) + layer: MambaMixerLayer = cfg.set(name="test").instantiate(parent=None) + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + layer_params = cast_floats(layer_params, to_dtype=dtype) + batch_size, tgt_len = 2, 6 + query = jax.random.normal( + jax.random.PRNGKey(1), + [batch_size, tgt_len, model_dim], + dtype=dtype, + ) + inputs = dict(query=query) + forward_outputs, _ = F( + layer, + state=layer_params, + is_training=False, + prng_key=jax.random.PRNGKey(2), + inputs=inputs, + recurrence=HybridMambaRecurrence.default_config().instantiate(parent=None), + ) + assert forward_outputs.data.shape == (batch_size, tgt_len, model_dim) + + @parameterized.parameters(jnp.float32, jnp.bfloat16) + def test_alternative_recurrence(self, dtype: jnp.dtype): + model_dim = 4 + state_dim = 16 + cfg = MambaMixerLayer.default_config().set( + input_dim=model_dim, + state_dim=state_dim, + cache_dtype=dtype, + dtype=dtype, + ) + layer: MambaMixerLayer = cfg.set(name="test").instantiate(parent=None) + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + layer_params = cast_floats(layer_params, to_dtype=dtype) + batch_size, tgt_len = 2, 6 + query = jax.random.normal( + jax.random.PRNGKey(1), + [batch_size, tgt_len, model_dim], + dtype=dtype, + ) + inputs = dict(query=query) + forward_outputs, _ = F( + layer, + state=layer_params, + is_training=False, + prng_key=jax.random.PRNGKey(2), + inputs=inputs, + recurrence=AlternativeMambaRecurrence.default_config().instantiate(parent=None), + ) + assert forward_outputs.data.shape == (batch_size, tgt_len, model_dim) + def _test_extend_step(layer_cfg: InstantiableConfig, *, model_dim: int, dtype: jnp.dtype): """Tests extend for composite layers.""" @@ -834,6 +896,58 @@ def test_prefill(self, block_klass: MambaBlock, dtype: jnp.dtype): _test_prefill_states(cfg, model_dim=model_dim, dtype=dtype) + @parameterized.product( + block_klass=(MambaBlock, JambaMambaBlock), + dtype=(jnp.float32, jnp.bfloat16), + ) + def test_hybrid_recurrence_in_block(self, block_klass: MambaBlock, dtype: jnp.dtype): + model_dim = 16 + state_dim = 16 + hidden_dim = 32 + num_layers = 3 + + cfg = StackedSSMLayer.default_config().set( + input_dim=model_dim, + num_layers=num_layers, + layer=block_klass.default_config().set( + state_dim=state_dim, + mamba_layer=MambaMixerLayer.default_config().set( + recurrence=HybridMambaRecurrence.default_config() + ) + ), + ) + cfg.layer.mamba_layer.set(dtype=dtype, cache_dtype=None) + if hasattr(cfg.layer, "feed_forward"): + cfg.layer.feed_forward.hidden_dim = hidden_dim + + _test_extend_step(cfg, model_dim=model_dim, dtype=dtype) + + @parameterized.product( + block_klass=(MambaBlock, JambaMambaBlock), + dtype=(jnp.float32, jnp.bfloat16), + ) + def test_alternative_recurrence_in_block(self, block_klass: MambaBlock, dtype: jnp.dtype): + model_dim = 16 + state_dim = 16 + hidden_dim = 32 + num_layers = 3 + + cfg = StackedSSMLayer.default_config().set( + input_dim=model_dim, + num_layers=num_layers, + layer=block_klass.default_config().set( + state_dim=state_dim, + mamba_layer=MambaMixerLayer.default_config().set( + recurrence=AlternativeMambaRecurrence.default_config() + ) + ), + ) + cfg.layer.mamba_layer.set(dtype=dtype, cache_dtype=None) + if hasattr(cfg.layer, "feed_forward"): + cfg.layer.feed_forward.hidden_dim = hidden_dim + + _test_extend_step(cfg, model_dim=model_dim, dtype=dtype) + class StackedMixedSSMTransformerTest(TestCase): """Tests that mixing SSM layers and transformer layers behaves as expected.""" @@ -927,3 +1041,57 @@ def test_prefill(self, dtype: jnp.dtype): cfg.layer.self_attention.attention.num_heads = num_heads cfg.layer.self_attention.attention.input_linear.set(dtype=dtype, cache_dtype=None) _test_prefill_states(cfg, model_dim=model_dim, dtype=dtype) + + @parameterized.parameters(jnp.float32, jnp.bfloat16) + def test_hybrid_recurrence_in_mixed_layer(self, dtype: jnp.dtype): + model_dim = 16 + state_dim = 16 + num_heads = 4 + hidden_dim = 32 + num_layers = 4 + cfg = StackedMixedSSMTransformerLayer.default_config().set( + input_dim=model_dim, + num_layers=num_layers, + transformer_layer_period=3, + transformer_layer_offset=1, + ssm_layer=JambaMambaBlock.default_config().set( + state_dim=state_dim, + mamba_layer=MambaMixerLayer.default_config().set( + recurrence=HybridMambaRecurrence.default_config() + ) + ), + dtype=dtype, + ) + cfg.ssm_layer.feed_forward.hidden_dim = hidden_dim + cfg.ssm_layer.mamba_layer.set(dtype=dtype, cache_dtype=None) + cfg.layer.feed_forward.hidden_dim = hidden_dim + cfg.layer.self_attention.attention.num_heads = num_heads + cfg.layer.self_attention.attention.input_linear.set(dtype=dtype, cache_dtype=None) + _test_extend_step(cfg, model_dim=model_dim, dtype=dtype) + + @parameterized.parameters(jnp.float32, jnp.bfloat16) + def test_alternative_recurrence_in_mixed_layer(self, dtype: jnp.dtype): + model_dim = 16 + state_dim = 16 + num_heads = 4 + hidden_dim = 32 + num_layers = 4 + cfg = StackedMixedSSMTransformerLayer.default_config().set( + input_dim=model_dim, + num_layers=num_layers, + transformer_layer_period=3, + transformer_layer_offset=1, + ssm_layer=JambaMambaBlock.default_config().set( + state_dim=state_dim, + mamba_layer=MambaMixerLayer.default_config().set( + recurrence=AlternativeMambaRecurrence.default_config() + ) + ), + dtype=dtype, + ) + cfg.ssm_layer.feed_forward.hidden_dim = hidden_dim + cfg.ssm_layer.mamba_layer.set(dtype=dtype, cache_dtype=None) + cfg.layer.feed_forward.hidden_dim = hidden_dim + cfg.layer.self_attention.attention.num_heads = num_heads + cfg.layer.self_attention.attention.input_linear.set(dtype=dtype, cache_dtype=None) + _test_extend_step(cfg, model_dim=model_dim, dtype=dtype) From 8a0982fd2d8ce9587ba8e7673b8f840e448aebd9 Mon Sep 17 00:00:00 2001 From: Vishesh <87526302+vishesh9131@users.noreply.github.com> Date: Thu, 12 Sep 2024 02:39:03 +0530 Subject: [PATCH 2/2] some fixes - fixed functions redundant definitions - fixed Incorrect Module Import in layers.py --- .../common/quantized_dot_general/layers.py | 3 +- axlearn/common/ssm.py | 89 +++---------------- axlearn/common/ssm_test.py | 14 ++- 3 files changed, 19 insertions(+), 87 deletions(-) diff --git a/axlearn/common/quantized_dot_general/layers.py b/axlearn/common/quantized_dot_general/layers.py index e32fefe3c..2935a5d15 100644 --- a/axlearn/common/quantized_dot_general/layers.py +++ b/axlearn/common/quantized_dot_general/layers.py @@ -27,6 +27,7 @@ import jax from absl import logging from aqt.jax.v2 import aqt_dot_general +from aqt.jax.v2 import utils as aqt_utils from jax import numpy as jnp from jax.lax import DotDimensionNumbers, Precision from jax.typing import DTypeLike @@ -79,7 +80,7 @@ def __call__( dimension_numbers: DotDimensionNumbers, precision: PrecisionLike = None, preferred_element_type: Optional[DTypeLike] = None, - context: aqt_dot_general.Context = aqt_dot_general.Context(key=None, train_step=None), + context: aqt_utils.Context = aqt_utils.Context(key=None, train_step=None), ) -> Tensor: ... diff --git a/axlearn/common/ssm.py b/axlearn/common/ssm.py index 4781df87a..98798687e 100644 --- a/axlearn/common/ssm.py +++ b/axlearn/common/ssm.py @@ -513,7 +513,6 @@ class MambaMixerLayer(BaseLayer): Can be substituted for a MultiheadAttention layer. """ - @config_class class Config(BaseLayer.Config): """Configures a MambaMixerLayer.""" @@ -883,30 +882,6 @@ def _conv_update( conv_state = conv_state + bias return conv_state, new_cache - - def _full_sequence_forward( - self, inputs: Tensor, *, recurrence: BaseMambaRecurrence - ) -> MambaOutput: - """Computes the Mamba layer output from a full sequence of inputs. - - Args: - inputs: A Tensor of shape [batch_size, seq_len, input_dim]. - recurrence: A BaseMambaRecurrence to use for computing the recurrence. - - Returns: - A MambaOutput. - """ - conv_input, res = self._project_input(inputs) - conv_states = jax.nn.silu(self.conv(conv_input)) - # Get "continuous" ssm parameters. - a, b, c, delta, d = self._ssm_parameters(conv_states) - recurrence_output = recurrence(conv_states, a=a, b=b, c=c, delta=delta, d=d) - output = self._output_from_states(recurrence_output.data, res=res) - return MambaMixerLayer.MambaOutput( - data=output, conv_input=conv_input, states=recurrence_output.states - ) - - def _single_step_ssm_update( self, inputs: Tensor, @@ -996,30 +971,6 @@ def extend_step( class JambaMixerLayer(MambaMixerLayer): """A Jamba-style Mamba layer, which norms the input-dependent SSM parameters.""" - def _ssm_parameters(self, inputs: Tensor) -> MambaMixerLayer.SSMParameters: - """Computes layer-normed versions of the input-dependent SSM parameters. - - Args: - inputs: [batch_size, seq_len, inner_dim] - - Returns: - An instance of MambaMixerLayer.SSMParameters. - """ - cfg = self.config - x_dbl = self.x_proj(inputs) # [batch_size, seq_len, dt_rank, state_dim*2] - dt, b, c = jnp.split( - x_dbl, - ( - self.dt_rank, - self.dt_rank + cfg.state_dim, - ), - axis=-1, - ) - dt, b, c = self.dt_norm(dt), self.b_norm(b), self.c_norm(c) - delta = jax.nn.softplus(self.dt_proj(dt)) # [batch_size, seq_len, inner_dim] - a = -jnp.exp(_at_least_float32(self.parameters["log_a"])).astype(inputs.dtype) - return MambaMixerLayer.SSMParameters(a=a, b=b, c=c, delta=delta, d=self.parameters["d"]) - @config_class class Config(MambaMixerLayer.Config): dt_norm: InstantiableConfig = RMSNorm.default_config() @@ -1033,30 +984,6 @@ def __init__(self, cfg: Config, *, parent: Module): self._add_child("b_norm", cfg.b_norm.set(input_dim=cfg.state_dim)) self._add_child("c_norm", cfg.c_norm.set(input_dim=cfg.state_dim)) - def _ssm_parameters(self, inputs: Tensor) -> MambaMixerLayer.SSMParameters: - """Computes layer-normed versions of the input-dependent SSM parameters. - - Args: - inputs: [batch_size, seq_len, inner_dim] - - Returns: - An instance of MambaMixerLayer.SSMParameters. - """ - cfg = self.config - x_dbl = self.x_proj(inputs) # [batch_size, seq_len, dt_rank, state_dim*2] - dt, b, c = jnp.split( - x_dbl, - ( - self.dt_rank, - self.dt_rank + cfg.state_dim, - ), - axis=-1, - ) - dt, b, c = self.dt_norm(dt), self.b_norm(b), self.c_norm(c) - delta = jax.nn.softplus(self.dt_proj(dt)) # [batch_size, seq_len, inner_dim] - a = -jnp.exp(_at_least_float32(self.parameters["log_a"])).astype(inputs.dtype) - return MambaMixerLayer.SSMParameters(a=a, b=b, c=c, delta=delta, d=self.parameters["d"]) - class BaseSSMLayer(BaseLayer): """An abstract class representing SSM layers. @@ -1496,13 +1423,13 @@ def __init__(self, cfg: Config, *, parent: Optional[Module]): super().__init__(cfg.set(layer=layers), parent=parent) - class HybridMambaRecurrence(BaseMambaRecurrence): """A layer that combines different recurrence methods to leverage their strengths.""" @config_class class Config(BaseMambaRecurrence.Config): """Configures a HybridMambaRecurrence.""" + primary_recurrence: BaseMambaRecurrence = LinearScanMambaRecurrence.default_config() secondary_recurrence: BaseMambaRecurrence = AssociativeScanMambaRecurrence.default_config() @@ -1517,7 +1444,11 @@ def forward( primary_output = self.primary_recurrence(x, a=a, b=b, c=c, delta=delta, d=d) secondary_output = self.secondary_recurrence(x, a=a, b=b, c=c, delta=delta, d=d) combined_data = (primary_output.data + secondary_output.data) / 2 - combined_states = (primary_output.states + secondary_output.states) / 2 if primary_output.states is not None and secondary_output.states is not None else None + combined_states = ( + (primary_output.states + secondary_output.states) / 2 + if primary_output.states is not None and secondary_output.states is not None + else None + ) return BaseMambaRecurrence.Output(data=combined_data, states=combined_states) @@ -1531,7 +1462,9 @@ def forward( # For demonstration, let's use a simple weighted sum of inputs and parameters. weighted_sum = jnp.einsum("btd,sd->btsd", x, a) + jnp.einsum("bts,sd->btsd", b, c) y = jnp.sum(weighted_sum, axis=-2) + d * x - states = weighted_sum if self.config.output_mode == MambaRecurrenceOutputMode.OUTPUTS_AND_STATES else None + states = ( + weighted_sum + if self.config.output_mode == MambaRecurrenceOutputMode.OUTPUTS_AND_STATES + else None + ) return BaseMambaRecurrence.Output(data=y, states=states) - - diff --git a/axlearn/common/ssm_test.py b/axlearn/common/ssm_test.py index dccd6d8e8..719f5d094 100644 --- a/axlearn/common/ssm_test.py +++ b/axlearn/common/ssm_test.py @@ -30,8 +30,10 @@ from axlearn.common.config import InstantiableConfig from axlearn.common.module import functional as F from axlearn.common.ssm import ( + AlternativeMambaRecurrence, AssociativeScanMambaRecurrence, BlockResidualMode, + HybridMambaRecurrence, JambaMambaBlock, LinearScanMambaRecurrence, MambaBlock, @@ -39,8 +41,6 @@ RepeatedSSMLayer, StackedMixedSSMTransformerLayer, StackedSSMLayer, - HybridMambaRecurrence, - AlternativeMambaRecurrence, ) from axlearn.common.test_utils import TestCase, assert_allclose from axlearn.common.utils import Nested, Tensor, cast_floats @@ -537,7 +537,6 @@ def test_hybrid_recurrence(self, dtype: jnp.dtype): is_training=False, prng_key=jax.random.PRNGKey(2), inputs=inputs, - recurrence=HybridMambaRecurrence.default_config().instantiate(parent=None), ) assert forward_outputs.data.shape == (batch_size, tgt_len, model_dim) @@ -567,7 +566,6 @@ def test_alternative_recurrence(self, dtype: jnp.dtype): is_training=False, prng_key=jax.random.PRNGKey(2), inputs=inputs, - recurrence=AlternativeMambaRecurrence.default_config().instantiate(parent=None), ) assert forward_outputs.data.shape == (batch_size, tgt_len, model_dim) @@ -913,7 +911,7 @@ def test_hybrid_recurrence_in_block(self, block_klass: MambaBlock, dtype: jnp.dt state_dim=state_dim, mamba_layer=MambaMixerLayer.default_config().set( recurrence=HybridMambaRecurrence.default_config() - ) + ), ), ) cfg.layer.mamba_layer.set(dtype=dtype, cache_dtype=None) @@ -939,7 +937,7 @@ def test_alternative_recurrence_in_block(self, block_klass: MambaBlock, dtype: j state_dim=state_dim, mamba_layer=MambaMixerLayer.default_config().set( recurrence=AlternativeMambaRecurrence.default_config() - ) + ), ), ) cfg.layer.mamba_layer.set(dtype=dtype, cache_dtype=None) @@ -1058,7 +1056,7 @@ def test_hybrid_recurrence_in_mixed_layer(self, dtype: jnp.dtype): state_dim=state_dim, mamba_layer=MambaMixerLayer.default_config().set( recurrence=HybridMambaRecurrence.default_config() - ) + ), ), dtype=dtype, ) @@ -1085,7 +1083,7 @@ def test_alternative_recurrence_in_mixed_layer(self, dtype: jnp.dtype): state_dim=state_dim, mamba_layer=MambaMixerLayer.default_config().set( recurrence=AlternativeMambaRecurrence.default_config() - ) + ), ), dtype=dtype, )