Skip to content

Commit 1b9c85c

Browse files
gufengcgufengc
andauthored
chore(mlx): upgrade mlx and mlx-lm (#452)
Co-authored-by: gufengc <gufeng@graident.network>
1 parent e16856c commit 1b9c85c

11 files changed

Lines changed: 48 additions & 18 deletions

pyproject.toml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,23 +43,23 @@ parallax = "parallax.cli:main"
4343
[project.optional-dependencies]
4444

4545
mac = [
46-
"nanobind==2.10.2",
46+
"nanobind==2.12.0",
4747
"torch==2.8.0",
48-
"mlx-lm==0.30.6",
49-
"mlx==0.30.4",
48+
"mlx-lm==0.31.3",
49+
"mlx==0.31.2",
5050
]
5151

5252
gpu = [
5353
"sglang[all]==0.5.12",
5454
"accelerate",
55-
"mlx-lm==0.28.4",
56-
"mlx[cpu]==0.30.0",
55+
"mlx-lm==0.31.3",
56+
"mlx[cpu]==0.31.2",
5757
]
5858

5959
vllm = [
6060
"vllm==0.14.0",
61-
"mlx-lm==0.28.4",
62-
"mlx[cpu]==0.30.0",
61+
"mlx-lm==0.31.3",
62+
"mlx[cpu]==0.31.2",
6363
]
6464

6565
benchmark = [

src/parallax/utils/utils.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,7 @@ def is_mps_available():
2727
def is_metal_available():
2828
"""Check if MLX Metal backend is available"""
2929
try:
30-
import mlx.core as mx
31-
32-
mx.metal.device_info()
33-
return True
30+
return mx.metal.is_available()
3431
except (RuntimeError, AttributeError, ImportError):
3532
return False
3633

@@ -43,7 +40,7 @@ def get_current_device():
4340
device = "cpu"
4441
if is_cuda_available():
4542
device = "cuda"
46-
if is_mps_available():
43+
if is_metal_available():
4744
device = "mlx"
4845
return device
4946

src/parallax_extensions/kernels/paged_attention.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ void PagedAttentionV1::eval_gpu(
120120
auto kernel = d.get_kernel(kname, lib, hash_name, func_consts);
121121

122122
// Prepare to encode kernel
123-
auto& compute_encoder = d.get_command_encoder(s.index);
123+
auto& compute_encoder = mx::metal::get_command_encoder(s);
124124
compute_encoder.set_compute_pipeline_state(kernel);
125125

126126
// Shared Memory

src/parallax_extensions/kernels/reshape_and_cache.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ namespace parallax_ext {
1515
mx::array reshape_and_cache(
1616
const mx::array& key, // [num_tokens, num_heads, head_size]
1717
const mx::array& value, // [num_tokens, num_heads, head_size]
18-
mx::array& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
19-
mx::array& value_cache, // [num_blocks, num_heads, head_size/x, block_size]
18+
const mx::array& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
19+
const mx::array& value_cache, // [num_blocks, num_heads, head_size/x, block_size]
2020
const mx::array& slot_mapping, // [num_tokens]
2121
mx::StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
2222
) {
@@ -88,7 +88,7 @@ void ReshapeAndCache::eval_gpu(
8888
auto kernel = d.get_kernel(kname, lib, hash_name, func_consts);
8989

9090
// Prepare to encode kernel
91-
auto& compute_encoder = d.get_command_encoder(s.index);
91+
auto& compute_encoder = mx::metal::get_command_encoder(s);
9292
compute_encoder.set_compute_pipeline_state(kernel);
9393

9494
// Calculate parameters

src/parallax_extensions/kernels/reshape_and_cache.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ namespace parallax_ext {
88
mx::array reshape_and_cache(
99
const mx::array& key, // [num_tokens, num_heads, head_size]
1010
const mx::array& value, // [num_tokens, num_heads, head_size]
11-
mx::array& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
12-
mx::array& value_cache, // [num_blocks, num_heads, head_size/x, block_size]
11+
const mx::array& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
12+
const mx::array& value_cache, // [num_blocks, num_heads, head_size/x, block_size]
1313
const mx::array& slot_mapping, // [num_tokens]
1414
mx::StreamOrDevice s /* = {} */ // Stream on which to schedule the operation
1515
);
240 Bytes
Binary file not shown.
224 Bytes
Binary file not shown.
240 Bytes
Binary file not shown.
80 Bytes
Binary file not shown.
8.58 KB
Binary file not shown.

0 commit comments

Comments
 (0)