Skip to content

Commit a796618

Browse files
Chapamanjosevalim
andauthored
feat: Sharding implementation (#1648)
Co-authored-by: José Valim <[email protected]>
1 parent 6246a4a commit a796618

File tree

3 files changed

+74
-0
lines changed

3 files changed

+74
-0
lines changed

nx/lib/nx/defn.ex

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -849,6 +849,29 @@ defmodule Nx.Defn do
849849
:ok
850850
end
851851

852+
def shard_jit(fun, mesh, opts \\ []) when is_function(fun) and is_list(opts) do
853+
wrap(fun, &shard_jit_apply(fun, mesh, &1, opts))
854+
end
855+
856+
def shard_jit_apply(fun, mesh, args, opts \\ [])
857+
when is_function(fun) and is_list(args) and is_list(opts) do
858+
{on_conflict, opts} = Keyword.pop(opts, :on_conflict, :raise)
859+
860+
cond do
861+
Nx.Defn.current() == nil ->
862+
do_shard_jit_apply(fun, mesh, args, opts)
863+
864+
on_conflict == :raise ->
865+
raise "cannot invoke Shard JITed function when there is a Shard JIT compilation happening"
866+
867+
on_conflict == :force ->
868+
do_shard_jit_apply(fun, mesh, args, opts)
869+
870+
on_conflict == :reuse ->
871+
apply(fun, args)
872+
end
873+
end
874+
852875
defp compile_error!(env, description) do
853876
raise CompileError, line: env.line, file: env.file, description: description
854877
end

nx/lib/nx/defn/compiler.ex

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,33 @@ defmodule Nx.Defn.Compiler do
7373
"""
7474
@callback __to_backend__(keyword) :: {module, keyword}
7575

76+
@doc """
77+
Callback for compilation of a parallelizable computation.
78+
79+
Its main purpose is to compile a function for a given `Nx.Defn.Shard.Mesh`.
80+
81+
Receives an opaque `key` used for caching, a `mesh`, a list of `vars`
82+
in `[vars]`, the function `fun` which builds a defn expression, a list of
83+
argument lists in `args_list`, and the compiler options.
84+
85+
Using `[vars]` instead of a single `vars` allows the compiler to keep one
86+
set of abstract parameters per shard or logical device in the mesh. This is useful
87+
when the tensors are already divided into shards.
88+
"""
89+
@callback __shard_jit__(
90+
key :: term,
91+
mesh :: Nx.Defn.Shard.Mesh.t(),
92+
[vars],
93+
fun :: (vars -> Nx.Container.t()),
94+
args_list :: [[(-> Nx.Tensor.t())]],
95+
opts :: keyword
96+
) :: [Nx.Container.t()]
97+
when vars: [Nx.Container.t()]
98+
99+
@optional_callbacks [
100+
__shard_jit__: 6
101+
]
102+
76103
# Modules allowed in defn
77104
@allowed_modules [Nx.Constants, Nx.Defn, Nx.Defn.Kernel, Nx.LinAlg, Nx.Type]
78105

@@ -265,6 +292,14 @@ defmodule Nx.Defn.Compiler do
265292
{:__block__, [], quoted}
266293
end
267294

295+
def __shard_jit__(fun, mesh, params, args_list, opts) do
296+
{module, runtime_fun, opts} = prepare_options(fun, mesh, opts)
297+
module.__shard_jit__(fun, mesh, params, runtime_fun, args_list, opts)
298+
rescue
299+
e in [UndefinedFunctionError] ->
300+
raise_missing_callback(e, :__shard_jit__, 6, __STACKTRACE__)
301+
end
302+
268303
defp compile_prepare_arities(definitions) do
269304
for {{name, arity}, %{defaults: defaults}} <- definitions,
270305
arity <- (arity - map_size(defaults))..arity,

nx/lib/nx/defn/shard/mesh.ex

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
defmodule Nx.Defn.Mesh do
2+
@moduledoc """
3+
A mesh is a named collection of devices arranged in a logical shape.
4+
5+
`name` is a string identifier for the mesh in the lowered program so that
6+
sharding annotations can refer to a specific device topology without
7+
embedding concrete device handles directly in the intermediate
8+
representation.
9+
10+
`shape` is a tuple describing the logical layout of devices, where each
11+
element is the size of a mesh dimension. For instance, a shape like
12+
`{2, 4}` represents a 2x4 logical grid of devices.
13+
"""
14+
defstruct [:name, :shape]
15+
@type t :: %__MODULE__{name: String.t(), shape: tuple()}
16+
end

0 commit comments

Comments
 (0)