Skip to content

Data race: process-global InTracing::trace_stack_ static is not thread_local; concurrent compile/grad/vmap traces race other threads #3620

@al8n

Description

@al8n

Summary

mlx::core::detail::InTracing::trace_stack() returns a reference to a function-local static vector that is shared across the entire process — it is not thread_local:

mlx/transforms.cpp (v0.31.2)

std::vector<std::pair<char, char>>& detail::InTracing::trace_stack() {
  static std::vector<std::pair<char, char>> trace_stack_;
  return trace_stack_;
}

Every transform that builds a graph constructs an InTracing, which push_backs on construction and pop_backs on destruction (mlx/transforms_impl.h). This includes compile (via compile_trace), vjp, jvp, value_and_grad, and vmap. Ordinary array ops also read this state through detail::in_tracing() / in_dynamic_tracing() to decide their behavior.

Because the vector is process-global and unsynchronized, two threads touching it concurrently is a data race (undefined behavior):

  • thread A enters compile(...) / grad(...) → the InTracing constructor push_backs the global vector, while
  • thread B concurrently enters any transform (another grad / vmap / compile), or even runs an ordinary op that reads in_tracing(),

so two threads perform concurrent read/write (or write/write) on the same std::vector with no lock → UB (torn reads, a reallocation under a concurrent reader, an inconsistent depth, etc.).

Note that compile's compiler_cache_ is thread_local, but the trace stack it pushes is not — so even the compile path is exposed.

Reproduction (conceptual)

Two threads against the same library, one repeatedly tracing and one operating:

import threading, mlx.core as mx

def tracer():
    f = mx.compile(lambda x: x * x + 1)
    for _ in range(100000):
        mx.eval(f(mx.array([1.0, 2.0, 3.0])))

def grader():
    g = mx.grad(lambda x: (x * x).sum())
    for _ in range(100000):
        mx.eval(g(mx.array([1.0, 2.0, 3.0])))

t1 = threading.Thread(target=tracer)
t2 = threading.Thread(target=grader)
t1.start(); t2.start(); t1.join(); t2.join()

Python's GIL can mask this for pure-Python drivers, but the race is in the C++ core regardless of the front-end: it reproduces reliably from native / multi-threaded bindings and is flagged by ThreadSanitizer on the C++ side.

Impact

Any multi-threaded use where one thread traces (compile / grad / vjp / vmap) while another thread traces or runs ops is UB. This is a problem for native bindings (mlx-swift, mlx-c consumers) that expose compile / grad as ordinary, thread-safe-looking APIs: a safe wrapper cannot soundly expose them without serializing every op against the trace stack, which is impractical and defeats the point of a thin binding.

Suggested fix

Make the tracing state thread-local:

static thread_local std::vector<std::pair<char, char>> trace_stack_;

(and likewise for whatever backs in_dynamic_tracing()), or guard the trace stack with internal synchronization. thread_local matches the intent — tracing depth is a property of the calling thread's graph construction, not of the whole process — and would make these transforms safe to use concurrently across threads.

Version

MLX v0.31.2 (68cf2fdd).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions