Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
ScaledTensor1x,
ScaledTensor2x,
GroupedScaledTensor1x,
GroupedNoScaleTensor,
ScalingMode,
QuantizerFactory,
QuantizeLayout,
Expand Down Expand Up @@ -1787,13 +1788,18 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout):
ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)

# jitting grouped_gemm
lhs_tensor = GroupedNoScaleTensor(
data=lhs, first_dims=group_sizes, last_dims=None, group_axis=0, original_shape=lhs.shape
)
rhs_tensor = GroupedNoScaleTensor(
data=rhs, first_dims=None, last_dims=None, group_axis=0, original_shape=rhs.shape
)
prim_out = jax.jit(
tex.grouped_gemm, static_argnames=("contracting_dims", "use_async_d2h_group_sizes")
)(
lhs,
rhs,
group_sizes,
contracting_dims,
lhs_tensor,
rhs_tensor,
contracting_dims=contracting_dims,
use_async_d2h_group_sizes=True,
)

Expand Down Expand Up @@ -1825,8 +1831,17 @@ def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout
)
ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims)

lhs_tensor = GroupedNoScaleTensor(
data=lhs, first_dims=group_sizes, last_dims=None, group_axis=0, original_shape=lhs.shape
)
rhs_tensor = GroupedNoScaleTensor(
data=rhs, first_dims=None, last_dims=None, group_axis=0, original_shape=rhs.shape
)
prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))(
lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set
lhs_tensor,
rhs_tensor,
contracting_dims=contracting_dims,
quantizer_set=quantizer_set,
)

allclose_dtype = jnp.float8_e4m3fn
Expand Down
Loading