Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions benchmarks/python/compare_recurrent_speed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright © 2023-2024 Apple Inc.

"""
Compare nn.GRU and nn.LSTM speed: legacy vs fast, on GPU and CPU.
Runs recurrent_bench.py for each (impl, device) and prints tables + speedups.

Run from benchmarks/python:
python compare_recurrent_speed.py
"""

import os
import subprocess
import sys


def run_bench(impl: str, device: str, num_runs: int = 3):
"""Run recurrent_bench.py num_runs times; return median (gru_ms, lstm_ms) for stability."""
env = {**os.environ, "MLX_RNN_IMPL": impl}
cwd = os.path.dirname(os.path.abspath(__file__)) or "."
gru_list, lstm_list = [], []

for _ in range(num_runs):
result = subprocess.run(
[sys.executable, "recurrent_bench.py", "--device", device],
env=env,
capture_output=True,
text=True,
cwd=cwd,
timeout=120,
)
if result.returncode != 0:
print(result.stderr or result.stdout, file=sys.stderr)
return float("nan"), float("nan")

gru_ms = lstm_ms = float("nan")
for line in result.stdout.splitlines():
if "nn.GRU(" in line:
try:
gru_ms = float(line.split(":")[-1].strip().replace(" ms", ""))
except ValueError:
pass
if "nn.LSTM(" in line:
try:
lstm_ms = float(line.split(":")[-1].strip().replace(" ms", ""))
except ValueError:
pass
if gru_ms == gru_ms:
gru_list.append(gru_ms)
if lstm_ms == lstm_ms:
lstm_list.append(lstm_ms)

gru_med = float("nan") if not gru_list else sorted(gru_list)[len(gru_list) // 2]
lstm_med = float("nan") if not lstm_list else sorted(lstm_list)[len(lstm_list) // 2]
return gru_med, lstm_med


def main():
print("Comparing nn.GRU / nn.LSTM: legacy (Python-only) vs fast (Metal)")
print("(batch=32, seq=40, input=128, hidden=200, h0/c0 set)")
print(
"(median of 3 runs per config; each run = median of 5×100 iters for stability)"
)
print()

gru_legacy_gpu, lstm_legacy_gpu = run_bench("legacy", "gpu")
gru_fast_gpu, lstm_fast_gpu = run_bench("fast", "gpu")
gru_legacy_cpu, lstm_legacy_cpu = run_bench("legacy", "cpu")
gru_fast_cpu, lstm_fast_cpu = run_bench("fast", "cpu")

def fmt(v):
return f"{v:.3f}" if v == v and v != float("inf") else "N/A"

print("=" * 70)
print("Full layer timings (ms)")
print("=" * 70)
print(
f" {'Layer':<12} {'legacy GPU':>12} {'fast GPU':>12} {'legacy CPU':>12} {'fast CPU':>12}"
)
print(
f" {'nn.GRU':<12} {fmt(gru_legacy_gpu):>12} {fmt(gru_fast_gpu):>12} {fmt(gru_legacy_cpu):>12} {fmt(gru_fast_cpu):>12}"
)
print(
f" {'nn.LSTM':<12} {fmt(lstm_legacy_gpu):>12} {fmt(lstm_fast_gpu):>12} {fmt(lstm_legacy_cpu):>12} {fmt(lstm_fast_cpu):>12}"
)
print()

print("Speedup (fast vs legacy) on same device:")
if gru_fast_gpu > 0 and not (gru_legacy_gpu != gru_legacy_gpu):
print(f" GPU GRU: {gru_legacy_gpu / gru_fast_gpu:.2f}x")
if lstm_fast_gpu > 0 and not (lstm_legacy_gpu != lstm_legacy_gpu):
print(f" GPU LSTM: {lstm_legacy_gpu / lstm_fast_gpu:.2f}x")
if gru_fast_cpu > 0 and not (gru_legacy_cpu != gru_legacy_cpu):
print(f" CPU GRU: {gru_legacy_cpu / gru_fast_cpu:.2f}x")
if lstm_fast_cpu > 0 and not (lstm_legacy_cpu != lstm_legacy_cpu):
print(f" CPU LSTM: {lstm_legacy_cpu / lstm_fast_cpu:.2f}x")
if gru_fast_cpu != gru_fast_cpu or lstm_fast_cpu != lstm_fast_cpu:
print(
" (fast on CPU: Metal kernels are GPU-only; fast path falls back or N/A)"
)
print()

print(
"Legacy: Python-only. Fast: Metal kernels on GPU; on CPU fast uses fallback (may be N/A)."
)
print("To use legacy: export MLX_RNN_IMPL=legacy")
print("See python/mlx/nn/layers/RECURRENT_VERSIONS.md for details.")


if __name__ == "__main__":
main()
71 changes: 71 additions & 0 deletions benchmarks/python/recurrent_bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright © 2023-2024 Apple Inc.

"""
Benchmark nn.GRU and nn.LSTM: legacy (Python-only) vs fast (Metal) implementation.

Run with device and implementation:
MLX_RNN_IMPL=legacy python recurrent_bench.py [--device gpu|cpu]
MLX_RNN_IMPL=fast python recurrent_bench.py [--device gpu|cpu]

Matches backup: h0/c0 set so fast path from step 1, batch=32, seq=40, input=128, hidden=200.
"""

import argparse
import os

import mlx.core as mx
import mlx.nn as nn
from time_utils import measure_runtime


def main():
p = argparse.ArgumentParser(
description="Benchmark nn.GRU / nn.LSTM (legacy vs fast)"
)
p.add_argument(
"--device", choices=("gpu", "cpu"), default="gpu", help="Device to run on"
)
args = p.parse_args()

impl = os.environ.get("MLX_RNN_IMPL", "fast").strip().lower()
if impl not in ("legacy", "fast", "fast_v2"):
impl = "fast"

device = mx.gpu if args.device == "gpu" else mx.cpu
mx.set_default_device(device)

# Match backup benchmark_attentivefp_and_gru.py --layers and RECURRENT_VERSIONS.md
batch, seq_len, input_size, hidden_size = 32, 40, 128, 200
mx.random.seed(0)

gru = nn.GRU(input_size, hidden_size, bias=True)
lstm = nn.LSTM(input_size, hidden_size, bias=True)
x = mx.random.normal(shape=(batch, seq_len, input_size)).astype(mx.float32)
h0 = mx.zeros((batch, hidden_size), dtype=mx.float32)
c0 = mx.zeros((batch, hidden_size), dtype=mx.float32)
mx.eval(x, h0, c0)

def gru_forward():
out = gru(x, h0)
mx.eval(out)
return out

def lstm_forward():
out, _ = lstm(x, h0, c0)
mx.eval(out)
return out

gru_ms = measure_runtime(gru_forward)
lstm_ms = measure_runtime(lstm_forward)

print(f"MLX_RNN_IMPL={impl} device={args.device}")
print(
f" nn.GRU({batch}, {seq_len}, {input_size} -> {hidden_size}): {gru_ms:.3f} ms"
)
print(
f" nn.LSTM({batch}, {seq_len}, {input_size} -> {hidden_size}): {lstm_ms:.3f} ms"
)


if __name__ == "__main__":
main()
21 changes: 13 additions & 8 deletions benchmarks/python/time_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,18 @@ def time_fn(fn, *args, **kwargs):
print(f"{msec:.5f} msec")


def measure_runtime(fn, **kwargs):
# Warmup
for _ in range(5):
def measure_runtime(fn, num_warmup=15, num_iters=100, num_runs=5, **kwargs):
"""Run fn repeatedly and return median ms per call. More stable than a single run."""
# Warmup (enough for GPU to settle)
for _ in range(num_warmup):
fn(**kwargs)

tic = time.perf_counter()
iters = 100
for _ in range(iters):
fn(**kwargs)
return (time.perf_counter() - tic) * 1000 / iters
times_ms = []
for _ in range(num_runs):
tic = time.perf_counter()
for _ in range(num_iters):
fn(**kwargs)
toc = time.perf_counter()
times_ms.append((toc - tic) * 1000 / num_iters)
times_ms.sort()
return times_ms[num_runs // 2] # median
21 changes: 21 additions & 0 deletions mlx/backend/cuda/primitives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,37 @@ namespace mlx::core {
throw std::runtime_error(#func " has no CUDA implementation."); \
}

#if CUDART_VERSION < 12080
void QQMatmul::eval_gpu(const std::vector<array>& inputs, array& out) {
throw std::runtime_error(
"[QQMatmul::eval_gpu] QQMM is only supported with CUDA 12.8 or higher.");
}
#endif

NO_GPU(BlockMaskedMM)
NO_GPU(FFT)
NO_GPU(GatherQMM)
NO_GPU(Hadamard)
NO_GPU_MULTI(LUF)
NO_GPU_MULTI(QRF)
NO_GPU(QuantizedMatmul)
NO_GPU(SegmentedMM)
NO_GPU_MULTI(SVD)
NO_GPU(Inverse)
NO_GPU(Cholesky)

namespace fast {
void FastGruCell::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
outputs = fallback_(inputs);
}
void FastLSTMCell::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
outputs = fallback_(inputs);
}
} // namespace fast
NO_GPU_MULTI(Eig)
NO_GPU_MULTI(Eigh)
NO_GPU(MaskedScatter)
Expand Down
2 changes: 2 additions & 0 deletions mlx/backend/metal/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast_gru_cell.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fast_lstm_cell.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
${CMAKE_CURRENT_SOURCE_DIR}/fft.cpp
${CMAKE_CURRENT_SOURCE_DIR}/hadamard.cpp
Expand Down
98 changes: 98 additions & 0 deletions mlx/backend/metal/fast_gru_cell.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
// Copyright © 2024 Apple Inc.
// Fused GRU cell Metal backend. See Apple Metal docs:
// https://developer.apple.com/documentation/metal

#include "mlx/allocator.h"
#include "mlx/backend/gpu/copy.h"
#include "mlx/backend/metal/device.h"
#include "mlx/backend/metal/utils.h"
#include "mlx/fast_primitives.h"
#include "mlx/utils.h"

namespace mlx::core::fast {

void FastGruCell::eval_cpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
outputs = fallback_(inputs);
}

void FastGruCell::eval_gpu(
const std::vector<array>& inputs,
std::vector<array>& outputs) {
auto& s = stream();
auto& d = metal::device(s.device);

const array& input_proj = inputs[0];
const array& hidden_proj = inputs[1];
const array& hidden_prev = inputs[2];
const bool has_bhn = (inputs.size() == 4);
array& out = outputs[0];

if (input_proj.dtype() != float32 && input_proj.dtype() != bfloat16) {
outputs = fallback_(inputs);
return;
}

out.set_data(allocator::malloc(out.nbytes()));

std::vector<array> copies;
auto copy_if_needed = [&copies, &s](const array& a) -> const array& {
if (a.flags().row_contiguous)
return a;
copies.push_back(contiguous_copy_gpu(a, s));
return copies.back();
};
const array& in_proj = copy_if_needed(input_proj);
const array& hid_proj = copy_if_needed(hidden_proj);
const array& hid_prev = copy_if_needed(hidden_prev);
const array* bhn_ptr = has_bhn ? &copy_if_needed(inputs[3]) : nullptr;

size_t batch_size = in_proj.shape(0);
size_t hidden_size = hid_prev.shape(1);
uint32_t h_quads = (static_cast<uint32_t>(hidden_size) + 3) / 4;
uint32_t total_threads = static_cast<uint32_t>(batch_size) * h_quads;

std::string kname;
if (has_bhn) {
kname = (in_proj.dtype() == bfloat16) ? "gru_cell_fused_bfloat16_bias"
: "gru_cell_fused_float_bias";
} else {
kname = (in_proj.dtype() == bfloat16) ? "gru_cell_fused_bfloat16"
: "gru_cell_fused_float";
}

auto& enc = d.get_command_encoder(s.index);
auto kernel = d.get_kernel(kname);
enc.set_compute_pipeline_state(kernel);

enc.set_input_array(in_proj, 0);
enc.set_input_array(hid_proj, 1);
enc.set_input_array(hid_prev, 2);
if (has_bhn) {
enc.set_input_array(*bhn_ptr, 3);
enc.set_output_array(out, 4);
} else {
enc.set_output_array(out, 3);
}
uint32_t bs = static_cast<uint32_t>(batch_size);
uint32_t hs = static_cast<uint32_t>(hidden_size);
uint32_t bytes_base = has_bhn ? 5u : 4u;
enc.set_bytes(bs, bytes_base);
enc.set_bytes(hs, bytes_base + 1);

constexpr uint32_t threads_per_group = 512;
uint32_t num_groups =
(total_threads + threads_per_group - 1) / threads_per_group;
MTL::Size grid_dims(num_groups, 1, 1);
MTL::Size group_dims(threads_per_group, 1, 1);
enc.dispatch_threadgroups(grid_dims, group_dims);

d.add_temporaries(std::move(copies), s.index);
}

bool FastGruCell::is_equivalent(const Primitive& other) const {
return other.name() == name();
}

} // namespace mlx::core::fast
Loading