From 4cd50dd02fa3a385006c8198297a9912d3b50dc8 Mon Sep 17 00:00:00 2001 From: stroxler Date: Thu, 21 May 2026 21:19:07 -0700 Subject: [PATCH 01/25] Add `uses_shape_dsl` to `shape_extensions` Summary: Add the `uses_shape_dsl` decorator to `shape_extensions`, the public API decorator that associates a shape DSL function with an API function in library stubs. At runtime it's a no-op passthrough; Pyrefly will use it at type-checking time to route bound arguments through the shape DSL for return-type refinement. Also register `UsesShapeDsl` as a `SpecialExport` variant in the Rust export system, so that later phases can detect `uses_shape_dsl` decorators during binding. Differential Revision: D105696519 --- pyrefly/lib/export/special.rs | 3 +++ .../fixtures/shape_extensions/__init__.py | 19 +++++++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/pyrefly/lib/export/special.rs b/pyrefly/lib/export/special.rs index 3d38a6f98a..35cedbc09f 100644 --- a/pyrefly/lib/export/special.rs +++ b/pyrefly/lib/export/special.rs @@ -71,6 +71,7 @@ pub enum SpecialExport { Final, TypingMapping, TypeForm, + UsesShapeDsl, } impl SpecialExport { @@ -133,6 +134,7 @@ impl SpecialExport { "Final" => Some(Self::Final), "Mapping" => Some(Self::TypingMapping), "TypeForm" => Some(Self::TypeForm), + "uses_shape_dsl" => Some(Self::UsesShapeDsl), _ => None, } } @@ -204,6 +206,7 @@ impl SpecialExport { "typing" | "typing_extensions" | "collections.abc" ), Self::Deprecated => matches!(m.as_str(), "warnings" | "typing_extensions"), + Self::UsesShapeDsl => matches!(m.as_str(), "shape_extensions"), } } diff --git a/test/tensor_shapes/fixtures/shape_extensions/__init__.py b/test/tensor_shapes/fixtures/shape_extensions/__init__.py index 52f764be46..990d24529f 100644 --- a/test/tensor_shapes/fixtures/shape_extensions/__init__.py +++ b/test/tensor_shapes/fixtures/shape_extensions/__init__.py @@ -149,3 +149,22 @@ def __iter__(self): @property def __typing_is_unpacked_typevartuple__(self): return True + + +def uses_shape_dsl( + ir_fn: typing.Callable, + *, + capture_init: list[str] | None = None, +) -> typing.Callable[[typing.Callable], typing.Callable]: + """Decorator that associates a shape DSL function with an API function. + + At runtime this is a no-op: the decorator arguments are ignored and the + decorated function is returned unchanged. Pyrefly uses this decorator + at type-checking time to route bound arguments through the shape DSL + for return-type refinement. + """ + + def decorator(fn: typing.Callable) -> typing.Callable: + return fn + + return decorator From 95e72e2d98f941e6d3a64caedb2522e8f718ce6c Mon Sep 17 00:00:00 2001 From: stroxler Date: Thu, 21 May 2026 21:19:07 -0700 Subject: [PATCH 02/25] Add `shape_extensions.dsl` submodule with `shape_dsl_function` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Add the `shape_extensions.dsl` submodule containing the `shape_dsl_function` decorator. This decorator is DSL-internal — it marks a function whose body should be converted to shape DSL IR during binding. It lives in a separate submodule from the public `shape_extensions` API because it's only used inside DSL definition files (like `torch/_shapes.pyi`), not in normal stubs or user code. Also register `ShapeDslFunction` as a `SpecialExport` variant with `defined_in` matching `shape_extensions.dsl`, so later phases can detect `shape_dsl_function` decorators during binding. Differential Revision: D105696516 --- pyrefly/lib/export/special.rs | 3 +++ .../fixtures/shape_extensions/dsl.py | 24 +++++++++++++++++++ 2 files changed, 27 insertions(+) create mode 100644 test/tensor_shapes/fixtures/shape_extensions/dsl.py diff --git a/pyrefly/lib/export/special.rs b/pyrefly/lib/export/special.rs index 35cedbc09f..f8de550faf 100644 --- a/pyrefly/lib/export/special.rs +++ b/pyrefly/lib/export/special.rs @@ -72,6 +72,7 @@ pub enum SpecialExport { TypingMapping, TypeForm, UsesShapeDsl, + ShapeDslFunction, } impl SpecialExport { @@ -135,6 +136,7 @@ impl SpecialExport { "Mapping" => Some(Self::TypingMapping), "TypeForm" => Some(Self::TypeForm), "uses_shape_dsl" => Some(Self::UsesShapeDsl), + "shape_dsl_function" => Some(Self::ShapeDslFunction), _ => None, } } @@ -207,6 +209,7 @@ impl SpecialExport { ), Self::Deprecated => matches!(m.as_str(), "warnings" | "typing_extensions"), Self::UsesShapeDsl => matches!(m.as_str(), "shape_extensions"), + Self::ShapeDslFunction => matches!(m.as_str(), "shape_extensions.dsl"), } } diff --git a/test/tensor_shapes/fixtures/shape_extensions/dsl.py b/test/tensor_shapes/fixtures/shape_extensions/dsl.py new file mode 100644 index 0000000000..e6829517d4 --- /dev/null +++ b/test/tensor_shapes/fixtures/shape_extensions/dsl.py @@ -0,0 +1,24 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-ignore-all-errors + +"""DSL internals for shape typing. + +Only used inside DSL definition files (e.g. torch/_shapes.pyi), not in +normal stubs or user code. +""" + +import typing + + +def shape_dsl_function(fn: typing.Callable) -> typing.Callable: + """Marks a function as a shape DSL function. + + At runtime this is a no-op: the decorated function is returned unchanged. + Pyrefly uses this decorator at type-checking time to convert the function + body to DSL IR via convert_fndef. + """ + return fn From 298623981a98c4dc1b70b5c0109c8a1be405de9b Mon Sep 17 00:00:00 2001 From: stroxler Date: Thu, 21 May 2026 21:19:07 -0700 Subject: [PATCH 03/25] Add DSL builtins to `shape_extensions.dsl` and update builtin prefix MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Add `prod`, `sum`, and `parse_einsum_equation` stub definitions to `shape_extensions/dsl.py` so DSL function files can import them from the canonical location. Update all DSL builtin references from `shape_extensions.*` to `shape_extensions.dsl.*` in both the Rust string matching (`convert_call`, `Display for DslBuiltin`) and the `DSL_SOURCE` Python code. Fix `convert_call` to support multi-level dotted names (e.g. `shape_extensions.dsl.prod`) — it previously only handled single-dotted names like `shape_extensions.prod`. Differential Revision: D105698000 --- crates/pyrefly_types/src/meta_shape_dsl.rs | 50 +++++++++++-------- .../pyrefly_types/src/tensor_ops_registry.rs | 14 +++--- .../fixtures/shape_extensions/dsl.py | 15 ++++++ 3 files changed, 50 insertions(+), 29 deletions(-) diff --git a/crates/pyrefly_types/src/meta_shape_dsl.rs b/crates/pyrefly_types/src/meta_shape_dsl.rs index c6fd1dc27f..2a2d2d8fe0 100644 --- a/crates/pyrefly_types/src/meta_shape_dsl.rs +++ b/crates/pyrefly_types/src/meta_shape_dsl.rs @@ -607,10 +607,12 @@ impl fmt::Display for DslBuiltin { match self { DslBuiltin::Len => write!(f, "len"), DslBuiltin::Range => write!(f, "range"), - DslBuiltin::Prod => write!(f, "shape_extensions.prod"), - DslBuiltin::Sum => write!(f, "shape_extensions.sum"), + DslBuiltin::Prod => write!(f, "shape_extensions.dsl.prod"), + DslBuiltin::Sum => write!(f, "shape_extensions.dsl.sum"), DslBuiltin::Str => write!(f, "str"), - DslBuiltin::ParseEinsumEquation => write!(f, "shape_extensions.parse_einsum_equation"), + DslBuiltin::ParseEinsumEquation => { + write!(f, "shape_extensions.dsl.parse_einsum_equation") + } DslBuiltin::Enumerate => write!(f, "enumerate"), DslBuiltin::Zip => write!(f, "zip"), } @@ -1240,21 +1242,25 @@ fn convert_expr(expr: &Expr) -> Result { } } -/// Convert a function call expression, dispatching special forms. -fn convert_call(call: &ruff_python_ast::ExprCall) -> Result { - // Extract function name for dispatch. Supports both simple names (`len`) - // and dotted names (`shape_extensions.prod`). - let func_name = match call.func.as_ref() { - Expr::Name(n) => n.id.to_string(), +/// Recursively extract a dotted name from an expression (e.g. `shape_extensions.dsl.prod`). +fn dotted_name(expr: &Expr) -> Option { + match expr { + Expr::Name(n) => Some(n.id.to_string()), Expr::Attribute(a) => { - let prefix = match a.value.as_ref() { - Expr::Name(n) => n.id.as_str(), - _ => return Err(format!("unsupported call target: {:?}", call.func)), - }; - format!("{}.{}", prefix, a.attr) + let prefix = dotted_name(&a.value)?; + Some(format!("{}.{}", prefix, a.attr)) } - _ => return Err(format!("unsupported call target: {:?}", call.func)), - }; + _ => None, + } +} + +/// Convert a function call expression, dispatching special forms. +fn convert_call(call: &ruff_python_ast::ExprCall) -> Result { + // Extract function name for dispatch. Supports simple names (`len`), + // single-dotted names (`Tensor`), and multi-dotted names + // (`shape_extensions.dsl.prod`). + let func_name = dotted_name(&call.func) + .ok_or_else(|| format!("unsupported call target: {:?}", call.func))?; match func_name.as_str() { // Special forms with non-call syntax — keep as dedicated DslExpr variants @@ -1324,14 +1330,14 @@ fn convert_call(call: &ruff_python_ast::ExprCall) -> Result { "str" | "enumerate" | "zip" - | "shape_extensions.prod" - | "shape_extensions.sum" - | "shape_extensions.parse_einsum_equation" => { + | "shape_extensions.dsl.prod" + | "shape_extensions.dsl.sum" + | "shape_extensions.dsl.parse_einsum_equation" => { let builtin = match func_name.as_str() { - "shape_extensions.prod" => DslBuiltin::Prod, - "shape_extensions.sum" => DslBuiltin::Sum, + "shape_extensions.dsl.prod" => DslBuiltin::Prod, + "shape_extensions.dsl.sum" => DslBuiltin::Sum, "str" => DslBuiltin::Str, - "shape_extensions.parse_einsum_equation" => DslBuiltin::ParseEinsumEquation, + "shape_extensions.dsl.parse_einsum_equation" => DslBuiltin::ParseEinsumEquation, "enumerate" => DslBuiltin::Enumerate, "zip" => DslBuiltin::Zip, _ => unreachable!(), diff --git a/crates/pyrefly_types/src/tensor_ops_registry.rs b/crates/pyrefly_types/src/tensor_ops_registry.rs index a40bfd257c..0d2178a47b 100644 --- a/crates/pyrefly_types/src/tensor_ops_registry.rs +++ b/crates/pyrefly_types/src/tensor_ops_registry.rs @@ -646,8 +646,8 @@ def reshape_ir(self: Tensor, shape: list[int | symint]) -> Tensor: if has_zero: raise Error("reshape dimensions cannot contain 0") if minus_one_count > 0: - known = shape_extensions.prod([d for d in shape if d != -1]) - total = shape_extensions.prod(self.shape) + known = shape_extensions.dsl.prod([d for d in shape if d != -1]) + total = shape_extensions.dsl.prod(self.shape) if isinstance(total, int) and isinstance(known, int) and total % known != 0: raise Error("could not infer size for dimension -1: expected " + str(total) + " to be divisible by " + str(known)) return Tensor(shape=[total // known if d == -1 else d for d in shape]) @@ -679,7 +679,7 @@ def flatten_ir(self: Tensor, start_dim: int = 0, end_dim: int = -1) -> Tensor: rank = len(self.shape) s = normalize_dim(rank, start_dim) e = normalize_dim(rank, end_dim) - return Tensor(shape=self.shape[:s] + [shape_extensions.prod(self.shape[s:e + 1])] + self.shape[e + 1:]) + return Tensor(shape=self.shape[:s] + [shape_extensions.dsl.prod(self.shape[s:e + 1])] + self.shape[e + 1:]) def expand_ir(self: Tensor, sizes: list[int | symint]) -> Tensor: return Tensor(shape=[d if t == -1 else t for d, t in zip(self.shape, sizes)]) @@ -702,7 +702,7 @@ def unfold_ir(self: Tensor, dimension: int, size: int | symint, step: int = 1) - def cat_ir(tensors: list[Tensor], dim: int = 0) -> Tensor: first = tensors[0] d = normalize_dim(len(first.shape), dim) - return Tensor(shape=[shape_extensions.sum([t.shape[i] for t in tensors]) if i == d else dim_val for i, dim_val in enumerate(first.shape)]) + return Tensor(shape=[shape_extensions.dsl.sum([t.shape[i] for t in tensors]) if i == d else dim_val for i, dim_val in enumerate(first.shape)]) def stack_ir(tensors: list[Tensor], dim: int = 0) -> Tensor: first = tensors[0] @@ -792,7 +792,7 @@ def topk_ir(self: Tensor, k: int | symint, dim: int = -1) -> [Tensor, Tensor]: def repeat_interleave_ir(self: Tensor, repeats: int | symint, dim: int | None = None) -> Tensor: if dim == None: - return Tensor(shape=[shape_extensions.prod(self.shape) * repeats]) + return Tensor(shape=[shape_extensions.dsl.prod(self.shape) * repeats]) d = normalize_dim(len(self.shape), dim) return Tensor(shape=replace_dim(self.shape, d, self.shape[d] * repeats)) @@ -883,7 +883,7 @@ def apply_einsum(output_map: list[list[int]], check_pairs: list[list[int]], inpu def einsum_ir(spec: str, operands: list[Tensor] | None = None) -> Tensor: if operands != None: - output_map, check_pairs = shape_extensions.parse_einsum_equation(spec) + output_map, check_pairs = shape_extensions.dsl.parse_einsum_equation(spec) return apply_einsum(output_map, check_pairs, operands) return Unknown @@ -975,7 +975,7 @@ def size_ir(self: Tensor, dim: int | None = None) -> int | symint: return [d for d in self.shape] def numel_ir(self: Tensor) -> int | symint: - return shape_extensions.prod(self.shape) + return shape_extensions.dsl.prod(self.shape) def dim_ir(self: Tensor) -> int: return len(self.shape) diff --git a/test/tensor_shapes/fixtures/shape_extensions/dsl.py b/test/tensor_shapes/fixtures/shape_extensions/dsl.py index e6829517d4..3d8cfeee7e 100644 --- a/test/tensor_shapes/fixtures/shape_extensions/dsl.py +++ b/test/tensor_shapes/fixtures/shape_extensions/dsl.py @@ -22,3 +22,18 @@ def shape_dsl_function(fn: typing.Callable) -> typing.Callable: body to DSL IR via convert_fndef. """ return fn + + +def prod(xs: list[int]) -> int: + """Compute the product of a list of dimension sizes.""" + ... + + +def sum(xs: list[int]) -> int: + """Compute the sum of a list of dimension sizes.""" + ... + + +def parse_einsum_equation(spec: str) -> list[list[list[int]]]: + """Parse an einsum equation string into output map and check pairs.""" + ... From 5d0d7d4476a9a13f62b7818bb5e40eb747823ce7 Mon Sep 17 00:00:00 2001 From: stroxler Date: Thu, 21 May 2026 21:19:07 -0700 Subject: [PATCH 04/25] Add `import shape_extensions.dsl` to `DSL_SOURCE` Summary: Add an import statement at the top of `DSL_SOURCE` to make it look more like a normal Python module. The import is silently skipped by `parse_dsl` (which only processes function definitions), but having it there makes the DSL source ready to behave like a real stub file once Phase 5 migrates it out of the Rust string. Differential Revision: D105698001 --- crates/pyrefly_types/src/tensor_ops_registry.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/crates/pyrefly_types/src/tensor_ops_registry.rs b/crates/pyrefly_types/src/tensor_ops_registry.rs index 0d2178a47b..59486fec77 100644 --- a/crates/pyrefly_types/src/tensor_ops_registry.rs +++ b/crates/pyrefly_types/src/tensor_ops_registry.rs @@ -576,6 +576,8 @@ impl Default for TensorOpsRegistry { /// The full DSL source defining all tensor shape ops and utility functions. /// This is valid Python syntax (a strict subset) that we parse with Pyrefly's parser. const DSL_SOURCE: &str = r#" +import shape_extensions.dsl + def normalize_dim(rank: int, dim: int) -> int: if dim < 0: return dim + rank From 8afde149ab1175c7727eca7dc0e84768db813374 Mon Sep 17 00:00:00 2001 From: stroxler Date: Thu, 21 May 2026 21:19:07 -0700 Subject: [PATCH 05/25] Add public shape-DSL wrapper API in pyrefly_types MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Phase 2 of the tensor-shapes-in-stubs migration. Adds a public surface to `pyrefly_types::meta_shape_dsl` that lets later phases (the binder and solver in `pyrefly/lib`) drive the DSL pipeline without exposing the grammar-aligned `DslFnDef` internals. The new surface: - `ShapeDslFunction` — an opaque, cheap (one `Arc`) handle to a single DSL function lowered from its Python AST. - `ShapeDslProgram` — a bundle of `ShapeDslFunction`s that has been validated together as a program. The only way to obtain one is via `build_shape_dsl_program`, which type-checks the bundle. - `convert_shape_dsl_function` — AST → `ShapeDslFunction`. - `build_shape_dsl_program` — `Iterator` → `ShapeDslProgram`. Panics on type-check failure today, matching the existing `parse_dsl` semantics; Phase 7 of the migration will convert this to a `Result` and surface diagnostics. - `make_meta_shape_function` — `(&ShapeDslProgram, root_name)` → `Box`. Taking a `&ShapeDslProgram` (not raw pieces) enforces that callers cannot build a `MetaShapeFunction` from un-type-checked DSL. The internal helpers (`convert_fndef`, `type_check_program`, `bind_dsl_params`, `eval_dsl_body`) stay module-private; the wrappers live in the same module and call them directly. This keeps DSL internals fully opaque outside the module. No call sites change in this commit; the new API is exercised by the follow-up dogfood refactor of `tensor_ops_registry`. Differential Revision: D105720303 --- crates/pyrefly_types/src/meta_shape_dsl.rs | 190 +++++++++++++++--- .../pyrefly_types/src/tensor_ops_registry.rs | 12 +- 2 files changed, 166 insertions(+), 36 deletions(-) diff --git a/crates/pyrefly_types/src/meta_shape_dsl.rs b/crates/pyrefly_types/src/meta_shape_dsl.rs index 2a2d2d8fe0..0946625135 100644 --- a/crates/pyrefly_types/src/meta_shape_dsl.rs +++ b/crates/pyrefly_types/src/meta_shape_dsl.rs @@ -44,9 +44,7 @@ use crate::tensor::TensorType; use crate::tuple::Tuple; use crate::types::Type; -// ============================================================================ -// Runtime Values -// ============================================================================ +// Section: Runtime Values /// Runtime value produced by parameter extraction and manipulated by the /// interpreter. Bridges between `Type` (the type-checker's representation) @@ -156,9 +154,7 @@ impl Val { } } -// ============================================================================ -// Extraction Helpers -// ============================================================================ +// Section: Extraction Helpers /// Helper functions for extracting typed values from `Type`. /// @@ -332,9 +328,7 @@ mod extract { } } -// ============================================================================ -// Meta-Shape Function Trait -// ============================================================================ +// Section: Meta-Shape Function Trait /// A function that computes output shapes from input shapes. /// @@ -369,9 +363,7 @@ pub trait MetaShapeFunction: Debug + Send + Sync { } } -// ============================================================================ -// Grammar-aligned data types -// ============================================================================ +// Section: Grammar-aligned data types /// Binary operators: arithmetic, comparison, and logical. /// Corresponds to OP in ` OP `. @@ -569,9 +561,7 @@ pub(crate) struct DslFnDef { body: DslBody, } -// ============================================================================ -// Display implementations -// ============================================================================ +// Section: Display implementations impl fmt::Display for DslOp { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -767,9 +757,7 @@ impl fmt::Display for DslParam { } } -// ============================================================================ -// AST conversion: ruff Python AST → DSL grammar types -// ============================================================================ +// Section: AST conversion: ruff Python AST → DSL grammar types /// Convert an isinstance type argument to a DslTypeCon. fn convert_type_constructor(expr: &Expr) -> Result { @@ -1417,9 +1405,7 @@ fn convert_fndef(func: &ruff_python_ast::StmtFunctionDef) -> Result Result, String> { Ok(fndefs) } -// ============================================================================ -// Interpreter — evaluate DSL directly against runtime Val values -// ============================================================================ +// Section: Interpreter — evaluate DSL directly against runtime Val values /// Extract a runtime `Val` from a type-checker `Type` based on the declared `DslType`. /// `actual_arg_type` is the type the user passed; `expected_param_type` is the DSL @@ -3109,3 +3091,157 @@ impl MetaShapeFunction for DslMetaShapeFunction { } } } +// Section: Public wrapper API +// +// These wrappers form the public surface of +// `pyrefly_types::meta_shape_dsl`. They let callers outside this module (the +// binder and solver in `pyrefly/lib`) drive the DSL pipeline without exposing +// the grammar-aligned `DslFnDef` internals. +// +// Constraint: the public surface never returns `Arc`. Callers store +// `ShapeDslFunction` / `ShapeDslProgram` opaquely; only code inside +// `pyrefly_types` (e.g. `tensor_ops_registry`) may reach the underlying +// `DslFnDef` via the `pub(crate)` fields. + +/// A single DSL function that has been lowered from its Python AST. +/// +/// This is a cheap (one `Arc`) opaque handle. It is the unit produced by +/// [`convert_shape_dsl_function`] and consumed by [`build_shape_dsl_program`]. +#[derive(Debug, Clone)] +pub struct ShapeDslFunction { + pub(crate) inner: Arc, +} + +/// A bundle of DSL functions that have been validated together as a program. +/// +/// The functions held by a `ShapeDslProgram` are guaranteed to have passed +/// `type_check_program` against the program's own signature set. The factory +/// [`make_meta_shape_function`] takes a `&ShapeDslProgram` (not raw pieces) +/// precisely to enforce that callers can only build a `MetaShapeFunction` +/// from validated DSL. +/// +/// The names used as keys for helper resolution are the *local* names a +/// caller's body refers to. Programs are populated with +/// `[self_dsl_fn, ...resolved_callees]` where each callee's `name` field +/// has already been rewritten (or annotated) to its caller-local form. This +/// API does not perform that rewriting itself — it accepts programs that +/// already satisfy that constraint. +#[derive(Debug, Clone)] +pub struct ShapeDslProgram { + pub(crate) fns: Vec>, +} + +/// Convert a single Python function definition into a [`ShapeDslFunction`]. +/// +/// This is pure AST-to-IR lowering — it does not parse source text or run +/// the type checker. The output is a single opaque handle; the caller is +/// expected to combine handles from this function (and possibly other +/// modules) into a [`ShapeDslProgram`] via [`build_shape_dsl_program`]. +/// +/// Returns `Err` with a terse description if the function body uses Python +/// syntax outside the DSL subset. +pub fn convert_shape_dsl_function( + func: &ruff_python_ast::StmtFunctionDef, +) -> Result { + let fndef = convert_fndef(func)?; + Ok(ShapeDslFunction { + inner: Arc::new(fndef), + }) +} + +/// Bundle a set of [`ShapeDslFunction`]s into a validated [`ShapeDslProgram`]. +/// +/// Type-checks the bundle as a whole, building the global signature map that +/// `infer_call` needs in order to resolve cross-function calls. Today this +/// step *panics* on type errors (undefined variable, unknown call target, +/// etc.), matching the existing `parse_dsl` behavior. +pub fn build_shape_dsl_program(fns: impl IntoIterator) -> ShapeDslProgram { + let fns: Vec> = fns.into_iter().map(|f| f.inner).collect(); + // `type_check_program` borrows a `&[DslFnDef]`; clone the Arc'd defs into + // a temporary slice for the call. The clone is one-time per program build + // and avoids changing the internal `type_check_program` signature. + // `type_check_program` panics on type errors today; that matches the + // existing `parse_dsl` semantics. + let view: Vec = fns.iter().map(|f| (**f).clone()).collect(); + type_check_program(&view); + ShapeDslProgram { fns } +} + +/// Construct a [`MetaShapeFunction`] keyed at `root_name` from a validated +/// [`ShapeDslProgram`]. +/// +/// The factory takes a `&ShapeDslProgram` (not raw pieces) so that a caller +/// cannot build a `MetaShapeFunction` from un-type-checked DSL — the only +/// way to obtain a `ShapeDslProgram` is via [`build_shape_dsl_program`], +/// which runs the type checker. +/// +/// `root_name` selects which function inside the program is the entry point +/// (the one whose parameters get bound from call-site arguments). All +/// functions in the program — including `root_name` — become part of the +/// resulting `fn_lookup`, which the interpreter consults for cross-function +/// calls. The lookup is keyed by each function's `name`, which the caller +/// is responsible for setting to the caller-local name; see +/// [`ShapeDslProgram`] for that invariant. +/// +/// Returns `None` if no function in the program has `name == root_name`. +pub fn make_meta_shape_function( + program: &ShapeDslProgram, + root_name: &str, +) -> Option> { + let fn_def = program + .fns + .iter() + .find(|f| f.name == root_name) + .map(Arc::clone)?; + let fn_lookup: Arc>> = Arc::new( + program + .fns + .iter() + .map(|f| (f.name.clone(), Arc::clone(f))) + .collect(), + ); + Some(Box::new(DslMetaShapeFunction { fn_def, fn_lookup })) +} + +// TODO: Remove this unit test once the DSL is fully in stubs. +// The e2e tests in pyrefly/test/tensor_shapes will exercise the same code +// paths more thoroughly, making this redundant. +#[cfg(test)] +mod tests { + use pyrefly_python::ast::Ast; + use ruff_python_ast::PySourceType; + use ruff_python_ast::Stmt; + + use super::*; + + /// Sanity check: the public wrapper API composes from AST through to a + /// `MetaShapeFunction` without panicking on a trivial well-formed DSL + /// function. Does not evaluate the DSL — only verifies the surface + /// composes. + #[test] + fn wrapper_api_composes_on_trivial_function() { + let source = "def add_one(x: int) -> int:\n return x + 1\n"; + let (module, errors, _unsupported) = Ast::parse(source, PySourceType::Python); + assert!(errors.is_empty(), "test DSL fixture should parse cleanly"); + let func = module + .body + .iter() + .find_map(|stmt| match stmt { + Stmt::FunctionDef(f) => Some(f), + _ => None, + }) + .expect("test source contains exactly one function def"); + + let shape_fn = convert_shape_dsl_function(func).expect("lowering succeeds"); + let program = build_shape_dsl_program(std::iter::once(shape_fn)); + let meta_fn = make_meta_shape_function(&program, "add_one"); + assert!( + meta_fn.is_some(), + "factory should resolve a function whose name matches the program" + ); + assert!( + make_meta_shape_function(&program, "no_such_function").is_none(), + "factory should return None for an unknown name" + ); + } +} diff --git a/crates/pyrefly_types/src/tensor_ops_registry.rs b/crates/pyrefly_types/src/tensor_ops_registry.rs index 59486fec77..5aff7362ac 100644 --- a/crates/pyrefly_types/src/tensor_ops_registry.rs +++ b/crates/pyrefly_types/src/tensor_ops_registry.rs @@ -19,9 +19,7 @@ use crate::meta_shape_dsl::DslMetaShapeFunction; use crate::meta_shape_dsl::MetaShapeFunction; use crate::meta_shape_dsl::parse_dsl; -// ============================================================================ -// DSL-based MetaShapeFunction construction -// ============================================================================ +// Section: DSL-based MetaShapeFunction construction /// Look up a DSL function by name and create a `DslMetaShapeFunction`. fn dsl_fn( @@ -39,9 +37,7 @@ fn dsl_fn( }) } -// ============================================================================ -// Meta-Shape Registry -// ============================================================================ +// Section: Meta-Shape Registry /// Registry mapping PyTorch op names to their shape functions. /// @@ -569,9 +565,7 @@ impl Default for TensorOpsRegistry { } } -// ============================================================================ -// DSL source code -// ============================================================================ +// Section: DSL source code /// The full DSL source defining all tensor shape ops and utility functions. /// This is valid Python syntax (a strict subset) that we parse with Pyrefly's parser. From 7fdf1daa18c64d70aba2e062160c60bc7f69cc56 Mon Sep 17 00:00:00 2001 From: stroxler Date: Thu, 21 May 2026 21:19:07 -0700 Subject: [PATCH 06/25] Dogfood Phase 2 wrapper API in tensor_ops_registry MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Refactor `TensorOpsRegistry::new()` to drive DSL construction through the public `pyrefly_types::meta_shape_dsl` wrapper API added in the previous commit, rather than reaching into the now-obsolete `parse_dsl` / `Arc` / `DslMetaShapeFunction` internals. This validates the new API works end-to-end before Phase 3 depends on it, and it shrinks the number of code paths through the DSL engine to one: AST → `convert_shape_dsl_function` → `build_shape_dsl_program` → `make_meta_shape_function`. `parse_dsl` was the only caller of itself; with the registry converted there are no callers left, so it is deleted (together with the source-text parsing helper it relied on). The pipeline still panics on parse / type-check failure, exactly as before — Phase 7 will revisit error handling end-to-end. Differential Revision: D105720304 --- crates/pyrefly_types/src/meta_shape_dsl.rs | 44 +- .../pyrefly_types/src/tensor_ops_registry.rs | 435 +++++++++--------- 2 files changed, 214 insertions(+), 265 deletions(-) diff --git a/crates/pyrefly_types/src/meta_shape_dsl.rs b/crates/pyrefly_types/src/meta_shape_dsl.rs index 0946625135..c6d00bbda7 100644 --- a/crates/pyrefly_types/src/meta_shape_dsl.rs +++ b/crates/pyrefly_types/src/meta_shape_dsl.rs @@ -24,13 +24,11 @@ use std::fmt; use std::fmt::Debug; use std::sync::Arc; -use pyrefly_python::ast::Ast; use ruff_python_ast::BoolOp as RuffBoolOp; use ruff_python_ast::CmpOp as RuffCmpOp; use ruff_python_ast::Expr; use ruff_python_ast::Number; use ruff_python_ast::Operator as RuffOperator; -use ruff_python_ast::PySourceType; use ruff_python_ast::Stmt; use ruff_python_ast::UnaryOp as RuffUnaryOp; @@ -555,7 +553,7 @@ enum DslExpr { /// Function definition. Corresponds to `` in the grammar. #[derive(Debug, Clone)] pub(crate) struct DslFnDef { - pub(crate) name: String, + name: String, params: Vec, return_type: Option, body: DslBody, @@ -1902,40 +1900,6 @@ fn type_check_program(fndefs: &[DslFnDef]) { } } -// Section: Entry point - -/// Parse DSL source code, convert to grammar-aligned types, and return the -/// list of function definitions. -pub(crate) fn parse_dsl(source: &str) -> Result, String> { - let (module, errors, _unsupported) = Ast::parse(source, PySourceType::Python); - if !errors.is_empty() { - return Err(format!( - "DSL syntax errors:\n{}", - errors - .iter() - .map(|e| format!(" {}", e)) - .collect::>() - .join("\n") - )); - } - - let fndefs: Vec = module - .body - .iter() - .filter_map(|stmt| { - if let Stmt::FunctionDef(f) = stmt { - Some(convert_fndef(f)) - } else { - None // skip comments, blank lines (not in AST anyway) - } - }) - .collect::>()?; - - type_check_program(&fndefs); - - Ok(fndefs) -} - // Section: Interpreter — evaluate DSL directly against runtime Val values /// Extract a runtime `Val` from a type-checker `Type` based on the declared `DslType`. @@ -3031,12 +2995,12 @@ fn val_to_type( /// A `MetaShapeFunction` backed by a parsed DSL function definition. /// The DSL is interpreted directly — no IR conversion. #[derive(Debug)] -pub(crate) struct DslMetaShapeFunction { +struct DslMetaShapeFunction { /// The primary function to evaluate. - pub(crate) fn_def: Arc, + fn_def: Arc, /// Precomputed lookup table mapping function names to definitions. /// Shared across all instances — built once at registry init. - pub(crate) fn_lookup: Arc>>, + fn_lookup: Arc>>, } impl MetaShapeFunction for DslMetaShapeFunction { diff --git a/crates/pyrefly_types/src/tensor_ops_registry.rs b/crates/pyrefly_types/src/tensor_ops_registry.rs index 5aff7362ac..8fb6d9b2f7 100644 --- a/crates/pyrefly_types/src/tensor_ops_registry.rs +++ b/crates/pyrefly_types/src/tensor_ops_registry.rs @@ -12,29 +12,26 @@ //! interpreted directly — no IR layer. use std::collections::HashMap; -use std::sync::Arc; -use crate::meta_shape_dsl::DslFnDef; -use crate::meta_shape_dsl::DslMetaShapeFunction; +use pyrefly_python::ast::Ast; +use ruff_python_ast::PySourceType; +use ruff_python_ast::Stmt; + use crate::meta_shape_dsl::MetaShapeFunction; -use crate::meta_shape_dsl::parse_dsl; +use crate::meta_shape_dsl::ShapeDslProgram; +use crate::meta_shape_dsl::build_shape_dsl_program; +use crate::meta_shape_dsl::convert_shape_dsl_function; +use crate::meta_shape_dsl::make_meta_shape_function; // Section: DSL-based MetaShapeFunction construction -/// Look up a DSL function by name and create a `DslMetaShapeFunction`. -fn dsl_fn( - fn_lookup: &Arc>>, - name: &str, -) -> Box { - let fn_def = Arc::clone( - fn_lookup - .get(name) - .unwrap_or_else(|| panic!("DSL function `{name}` not found")), - ); - Box::new(DslMetaShapeFunction { - fn_def, - fn_lookup: Arc::clone(fn_lookup), - }) +/// Look up a DSL function by name in the shared program and wrap it as a +/// `MetaShapeFunction` suitable for registration. Panics if the program does +/// not contain a function with that name — the bundled `DSL_SOURCE` is +/// fixed, so a missing name is a programming error in this file. +fn dsl_fn(program: &ShapeDslProgram, name: &str) -> Box { + make_meta_shape_function(program, name) + .unwrap_or_else(|| panic!("DSL function `{name}` not found")) } // Section: Meta-Shape Registry @@ -58,369 +55,357 @@ pub struct TensorOpsRegistry { impl TensorOpsRegistry { /// Create a new registry with built-in meta-shape functions. pub fn new() -> Self { - // Parse DSL once; definitions are shared via Arc across all instances. - let dsl_fns: Vec> = parse_dsl(DSL_SOURCE) - .expect("DSL source in tensor_ops_registry.rs has errors") - .into_iter() - .map(Arc::new) - .collect(); - // Build function lookup table once, shared by all DslMetaShapeFunctions. - let fn_lookup: Arc>> = Arc::new( - dsl_fns - .iter() - .map(|f| (f.name.clone(), Arc::clone(f))) - .collect(), - ); + // Parse DSL_SOURCE to AST, then lower each function via the public + // `convert_shape_dsl_function` wrapper and bundle the results into a + // type-checked `ShapeDslProgram`. Sharing one `ShapeDslProgram` + // across all registrations means the underlying `Arc` + // graph is built once and re-used (each `make_meta_shape_function` + // call just clones a couple of `Arc`s). + let (module, errors, _unsupported) = Ast::parse(DSL_SOURCE, PySourceType::Python); + assert!( + errors.is_empty(), + "DSL source in tensor_ops_registry.rs has parse errors: {errors:?}" + ); + let shape_fns = module.body.iter().filter_map(|stmt| match stmt { + Stmt::FunctionDef(f) => Some( + convert_shape_dsl_function(f) + .expect("DSL source in tensor_ops_registry.rs has errors"), + ), + _ => None, + }); + let program = build_shape_dsl_program(shape_fns); let mut registry = Self { functions: HashMap::new(), init_captures: HashMap::new(), }; // Shape manipulation - registry.register_dual("reshape", || dsl_fn(&fn_lookup, "reshape_ir")); - registry.register("torch.cat", dsl_fn(&fn_lookup, "cat_ir")); - registry.register("torch.broadcast_to", dsl_fn(&fn_lookup, "broadcast_to_ir")); - registry.register_dual("squeeze", || dsl_fn(&fn_lookup, "squeeze_ir")); - registry.register_dual("unsqueeze", || dsl_fn(&fn_lookup, "unsqueeze_ir")); - registry.register_dual("transpose", || dsl_fn(&fn_lookup, "transpose_ir")); + registry.register_dual("reshape", || dsl_fn(&program, "reshape_ir")); + registry.register("torch.cat", dsl_fn(&program, "cat_ir")); + registry.register("torch.broadcast_to", dsl_fn(&program, "broadcast_to_ir")); + registry.register_dual("squeeze", || dsl_fn(&program, "squeeze_ir")); + registry.register_dual("unsqueeze", || dsl_fn(&program, "unsqueeze_ir")); + registry.register_dual("transpose", || dsl_fn(&program, "transpose_ir")); // torch.permute takes dims as a tuple; Tensor.permute takes *dims (variadic). // Both use the same DSL function; parameter binding matches by name. - registry.register("torch.permute", dsl_fn(&fn_lookup, "permute_ir")); - registry.register("torch.Tensor.permute", dsl_fn(&fn_lookup, "permute_ir")); - registry.register("torch.flatten", dsl_fn(&fn_lookup, "flatten_ir")); - registry.register("torch.stack", dsl_fn(&fn_lookup, "stack_ir")); - registry.register("torch.tile", dsl_fn(&fn_lookup, "tile_ir")); - registry.register("torch.view", dsl_fn(&fn_lookup, "reshape_ir")); - registry.register("torch.unbind", dsl_fn(&fn_lookup, "unbind_ir")); - registry.register("torch.Tensor.unbind", dsl_fn(&fn_lookup, "unbind_ir")); - registry.register("torch.movedim", dsl_fn(&fn_lookup, "movedim_ir")); - registry.register("torch.moveaxis", dsl_fn(&fn_lookup, "movedim_ir")); - registry.register("torch.Tensor.movedim", dsl_fn(&fn_lookup, "movedim_ir")); - registry.register("torch.Tensor.moveaxis", dsl_fn(&fn_lookup, "movedim_ir")); - registry.register("torch.unfold", dsl_fn(&fn_lookup, "unfold_ir")); - registry.register("torch.Tensor.unfold", dsl_fn(&fn_lookup, "unfold_ir")); + registry.register("torch.permute", dsl_fn(&program, "permute_ir")); + registry.register("torch.Tensor.permute", dsl_fn(&program, "permute_ir")); + registry.register("torch.flatten", dsl_fn(&program, "flatten_ir")); + registry.register("torch.stack", dsl_fn(&program, "stack_ir")); + registry.register("torch.tile", dsl_fn(&program, "tile_ir")); + registry.register("torch.view", dsl_fn(&program, "reshape_ir")); + registry.register("torch.unbind", dsl_fn(&program, "unbind_ir")); + registry.register("torch.Tensor.unbind", dsl_fn(&program, "unbind_ir")); + registry.register("torch.movedim", dsl_fn(&program, "movedim_ir")); + registry.register("torch.moveaxis", dsl_fn(&program, "movedim_ir")); + registry.register("torch.Tensor.movedim", dsl_fn(&program, "movedim_ir")); + registry.register("torch.Tensor.moveaxis", dsl_fn(&program, "movedim_ir")); + registry.register("torch.unfold", dsl_fn(&program, "unfold_ir")); + registry.register("torch.Tensor.unfold", dsl_fn(&program, "unfold_ir")); // Method-only shape manipulation - registry.register("torch.Tensor.reshape", dsl_fn(&fn_lookup, "reshape_ir")); - registry.register("torch.Tensor.view", dsl_fn(&fn_lookup, "reshape_ir")); - registry.register("torch.Tensor.squeeze", dsl_fn(&fn_lookup, "squeeze_ir")); - registry.register("torch.Tensor.flatten", dsl_fn(&fn_lookup, "flatten_ir")); - registry.register("torch.Tensor.tile", dsl_fn(&fn_lookup, "tile_ir")); - registry.register( - "torch.Tensor.diag_embed", - dsl_fn(&fn_lookup, "diag_embed_ir"), - ); - registry.register("torch.Tensor.repeat", dsl_fn(&fn_lookup, "repeat_ir")); - registry.register("torch.Tensor.expand", dsl_fn(&fn_lookup, "expand_ir")); + registry.register("torch.Tensor.reshape", dsl_fn(&program, "reshape_ir")); + registry.register("torch.Tensor.view", dsl_fn(&program, "reshape_ir")); + registry.register("torch.Tensor.squeeze", dsl_fn(&program, "squeeze_ir")); + registry.register("torch.Tensor.flatten", dsl_fn(&program, "flatten_ir")); + registry.register("torch.Tensor.tile", dsl_fn(&program, "tile_ir")); + registry.register("torch.Tensor.diag_embed", dsl_fn(&program, "diag_embed_ir")); + registry.register("torch.Tensor.repeat", dsl_fn(&program, "repeat_ir")); + registry.register("torch.Tensor.expand", dsl_fn(&program, "expand_ir")); // Reduction operations - registry.register_dual("sum", || dsl_fn(&fn_lookup, "reduce_ir")); - registry.register_dual("mean", || dsl_fn(&fn_lookup, "reduce_ir")); - registry.register_dual("prod", || dsl_fn(&fn_lookup, "reduce_ir")); - registry.register_dual("min", || dsl_fn(&fn_lookup, "min_max_median_ir")); - registry.register_dual("max", || dsl_fn(&fn_lookup, "min_max_median_ir")); - registry.register_dual("all", || dsl_fn(&fn_lookup, "reduce_ir")); - registry.register_dual("any", || dsl_fn(&fn_lookup, "reduce_ir")); - registry.register_dual("std", || dsl_fn(&fn_lookup, "reduce_ir")); - registry.register_dual("var", || dsl_fn(&fn_lookup, "reduce_ir")); - registry.register_dual("argmax", || dsl_fn(&fn_lookup, "reduce_ir")); - registry.register_dual("argmin", || dsl_fn(&fn_lookup, "reduce_ir")); - registry.register("torch.median", dsl_fn(&fn_lookup, "min_max_median_ir")); - registry.register("torch.logsumexp", dsl_fn(&fn_lookup, "reduce_ir")); - registry.register("torch.count_nonzero", dsl_fn(&fn_lookup, "reduce_ir")); - registry.register("torch.aminmax", dsl_fn(&fn_lookup, "aminmax_ir")); - registry.register("torch.norm", dsl_fn(&fn_lookup, "reduce_ir")); - registry.register("torch.mode", dsl_fn(&fn_lookup, "tuple_reduce_ir")); - registry.register("torch.topk", dsl_fn(&fn_lookup, "topk_ir")); - registry.register("torch.kthvalue", dsl_fn(&fn_lookup, "tuple_reduce_ir")); - registry.register("torch.var_mean", dsl_fn(&fn_lookup, "aminmax_ir")); - registry.register("torch.std_mean", dsl_fn(&fn_lookup, "aminmax_ir")); + registry.register_dual("sum", || dsl_fn(&program, "reduce_ir")); + registry.register_dual("mean", || dsl_fn(&program, "reduce_ir")); + registry.register_dual("prod", || dsl_fn(&program, "reduce_ir")); + registry.register_dual("min", || dsl_fn(&program, "min_max_median_ir")); + registry.register_dual("max", || dsl_fn(&program, "min_max_median_ir")); + registry.register_dual("all", || dsl_fn(&program, "reduce_ir")); + registry.register_dual("any", || dsl_fn(&program, "reduce_ir")); + registry.register_dual("std", || dsl_fn(&program, "reduce_ir")); + registry.register_dual("var", || dsl_fn(&program, "reduce_ir")); + registry.register_dual("argmax", || dsl_fn(&program, "reduce_ir")); + registry.register_dual("argmin", || dsl_fn(&program, "reduce_ir")); + registry.register("torch.median", dsl_fn(&program, "min_max_median_ir")); + registry.register("torch.logsumexp", dsl_fn(&program, "reduce_ir")); + registry.register("torch.count_nonzero", dsl_fn(&program, "reduce_ir")); + registry.register("torch.aminmax", dsl_fn(&program, "aminmax_ir")); + registry.register("torch.norm", dsl_fn(&program, "reduce_ir")); + registry.register("torch.mode", dsl_fn(&program, "tuple_reduce_ir")); + registry.register("torch.topk", dsl_fn(&program, "topk_ir")); + registry.register("torch.kthvalue", dsl_fn(&program, "tuple_reduce_ir")); + registry.register("torch.var_mean", dsl_fn(&program, "aminmax_ir")); + registry.register("torch.std_mean", dsl_fn(&program, "aminmax_ir")); // Reduction method versions - registry.register( - "torch.Tensor.median", - dsl_fn(&fn_lookup, "min_max_median_ir"), - ); - registry.register("torch.Tensor.logsumexp", dsl_fn(&fn_lookup, "reduce_ir")); - registry.register( - "torch.Tensor.count_nonzero", - dsl_fn(&fn_lookup, "reduce_ir"), - ); - registry.register("torch.Tensor.aminmax", dsl_fn(&fn_lookup, "aminmax_ir")); - registry.register("torch.Tensor.norm", dsl_fn(&fn_lookup, "reduce_ir")); - registry.register("torch.Tensor.mode", dsl_fn(&fn_lookup, "tuple_reduce_ir")); - registry.register("torch.Tensor.topk", dsl_fn(&fn_lookup, "topk_ir")); - registry.register( - "torch.Tensor.kthvalue", - dsl_fn(&fn_lookup, "tuple_reduce_ir"), - ); + registry.register("torch.Tensor.median", dsl_fn(&program, "min_max_median_ir")); + registry.register("torch.Tensor.logsumexp", dsl_fn(&program, "reduce_ir")); + registry.register("torch.Tensor.count_nonzero", dsl_fn(&program, "reduce_ir")); + registry.register("torch.Tensor.aminmax", dsl_fn(&program, "aminmax_ir")); + registry.register("torch.Tensor.norm", dsl_fn(&program, "reduce_ir")); + registry.register("torch.Tensor.mode", dsl_fn(&program, "tuple_reduce_ir")); + registry.register("torch.Tensor.topk", dsl_fn(&program, "topk_ir")); + registry.register("torch.Tensor.kthvalue", dsl_fn(&program, "tuple_reduce_ir")); // Repeat interleave registry.register( "torch.Tensor.repeat_interleave", - dsl_fn(&fn_lookup, "repeat_interleave_ir"), + dsl_fn(&program, "repeat_interleave_ir"), ); registry.register( "torch.repeat_interleave", - dsl_fn(&fn_lookup, "repeat_interleave_ir"), + dsl_fn(&program, "repeat_interleave_ir"), ); // Cosine similarity (reduces one dim) registry.register( "torch.nn.functional.cosine_similarity", - dsl_fn(&fn_lookup, "cosine_similarity_ir"), + dsl_fn(&program, "cosine_similarity_ir"), ); // Indexing/slicing - registry.register("torch.select", dsl_fn(&fn_lookup, "select_ir")); - registry.register("torch.narrow", dsl_fn(&fn_lookup, "narrow_ir")); - registry.register("torch.split", dsl_fn(&fn_lookup, "split_ir")); - registry.register("torch.chunk", dsl_fn(&fn_lookup, "chunk_ir")); - registry.register("torch.index_select", dsl_fn(&fn_lookup, "index_select_ir")); - registry.register("torch.Tensor.select", dsl_fn(&fn_lookup, "select_ir")); - registry.register("torch.Tensor.narrow", dsl_fn(&fn_lookup, "narrow_ir")); - registry.register("torch.Tensor.split", dsl_fn(&fn_lookup, "split_ir")); - registry.register("torch.Tensor.chunk", dsl_fn(&fn_lookup, "chunk_ir")); + registry.register("torch.select", dsl_fn(&program, "select_ir")); + registry.register("torch.narrow", dsl_fn(&program, "narrow_ir")); + registry.register("torch.split", dsl_fn(&program, "split_ir")); + registry.register("torch.chunk", dsl_fn(&program, "chunk_ir")); + registry.register("torch.index_select", dsl_fn(&program, "index_select_ir")); + registry.register("torch.Tensor.select", dsl_fn(&program, "select_ir")); + registry.register("torch.Tensor.narrow", dsl_fn(&program, "narrow_ir")); + registry.register("torch.Tensor.split", dsl_fn(&program, "split_ir")); + registry.register("torch.Tensor.chunk", dsl_fn(&program, "chunk_ir")); registry.register( "torch.Tensor.index_select", - dsl_fn(&fn_lookup, "index_select_ir"), + dsl_fn(&program, "index_select_ir"), ); // Tensor creation - registry.register("torch.randn", dsl_fn(&fn_lookup, "randn_ir")); - registry.register("torch.rand", dsl_fn(&fn_lookup, "randn_ir")); - registry.register("torch.zeros", dsl_fn(&fn_lookup, "randn_ir")); - registry.register("torch.ones", dsl_fn(&fn_lookup, "randn_ir")); - registry.register("torch.empty", dsl_fn(&fn_lookup, "randn_ir")); - registry.register("torch.full", dsl_fn(&fn_lookup, "randn_ir")); - registry.register("torch.randint", dsl_fn(&fn_lookup, "randint_ir")); - registry.register("torch.arange", dsl_fn(&fn_lookup, "arange_ir")); - registry.register("torch.linspace", dsl_fn(&fn_lookup, "linspace_ir")); - registry.register("torch.eye", dsl_fn(&fn_lookup, "eye_ir")); - registry.register("torch.diag_embed", dsl_fn(&fn_lookup, "diag_embed_ir")); - registry.register("torch.tril_indices", dsl_fn(&fn_lookup, "tri_indices_ir")); - registry.register("torch.triu_indices", dsl_fn(&fn_lookup, "tri_indices_ir")); + registry.register("torch.randn", dsl_fn(&program, "randn_ir")); + registry.register("torch.rand", dsl_fn(&program, "randn_ir")); + registry.register("torch.zeros", dsl_fn(&program, "randn_ir")); + registry.register("torch.ones", dsl_fn(&program, "randn_ir")); + registry.register("torch.empty", dsl_fn(&program, "randn_ir")); + registry.register("torch.full", dsl_fn(&program, "randn_ir")); + registry.register("torch.randint", dsl_fn(&program, "randint_ir")); + registry.register("torch.arange", dsl_fn(&program, "arange_ir")); + registry.register("torch.linspace", dsl_fn(&program, "linspace_ir")); + registry.register("torch.eye", dsl_fn(&program, "eye_ir")); + registry.register("torch.diag_embed", dsl_fn(&program, "diag_embed_ir")); + registry.register("torch.tril_indices", dsl_fn(&program, "tri_indices_ir")); + registry.register("torch.triu_indices", dsl_fn(&program, "tri_indices_ir")); // Linear algebra - registry.register("torch.matmul", dsl_fn(&fn_lookup, "matmul_ir")); - registry.register("torch.mv", dsl_fn(&fn_lookup, "mv_ir")); - registry.register("torch.outer", dsl_fn(&fn_lookup, "outer_ir")); - registry.register("torch.tensordot", dsl_fn(&fn_lookup, "tensordot_ir")); - registry.register("torch.einsum", dsl_fn(&fn_lookup, "einsum_ir")); - registry.register("torch.Tensor.matmul", dsl_fn(&fn_lookup, "matmul_ir")); - registry.register("torch.Tensor.__matmul__", dsl_fn(&fn_lookup, "matmul_ir")); - registry.register("torch.Tensor.mv", dsl_fn(&fn_lookup, "mv_ir")); + registry.register("torch.matmul", dsl_fn(&program, "matmul_ir")); + registry.register("torch.mv", dsl_fn(&program, "mv_ir")); + registry.register("torch.outer", dsl_fn(&program, "outer_ir")); + registry.register("torch.tensordot", dsl_fn(&program, "tensordot_ir")); + registry.register("torch.einsum", dsl_fn(&program, "einsum_ir")); + registry.register("torch.Tensor.matmul", dsl_fn(&program, "matmul_ir")); + registry.register("torch.Tensor.__matmul__", dsl_fn(&program, "matmul_ir")); + registry.register("torch.Tensor.mv", dsl_fn(&program, "mv_ir")); // Eigenvalue decomposition - registry.register("torch.linalg.eig", dsl_fn(&fn_lookup, "eig_ir")); - registry.register("torch.eig", dsl_fn(&fn_lookup, "eig_ir")); - registry.register("torch.linalg.eigh", dsl_fn(&fn_lookup, "eig_ir")); - registry.register("torch.eigh", dsl_fn(&fn_lookup, "eig_ir")); - registry.register("torch.linalg.eigvals", dsl_fn(&fn_lookup, "eigvals_ir")); - registry.register("torch.linalg.eigvalsh", dsl_fn(&fn_lookup, "eigvals_ir")); + registry.register("torch.linalg.eig", dsl_fn(&program, "eig_ir")); + registry.register("torch.eig", dsl_fn(&program, "eig_ir")); + registry.register("torch.linalg.eigh", dsl_fn(&program, "eig_ir")); + registry.register("torch.eigh", dsl_fn(&program, "eig_ir")); + registry.register("torch.linalg.eigvals", dsl_fn(&program, "eigvals_ir")); + registry.register("torch.linalg.eigvalsh", dsl_fn(&program, "eigvals_ir")); // Linear solvers - registry.register("torch.linalg.solve", dsl_fn(&fn_lookup, "solve_ir")); - registry.register("torch.solve", dsl_fn(&fn_lookup, "solve_ir")); + registry.register("torch.linalg.solve", dsl_fn(&program, "solve_ir")); + registry.register("torch.solve", dsl_fn(&program, "solve_ir")); registry.register( "torch.linalg.solve_triangular", - dsl_fn(&fn_lookup, "solve_ir"), + dsl_fn(&program, "solve_ir"), ); registry.register( "torch.triangular_solve", - dsl_fn(&fn_lookup, "solve_reversed_ir"), + dsl_fn(&program, "solve_reversed_ir"), ); registry.register( "torch.linalg.cholesky_solve", - dsl_fn(&fn_lookup, "solve_reversed_ir"), + dsl_fn(&program, "solve_reversed_ir"), ); registry.register( "torch.cholesky_solve", - dsl_fn(&fn_lookup, "solve_reversed_ir"), + dsl_fn(&program, "solve_reversed_ir"), ); - registry.register("torch.lu_solve", dsl_fn(&fn_lookup, "solve_ir")); + registry.register("torch.lu_solve", dsl_fn(&program, "solve_ir")); // Determinant - registry.register("torch.linalg.slogdet", dsl_fn(&fn_lookup, "slogdet_ir")); - registry.register("torch.slogdet", dsl_fn(&fn_lookup, "slogdet_ir")); - registry.register("torch.Tensor.slogdet", dsl_fn(&fn_lookup, "slogdet_ir")); + registry.register("torch.linalg.slogdet", dsl_fn(&program, "slogdet_ir")); + registry.register("torch.slogdet", dsl_fn(&program, "slogdet_ir")); + registry.register("torch.Tensor.slogdet", dsl_fn(&program, "slogdet_ir")); // Convolution - registry.register("torch.nn.functional.conv1d", dsl_fn(&fn_lookup, "conv_ir")); - registry.register("torch.nn.functional.conv2d", dsl_fn(&fn_lookup, "conv_ir")); - registry.register("torch.nn.functional.conv3d", dsl_fn(&fn_lookup, "conv_ir")); + registry.register("torch.nn.functional.conv1d", dsl_fn(&program, "conv_ir")); + registry.register("torch.nn.functional.conv2d", dsl_fn(&program, "conv_ir")); + registry.register("torch.nn.functional.conv3d", dsl_fn(&program, "conv_ir")); registry.register( "torch.nn.functional.conv_transpose1d", - dsl_fn(&fn_lookup, "conv_transpose_ir"), + dsl_fn(&program, "conv_transpose_ir"), ); registry.register( "torch.nn.functional.conv_transpose2d", - dsl_fn(&fn_lookup, "conv_transpose_ir"), + dsl_fn(&program, "conv_transpose_ir"), ); registry.register( "torch.nn.functional.conv_transpose3d", - dsl_fn(&fn_lookup, "conv_transpose_ir"), + dsl_fn(&program, "conv_transpose_ir"), ); // Pooling registry.register( "torch.nn.functional.max_pool1d", - dsl_fn(&fn_lookup, "pool_ir"), + dsl_fn(&program, "pool_ir"), ); registry.register( "torch.nn.functional.max_pool2d", - dsl_fn(&fn_lookup, "pool_ir"), + dsl_fn(&program, "pool_ir"), ); registry.register( "torch.nn.functional.max_pool3d", - dsl_fn(&fn_lookup, "pool_ir"), + dsl_fn(&program, "pool_ir"), ); registry.register( "torch.nn.functional.avg_pool1d", - dsl_fn(&fn_lookup, "pool_ir"), + dsl_fn(&program, "pool_ir"), ); registry.register( "torch.nn.functional.avg_pool2d", - dsl_fn(&fn_lookup, "pool_ir"), + dsl_fn(&program, "pool_ir"), ); registry.register( "torch.nn.functional.avg_pool3d", - dsl_fn(&fn_lookup, "pool_ir"), + dsl_fn(&program, "pool_ir"), ); registry.register( "torch.nn.functional.adaptive_max_pool1d", - dsl_fn(&fn_lookup, "adaptive_pool_ir"), + dsl_fn(&program, "adaptive_pool_ir"), ); registry.register( "torch.nn.functional.adaptive_max_pool2d", - dsl_fn(&fn_lookup, "adaptive_pool_ir"), + dsl_fn(&program, "adaptive_pool_ir"), ); registry.register( "torch.nn.functional.adaptive_max_pool3d", - dsl_fn(&fn_lookup, "adaptive_pool_ir"), + dsl_fn(&program, "adaptive_pool_ir"), ); registry.register( "torch.nn.functional.adaptive_avg_pool1d", - dsl_fn(&fn_lookup, "adaptive_pool_ir"), + dsl_fn(&program, "adaptive_pool_ir"), ); registry.register( "torch.nn.functional.adaptive_avg_pool2d", - dsl_fn(&fn_lookup, "adaptive_pool_ir"), + dsl_fn(&program, "adaptive_pool_ir"), ); registry.register( "torch.nn.functional.adaptive_avg_pool3d", - dsl_fn(&fn_lookup, "adaptive_pool_ir"), + dsl_fn(&program, "adaptive_pool_ir"), ); // Interpolation registry.register( "torch.nn.functional.interpolate", - dsl_fn(&fn_lookup, "interpolate_ir"), + dsl_fn(&program, "interpolate_ir"), ); registry.register( "torch.nn.functional.upsample", - dsl_fn(&fn_lookup, "interpolate_ir"), + dsl_fn(&program, "interpolate_ir"), ); // Conditional operations - registry.register("torch.where", dsl_fn(&fn_lookup, "where_ir")); + registry.register("torch.where", dsl_fn(&program, "where_ir")); registry.register( "torch.take_along_dim", - dsl_fn(&fn_lookup, "take_along_dim_ir"), + dsl_fn(&program, "take_along_dim_ir"), ); registry.register( "torch.Tensor.take_along_dim", - dsl_fn(&fn_lookup, "take_along_dim_ir"), + dsl_fn(&program, "take_along_dim_ir"), ); // Loss functions - registry.register( - "torch.nn.functional.mse_loss", - dsl_fn(&fn_lookup, "loss_ir"), - ); - registry.register("torch.nn.functional.l1_loss", dsl_fn(&fn_lookup, "loss_ir")); - registry.register( - "torch.nn.functional.nll_loss", - dsl_fn(&fn_lookup, "loss_ir"), - ); + registry.register("torch.nn.functional.mse_loss", dsl_fn(&program, "loss_ir")); + registry.register("torch.nn.functional.l1_loss", dsl_fn(&program, "loss_ir")); + registry.register("torch.nn.functional.nll_loss", dsl_fn(&program, "loss_ir")); registry.register( "torch.nn.functional.cross_entropy", - dsl_fn(&fn_lookup, "loss_ir"), + dsl_fn(&program, "loss_ir"), ); registry.register( "torch.nn.functional.binary_cross_entropy", - dsl_fn(&fn_lookup, "loss_ir"), + dsl_fn(&program, "loss_ir"), ); registry.register( "torch.nn.functional.binary_cross_entropy_with_logits", - dsl_fn(&fn_lookup, "loss_ir"), + dsl_fn(&program, "loss_ir"), ); - registry.register("torch.nn.functional.kl_div", dsl_fn(&fn_lookup, "loss_ir")); + registry.register("torch.nn.functional.kl_div", dsl_fn(&program, "loss_ir")); registry.register( "torch.nn.functional.smooth_l1_loss", - dsl_fn(&fn_lookup, "loss_ir"), + dsl_fn(&program, "loss_ir"), ); registry.register( "torch.nn.functional.huber_loss", - dsl_fn(&fn_lookup, "loss_ir"), + dsl_fn(&program, "loss_ir"), ); registry.register( "torch.nn.functional.poisson_nll_loss", - dsl_fn(&fn_lookup, "loss_ir"), + dsl_fn(&program, "loss_ir"), ); registry.register( "torch.nn.functional.cosine_embedding_loss", - dsl_fn(&fn_lookup, "loss_ir"), + dsl_fn(&program, "loss_ir"), ); registry.register( "torch.nn.functional.margin_ranking_loss", - dsl_fn(&fn_lookup, "loss_ir"), + dsl_fn(&program, "loss_ir"), ); registry.register( "torch.nn.functional.triplet_margin_loss", - dsl_fn(&fn_lookup, "loss_ir"), + dsl_fn(&program, "loss_ir"), ); registry.register( "torch.nn.functional.hinge_embedding_loss", - dsl_fn(&fn_lookup, "loss_ir"), + dsl_fn(&program, "loss_ir"), ); // Padding - registry.register("torch.nn.functional.pad", dsl_fn(&fn_lookup, "pad_ir")); + registry.register("torch.nn.functional.pad", dsl_fn(&program, "pad_ir")); // FFT - registry.register("torch.fft.rfft", dsl_fn(&fn_lookup, "rfft_ir")); - registry.register("torch.fft.irfft", dsl_fn(&fn_lookup, "irfft_ir")); - registry.register("torch.fft.hfft", dsl_fn(&fn_lookup, "irfft_ir")); - registry.register("torch.fft.ihfft", dsl_fn(&fn_lookup, "rfft_ir")); + registry.register("torch.fft.rfft", dsl_fn(&program, "rfft_ir")); + registry.register("torch.fft.irfft", dsl_fn(&program, "irfft_ir")); + registry.register("torch.fft.hfft", dsl_fn(&program, "irfft_ir")); + registry.register("torch.fft.ihfft", dsl_fn(&program, "rfft_ir")); // Tensor properties - registry.register("torch.Tensor.size", dsl_fn(&fn_lookup, "size_ir")); - registry.register("torch.Tensor.numel", dsl_fn(&fn_lookup, "numel_ir")); - registry.register("torch.Tensor.dim", dsl_fn(&fn_lookup, "dim_ir")); - registry.register("torch.Tensor.nelement", dsl_fn(&fn_lookup, "numel_ir")); - registry.register("torch.Tensor.item", dsl_fn(&fn_lookup, "item_ir")); - registry.register("torch.Tensor.tolist", dsl_fn(&fn_lookup, "tolist_ir")); - registry.register("torch.numel", dsl_fn(&fn_lookup, "numel_ir")); + registry.register("torch.Tensor.size", dsl_fn(&program, "size_ir")); + registry.register("torch.Tensor.numel", dsl_fn(&program, "numel_ir")); + registry.register("torch.Tensor.dim", dsl_fn(&program, "dim_ir")); + registry.register("torch.Tensor.nelement", dsl_fn(&program, "numel_ir")); + registry.register("torch.Tensor.item", dsl_fn(&program, "item_ir")); + registry.register("torch.Tensor.tolist", dsl_fn(&program, "tolist_ir")); + registry.register("torch.numel", dsl_fn(&program, "numel_ir")); // nn.Module forward methods with init capture. // register_init_forward registers both the forward DSL function and the // list of __init__ params to capture in the NNModule type. let maxpool_captures = &["kernel_size", "stride", "padding", "dilation"]; registry.register_init_forward( - &fn_lookup, + &program, "torch.nn.MaxPool1d", "nn_maxpool_forward_ir", maxpool_captures, ); registry.register_init_forward( - &fn_lookup, + &program, "torch.nn.MaxPool2d", "nn_maxpool_forward_ir", maxpool_captures, ); registry.register_init_forward( - &fn_lookup, + &program, "torch.nn.MaxPool3d", "nn_maxpool_forward_ir", maxpool_captures, @@ -428,81 +413,81 @@ impl TensorOpsRegistry { let avgpool_captures = &["kernel_size", "stride", "padding"]; registry.register_init_forward( - &fn_lookup, + &program, "torch.nn.AvgPool1d", "nn_avgpool_forward_ir", avgpool_captures, ); registry.register_init_forward( - &fn_lookup, + &program, "torch.nn.AvgPool2d", "nn_avgpool_forward_ir", avgpool_captures, ); registry.register_init_forward( - &fn_lookup, + &program, "torch.nn.AvgPool3d", "nn_avgpool_forward_ir", avgpool_captures, ); registry.register_init_forward( - &fn_lookup, + &program, "torch.nn.Flatten", "nn_flatten_forward_ir", &["start_dim", "end_dim"], ); registry.register_init_forward( - &fn_lookup, + &program, "torch.nn.PixelShuffle", "nn_pixel_shuffle_forward_ir", &["upscale_factor"], ); - registry.register_init_forward(&fn_lookup, "torch.nn.GLU", "nn_glu_forward_ir", &["dim"]); + registry.register_init_forward(&program, "torch.nn.GLU", "nn_glu_forward_ir", &["dim"]); registry.register_init_forward( - &fn_lookup, + &program, "torch.nn.LSTM", "nn_lstm_forward_ir", &["input_size", "hidden_size", "num_layers", "bidirectional"], ); registry.register_init_forward( - &fn_lookup, + &program, "torch.nn.Upsample", "nn_upsample_forward_ir", &["size", "scale_factor"], ); registry.register_init_forward( - &fn_lookup, + &program, "torch.nn.GRU", "nn_gru_forward_ir", &["input_size", "hidden_size", "num_layers", "bidirectional"], ); registry.register_init_forward( - &fn_lookup, + &program, "torch.nn.LSTMCell", "nn_lstmcell_forward_ir", &["input_size", "hidden_size"], ); registry.register_init_forward( - &fn_lookup, + &program, "torch.nn.ReflectionPad2d", "nn_reflectionpad2d_forward_ir", &["padding"], ); registry.register_init_forward( - &fn_lookup, + &program, "torch.nn.ReplicationPad2d", "nn_reflectionpad2d_forward_ir", &["padding"], ); // Random sampling - registry.register("torch.multinomial", dsl_fn(&fn_lookup, "multinomial_ir")); + registry.register("torch.multinomial", dsl_fn(&program, "multinomial_ir")); registry.register( "torch.Tensor.multinomial", - dsl_fn(&fn_lookup, "multinomial_ir"), + dsl_fn(&program, "multinomial_ir"), ); - registry.register("torch.normal", dsl_fn(&fn_lookup, "normal_ir")); + registry.register("torch.normal", dsl_fn(&program, "normal_ir")); registry } @@ -537,14 +522,14 @@ impl TensorOpsRegistry { /// the init captures under `"{class_name}"`. fn register_init_forward( &mut self, - fn_lookup: &Arc>>, + program: &ShapeDslProgram, class_name: &str, dsl_fn_name: &str, capture_params: &[&str], ) { self.functions.insert( format!("{class_name}.forward"), - dsl_fn(fn_lookup, dsl_fn_name), + dsl_fn(program, dsl_fn_name), ); self.init_captures.insert( class_name.to_owned(), From 28b7426b631e7bc47af1842402c855ac8585126a Mon Sep 17 00:00:00 2001 From: stroxler Date: Thu, 21 May 2026 21:19:07 -0700 Subject: [PATCH 07/25] Add `capture_init` plumbing to class metadata MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Thread a `capture_init: Option>` field from class binding through to solved `ClassMetadata`, following the `pydantic_before_validator_fields` precedent. This field holds `__init__` parameter names extracted from `uses_shape_dsl(..., capture_init=[...])` decorators on `forward` methods — today it is populated but not yet consumed. Phase 4 will wire it into `maybe_wrap_nn_module` to replace the hardcoded `TensorOpsRegistry::get_init_capture` lookup. Differential Revision: D105720302 --- pyrefly/lib/alt/class/class_metadata.rs | 2 ++ pyrefly/lib/alt/solve.rs | 2 ++ pyrefly/lib/alt/types/class_metadata.rs | 10 ++++++++ pyrefly/lib/binding/binding.rs | 5 +++- pyrefly/lib/binding/class.rs | 34 +++++++++++++++++++++++++ pyrefly/lib/report/binding_memory.rs | 1 + 6 files changed, 53 insertions(+), 1 deletion(-) diff --git a/pyrefly/lib/alt/class/class_metadata.rs b/pyrefly/lib/alt/class/class_metadata.rs index c0a3853bcc..a4e5994e0d 100644 --- a/pyrefly/lib/alt/class/class_metadata.rs +++ b/pyrefly/lib/alt/class/class_metadata.rs @@ -129,6 +129,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { pydantic_config_dict: &PydanticConfigDict, pydantic_before_validator_fields: &[Name], django_field_info: &DjangoFieldInfo, + capture_init: Option<&[Name]>, errors: &ErrorCollector, ) -> ClassMetadata { // Get class decorators. @@ -507,6 +508,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { is_factory_boy_factory, is_metaclass, slots_info, + capture_init.map(|names| names.to_vec()), ) } diff --git a/pyrefly/lib/alt/solve.rs b/pyrefly/lib/alt/solve.rs index dd49d2f0c5..5372ef2953 100644 --- a/pyrefly/lib/alt/solve.rs +++ b/pyrefly/lib/alt/solve.rs @@ -344,6 +344,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { pydantic_config_dict, pydantic_before_validator_fields, django_field_info, + capture_init, } = binding; let metadata = match &self.get_idx(*k).0 { None => ClassMetadata::recursive(), @@ -356,6 +357,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { pydantic_config_dict, pydantic_before_validator_fields, django_field_info, + capture_init.as_deref(), errors, ), }; diff --git a/pyrefly/lib/alt/types/class_metadata.rs b/pyrefly/lib/alt/types/class_metadata.rs index 8d4b97eb7b..7c8d8636e3 100644 --- a/pyrefly/lib/alt/types/class_metadata.rs +++ b/pyrefly/lib/alt/types/class_metadata.rs @@ -74,6 +74,9 @@ pub struct ClassMetadata { /// Whether this class is a metaclass (i.e., a subclass of `type`). is_metaclass: bool, slots_info: Option, + /// `__init__` parameter names to capture for shape inference, extracted from + /// `@uses_shape_dsl(..., capture_init=[...])` on a `forward` method. + capture_init: Option>, } impl VisitMut for ClassMetadata { @@ -132,6 +135,7 @@ impl ClassMetadata { is_factory_boy_factory: bool, is_metaclass: bool, slots_info: Option, + capture_init: Option>, ) -> ClassMetadata { ClassMetadata { metaclass, @@ -158,6 +162,7 @@ impl ClassMetadata { is_factory_boy_factory, is_metaclass, slots_info, + capture_init, } } @@ -187,6 +192,7 @@ impl ClassMetadata { is_factory_boy_factory: false, is_metaclass: false, slots_info: None, + capture_init: None, } } @@ -346,6 +352,10 @@ impl ClassMetadata { pub fn django_model_metadata(&self) -> Option<&DjangoModelMetadata> { self.django_model_metadata.as_ref() } + + pub fn capture_init(&self) -> Option<&[Name]> { + self.capture_init.as_deref() + } } /// A field that we synthesize and add to a class. Note that if a non-synthesized field already diff --git a/pyrefly/lib/binding/binding.rs b/pyrefly/lib/binding/binding.rs index 005dc9ea8f..8c23943b41 100644 --- a/pyrefly/lib/binding/binding.rs +++ b/pyrefly/lib/binding/binding.rs @@ -118,7 +118,7 @@ assert_words!(BindingAnnotation, 15); assert_words!(BindingClass, 11); assert_words!(BindingTParams, 10); assert_words!(BindingClassBaseType, 3); -assert_words!(BindingClassMetadata, 11); +assert_words!(BindingClassMetadata, 13); assert_bytes!(BindingClassMro, 4); assert_bytes!(BindingAbstractClassCheck, 4); assert_bytes!(BindingClassSubscriptSymmetry, 4); @@ -3167,6 +3167,9 @@ pub struct BindingClassMetadata { pub pydantic_before_validator_fields: Box<[Name]>, /// Django-specific field information. pub django_field_info: Box, + /// `__init__` parameter names to capture for shape inference, extracted from + /// `@uses_shape_dsl(..., capture_init=[...])` on a `forward` method. + pub capture_init: Option>, } impl DisplayWith for BindingClassMetadata { diff --git a/pyrefly/lib/binding/class.rs b/pyrefly/lib/binding/class.rs index 8faa1ce0c8..fb821221fb 100644 --- a/pyrefly/lib/binding/class.rs +++ b/pyrefly/lib/binding/class.rs @@ -246,6 +246,7 @@ impl<'a> BindingsBuilder<'a> { let body = mem::take(&mut x.body); let field_docstrings = self.extract_field_docstrings(&body); let pydantic_before_validator_fields = self.extract_field_validator_fields(&body); + let capture_init = self.extract_capture_init(&body); let decorators = self.ensure_and_bind_decorators(mem::take(&mut x.decorator_list), class_object.usage()); @@ -540,6 +541,7 @@ impl<'a> BindingsBuilder<'a> { pydantic_before_validator_fields: pydantic_before_validator_fields .into_boxed_slice(), django_field_info: Box::new(django_field_info), + capture_init: capture_init.map(|v| v.into_boxed_slice()), }, ); self.insert_binding_idx( @@ -556,6 +558,37 @@ impl<'a> BindingsBuilder<'a> { ); } + /// Scan a class body for a `forward` method decorated with + /// `@uses_shape_dsl(..., capture_init=[...])` and return the list of `__init__` + /// parameter names to capture for shape inference. + fn extract_capture_init(&self, body: &[Stmt]) -> Option> { + body.iter() + .filter_map(|stmt| stmt.as_function_def_stmt()) + .filter(|func_def| func_def.name.as_str() == "forward") + .flat_map(|func_def| &func_def.decorator_list) + .find_map(|decorator| { + let call = decorator.expression.as_call_expr()?; + if self.as_special_export(&call.func) != Some(SpecialExport::UsesShapeDsl) { + return None; + } + let capture_init_kw = call.arguments.keywords.iter().find(|kw| { + kw.arg + .as_ref() + .is_some_and(|a| a.as_str() == "capture_init") + })?; + let list = capture_init_kw.value.as_list_expr()?; + let names: Vec = list + .elts + .iter() + .filter_map(|elt| { + elt.as_string_literal_expr() + .map(|s| Name::new(s.value.to_str())) + }) + .collect(); + Some(names) + }) + } + /// Extracts docstrings for each field, mapping the field's range to the docstring's range. fn extract_field_docstrings( &self, @@ -994,6 +1027,7 @@ impl<'a> BindingsBuilder<'a> { pydantic_config_dict: PydanticConfigDict::default(), pydantic_before_validator_fields: Box::default(), django_field_info: Box::default(), + capture_init: None, }, ); self.insert_binding_idx( diff --git a/pyrefly/lib/report/binding_memory.rs b/pyrefly/lib/report/binding_memory.rs index 80661881df..fac7aa3c0d 100644 --- a/pyrefly/lib/report/binding_memory.rs +++ b/pyrefly/lib/report/binding_memory.rs @@ -166,6 +166,7 @@ mod tests { pydantic_config_dict: PydanticConfigDict::default(), pydantic_before_validator_fields: Box::default(), django_field_info: Box::default(), + capture_init: None, }; assert_eq!( ReportKey::new(module, &v), From 8ca70085bdccefd042f96939ff7c1b4fd1b30e3a Mon Sep 17 00:00:00 2001 From: stroxler Date: Thu, 21 May 2026 21:19:07 -0700 Subject: [PATCH 08/25] Add `FunctionKind::ShapeDsl` variant MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Add a new `FunctionKind::ShapeDsl(Arc, Arc)` variant that will represent functions whose return types are computed by the shape DSL. The `FuncId` provides identity (module, class, name) for display and lookup; the `ShapeDslFunction` carries the parsed DSL IR. The DSL definition is carried inside the `FunctionKind` variant rather than as a separate `Option` field on `Function` to avoid touching ~30-40 construction sites with `dsl_def: None`. This is semantically equivalent — `dsl_def` is `Some` exactly when `FunctionKind` is `ShapeDsl`, so embedding it in the variant enforces the invariant by construction. Adds `PartialEq`/`Eq`/`Hash`/`Ord`/`Visit`/`VisitMut`/`TypeEq` implementations on `ShapeDslFunction` (pointer-identity semantics, no-op visiting since DSL IR contains no `Type` values). Differential Revision: D105720305 --- crates/pyrefly_types/src/callable.rs | 10 +++- crates/pyrefly_types/src/meta_shape_dsl.rs | 67 ++++++++++++++++++++++ 2 files changed, 76 insertions(+), 1 deletion(-) diff --git a/crates/pyrefly_types/src/callable.rs b/crates/pyrefly_types/src/callable.rs index 7ec216a8f8..70c300da82 100644 --- a/crates/pyrefly_types/src/callable.rs +++ b/crates/pyrefly_types/src/callable.rs @@ -36,6 +36,7 @@ use crate::display::TypeDisplayContext; use crate::equality::TypeEq; use crate::equality::TypeEqCtx; use crate::keywords::DataclassTransformMetadata; +use crate::meta_shape_dsl::ShapeDslFunction; use crate::type_output::TypeOutput; use crate::types::AnyStyle; use crate::types::Type; @@ -810,6 +811,10 @@ pub enum FunctionKind { NumbaJit, /// `numba.njit()` NumbaNjit, + /// A function whose return type is computed by a shape DSL definition. + /// The `FuncId` provides identity (module, class, name) for display and + /// lookup; the `ShapeDslFunction` carries the parsed DSL IR. + ShapeDsl(Arc, Arc), } impl Callable { @@ -1218,6 +1223,7 @@ impl FunctionKind { Self::NumbaJit => ModuleName::from_str("numba"), Self::NumbaNjit => ModuleName::from_str("numba"), Self::Def(func_id) => func_id.module.name().dupe(), + Self::ShapeDsl(id, _) => id.module.name().dupe(), } } @@ -1244,6 +1250,7 @@ impl FunctionKind { Self::NumbaJit => Cow::Owned(Name::new_static("jit")), Self::NumbaNjit => Cow::Owned(Name::new_static("njit")), Self::Def(func_id) => Cow::Borrowed(&func_id.name), + Self::ShapeDsl(id, _) => Cow::Borrowed(&id.name), } } @@ -1270,12 +1277,13 @@ impl FunctionKind { Self::TotalOrdering => None, Self::DisjointBase => None, Self::Def(func_id) => func_id.cls.clone(), + Self::ShapeDsl(id, _) => id.cls.clone(), } } pub fn outer_funcs(&self) -> Option<&Name> { match self { - Self::Def(func_id) => func_id.outer_funcs.as_ref(), + Self::Def(func_id) | Self::ShapeDsl(func_id, _) => func_id.outer_funcs.as_ref(), _ => None, } } diff --git a/crates/pyrefly_types/src/meta_shape_dsl.rs b/crates/pyrefly_types/src/meta_shape_dsl.rs index c6d00bbda7..56264e9fb7 100644 --- a/crates/pyrefly_types/src/meta_shape_dsl.rs +++ b/crates/pyrefly_types/src/meta_shape_dsl.rs @@ -19,11 +19,16 @@ //! //! The data types mirror the DSL grammar defined in `meta_shape_pythonic.md`. +use std::cmp::Ordering; use std::collections::HashMap; use std::fmt; use std::fmt::Debug; +use std::hash::Hash; +use std::hash::Hasher; use std::sync::Arc; +use pyrefly_util::visit::Visit; +use pyrefly_util::visit::VisitMut; use ruff_python_ast::BoolOp as RuffBoolOp; use ruff_python_ast::CmpOp as RuffCmpOp; use ruff_python_ast::Expr; @@ -35,6 +40,8 @@ use ruff_python_ast::UnaryOp as RuffUnaryOp; use crate::dimension::ShapeError; use crate::dimension::SizeExpr; use crate::dimension::canonicalize; +use crate::equality::TypeEq; +use crate::equality::TypeEqCtx; use crate::lit_int::LitInt; use crate::literal::Lit; use crate::tensor::TensorShape; @@ -3076,6 +3083,66 @@ pub struct ShapeDslFunction { pub(crate) inner: Arc, } +/// Pointer identity: two `ShapeDslFunction`s are equal iff they point to the +/// same `DslFnDef` allocation. +impl PartialEq for ShapeDslFunction { + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.inner, &other.inner) + } +} + +impl Eq for ShapeDslFunction {} + +impl Hash for ShapeDslFunction { + fn hash(&self, state: &mut H) { + (Arc::as_ptr(&self.inner) as *const () as usize).hash(state); + } +} + +impl PartialOrd for ShapeDslFunction { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for ShapeDslFunction { + fn cmp(&self, other: &Self) -> Ordering { + let self_ptr = Arc::as_ptr(&self.inner) as *const () as usize; + let other_ptr = Arc::as_ptr(&other.inner) as *const () as usize; + self_ptr.cmp(&other_ptr) + } +} + +/// DSL IR contains no `Type` values, so visiting is a no-op. +impl Visit for ShapeDslFunction { + const RECURSE_CONTAINS: bool = false; + fn recurse<'a>(&'a self, _: &mut dyn FnMut(&'a Type)) {} +} + +/// DSL IR contains no `Type` values, so visiting is a no-op. +impl VisitMut for ShapeDslFunction { + const RECURSE_CONTAINS: bool = false; + fn recurse_mut(&mut self, _: &mut dyn FnMut(&mut Type)) {} +} + +/// DSL IR contains no `Type` values, so visiting through `Arc` is also a no-op. +impl Visit for Arc { + const RECURSE_CONTAINS: bool = false; + fn recurse<'a>(&'a self, _: &mut dyn FnMut(&'a Type)) {} +} + +/// DSL IR contains no `Type` values, so visiting through `Arc` is also a no-op. +impl VisitMut for Arc { + const RECURSE_CONTAINS: bool = false; + fn recurse_mut(&mut self, _: &mut dyn FnMut(&mut Type)) {} +} + +impl TypeEq for ShapeDslFunction { + fn type_eq(&self, other: &Self, _ctx: &mut TypeEqCtx) -> bool { + self == other + } +} + /// A bundle of DSL functions that have been validated together as a program. /// /// The functions held by a `ShapeDslProgram` are guaranteed to have passed From 0747d7733a332f0119dbb1f0c8485a488b86d3ed Mon Sep 17 00:00:00 2001 From: stroxler Date: Thu, 21 May 2026 21:19:07 -0700 Subject: [PATCH 09/25] Detect `@shape_dsl_function` and produce `FunctionKind::ShapeDsl` Summary: Wire the binder and solver so that functions decorated with `shape_dsl_function` (from `shape_extensions.dsl`) are converted to DSL IR at binding time and produce `FunctionKind::ShapeDsl` at solve time. The DSL definition is stored as `shape_dsl_def: Option>` on `BindingUndecoratedFunction` rather than as a separate `Binding` variant, since the function still needs its normal name binding and decorator processing chain. When the solver sees `shape_dsl_def` is `Some`, it constructs `FunctionKind::ShapeDsl(func_id, dsl_fn)` instead of `FunctionKind::Def`. The function's name resolves through the normal binding chain, so `from torch._shapes import reshape_ir` works automatically. The conversion must happen before `function_body()` consumes the AST body via `mem::take`. Conversion failures panic (Phase 7 will add structured diagnostics). Differential Revision: D105728837 --- pyrefly/lib/alt/function.rs | 36 ++++++++++++++++++++++++++------- pyrefly/lib/alt/solve.rs | 1 + pyrefly/lib/binding/binding.rs | 7 ++++++- pyrefly/lib/binding/function.rs | 18 +++++++++++++++++ 4 files changed, 54 insertions(+), 8 deletions(-) diff --git a/pyrefly/lib/alt/function.rs b/pyrefly/lib/alt/function.rs index efad7d3617..f0e7d3c848 100644 --- a/pyrefly/lib/alt/function.rs +++ b/pyrefly/lib/alt/function.rs @@ -15,11 +15,13 @@ use pyrefly_python::ast::Ast; use pyrefly_python::dunder; use pyrefly_python::module_path::ModuleStyle; use pyrefly_python::short_identifier::ShortIdentifier; +use pyrefly_types::callable::FuncId; use pyrefly_types::callable::Params; use pyrefly_types::callable::PlaceholderBodyKind; use pyrefly_types::class::Class; use pyrefly_types::class::ClassType; use pyrefly_types::dimension::SizeExpr; +use pyrefly_types::meta_shape_dsl::ShapeDslFunction; use pyrefly_types::quantified::Quantified; use pyrefly_types::quantified::QuantifiedOrigin; use pyrefly_types::type_var::Restriction; @@ -435,6 +437,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { legacy_tparams: &[Idx], module_style: ModuleStyle, outer_funcs: Option, + shape_dsl_def: Option>, errors: &ErrorCollector, ) -> Arc { let defining_cls = class_key.and_then(|k| self.get_idx(*k).0.dupe()); @@ -536,13 +539,32 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { tparams.extend(legacy_tparams); let tparams = self.validated_tparams(def.range, tparams, TParamsSource::Function, errors); - let kind = FunctionKind::from_name( - self.module().dupe(), - defining_cls.clone(), - &def.name.id, - Some(def_index), - outer_funcs, - ); + let kind = if let Some(dsl_fn) = shape_dsl_def { + // Shape DSL functions carry their parsed IR for call-site evaluation. + // + // TODO: the embedded `ShapeDslFunction` currently carries only this + // function's own DSL body, with no `fn_lookup` of resolved helpers. + // That's adequate for leaf DSL functions but breaks any function that + // calls another `@shape_dsl_function` (e.g. `reshape_ir` -> + // `normalize_dim`). We need to build the per-function `fn_lookup` here + // from resolved transitive callees. + let func_id = Arc::new(FuncId { + module: self.module().dupe(), + cls: defining_cls.clone(), + name: def.name.id.clone(), + def_index: Some(def_index), + outer_funcs, + }); + FunctionKind::ShapeDsl(func_id, dsl_fn) + } else { + FunctionKind::from_name( + self.module().dupe(), + defining_cls.clone(), + &def.name.id, + Some(def_index), + outer_funcs, + ) + }; let metadata = FuncMetadata { kind, flags }; Arc::new(UndecoratedFunction { diff --git a/pyrefly/lib/alt/solve.rs b/pyrefly/lib/alt/solve.rs index 5372ef2953..2c9e263602 100644 --- a/pyrefly/lib/alt/solve.rs +++ b/pyrefly/lib/alt/solve.rs @@ -5351,6 +5351,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { &x.legacy_tparams, x.module_style, x.outer_funcs.clone(), + x.shape_dsl_def.clone(), errors, ) } diff --git a/pyrefly/lib/binding/binding.rs b/pyrefly/lib/binding/binding.rs index 8c23943b41..895f2289f0 100644 --- a/pyrefly/lib/binding/binding.rs +++ b/pyrefly/lib/binding/binding.rs @@ -9,6 +9,7 @@ use std::fmt; use std::fmt::Debug; use std::fmt::Display; use std::hash::Hash; +use std::sync::Arc; use dupe::Dupe; use pyrefly_derive::TypeEq; @@ -22,6 +23,7 @@ use pyrefly_python::short_identifier::ShortIdentifier; use pyrefly_python::symbol_kind::SymbolKind; use pyrefly_types::callable::PlaceholderBodyKind; use pyrefly_types::heap::TypeHeap; +use pyrefly_types::meta_shape_dsl::ShapeDslFunction; use pyrefly_types::special_form::SpecialForm; use pyrefly_types::type_alias::TypeAlias; use pyrefly_types::type_alias::TypeAliasIndex; @@ -129,7 +131,7 @@ assert_words!(BindingYield, 4); assert_words!(BindingYieldFrom, 4); assert_words!(BindingDecorator, 10); assert_bytes!(BindingDecoratedFunction, 20); -assert_words!(BindingUndecoratedFunction, 18); +assert_words!(BindingUndecoratedFunction, 19); #[derive(Clone, Dupe, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum AnyIdx { @@ -1856,6 +1858,9 @@ pub struct BindingUndecoratedFunction { /// Dot-separated path of enclosing function names (e.g. `"f1"` for `f2` defined inside `f1`, /// or `"f1.g1"` for two levels deep). `None` for top-level or class-method functions. pub outer_funcs: Option, + /// When the function is decorated with `@shape_dsl_function`, this holds the + /// parsed DSL IR so the solver can produce `FunctionKind::ShapeDsl`. + pub shape_dsl_def: Option>, } impl DisplayWith for BindingUndecoratedFunction { diff --git a/pyrefly/lib/binding/function.rs b/pyrefly/lib/binding/function.rs index 8d187964f7..f298bb5d54 100644 --- a/pyrefly/lib/binding/function.rs +++ b/pyrefly/lib/binding/function.rs @@ -6,6 +6,7 @@ */ use std::mem; +use std::sync::Arc; use dupe::Dupe as _; use pyrefly_graph::index::Idx; @@ -16,6 +17,7 @@ use pyrefly_python::nesting_context::NestingContext; use pyrefly_python::short_identifier::ShortIdentifier; use pyrefly_python::sys_info::SysInfo; use pyrefly_types::callable::PlaceholderBodyKind; +use pyrefly_types::meta_shape_dsl::convert_shape_dsl_function; use pyrefly_util::prelude::VecExt; use pyrefly_util::visit::Visit; use ruff_python_ast::Decorator; @@ -804,12 +806,27 @@ impl<'a> BindingsBuilder<'a> { _ => (None, None), }; + // Check whether this function is decorated with `@shape_dsl_function` + // before `decorators()` takes the decorator list. + let is_shape_dsl = x.decorator_list.iter().any(|d| { + self.as_special_export(&d.expression) == Some(SpecialExport::ShapeDslFunction) + }); + self.scopes.push(Scope::annotation(x.range)); let (return_ann_with_range, legacy_tparams) = self.function_header(&mut x, &func_name, class_key, def_idx.usage(), parent); let decorators = self.decorators(mem::take(&mut x.decorator_list), def_idx.usage()); + // Convert the function body to DSL IR before `function_body` takes the body. + let shape_dsl_def = if is_shape_dsl { + Some(Arc::new( + convert_shape_dsl_function(&x).expect("@shape_dsl_function body must be valid DSL"), + )) + } else { + None + }; + let docstring_range = Docstring::range_from_stmts(x.body.as_slice()); let (stub_or_impl, placeholder_body_kind, is_return_inferred, self_assignments) = self .function_body( @@ -844,6 +861,7 @@ impl<'a> BindingsBuilder<'a> { legacy_tparams: legacy_tparams.into_boxed_slice(), module_style: self.module_info.path().style(), outer_funcs, + shape_dsl_def, }, ); From 967bc224e341af5adccfefbd90fb82059f335eb8 Mon Sep 17 00:00:00 2001 From: stroxler Date: Thu, 21 May 2026 21:19:07 -0700 Subject: [PATCH 10/25] Extract `@uses_shape_dsl(ir_fn)` argument at binding time Summary: When a function is decorated with `uses_shape_dsl(reshape_ir)`, extract the first positional argument's name from the decorator call AST and store it as `uses_shape_dsl_ir_name: Option` on `BindingUndecoratedFunction`. This is needed because the existing `KwCall` mechanism in the decorator pipeline only captures keyword arguments, not positional ones. Binding-time extraction makes the IR function name available for Phase 4, where the solver will resolve it to a `Type::Function` with `FunctionKind::ShapeDsl` and wire it into `FuncFlags.shape_transform`. Differential Revision: D105728838 --- pyrefly/lib/binding/binding.rs | 6 +++++- pyrefly/lib/binding/function.rs | 14 ++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/pyrefly/lib/binding/binding.rs b/pyrefly/lib/binding/binding.rs index 895f2289f0..3ff076cf21 100644 --- a/pyrefly/lib/binding/binding.rs +++ b/pyrefly/lib/binding/binding.rs @@ -131,7 +131,7 @@ assert_words!(BindingYield, 4); assert_words!(BindingYieldFrom, 4); assert_words!(BindingDecorator, 10); assert_bytes!(BindingDecoratedFunction, 20); -assert_words!(BindingUndecoratedFunction, 19); +assert_words!(BindingUndecoratedFunction, 22); #[derive(Clone, Dupe, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum AnyIdx { @@ -1861,6 +1861,10 @@ pub struct BindingUndecoratedFunction { /// When the function is decorated with `@shape_dsl_function`, this holds the /// parsed DSL IR so the solver can produce `FunctionKind::ShapeDsl`. pub shape_dsl_def: Option>, + /// Name of the IR function passed as the first positional argument to + /// `@uses_shape_dsl(ir_fn)`. Extracted at binding time so the solver can + /// resolve it to a `FunctionKind::ShapeDsl` type. + pub uses_shape_dsl_ir_name: Option, } impl DisplayWith for BindingUndecoratedFunction { diff --git a/pyrefly/lib/binding/function.rs b/pyrefly/lib/binding/function.rs index f298bb5d54..cbcf84a605 100644 --- a/pyrefly/lib/binding/function.rs +++ b/pyrefly/lib/binding/function.rs @@ -812,6 +812,19 @@ impl<'a> BindingsBuilder<'a> { self.as_special_export(&d.expression) == Some(SpecialExport::ShapeDslFunction) }); + // Extract the IR function name from @uses_shape_dsl(ir_fn) if present. + let uses_shape_dsl_ir_name = x.decorator_list.iter().find_map(|d| { + let call = d.expression.as_call_expr()?; + if self.as_special_export(&call.func) != Some(SpecialExport::UsesShapeDsl) { + return None; + } + // The first positional argument is the IR function reference. + let first_arg = call.arguments.args.first()?; + // Must be a simple name (not a dotted path or arbitrary expression). + let name_expr = first_arg.as_name_expr()?; + Some(name_expr.id.clone()) + }); + self.scopes.push(Scope::annotation(x.range)); let (return_ann_with_range, legacy_tparams) = self.function_header(&mut x, &func_name, class_key, def_idx.usage(), parent); @@ -862,6 +875,7 @@ impl<'a> BindingsBuilder<'a> { module_style: self.module_info.path().style(), outer_funcs, shape_dsl_def, + uses_shape_dsl_ir_name, }, ); From a7990e6772a4c51985e6083bdbe44c4afefa4d7d Mon Sep 17 00:00:00 2001 From: stroxler Date: Thu, 21 May 2026 21:19:07 -0700 Subject: [PATCH 11/25] Add `shape_transform` field to `FuncFlags` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Add `ShapeTransformRef` type in `meta_shape_dsl.rs` and a new `shape_transform: Option>` field on `FuncFlags`. `ShapeTransformRef` carries an `Arc` — the resolved DSL function definition. By the time this field is populated (Phase 4b), the IR function name has been resolved to a `Type::Function` with `FunctionKind::ShapeDsl`, so we store the extracted definition directly rather than a name or binding key. Trait impls follow the same pointer-identity pattern as `ShapeDslFunction`: `PartialEq/Eq/Hash/PartialOrd/Ord` delegate to the inner `ShapeDslFunction`, `Visit/VisitMut` are no-ops (DSL IR contains no `Type` values), and `TypeEq` delegates to `PartialEq`. Differential Revision: D105739362 --- crates/pyrefly_types/src/callable.rs | 4 ++ crates/pyrefly_types/src/meta_shape_dsl.rs | 60 ++++++++++++++++++++++ 2 files changed, 64 insertions(+) diff --git a/crates/pyrefly_types/src/callable.rs b/crates/pyrefly_types/src/callable.rs index 70c300da82..d25cee2c92 100644 --- a/crates/pyrefly_types/src/callable.rs +++ b/crates/pyrefly_types/src/callable.rs @@ -37,6 +37,7 @@ use crate::equality::TypeEq; use crate::equality::TypeEqCtx; use crate::keywords::DataclassTransformMetadata; use crate::meta_shape_dsl::ShapeDslFunction; +use crate::meta_shape_dsl::ShapeTransformRef; use crate::type_output::TypeOutput; use crate::types::AnyStyle; use crate::types::Type; @@ -646,6 +647,9 @@ pub struct FuncFlags { /// `dataclass_transform` call. See /// https://typing.python.org/en/latest/spec/dataclasses.html#specification. pub dataclass_transform_metadata: Option, + /// A function decorated with `@uses_shape_dsl`, whose return type should be + /// refined by evaluating the referenced shape-DSL function at call sites. + pub shape_transform: Option>, } impl FuncFlags { diff --git a/crates/pyrefly_types/src/meta_shape_dsl.rs b/crates/pyrefly_types/src/meta_shape_dsl.rs index 56264e9fb7..7fcba0223a 100644 --- a/crates/pyrefly_types/src/meta_shape_dsl.rs +++ b/crates/pyrefly_types/src/meta_shape_dsl.rs @@ -3143,6 +3143,66 @@ impl TypeEq for ShapeDslFunction { } } +/// Reference to a shape-DSL function that refines a callable's return type. +/// Carried on `FuncFlags` for functions decorated with `@uses_shape_dsl`. +#[derive(Debug, Clone)] +pub struct ShapeTransformRef { + pub dsl_fn: Arc, +} + +/// Pointer identity: delegates to `ShapeDslFunction`'s pointer-identity equality. +impl PartialEq for ShapeTransformRef { + fn eq(&self, other: &Self) -> bool { + self.dsl_fn == other.dsl_fn + } +} + +impl Eq for ShapeTransformRef {} + +impl Hash for ShapeTransformRef { + fn hash(&self, state: &mut H) { + self.dsl_fn.hash(state); + } +} + +impl PartialOrd for ShapeTransformRef { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for ShapeTransformRef { + fn cmp(&self, other: &Self) -> Ordering { + self.dsl_fn.cmp(&other.dsl_fn) + } +} + +impl Visit for ShapeTransformRef { + const RECURSE_CONTAINS: bool = false; + fn recurse<'a>(&'a self, _: &mut dyn FnMut(&'a Type)) {} +} + +impl VisitMut for ShapeTransformRef { + const RECURSE_CONTAINS: bool = false; + fn recurse_mut(&mut self, _: &mut dyn FnMut(&mut Type)) {} +} + +impl Visit for Arc { + const RECURSE_CONTAINS: bool = false; + fn recurse<'a>(&'a self, _: &mut dyn FnMut(&'a Type)) {} +} + +impl VisitMut for Arc { + const RECURSE_CONTAINS: bool = false; + fn recurse_mut(&mut self, _: &mut dyn FnMut(&mut Type)) {} +} + +impl TypeEq for ShapeTransformRef { + fn type_eq(&self, other: &Self, _ctx: &mut TypeEqCtx) -> bool { + self == other + } +} + /// A bundle of DSL functions that have been validated together as a program. /// /// The functions held by a `ShapeDslProgram` are guaranteed to have passed From 6dbc7a451b5560a683eb234621afbe1a70a0a67f Mon Sep 17 00:00:00 2001 From: stroxler Date: Thu, 21 May 2026 21:19:07 -0700 Subject: [PATCH 12/25] Wire up `@uses_shape_dsl` decorator recognition and `shape_transform` population MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Add `FunctionKind::UsesShapeDsl` so that calling `uses_shape_dsl(...)` produces a `Type::KwCall`, making the decorator identifiable in `get_special_decorator`. Without this, the decorator's type would be plain `Callable`, indistinguishable from any other callable-returning decorator, and the generic pipeline would produce `Any`. The decorator is consumed via `SpecialDecorator::UsesShapeDsl` → `set_flag_from_special_decorator` (returns `true` to filter it out). The actual `FuncFlags.shape_transform` is populated after the decorator loop in `undecorated_function`, where `uses_shape_dsl_ir_name` (extracted at binding time in Phase 3d) is resolved via `Key::BoundName` to get the IR function's `ShapeDslFunction` from its `FunctionKind::ShapeDsl` variant. To enable solve-time lookup, `uses_shape_dsl_ir_name` is changed from `Option` to `Option<(Name, ShortIdentifier)>`, carrying the `TextRange` needed for `Key::BoundName` resolution (Pyrefly's binding lookup is range-based, not name-based). Differential Revision: D105739363 --- crates/pyrefly_types/src/callable.rs | 6 +++++ pyrefly/lib/alt/call.rs | 4 ++- pyrefly/lib/alt/function.rs | 28 +++++++++++++++++++++ pyrefly/lib/alt/solve.rs | 1 + pyrefly/lib/alt/types/decorated_function.rs | 1 + pyrefly/lib/binding/binding.rs | 4 +-- pyrefly/lib/binding/function.rs | 2 +- 7 files changed, 42 insertions(+), 4 deletions(-) diff --git a/crates/pyrefly_types/src/callable.rs b/crates/pyrefly_types/src/callable.rs index d25cee2c92..d951903dc2 100644 --- a/crates/pyrefly_types/src/callable.rs +++ b/crates/pyrefly_types/src/callable.rs @@ -819,6 +819,8 @@ pub enum FunctionKind { /// The `FuncId` provides identity (module, class, name) for display and /// lookup; the `ShapeDslFunction` carries the parsed DSL IR. ShapeDsl(Arc, Arc), + /// The `shape_extensions.uses_shape_dsl` decorator function itself. + UsesShapeDsl, } impl Callable { @@ -1194,6 +1196,7 @@ impl FunctionKind { ("typing" | "typing_extensions", None, "disjoint_base") => Self::DisjointBase, ("numba.core.decorators", None, "jit") => Self::NumbaJit, ("numba.core.decorators", None, "njit") => Self::NumbaNjit, + ("shape_extensions", None, "uses_shape_dsl") => Self::UsesShapeDsl, _ => Self::Def(Arc::new(FuncId { module, cls, @@ -1228,6 +1231,7 @@ impl FunctionKind { Self::NumbaNjit => ModuleName::from_str("numba"), Self::Def(func_id) => func_id.module.name().dupe(), Self::ShapeDsl(id, _) => id.module.name().dupe(), + Self::UsesShapeDsl => ModuleName::from_str("shape_extensions"), } } @@ -1255,6 +1259,7 @@ impl FunctionKind { Self::NumbaNjit => Cow::Owned(Name::new_static("njit")), Self::Def(func_id) => Cow::Borrowed(&func_id.name), Self::ShapeDsl(id, _) => Cow::Borrowed(&id.name), + Self::UsesShapeDsl => Cow::Owned(Name::new_static("uses_shape_dsl")), } } @@ -1282,6 +1287,7 @@ impl FunctionKind { Self::DisjointBase => None, Self::Def(func_id) => func_id.cls.clone(), Self::ShapeDsl(id, _) => id.cls.clone(), + Self::UsesShapeDsl => None, } } diff --git a/pyrefly/lib/alt/call.rs b/pyrefly/lib/alt/call.rs index 1c085cdbf1..bf8b1072dd 100644 --- a/pyrefly/lib/alt/call.rs +++ b/pyrefly/lib/alt/call.rs @@ -1313,7 +1313,9 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { if let Some(m) = metadata && (matches!( m.kind, - FunctionKind::Dataclass | FunctionKind::DataclassTransform + FunctionKind::Dataclass + | FunctionKind::DataclassTransform + | FunctionKind::UsesShapeDsl ) || m.kind.is_signature_preserving_decorator() || m.flags.dataclass_transform_metadata.is_some()) { diff --git a/pyrefly/lib/alt/function.rs b/pyrefly/lib/alt/function.rs index f0e7d3c848..e10774d212 100644 --- a/pyrefly/lib/alt/function.rs +++ b/pyrefly/lib/alt/function.rs @@ -22,6 +22,7 @@ use pyrefly_types::class::Class; use pyrefly_types::class::ClassType; use pyrefly_types::dimension::SizeExpr; use pyrefly_types::meta_shape_dsl::ShapeDslFunction; +use pyrefly_types::meta_shape_dsl::ShapeTransformRef; use pyrefly_types::quantified::Quantified; use pyrefly_types::quantified::QuantifiedOrigin; use pyrefly_types::type_var::Restriction; @@ -438,6 +439,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { module_style: ModuleStyle, outer_funcs: Option, shape_dsl_def: Option>, + uses_shape_dsl_ir_name: Option<(Name, ShortIdentifier)>, errors: &ErrorCollector, ) -> Arc { let defining_cls = class_key.and_then(|k| self.get_idx(*k).0.dupe()); @@ -565,6 +567,20 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { outer_funcs, ) }; + + // Resolve the IR function reference from @uses_shape_dsl(ir_fn) and + // populate `flags.shape_transform` with the DSL function it points to. + if let Some((_name, ir_identifier)) = uses_shape_dsl_ir_name { + let ir_type = self.get(&Key::BoundName(ir_identifier)).arc_clone_ty(); + if let Type::Function(func) = &ir_type + && let FunctionKind::ShapeDsl(_, dsl_fn) = &func.metadata.kind + { + flags.shape_transform = Some(Arc::new(ShapeTransformRef { + dsl_fn: dsl_fn.clone(), + })); + } + } + let metadata = FuncMetadata { kind, flags }; Arc::new(UndecoratedFunction { @@ -806,6 +822,11 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { { Some(SpecialDecorator::DataclassTransformCall(&call.keywords)) } + _ if let Type::KwCall(call) = &decorator.ty + && call.has_function_kind(FunctionKind::UsesShapeDsl) => + { + Some(SpecialDecorator::UsesShapeDsl) + } Some(CalleeKind::Class(ClassKind::EnumNonmember)) => { Some(SpecialDecorator::EnumNonmember) } @@ -894,6 +915,13 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { flags.is_abstract_method = true; true } + SpecialDecorator::UsesShapeDsl => { + // The actual shape_transform flag is populated after the decorator + // loop in undecorated_function, where uses_shape_dsl_ir_name is + // available. Returning true here just filters the decorator out of + // the list so it doesn't go through the generic decorator pipeline. + true + } _ => false, } } diff --git a/pyrefly/lib/alt/solve.rs b/pyrefly/lib/alt/solve.rs index 2c9e263602..5401e4c59a 100644 --- a/pyrefly/lib/alt/solve.rs +++ b/pyrefly/lib/alt/solve.rs @@ -5352,6 +5352,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { x.module_style, x.outer_funcs.clone(), x.shape_dsl_def.clone(), + x.uses_shape_dsl_ir_name.clone(), errors, ) } diff --git a/pyrefly/lib/alt/types/decorated_function.rs b/pyrefly/lib/alt/types/decorated_function.rs index 3d97bb0c34..45c6fffda9 100644 --- a/pyrefly/lib/alt/types/decorated_function.rs +++ b/pyrefly/lib/alt/types/decorated_function.rs @@ -98,6 +98,7 @@ pub enum SpecialDecorator<'a> { DataclassTransformCall(&'a TypeMap), EnumNonmember, AbstractMethod, + UsesShapeDsl, } impl UndecoratedFunction { diff --git a/pyrefly/lib/binding/binding.rs b/pyrefly/lib/binding/binding.rs index 3ff076cf21..d642f282b4 100644 --- a/pyrefly/lib/binding/binding.rs +++ b/pyrefly/lib/binding/binding.rs @@ -131,7 +131,7 @@ assert_words!(BindingYield, 4); assert_words!(BindingYieldFrom, 4); assert_words!(BindingDecorator, 10); assert_bytes!(BindingDecoratedFunction, 20); -assert_words!(BindingUndecoratedFunction, 22); +assert_words!(BindingUndecoratedFunction, 23); #[derive(Clone, Dupe, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum AnyIdx { @@ -1864,7 +1864,7 @@ pub struct BindingUndecoratedFunction { /// Name of the IR function passed as the first positional argument to /// `@uses_shape_dsl(ir_fn)`. Extracted at binding time so the solver can /// resolve it to a `FunctionKind::ShapeDsl` type. - pub uses_shape_dsl_ir_name: Option, + pub uses_shape_dsl_ir_name: Option<(Name, ShortIdentifier)>, } impl DisplayWith for BindingUndecoratedFunction { diff --git a/pyrefly/lib/binding/function.rs b/pyrefly/lib/binding/function.rs index cbcf84a605..c1c449d91f 100644 --- a/pyrefly/lib/binding/function.rs +++ b/pyrefly/lib/binding/function.rs @@ -822,7 +822,7 @@ impl<'a> BindingsBuilder<'a> { let first_arg = call.arguments.args.first()?; // Must be a simple name (not a dotted path or arbitrary expression). let name_expr = first_arg.as_name_expr()?; - Some(name_expr.id.clone()) + Some((name_expr.id.clone(), ShortIdentifier::expr_name(name_expr))) }); self.scopes.push(Scope::annotation(x.range)); From 67627db4c70c8b8a8d031888f1a30ffdaba3db8b Mon Sep 17 00:00:00 2001 From: stroxler Date: Thu, 21 May 2026 21:19:44 -0700 Subject: [PATCH 13/25] Add overload regression tests for `@uses_shape_dsl` Summary: Verify that `uses_shape_dsl` decorator recognition works correctly: - Plain function: decorator is consumed and function type is preserved - Overloaded with implementation: `shape_transform` flows through `merge_overload_metadata_with_implementation` via `FuncFlags` - Stub-only overloads (no implementation): `shape_transform` flows through `merge_overload_metadata_no_implementation` from the first overload These tests catch regressions if `merge_overload_metadata_*` ever changes in a way that drops `FuncFlags` fields. Differential Revision: D105739364 --- .cargo/config.toml | 1 + pyrefly/lib/test/mod.rs | 1 + pyrefly/lib/test/shape_dsl.rs | 86 +++++++++++++++++++++++++++++++++++ 3 files changed, 88 insertions(+) create mode 100644 pyrefly/lib/test/shape_dsl.rs diff --git a/.cargo/config.toml b/.cargo/config.toml index f3e89cd01e..c3a81b9d69 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -10,3 +10,4 @@ MARSHMALLOW_TEST_PATH = { value = "pyrefly/lib/test/marshmallow/third-party", re GLEAN_SNAPSHOTS_PATH = { value = "pyrefly/lib/report/glean/snapshots", relative = true } REPORT_TEST_PATH = { value = "pyrefly/lib/test/report/test_files", relative = true } STUBGEN_TEST_PATH = { value = "pyrefly/lib/test/stubgen", relative = true } +SHAPE_DSL_TEST_PATH = { value = "test/tensor_shapes/fixtures", relative = true } diff --git a/pyrefly/lib/test/mod.rs b/pyrefly/lib/test/mod.rs index 032d2bcbbe..7f9cc219bc 100644 --- a/pyrefly/lib/test/mod.rs +++ b/pyrefly/lib/test/mod.rs @@ -67,6 +67,7 @@ mod redundant_cast; mod returns; mod scope; mod semantic_syntax_errors; +mod shape_dsl; mod simple; mod slots; mod state; diff --git a/pyrefly/lib/test/shape_dsl.rs b/pyrefly/lib/test/shape_dsl.rs new file mode 100644 index 0000000000..d42c10d0ec --- /dev/null +++ b/pyrefly/lib/test/shape_dsl.rs @@ -0,0 +1,86 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +use crate::test::util::TestEnv; +use crate::testcase; + +fn shape_dsl_env() -> TestEnv { + let path = std::env::var("SHAPE_DSL_TEST_PATH").expect("SHAPE_DSL_TEST_PATH must be set"); + let mut env = TestEnv::new_with_site_package_paths(&[&path]); + env.add_with_path( + "my_shapes", + "my_shapes.pyi", + r#" +from shape_extensions.dsl import shape_dsl_function + +@shape_dsl_function +def identity_ir(x: int) -> int: + return x +"#, + ); + env.add_with_path( + "my_lib", + "my_lib.pyi", + r#" +from typing import overload +from shape_extensions import uses_shape_dsl +from my_shapes import identity_ir + +@uses_shape_dsl(identity_ir) +def plain_fn(x: int) -> int: ... + +@overload +def overloaded_with_impl(x: int) -> int: ... +@overload +def overloaded_with_impl(x: str) -> str: ... +@uses_shape_dsl(identity_ir) +def overloaded_with_impl(x: int | str) -> int | str: ... + +@uses_shape_dsl(identity_ir) +@overload +def overloaded_no_impl(x: int) -> int: ... +@overload +def overloaded_no_impl(x: str) -> str: ... +"#, + ); + env +} + +testcase!( + test_uses_shape_dsl_preserves_type, + shape_dsl_env(), + r#" +from typing import assert_type +from my_lib import plain_fn + +assert_type(plain_fn(1), int) +"#, +); + +testcase!( + test_uses_shape_dsl_overload_with_implementation, + shape_dsl_env(), + r#" +from typing import assert_type +from my_lib import overloaded_with_impl + +assert_type(overloaded_with_impl(1), int) +assert_type(overloaded_with_impl("a"), str) +"#, +); + +testcase!( + test_uses_shape_dsl_overload_no_implementation, + shape_dsl_env(), + r#" +from typing import assert_type +from my_lib import overloaded_no_impl + +assert_type(overloaded_no_impl(1), int) +assert_type(overloaded_no_impl("a"), str) +"#, +); From 1fe03d3b15c23460b243da92c36ce87368fe46ef Mon Sep 17 00:00:00 2001 From: stroxler Date: Thu, 21 May 2026 21:19:44 -0700 Subject: [PATCH 14/25] Document why `val_to_type` Int/Bool branches use `Literal[n]` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: The `DslType::Int` and `DslType::Bool` branches in `val_to_type` synthesize `Literal[n]` / `Literal[bool]` from the DSL's traced runtime value. This looks inconsistent with the other branches (Tensor, List, Tuple, None, Str) which return `expected_return_type.clone()`, but the difference is intentional and load-bearing. Functions like `dim_ir`, `numel_ir`, and `size_ir(dim=N)` trace exact integer results. Downstream consumers (assert_type, reshape validation, shape inference) rely on this literal precision. The fixture return type for these functions is just `int` — the literal value comes solely from DSL evaluation. In contrast, the Tensor/List branches' `expected_return_type` already carries refined structure (e.g. `Tensor[B, C, H, W]` with shape injected), so cloning it is correct there. This commit adds comments explaining the invariant so future readers don't mistake the asymmetry for a bug. Differential Revision: D105758573 --- crates/pyrefly_types/src/meta_shape_dsl.rs | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/crates/pyrefly_types/src/meta_shape_dsl.rs b/crates/pyrefly_types/src/meta_shape_dsl.rs index 7fcba0223a..9040aa125d 100644 --- a/crates/pyrefly_types/src/meta_shape_dsl.rs +++ b/crates/pyrefly_types/src/meta_shape_dsl.rs @@ -2904,6 +2904,21 @@ fn val_to_type( ), }, + // Int and Bool synthesize Literal[n] / Literal[bool] from the DSL's + // traced runtime value, just like SymInt does via `val_to_scalar_type`. + // This is intentionally load-bearing: functions like `dim_ir`, + // `numel_ir`, and `size_ir(dim=N)` trace exact integer results, and + // downstream consumers (assert_type, reshape validation, shape + // inference) rely on the literal precision. Returning + // `expected_return_type` here would discard the traced value and + // produce `int` instead of e.g. `Literal[3]`. + // + // This differs from the Tensor/List/Tuple/None/Str branches, which + // return `expected_return_type.clone()`. Those branches are correct + // because their `expected_return_type` already carries the refined + // structure (e.g. `Tensor[B, C, H, W]` with shape injected). For + // scalars, the fixture return type is just `int` — the literal value + // comes solely from DSL evaluation. DslType::Int => match val { Val::Int(n) => Lit::Int(LitInt::new(n)).to_implicit_type(), _ => panic!( @@ -2912,6 +2927,10 @@ fn val_to_type( ), }, + // SymInt synthesizes a type from the traced `Val`: `val_to_scalar_type` + // returns `Type::Dim` for `Val::Dim` and `Literal[n]` for `Val::Int`. + // The trace value is load-bearing for shape inference — downstream + // tensor shape types are built from these dimension representations. DslType::SymInt => val_to_scalar_type(&val), DslType::Bool => match val { From c12ccd5b33fc2fa3c4cd9d74e33756eb9e1512f1 Mon Sep 17 00:00:00 2001 From: stroxler Date: Thu, 21 May 2026 21:19:44 -0700 Subject: [PATCH 15/25] Wire up solver consumption with legacy fallback MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Add `ShapeTransformRef::to_meta_shape_function()` which builds a `DslMetaShapeFunction` from the decorator-carried DSL definition, and wire it into `callable_infer_inner` so the decorator-based `shape_transform` is preferred over the legacy registry lookup. The registry serves as fallback for functions not yet migrated to `uses_shape_dsl`. The decorator path is intentionally not gated by the `tensor_shapes` flag since `uses_shape_dsl` is itself the opt-in. Thread `shape_transform: Option<&ShapeTransformRef>` through the call inference chain (`call_infer_with_callee_range` → `call_infer_inner` → `callable_infer` → `callable_infer_inner`) and the overload path (`call_overloads` → `find_closest_overload` → `call_overload`). Fix ordering in the binder: `convert_shape_dsl_function` must run before `function_header`, which consumes `x.returns` via `mem::take`. Without this, the DSL converter sees no return type annotation and produces `DslFnDef.return_type = None`. Update `maybe_wrap_nn_module` to check `ClassMetadata.capture_init()` first, falling back to `TensorOpsRegistry::get_init_capture` for classes not yet migrated. Differential Revision: D105758771 --- crates/pyrefly_types/src/meta_shape_dsl.rs | 11 ++++++ pyrefly/lib/alt/call.rs | 41 ++++++++++++++++------ pyrefly/lib/alt/callable.rs | 20 ++++++++--- pyrefly/lib/alt/class/dataclass.rs | 1 + pyrefly/lib/alt/overload.rs | 10 ++++++ pyrefly/lib/binding/function.rs | 15 ++++---- pyrefly/lib/test/shape_dsl.rs | 15 ++++---- 7 files changed, 86 insertions(+), 27 deletions(-) diff --git a/crates/pyrefly_types/src/meta_shape_dsl.rs b/crates/pyrefly_types/src/meta_shape_dsl.rs index 9040aa125d..d5c545644b 100644 --- a/crates/pyrefly_types/src/meta_shape_dsl.rs +++ b/crates/pyrefly_types/src/meta_shape_dsl.rs @@ -3222,6 +3222,17 @@ impl TypeEq for ShapeTransformRef { } } +impl ShapeTransformRef { + /// Build a `MetaShapeFunction` evaluator from this shape transform reference. + /// Uses an empty `fn_lookup` — cross-function DSL calls are not yet supported. + pub fn to_meta_shape_function(&self) -> Box { + Box::new(DslMetaShapeFunction { + fn_def: self.dsl_fn.inner.clone(), + fn_lookup: Arc::new(HashMap::new()), + }) + } +} + /// A bundle of DSL functions that have been validated together as a program. /// /// The functions held by a `ShapeDslProgram` are guaranteed to have passed diff --git a/pyrefly/lib/alt/call.rs b/pyrefly/lib/alt/call.rs index bf8b1072dd..69280dda59 100644 --- a/pyrefly/lib/alt/call.rs +++ b/pyrefly/lib/alt/call.rs @@ -10,6 +10,7 @@ use std::sync::Arc; use dupe::Dupe; use pyrefly_python::dunder; +use pyrefly_types::meta_shape_dsl::ShapeTransformRef; use pyrefly_types::quantified::Quantified; use pyrefly_types::special_form::SpecialForm; use pyrefly_types::tensor_ops_registry::TensorOpsRegistry; @@ -1120,14 +1121,27 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { errors: &ErrorCollector, result: Type, ) -> Type { - use std::sync::OnceLock; - static TENSOR_OPS_REGISTRY: OnceLock = OnceLock::new(); + // Check ClassMetadata.capture_init first (populated from @uses_shape_dsl + // decorator on the forward method). Fall back to TensorOpsRegistry for + // classes not yet migrated to stub-based declarations. + let class_metadata = self.get_metadata_for_class(ct.class_object()); + let capture_names_from_metadata: Vec; + let capture_names: &[Name] = if let Some(names) = class_metadata.capture_init() { + capture_names_from_metadata = names.to_vec(); + &capture_names_from_metadata + } else { + use std::sync::OnceLock; + static TENSOR_OPS_REGISTRY: OnceLock = OnceLock::new(); - let class_name = format!("{}.{}", ct.class_object().module_name(), ct.name()); - let registry = TENSOR_OPS_REGISTRY.get_or_init(TensorOpsRegistry::new); - let capture_names = match registry.get_init_capture(&class_name) { - Some(names) => names, - None => return result, + let class_name = format!("{}.{}", ct.class_object().module_name(), ct.name()); + let registry = TENSOR_OPS_REGISTRY.get_or_init(TensorOpsRegistry::new); + match registry.get_init_capture(&class_name) { + Some(names) => { + capture_names_from_metadata = names.iter().map(Name::new).collect(); + &capture_names_from_metadata + } + None => return result, + } }; let infer_type_or_expr = |toe: TypeOrExpr, errors: &ErrorCollector| -> Type { @@ -1139,17 +1153,16 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { let mut fields = SmallMap::new(); for (i, param_name) in capture_names.iter().enumerate() { - let name = Name::new(param_name); // First check keyword args. if let Some(kw) = keywords.iter().find(|k| { k.arg .is_some_and(|id| id.id.as_str() == param_name.as_str()) }) { - fields.insert(name, infer_type_or_expr(kw.value, errors)); + fields.insert(param_name.clone(), infer_type_or_expr(kw.value, errors)); } else if i < args.len() { // Map positional arg by index to the capture param name. if let CallArg::Arg(toe) = &args[i] { - fields.insert(name, infer_type_or_expr(*toe, errors)); + fields.insert(param_name.clone(), infer_type_or_expr(*toe, errors)); } } // If neither keyword nor positional, the param uses its default. @@ -1422,6 +1435,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { ) => self.call_infer_inner( signature, Some(&metadata.kind), + metadata.flags.shape_transform.as_deref(), tparams.as_deref(), Some(obj), args, @@ -1436,6 +1450,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { CallTarget::Callable(TargetWithTParams(tparams, callable)) => self.call_infer_inner( callable, None, + None, tparams.as_deref(), None, args, @@ -1456,6 +1471,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { )) => self.call_infer_inner( callable, Some(&metadata.kind), + metadata.flags.shape_transform.as_deref(), tparams.as_deref(), None, args, @@ -1471,6 +1487,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { self.call_overloads( overloads, &metadata, + metadata.flags.shape_transform.as_deref(), None, args, keywords, @@ -1486,6 +1503,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { self.call_overloads( overloads, &meta, + meta.flags.shape_transform.as_deref(), Some(obj), args, keywords, @@ -1553,6 +1571,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { &self, callable: Callable, callable_name: Option<&FunctionKind>, + shape_transform: Option<&ShapeTransformRef>, tparams: Option<&TParams>, self_obj: Option, args: &[CallArg], @@ -1570,6 +1589,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { let res_no_hint = self.callable_infer( callable.clone(), callable_name, + shape_transform, tparams, self_obj.clone(), args, @@ -1589,6 +1609,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { let res_with_hint = self.callable_infer( callable, callable_name, + shape_transform, tparams, self_obj, args, diff --git a/pyrefly/lib/alt/callable.rs b/pyrefly/lib/alt/callable.rs index 551cf0e905..0cd81b1478 100644 --- a/pyrefly/lib/alt/callable.rs +++ b/pyrefly/lib/alt/callable.rs @@ -12,6 +12,7 @@ use itertools::Itertools; use pyrefly_python::dunder; use pyrefly_types::callable::FunctionKind; use pyrefly_types::meta_shape_dsl::MetaShapeFunction; +use pyrefly_types::meta_shape_dsl::ShapeTransformRef; use pyrefly_types::tensor_ops_registry::TensorOpsRegistry; use pyrefly_types::tuple::Tuple; use pyrefly_types::typed_dict::ExtraItems; @@ -1326,6 +1327,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { &self, callable: Callable, callable_name: Option<&FunctionKind>, + shape_transform: Option<&ShapeTransformRef>, tparams: Option<&TParams>, self_obj: Option, args: &[CallArg], @@ -1348,6 +1350,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { self.callable_infer_inner( callable.clone(), callable_name, + shape_transform, tparams, self_obj.clone(), args, @@ -1368,6 +1371,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { &self, callable: Callable, callable_name: Option<&FunctionKind>, + shape_transform: Option<&ShapeTransformRef>, tparams: Option<&TParams>, mut self_obj: Option, mut args: &[CallArg], @@ -1387,14 +1391,22 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { call_context.set_argument_side(ArgumentSide::Got); call_context.require_boundary_consumption(); - // Look up meta-shape early so we can conditionally collect bound args. - // Only consult the registry when tensor_shapes is enabled to avoid - // unnecessary DSL parsing and per-call HashMap lookups. - let meta_shape_func = if self.solver().tensor_shapes { + // Prefer decorator-based shape_transform over the legacy registry. + // Fall back to lookup_meta_shape only when shape_transform is None. + // The decorator path is not gated by tensor_shapes — @uses_shape_dsl + // is itself the opt-in. The registry applies implicitly by qualified + // name and needs the feature gate to avoid unnecessary DSL parsing + // and per-call HashMap lookups. + let shape_transform_func = shape_transform.map(|t| t.to_meta_shape_function()); + let registry_func = if shape_transform_func.is_none() && self.solver().tensor_shapes { Self::lookup_meta_shape(callable_name) } else { None }; + let meta_shape_func: Option<&dyn MetaShapeFunction> = shape_transform_func + .as_ref() + .map(|b| &**b) + .or(registry_func); let mut bound_args: Option> = meta_shape_func.map(|_| HashMap::new()); let (callable_qs, mut callable) = if let Some(tparams) = tparams { diff --git a/pyrefly/lib/alt/class/dataclass.rs b/pyrefly/lib/alt/class/dataclass.rs index 85b4969cba..a1a8b8c03b 100644 --- a/pyrefly/lib/alt/class/dataclass.rs +++ b/pyrefly/lib/alt/class/dataclass.rs @@ -678,6 +678,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { })) .unwrap(), &overload.metadata, + None, // no shape_transform for dataclass constructors None, &args.args.map(CallArg::expr_maybe_starred), &args.keywords.map(CallKeyword::new), diff --git a/pyrefly/lib/alt/overload.rs b/pyrefly/lib/alt/overload.rs index 28311de3c7..aba14988e4 100644 --- a/pyrefly/lib/alt/overload.rs +++ b/pyrefly/lib/alt/overload.rs @@ -13,6 +13,7 @@ use itertools::Itertools; use pyrefly_types::callable::ArgCount; use pyrefly_types::callable::ArgCounts; use pyrefly_types::callable::Param; +use pyrefly_types::meta_shape_dsl::ShapeTransformRef; use pyrefly_types::tuple::Tuple; use pyrefly_types::types::TArgs; use pyrefly_util::gas::Gas; @@ -231,6 +232,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { &self, overloads: Vec1>, metadata: &FuncMetadata, + shape_transform: Option<&ShapeTransformRef>, self_obj: Option, args: &[CallArg], keywords: &[CallKeyword], @@ -286,6 +288,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { let (mut closest_overload, mut matched) = self.find_closest_overload( &arity_compatible_overloads, metadata, + shape_transform, self_obj.as_ref(), &args, &keywords, @@ -310,6 +313,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { let (cur_closest, cur_matched) = self.find_closest_overload( &arity_compatible_overloads, metadata, + shape_transform, self_obj.as_ref(), cur_args, cur_keywords, @@ -516,6 +520,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { &self, overloads: &Vec1<&'c TargetWithTParams>, metadata: &FuncMetadata, + shape_transform: Option<&ShapeTransformRef>, self_obj: Option<&Type>, args: &[CallArg], keywords: &[CallKeyword], @@ -535,6 +540,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { let called_overload = self.call_overload( callable, metadata, + shape_transform, self_obj, args, keywords, @@ -656,6 +662,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { let res = self.call_overload( o.func, metadata, + shape_transform, self_obj, &materialized_args, &materialized_keywords, @@ -687,6 +694,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { let contextual_overload = self.call_overload( overload.func, metadata, + shape_transform, self_obj, args, keywords, @@ -811,6 +819,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { &self, callable: &'c TargetWithTParams, metadata: &FuncMetadata, + shape_transform: Option<&ShapeTransformRef>, self_obj: Option<&Type>, args: &[CallArg], keywords: &[CallKeyword], @@ -830,6 +839,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { let (res, specialization_errors, expected_types) = self.callable_infer( callable.1.signature.clone(), Some(&metadata.kind), + shape_transform, tparams, self_obj.cloned(), args, diff --git a/pyrefly/lib/binding/function.rs b/pyrefly/lib/binding/function.rs index c1c449d91f..76a48ccc96 100644 --- a/pyrefly/lib/binding/function.rs +++ b/pyrefly/lib/binding/function.rs @@ -825,13 +825,8 @@ impl<'a> BindingsBuilder<'a> { Some((name_expr.id.clone(), ShortIdentifier::expr_name(name_expr))) }); - self.scopes.push(Scope::annotation(x.range)); - let (return_ann_with_range, legacy_tparams) = - self.function_header(&mut x, &func_name, class_key, def_idx.usage(), parent); - - let decorators = self.decorators(mem::take(&mut x.decorator_list), def_idx.usage()); - - // Convert the function body to DSL IR before `function_body` takes the body. + // Convert the function to DSL IR before `function_header` takes `returns` + // and before `function_body` takes `body`. let shape_dsl_def = if is_shape_dsl { Some(Arc::new( convert_shape_dsl_function(&x).expect("@shape_dsl_function body must be valid DSL"), @@ -840,6 +835,12 @@ impl<'a> BindingsBuilder<'a> { None }; + self.scopes.push(Scope::annotation(x.range)); + let (return_ann_with_range, legacy_tparams) = + self.function_header(&mut x, &func_name, class_key, def_idx.usage(), parent); + + let decorators = self.decorators(mem::take(&mut x.decorator_list), def_idx.usage()); + let docstring_range = Docstring::range_from_stmts(x.body.as_slice()); let (stub_or_impl, placeholder_body_kind, is_return_inferred, self_assignments) = self .function_body( diff --git a/pyrefly/lib/test/shape_dsl.rs b/pyrefly/lib/test/shape_dsl.rs index d42c10d0ec..aa39e0f7a8 100644 --- a/pyrefly/lib/test/shape_dsl.rs +++ b/pyrefly/lib/test/shape_dsl.rs @@ -54,10 +54,13 @@ testcase!( test_uses_shape_dsl_preserves_type, shape_dsl_env(), r#" -from typing import assert_type +from typing import Literal, assert_type from my_lib import plain_fn -assert_type(plain_fn(1), int) +# identity_ir returns its input unchanged. Because val_to_type synthesizes +# Literal[n] from the DSL's traced integer value (not the declared return +# type), the result is Literal[1], not int. +assert_type(plain_fn(1), Literal[1]) "#, ); @@ -65,10 +68,10 @@ testcase!( test_uses_shape_dsl_overload_with_implementation, shape_dsl_env(), r#" -from typing import assert_type +from typing import Literal, assert_type from my_lib import overloaded_with_impl -assert_type(overloaded_with_impl(1), int) +assert_type(overloaded_with_impl(1), Literal[1]) assert_type(overloaded_with_impl("a"), str) "#, ); @@ -77,10 +80,10 @@ testcase!( test_uses_shape_dsl_overload_no_implementation, shape_dsl_env(), r#" -from typing import assert_type +from typing import Literal, assert_type from my_lib import overloaded_no_impl -assert_type(overloaded_no_impl(1), int) +assert_type(overloaded_no_impl(1), Literal[1]) assert_type(overloaded_no_impl("a"), str) "#, ); From bad31c93cb21365aa1d5da74bc32adc7376f40e1 Mon Sep 17 00:00:00 2001 From: stroxler Date: Thu, 21 May 2026 21:19:44 -0700 Subject: [PATCH 16/25] Create `torch/_shapes.pyi` with all DSL functions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Move all 86 shape DSL functions (14 helpers + 72 IR functions) from `DSL_SOURCE` in `tensor_ops_registry.rs` to `test/tensor_shapes/fixtures/torch/_shapes.pyi`, adding `shape_dsl_function` decorators. The function bodies are verbatim copies. No behavioral change — nothing references this file yet. The decorators will be consumed in commit 5c when `uses_shape_dsl` decorators are added to the torch fixture stubs. Differential Revision: D105775130 --- test/tensor_shapes/fixtures/torch/_shapes.pyi | 805 ++++++++++++++++++ 1 file changed, 805 insertions(+) create mode 100644 test/tensor_shapes/fixtures/torch/_shapes.pyi diff --git a/test/tensor_shapes/fixtures/torch/_shapes.pyi b/test/tensor_shapes/fixtures/torch/_shapes.pyi new file mode 100644 index 0000000000..56d8b37f13 --- /dev/null +++ b/test/tensor_shapes/fixtures/torch/_shapes.pyi @@ -0,0 +1,805 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from shape_extensions.dsl import shape_dsl_function + +@shape_dsl_function +def normalize_dim(rank: int, dim: int) -> int: + if dim < 0: + return dim + rank + return dim + +@shape_dsl_function +def int_max(a: int, b: int) -> int: + if a > b: + return a + return b + +@shape_dsl_function +def replace_dim( + dims: list[int | symint], i: int, value: int | symint +) -> list[int | symint]: + return dims[:i] + [value] + dims[i + 1 :] + +@shape_dsl_function +def remove_dim(dims: list[int | symint], i: int) -> list[int | symint]: + return dims[:i] + dims[i + 1 :] + +@shape_dsl_function +def insert_dim( + dims: list[int | symint], i: int, value: int | symint +) -> list[int | symint]: + return dims[:i] + [value] + dims[i:] + +@shape_dsl_function +def broadcast(a: list[int | symint], b: list[int | symint]) -> list[int | symint]: + max_len = int_max(len(a), len(b)) + padded_a = [1 for _ in range(max_len - len(a))] + a + padded_b = [1 for _ in range(max_len - len(b))] + b + return [bd if ad == 1 else ad for ad, bd in zip(padded_a, padded_b)] + +@shape_dsl_function +def broadcast_int( + expr: int | symint | list[int | symint], n: int +) -> list[int | symint]: + if isinstance(expr, list): + return expr + return [expr for _ in range(n)] + +@shape_dsl_function +def reduce_shape( + dims: list[int | symint], dim: int | list[int] | None, keepdim: bool +) -> list[int | symint]: + if dim == None: + if keepdim: + return [1 for _ in range(len(dims))] + return [] + dim_list = dim if isinstance(dim, list) else [dim] + norm = [normalize_dim(len(dims), d) for d in dim_list] + return [ + 1 if i in norm else elem + for i, elem in enumerate(dims) + if not (i in norm) or keepdim + ] + +@shape_dsl_function +def contains(lst: list[int], val: int) -> bool: + return len([x for x in lst if x == val]) > 0 + +@shape_dsl_function +def scatter(size: int, indices: list[int], values: list[int], fill: int) -> list[int]: + matches = [[k for k in range(len(indices)) if indices[k] == i] for i in range(size)] + return [values[m[0]] if len(m) > 0 else fill for m in matches] + +@shape_dsl_function +def move_dims( + dims: list[int | symint], source: int | list[int], dest: int | list[int], rank: int +) -> list[int | symint]: + src = broadcast_int(source, 1) + dst = broadcast_int(dest, 1) + src_norm = [normalize_dim(rank, s) for s in src] + dst_norm = [normalize_dim(rank, d) for d in dst] + non_dst = [i for i in range(rank) if not contains(dst_norm, i)] + remaining = [i for i in range(rank) if not contains(src_norm, i)] + perm = scatter(rank, dst_norm + non_dst, src_norm + remaining, 0) + return [dims[p] for p in perm] + +@shape_dsl_function +def conv_spatial_out( + input_dim: int | symint, + kernel: int | symint, + stride: int | symint, + padding: int | symint, + dilation: int | symint, +) -> int | symint: + return (input_dim + 2 * padding - dilation * (kernel - 1) - 1) // stride + 1 + +@shape_dsl_function +def reshape_ir(self: Tensor, shape: list[int | symint]) -> Tensor: + minus_one_count = len([d for d in shape if d == -1]) + if minus_one_count > 1: + raise Error("can only specify one unknown dimension as -1") + has_bad_neg = len([d for d in shape if isinstance(d, int) and d < -1]) > 0 + if has_bad_neg: + raise Error("invalid negative dimension value (only -1 is allowed)") + has_zero = len([d for d in shape if isinstance(d, int) and d == 0]) > 0 + if has_zero: + raise Error("reshape dimensions cannot contain 0") + if minus_one_count > 0: + known = shape_extensions.dsl.prod([d for d in shape if d != -1]) + total = shape_extensions.dsl.prod(self.shape) + if isinstance(total, int) and isinstance(known, int) and total % known != 0: + raise Error( + "could not infer size for dimension -1: expected " + + str(total) + + " to be divisible by " + + str(known) + ) + return Tensor(shape=[total // known if d == -1 else d for d in shape]) + return Tensor(shape=shape) + +@shape_dsl_function +def squeeze_ir(self: Tensor, dim: int | None = None) -> Tensor: + if dim == None: + return Tensor(shape=[d for d in self.shape if d != 1]) + idx = normalize_dim(len(self.shape), dim) + return Tensor( + shape=[d for i, d in enumerate(self.shape) if not (i == idx and d == 1)] + ) + +@shape_dsl_function +def unsqueeze_ir(self: Tensor, dim: int) -> Tensor: + d = normalize_dim(len(self.shape) + 1, dim) + return Tensor(shape=insert_dim(self.shape, d, 1)) + +@shape_dsl_function +def transpose_ir(self: Tensor, dim0: int, dim1: int) -> Tensor: + rank = len(self.shape) + d0 = normalize_dim(rank, dim0) + d1 = normalize_dim(rank, dim1) + return Tensor( + shape=[ + self.shape[d1] if i == d0 else self.shape[d0] if i == d1 else d + for i, d in enumerate(self.shape) + ] + ) + +@shape_dsl_function +def permute_ir(self: Tensor, dims: list[int]) -> Tensor: + rank = len(self.shape) + if len(dims) != rank: + raise Error("permute: expected " + str(rank) + " dims, got " + str(len(dims))) + return Tensor(shape=[self.shape[normalize_dim(rank, d)] for d in dims]) + +@shape_dsl_function +def flatten_ir(self: Tensor, start_dim: int = 0, end_dim: int = -1) -> Tensor: + rank = len(self.shape) + s = normalize_dim(rank, start_dim) + e = normalize_dim(rank, end_dim) + return Tensor( + shape=self.shape[:s] + + [shape_extensions.dsl.prod(self.shape[s : e + 1])] + + self.shape[e + 1 :] + ) + +@shape_dsl_function +def expand_ir(self: Tensor, sizes: list[int | symint]) -> Tensor: + return Tensor(shape=[d if t == -1 else t for d, t in zip(self.shape, sizes)]) + +@shape_dsl_function +def repeat_ir(self: Tensor, sizes: list[int | symint]) -> Tensor: + return Tensor(shape=[d * r for d, r in zip(self.shape, sizes)]) + +@shape_dsl_function +def unbind_ir(self: Tensor, dim: int = 0) -> list[Tensor]: + d = normalize_dim(len(self.shape), dim) + return [Tensor(shape=remove_dim(self.shape, d)), ...] + +@shape_dsl_function +def movedim_ir( + self: Tensor, source: int | list[int], destination: int | list[int] +) -> Tensor: + return Tensor(shape=move_dims(self.shape, source, destination, len(self.shape))) + +@shape_dsl_function +def unfold_ir( + self: Tensor, dimension: int, size: int | symint, step: int = 1 +) -> Tensor: + d = normalize_dim(len(self.shape), dimension) + new_dim = (self.shape[d] - size) // step + 1 + return Tensor(shape=replace_dim(self.shape, d, new_dim) + [size]) + +@shape_dsl_function +def cat_ir(tensors: list[Tensor], dim: int = 0) -> Tensor: + first = tensors[0] + d = normalize_dim(len(first.shape), dim) + return Tensor( + shape=[ + shape_extensions.dsl.sum([t.shape[i] for t in tensors]) + if i == d + else dim_val + for i, dim_val in enumerate(first.shape) + ] + ) + +@shape_dsl_function +def stack_ir(tensors: list[Tensor], dim: int = 0) -> Tensor: + first = tensors[0] + d = normalize_dim(len(first.shape) + 1, dim) + return Tensor(shape=insert_dim(first.shape, d, len(tensors))) + +@shape_dsl_function +def broadcast_to_ir(self: Tensor, shape: list[int | symint]) -> Tensor: + return Tensor(shape=shape) + +@shape_dsl_function +def tile_ir(self: Tensor, dims: list[int]) -> Tensor: + rank = len(self.shape) + if len(dims) > rank: + extra = len(dims) - rank + return Tensor( + shape=[r for r in dims[:extra]] + + [d * r for d, r in zip(self.shape, dims[extra:])] + ) + return Tensor(shape=[d * r for d, r in zip(self.shape, dims)]) + +@shape_dsl_function +def select_ir(self: Tensor, dim: int) -> Tensor: + d = normalize_dim(len(self.shape), dim) + return Tensor(shape=remove_dim(self.shape, d)) + +@shape_dsl_function +def narrow_ir(self: Tensor, dim: int, length: int | symint) -> Tensor: + return Tensor( + shape=replace_dim(self.shape, normalize_dim(len(self.shape), dim), length) + ) + +@shape_dsl_function +def split_ir( + self: Tensor, + split_size_or_sections: int | symint | list[int | symint] | None = None, + dim: int = 0, +) -> list[Tensor]: + d = normalize_dim(len(self.shape), dim) + if isinstance(split_size_or_sections, list): + return [ + Tensor(shape=replace_dim(self.shape, d, section)) + for section in split_size_or_sections + ] + if isinstance(split_size_or_sections, int): + dim_val = self.shape[d] + if isinstance(dim_val, int): + count = (dim_val + split_size_or_sections - 1) // split_size_or_sections + return [ + Tensor( + shape=replace_dim( + self.shape, + d, + split_size_or_sections + if i < count - 1 + else dim_val - (count - 1) * split_size_or_sections, + ) + ) + for i in range(count) + ] + return [Tensor(shape=replace_dim(self.shape, d, split_size_or_sections)), ...] + if split_size_or_sections != None: + quotient = self.shape[d] // split_size_or_sections + if isinstance(quotient, int): + return [ + Tensor(shape=replace_dim(self.shape, d, split_size_or_sections)) + for _ in range(quotient) + ] + return [Tensor(shape=replace_dim(self.shape, d, split_size_or_sections)), ...] + return Unknown + +@shape_dsl_function +def chunk_ir(self: Tensor, chunks: int, dim: int = 0) -> list[Tensor]: + d = normalize_dim(len(self.shape), dim) + dim_val = self.shape[d] + if isinstance(dim_val, int): + chunk_size = (dim_val + chunks - 1) // chunks + return [ + Tensor( + shape=replace_dim( + self.shape, + d, + chunk_size + if i < chunks - 1 + else dim_val - (chunks - 1) * chunk_size, + ) + ) + for i in range(chunks) + ] + return [ + Tensor(shape=replace_dim(self.shape, d, dim_val // chunks)) + for i in range(chunks) + ] + +@shape_dsl_function +def index_select_ir(self: Tensor, dim: int, index: Tensor) -> Tensor: + return Tensor( + shape=replace_dim( + self.shape, normalize_dim(len(self.shape), dim), index.shape[0] + ) + ) + +@shape_dsl_function +def reduce_ir( + self: Tensor, dim: int | list[int] | None = None, keepdim: bool = False +) -> Tensor: + if dim == None: + return Tensor(shape=reduce_shape(self.shape, dim, keepdim)) + if isinstance(dim, list): + return Tensor(shape=reduce_shape(self.shape, dim, keepdim)) + return Tensor(shape=reduce_single(self.shape, dim, keepdim)) + +@shape_dsl_function +def reduce_single( + dims: list[int | symint], dim: int, keepdim: bool +) -> list[int | symint]: + before = dims[:dim] + if dim == -1: + if keepdim: + return before + [1] + return before + after = dims[dim + 1 :] + if keepdim: + return before + [1] + after + return before + after + +@shape_dsl_function +def min_max_median_ir( + self: Tensor, dim: int | None = None, keepdim: bool = False +) -> Tensor: + if dim == None: + return Tensor(shape=[]) + s = reduce_shape(self.shape, dim, keepdim) + return [Tensor(shape=s), Tensor(shape=s)] + +@shape_dsl_function +def aminmax_ir( + self: Tensor, dim: int | list[int] | None = None, keepdim: bool = False +) -> [Tensor, Tensor]: + s = reduce_shape(self.shape, dim, keepdim) + return [Tensor(shape=s), Tensor(shape=s)] + +@shape_dsl_function +def tuple_reduce_ir( + self: Tensor, dim: int = -1, keepdim: bool = False +) -> [Tensor, Tensor]: + s = reduce_shape(self.shape, dim, keepdim) + return [Tensor(shape=s), Tensor(shape=s)] + +@shape_dsl_function +def topk_ir(self: Tensor, k: int | symint, dim: int = -1) -> [Tensor, Tensor]: + s = replace_dim(self.shape, normalize_dim(len(self.shape), dim), k) + return [Tensor(shape=s), Tensor(shape=s)] + +@shape_dsl_function +def repeat_interleave_ir( + self: Tensor, repeats: int | symint, dim: int | None = None +) -> Tensor: + if dim == None: + return Tensor(shape=[shape_extensions.dsl.prod(self.shape) * repeats]) + d = normalize_dim(len(self.shape), dim) + return Tensor(shape=replace_dim(self.shape, d, self.shape[d] * repeats)) + +@shape_dsl_function +def cosine_similarity_ir(x1: Tensor, x2: Tensor, dim: int = 1) -> Tensor: + s = broadcast(x1.shape, x2.shape) + return Tensor(shape=reduce_single(s, normalize_dim(len(s), dim), False)) + +@shape_dsl_function +def randn_ir(size: list[int | symint]) -> Tensor: + return Tensor(shape=size) + +@shape_dsl_function +def randint_ir(low: int, high: int, size: list[int | symint]) -> Tensor: + return Tensor(shape=size) + +@shape_dsl_function +def linspace_ir(steps: int | symint) -> Tensor: + return Tensor(shape=[steps]) + +@shape_dsl_function +def eye_ir(n: int | symint, m: int | symint | None = None) -> Tensor: + if m == None: + return Tensor(shape=[n, n]) + return Tensor(shape=[n, m]) + +@shape_dsl_function +def arange_ir( + start: int | symint | None = None, + end: int | symint | None = None, + step: int | symint | None = None, +) -> Tensor: + if start != None and end != None and step != None: + return Tensor(shape=[(end - start) // step]) + if start != None and end != None: + return Tensor(shape=[end - start]) + if end != None: + return Tensor(shape=[end]) + if start != None: + return Tensor(shape=[start]) + return Unknown + +@shape_dsl_function +def normal_ir( + mean: Tensor | None = None, std: Tensor | None = None, size: list[int] | None = None +) -> Tensor: + if size != None: + return Tensor(shape=[s for s in size]) + if mean != None: + return Tensor(shape=mean.shape) + if std != None: + return Tensor(shape=std.shape) + return Unknown + +@shape_dsl_function +def diag_embed_ir(self: Tensor, offset: int = 0) -> Tensor: + new_dim = self.shape[-1] + (offset if offset >= 0 else -offset) + return Tensor(shape=self.shape[:-1] + [new_dim, new_dim]) + +@shape_dsl_function +def tri_indices_ir(row: int | symint, col: int | symint, offset: int = 0) -> Tensor: + return Tensor(shape=[2, 0]) + +@shape_dsl_function +def matmul_ir(self: Tensor, other: Tensor) -> Tensor: + r1 = len(self.shape) + r2 = len(other.shape) + if r1 == 1 and r2 == 1: + return Tensor(shape=[]) + if r1 == 1 and r2 == 2: + return Tensor(shape=[other.shape[1]]) + if r1 == 2 and r2 == 1: + return Tensor(shape=[self.shape[0]]) + if r1 == 2 and r2 == 2: + return Tensor(shape=[self.shape[0], other.shape[1]]) + if r1 == 2 and r2 >= 3: + return Tensor(shape=other.shape[:-2] + [self.shape[0]] + [other.shape[-1]]) + if r1 >= 3 and r2 == 2: + return Tensor(shape=self.shape[:-2] + [self.shape[-2]] + [other.shape[1]]) + if r1 >= 3 and r2 >= 3: + return Tensor( + shape=broadcast(self.shape[:-2], other.shape[:-2]) + + [self.shape[-2]] + + [other.shape[-1]] + ) + return Unknown + +@shape_dsl_function +def mv_ir(self: Tensor, vec: Tensor) -> Tensor: + if len(self.shape) != 2: + raise Error("mv expects 2D matrix, got " + str(len(self.shape)) + "D tensor") + if len(vec.shape) != 1: + raise Error("mv expects 1D vector, got " + str(len(vec.shape)) + "D tensor") + return Tensor(shape=[self.shape[0]]) + +@shape_dsl_function +def outer_ir(self: Tensor, vec2: Tensor) -> Tensor: + if len(self.shape) != 1 or len(vec2.shape) != 1: + raise Error( + "outer expects 1D tensors, got " + + str(len(self.shape)) + + "D and " + + str(len(vec2.shape)) + + "D" + ) + return Tensor(shape=[self.shape[0], vec2.shape[0]]) + +@shape_dsl_function +def tensordot_ir(self: Tensor, other: Tensor, dims: int) -> Tensor: + return Tensor(shape=self.shape[: len(self.shape) - dims] + other.shape[dims:]) + +@shape_dsl_function +def apply_einsum( + output_map: list[list[int]], check_pairs: list[list[int]], inputs: list[Tensor] +) -> Tensor: + bad_dims = [ + 1 + for i0, d0, i1, d1 in check_pairs + if isinstance(inputs[i0].shape[d0], int) + and isinstance(inputs[i1].shape[d1], int) + and inputs[i0].shape[d0] != inputs[i1].shape[d1] + ] + if len(bad_dims) > 0: + raise Error("einsum: inconsistent dimensions for repeated index") + return Tensor(shape=[inputs[inp].shape[dim] for inp, dim in output_map]) + +@shape_dsl_function +def einsum_ir(spec: str, operands: list[Tensor] | None = None) -> Tensor: + if operands != None: + output_map, check_pairs = shape_extensions.dsl.parse_einsum_equation(spec) + return apply_einsum(output_map, check_pairs, operands) + return Unknown + +@shape_dsl_function +def eigvals_ir(self: Tensor) -> Tensor: + if len(self.shape) < 2: + raise Error( + "eigvals requires at least 2D input, got " + + str(len(self.shape)) + + "D tensor" + ) + return Tensor(shape=self.shape[:-2] + [self.shape[-2]]) + +@shape_dsl_function +def eig_ir(self: Tensor) -> [Tensor, Tensor]: + if len(self.shape) < 2: + raise Error( + "eig requires at least 2D input, got " + str(len(self.shape)) + "D tensor" + ) + batch = self.shape[:-2] + return [ + Tensor(shape=batch + [self.shape[-2]]), + Tensor(shape=batch + self.shape[-2:]), + ] + +@shape_dsl_function +def slogdet_ir(self: Tensor) -> [Tensor, Tensor]: + if len(self.shape) < 2: + raise Error( + "slogdet requires at least 2D input, got " + + str(len(self.shape)) + + "D tensor" + ) + return [Tensor(shape=self.shape[:-2]), Tensor(shape=self.shape[:-2])] + +@shape_dsl_function +def solve_ir(self: Tensor, other: Tensor) -> Tensor: + return Tensor(shape=other.shape) + +@shape_dsl_function +def solve_reversed_ir(self: Tensor, other: Tensor) -> Tensor: + return Tensor(shape=self.shape) + +@shape_dsl_function +def conv_ir( + self: Tensor, + weight: Tensor, + stride: int | list[int] = 1, + padding: int | list[int] = 0, + dilation: int | list[int] = 1, +) -> Tensor: + spatial_dims = len(self.shape) - 2 + stride_list = broadcast_int(stride, spatial_dims) + padding_list = broadcast_int(padding, spatial_dims) + dilation_list = broadcast_int(dilation, spatial_dims) + return Tensor( + shape=[self.shape[0], weight.shape[0]] + + [ + conv_spatial_out(s, k, st, p, dil) + for s, k, st, p, dil in zip( + self.shape[2:], + weight.shape[2:], + stride_list, + padding_list, + dilation_list, + ) + ] + ) + +@shape_dsl_function +def conv_transpose_ir( + self: Tensor, + weight: Tensor, + stride: int | list[int] = 1, + padding: int | list[int] = 0, + output_padding: int | list[int] = 0, + dilation: int | list[int] = 1, +) -> Tensor: + spatial_dims = len(self.shape) - 2 + stride_list = broadcast_int(stride, spatial_dims) + padding_list = broadcast_int(padding, spatial_dims) + outpad_list = broadcast_int(output_padding, spatial_dims) + dilation_list = broadcast_int(dilation, spatial_dims) + return Tensor( + shape=[self.shape[0], weight.shape[1]] + + [ + (s - 1) * st - 2 * p + dil * (k - 1) + op + 1 + for s, k, st, p, op, dil in zip( + self.shape[2:], + weight.shape[2:], + stride_list, + padding_list, + outpad_list, + dilation_list, + ) + ] + ) + +@shape_dsl_function +def pool_ir( + self: Tensor, + kernel_size: int | list[int], + stride: int | list[int] | None = None, + padding: int | list[int] = 0, + dilation: int | list[int] = 1, + return_indices: bool = False, +) -> Tensor: + spatial_dims = len(self.shape) - 2 + ks_list = broadcast_int(kernel_size, spatial_dims) + stride_list = ks_list if stride == None else broadcast_int(stride, spatial_dims) + padding_list = broadcast_int(padding, spatial_dims) + dilation_list = broadcast_int(dilation, spatial_dims) + out = [self.shape[0], self.shape[1]] + [ + conv_spatial_out(s, k, st, p, dil) + for s, k, st, p, dil in zip( + self.shape[2:], ks_list, stride_list, padding_list, dilation_list + ) + ] + if return_indices: + return [Tensor(shape=out), Tensor(shape=out)] + return Tensor(shape=out) + +@shape_dsl_function +def adaptive_pool_ir( + self: Tensor, output_size: int | symint | list[int | symint] +) -> Tensor: + out_sizes = broadcast_int(output_size, len(self.shape) - 2) + return Tensor(shape=[self.shape[0], self.shape[1]] + out_sizes) + +@shape_dsl_function +def interpolate_ir( + self: Tensor, + size: int | symint | list[int | symint] | None = None, + scale_factor: int | symint | None = None, +) -> Tensor: + if size != None: + return Tensor( + shape=[self.shape[0], self.shape[1]] + + broadcast_int(size, len(self.shape) - 2) + ) + if scale_factor != None: + return Tensor( + shape=[self.shape[0], self.shape[1]] + + [d * scale_factor for d in self.shape[2:]] + ) + raise Error("interpolate requires either 'size' or 'scale_factor' argument") + +@shape_dsl_function +def loss_ir(self: Tensor, reduction: str = "mean") -> Tensor: + if reduction == "none": + return Tensor(shape=self.shape) + return Tensor(shape=[]) + +@shape_dsl_function +def pad_ir(self: Tensor, pad: list[int]) -> Tensor: + rank = len(self.shape) + num_pad_dims = len(pad) // 2 + offsets = [ + pad[(rank - 1 - i) * 2] + pad[(rank - 1 - i) * 2 + 1] + if i >= rank - num_pad_dims + else 0 + for i in range(rank) + ] + return Tensor(shape=[d + offsets[i] for i, d in enumerate(self.shape)]) + +@shape_dsl_function +def rfft_ir(self: Tensor, n: int | symint | None = None, dim: int = -1) -> Tensor: + d = normalize_dim(len(self.shape), dim) + if n != None: + return Tensor(shape=replace_dim(self.shape, d, n // 2 + 1)) + return Tensor(shape=replace_dim(self.shape, d, self.shape[d] // 2 + 1)) + +@shape_dsl_function +def irfft_ir(self: Tensor, n: int | symint | None = None, dim: int = -1) -> Tensor: + d = normalize_dim(len(self.shape), dim) + if n != None: + return Tensor(shape=replace_dim(self.shape, d, n)) + return Tensor(shape=replace_dim(self.shape, d, 2 * (self.shape[d] - 1))) + +@shape_dsl_function +def size_ir(self: Tensor, dim: int | None = None) -> int | symint: + if dim != None: + return self.shape[normalize_dim(len(self.shape), dim)] + return [d for d in self.shape] + +@shape_dsl_function +def numel_ir(self: Tensor) -> int | symint: + return shape_extensions.dsl.prod(self.shape) + +@shape_dsl_function +def dim_ir(self: Tensor) -> int: + return len(self.shape) + +@shape_dsl_function +def item_ir(self: Tensor) -> Tensor: + if len(self.shape) != 0: + raise Error( + "item() only works on 0-dimensional tensors, got " + + str(len(self.shape)) + + "D tensor" + ) + return Unknown + +@shape_dsl_function +def tolist_ir(self: Tensor) -> Tensor: + return Unknown + +@shape_dsl_function +def multinomial_ir(self: Tensor, num_samples: int | symint) -> Tensor: + return Tensor(shape=self.shape[:-1] + [num_samples]) + +@shape_dsl_function +def where_ir(condition: Tensor, x: Tensor, y: Tensor) -> Tensor: + return Tensor(shape=x.shape) + +@shape_dsl_function +def take_along_dim_ir(self: Tensor, indices: Tensor) -> Tensor: + return Tensor(shape=indices.shape) + +@shape_dsl_function +def nn_flatten_forward_ir( + input: Tensor, start_dim: symint = 1, end_dim: symint = -1 +) -> Tensor: + return flatten_ir(input, start_dim, end_dim) + +@shape_dsl_function +def nn_maxpool_forward_ir( + input: Tensor, + kernel_size: symint = 1, + stride: symint | None = None, + padding: symint = 0, + dilation: symint = 1, +) -> Tensor: + return pool_ir(input, kernel_size, stride, padding, dilation) + +@shape_dsl_function +def nn_avgpool_forward_ir( + input: Tensor, + kernel_size: symint = 1, + stride: symint | None = None, + padding: symint = 0, +) -> Tensor: + return pool_ir(input, kernel_size, stride, padding, 1) + +@shape_dsl_function +def nn_upsample_forward_ir( + input: Tensor, size: symint | None = None, scale_factor: symint | None = None +) -> Tensor: + return interpolate_ir(input, size, scale_factor) + +@shape_dsl_function +def nn_pixel_shuffle_forward_ir(input: Tensor, upscale_factor: symint) -> Tensor: + r = upscale_factor + return Tensor( + shape=[input.shape[0], input.shape[1] // (r * r)] + + [d * r for d in input.shape[2:]] + ) + +@shape_dsl_function +def nn_glu_forward_ir(input: Tensor, dim: symint = 1) -> Tensor: + rank = len(input.shape) + d = normalize_dim(rank, dim) + return Tensor(shape=replace_dim(input.shape, d, input.shape[d] // 2)) + +@shape_dsl_function +def nn_lstm_forward_ir( + input: Tensor, + input_size: symint, + hidden_size: symint, + num_layers: symint = 1, + bidirectional: bool = False, +) -> [Tensor, Tensor, Tensor]: + nd = 2 if bidirectional else 1 + output = Tensor(shape=[input.shape[0], input.shape[1], hidden_size * nd]) + h_n = Tensor(shape=[num_layers * nd, input.shape[0], hidden_size]) + c_n = Tensor(shape=[num_layers * nd, input.shape[0], hidden_size]) + return [output, h_n, c_n] + +@shape_dsl_function +def nn_gru_forward_ir( + input: Tensor, + input_size: symint, + hidden_size: symint, + num_layers: symint = 1, + bidirectional: bool = False, +) -> [Tensor, Tensor]: + nd = 2 if bidirectional else 1 + output = Tensor(shape=[input.shape[0], input.shape[1], hidden_size * nd]) + h_n = Tensor(shape=[num_layers * nd, input.shape[0], hidden_size]) + return [output, h_n] + +@shape_dsl_function +def nn_lstmcell_forward_ir( + input: Tensor, input_size: symint, hidden_size: symint +) -> [Tensor, Tensor]: + h = Tensor(shape=[input.shape[0], hidden_size]) + c = Tensor(shape=[input.shape[0], hidden_size]) + return [h, c] + +@shape_dsl_function +def nn_reflectionpad2d_forward_ir(input: Tensor, padding: symint) -> Tensor: + return Tensor( + shape=[ + input.shape[0], + input.shape[1], + input.shape[2] + 2 * padding, + input.shape[3] + 2 * padding, + ] + ) From 79bed43ea7d75aad5cb0a2201b651e8e4cc0535e Mon Sep 17 00:00:00 2001 From: stroxler Date: Thu, 21 May 2026 21:19:44 -0700 Subject: [PATCH 17/25] Add fn_lookup infrastructure for DSL helper resolution Summary: Enable DSL functions that call helpers (e.g., `reshape_ir` calling `normalize_dim`) to work through the decorator path. Previously, `ShapeTransformRef::to_meta_shape_function()` used an empty `fn_lookup`, which panicked on any helper call. Introduces a `Derived` wrapper in `pyrefly_types` for attaching auxiliary data to types without affecting identity comparisons. Uses it to carry same-module DSL siblings on `FunctionKind::ShapeDsl` and `ShapeTransformRef`. At `shape_dsl_function` solve time, all siblings from the module's `BindingsMetadata` are collected; at `uses_shape_dsl` consumer sites, the siblings flow through to `to_meta_shape_function` which builds fn_lookup from self + siblings. This is a deliberate all-siblings shortcut matching the registry's flat-namespace behavior. Follow-up #9 replaces it with per-caller transitive-callee resolution. Differential Revision: D105775131 --- crates/pyrefly_types/src/callable.rs | 71 ++++++++++++++++++++-- crates/pyrefly_types/src/meta_shape_dsl.rs | 15 ++++- pyrefly/lib/alt/function.rs | 24 +++++--- pyrefly/lib/binding/function.rs | 7 ++- pyrefly/lib/binding/metadata.rs | 17 ++++++ pyrefly/lib/test/shape_dsl.rs | 24 +++++++- 6 files changed, 139 insertions(+), 19 deletions(-) diff --git a/crates/pyrefly_types/src/callable.rs b/crates/pyrefly_types/src/callable.rs index d951903dc2..e60b6cc78b 100644 --- a/crates/pyrefly_types/src/callable.rs +++ b/crates/pyrefly_types/src/callable.rs @@ -12,6 +12,7 @@ use std::fmt; use std::fmt::Display; use std::hash::Hash; use std::hash::Hasher; +use std::ops::Deref; use std::sync::Arc; use dupe::Dupe; @@ -42,6 +43,62 @@ use crate::type_output::TypeOutput; use crate::types::AnyStyle; use crate::types::Type; +/// A wrapper for derived/cached data that should not participate in +/// equality, hashing, or ordering comparisons. `Derived` always +/// compares as equal, hashes as a no-op, and orders as `Equal`. +/// +/// This is useful for attaching auxiliary data to types that derive +/// `PartialEq`, `Hash`, `Ord`, etc. without affecting their identity. +#[derive(Debug, Clone)] +pub struct Derived(pub T); + +impl PartialEq for Derived { + fn eq(&self, _other: &Self) -> bool { + true + } +} + +impl Eq for Derived {} + +impl Hash for Derived { + fn hash(&self, _state: &mut H) {} +} + +impl PartialOrd for Derived { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Derived { + fn cmp(&self, _other: &Self) -> Ordering { + Ordering::Equal + } +} + +impl Visit for Derived { + const RECURSE_CONTAINS: bool = false; + fn recurse<'a>(&'a self, _: &mut dyn FnMut(&'a Type)) {} +} + +impl VisitMut for Derived { + const RECURSE_CONTAINS: bool = false; + fn recurse_mut(&mut self, _: &mut dyn FnMut(&mut Type)) {} +} + +impl TypeEq for Derived { + fn type_eq(&self, _other: &Self, _ctx: &mut TypeEqCtx) -> bool { + true + } +} + +impl Deref for Derived { + type Target = T; + fn deref(&self) -> &T { + &self.0 + } +} + #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] #[derive(Visit, VisitMut, TypeEq)] pub struct Callable { @@ -818,7 +875,11 @@ pub enum FunctionKind { /// A function whose return type is computed by a shape DSL definition. /// The `FuncId` provides identity (module, class, name) for display and /// lookup; the `ShapeDslFunction` carries the parsed DSL IR. - ShapeDsl(Arc, Arc), + ShapeDsl( + Arc, + Arc, + Derived>>>, + ), /// The `shape_extensions.uses_shape_dsl` decorator function itself. UsesShapeDsl, } @@ -1230,7 +1291,7 @@ impl FunctionKind { Self::NumbaJit => ModuleName::from_str("numba"), Self::NumbaNjit => ModuleName::from_str("numba"), Self::Def(func_id) => func_id.module.name().dupe(), - Self::ShapeDsl(id, _) => id.module.name().dupe(), + Self::ShapeDsl(id, _, _) => id.module.name().dupe(), Self::UsesShapeDsl => ModuleName::from_str("shape_extensions"), } } @@ -1258,7 +1319,7 @@ impl FunctionKind { Self::NumbaJit => Cow::Owned(Name::new_static("jit")), Self::NumbaNjit => Cow::Owned(Name::new_static("njit")), Self::Def(func_id) => Cow::Borrowed(&func_id.name), - Self::ShapeDsl(id, _) => Cow::Borrowed(&id.name), + Self::ShapeDsl(id, _, _) => Cow::Borrowed(&id.name), Self::UsesShapeDsl => Cow::Owned(Name::new_static("uses_shape_dsl")), } } @@ -1286,14 +1347,14 @@ impl FunctionKind { Self::TotalOrdering => None, Self::DisjointBase => None, Self::Def(func_id) => func_id.cls.clone(), - Self::ShapeDsl(id, _) => id.cls.clone(), + Self::ShapeDsl(id, _, _) => id.cls.clone(), Self::UsesShapeDsl => None, } } pub fn outer_funcs(&self) -> Option<&Name> { match self { - Self::Def(func_id) | Self::ShapeDsl(func_id, _) => func_id.outer_funcs.as_ref(), + Self::Def(func_id) | Self::ShapeDsl(func_id, _, _) => func_id.outer_funcs.as_ref(), _ => None, } } diff --git a/crates/pyrefly_types/src/meta_shape_dsl.rs b/crates/pyrefly_types/src/meta_shape_dsl.rs index d5c545644b..712653ed18 100644 --- a/crates/pyrefly_types/src/meta_shape_dsl.rs +++ b/crates/pyrefly_types/src/meta_shape_dsl.rs @@ -37,6 +37,7 @@ use ruff_python_ast::Operator as RuffOperator; use ruff_python_ast::Stmt; use ruff_python_ast::UnaryOp as RuffUnaryOp; +use crate::callable::Derived; use crate::dimension::ShapeError; use crate::dimension::SizeExpr; use crate::dimension::canonicalize; @@ -3167,6 +3168,8 @@ impl TypeEq for ShapeDslFunction { #[derive(Debug, Clone)] pub struct ShapeTransformRef { pub dsl_fn: Arc, + // TODO: Replace all-siblings snapshot with resolved transitive callees. + pub helpers: Derived>>>, } /// Pointer identity: delegates to `ShapeDslFunction`'s pointer-identity equality. @@ -3224,11 +3227,19 @@ impl TypeEq for ShapeTransformRef { impl ShapeTransformRef { /// Build a `MetaShapeFunction` evaluator from this shape transform reference. - /// Uses an empty `fn_lookup` — cross-function DSL calls are not yet supported. + /// Populates `fn_lookup` with this function and all same-module siblings + /// so that cross-function DSL calls resolve correctly. pub fn to_meta_shape_function(&self) -> Box { + // helpers already includes self (it's all same-module siblings). + let fn_lookup: Arc>> = Arc::new( + self.helpers + .iter() + .map(|h| (h.inner.name.clone(), h.inner.clone())) + .collect(), + ); Box::new(DslMetaShapeFunction { fn_def: self.dsl_fn.inner.clone(), - fn_lookup: Arc::new(HashMap::new()), + fn_lookup, }) } } diff --git a/pyrefly/lib/alt/function.rs b/pyrefly/lib/alt/function.rs index e10774d212..82c784aaff 100644 --- a/pyrefly/lib/alt/function.rs +++ b/pyrefly/lib/alt/function.rs @@ -15,6 +15,7 @@ use pyrefly_python::ast::Ast; use pyrefly_python::dunder; use pyrefly_python::module_path::ModuleStyle; use pyrefly_python::short_identifier::ShortIdentifier; +use pyrefly_types::callable::Derived; use pyrefly_types::callable::FuncId; use pyrefly_types::callable::Params; use pyrefly_types::callable::PlaceholderBodyKind; @@ -543,13 +544,6 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { let kind = if let Some(dsl_fn) = shape_dsl_def { // Shape DSL functions carry their parsed IR for call-site evaluation. - // - // TODO: the embedded `ShapeDslFunction` currently carries only this - // function's own DSL body, with no `fn_lookup` of resolved helpers. - // That's adequate for leaf DSL functions but breaks any function that - // calls another `@shape_dsl_function` (e.g. `reshape_ir` -> - // `normalize_dim`). We need to build the per-function `fn_lookup` here - // from resolved transitive callees. let func_id = Arc::new(FuncId { module: self.module().dupe(), cls: defining_cls.clone(), @@ -557,7 +551,18 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { def_index: Some(def_index), outer_funcs, }); - FunctionKind::ShapeDsl(func_id, dsl_fn) + // Collect all same-module @shape_dsl_function siblings for fn_lookup. + // TODO: Replace all-siblings snapshot with per-caller transitive-callee + // resolution. + let siblings: Arc>> = Arc::new( + self.bindings() + .metadata() + .shape_dsl_functions() + .iter() + .map(|(_, dsl)| Arc::clone(dsl)) + .collect(), + ); + FunctionKind::ShapeDsl(func_id, dsl_fn, Derived(siblings)) } else { FunctionKind::from_name( self.module().dupe(), @@ -573,10 +578,11 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { if let Some((_name, ir_identifier)) = uses_shape_dsl_ir_name { let ir_type = self.get(&Key::BoundName(ir_identifier)).arc_clone_ty(); if let Type::Function(func) = &ir_type - && let FunctionKind::ShapeDsl(_, dsl_fn) = &func.metadata.kind + && let FunctionKind::ShapeDsl(_, dsl_fn, helpers) = &func.metadata.kind { flags.shape_transform = Some(Arc::new(ShapeTransformRef { dsl_fn: dsl_fn.clone(), + helpers: helpers.clone(), })); } } diff --git a/pyrefly/lib/binding/function.rs b/pyrefly/lib/binding/function.rs index 76a48ccc96..73c131d1ea 100644 --- a/pyrefly/lib/binding/function.rs +++ b/pyrefly/lib/binding/function.rs @@ -828,9 +828,12 @@ impl<'a> BindingsBuilder<'a> { // Convert the function to DSL IR before `function_header` takes `returns` // and before `function_body` takes `body`. let shape_dsl_def = if is_shape_dsl { - Some(Arc::new( + let dsl_fn = Arc::new( convert_shape_dsl_function(&x).expect("@shape_dsl_function body must be valid DSL"), - )) + ); + self.metadata + .push_shape_dsl(func_name.id.clone(), Arc::clone(&dsl_fn)); + Some(dsl_fn) } else { None }; diff --git a/pyrefly/lib/binding/metadata.rs b/pyrefly/lib/binding/metadata.rs index 628c78c82b..fe3a38478b 100644 --- a/pyrefly/lib/binding/metadata.rs +++ b/pyrefly/lib/binding/metadata.rs @@ -11,8 +11,12 @@ //! This module stores per-class metadata (starting with field information) //! in a `Vec` indexed by `ClassDefIndex`, enabling efficient lookups. +use std::sync::Arc; + use pyrefly_types::class::ClassDefIndex; use pyrefly_types::class::ClassFields; +use pyrefly_types::meta_shape_dsl::ShapeDslFunction; +use ruff_python_ast::name::Name; /// Metadata for a single class definition, populated during binding. #[derive(Debug, Clone, Default)] @@ -38,12 +42,14 @@ pub struct ClassMetadata { #[derive(Debug, Clone)] pub struct BindingsMetadata { classes: Vec, + shape_dsl_functions: Vec<(Name, Arc)>, } impl BindingsMetadata { pub fn new() -> Self { Self { classes: Vec::new(), + shape_dsl_functions: Vec::new(), } } @@ -67,4 +73,15 @@ impl BindingsMetadata { pub fn get_class_mut(&mut self, idx: ClassDefIndex) -> &mut ClassMetadata { &mut self.classes[idx.0 as usize] } + + /// Record a `@shape_dsl_function` definition so sibling DSL functions + /// can be discovered for cross-function helper resolution. + pub fn push_shape_dsl(&mut self, name: Name, dsl_fn: Arc) { + self.shape_dsl_functions.push((name, dsl_fn)); + } + + /// All `@shape_dsl_function` definitions recorded in this module. + pub fn shape_dsl_functions(&self) -> &[(Name, Arc)] { + &self.shape_dsl_functions + } } diff --git a/pyrefly/lib/test/shape_dsl.rs b/pyrefly/lib/test/shape_dsl.rs index aa39e0f7a8..af949cd53f 100644 --- a/pyrefly/lib/test/shape_dsl.rs +++ b/pyrefly/lib/test/shape_dsl.rs @@ -20,6 +20,14 @@ from shape_extensions.dsl import shape_dsl_function @shape_dsl_function def identity_ir(x: int) -> int: return x + +@shape_dsl_function +def times_two(x: int) -> int: + return x + x + +@shape_dsl_function +def double_ir(x: int) -> int: + return times_two(x) "#, ); env.add_with_path( @@ -28,7 +36,7 @@ def identity_ir(x: int) -> int: r#" from typing import overload from shape_extensions import uses_shape_dsl -from my_shapes import identity_ir +from my_shapes import identity_ir, double_ir @uses_shape_dsl(identity_ir) def plain_fn(x: int) -> int: ... @@ -45,6 +53,9 @@ def overloaded_with_impl(x: int | str) -> int | str: ... def overloaded_no_impl(x: int) -> int: ... @overload def overloaded_no_impl(x: str) -> str: ... + +@uses_shape_dsl(double_ir) +def double_fn(x: int) -> int: ... "#, ); env @@ -87,3 +98,14 @@ assert_type(overloaded_no_impl(1), Literal[1]) assert_type(overloaded_no_impl("a"), str) "#, ); + +testcase!( + test_uses_shape_dsl_cross_function_call, + shape_dsl_env(), + r#" +from typing import Literal, assert_type +from my_lib import double_fn + +assert_type(double_fn(3), Literal[6]) +"#, +); From 8bc14e525a4497dc952c15a79897f7f363cd0d95 Mon Sep 17 00:00:00 2001 From: stroxler Date: Thu, 21 May 2026 21:19:44 -0700 Subject: [PATCH 18/25] Add `@uses_shape_dsl` decorators to torch fixture stubs Summary: Wire up ~131 torch fixture stub functions with `uses_shape_dsl(ir_fn)` decorators, matching the `TensorOpsRegistry` mappings. The decorator path is now preferred over the registry for all decorated functions. Modified fixture files: - `torch/__init__.pyi`: ~112 decorators on module-level functions and `Tensor` methods (shape manipulation, reductions, creation, linalg, indexing, random, conditional, properties) - `torch/nn/functional.pyi`: ~30 decorators (conv, pool, loss, pad, interpolate, cosine_similarity) - `torch/fft.pyi`: 4 decorators (rfft, irfft, hfft, ihfft) - `torch/linalg.pyi`: 9 decorators (eig, eigvals, solve, slogdet, etc.) All tensor_shapes tests pass through the decorator path. Differential Revision: D105775133 --- .../tensor_shapes/fixtures/torch/__init__.pyi | 177 ++++++++++++++++++ test/tensor_shapes/fixtures/torch/fft.pyi | 6 + test/tensor_shapes/fixtures/torch/linalg.pyi | 10 + .../fixtures/torch/nn/functional.pyi | 48 +++++ 4 files changed, 241 insertions(+) diff --git a/test/tensor_shapes/fixtures/torch/__init__.pyi b/test/tensor_shapes/fixtures/torch/__init__.pyi index d0480d3b60..c17aeca161 100644 --- a/test/tensor_shapes/fixtures/torch/__init__.pyi +++ b/test/tensor_shapes/fixtures/torch/__init__.pyi @@ -19,6 +19,61 @@ For operations handled by meta-shapes, see pyrefly_types/src/meta_shape.rs: import builtins from typing import Any, overload, Self, TYPE_CHECKING +from shape_extensions import uses_shape_dsl +from torch._shapes import ( + aminmax_ir, + arange_ir, + broadcast_to_ir, + cat_ir, + chunk_ir, + diag_embed_ir, + dim_ir, + eig_ir, + einsum_ir, + expand_ir, + eye_ir, + flatten_ir, + index_select_ir, + item_ir, + linspace_ir, + matmul_ir, + min_max_median_ir, + movedim_ir, + multinomial_ir, + mv_ir, + narrow_ir, + normal_ir, + numel_ir, + outer_ir, + permute_ir, + randint_ir, + randn_ir, + reduce_ir, + repeat_interleave_ir, + repeat_ir, + reshape_ir, + select_ir, + size_ir, + slogdet_ir, + solve_ir, + solve_reversed_ir, + split_ir, + squeeze_ir, + stack_ir, + take_along_dim_ir, + tensordot_ir, + tile_ir, + tolist_ir, + topk_ir, + transpose_ir, + tri_indices_ir, + tuple_reduce_ir, + unbind_ir, + unfold_ir, + unsqueeze_ir, + where_ir, +) + if TYPE_CHECKING: from shape_extensions import Dim as _Dim @@ -95,6 +150,7 @@ class Tensor[*Shape]: # ==== Matrix Multiplication ==== # Uses meta-shape for shape inference + @uses_shape_dsl(matmul_ir) def __matmul__(self: Tensor, other: Tensor) -> Tensor: """Matrix multiplication (@). Shape inference via meta-shape: torch.Tensor.matmul""" ... @@ -134,6 +190,7 @@ class Tensor[*Shape]: # ==== Shape Manipulation Operations ==== # Handled by meta-shape functions - simplified signatures + @uses_shape_dsl(reshape_ir) @overload def reshape(self: Tensor, *shape: int) -> Tensor: """Reshape tensor. Shape inference via meta-shape: torch.Tensor.reshape""" @@ -144,6 +201,7 @@ class Tensor[*Shape]: """Reshape tensor. Shape inference via meta-shape: torch.Tensor.reshape""" ... + @uses_shape_dsl(reshape_ir) @overload def view(self: Tensor, *shape: int) -> Tensor: """View (alias for reshape). Shape inference via meta-shape: torch.Tensor.view""" @@ -154,14 +212,17 @@ class Tensor[*Shape]: """View (alias for reshape). Shape inference via meta-shape: torch.Tensor.view""" ... + @uses_shape_dsl(flatten_ir) def flatten(self: Tensor, start_dim: int = 0, end_dim: int = -1) -> Tensor: """Flatten dimensions. Shape inference via meta-shape: torch.flatten""" ... + @uses_shape_dsl(transpose_ir) def transpose(self: Tensor, dim0: int, dim1: int) -> Tensor: """Transpose two dimensions. Shape inference via meta-shape: torch.transpose""" ... + @uses_shape_dsl(permute_ir) @overload def permute(self: Tensor, *dims: int) -> Tensor: """Permute dimensions. Shape inference via meta-shape: torch.Tensor.permute""" @@ -172,14 +233,17 @@ class Tensor[*Shape]: """Permute dimensions. Shape inference via meta-shape: torch.Tensor.permute""" ... + @uses_shape_dsl(squeeze_ir) def squeeze(self: Tensor, dim: int | None = None) -> Tensor: """Remove dimensions of size 1. Shape inference via meta-shape: torch.squeeze""" ... + @uses_shape_dsl(unsqueeze_ir) def unsqueeze(self: Tensor, dim: int) -> Tensor: """Add dimension of size 1. Shape inference via meta-shape: torch.unsqueeze""" ... + @uses_shape_dsl(repeat_ir) @overload def repeat(self: Tensor, *sizes: int) -> Tensor: """Repeat tensor. Shape inference via meta-shape: torch.Tensor.repeat""" @@ -194,6 +258,7 @@ class Tensor[*Shape]: """Transpose 2D tensor. Swaps dimensions.""" ... + @uses_shape_dsl(expand_ir) def expand(self: Tensor, *sizes: int) -> Tensor: """Expand tensor. Shape inference via meta-shape: torch.Tensor.expand""" ... @@ -202,6 +267,7 @@ class Tensor[*Shape]: """Expand tensor to match the shape of `other`.""" ... + @uses_shape_dsl(repeat_interleave_ir) def repeat_interleave( self: Tensor, repeats: int | Tensor, dim: int | None = None ) -> Tensor: @@ -332,26 +398,32 @@ class Tensor[*Shape]: """Enable/disable gradient tracking in-place. Shape-preserving.""" ... + @uses_shape_dsl(item_ir) def item(self: Tensor) -> float | int: """Returns Python scalar from 0-dimensional tensor. Shape inference via meta-shape: torch.Tensor.item""" ... + @uses_shape_dsl(tolist_ir) def tolist(self: Tensor) -> Any: """Returns tensor as nested Python list. Shape inference via meta-shape: torch.Tensor.tolist""" ... + @uses_shape_dsl(tile_ir) def tile(self: Tensor, dims: tuple[int, ...]) -> Tensor: """Tile tensor. Shape inference via meta-shape: torch.Tensor.tile""" ... + @uses_shape_dsl(select_ir) def select(self: Tensor, dim: int, index: int) -> Tensor: """Select along dimension. Shape inference via meta-shape: torch.Tensor.select""" ... + @uses_shape_dsl(narrow_ir) def narrow(self: Tensor, dim: int, start: int, length: int) -> Tensor: """Narrow tensor along dimension. Shape inference via meta-shape: torch.Tensor.narrow""" ... + @uses_shape_dsl(split_ir) @overload def split( self: Tensor, split_size_or_sections: int, dim: int = 0 @@ -366,10 +438,12 @@ class Tensor[*Shape]: """Split tensor into variable-sized chunks. Shape inference via meta-shape: torch.Tensor.split""" ... + @uses_shape_dsl(chunk_ir) def chunk(self: Tensor, chunks: int, dim: int = 0) -> tuple[Tensor, ...]: """Split tensor into chunks. Shape inference via meta-shape: torch.Tensor.chunk""" ... + @uses_shape_dsl(index_select_ir) def index_select(self: Tensor, dim: int, index: Tensor) -> Tensor: """Select elements along dimension. Shape inference via meta-shape: torch.Tensor.index_select""" ... @@ -392,10 +466,12 @@ class Tensor[*Shape]: # ==== Phase 1.1: Missing Shape Operations (Methods) ==== + @uses_shape_dsl(unbind_ir) def unbind(self: Tensor, dim: int = 0) -> tuple[Tensor, ...]: """Remove dimension by slicing along it. Shape inference via meta-shape: torch.Tensor.unbind""" ... + @uses_shape_dsl(movedim_ir) @overload def movedim(self: Tensor, source: int, destination: int) -> Tensor: """Move single dimension to new position. Shape inference via meta-shape: torch.Tensor.movedim""" @@ -408,6 +484,7 @@ class Tensor[*Shape]: """Move multiple dimensions to new positions. Shape inference via meta-shape: torch.Tensor.movedim""" ... + @uses_shape_dsl(movedim_ir) @overload def moveaxis(self: Tensor, source: int, destination: int) -> Tensor: """Alias for movedim. Shape inference via meta-shape: torch.Tensor.moveaxis""" @@ -420,10 +497,12 @@ class Tensor[*Shape]: """Alias for movedim. Shape inference via meta-shape: torch.Tensor.moveaxis""" ... + @uses_shape_dsl(unfold_ir) def unfold(self: Tensor, dimension: int, size: int, step: int) -> Tensor: """Returns sliding window view. Shape inference via meta-shape: torch.Tensor.unfold""" ... + @uses_shape_dsl(size_ir) @overload def size(self: Tensor) -> tuple[builtins.int, ...]: """Returns the size of the tensor as a tuple. Shape inference via meta-shape: torch.Tensor.size""" @@ -437,6 +516,7 @@ class Tensor[*Shape]: # ==== Reduction Operations ==== # Handled by meta-shape functions - simplified signatures + @uses_shape_dsl(reduce_ir) @overload def sum(self: Tensor, dim: int | None = None, keepdim: bool = False) -> Tensor: """Sum along dimension(s). Shape inference via meta-shape: torch.Tensor.sum""" @@ -447,10 +527,12 @@ class Tensor[*Shape]: """Sum along multiple dimensions. Shape inference via meta-shape: torch.Tensor.sum""" ... + @uses_shape_dsl(reduce_ir) def mean(self: Tensor, dim: int | None = None, keepdim: bool = False) -> Tensor: """Mean along dimension(s). Shape inference via meta-shape: torch.mean""" ... + @uses_shape_dsl(min_max_median_ir) @overload def max(self: Tensor) -> Tensor: """Max of all elements (scalar). Shape inference via meta-shape: torch.Tensor.max""" @@ -461,6 +543,7 @@ class Tensor[*Shape]: """Max along dimension. Returns (values, indices). Shape inference via meta-shape: torch.Tensor.max""" ... + @uses_shape_dsl(min_max_median_ir) @overload def min(self: Tensor) -> Tensor: """Min of all elements (scalar). Shape inference via meta-shape: torch.Tensor.min""" @@ -471,28 +554,34 @@ class Tensor[*Shape]: """Min along dimension. Returns (values, indices). Shape inference via meta-shape: torch.Tensor.min""" ... + @uses_shape_dsl(reduce_ir) def prod(self: Tensor, dim: int | None = None, keepdim: bool = False) -> Tensor: """Product along dimension(s). Shape inference via meta-shape: torch.prod""" ... + @uses_shape_dsl(reduce_ir) def std(self: Tensor, dim: int | None = None, keepdim: bool = False) -> Tensor: """Standard deviation along dimension(s). Shape inference via meta-shape: torch.std""" ... + @uses_shape_dsl(reduce_ir) def var(self: Tensor, dim: int | None = None, keepdim: bool = False) -> Tensor: """Variance along dimension(s). Shape inference via meta-shape: torch.var""" ... + @uses_shape_dsl(reduce_ir) def argmax(self: Tensor, dim: int | None = None, keepdim: bool = False) -> Tensor: """Argmax along dimension(s). Shape inference via meta-shape: torch.argmax""" ... + @uses_shape_dsl(reduce_ir) def argmin(self: Tensor, dim: int | None = None, keepdim: bool = False) -> Tensor: """Argmin along dimension(s). Shape inference via meta-shape: torch.argmin""" ... # ==== Phase 1.2: Missing Reduction Operations (Methods) ==== + @uses_shape_dsl(min_max_median_ir) @overload def median(self: Tensor) -> Tensor: """Median of all elements (scalar). Shape inference via meta-shape: torch.Tensor.median""" @@ -503,22 +592,26 @@ class Tensor[*Shape]: """Median along dimension. Returns (values, indices). Shape inference via meta-shape: torch.Tensor.median""" ... + @uses_shape_dsl(reduce_ir) def logsumexp( self: Tensor, dim: int | None = None, keepdim: bool = False ) -> Tensor: """Log-sum-exp along dimension(s). Shape inference via meta-shape: torch.Tensor.logsumexp""" ... + @uses_shape_dsl(reduce_ir) def count_nonzero(self: Tensor, dim: int | None = None) -> Tensor: """Count non-zero elements. Shape inference via meta-shape: torch.Tensor.count_nonzero""" ... + @uses_shape_dsl(aminmax_ir) def aminmax( self: Tensor, dim: int | None = None, keepdim: bool = False ) -> tuple[Tensor, Tensor]: """Min and max along dimension(s). Shape inference via meta-shape: torch.Tensor.aminmax""" ... + @uses_shape_dsl(reduce_ir) def norm( self: Tensor, p: int | float = 2, @@ -554,12 +647,14 @@ class Tensor[*Shape]: # ==== Tier 2: Additional Reduction Methods ==== + @uses_shape_dsl(tuple_reduce_ir) def mode( self: Tensor, dim: int = -1, keepdim: bool = False ) -> tuple[Tensor, Tensor]: """Mode along dimension. Returns (values, indices). Shape inference via meta-shape: torch.Tensor.mode""" ... + @uses_shape_dsl(topk_ir) def topk( self: Tensor, k: int, dim: int = -1, largest: bool = True, sorted: bool = True ) -> tuple[Tensor, Tensor]: @@ -575,6 +670,7 @@ class Tensor[*Shape]: """Sort tensor. Returns (values, indices). Shape-preserving operation.""" ... + @uses_shape_dsl(tuple_reduce_ir) def kthvalue( self: Tensor, k: int, dim: int = -1, keepdim: bool = False ) -> tuple[Tensor, Tensor]: @@ -583,6 +679,7 @@ class Tensor[*Shape]: # ==== Phase 1.3: Tensor Creation Operations (Methods) ==== + @uses_shape_dsl(diag_embed_ir) def diag_embed( self: Tensor, offset: int = 0, dim1: int = -2, dim2: int = -1 ) -> Tensor: @@ -599,6 +696,7 @@ class Tensor[*Shape]: # ==== Phase 1.4: Basic Linear Algebra Operations (Methods) ==== + @uses_shape_dsl(matmul_ir) def matmul(self: Tensor, other: Tensor) -> Tensor: """Matrix multiplication. Shape inference via meta-shape: torch.Tensor.matmul""" ... @@ -613,6 +711,7 @@ class Tensor[*Shape]: """Batch matrix multiplication (3D @ 3D). Output: [B, N, M].""" ... + @uses_shape_dsl(mv_ir) def mv(self: Tensor, vec: Tensor) -> Tensor: """Matrix-vector multiplication. Shape inference via meta-shape: torch.Tensor.mv""" ... @@ -976,6 +1075,7 @@ class Tensor[*Shape]: """Log determinant. Returns batch dimensions only (drops last 2 dims).""" ... + @uses_shape_dsl(slogdet_ir) def slogdet(self: Tensor) -> tuple[Tensor, Tensor]: """Sign and log determinant. Shape inference via meta-shape: torch.Tensor.slogdet""" ... @@ -1058,6 +1158,7 @@ class Tensor[*Shape]: """Take elements at indices. Output shape matches index shape.""" ... + @uses_shape_dsl(take_along_dim_ir) def take_along_dim(self: Tensor, indices: Tensor, dim: int) -> Tensor: """Take along dimension. Shape inference via meta-shape: torch.Tensor.take_along_dim""" ... @@ -1080,6 +1181,7 @@ class Tensor[*Shape]: """Sample from Bernoulli distribution in-place. Shape inference via generic fixture signature.""" ... + @uses_shape_dsl(multinomial_ir) def multinomial( self: Tensor, num_samples: int, replacement: bool = False ) -> Tensor: @@ -1098,14 +1200,17 @@ class Tensor[*Shape]: """Fill with uniform distribution in-place. Shape inference via generic fixture signature.""" ... + @uses_shape_dsl(numel_ir) def numel(self: Tensor) -> int: """Number of elements. Shape inference via meta-shape: torch.Tensor.numel""" ... + @uses_shape_dsl(dim_ir) def dim(self: Tensor) -> int: """Number of dimensions. Shape inference via meta-shape: torch.Tensor.dim""" ... + @uses_shape_dsl(numel_ir) def nelement(self: Tensor) -> int: """Number of elements. Shape inference via meta-shape: torch.Tensor.nelement""" ... @@ -1114,46 +1219,57 @@ class Tensor[*Shape]: # Module-level Functions # ============================================================================ +@uses_shape_dsl(matmul_ir) def matmul(self: Tensor, other: Tensor) -> Tensor: """Matrix multiplication function. Shape inference via meta-shape: torch.matmul""" ... +@uses_shape_dsl(cat_ir) def cat(tensors: list[Tensor] | tuple[Tensor, ...], dim: int = 0) -> Tensor: """Concatenate tensors. Shape inference via meta-shape: torch.cat""" ... +@uses_shape_dsl(stack_ir) def stack(tensors: list[Tensor] | tuple[Tensor, ...], dim: int = 0) -> Tensor: """Stack tensors (adds new dimension).""" ... +@uses_shape_dsl(transpose_ir) def transpose(self: Tensor, dim0: int, dim1: int) -> Tensor: """Transpose two dimensions. Shape inference via meta-shape: torch.transpose""" ... +@uses_shape_dsl(reshape_ir) def reshape(self: Tensor, shape: tuple[int, ...]) -> Tensor: """Reshape tensor. Shape inference via meta-shape: torch.reshape""" ... +@uses_shape_dsl(squeeze_ir) def squeeze(self: Tensor, dim: int | None = None) -> Tensor: """Remove dimensions of size 1. Shape inference via meta-shape: torch.squeeze""" ... +@uses_shape_dsl(unsqueeze_ir) def unsqueeze(self: Tensor, dim: int) -> Tensor: """Add dimension of size 1. Shape inference via meta-shape: torch.unsqueeze""" ... +@uses_shape_dsl(permute_ir) def permute(self: Tensor, dims: tuple[int, ...]) -> Tensor: """Permute dimensions. Shape inference via meta-shape: torch.permute""" ... +@uses_shape_dsl(reduce_ir) def sum(self: Tensor, dim: int | None = None, keepdim: bool = False) -> Tensor: """Sum along dimension(s). Shape inference via meta-shape: torch.sum""" ... +@uses_shape_dsl(reduce_ir) def mean(self: Tensor, dim: int | None = None, keepdim: bool = False) -> Tensor: """Mean along dimension(s). Shape inference via meta-shape: torch.mean""" ... +@uses_shape_dsl(min_max_median_ir) @overload def max(self: Tensor) -> Tensor: """Max of all elements (scalar). Shape inference via meta-shape: torch.max""" @@ -1164,6 +1280,7 @@ def max(self: Tensor, dim: int, keepdim: bool = False) -> tuple[Tensor, Tensor]: """Max along dimension. Returns (values, indices). Shape inference via meta-shape: torch.max""" ... +@uses_shape_dsl(min_max_median_ir) @overload def min(self: Tensor) -> Tensor: """Min of all elements (scalar). Shape inference via meta-shape: torch.min""" @@ -1179,36 +1296,44 @@ def min[*S](input: Tensor[*S], other: Tensor) -> Tensor[*S]: """Element-wise minimum of two tensors.""" ... +@uses_shape_dsl(reduce_ir) def prod(self: Tensor, dim: int | None = None, keepdim: bool = False) -> Tensor: """Product along dimension(s). Shape inference via meta-shape: torch.prod""" ... +@uses_shape_dsl(reduce_ir) def std(self: Tensor, dim: int | None = None, keepdim: bool = False) -> Tensor: """Standard deviation. Shape inference via meta-shape: torch.std""" ... +@uses_shape_dsl(reduce_ir) def var(self: Tensor, dim: int | None = None, keepdim: bool = False) -> Tensor: """Variance. Shape inference via meta-shape: torch.var""" ... +@uses_shape_dsl(reduce_ir) def argmax(self: Tensor, dim: int | None = None, keepdim: bool = False) -> Tensor: """Argmax. Shape inference via meta-shape: torch.argmax""" ... +@uses_shape_dsl(reduce_ir) def argmin(self: Tensor, dim: int | None = None, keepdim: bool = False) -> Tensor: """Argmin. Shape inference via meta-shape: torch.argmin""" ... +@uses_shape_dsl(flatten_ir) def flatten(self: Tensor, start_dim: int = 0, end_dim: int = -1) -> Tensor: """Flatten dimensions. Shape inference via meta-shape: torch.flatten""" ... +@uses_shape_dsl(stack_ir) def stack(tensors: list[Tensor] | tuple[Tensor, ...], dim: int = 0) -> Tensor: """Stack tensors. Shape inference via meta-shape: torch.stack""" ... # ==== Tensor Creation Functions ==== +@uses_shape_dsl(randn_ir) @overload def randn(*size: int, dtype: Any = None, device: Any = None) -> Tensor: """Create tensor with random values. Shape inference via meta-shape: torch.randn""" @@ -1219,6 +1344,7 @@ def randn(size: tuple[int, ...], dtype: Any = None, device: Any = None) -> Tenso """Create tensor with random values (tuple size). Shape inference via meta-shape: torch.randn""" ... +@uses_shape_dsl(randn_ir) @overload def rand(*size: int, dtype: Any = None, device: Any = None) -> Tensor: """Create tensor with random values [0, 1). Shape inference via meta-shape: torch.rand""" @@ -1229,6 +1355,7 @@ def rand(size: tuple[int, ...], dtype: Any = None, device: Any = None) -> Tensor """Create tensor with random values (tuple size). Shape inference via meta-shape: torch.rand""" ... +@uses_shape_dsl(randn_ir) @overload def zeros(*size: int, dtype: Any = None, device: Any = None) -> Tensor: """Create tensor filled with zeros. Shape inference via meta-shape: torch.zeros""" @@ -1239,6 +1366,7 @@ def zeros(size: tuple[int, ...], dtype: Any = None, device: Any = None) -> Tenso """Create tensor filled with zeros (tuple size). Shape inference via meta-shape: torch.zeros""" ... +@uses_shape_dsl(randn_ir) @overload def ones(*size: int, dtype: Any = None, device: Any = None) -> Tensor: """Create tensor filled with ones. Shape inference via meta-shape: torch.ones""" @@ -1249,6 +1377,7 @@ def ones(size: tuple[int, ...], dtype: Any = None, device: Any = None) -> Tensor """Create tensor filled with ones (tuple size). Shape inference via meta-shape: torch.ones""" ... +@uses_shape_dsl(randn_ir) @overload def empty(*size: int, dtype: Any = None, device: Any = None) -> Tensor: """Create uninitialized tensor. Shape inference via meta-shape: torch.empty""" @@ -1259,11 +1388,13 @@ def empty(size: tuple[int, ...], dtype: Any = None, device: Any = None) -> Tenso """Create uninitialized tensor (tuple size). Shape inference via meta-shape: torch.empty""" ... +@uses_shape_dsl(randn_ir) def full(size: tuple[int, ...], fill_value: float) -> Tensor: """Create tensor filled with value. Shape inference via meta-shape: torch.full""" ... # arange overloads - Dim is compatible with int, so meta-shape handles both +@uses_shape_dsl(arange_ir) @overload def arange(end: int) -> Tensor: """Create 1D tensor with range [0, end). Shape inference via meta-shape: torch.arange""" @@ -1291,44 +1422,53 @@ def arange( """Create 1D tensor with range [start, end) with step. Shape inference via meta-shape: torch.arange""" ... +@uses_shape_dsl(linspace_ir) def linspace( start: float, end: float, steps: int, *, dtype: Any = None, device: Any = None ) -> Tensor: """Create 1D tensor with linearly spaced values. Shape inference via meta-shape: torch.linspace""" ... +@uses_shape_dsl(eye_ir) def eye(n: int) -> Tensor: """Create 2D identity matrix. Shape inference via meta-shape: torch.eye""" ... # ==== Shape Manipulation Functions ==== +@uses_shape_dsl(broadcast_to_ir) def broadcast_to(self: Tensor, shape: tuple[int, ...]) -> Tensor: """Broadcast tensor to shape. Shape inference via meta-shape: torch.broadcast_to""" ... +@uses_shape_dsl(tile_ir) def tile(self: Tensor, dims: tuple[int, ...]) -> Tensor: """Tile tensor by repeating. Shape inference via meta-shape: torch.tile""" ... +@uses_shape_dsl(select_ir) def select(self: Tensor, dim: int, index: int) -> Tensor: """Select along dimension. Shape inference via meta-shape: torch.select""" ... +@uses_shape_dsl(narrow_ir) def narrow(self: Tensor, dim: int, start: int, length: int) -> Tensor: """Narrow tensor along dimension. Shape inference via meta-shape: torch.narrow""" ... +@uses_shape_dsl(split_ir) def split( self: Tensor, split_size_or_sections: int, dim: int = 0 ) -> tuple[Tensor, ...]: """Split tensor into chunks. Shape inference via meta-shape: torch.split""" ... +@uses_shape_dsl(chunk_ir) def chunk(self: Tensor, chunks: int, dim: int = 0) -> tuple[Tensor, ...]: """Split tensor into chunks. Shape inference via meta-shape: torch.chunk""" ... +@uses_shape_dsl(index_select_ir) def index_select(self: Tensor, dim: int, index: Tensor) -> Tensor: """Select elements along dimension. Shape inference via meta-shape: torch.index_select""" ... @@ -1351,10 +1491,12 @@ def masked_select(self: Tensor, mask: Tensor) -> Tensor[Any]: # ==== Phase 1.1: Missing Shape Operations ==== +@uses_shape_dsl(unbind_ir) def unbind(self: Tensor, dim: int = 0) -> tuple[Tensor, ...]: """Remove dimension by slicing along it. Shape inference via meta-shape: torch.unbind""" ... +@uses_shape_dsl(movedim_ir) @overload def movedim(self: Tensor, source: int, destination: int) -> Tensor: """Move single dimension to new position. Shape inference via meta-shape: torch.movedim""" @@ -1367,6 +1509,7 @@ def movedim( """Move multiple dimensions to new positions. Shape inference via meta-shape: torch.movedim""" ... +@uses_shape_dsl(movedim_ir) @overload def moveaxis(self: Tensor, source: int, destination: int) -> Tensor: """Alias for movedim. Shape inference via meta-shape: torch.moveaxis""" @@ -1379,22 +1522,26 @@ def moveaxis( """Alias for movedim. Shape inference via meta-shape: torch.moveaxis""" ... +@uses_shape_dsl(unfold_ir) def unfold(self: Tensor, dimension: int, size: int, step: int) -> Tensor: """Returns sliding window view. Shape inference via meta-shape: torch.unfold""" ... # ==== Additional Reduction Functions ==== +@uses_shape_dsl(reduce_ir) def all(self: Tensor, dim: int | None = None, keepdim: bool = False) -> Tensor: """Check if all elements are True. Shape inference via meta-shape: torch.all""" ... +@uses_shape_dsl(reduce_ir) def any(self: Tensor, dim: int | None = None, keepdim: bool = False) -> Tensor: """Check if any element is True. Shape inference via meta-shape: torch.any""" ... # ==== Phase 1.2: Missing Reduction Operations ==== +@uses_shape_dsl(min_max_median_ir) @overload def median(self: Tensor) -> Tensor: """Median of all elements (scalar). Shape inference via meta-shape: torch.median""" @@ -1405,20 +1552,24 @@ def median(self: Tensor, dim: int, keepdim: bool = False) -> tuple[Tensor, Tenso """Median along dimension. Returns (values, indices). Shape inference via meta-shape: torch.median""" ... +@uses_shape_dsl(reduce_ir) def logsumexp(self: Tensor, dim: int | None = None, keepdim: bool = False) -> Tensor: """Log-sum-exp along dimension(s). Shape inference via meta-shape: torch.logsumexp""" ... +@uses_shape_dsl(reduce_ir) def count_nonzero(self: Tensor, dim: int | None = None) -> Tensor: """Count non-zero elements. Shape inference via meta-shape: torch.count_nonzero""" ... +@uses_shape_dsl(aminmax_ir) def aminmax( self: Tensor, dim: int | None = None, keepdim: bool = False ) -> tuple[Tensor, Tensor]: """Min and max along dimension(s). Shape inference via meta-shape: torch.aminmax""" ... +@uses_shape_dsl(reduce_ir) def norm( self: Tensor, p: int | float = 2, @@ -1453,10 +1604,12 @@ def cummin[*Shape]( ... # Tier 2: Additional reduction operations (always return tuples) +@uses_shape_dsl(tuple_reduce_ir) def mode(self: Tensor, dim: int = -1, keepdim: bool = False) -> tuple[Tensor, Tensor]: """Mode along dimension. Returns (values, indices). Shape inference via meta-shape: torch.mode""" ... +@uses_shape_dsl(topk_ir) def topk( self: Tensor, k: int, dim: int = -1, largest: bool = True, sorted: bool = True ) -> tuple[Tensor, Tensor]: @@ -1469,6 +1622,7 @@ def sort[*Shape]( """Sort tensor. Returns (values, indices). Shape-preserving operation.""" ... +@uses_shape_dsl(tuple_reduce_ir) def kthvalue( self: Tensor, k: int, dim: int = -1, keepdim: bool = False ) -> tuple[Tensor, Tensor]: @@ -1476,6 +1630,7 @@ def kthvalue( ... # Tier 3: Statistical operations returning tuples +@uses_shape_dsl(aminmax_ir) def var_mean( self: Tensor, dim: int | tuple[int, ...] | None = None, @@ -1485,6 +1640,7 @@ def var_mean( """Variance and mean. Returns (var, mean). Shape inference via meta-shape: torch.var_mean""" ... +@uses_shape_dsl(aminmax_ir) def std_mean( self: Tensor, dim: int | tuple[int, ...] | None = None, @@ -1520,6 +1676,7 @@ def randn_like[*Shape](input: Tensor[*Shape]) -> Tensor[*Shape]: """Create random normal tensor with same shape. Shape inference via generic fixture signature.""" ... +@uses_shape_dsl(diag_embed_ir) def diag_embed(self: Tensor, offset: int = 0, dim1: int = -2, dim2: int = -1) -> Tensor: """Create diagonal tensor. Shape inference via meta-shape: torch.diag_embed""" ... @@ -1532,10 +1689,12 @@ def triu[*Shape](input: Tensor[*Shape], diagonal: int = 0) -> Tensor[*Shape]: """Upper triangular part. Shape inference via generic fixture signature.""" ... +@uses_shape_dsl(tri_indices_ir) def tril_indices(row: int, col: int, offset: int = 0) -> Tensor: """Indices of lower triangular part. Shape inference via meta-shape: torch.tril_indices""" ... +@uses_shape_dsl(tri_indices_ir) def triu_indices(row: int, col: int, offset: int = 0) -> Tensor: """Indices of upper triangular part. Shape inference via meta-shape: torch.triu_indices""" ... @@ -1553,6 +1712,7 @@ def bmm[B, N, K, M](input: Tensor[B, N, K], mat2: Tensor[B, K, M]) -> Tensor[B, """Batch matrix multiplication (3D @ 3D). Output: [B, N, M].""" ... +@uses_shape_dsl(mv_ir) def mv(self: Tensor, vec: Tensor) -> Tensor: """Matrix-vector multiplication (2D @ 1D). Shape inference via meta-shape: torch.mv""" ... @@ -1839,21 +1999,25 @@ def fmin[*Shape](input: Tensor[*Shape], other: Tensor) -> Tensor[*Shape]: # ============================================================================== # Advanced matmul operations +@uses_shape_dsl(tensordot_ir) def tensordot( self: Tensor, other: Tensor, dims: int | tuple[list[int], list[int]] = 2 ) -> Tensor: """Tensor contraction over specified dimensions. Shape inference via meta-shape: torch.tensordot""" ... +@uses_shape_dsl(einsum_ir) def einsum(spec: str, *operands: Tensor) -> Tensor: """Einstein summation convention. Shape inference via meta-shape: torch.einsum""" ... # Eigenvalue decomposition +@uses_shape_dsl(eig_ir) def eig(self: Tensor, eigenvectors: bool = False) -> tuple[Tensor, Tensor]: """Eigenvalue decomposition. Shape inference via meta-shape: torch.eig""" ... +@uses_shape_dsl(eig_ir) def eigh(self: Tensor, UPLO: str = "L") -> tuple[Tensor, Tensor]: """Hermitian eigenvalue decomposition. Shape inference via meta-shape: torch.eigh""" ... @@ -1864,18 +2028,22 @@ def cholesky[*Shape](input: Tensor[*Shape], upper: bool = False) -> Tensor[*Shap ... # Linear system solvers +@uses_shape_dsl(solve_ir) def solve(self: Tensor, other: Tensor) -> Tensor: """Solve linear system. Shape inference via meta-shape: torch.solve""" ... +@uses_shape_dsl(solve_reversed_ir) def triangular_solve(self: Tensor, other: Tensor, upper: bool = True) -> Tensor: """Solve triangular system. Shape inference via meta-shape: torch.triangular_solve""" ... +@uses_shape_dsl(solve_reversed_ir) def cholesky_solve(self: Tensor, other: Tensor, upper: bool = False) -> Tensor: """Solve using Cholesky. Shape inference via meta-shape: torch.cholesky_solve""" ... +@uses_shape_dsl(solve_ir) def lu_solve(self: Tensor, other: Tensor, LU_pivots: Tensor) -> Tensor: """Solve using LU decomposition. Shape inference via meta-shape: torch.lu_solve""" ... @@ -1894,6 +2062,7 @@ def logdet[*Batch, M, N](input: Tensor[*Batch, M, N]) -> Tensor[*Batch]: """Log determinant. Returns batch dimensions only (drops last 2 dims).""" ... +@uses_shape_dsl(slogdet_ir) def slogdet(self: Tensor) -> tuple[Tensor, Tensor]: """Sign and log determinant. Shape inference via meta-shape: torch.slogdet""" ... @@ -1924,6 +2093,7 @@ def matrix_rank[*Batch, M, N]( # ============================================================================== # Conditional operations +@uses_shape_dsl(where_ir) def where(condition: Tensor, x: Tensor, y: Tensor) -> Tensor: """Conditional element-wise selection. Shape inference via meta-shape: torch.where""" ... @@ -1973,6 +2143,7 @@ def take[*IndexShape](input: Tensor, index: Tensor[*IndexShape]) -> Tensor[*Inde """Take elements at indices. Output shape matches index shape.""" ... +@uses_shape_dsl(take_along_dim_ir) def take_along_dim(self: Tensor, indices: Tensor, dim: int) -> Tensor: """Take along dimension. Shape inference via meta-shape: torch.take_along_dim""" ... @@ -1992,10 +2163,12 @@ def bernoulli[*Shape](input: Tensor[*Shape], p: float = 0.5) -> Tensor[*Shape]: """Sample from Bernoulli distribution. Shape inference via generic fixture signature.""" ... +@uses_shape_dsl(multinomial_ir) def multinomial(self: Tensor, num_samples: int, replacement: bool = False) -> Tensor: """Sample from multinomial distribution. Shape inference via meta-shape: torch.multinomial""" ... +@uses_shape_dsl(normal_ir) @overload def normal(mean: Tensor, std: Tensor) -> Tensor: """Sample from normal distribution (tensor mean, tensor std). Shape inference via meta-shape: torch.normal""" @@ -2021,6 +2194,7 @@ def poisson[*Shape](input: Tensor[*Shape]) -> Tensor[*Shape]: ... # Tensor property functions +@uses_shape_dsl(numel_ir) def numel[*Dims](self: Tensor[*Dims]) -> int: """Number of elements. Shape inference via meta-shape: torch.numel""" ... @@ -2056,6 +2230,7 @@ def tensor( """Create tensor from data. Returns shapeless tensor (shape depends on input data).""" ... +@uses_shape_dsl(randint_ir) def randint( low: int, high: int, @@ -2076,6 +2251,7 @@ def rsqrt[*Shape](input: Tensor[*Shape]) -> Tensor[*Shape]: """Reciprocal square root (1/sqrt(x)). Shape-preserving element-wise operation.""" ... +@uses_shape_dsl(outer_ir) def outer(self: Tensor, vec2: Tensor) -> Tensor: """Outer product of two 1D tensors. Shape inference via meta-shape: torch.outer""" ... @@ -2142,6 +2318,7 @@ def cross[*B]( """Cross product of two tensors along a dimension of size 3.""" ... +@uses_shape_dsl(flatten_ir) def flatten( self: Tensor, start_dim: int = 0, diff --git a/test/tensor_shapes/fixtures/torch/fft.pyi b/test/tensor_shapes/fixtures/torch/fft.pyi index d03502ab22..64720f95ca 100644 --- a/test/tensor_shapes/fixtures/torch/fft.pyi +++ b/test/tensor_shapes/fixtures/torch/fft.pyi @@ -4,7 +4,9 @@ # LICENSE file in the root directory of this source tree. # Type stubs for torch.fft module (Phase 6: FFT Operations) +from shape_extensions import uses_shape_dsl from torch import Tensor +from torch._shapes import irfft_ir, rfft_ir # 1D FFT operations def fft[*Shape]( @@ -13,9 +15,13 @@ def fft[*Shape]( def ifft[*Shape]( input: Tensor[*Shape], n: int = None, dim: int = -1, norm: str = None ) -> Tensor[*Shape]: ... +@uses_shape_dsl(rfft_ir) def rfft(self: Tensor, n: int = None, dim: int = -1, norm: str = None) -> Tensor: ... +@uses_shape_dsl(irfft_ir) def irfft(self: Tensor, n: int = None, dim: int = -1, norm: str = None) -> Tensor: ... +@uses_shape_dsl(irfft_ir) def hfft(self: Tensor, n: int = None, dim: int = -1, norm: str = None) -> Tensor: ... +@uses_shape_dsl(rfft_ir) def ihfft(self: Tensor, n: int = None, dim: int = -1, norm: str = None) -> Tensor: ... # 2D FFT operations diff --git a/test/tensor_shapes/fixtures/torch/linalg.pyi b/test/tensor_shapes/fixtures/torch/linalg.pyi index 2154a55a84..73d9cdec8f 100644 --- a/test/tensor_shapes/fixtures/torch/linalg.pyi +++ b/test/tensor_shapes/fixtures/torch/linalg.pyi @@ -4,22 +4,31 @@ # LICENSE file in the root directory of this source tree. # Type stubs for torch.linalg module (Phase 4: Advanced Linear Algebra) +from shape_extensions import uses_shape_dsl from torch import Tensor +from torch._shapes import eig_ir, eigvals_ir, slogdet_ir, solve_ir, solve_reversed_ir # Eigenvalue decomposition +@uses_shape_dsl(eig_ir) def eig(self: Tensor) -> tuple[Tensor, Tensor]: ... +@uses_shape_dsl(eig_ir) def eigh(self: Tensor, UPLO: str = "L") -> tuple[Tensor, Tensor]: ... # Tier 3: Eigenvalues only (no eigenvectors) +@uses_shape_dsl(eigvals_ir) def eigvals(self: Tensor) -> Tensor: ... +@uses_shape_dsl(eigvals_ir) def eigvalsh(self: Tensor, UPLO: str = "L") -> Tensor: ... # Cholesky decomposition def cholesky[*Shape](input: Tensor[*Shape], upper: bool = False) -> Tensor[*Shape]: ... # Linear system solvers +@uses_shape_dsl(solve_ir) def solve(self: Tensor, other: Tensor) -> Tensor: ... +@uses_shape_dsl(solve_ir) def solve_triangular(self: Tensor, other: Tensor, upper: bool = False) -> Tensor: ... +@uses_shape_dsl(solve_reversed_ir) def cholesky_solve(self: Tensor, other: Tensor, upper: bool = False) -> Tensor: ... # Matrix inverse @@ -29,6 +38,7 @@ def inv[*Shape](input: Tensor[*Shape]) -> Tensor[*Shape]: ... def det[*Batch, M, N](input: Tensor[*Batch, M, N]) -> Tensor[*Batch]: ... # Sign and log determinant +@uses_shape_dsl(slogdet_ir) def slogdet(self: Tensor) -> tuple[Tensor, Tensor]: ... # Matrix power diff --git a/test/tensor_shapes/fixtures/torch/nn/functional.pyi b/test/tensor_shapes/fixtures/torch/nn/functional.pyi index 9198e9b808..04a648262b 100644 --- a/test/tensor_shapes/fixtures/torch/nn/functional.pyi +++ b/test/tensor_shapes/fixtures/torch/nn/functional.pyi @@ -10,6 +10,18 @@ Functional neural network operations including convolution, pooling, activation, from typing import Literal, overload +from shape_extensions import uses_shape_dsl +from torch._shapes import ( + adaptive_pool_ir, + conv_ir, + conv_transpose_ir, + cosine_similarity_ir, + interpolate_ir, + loss_ir, + pad_ir, + pool_ir, +) + from .. import Tensor __all__ = [ @@ -93,6 +105,7 @@ __all__ = [ # ==================================================================== # Convolution operations +@uses_shape_dsl(conv_ir) def conv1d( self: Tensor, weight: Tensor, @@ -105,6 +118,7 @@ def conv1d( """1D convolution. Shape inference via meta-shape: torch.nn.functional.conv1d""" ... +@uses_shape_dsl(conv_ir) def conv2d( self: Tensor, weight: Tensor, @@ -117,6 +131,7 @@ def conv2d( """2D convolution. Shape inference via meta-shape: torch.nn.functional.conv2d""" ... +@uses_shape_dsl(conv_ir) def conv3d( self: Tensor, weight: Tensor, @@ -130,6 +145,7 @@ def conv3d( ... # Transposed convolution operations +@uses_shape_dsl(conv_transpose_ir) def conv_transpose1d( self: Tensor, weight: Tensor, @@ -143,6 +159,7 @@ def conv_transpose1d( """1D transposed convolution. Shape inference via meta-shape: torch.nn.functional.conv_transpose1d""" ... +@uses_shape_dsl(conv_transpose_ir) def conv_transpose2d( self: Tensor, weight: Tensor, @@ -156,6 +173,7 @@ def conv_transpose2d( """2D transposed convolution. Shape inference via meta-shape: torch.nn.functional.conv_transpose2d""" ... +@uses_shape_dsl(conv_transpose_ir) def conv_transpose3d( self: Tensor, weight: Tensor, @@ -170,6 +188,7 @@ def conv_transpose3d( ... # Max pooling operations +@uses_shape_dsl(pool_ir) @overload def max_pool1d( self: Tensor, @@ -196,6 +215,7 @@ def max_pool1d( """1D max pooling with indices. Shape inference via meta-shape: torch.nn.functional.max_pool1d""" ... +@uses_shape_dsl(pool_ir) @overload def max_pool2d( self: Tensor, @@ -222,6 +242,7 @@ def max_pool2d( """2D max pooling with indices. Shape inference via meta-shape: torch.nn.functional.max_pool2d""" ... +@uses_shape_dsl(pool_ir) @overload def max_pool3d( self: Tensor, @@ -249,6 +270,7 @@ def max_pool3d( ... # Average pooling operations +@uses_shape_dsl(pool_ir) def avg_pool1d( self: Tensor, kernel_size: int | tuple[int], @@ -260,6 +282,7 @@ def avg_pool1d( """1D average pooling. Shape inference via meta-shape: torch.nn.functional.avg_pool1d""" ... +@uses_shape_dsl(pool_ir) def avg_pool2d( self: Tensor, kernel_size: int | tuple[int, int], @@ -272,6 +295,7 @@ def avg_pool2d( """2D average pooling. Shape inference via meta-shape: torch.nn.functional.avg_pool2d""" ... +@uses_shape_dsl(pool_ir) def avg_pool3d( self: Tensor, kernel_size: int | tuple[int, int, int], @@ -285,12 +309,14 @@ def avg_pool3d( ... # Adaptive max pooling operations +@uses_shape_dsl(adaptive_pool_ir) def adaptive_max_pool1d( self: Tensor, output_size: int | tuple[int], return_indices: bool = False ) -> Tensor: """1D adaptive max pooling. Shape inference via meta-shape: torch.nn.functional.adaptive_max_pool1d""" ... +@uses_shape_dsl(adaptive_pool_ir) def adaptive_max_pool2d( self: Tensor, output_size: int | tuple[int, int] | None, @@ -299,6 +325,7 @@ def adaptive_max_pool2d( """2D adaptive max pooling. Shape inference via meta-shape: torch.nn.functional.adaptive_max_pool2d""" ... +@uses_shape_dsl(adaptive_pool_ir) def adaptive_max_pool3d( self: Tensor, output_size: int | tuple[int, int, int] | None, @@ -308,16 +335,19 @@ def adaptive_max_pool3d( ... # Adaptive average pooling operations +@uses_shape_dsl(adaptive_pool_ir) def adaptive_avg_pool1d(self: Tensor, output_size: int | tuple[int]) -> Tensor: """1D adaptive average pooling. Shape inference via meta-shape: torch.nn.functional.adaptive_avg_pool1d""" ... +@uses_shape_dsl(adaptive_pool_ir) def adaptive_avg_pool2d( self: Tensor, output_size: int | tuple[int, int] | None ) -> Tensor: """2D adaptive average pooling. Shape inference via meta-shape: torch.nn.functional.adaptive_avg_pool2d""" ... +@uses_shape_dsl(adaptive_pool_ir) def adaptive_avg_pool3d( self: Tensor, output_size: int | tuple[int, int, int] | None ) -> Tensor: @@ -325,6 +355,7 @@ def adaptive_avg_pool3d( ... # Interpolation/upsampling operations +@uses_shape_dsl(interpolate_ir) def interpolate( self: Tensor, size: int | tuple[int, ...] | None = None, @@ -337,6 +368,7 @@ def interpolate( """Interpolate/upsample tensor. Shape inference via meta-shape: torch.nn.functional.interpolate""" ... +@uses_shape_dsl(interpolate_ir) def upsample( self: Tensor, size: int | tuple[int, ...] | None = None, @@ -552,6 +584,7 @@ def logsigmoid[*Shape](input: Tensor[*Shape]) -> Tensor[*Shape]: # Phase 6: Loss Functions # ============================================================================== +@uses_shape_dsl(loss_ir) def mse_loss( self: Tensor, target: Tensor, @@ -562,6 +595,7 @@ def mse_loss( """Mean squared error loss. Shape inference via meta-shape: torch.nn.functional.mse_loss""" ... +@uses_shape_dsl(loss_ir) def l1_loss( self: Tensor, target: Tensor, @@ -572,6 +606,7 @@ def l1_loss( """L1 loss. Shape inference via meta-shape: torch.nn.functional.l1_loss""" ... +@uses_shape_dsl(loss_ir) def nll_loss( self: Tensor, target: Tensor, @@ -584,6 +619,7 @@ def nll_loss( """Negative log likelihood loss. Shape inference via meta-shape: torch.nn.functional.nll_loss""" ... +@uses_shape_dsl(loss_ir) def cross_entropy( self: Tensor, target: Tensor, @@ -597,6 +633,7 @@ def cross_entropy( """Cross entropy loss. Shape inference via meta-shape: torch.nn.functional.cross_entropy""" ... +@uses_shape_dsl(loss_ir) def binary_cross_entropy( self: Tensor, target: Tensor, @@ -608,6 +645,7 @@ def binary_cross_entropy( """Binary cross entropy loss. Shape inference via meta-shape: torch.nn.functional.binary_cross_entropy""" ... +@uses_shape_dsl(loss_ir) def binary_cross_entropy_with_logits( self: Tensor, target: Tensor, @@ -620,6 +658,7 @@ def binary_cross_entropy_with_logits( """Binary cross entropy with logits. Shape inference via meta-shape: torch.nn.functional.binary_cross_entropy_with_logits""" ... +@uses_shape_dsl(loss_ir) def kl_div( self: Tensor, target: Tensor, @@ -631,6 +670,7 @@ def kl_div( """KL divergence loss. Shape inference via meta-shape: torch.nn.functional.kl_div""" ... +@uses_shape_dsl(loss_ir) def smooth_l1_loss( self: Tensor, target: Tensor, @@ -642,12 +682,14 @@ def smooth_l1_loss( """Smooth L1 loss. Shape inference via meta-shape: torch.nn.functional.smooth_l1_loss""" ... +@uses_shape_dsl(loss_ir) def huber_loss( self: Tensor, target: Tensor, reduction: str = "mean", delta: float = 1.0 ) -> Tensor: """Huber loss. Shape inference via meta-shape: torch.nn.functional.huber_loss""" ... +@uses_shape_dsl(loss_ir) def poisson_nll_loss( self: Tensor, target: Tensor, @@ -661,6 +703,7 @@ def poisson_nll_loss( """Poisson NLL loss. Shape inference via meta-shape: torch.nn.functional.poisson_nll_loss""" ... +@uses_shape_dsl(loss_ir) def cosine_embedding_loss( self: Tensor, input2: Tensor, @@ -673,6 +716,7 @@ def cosine_embedding_loss( """Cosine embedding loss. Shape inference via meta-shape: torch.nn.functional.cosine_embedding_loss""" ... +@uses_shape_dsl(loss_ir) def margin_ranking_loss( self: Tensor, input2: Tensor, @@ -685,6 +729,7 @@ def margin_ranking_loss( """Margin ranking loss. Shape inference via meta-shape: torch.nn.functional.margin_ranking_loss""" ... +@uses_shape_dsl(loss_ir) def triplet_margin_loss( self: Tensor, positive: Tensor, @@ -700,6 +745,7 @@ def triplet_margin_loss( """Triplet margin loss. Shape inference via meta-shape: torch.nn.functional.triplet_margin_loss""" ... +@uses_shape_dsl(loss_ir) def hinge_embedding_loss( self: Tensor, target: Tensor, @@ -712,6 +758,7 @@ def hinge_embedding_loss( ... # Padding operation +@uses_shape_dsl(pad_ir) def pad( self: Tensor, pad: tuple[int, ...], mode: str = "constant", value: float = 0.0 ) -> Tensor: @@ -829,6 +876,7 @@ def scaled_dot_product_attention[ """Scaled dot product attention. Shape inference via meta-shape: torch.nn.functional.scaled_dot_product_attention""" ... +@uses_shape_dsl(cosine_similarity_ir) def cosine_similarity( x1: Tensor, x2: Tensor, dim: int = 1, eps: float = 1e-8 ) -> Tensor: From 56d9837c4fbbf87205b9e2de6f2f0935e0edce8f Mon Sep 17 00:00:00 2001 From: stroxler Date: Thu, 21 May 2026 21:19:44 -0700 Subject: [PATCH 19/25] Add `capture_init` annotations to nn.Module forward methods MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Add `uses_shape_dsl(ir_fn, capture_init=[...])` decorators to the `forward` methods of all 15 nn.Module classes that have shape-aware forward inference: MaxPool1d/2d/3d, AvgPool1d/2d/3d, Flatten, PixelShuffle, GLU, LSTM, Upsample, GRU, LSTMCell, ReflectionPad2d, ReplicationPad2d. The `capture_init` plumbing (binding extraction, ClassMetadata propagation, and `maybe_wrap_nn_module` consumption) was completed in Phases 3a and 4c. This commit is purely fixture annotations — once added, the decorators override the registry path for each annotated class. Differential Revision: D105775129 --- .../fixtures/torch/nn/__init__.pyi | 50 ++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/test/tensor_shapes/fixtures/torch/nn/__init__.pyi b/test/tensor_shapes/fixtures/torch/nn/__init__.pyi index a8fee4f2c1..84236a48db 100644 --- a/test/tensor_shapes/fixtures/torch/nn/__init__.pyi +++ b/test/tensor_shapes/fixtures/torch/nn/__init__.pyi @@ -21,8 +21,20 @@ from typing import ( ) if TYPE_CHECKING: - from shape_extensions import Dim as _Dim + from shape_extensions import Dim as _Dim, uses_shape_dsl from torch import Tensor + from torch._shapes import ( + nn_avgpool_forward_ir, + nn_flatten_forward_ir, + nn_glu_forward_ir, + nn_gru_forward_ir, + nn_lstm_forward_ir, + nn_lstmcell_forward_ir, + nn_maxpool_forward_ir, + nn_pixel_shuffle_forward_ir, + nn_reflectionpad2d_forward_ir, + nn_upsample_forward_ir, + ) # Re-export submodules from . import functional as functional, init as init @@ -675,6 +687,10 @@ class MaxPool1d(Module): return_indices: bool = False, ceil_mode: bool = False, ) -> None: ... + @uses_shape_dsl( + nn_maxpool_forward_ir, + capture_init=["kernel_size", "stride", "padding", "dilation"], + ) def forward(self, input: Tensor) -> Tensor: ... class MaxPool2d(Module): @@ -688,6 +704,10 @@ class MaxPool2d(Module): return_indices: bool = False, ceil_mode: bool = False, ) -> None: ... + @uses_shape_dsl( + nn_maxpool_forward_ir, + capture_init=["kernel_size", "stride", "padding", "dilation"], + ) def forward(self, input: Tensor) -> Tensor: ... class MaxPool3d(Module): @@ -701,6 +721,10 @@ class MaxPool3d(Module): return_indices: bool = False, ceil_mode: bool = False, ) -> None: ... + @uses_shape_dsl( + nn_maxpool_forward_ir, + capture_init=["kernel_size", "stride", "padding", "dilation"], + ) def forward(self, input: Tensor) -> Tensor: ... class AvgPool1d(Module): @@ -713,6 +737,9 @@ class AvgPool1d(Module): ceil_mode: bool = False, count_include_pad: bool = True, ) -> None: ... + @uses_shape_dsl( + nn_avgpool_forward_ir, capture_init=["kernel_size", "stride", "padding"] + ) def forward(self, input: Tensor) -> Tensor: ... class AvgPool2d(Module): @@ -726,6 +753,9 @@ class AvgPool2d(Module): count_include_pad: bool = True, divisor_override: int | None = None, ) -> None: ... + @uses_shape_dsl( + nn_avgpool_forward_ir, capture_init=["kernel_size", "stride", "padding"] + ) def forward(self, input: Tensor) -> Tensor: ... class AvgPool3d(Module): @@ -739,6 +769,9 @@ class AvgPool3d(Module): count_include_pad: bool = True, divisor_override: int | None = None, ) -> None: ... + @uses_shape_dsl( + nn_avgpool_forward_ir, capture_init=["kernel_size", "stride", "padding"] + ) def forward(self, input: Tensor) -> Tensor: ... class AdaptiveAvgPool1d[OL](Module): @@ -794,6 +827,7 @@ class PixelShuffle(Module): """ def __init__(self, upscale_factor: int) -> None: ... + @uses_shape_dsl(nn_pixel_shuffle_forward_ir, capture_init=["upscale_factor"]) def forward(self, input: Tensor) -> Tensor: ... class GLU(Module): @@ -806,6 +840,7 @@ class GLU(Module): """ def __init__(self, dim: int = 1) -> None: ... + @uses_shape_dsl(nn_glu_forward_ir, capture_init=["dim"]) def forward(self, input: Tensor) -> Tensor: ... class LSTM(Module): @@ -834,6 +869,10 @@ class LSTM(Module): def flatten_parameters(self) -> None: """Reset parameter data pointer for CUDA contiguous memory. No-op on CPU.""" ... + @uses_shape_dsl( + nn_lstm_forward_ir, + capture_init=["input_size", "hidden_size", "num_layers", "bidirectional"], + ) def forward(self, input: Tensor) -> tuple[Tensor, Tensor, Tensor]: ... class LSTMCell(Module): @@ -853,6 +892,7 @@ class LSTMCell(Module): device: Any = None, dtype: Any = None, ) -> None: ... + @uses_shape_dsl(nn_lstmcell_forward_ir, capture_init=["input_size", "hidden_size"]) def forward( self, input: Tensor, hx: tuple[Tensor, Tensor] | None = None ) -> tuple[Tensor, Tensor]: ... @@ -882,6 +922,10 @@ class GRU(Module): def flatten_parameters(self) -> None: """Reset parameter data pointer for CUDA contiguous memory. No-op on CPU.""" ... + @uses_shape_dsl( + nn_gru_forward_ir, + capture_init=["input_size", "hidden_size", "num_layers", "bidirectional"], + ) def forward( self, input: Tensor, hx: Tensor | None = None ) -> tuple[Tensor, Tensor]: ... @@ -920,6 +964,7 @@ class Upsample(Module): mode: str = "nearest", align_corners: bool | None = None, ) -> None: ... + @uses_shape_dsl(nn_upsample_forward_ir, capture_init=["size", "scale_factor"]) def forward(self, input: Tensor) -> Tensor: ... # ============================================================================== @@ -1073,6 +1118,7 @@ class Flatten(Module): """ def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None: ... + @uses_shape_dsl(nn_flatten_forward_ir, capture_init=["start_dim", "end_dim"]) def forward(self, input: Tensor) -> Tensor: ... class Unflatten(Module): @@ -1087,6 +1133,7 @@ class ReflectionPad2d(Module): """ def __init__(self, padding: int) -> None: ... + @uses_shape_dsl(nn_reflectionpad2d_forward_ir, capture_init=["padding"]) def forward(self, input: Tensor) -> Tensor: ... class ReplicationPad2d(Module): @@ -1096,6 +1143,7 @@ class ReplicationPad2d(Module): """ def __init__(self, padding: int) -> None: ... + @uses_shape_dsl(nn_reflectionpad2d_forward_ir, capture_init=["padding"]) def forward(self, input: Tensor) -> Tensor: ... # Embedding variants From 4e4e795ce3823f1e8cd1ffe0f79b40434140645d Mon Sep 17 00:00:00 2001 From: stroxler Date: Thu, 21 May 2026 21:19:44 -0700 Subject: [PATCH 20/25] Delete `TensorOpsRegistry` and all legacy fallback paths MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Remove the old hardcoded registry now that all ~131 shape functions and 15 nn.Module classes are wired through `uses_shape_dsl` decorators. Deleted: - `tensor_ops_registry.rs` (1,022 lines): `DSL_SOURCE` string, registry struct, ~131 `register*()` calls, `OnceLock` statics - `lookup_meta_shape` in `callable.rs`: registry-based shape function lookup - Registry fallback in `callable_infer_inner`: now decorator-only - Registry fallback in `maybe_wrap_nn_module`: now ClassMetadata-only - Phase 2 unit test in `meta_shape_dsl.rs`: superseded by e2e tests The `tensor_shapes` config flag remains — it gates non-registry behavior (Tensor subscript support, jaxtyping, operator overloads). Differential Revision: D105775132 --- crates/pyrefly_types/src/lib.rs | 1 - crates/pyrefly_types/src/meta_shape_dsl.rs | 43 - .../pyrefly_types/src/tensor_ops_registry.rs | 1022 ----------------- pyrefly/lib/alt/call.rs | 17 +- pyrefly/lib/alt/callable.rs | 37 +- 5 files changed, 2 insertions(+), 1118 deletions(-) delete mode 100644 crates/pyrefly_types/src/tensor_ops_registry.rs diff --git a/crates/pyrefly_types/src/lib.rs b/crates/pyrefly_types/src/lib.rs index 2b6607c4ca..83c7c4ff21 100644 --- a/crates/pyrefly_types/src/lib.rs +++ b/crates/pyrefly_types/src/lib.rs @@ -41,7 +41,6 @@ pub mod simplify; pub mod special_form; pub mod stdlib; pub mod tensor; -pub mod tensor_ops_registry; pub mod tuple; pub mod type_alias; pub mod type_info; diff --git a/crates/pyrefly_types/src/meta_shape_dsl.rs b/crates/pyrefly_types/src/meta_shape_dsl.rs index 712653ed18..498ebcaef4 100644 --- a/crates/pyrefly_types/src/meta_shape_dsl.rs +++ b/crates/pyrefly_types/src/meta_shape_dsl.rs @@ -3334,46 +3334,3 @@ pub fn make_meta_shape_function( ); Some(Box::new(DslMetaShapeFunction { fn_def, fn_lookup })) } - -// TODO: Remove this unit test once the DSL is fully in stubs. -// The e2e tests in pyrefly/test/tensor_shapes will exercise the same code -// paths more thoroughly, making this redundant. -#[cfg(test)] -mod tests { - use pyrefly_python::ast::Ast; - use ruff_python_ast::PySourceType; - use ruff_python_ast::Stmt; - - use super::*; - - /// Sanity check: the public wrapper API composes from AST through to a - /// `MetaShapeFunction` without panicking on a trivial well-formed DSL - /// function. Does not evaluate the DSL — only verifies the surface - /// composes. - #[test] - fn wrapper_api_composes_on_trivial_function() { - let source = "def add_one(x: int) -> int:\n return x + 1\n"; - let (module, errors, _unsupported) = Ast::parse(source, PySourceType::Python); - assert!(errors.is_empty(), "test DSL fixture should parse cleanly"); - let func = module - .body - .iter() - .find_map(|stmt| match stmt { - Stmt::FunctionDef(f) => Some(f), - _ => None, - }) - .expect("test source contains exactly one function def"); - - let shape_fn = convert_shape_dsl_function(func).expect("lowering succeeds"); - let program = build_shape_dsl_program(std::iter::once(shape_fn)); - let meta_fn = make_meta_shape_function(&program, "add_one"); - assert!( - meta_fn.is_some(), - "factory should resolve a function whose name matches the program" - ); - assert!( - make_meta_shape_function(&program, "no_such_function").is_none(), - "factory should return None for an unknown name" - ); - } -} diff --git a/crates/pyrefly_types/src/tensor_ops_registry.rs b/crates/pyrefly_types/src/tensor_ops_registry.rs deleted file mode 100644 index 8fb6d9b2f7..0000000000 --- a/crates/pyrefly_types/src/tensor_ops_registry.rs +++ /dev/null @@ -1,1022 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. - */ - -//! Meta-shape op registry. -//! -//! This module registers all PyTorch op shape functions in `TensorOpsRegistry`. -//! All op definitions are expressed in the DSL (parsed by `meta_shape_dsl.rs`) and -//! interpreted directly — no IR layer. - -use std::collections::HashMap; - -use pyrefly_python::ast::Ast; -use ruff_python_ast::PySourceType; -use ruff_python_ast::Stmt; - -use crate::meta_shape_dsl::MetaShapeFunction; -use crate::meta_shape_dsl::ShapeDslProgram; -use crate::meta_shape_dsl::build_shape_dsl_program; -use crate::meta_shape_dsl::convert_shape_dsl_function; -use crate::meta_shape_dsl::make_meta_shape_function; - -// Section: DSL-based MetaShapeFunction construction - -/// Look up a DSL function by name in the shared program and wrap it as a -/// `MetaShapeFunction` suitable for registration. Panics if the program does -/// not contain a function with that name — the bundled `DSL_SOURCE` is -/// fixed, so a missing name is a programming error in this file. -fn dsl_fn(program: &ShapeDslProgram, name: &str) -> Box { - make_meta_shape_function(program, name) - .unwrap_or_else(|| panic!("DSL function `{name}` not found")) -} - -// Section: Meta-Shape Registry - -/// Registry mapping PyTorch op names to their shape functions. -/// -/// All shape functions are backed by DSL definitions (see `meta_shape_dsl.rs`). -/// The DSL source is parsed once at registry construction time. Parsed definitions -/// are shared via `Arc` across all shape function instances. -pub struct TensorOpsRegistry { - functions: HashMap>, - /// Maps qualified class names (e.g., "torch.nn.MaxPool2d") to the list of - /// __init__ parameter names to capture. When a class has an init capture - /// registered, `construct_class` builds a `Type::NNModule` instead of a - /// `Type::ClassType`, storing the captured arg values in the NNModule's - /// field map. This allows forward DSL functions to access constructor - /// parameters directly from the type. - init_captures: HashMap>, -} - -impl TensorOpsRegistry { - /// Create a new registry with built-in meta-shape functions. - pub fn new() -> Self { - // Parse DSL_SOURCE to AST, then lower each function via the public - // `convert_shape_dsl_function` wrapper and bundle the results into a - // type-checked `ShapeDslProgram`. Sharing one `ShapeDslProgram` - // across all registrations means the underlying `Arc` - // graph is built once and re-used (each `make_meta_shape_function` - // call just clones a couple of `Arc`s). - let (module, errors, _unsupported) = Ast::parse(DSL_SOURCE, PySourceType::Python); - assert!( - errors.is_empty(), - "DSL source in tensor_ops_registry.rs has parse errors: {errors:?}" - ); - let shape_fns = module.body.iter().filter_map(|stmt| match stmt { - Stmt::FunctionDef(f) => Some( - convert_shape_dsl_function(f) - .expect("DSL source in tensor_ops_registry.rs has errors"), - ), - _ => None, - }); - let program = build_shape_dsl_program(shape_fns); - let mut registry = Self { - functions: HashMap::new(), - init_captures: HashMap::new(), - }; - - // Shape manipulation - registry.register_dual("reshape", || dsl_fn(&program, "reshape_ir")); - registry.register("torch.cat", dsl_fn(&program, "cat_ir")); - registry.register("torch.broadcast_to", dsl_fn(&program, "broadcast_to_ir")); - registry.register_dual("squeeze", || dsl_fn(&program, "squeeze_ir")); - registry.register_dual("unsqueeze", || dsl_fn(&program, "unsqueeze_ir")); - registry.register_dual("transpose", || dsl_fn(&program, "transpose_ir")); - // torch.permute takes dims as a tuple; Tensor.permute takes *dims (variadic). - // Both use the same DSL function; parameter binding matches by name. - registry.register("torch.permute", dsl_fn(&program, "permute_ir")); - registry.register("torch.Tensor.permute", dsl_fn(&program, "permute_ir")); - registry.register("torch.flatten", dsl_fn(&program, "flatten_ir")); - registry.register("torch.stack", dsl_fn(&program, "stack_ir")); - registry.register("torch.tile", dsl_fn(&program, "tile_ir")); - registry.register("torch.view", dsl_fn(&program, "reshape_ir")); - registry.register("torch.unbind", dsl_fn(&program, "unbind_ir")); - registry.register("torch.Tensor.unbind", dsl_fn(&program, "unbind_ir")); - registry.register("torch.movedim", dsl_fn(&program, "movedim_ir")); - registry.register("torch.moveaxis", dsl_fn(&program, "movedim_ir")); - registry.register("torch.Tensor.movedim", dsl_fn(&program, "movedim_ir")); - registry.register("torch.Tensor.moveaxis", dsl_fn(&program, "movedim_ir")); - registry.register("torch.unfold", dsl_fn(&program, "unfold_ir")); - registry.register("torch.Tensor.unfold", dsl_fn(&program, "unfold_ir")); - - // Method-only shape manipulation - registry.register("torch.Tensor.reshape", dsl_fn(&program, "reshape_ir")); - registry.register("torch.Tensor.view", dsl_fn(&program, "reshape_ir")); - registry.register("torch.Tensor.squeeze", dsl_fn(&program, "squeeze_ir")); - registry.register("torch.Tensor.flatten", dsl_fn(&program, "flatten_ir")); - registry.register("torch.Tensor.tile", dsl_fn(&program, "tile_ir")); - registry.register("torch.Tensor.diag_embed", dsl_fn(&program, "diag_embed_ir")); - registry.register("torch.Tensor.repeat", dsl_fn(&program, "repeat_ir")); - registry.register("torch.Tensor.expand", dsl_fn(&program, "expand_ir")); - - // Reduction operations - registry.register_dual("sum", || dsl_fn(&program, "reduce_ir")); - registry.register_dual("mean", || dsl_fn(&program, "reduce_ir")); - registry.register_dual("prod", || dsl_fn(&program, "reduce_ir")); - registry.register_dual("min", || dsl_fn(&program, "min_max_median_ir")); - registry.register_dual("max", || dsl_fn(&program, "min_max_median_ir")); - registry.register_dual("all", || dsl_fn(&program, "reduce_ir")); - registry.register_dual("any", || dsl_fn(&program, "reduce_ir")); - registry.register_dual("std", || dsl_fn(&program, "reduce_ir")); - registry.register_dual("var", || dsl_fn(&program, "reduce_ir")); - registry.register_dual("argmax", || dsl_fn(&program, "reduce_ir")); - registry.register_dual("argmin", || dsl_fn(&program, "reduce_ir")); - registry.register("torch.median", dsl_fn(&program, "min_max_median_ir")); - registry.register("torch.logsumexp", dsl_fn(&program, "reduce_ir")); - registry.register("torch.count_nonzero", dsl_fn(&program, "reduce_ir")); - registry.register("torch.aminmax", dsl_fn(&program, "aminmax_ir")); - registry.register("torch.norm", dsl_fn(&program, "reduce_ir")); - registry.register("torch.mode", dsl_fn(&program, "tuple_reduce_ir")); - registry.register("torch.topk", dsl_fn(&program, "topk_ir")); - registry.register("torch.kthvalue", dsl_fn(&program, "tuple_reduce_ir")); - registry.register("torch.var_mean", dsl_fn(&program, "aminmax_ir")); - registry.register("torch.std_mean", dsl_fn(&program, "aminmax_ir")); - - // Reduction method versions - registry.register("torch.Tensor.median", dsl_fn(&program, "min_max_median_ir")); - registry.register("torch.Tensor.logsumexp", dsl_fn(&program, "reduce_ir")); - registry.register("torch.Tensor.count_nonzero", dsl_fn(&program, "reduce_ir")); - registry.register("torch.Tensor.aminmax", dsl_fn(&program, "aminmax_ir")); - registry.register("torch.Tensor.norm", dsl_fn(&program, "reduce_ir")); - registry.register("torch.Tensor.mode", dsl_fn(&program, "tuple_reduce_ir")); - registry.register("torch.Tensor.topk", dsl_fn(&program, "topk_ir")); - registry.register("torch.Tensor.kthvalue", dsl_fn(&program, "tuple_reduce_ir")); - - // Repeat interleave - registry.register( - "torch.Tensor.repeat_interleave", - dsl_fn(&program, "repeat_interleave_ir"), - ); - registry.register( - "torch.repeat_interleave", - dsl_fn(&program, "repeat_interleave_ir"), - ); - - // Cosine similarity (reduces one dim) - registry.register( - "torch.nn.functional.cosine_similarity", - dsl_fn(&program, "cosine_similarity_ir"), - ); - - // Indexing/slicing - registry.register("torch.select", dsl_fn(&program, "select_ir")); - registry.register("torch.narrow", dsl_fn(&program, "narrow_ir")); - registry.register("torch.split", dsl_fn(&program, "split_ir")); - registry.register("torch.chunk", dsl_fn(&program, "chunk_ir")); - registry.register("torch.index_select", dsl_fn(&program, "index_select_ir")); - registry.register("torch.Tensor.select", dsl_fn(&program, "select_ir")); - registry.register("torch.Tensor.narrow", dsl_fn(&program, "narrow_ir")); - registry.register("torch.Tensor.split", dsl_fn(&program, "split_ir")); - registry.register("torch.Tensor.chunk", dsl_fn(&program, "chunk_ir")); - registry.register( - "torch.Tensor.index_select", - dsl_fn(&program, "index_select_ir"), - ); - - // Tensor creation - registry.register("torch.randn", dsl_fn(&program, "randn_ir")); - registry.register("torch.rand", dsl_fn(&program, "randn_ir")); - registry.register("torch.zeros", dsl_fn(&program, "randn_ir")); - registry.register("torch.ones", dsl_fn(&program, "randn_ir")); - registry.register("torch.empty", dsl_fn(&program, "randn_ir")); - registry.register("torch.full", dsl_fn(&program, "randn_ir")); - registry.register("torch.randint", dsl_fn(&program, "randint_ir")); - registry.register("torch.arange", dsl_fn(&program, "arange_ir")); - registry.register("torch.linspace", dsl_fn(&program, "linspace_ir")); - registry.register("torch.eye", dsl_fn(&program, "eye_ir")); - registry.register("torch.diag_embed", dsl_fn(&program, "diag_embed_ir")); - registry.register("torch.tril_indices", dsl_fn(&program, "tri_indices_ir")); - registry.register("torch.triu_indices", dsl_fn(&program, "tri_indices_ir")); - - // Linear algebra - registry.register("torch.matmul", dsl_fn(&program, "matmul_ir")); - registry.register("torch.mv", dsl_fn(&program, "mv_ir")); - registry.register("torch.outer", dsl_fn(&program, "outer_ir")); - registry.register("torch.tensordot", dsl_fn(&program, "tensordot_ir")); - registry.register("torch.einsum", dsl_fn(&program, "einsum_ir")); - registry.register("torch.Tensor.matmul", dsl_fn(&program, "matmul_ir")); - registry.register("torch.Tensor.__matmul__", dsl_fn(&program, "matmul_ir")); - registry.register("torch.Tensor.mv", dsl_fn(&program, "mv_ir")); - - // Eigenvalue decomposition - registry.register("torch.linalg.eig", dsl_fn(&program, "eig_ir")); - registry.register("torch.eig", dsl_fn(&program, "eig_ir")); - registry.register("torch.linalg.eigh", dsl_fn(&program, "eig_ir")); - registry.register("torch.eigh", dsl_fn(&program, "eig_ir")); - registry.register("torch.linalg.eigvals", dsl_fn(&program, "eigvals_ir")); - registry.register("torch.linalg.eigvalsh", dsl_fn(&program, "eigvals_ir")); - - // Linear solvers - registry.register("torch.linalg.solve", dsl_fn(&program, "solve_ir")); - registry.register("torch.solve", dsl_fn(&program, "solve_ir")); - registry.register( - "torch.linalg.solve_triangular", - dsl_fn(&program, "solve_ir"), - ); - registry.register( - "torch.triangular_solve", - dsl_fn(&program, "solve_reversed_ir"), - ); - registry.register( - "torch.linalg.cholesky_solve", - dsl_fn(&program, "solve_reversed_ir"), - ); - registry.register( - "torch.cholesky_solve", - dsl_fn(&program, "solve_reversed_ir"), - ); - registry.register("torch.lu_solve", dsl_fn(&program, "solve_ir")); - - // Determinant - registry.register("torch.linalg.slogdet", dsl_fn(&program, "slogdet_ir")); - registry.register("torch.slogdet", dsl_fn(&program, "slogdet_ir")); - registry.register("torch.Tensor.slogdet", dsl_fn(&program, "slogdet_ir")); - - // Convolution - registry.register("torch.nn.functional.conv1d", dsl_fn(&program, "conv_ir")); - registry.register("torch.nn.functional.conv2d", dsl_fn(&program, "conv_ir")); - registry.register("torch.nn.functional.conv3d", dsl_fn(&program, "conv_ir")); - registry.register( - "torch.nn.functional.conv_transpose1d", - dsl_fn(&program, "conv_transpose_ir"), - ); - registry.register( - "torch.nn.functional.conv_transpose2d", - dsl_fn(&program, "conv_transpose_ir"), - ); - registry.register( - "torch.nn.functional.conv_transpose3d", - dsl_fn(&program, "conv_transpose_ir"), - ); - - // Pooling - registry.register( - "torch.nn.functional.max_pool1d", - dsl_fn(&program, "pool_ir"), - ); - registry.register( - "torch.nn.functional.max_pool2d", - dsl_fn(&program, "pool_ir"), - ); - registry.register( - "torch.nn.functional.max_pool3d", - dsl_fn(&program, "pool_ir"), - ); - registry.register( - "torch.nn.functional.avg_pool1d", - dsl_fn(&program, "pool_ir"), - ); - registry.register( - "torch.nn.functional.avg_pool2d", - dsl_fn(&program, "pool_ir"), - ); - registry.register( - "torch.nn.functional.avg_pool3d", - dsl_fn(&program, "pool_ir"), - ); - registry.register( - "torch.nn.functional.adaptive_max_pool1d", - dsl_fn(&program, "adaptive_pool_ir"), - ); - registry.register( - "torch.nn.functional.adaptive_max_pool2d", - dsl_fn(&program, "adaptive_pool_ir"), - ); - registry.register( - "torch.nn.functional.adaptive_max_pool3d", - dsl_fn(&program, "adaptive_pool_ir"), - ); - registry.register( - "torch.nn.functional.adaptive_avg_pool1d", - dsl_fn(&program, "adaptive_pool_ir"), - ); - registry.register( - "torch.nn.functional.adaptive_avg_pool2d", - dsl_fn(&program, "adaptive_pool_ir"), - ); - registry.register( - "torch.nn.functional.adaptive_avg_pool3d", - dsl_fn(&program, "adaptive_pool_ir"), - ); - - // Interpolation - registry.register( - "torch.nn.functional.interpolate", - dsl_fn(&program, "interpolate_ir"), - ); - registry.register( - "torch.nn.functional.upsample", - dsl_fn(&program, "interpolate_ir"), - ); - - // Conditional operations - registry.register("torch.where", dsl_fn(&program, "where_ir")); - registry.register( - "torch.take_along_dim", - dsl_fn(&program, "take_along_dim_ir"), - ); - registry.register( - "torch.Tensor.take_along_dim", - dsl_fn(&program, "take_along_dim_ir"), - ); - - // Loss functions - registry.register("torch.nn.functional.mse_loss", dsl_fn(&program, "loss_ir")); - registry.register("torch.nn.functional.l1_loss", dsl_fn(&program, "loss_ir")); - registry.register("torch.nn.functional.nll_loss", dsl_fn(&program, "loss_ir")); - registry.register( - "torch.nn.functional.cross_entropy", - dsl_fn(&program, "loss_ir"), - ); - registry.register( - "torch.nn.functional.binary_cross_entropy", - dsl_fn(&program, "loss_ir"), - ); - registry.register( - "torch.nn.functional.binary_cross_entropy_with_logits", - dsl_fn(&program, "loss_ir"), - ); - registry.register("torch.nn.functional.kl_div", dsl_fn(&program, "loss_ir")); - registry.register( - "torch.nn.functional.smooth_l1_loss", - dsl_fn(&program, "loss_ir"), - ); - registry.register( - "torch.nn.functional.huber_loss", - dsl_fn(&program, "loss_ir"), - ); - registry.register( - "torch.nn.functional.poisson_nll_loss", - dsl_fn(&program, "loss_ir"), - ); - registry.register( - "torch.nn.functional.cosine_embedding_loss", - dsl_fn(&program, "loss_ir"), - ); - registry.register( - "torch.nn.functional.margin_ranking_loss", - dsl_fn(&program, "loss_ir"), - ); - registry.register( - "torch.nn.functional.triplet_margin_loss", - dsl_fn(&program, "loss_ir"), - ); - registry.register( - "torch.nn.functional.hinge_embedding_loss", - dsl_fn(&program, "loss_ir"), - ); - - // Padding - registry.register("torch.nn.functional.pad", dsl_fn(&program, "pad_ir")); - - // FFT - registry.register("torch.fft.rfft", dsl_fn(&program, "rfft_ir")); - registry.register("torch.fft.irfft", dsl_fn(&program, "irfft_ir")); - registry.register("torch.fft.hfft", dsl_fn(&program, "irfft_ir")); - registry.register("torch.fft.ihfft", dsl_fn(&program, "rfft_ir")); - - // Tensor properties - registry.register("torch.Tensor.size", dsl_fn(&program, "size_ir")); - registry.register("torch.Tensor.numel", dsl_fn(&program, "numel_ir")); - registry.register("torch.Tensor.dim", dsl_fn(&program, "dim_ir")); - registry.register("torch.Tensor.nelement", dsl_fn(&program, "numel_ir")); - registry.register("torch.Tensor.item", dsl_fn(&program, "item_ir")); - registry.register("torch.Tensor.tolist", dsl_fn(&program, "tolist_ir")); - registry.register("torch.numel", dsl_fn(&program, "numel_ir")); - - // nn.Module forward methods with init capture. - // register_init_forward registers both the forward DSL function and the - // list of __init__ params to capture in the NNModule type. - let maxpool_captures = &["kernel_size", "stride", "padding", "dilation"]; - registry.register_init_forward( - &program, - "torch.nn.MaxPool1d", - "nn_maxpool_forward_ir", - maxpool_captures, - ); - registry.register_init_forward( - &program, - "torch.nn.MaxPool2d", - "nn_maxpool_forward_ir", - maxpool_captures, - ); - registry.register_init_forward( - &program, - "torch.nn.MaxPool3d", - "nn_maxpool_forward_ir", - maxpool_captures, - ); - - let avgpool_captures = &["kernel_size", "stride", "padding"]; - registry.register_init_forward( - &program, - "torch.nn.AvgPool1d", - "nn_avgpool_forward_ir", - avgpool_captures, - ); - registry.register_init_forward( - &program, - "torch.nn.AvgPool2d", - "nn_avgpool_forward_ir", - avgpool_captures, - ); - registry.register_init_forward( - &program, - "torch.nn.AvgPool3d", - "nn_avgpool_forward_ir", - avgpool_captures, - ); - - registry.register_init_forward( - &program, - "torch.nn.Flatten", - "nn_flatten_forward_ir", - &["start_dim", "end_dim"], - ); - registry.register_init_forward( - &program, - "torch.nn.PixelShuffle", - "nn_pixel_shuffle_forward_ir", - &["upscale_factor"], - ); - registry.register_init_forward(&program, "torch.nn.GLU", "nn_glu_forward_ir", &["dim"]); - registry.register_init_forward( - &program, - "torch.nn.LSTM", - "nn_lstm_forward_ir", - &["input_size", "hidden_size", "num_layers", "bidirectional"], - ); - registry.register_init_forward( - &program, - "torch.nn.Upsample", - "nn_upsample_forward_ir", - &["size", "scale_factor"], - ); - registry.register_init_forward( - &program, - "torch.nn.GRU", - "nn_gru_forward_ir", - &["input_size", "hidden_size", "num_layers", "bidirectional"], - ); - registry.register_init_forward( - &program, - "torch.nn.LSTMCell", - "nn_lstmcell_forward_ir", - &["input_size", "hidden_size"], - ); - registry.register_init_forward( - &program, - "torch.nn.ReflectionPad2d", - "nn_reflectionpad2d_forward_ir", - &["padding"], - ); - registry.register_init_forward( - &program, - "torch.nn.ReplicationPad2d", - "nn_reflectionpad2d_forward_ir", - &["padding"], - ); - - // Random sampling - registry.register("torch.multinomial", dsl_fn(&program, "multinomial_ir")); - registry.register( - "torch.Tensor.multinomial", - dsl_fn(&program, "multinomial_ir"), - ); - registry.register("torch.normal", dsl_fn(&program, "normal_ir")); - - registry - } - - /// Register a meta-shape function. - pub fn register(&mut self, name: impl Into, func: Box) { - self.functions.insert(name.into(), func); - } - - /// Register a meta-shape function for both `torch.X` and `torch.Tensor.X`. - pub fn register_dual Box>( - &mut self, - op_name: &str, - factory: F, - ) { - self.functions - .insert(format!("torch.{}", op_name), factory()); - self.functions - .insert(format!("torch.Tensor.{}", op_name), factory()); - } - - /// Get a meta-shape function by name. - pub fn get(&self, name: &str) -> Option<&dyn MetaShapeFunction> { - self.functions.get(name).map(|b| b.as_ref()) - } - - /// Register an nn.Module: both the forward DSL function and the list of - /// __init__ parameter names to capture in the NNModule type. - /// - /// `class_name` is the qualified class name (e.g., `"torch.nn.MaxPool2d"`). - /// This registers the forward function under `"{class_name}.forward"` and - /// the init captures under `"{class_name}"`. - fn register_init_forward( - &mut self, - program: &ShapeDslProgram, - class_name: &str, - dsl_fn_name: &str, - capture_params: &[&str], - ) { - self.functions.insert( - format!("{class_name}.forward"), - dsl_fn(program, dsl_fn_name), - ); - self.init_captures.insert( - class_name.to_owned(), - capture_params.iter().map(|s| (*s).to_owned()).collect(), - ); - } - - /// Look up init capture config for a qualified class name. - /// Returns the list of __init__ parameter names to capture. - pub fn get_init_capture(&self, class_name: &str) -> Option<&[String]> { - self.init_captures.get(class_name).map(|v| v.as_slice()) - } -} - -impl Default for TensorOpsRegistry { - fn default() -> Self { - Self::new() - } -} - -// Section: DSL source code - -/// The full DSL source defining all tensor shape ops and utility functions. -/// This is valid Python syntax (a strict subset) that we parse with Pyrefly's parser. -const DSL_SOURCE: &str = r#" -import shape_extensions.dsl - -def normalize_dim(rank: int, dim: int) -> int: - if dim < 0: - return dim + rank - return dim - -def int_max(a: int, b: int) -> int: - if a > b: - return a - return b - -def replace_dim(dims: list[int | symint], i: int, value: int | symint) -> list[int | symint]: - return dims[:i] + [value] + dims[i + 1:] - -def remove_dim(dims: list[int | symint], i: int) -> list[int | symint]: - return dims[:i] + dims[i + 1:] - -def insert_dim(dims: list[int | symint], i: int, value: int | symint) -> list[int | symint]: - return dims[:i] + [value] + dims[i:] - -def broadcast(a: list[int | symint], b: list[int | symint]) -> list[int | symint]: - max_len = int_max(len(a), len(b)) - padded_a = [1 for _ in range(max_len - len(a))] + a - padded_b = [1 for _ in range(max_len - len(b))] + b - return [bd if ad == 1 else ad for ad, bd in zip(padded_a, padded_b)] - -def broadcast_int(expr: int | symint | list[int | symint], n: int) -> list[int | symint]: - if isinstance(expr, list): - return expr - return [expr for _ in range(n)] - -def reduce_shape(dims: list[int | symint], dim: int | list[int] | None, keepdim: bool) -> list[int | symint]: - if dim == None: - if keepdim: - return [1 for _ in range(len(dims))] - return [] - dim_list = dim if isinstance(dim, list) else [dim] - norm = [normalize_dim(len(dims), d) for d in dim_list] - return [1 if i in norm else elem for i, elem in enumerate(dims) if not (i in norm) or keepdim] - -def contains(lst: list[int], val: int) -> bool: - return len([x for x in lst if x == val]) > 0 - -def scatter(size: int, indices: list[int], values: list[int], fill: int) -> list[int]: - matches = [[k for k in range(len(indices)) if indices[k] == i] for i in range(size)] - return [values[m[0]] if len(m) > 0 else fill for m in matches] - -def move_dims(dims: list[int | symint], source: int | list[int], dest: int | list[int], rank: int) -> list[int | symint]: - src = broadcast_int(source, 1) - dst = broadcast_int(dest, 1) - src_norm = [normalize_dim(rank, s) for s in src] - dst_norm = [normalize_dim(rank, d) for d in dst] - non_dst = [i for i in range(rank) if not contains(dst_norm, i)] - remaining = [i for i in range(rank) if not contains(src_norm, i)] - perm = scatter(rank, dst_norm + non_dst, src_norm + remaining, 0) - return [dims[p] for p in perm] - -def conv_spatial_out(input_dim: int | symint, kernel: int | symint, stride: int | symint, padding: int | symint, dilation: int | symint) -> int | symint: - return (input_dim + 2 * padding - dilation * (kernel - 1) - 1) // stride + 1 - -def reshape_ir(self: Tensor, shape: list[int | symint]) -> Tensor: - minus_one_count = len([d for d in shape if d == -1]) - if minus_one_count > 1: - raise Error("can only specify one unknown dimension as -1") - has_bad_neg = len([d for d in shape if isinstance(d, int) and d < -1]) > 0 - if has_bad_neg: - raise Error("invalid negative dimension value (only -1 is allowed)") - has_zero = len([d for d in shape if isinstance(d, int) and d == 0]) > 0 - if has_zero: - raise Error("reshape dimensions cannot contain 0") - if minus_one_count > 0: - known = shape_extensions.dsl.prod([d for d in shape if d != -1]) - total = shape_extensions.dsl.prod(self.shape) - if isinstance(total, int) and isinstance(known, int) and total % known != 0: - raise Error("could not infer size for dimension -1: expected " + str(total) + " to be divisible by " + str(known)) - return Tensor(shape=[total // known if d == -1 else d for d in shape]) - return Tensor(shape=shape) - -def squeeze_ir(self: Tensor, dim: int | None = None) -> Tensor: - if dim == None: - return Tensor(shape=[d for d in self.shape if d != 1]) - idx = normalize_dim(len(self.shape), dim) - return Tensor(shape=[d for i, d in enumerate(self.shape) if not (i == idx and d == 1)]) - -def unsqueeze_ir(self: Tensor, dim: int) -> Tensor: - d = normalize_dim(len(self.shape) + 1, dim) - return Tensor(shape=insert_dim(self.shape, d, 1)) - -def transpose_ir(self: Tensor, dim0: int, dim1: int) -> Tensor: - rank = len(self.shape) - d0 = normalize_dim(rank, dim0) - d1 = normalize_dim(rank, dim1) - return Tensor(shape=[self.shape[d1] if i == d0 else self.shape[d0] if i == d1 else d for i, d in enumerate(self.shape)]) - -def permute_ir(self: Tensor, dims: list[int]) -> Tensor: - rank = len(self.shape) - if len(dims) != rank: - raise Error("permute: expected " + str(rank) + " dims, got " + str(len(dims))) - return Tensor(shape=[self.shape[normalize_dim(rank, d)] for d in dims]) - -def flatten_ir(self: Tensor, start_dim: int = 0, end_dim: int = -1) -> Tensor: - rank = len(self.shape) - s = normalize_dim(rank, start_dim) - e = normalize_dim(rank, end_dim) - return Tensor(shape=self.shape[:s] + [shape_extensions.dsl.prod(self.shape[s:e + 1])] + self.shape[e + 1:]) - -def expand_ir(self: Tensor, sizes: list[int | symint]) -> Tensor: - return Tensor(shape=[d if t == -1 else t for d, t in zip(self.shape, sizes)]) - -def repeat_ir(self: Tensor, sizes: list[int | symint]) -> Tensor: - return Tensor(shape=[d * r for d, r in zip(self.shape, sizes)]) - -def unbind_ir(self: Tensor, dim: int = 0) -> list[Tensor]: - d = normalize_dim(len(self.shape), dim) - return [Tensor(shape=remove_dim(self.shape, d)), ...] - -def movedim_ir(self: Tensor, source: int | list[int], destination: int | list[int]) -> Tensor: - return Tensor(shape=move_dims(self.shape, source, destination, len(self.shape))) - -def unfold_ir(self: Tensor, dimension: int, size: int | symint, step: int = 1) -> Tensor: - d = normalize_dim(len(self.shape), dimension) - new_dim = (self.shape[d] - size) // step + 1 - return Tensor(shape=replace_dim(self.shape, d, new_dim) + [size]) - -def cat_ir(tensors: list[Tensor], dim: int = 0) -> Tensor: - first = tensors[0] - d = normalize_dim(len(first.shape), dim) - return Tensor(shape=[shape_extensions.dsl.sum([t.shape[i] for t in tensors]) if i == d else dim_val for i, dim_val in enumerate(first.shape)]) - -def stack_ir(tensors: list[Tensor], dim: int = 0) -> Tensor: - first = tensors[0] - d = normalize_dim(len(first.shape) + 1, dim) - return Tensor(shape=insert_dim(first.shape, d, len(tensors))) - -def broadcast_to_ir(self: Tensor, shape: list[int | symint]) -> Tensor: - return Tensor(shape=shape) - -def tile_ir(self: Tensor, dims: list[int]) -> Tensor: - rank = len(self.shape) - if len(dims) > rank: - extra = len(dims) - rank - return Tensor(shape=[r for r in dims[:extra]] + [d * r for d, r in zip(self.shape, dims[extra:])]) - return Tensor(shape=[d * r for d, r in zip(self.shape, dims)]) - -def select_ir(self: Tensor, dim: int) -> Tensor: - d = normalize_dim(len(self.shape), dim) - return Tensor(shape=remove_dim(self.shape, d)) - -def narrow_ir(self: Tensor, dim: int, length: int | symint) -> Tensor: - return Tensor(shape=replace_dim(self.shape, normalize_dim(len(self.shape), dim), length)) - -def split_ir(self: Tensor, split_size_or_sections: int | symint | list[int | symint] | None = None, dim: int = 0) -> list[Tensor]: - d = normalize_dim(len(self.shape), dim) - if isinstance(split_size_or_sections, list): - return [Tensor(shape=replace_dim(self.shape, d, section)) for section in split_size_or_sections] - if isinstance(split_size_or_sections, int): - dim_val = self.shape[d] - if isinstance(dim_val, int): - count = (dim_val + split_size_or_sections - 1) // split_size_or_sections - return [Tensor(shape=replace_dim(self.shape, d, split_size_or_sections if i < count - 1 else dim_val - (count - 1) * split_size_or_sections)) for i in range(count)] - return [Tensor(shape=replace_dim(self.shape, d, split_size_or_sections)), ...] - if split_size_or_sections != None: - quotient = self.shape[d] // split_size_or_sections - if isinstance(quotient, int): - return [Tensor(shape=replace_dim(self.shape, d, split_size_or_sections)) for _ in range(quotient)] - return [Tensor(shape=replace_dim(self.shape, d, split_size_or_sections)), ...] - return Unknown - -def chunk_ir(self: Tensor, chunks: int, dim: int = 0) -> list[Tensor]: - d = normalize_dim(len(self.shape), dim) - dim_val = self.shape[d] - if isinstance(dim_val, int): - chunk_size = (dim_val + chunks - 1) // chunks - return [Tensor(shape=replace_dim(self.shape, d, chunk_size if i < chunks - 1 else dim_val - (chunks - 1) * chunk_size)) for i in range(chunks)] - return [Tensor(shape=replace_dim(self.shape, d, dim_val // chunks)) for i in range(chunks)] - -def index_select_ir(self: Tensor, dim: int, index: Tensor) -> Tensor: - return Tensor(shape=replace_dim(self.shape, normalize_dim(len(self.shape), dim), index.shape[0])) - -def reduce_ir(self: Tensor, dim: int | list[int] | None = None, keepdim: bool = False) -> Tensor: - if dim == None: - return Tensor(shape=reduce_shape(self.shape, dim, keepdim)) - if isinstance(dim, list): - return Tensor(shape=reduce_shape(self.shape, dim, keepdim)) - return Tensor(shape=reduce_single(self.shape, dim, keepdim)) - -def reduce_single(dims: list[int | symint], dim: int, keepdim: bool) -> list[int | symint]: - before = dims[:dim] - if dim == -1: - if keepdim: - return before + [1] - return before - after = dims[dim + 1:] - if keepdim: - return before + [1] + after - return before + after - -def min_max_median_ir(self: Tensor, dim: int | None = None, keepdim: bool = False) -> Tensor: - if dim == None: - return Tensor(shape=[]) - s = reduce_shape(self.shape, dim, keepdim) - return [Tensor(shape=s), Tensor(shape=s)] - -def aminmax_ir(self: Tensor, dim: int | list[int] | None = None, keepdim: bool = False) -> [Tensor, Tensor]: - s = reduce_shape(self.shape, dim, keepdim) - return [Tensor(shape=s), Tensor(shape=s)] - -def tuple_reduce_ir(self: Tensor, dim: int = -1, keepdim: bool = False) -> [Tensor, Tensor]: - s = reduce_shape(self.shape, dim, keepdim) - return [Tensor(shape=s), Tensor(shape=s)] - -def topk_ir(self: Tensor, k: int | symint, dim: int = -1) -> [Tensor, Tensor]: - s = replace_dim(self.shape, normalize_dim(len(self.shape), dim), k) - return [Tensor(shape=s), Tensor(shape=s)] - -def repeat_interleave_ir(self: Tensor, repeats: int | symint, dim: int | None = None) -> Tensor: - if dim == None: - return Tensor(shape=[shape_extensions.dsl.prod(self.shape) * repeats]) - d = normalize_dim(len(self.shape), dim) - return Tensor(shape=replace_dim(self.shape, d, self.shape[d] * repeats)) - -def cosine_similarity_ir(x1: Tensor, x2: Tensor, dim: int = 1) -> Tensor: - s = broadcast(x1.shape, x2.shape) - return Tensor(shape=reduce_single(s, normalize_dim(len(s), dim), False)) - -def randn_ir(size: list[int | symint]) -> Tensor: - return Tensor(shape=size) - -def randint_ir(low: int, high: int, size: list[int | symint]) -> Tensor: - return Tensor(shape=size) - -def linspace_ir(steps: int | symint) -> Tensor: - return Tensor(shape=[steps]) - -def eye_ir(n: int | symint, m: int | symint | None = None) -> Tensor: - if m == None: - return Tensor(shape=[n, n]) - return Tensor(shape=[n, m]) - -def arange_ir(start: int | symint | None = None, end: int | symint | None = None, step: int | symint | None = None) -> Tensor: - if start != None and end != None and step != None: - return Tensor(shape=[(end - start) // step]) - if start != None and end != None: - return Tensor(shape=[end - start]) - if end != None: - return Tensor(shape=[end]) - if start != None: - return Tensor(shape=[start]) - return Unknown - -def normal_ir(mean: Tensor | None = None, std: Tensor | None = None, size: list[int] | None = None) -> Tensor: - if size != None: - return Tensor(shape=[s for s in size]) - if mean != None: - return Tensor(shape=mean.shape) - if std != None: - return Tensor(shape=std.shape) - return Unknown - -def diag_embed_ir(self: Tensor, offset: int = 0) -> Tensor: - new_dim = self.shape[-1] + (offset if offset >= 0 else -offset) - return Tensor(shape=self.shape[:-1] + [new_dim, new_dim]) - -def tri_indices_ir(row: int | symint, col: int | symint, offset: int = 0) -> Tensor: - return Tensor(shape=[2, 0]) - -def matmul_ir(self: Tensor, other: Tensor) -> Tensor: - r1 = len(self.shape) - r2 = len(other.shape) - if r1 == 1 and r2 == 1: - return Tensor(shape=[]) - if r1 == 1 and r2 == 2: - return Tensor(shape=[other.shape[1]]) - if r1 == 2 and r2 == 1: - return Tensor(shape=[self.shape[0]]) - if r1 == 2 and r2 == 2: - return Tensor(shape=[self.shape[0], other.shape[1]]) - if r1 == 2 and r2 >= 3: - return Tensor(shape=other.shape[:-2] + [self.shape[0]] + [other.shape[-1]]) - if r1 >= 3 and r2 == 2: - return Tensor(shape=self.shape[:-2] + [self.shape[-2]] + [other.shape[1]]) - if r1 >= 3 and r2 >= 3: - return Tensor(shape=broadcast(self.shape[:-2], other.shape[:-2]) + [self.shape[-2]] + [other.shape[-1]]) - return Unknown - -def mv_ir(self: Tensor, vec: Tensor) -> Tensor: - if len(self.shape) != 2: - raise Error("mv expects 2D matrix, got " + str(len(self.shape)) + "D tensor") - if len(vec.shape) != 1: - raise Error("mv expects 1D vector, got " + str(len(vec.shape)) + "D tensor") - return Tensor(shape=[self.shape[0]]) - -def outer_ir(self: Tensor, vec2: Tensor) -> Tensor: - if len(self.shape) != 1 or len(vec2.shape) != 1: - raise Error("outer expects 1D tensors, got " + str(len(self.shape)) + "D and " + str(len(vec2.shape)) + "D") - return Tensor(shape=[self.shape[0], vec2.shape[0]]) - -def tensordot_ir(self: Tensor, other: Tensor, dims: int) -> Tensor: - return Tensor(shape=self.shape[:len(self.shape) - dims] + other.shape[dims:]) - -def apply_einsum(output_map: list[list[int]], check_pairs: list[list[int]], inputs: list[Tensor]) -> Tensor: - bad_dims = [1 for i0, d0, i1, d1 in check_pairs if isinstance(inputs[i0].shape[d0], int) and isinstance(inputs[i1].shape[d1], int) and inputs[i0].shape[d0] != inputs[i1].shape[d1]] - if len(bad_dims) > 0: - raise Error("einsum: inconsistent dimensions for repeated index") - return Tensor(shape=[inputs[inp].shape[dim] for inp, dim in output_map]) - -def einsum_ir(spec: str, operands: list[Tensor] | None = None) -> Tensor: - if operands != None: - output_map, check_pairs = shape_extensions.dsl.parse_einsum_equation(spec) - return apply_einsum(output_map, check_pairs, operands) - return Unknown - -def eigvals_ir(self: Tensor) -> Tensor: - if len(self.shape) < 2: - raise Error("eigvals requires at least 2D input, got " + str(len(self.shape)) + "D tensor") - return Tensor(shape=self.shape[:-2] + [self.shape[-2]]) - -def eig_ir(self: Tensor) -> [Tensor, Tensor]: - if len(self.shape) < 2: - raise Error("eig requires at least 2D input, got " + str(len(self.shape)) + "D tensor") - batch = self.shape[:-2] - return [Tensor(shape=batch + [self.shape[-2]]), Tensor(shape=batch + self.shape[-2:])] - -def slogdet_ir(self: Tensor) -> [Tensor, Tensor]: - if len(self.shape) < 2: - raise Error("slogdet requires at least 2D input, got " + str(len(self.shape)) + "D tensor") - return [Tensor(shape=self.shape[:-2]), Tensor(shape=self.shape[:-2])] - -def solve_ir(self: Tensor, other: Tensor) -> Tensor: - return Tensor(shape=other.shape) - -def solve_reversed_ir(self: Tensor, other: Tensor) -> Tensor: - return Tensor(shape=self.shape) - -def conv_ir(self: Tensor, weight: Tensor, stride: int | list[int] = 1, padding: int | list[int] = 0, dilation: int | list[int] = 1) -> Tensor: - spatial_dims = len(self.shape) - 2 - stride_list = broadcast_int(stride, spatial_dims) - padding_list = broadcast_int(padding, spatial_dims) - dilation_list = broadcast_int(dilation, spatial_dims) - return Tensor(shape=[self.shape[0], weight.shape[0]] + [conv_spatial_out(s, k, st, p, dil) for s, k, st, p, dil in zip(self.shape[2:], weight.shape[2:], stride_list, padding_list, dilation_list)]) - -def conv_transpose_ir(self: Tensor, weight: Tensor, stride: int | list[int] = 1, padding: int | list[int] = 0, output_padding: int | list[int] = 0, dilation: int | list[int] = 1) -> Tensor: - spatial_dims = len(self.shape) - 2 - stride_list = broadcast_int(stride, spatial_dims) - padding_list = broadcast_int(padding, spatial_dims) - outpad_list = broadcast_int(output_padding, spatial_dims) - dilation_list = broadcast_int(dilation, spatial_dims) - return Tensor(shape=[self.shape[0], weight.shape[1]] + [(s - 1) * st - 2 * p + dil * (k - 1) + op + 1 for s, k, st, p, op, dil in zip(self.shape[2:], weight.shape[2:], stride_list, padding_list, outpad_list, dilation_list)]) - -def pool_ir(self: Tensor, kernel_size: int | list[int], stride: int | list[int] | None = None, padding: int | list[int] = 0, dilation: int | list[int] = 1, return_indices: bool = False) -> Tensor: - spatial_dims = len(self.shape) - 2 - ks_list = broadcast_int(kernel_size, spatial_dims) - stride_list = ks_list if stride == None else broadcast_int(stride, spatial_dims) - padding_list = broadcast_int(padding, spatial_dims) - dilation_list = broadcast_int(dilation, spatial_dims) - out = [self.shape[0], self.shape[1]] + [conv_spatial_out(s, k, st, p, dil) for s, k, st, p, dil in zip(self.shape[2:], ks_list, stride_list, padding_list, dilation_list)] - if return_indices: - return [Tensor(shape=out), Tensor(shape=out)] - return Tensor(shape=out) - -def adaptive_pool_ir(self: Tensor, output_size: int | symint | list[int | symint]) -> Tensor: - out_sizes = broadcast_int(output_size, len(self.shape) - 2) - return Tensor(shape=[self.shape[0], self.shape[1]] + out_sizes) - -def interpolate_ir(self: Tensor, size: int | symint | list[int | symint] | None = None, scale_factor: int | symint | None = None) -> Tensor: - if size != None: - return Tensor(shape=[self.shape[0], self.shape[1]] + broadcast_int(size, len(self.shape) - 2)) - if scale_factor != None: - return Tensor(shape=[self.shape[0], self.shape[1]] + [d * scale_factor for d in self.shape[2:]]) - raise Error("interpolate requires either 'size' or 'scale_factor' argument") - -def loss_ir(self: Tensor, reduction: str = "mean") -> Tensor: - if reduction == "none": - return Tensor(shape=self.shape) - return Tensor(shape=[]) - -def pad_ir(self: Tensor, pad: list[int]) -> Tensor: - rank = len(self.shape) - num_pad_dims = len(pad) // 2 - offsets = [pad[(rank - 1 - i) * 2] + pad[(rank - 1 - i) * 2 + 1] if i >= rank - num_pad_dims else 0 for i in range(rank)] - return Tensor(shape=[d + offsets[i] for i, d in enumerate(self.shape)]) - -def rfft_ir(self: Tensor, n: int | symint | None = None, dim: int = -1) -> Tensor: - d = normalize_dim(len(self.shape), dim) - if n != None: - return Tensor(shape=replace_dim(self.shape, d, n // 2 + 1)) - return Tensor(shape=replace_dim(self.shape, d, self.shape[d] // 2 + 1)) - -def irfft_ir(self: Tensor, n: int | symint | None = None, dim: int = -1) -> Tensor: - d = normalize_dim(len(self.shape), dim) - if n != None: - return Tensor(shape=replace_dim(self.shape, d, n)) - return Tensor(shape=replace_dim(self.shape, d, 2 * (self.shape[d] - 1))) - -def size_ir(self: Tensor, dim: int | None = None) -> int | symint: - if dim != None: - return self.shape[normalize_dim(len(self.shape), dim)] - return [d for d in self.shape] - -def numel_ir(self: Tensor) -> int | symint: - return shape_extensions.dsl.prod(self.shape) - -def dim_ir(self: Tensor) -> int: - return len(self.shape) - -def item_ir(self: Tensor) -> Tensor: - if len(self.shape) != 0: - raise Error("item() only works on 0-dimensional tensors, got " + str(len(self.shape)) + "D tensor") - return Unknown - -def tolist_ir(self: Tensor) -> Tensor: - return Unknown - -def multinomial_ir(self: Tensor, num_samples: int | symint) -> Tensor: - return Tensor(shape=self.shape[:-1] + [num_samples]) - -def where_ir(condition: Tensor, x: Tensor, y: Tensor) -> Tensor: - return Tensor(shape=x.shape) - -def take_along_dim_ir(self: Tensor, indices: Tensor) -> Tensor: - return Tensor(shape=indices.shape) - -def nn_flatten_forward_ir(input: Tensor, start_dim: symint = 1, end_dim: symint = -1) -> Tensor: - return flatten_ir(input, start_dim, end_dim) - -def nn_maxpool_forward_ir(input: Tensor, kernel_size: symint = 1, stride: symint | None = None, padding: symint = 0, dilation: symint = 1) -> Tensor: - return pool_ir(input, kernel_size, stride, padding, dilation) - -def nn_avgpool_forward_ir(input: Tensor, kernel_size: symint = 1, stride: symint | None = None, padding: symint = 0) -> Tensor: - return pool_ir(input, kernel_size, stride, padding, 1) - -def nn_upsample_forward_ir(input: Tensor, size: symint | None = None, scale_factor: symint | None = None) -> Tensor: - return interpolate_ir(input, size, scale_factor) - -def nn_pixel_shuffle_forward_ir(input: Tensor, upscale_factor: symint) -> Tensor: - r = upscale_factor - return Tensor(shape=[input.shape[0], input.shape[1] // (r * r)] + [d * r for d in input.shape[2:]]) - -def nn_glu_forward_ir(input: Tensor, dim: symint = 1) -> Tensor: - rank = len(input.shape) - d = normalize_dim(rank, dim) - return Tensor(shape=replace_dim(input.shape, d, input.shape[d] // 2)) - -def nn_lstm_forward_ir(input: Tensor, input_size: symint, hidden_size: symint, num_layers: symint = 1, bidirectional: bool = False) -> [Tensor, Tensor, Tensor]: - nd = 2 if bidirectional else 1 - output = Tensor(shape=[input.shape[0], input.shape[1], hidden_size * nd]) - h_n = Tensor(shape=[num_layers * nd, input.shape[0], hidden_size]) - c_n = Tensor(shape=[num_layers * nd, input.shape[0], hidden_size]) - return [output, h_n, c_n] - -def nn_gru_forward_ir(input: Tensor, input_size: symint, hidden_size: symint, num_layers: symint = 1, bidirectional: bool = False) -> [Tensor, Tensor]: - nd = 2 if bidirectional else 1 - output = Tensor(shape=[input.shape[0], input.shape[1], hidden_size * nd]) - h_n = Tensor(shape=[num_layers * nd, input.shape[0], hidden_size]) - return [output, h_n] - -def nn_lstmcell_forward_ir(input: Tensor, input_size: symint, hidden_size: symint) -> [Tensor, Tensor]: - h = Tensor(shape=[input.shape[0], hidden_size]) - c = Tensor(shape=[input.shape[0], hidden_size]) - return [h, c] - -def nn_reflectionpad2d_forward_ir(input: Tensor, padding: symint) -> Tensor: - return Tensor(shape=[input.shape[0], input.shape[1], input.shape[2] + 2 * padding, input.shape[3] + 2 * padding]) -"#; diff --git a/pyrefly/lib/alt/call.rs b/pyrefly/lib/alt/call.rs index 69280dda59..a7a2c3bcb2 100644 --- a/pyrefly/lib/alt/call.rs +++ b/pyrefly/lib/alt/call.rs @@ -13,7 +13,6 @@ use pyrefly_python::dunder; use pyrefly_types::meta_shape_dsl::ShapeTransformRef; use pyrefly_types::quantified::Quantified; use pyrefly_types::special_form::SpecialForm; -use pyrefly_types::tensor_ops_registry::TensorOpsRegistry; use pyrefly_types::typed_dict::TypedDictInner; use pyrefly_types::types::CalleeKind; use pyrefly_types::types::NNModuleType; @@ -1121,27 +1120,13 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { errors: &ErrorCollector, result: Type, ) -> Type { - // Check ClassMetadata.capture_init first (populated from @uses_shape_dsl - // decorator on the forward method). Fall back to TensorOpsRegistry for - // classes not yet migrated to stub-based declarations. let class_metadata = self.get_metadata_for_class(ct.class_object()); let capture_names_from_metadata: Vec; let capture_names: &[Name] = if let Some(names) = class_metadata.capture_init() { capture_names_from_metadata = names.to_vec(); &capture_names_from_metadata } else { - use std::sync::OnceLock; - static TENSOR_OPS_REGISTRY: OnceLock = OnceLock::new(); - - let class_name = format!("{}.{}", ct.class_object().module_name(), ct.name()); - let registry = TENSOR_OPS_REGISTRY.get_or_init(TensorOpsRegistry::new); - match registry.get_init_capture(&class_name) { - Some(names) => { - capture_names_from_metadata = names.iter().map(Name::new).collect(); - &capture_names_from_metadata - } - None => return result, - } + return result; }; let infer_type_or_expr = |toe: TypeOrExpr, errors: &ErrorCollector| -> Type { diff --git a/pyrefly/lib/alt/callable.rs b/pyrefly/lib/alt/callable.rs index 0cd81b1478..4344ea46f0 100644 --- a/pyrefly/lib/alt/callable.rs +++ b/pyrefly/lib/alt/callable.rs @@ -13,7 +13,6 @@ use pyrefly_python::dunder; use pyrefly_types::callable::FunctionKind; use pyrefly_types::meta_shape_dsl::MetaShapeFunction; use pyrefly_types::meta_shape_dsl::ShapeTransformRef; -use pyrefly_types::tensor_ops_registry::TensorOpsRegistry; use pyrefly_types::tuple::Tuple; use pyrefly_types::typed_dict::ExtraItems; use pyrefly_types::types::TArgs; @@ -1391,22 +1390,8 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { call_context.set_argument_side(ArgumentSide::Got); call_context.require_boundary_consumption(); - // Prefer decorator-based shape_transform over the legacy registry. - // Fall back to lookup_meta_shape only when shape_transform is None. - // The decorator path is not gated by tensor_shapes — @uses_shape_dsl - // is itself the opt-in. The registry applies implicitly by qualified - // name and needs the feature gate to avoid unnecessary DSL parsing - // and per-call HashMap lookups. let shape_transform_func = shape_transform.map(|t| t.to_meta_shape_function()); - let registry_func = if shape_transform_func.is_none() && self.solver().tensor_shapes { - Self::lookup_meta_shape(callable_name) - } else { - None - }; - let meta_shape_func: Option<&dyn MetaShapeFunction> = shape_transform_func - .as_ref() - .map(|b| &**b) - .or(registry_func); + let meta_shape_func: Option<&dyn MetaShapeFunction> = shape_transform_func.as_deref(); let mut bound_args: Option> = meta_shape_func.map(|_| HashMap::new()); let (callable_qs, mut callable) = if let Some(tparams) = tparams { @@ -1631,26 +1616,6 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { ) } - /// Look up whether a callable has a registered meta-shape function. - fn lookup_meta_shape(callable_name: Option<&FunctionKind>) -> Option<&dyn MetaShapeFunction> { - use std::sync::OnceLock; - static TENSOR_OPS_REGISTRY: OnceLock = OnceLock::new(); - - let func_id = callable_name.and_then(|fk| match fk { - FunctionKind::Def(box_func_id) => Some(box_func_id.as_ref()), - _ => None, - })?; - - let qualified_name = if let Some(cls) = &func_id.cls { - format!("{}.{}.{}", func_id.module.name(), cls.name(), func_id.name) - } else { - format!("{}.{}", func_id.module.name(), func_id.name) - }; - - let registry = TENSOR_OPS_REGISTRY.get_or_init(TensorOpsRegistry::new); - registry.get(&qualified_name) - } - /// Auto-inject module field values into `bound_args` for DSL parameters /// that aren't method parameters but match fields on `self`. /// From ef5afd8c511d4b1df79e9e707eff541e6d1e149e Mon Sep 17 00:00:00 2001 From: stroxler Date: Thu, 21 May 2026 21:19:44 -0700 Subject: [PATCH 21/25] Replace all-siblings fn_lookup with per-caller transitive-callee resolution MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Each `shape_dsl_function` now carries only its transitive callees instead of every DSL function in the module. A leaf function like `randn_ir` gets `helpers = [self]` (1 entry) instead of all 86 siblings. `reshape_ir` gets `[reshape_ir, normalize_dim]` (2 entries). `movedim_ir` gets its 6 transitive callees. Adds `ShapeDslFunction::call_targets()` which walks `DslBody`/`DslExpr` collecting `DslCallTarget::UserDefined` names, and `compute_transitive_helpers` which resolves those names against the per-module `BindingsMetadata` index and computes the fixed-point closure. Behaviorally identical — same shape inference results, same types. The change reduces per-function memory footprint and will enable finer cache invalidation once solver dependency edges are per-helper. Differential Revision: D105783601 --- crates/pyrefly_types/src/meta_shape_dsl.rs | 196 ++++++++++++++++++++- pyrefly/lib/alt/function.rs | 48 +++-- 2 files changed, 229 insertions(+), 15 deletions(-) diff --git a/crates/pyrefly_types/src/meta_shape_dsl.rs b/crates/pyrefly_types/src/meta_shape_dsl.rs index 498ebcaef4..0065ad2815 100644 --- a/crates/pyrefly_types/src/meta_shape_dsl.rs +++ b/crates/pyrefly_types/src/meta_shape_dsl.rs @@ -21,6 +21,7 @@ use std::cmp::Ordering; use std::collections::HashMap; +use std::collections::HashSet; use std::fmt; use std::fmt::Debug; use std::hash::Hash; @@ -3163,12 +3164,118 @@ impl TypeEq for ShapeDslFunction { } } +impl ShapeDslFunction { + /// The function name from the DSL definition. + pub fn name(&self) -> &str { + &self.inner.name + } + + /// Returns the set of user-defined function names called in this function's body. + pub fn call_targets(&self) -> HashSet { + let mut targets = HashSet::new(); + collect_call_targets_body(&self.inner.body, &mut targets); + targets + } +} + +/// Walk a `DslBody` and collect all `DslCallTarget::UserDefined` names. +fn collect_call_targets_body(body: &DslBody, targets: &mut HashSet) { + match body { + DslBody::Assign { expr, rest, .. } => { + collect_call_targets_expr(expr, targets); + collect_call_targets_body(rest, targets); + } + DslBody::If { + cond, + then_body, + rest, + } => { + collect_call_targets_expr(cond, targets); + collect_call_targets_body(then_body, targets); + collect_call_targets_body(rest, targets); + } + DslBody::Return(expr) | DslBody::Raise(expr) => { + collect_call_targets_expr(expr, targets); + } + } +} + +/// Walk a `DslExpr` and collect all `DslCallTarget::UserDefined` names. +fn collect_call_targets_expr(expr: &DslExpr, targets: &mut HashSet) { + match expr { + DslExpr::Call { + func: DslCallTarget::UserDefined(name), + args, + } => { + targets.insert(name.clone()); + for arg in args { + collect_call_targets_expr(arg, targets); + } + } + DslExpr::Call { args, .. } => { + for arg in args { + collect_call_targets_expr(arg, targets); + } + } + DslExpr::List(items) => { + for item in items { + collect_call_targets_expr(item, targets); + } + } + DslExpr::ListComp { + elt, iter, cond, .. + } => { + collect_call_targets_expr(elt, targets); + collect_call_targets_expr(iter, targets); + if let Some(c) = cond { + collect_call_targets_expr(c, targets); + } + } + DslExpr::Index { base, index } => { + collect_call_targets_expr(base, targets); + collect_call_targets_expr(index, targets); + } + DslExpr::Slice { base, lower, upper } => { + collect_call_targets_expr(base, targets); + if let Some(l) = lower { + collect_call_targets_expr(l, targets); + } + if let Some(u) = upper { + collect_call_targets_expr(u, targets); + } + } + DslExpr::BinOp { left, right, .. } => { + collect_call_targets_expr(left, targets); + collect_call_targets_expr(right, targets); + } + DslExpr::UnaryOp { operand, .. } => { + collect_call_targets_expr(operand, targets); + } + DslExpr::IsInstance { expr, .. } => { + collect_call_targets_expr(expr, targets); + } + DslExpr::In { left, right } => { + collect_call_targets_expr(left, targets); + collect_call_targets_expr(right, targets); + } + DslExpr::Shape(expr) | DslExpr::TensorNew(expr) => { + collect_call_targets_expr(expr, targets); + } + DslExpr::IfExpr { body, test, orelse } => { + collect_call_targets_expr(body, targets); + collect_call_targets_expr(test, targets); + collect_call_targets_expr(orelse, targets); + } + DslExpr::Const(_) | DslExpr::Var(_) | DslExpr::Ellipsis | DslExpr::Unknown => {} + } +} + /// Reference to a shape-DSL function that refines a callable's return type. /// Carried on `FuncFlags` for functions decorated with `@uses_shape_dsl`. #[derive(Debug, Clone)] pub struct ShapeTransformRef { pub dsl_fn: Arc, - // TODO: Replace all-siblings snapshot with resolved transitive callees. + /// Transitive closure of user-defined helpers called by `dsl_fn`. pub helpers: Derived>>>, } @@ -3227,10 +3334,10 @@ impl TypeEq for ShapeTransformRef { impl ShapeTransformRef { /// Build a `MetaShapeFunction` evaluator from this shape transform reference. - /// Populates `fn_lookup` with this function and all same-module siblings + /// Populates `fn_lookup` with this function and its transitive callees /// so that cross-function DSL calls resolve correctly. pub fn to_meta_shape_function(&self) -> Box { - // helpers already includes self (it's all same-module siblings). + // helpers contains self and its transitive callees. let fn_lookup: Arc>> = Arc::new( self.helpers .iter() @@ -3334,3 +3441,86 @@ pub fn make_meta_shape_function( ); Some(Box::new(DslMetaShapeFunction { fn_def, fn_lookup })) } + +#[cfg(test)] +mod tests { + use pyrefly_python::ast::Ast; + use ruff_python_ast::PySourceType; + use ruff_python_ast::Stmt; + + use super::*; + + fn parse_dsl_functions(source: &str) -> Vec { + let (module, _, _) = Ast::parse(source, PySourceType::Stub); + module + .body + .iter() + .filter_map(|stmt| { + if let Stmt::FunctionDef(func) = stmt { + Some(convert_shape_dsl_function(func).unwrap()) + } else { + None + } + }) + .collect() + } + + #[test] + fn test_call_targets_disjoint_closures() { + let fns = parse_dsl_functions( + r#" +def helper_a(x: int) -> int: + return x + 1 + +def helper_b(x: int) -> int: + return x + 2 + +def calls_a(x: int) -> int: + return helper_a(x) + +def calls_b(x: int) -> int: + return helper_b(x) + +def leaf(x: int) -> int: + return x +"#, + ); + assert_eq!(fns.len(), 5); + + let calls_a = fns.iter().find(|f| f.name() == "calls_a").unwrap(); + let calls_b = fns.iter().find(|f| f.name() == "calls_b").unwrap(); + let leaf = fns.iter().find(|f| f.name() == "leaf").unwrap(); + + assert_eq!( + calls_a.call_targets(), + HashSet::from(["helper_a".to_owned()]) + ); + assert_eq!( + calls_b.call_targets(), + HashSet::from(["helper_b".to_owned()]) + ); + assert!(leaf.call_targets().is_empty()); + } + + #[test] + fn test_call_targets_transitive() { + let fns = parse_dsl_functions( + r#" +def deep(x: int) -> int: + return x + +def mid(x: int) -> int: + return deep(x) + +def top(x: int) -> int: + return mid(x) +"#, + ); + let top = fns.iter().find(|f| f.name() == "top").unwrap(); + let mid = fns.iter().find(|f| f.name() == "mid").unwrap(); + + // call_targets is direct only, not transitive + assert_eq!(top.call_targets(), HashSet::from(["mid".to_owned()])); + assert_eq!(mid.call_targets(), HashSet::from(["deep".to_owned()])); + } +} diff --git a/pyrefly/lib/alt/function.rs b/pyrefly/lib/alt/function.rs index 82c784aaff..2dc407e5fe 100644 --- a/pyrefly/lib/alt/function.rs +++ b/pyrefly/lib/alt/function.rs @@ -5,6 +5,7 @@ * LICENSE file in the root directory of this source tree. */ +use std::collections::HashSet; use std::collections::VecDeque; use std::ops::Deref; use std::sync::Arc; @@ -551,18 +552,10 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { def_index: Some(def_index), outer_funcs, }); - // Collect all same-module @shape_dsl_function siblings for fn_lookup. - // TODO: Replace all-siblings snapshot with per-caller transitive-callee - // resolution. - let siblings: Arc>> = Arc::new( - self.bindings() - .metadata() - .shape_dsl_functions() - .iter() - .map(|(_, dsl)| Arc::clone(dsl)) - .collect(), - ); - FunctionKind::ShapeDsl(func_id, dsl_fn, Derived(siblings)) + // Build the transitive closure of helper functions this DSL function calls. + let module_dsl_fns = self.bindings().metadata().shape_dsl_functions(); + let helpers = compute_transitive_helpers(&dsl_fn, module_dsl_fns); + FunctionKind::ShapeDsl(func_id, dsl_fn, Derived(helpers)) } else { FunctionKind::from_name( self.module().dupe(), @@ -2551,3 +2544,34 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { } } } + +/// Compute the transitive closure of DSL helper functions called by `root`. +/// +/// Starting from `root`, follows `call_targets()` through `module_dsl_fns` +/// to collect every user-defined function reachable from `root`. The result +/// always includes `root` itself as the first element. +fn compute_transitive_helpers( + root: &Arc, + module_dsl_fns: &[(Name, Arc)], +) -> Arc>> { + let mut closure: Vec> = vec![Arc::clone(root)]; + let mut visited: HashSet = HashSet::new(); + visited.insert(root.name().to_owned()); + + let mut queue: Vec = root.call_targets().into_iter().collect(); + while let Some(name) = queue.pop() { + if !visited.insert(name.clone()) { + continue; + } + if let Some((_, dsl_fn)) = module_dsl_fns.iter().find(|(n, _)| n.as_str() == name) { + closure.push(Arc::clone(dsl_fn)); + for target in dsl_fn.call_targets() { + if !visited.contains(&target) { + queue.push(target); + } + } + } + } + + Arc::new(closure) +} From 2774b74116b44c8801b28fa35058dfecd3f4a18a Mon Sep 17 00:00:00 2001 From: stroxler Date: Thu, 21 May 2026 21:19:44 -0700 Subject: [PATCH 22/25] Add per-function `type_check_program` validation Summary: Run `type_check_program` on each DSL function's transitive-callee closure at `shape_dsl_function` solve time. This validates that cross-function call signatures are consistent (e.g., `reshape_ir` calling `normalize_dim` with the right argument types). Previously, validation was done once globally by the deleted `TensorOpsRegistry`. With per-caller resolution from 7a-1, each function is now validated against only its actual callees. Panics on type errors for now; Phase 7b will convert to diagnostics. Differential Revision: D105783602 --- crates/pyrefly_types/src/meta_shape_dsl.rs | 12 ++++++++++++ pyrefly/lib/alt/function.rs | 5 ++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/crates/pyrefly_types/src/meta_shape_dsl.rs b/crates/pyrefly_types/src/meta_shape_dsl.rs index 0065ad2815..5f56a4c545 100644 --- a/crates/pyrefly_types/src/meta_shape_dsl.rs +++ b/crates/pyrefly_types/src/meta_shape_dsl.rs @@ -3270,6 +3270,18 @@ fn collect_call_targets_expr(expr: &DslExpr, targets: &mut HashSet) { } } +/// Validate a set of `ShapeDslFunction`s as a program. +/// +/// Runs `type_check_program` on the inner `DslFnDef`s, verifying that +/// cross-function calls have consistent signatures. Panics on type errors. +/// +/// Intended to be called with a per-caller transitive closure (root + +/// its resolved helpers), not the full module. +pub fn validate_shape_dsl_functions(fns: &[Arc]) { + let defs: Vec = fns.iter().map(|f| (*f.inner).clone()).collect(); + type_check_program(&defs); +} + /// Reference to a shape-DSL function that refines a callable's return type. /// Carried on `FuncFlags` for functions decorated with `@uses_shape_dsl`. #[derive(Debug, Clone)] diff --git a/pyrefly/lib/alt/function.rs b/pyrefly/lib/alt/function.rs index 2dc407e5fe..869e21ff05 100644 --- a/pyrefly/lib/alt/function.rs +++ b/pyrefly/lib/alt/function.rs @@ -25,6 +25,7 @@ use pyrefly_types::class::ClassType; use pyrefly_types::dimension::SizeExpr; use pyrefly_types::meta_shape_dsl::ShapeDslFunction; use pyrefly_types::meta_shape_dsl::ShapeTransformRef; +use pyrefly_types::meta_shape_dsl::validate_shape_dsl_functions; use pyrefly_types::quantified::Quantified; use pyrefly_types::quantified::QuantifiedOrigin; use pyrefly_types::type_var::Restriction; @@ -552,9 +553,11 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { def_index: Some(def_index), outer_funcs, }); - // Build the transitive closure of helper functions this DSL function calls. + // Build the transitive closure of helper functions this DSL function calls, + // then validate cross-function call signatures. let module_dsl_fns = self.bindings().metadata().shape_dsl_functions(); let helpers = compute_transitive_helpers(&dsl_fn, module_dsl_fns); + validate_shape_dsl_functions(&helpers); FunctionKind::ShapeDsl(func_id, dsl_fn, Derived(helpers)) } else { FunctionKind::from_name( From 606c1701ab19ce341237efb95e452abc14064475 Mon Sep 17 00:00:00 2001 From: stroxler Date: Thu, 21 May 2026 21:19:44 -0700 Subject: [PATCH 23/25] Emit diagnostic for invalid `@uses_shape_dsl` arguments MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: When `uses_shape_dsl(ir_fn)` references a function that is not decorated with `shape_dsl_function`, emit an `InvalidArgument` diagnostic instead of silently producing a function with no shape inference. The function still falls back to its declared return type. Previously this was a silent no-op — the `if let FunctionKind::ShapeDsl` chain simply didn't match and `shape_transform` stayed `None` with no indication to the user that their decorator was misconfigured. Differential Revision: D105783626 --- pyrefly/lib/alt/function.rs | 8 ++++++++ pyrefly/lib/test/shape_dsl.rs | 20 +++++++++++++++++++- 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/pyrefly/lib/alt/function.rs b/pyrefly/lib/alt/function.rs index 869e21ff05..99867ba0c1 100644 --- a/pyrefly/lib/alt/function.rs +++ b/pyrefly/lib/alt/function.rs @@ -580,6 +580,14 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { dsl_fn: dsl_fn.clone(), helpers: helpers.clone(), })); + } else { + self.error( + errors, + ir_identifier.range(), + ErrorKind::InvalidArgument, + "`@uses_shape_dsl` argument does not resolve to a `@shape_dsl_function`" + .to_owned(), + ); } } diff --git a/pyrefly/lib/test/shape_dsl.rs b/pyrefly/lib/test/shape_dsl.rs index af949cd53f..1f02f7e1fe 100644 --- a/pyrefly/lib/test/shape_dsl.rs +++ b/pyrefly/lib/test/shape_dsl.rs @@ -28,6 +28,8 @@ def times_two(x: int) -> int: @shape_dsl_function def double_ir(x: int) -> int: return times_two(x) + +def not_a_dsl_fn(x: int) -> int: ... "#, ); env.add_with_path( @@ -36,7 +38,7 @@ def double_ir(x: int) -> int: r#" from typing import overload from shape_extensions import uses_shape_dsl -from my_shapes import identity_ir, double_ir +from my_shapes import identity_ir, double_ir, not_a_dsl_fn @uses_shape_dsl(identity_ir) def plain_fn(x: int) -> int: ... @@ -56,6 +58,9 @@ def overloaded_no_impl(x: str) -> str: ... @uses_shape_dsl(double_ir) def double_fn(x: int) -> int: ... + +@uses_shape_dsl(not_a_dsl_fn) # E: `@uses_shape_dsl` argument does not resolve to a `@shape_dsl_function` +def bad_fn(x: int) -> int: ... "#, ); env @@ -109,3 +114,16 @@ from my_lib import double_fn assert_type(double_fn(3), Literal[6]) "#, ); + +testcase!( + test_uses_shape_dsl_not_a_dsl_function, + shape_dsl_env(), + r#" +from typing import assert_type +from my_lib import bad_fn + +# The @uses_shape_dsl argument is not a @shape_dsl_function, so no shape +# transform is applied and the declared return type (int) is used instead. +assert_type(bad_fn(1), int) +"#, +); From c3b9a18d5032e81212bca7f3d642ac7c55bef607 Mon Sep 17 00:00:00 2001 From: stroxler Date: Thu, 21 May 2026 21:19:44 -0700 Subject: [PATCH 24/25] Replace `convert_fndef` panic with diagnostic MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: When a `shape_dsl_function` body uses unsupported Python syntax, emit a diagnostic instead of panicking. The function degrades to a normal `FunctionKind::Def` with no shape inference. Also adds warnings for unsupported parameter kinds (`*args`, `**kwargs`, keyword-only, positional-only) — these are silently dropped by the DSL converter but the user should know they have no effect. Previously, any `convert_fndef` failure crashed the type checker via `.expect()`. Third-party stub authors writing new DSL functions would hit this crash on any syntax mistake. Differential Revision: D105783605 --- pyrefly/lib/binding/function.rs | 53 +++++++++++++++++++++++++++++---- pyrefly/lib/test/shape_dsl.rs | 44 ++++++++++++++++++++++++++- 2 files changed, 90 insertions(+), 7 deletions(-) diff --git a/pyrefly/lib/binding/function.rs b/pyrefly/lib/binding/function.rs index 73c131d1ea..4ca56a8ae8 100644 --- a/pyrefly/lib/binding/function.rs +++ b/pyrefly/lib/binding/function.rs @@ -72,6 +72,7 @@ use crate::binding::scope::UnusedParameter; use crate::binding::scope::UnusedVariable; use crate::binding::scope::YieldsAndReturns; use crate::config::base::InferReturnTypes; +use crate::config::error_kind::ErrorKind; use crate::export::special::SpecialExport; use crate::types::types::AnyStyle; @@ -828,12 +829,52 @@ impl<'a> BindingsBuilder<'a> { // Convert the function to DSL IR before `function_header` takes `returns` // and before `function_body` takes `body`. let shape_dsl_def = if is_shape_dsl { - let dsl_fn = Arc::new( - convert_shape_dsl_function(&x).expect("@shape_dsl_function body must be valid DSL"), - ); - self.metadata - .push_shape_dsl(func_name.id.clone(), Arc::clone(&dsl_fn)); - Some(dsl_fn) + // Warn about parameter kinds the DSL silently ignores. + if let Some(vararg) = &x.parameters.vararg { + self.error( + vararg.range(), + ErrorKind::InvalidArgument, + "@shape_dsl_function: *args parameters are not supported in the shape DSL and will be ignored".to_owned(), + ); + } + if let Some(kwarg) = &x.parameters.kwarg { + self.error( + kwarg.range(), + ErrorKind::InvalidArgument, + "@shape_dsl_function: **kwargs parameters are not supported in the shape DSL and will be ignored".to_owned(), + ); + } + if !x.parameters.kwonlyargs.is_empty() { + self.error( + x.parameters.kwonlyargs[0].range(), + ErrorKind::InvalidArgument, + "@shape_dsl_function: keyword-only parameters are not supported in the shape DSL and will be ignored".to_owned(), + ); + } + if !x.parameters.posonlyargs.is_empty() { + self.error( + x.parameters.posonlyargs[0].range(), + ErrorKind::InvalidArgument, + "@shape_dsl_function: positional-only parameters are not supported in the shape DSL and will be ignored".to_owned(), + ); + } + + match convert_shape_dsl_function(&x) { + Ok(dsl_fn) => { + let dsl_fn = Arc::new(dsl_fn); + self.metadata + .push_shape_dsl(func_name.id.clone(), Arc::clone(&dsl_fn)); + Some(dsl_fn) + } + Err(msg) => { + self.error( + x.range, + ErrorKind::InvalidArgument, + format!("@shape_dsl_function: {msg}"), + ); + None + } + } } else { None }; diff --git a/pyrefly/lib/test/shape_dsl.rs b/pyrefly/lib/test/shape_dsl.rs index 1f02f7e1fe..0b2e026ff3 100644 --- a/pyrefly/lib/test/shape_dsl.rs +++ b/pyrefly/lib/test/shape_dsl.rs @@ -30,6 +30,16 @@ def double_ir(x: int) -> int: return times_two(x) def not_a_dsl_fn(x: int) -> int: ... + +@shape_dsl_function # E: @shape_dsl_function: unexpected statement in DSL body +def bad_syntax_ir(x: int) -> int: + while x > 0: + x = x - 1 + return x + +@shape_dsl_function +def kwargs_ir(x: int, **kwargs) -> int: # E: @shape_dsl_function: **kwargs parameters are not supported + return x "#, ); env.add_with_path( @@ -38,7 +48,7 @@ def not_a_dsl_fn(x: int) -> int: ... r#" from typing import overload from shape_extensions import uses_shape_dsl -from my_shapes import identity_ir, double_ir, not_a_dsl_fn +from my_shapes import identity_ir, double_ir, not_a_dsl_fn, bad_syntax_ir, kwargs_ir @uses_shape_dsl(identity_ir) def plain_fn(x: int) -> int: ... @@ -61,6 +71,12 @@ def double_fn(x: int) -> int: ... @uses_shape_dsl(not_a_dsl_fn) # E: `@uses_shape_dsl` argument does not resolve to a `@shape_dsl_function` def bad_fn(x: int) -> int: ... + +@uses_shape_dsl(bad_syntax_ir) # E: `@uses_shape_dsl` argument does not resolve to a `@shape_dsl_function` +def bad_syntax_fn(x: int) -> int: ... + +@uses_shape_dsl(kwargs_ir) +def kwargs_fn(x: int) -> int: ... "#, ); env @@ -127,3 +143,29 @@ from my_lib import bad_fn assert_type(bad_fn(1), int) "#, ); + +testcase!( + test_shape_dsl_unsupported_syntax, + shape_dsl_env(), + r#" +from typing import assert_type +from my_lib import bad_syntax_fn + +# bad_syntax_ir uses a while loop which is unsupported DSL syntax, so +# bad_syntax_fn falls back to the declared return type. +assert_type(bad_syntax_fn(1), int) +"#, +); + +testcase!( + test_shape_dsl_kwargs_warning, + shape_dsl_env(), + r#" +from typing import Literal, assert_type +from my_lib import kwargs_fn + +# kwargs_ir has **kwargs which triggers a warning but the DSL conversion +# still succeeds (kwargs are silently dropped), so shape inference works. +assert_type(kwargs_fn(1), Literal[1]) +"#, +); From 53dc24753a0b6a26d1b35cb37976a5de35ed4fd0 Mon Sep 17 00:00:00 2001 From: Steven Troxler Date: Thu, 21 May 2026 23:46:16 -0700 Subject: [PATCH 25/25] Convert `type_check_program` from panics to collected errors (#3487) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Summary: Pull Request resolved: https://github.com/facebook/pyrefly/pull/3487 **This stack** Reworks tensor shape operations so that instead of being hardcoded in tensor_ops_registry.rs, only DSL primitives are directly hardcoded into Pyrefly; the actual operations and the association with "normal" (non DSL) stubs all lives in user-space stub files. This allows iterating on the ops without rebuilding Pryefly, and is 100% essential for actually building out full stubs for pytorch (and even more so if we want to extend to other libraries like numpy and jax). The DSL itself is unchanged but we will use a decorator to indicate when a stub function is a DSL function; we use a different decorator to actually register a DSL function as the "shape transform" associated with some normal function (e.g. to associate the DSL function `reshape_ir` with a `torch.reshape` function). Details of the plan are in https://github.com/stroxler/pyrefly-docs/blob/main/tensor-shapes-in-stubs/v2-doc.md **This commit** Replace ~20 panic sites in the DSL type checker with error collection, so type errors in `shape_dsl_function` stubs produce diagnostics instead of crashing the type checker. `type_check_program` now returns `Result<(), Vec>`, threading errors through `check_body`, `check_expr`, `infer_expr`, `infer_call`, and the narrowing/joining helpers. `validate_shape_dsl_functions` propagates these errors to the solver, which emits them as `InvalidArgument` diagnostics on the function definition. Eval-time panics in `eval_dsl_body` / `eval_dsl_expr` are intentionally left as panics — they are correctness assertions that should be unreachable for type-checked programs. Differential Revision: D105783604 --- crates/pyrefly_types/src/meta_shape_dsl.rs | 305 +++++++++++++-------- pyrefly/lib/alt/function.rs | 38 ++- pyrefly/lib/test/shape_dsl.rs | 10 + 3 files changed, 228 insertions(+), 125 deletions(-) diff --git a/crates/pyrefly_types/src/meta_shape_dsl.rs b/crates/pyrefly_types/src/meta_shape_dsl.rs index 5f56a4c545..e1a8e42e0c 100644 --- a/crates/pyrefly_types/src/meta_shape_dsl.rs +++ b/crates/pyrefly_types/src/meta_shape_dsl.rs @@ -1486,16 +1486,19 @@ fn join(a: &DslType, b: &DslType) -> DslType { } } -/// Extract the element type from a list type. Panics on non-list (IR bug). -fn element_type(ty: &DslType) -> DslType { +/// Extract the element type from a list type. Pushes an error on non-list. +fn element_type(ty: &DslType, errors: &mut Vec) -> DslType { match ty { DslType::List(inner) => *inner.clone(), - _ => unreachable!("expected list type, got {}", ty), + _ => { + errors.push(format!("expected list type, got {}", ty)); + DslType::Int + } } } /// Narrow a type to only variants matching a constructor. -fn narrow_to(ty: &DslType, con: DslTypeCon) -> DslType { +fn narrow_to(ty: &DslType, con: DslTypeCon, errors: &mut Vec) -> DslType { match ty { DslType::Union(types) => { let matching: Vec<_> = types @@ -1505,17 +1508,29 @@ fn narrow_to(ty: &DslType, con: DslTypeCon) -> DslType { .collect(); match matching.len() { 1 => matching.into_iter().next().unwrap(), - 0 => unreachable!("isinstance narrowing: no variant of {} matches {}", ty, con), + 0 => { + errors.push(format!( + "isinstance narrowing: no variant of {} matches {}", + ty, con + )); + ty.clone() + } _ => DslType::Union(matching), } } _ if matches_constructor(ty, con) => ty.clone(), - _ => unreachable!("isinstance narrowing: {} does not match {}", ty, con), + _ => { + errors.push(format!( + "isinstance narrowing: {} does not match {}", + ty, con + )); + ty.clone() + } } } /// Narrow a type to exclude variants matching a constructor. -fn narrow_away(ty: &DslType, con: DslTypeCon) -> DslType { +fn narrow_away(ty: &DslType, con: DslTypeCon, errors: &mut Vec) -> DslType { match ty { DslType::Union(types) => { let remaining: Vec<_> = types @@ -1525,7 +1540,10 @@ fn narrow_away(ty: &DslType, con: DslTypeCon) -> DslType { .collect(); match remaining.len() { 1 => remaining.into_iter().next().unwrap(), - 0 => unreachable!("narrowed away all variants of {}", ty), + 0 => { + errors.push(format!("narrowed away all variants of {}", ty)); + ty.clone() + } _ => DslType::Union(remaining), } } @@ -1534,7 +1552,7 @@ fn narrow_away(ty: &DslType, con: DslTypeCon) -> DslType { } /// Narrow a type to exclude None. -fn narrow_away_none(ty: &DslType) -> DslType { +fn narrow_away_none(ty: &DslType, errors: &mut Vec) -> DslType { match ty { DslType::Union(types) => { let remaining: Vec<_> = types @@ -1544,7 +1562,10 @@ fn narrow_away_none(ty: &DslType) -> DslType { .collect(); match remaining.len() { 1 => remaining.into_iter().next().unwrap(), - 0 => unreachable!("narrowed away all variants of {}", ty), + 0 => { + errors.push(format!("narrowed away all variants of {}", ty)); + ty.clone() + } _ => DslType::Union(remaining), } } @@ -1554,7 +1575,7 @@ fn narrow_away_none(ty: &DslType) -> DslType { /// Analyze a condition expression for type narrowing. /// Returns (then_env, else_env) — the environments for the true and false branches. -fn narrow(cond: &DslExpr, env: &TypeEnv) -> (TypeEnv, TypeEnv) { +fn narrow(cond: &DslExpr, env: &TypeEnv, errors: &mut Vec) -> (TypeEnv, TypeEnv) { match cond { // isinstance(x, con) DslExpr::IsInstance { expr, ty } => { @@ -1563,8 +1584,8 @@ fn narrow(cond: &DslExpr, env: &TypeEnv) -> (TypeEnv, TypeEnv) { { let mut then_env = env.clone(); let mut else_env = env.clone(); - then_env.insert(name.clone(), narrow_to(var_ty, *ty)); - else_env.insert(name.clone(), narrow_away(var_ty, *ty)); + then_env.insert(name.clone(), narrow_to(var_ty, *ty, errors)); + else_env.insert(name.clone(), narrow_away(var_ty, *ty, errors)); return (then_env, else_env); } (env.clone(), env.clone()) @@ -1582,7 +1603,7 @@ fn narrow(cond: &DslExpr, env: &TypeEnv) -> (TypeEnv, TypeEnv) { let mut then_env = env.clone(); let mut else_env = env.clone(); then_env.insert(name.clone(), DslType::None); - else_env.insert(name.clone(), narrow_away_none(var_ty)); + else_env.insert(name.clone(), narrow_away_none(var_ty, errors)); return (then_env, else_env); } (env.clone(), env.clone()) @@ -1599,7 +1620,7 @@ fn narrow(cond: &DslExpr, env: &TypeEnv) -> (TypeEnv, TypeEnv) { { let mut then_env = env.clone(); let mut else_env = env.clone(); - then_env.insert(name.clone(), narrow_away_none(var_ty)); + then_env.insert(name.clone(), narrow_away_none(var_ty, errors)); else_env.insert(name.clone(), DslType::None); return (then_env, else_env); } @@ -1610,7 +1631,7 @@ fn narrow(cond: &DslExpr, env: &TypeEnv) -> (TypeEnv, TypeEnv) { op: DslUnaryOp::Not, operand, } => { - let (then_env, else_env) = narrow(operand, env); + let (then_env, else_env) = narrow(operand, env, errors); (else_env, then_env) } // cond1 and cond2 — narrow both in then-branch, conservative in else @@ -1619,8 +1640,8 @@ fn narrow(cond: &DslExpr, env: &TypeEnv) -> (TypeEnv, TypeEnv) { op: DslOp::And, right, } => { - let (then1, _) = narrow(left, env); - let (then2, _) = narrow(right, &then1); + let (then1, _) = narrow(left, env, errors); + let (then2, _) = narrow(right, &then1, errors); (then2, env.clone()) } _ => (env.clone(), env.clone()), @@ -1628,15 +1649,15 @@ fn narrow(cond: &DslExpr, env: &TypeEnv) -> (TypeEnv, TypeEnv) { } /// Build function return type map from DSL definitions. -fn build_fn_ret_types(fndefs: &[DslFnDef]) -> FnRetTypes { +fn build_fn_ret_types(fndefs: &[DslFnDef], errors: &mut Vec) -> FnRetTypes { fndefs .iter() - .map(|f| { - let return_type = f - .return_type - .clone() - .unwrap_or_else(|| unreachable!("DSL function {} must have a return type", f.name)); - (f.name.clone(), return_type) + .filter_map(|f| match &f.return_type { + Some(rt) => Some((f.name.clone(), rt.clone())), + None => { + errors.push(format!("DSL function {} must have a return type", f.name)); + None + } }) .collect() } @@ -1653,56 +1674,75 @@ fn arithmetic_result(a: &DslType, b: &DslType) -> DslType { } /// Infer the element type of a list literal from its elements. -fn infer_list_elem_type(elts: &[DslExpr], env: &TypeEnv, sigs: &FnRetTypes) -> DslType { - assert!( - !elts.is_empty(), - "infer_list_elem_type called with empty list" - ); - let mut result = infer_expr(&elts[0], env, sigs); +fn infer_list_elem_type( + elts: &[DslExpr], + env: &TypeEnv, + sigs: &FnRetTypes, + errors: &mut Vec, +) -> DslType { + if elts.is_empty() { + errors.push("infer_list_elem_type called with empty list".to_owned()); + return DslType::Int; + } + let mut result = infer_expr(&elts[0], env, sigs, errors); for elt in &elts[1..] { - result = join(&result, &infer_expr(elt, env, sigs)); + result = join(&result, &infer_expr(elt, env, sigs, errors)); } result } /// Bind comprehension variables based on the iterator expression. /// Handles zip (multi-list iteration) and enumerate (index + element). -fn bind_comp_vars(vars: &[String], iter: &DslExpr, env: &TypeEnv, sigs: &FnRetTypes) -> TypeEnv { +fn bind_comp_vars( + vars: &[String], + iter: &DslExpr, + env: &TypeEnv, + sigs: &FnRetTypes, + errors: &mut Vec, +) -> TypeEnv { let mut new_env = env.clone(); match iter { DslExpr::Call { func: DslCallTarget::Builtin(DslBuiltin::Zip), args, } => { - assert_eq!( - vars.len(), - args.len(), - "zip: {} vars but {} args", - vars.len(), - args.len() - ); + if vars.len() != args.len() { + errors.push(format!("zip: {} vars but {} args", vars.len(), args.len())); + } for (var, arg) in vars.iter().zip(args.iter()) { - let arg_ty = infer_expr(arg, env, sigs); - new_env.insert(var.clone(), element_type(&arg_ty)); + let arg_ty = infer_expr(arg, env, sigs, errors); + new_env.insert(var.clone(), element_type(&arg_ty, errors)); } } DslExpr::Call { func: DslCallTarget::Builtin(DslBuiltin::Enumerate), args, } => { - assert_eq!(args.len(), 1, "enumerate takes exactly 1 argument"); - assert_eq!(vars.len(), 2, "enumerate requires exactly 2 variables"); - let list_ty = infer_expr(&args[0], env, sigs); - new_env.insert(vars[0].clone(), DslType::Int); - new_env.insert(vars[1].clone(), element_type(&list_ty)); + if args.len() != 1 { + errors.push(format!( + "enumerate takes exactly 1 argument, got {}", + args.len() + )); + } + if vars.len() != 2 { + errors.push(format!( + "enumerate requires exactly 2 variables, got {}", + vars.len() + )); + } + if !args.is_empty() && vars.len() >= 2 { + let list_ty = infer_expr(&args[0], env, sigs, errors); + new_env.insert(vars[0].clone(), DslType::Int); + new_env.insert(vars[1].clone(), element_type(&list_ty, errors)); + } } _ => { - let iter_ty = infer_expr(iter, env, sigs); + let iter_ty = infer_expr(iter, env, sigs, errors); if vars.len() == 1 { - new_env.insert(vars[0].clone(), element_type(&iter_ty)); + new_env.insert(vars[0].clone(), element_type(&iter_ty, errors)); } else { // Multiple vars iterating over a single list — each gets element type. - let elem = element_type(&iter_ty); + let elem = element_type(&iter_ty, errors); for var in vars { new_env.insert(var.clone(), elem.clone()); } @@ -1713,17 +1753,23 @@ fn bind_comp_vars(vars: &[String], iter: &DslExpr, env: &TypeEnv, sigs: &FnRetTy } /// Infer the return type of a function call. -fn infer_call(func: &DslCallTarget, args: &[DslExpr], env: &TypeEnv, sigs: &FnRetTypes) -> DslType { +fn infer_call( + func: &DslCallTarget, + args: &[DslExpr], + env: &TypeEnv, + sigs: &FnRetTypes, + errors: &mut Vec, +) -> DslType { // Infer all arguments for logging, regardless of whether we need them. for arg in args { - infer_expr(arg, env, sigs); + infer_expr(arg, env, sigs, errors); } match func { DslCallTarget::Builtin(builtin) => match builtin { // prod/sum reduce a list of dims to a single dim. DslBuiltin::Prod | DslBuiltin::Sum => { - let arg_ty = infer_expr(&args[0], env, sigs); - element_type(&arg_ty) + let arg_ty = infer_expr(&args[0], env, sigs, errors); + element_type(&arg_ty, errors) } DslBuiltin::Str => DslType::Str, DslBuiltin::ParseEinsumEquation => DslType::List(Box::new(DslType::List(Box::new( @@ -1732,18 +1778,30 @@ fn infer_call(func: &DslCallTarget, args: &[DslExpr], env: &TypeEnv, sigs: &FnRe DslBuiltin::Len => DslType::Int, DslBuiltin::Range => DslType::List(Box::new(DslType::Int)), DslBuiltin::Zip | DslBuiltin::Enumerate => { - unreachable!("{} should only appear as comprehension iterator", builtin) + errors.push(format!( + "{} should only appear as comprehension iterator", + builtin + )); + DslType::Int + } + }, + DslCallTarget::UserDefined(name) => match sigs.get(name) { + Some(ty) => ty.clone(), + None => { + errors.push(format!("undefined function: {}", name)); + DslType::Int } }, - DslCallTarget::UserDefined(name) => sigs - .get(name) - .unwrap_or_else(|| unreachable!("undefined function: {}", name)) - .clone(), } } /// Infer the type of a DSL expression. -fn infer_expr(expr: &DslExpr, env: &TypeEnv, sigs: &FnRetTypes) -> DslType { +fn infer_expr( + expr: &DslExpr, + env: &TypeEnv, + sigs: &FnRetTypes, + errors: &mut Vec, +) -> DslType { match expr { DslExpr::Const(c) => match c { DslConst::None => DslType::None, @@ -1751,63 +1809,65 @@ fn infer_expr(expr: &DslExpr, env: &TypeEnv, sigs: &FnRetTypes) -> DslType { DslConst::Bool(_) => DslType::Bool, DslConst::Str(_) => DslType::Str, }, - DslExpr::Var(name) => env - .get(name) - .cloned() - .unwrap_or_else(|| unreachable!("undefined variable: {}", name)), + DslExpr::Var(name) => match env.get(name) { + Some(ty) => ty.clone(), + None => { + errors.push(format!("undefined variable: {}", name)); + DslType::Int + } + }, DslExpr::List(elts) => { if elts.is_empty() { // All empty list literals in the DSL are dimension lists. DslType::List(Box::new(dim_type())) } else if matches!(elts.last(), Some(DslExpr::Ellipsis)) { // [expr, ...] — unbounded list sentinel. - let elem_ty = infer_list_elem_type(&elts[..elts.len() - 1], env, sigs); + let elem_ty = infer_list_elem_type(&elts[..elts.len() - 1], env, sigs, errors); DslType::List(Box::new(elem_ty)) } else { - let elem_ty = infer_list_elem_type(elts, env, sigs); + let elem_ty = infer_list_elem_type(elts, env, sigs, errors); DslType::List(Box::new(elem_ty)) } } DslExpr::ListComp { elt, vars, iter, .. } => { - let comp_env = bind_comp_vars(vars, iter, env, sigs); - let elt_ty = infer_expr(elt, &comp_env, sigs); + let comp_env = bind_comp_vars(vars, iter, env, sigs, errors); + let elt_ty = infer_expr(elt, &comp_env, sigs, errors); DslType::List(Box::new(elt_ty)) } DslExpr::Index { base, index } => { - let base_ty = infer_expr(base, env, sigs); - infer_expr(index, env, sigs); - element_type(&base_ty) + let base_ty = infer_expr(base, env, sigs, errors); + infer_expr(index, env, sigs, errors); + element_type(&base_ty, errors) } DslExpr::Slice { base, lower, upper } => { - let base_ty = infer_expr(base, env, sigs); + let base_ty = infer_expr(base, env, sigs, errors); if let Some(l) = lower { - infer_expr(l, env, sigs); + infer_expr(l, env, sigs, errors); } if let Some(u) = upper { - infer_expr(u, env, sigs); + infer_expr(u, env, sigs, errors); } base_ty } DslExpr::BinOp { left, op, right } => { - let lt = infer_expr(left, env, sigs); - let rt = infer_expr(right, env, sigs); + let lt = infer_expr(left, env, sigs, errors); + let rt = infer_expr(right, env, sigs, errors); match op { DslOp::Add => { // List concatenation, string concatenation, or numeric addition. if let DslType::List(a) = < { - let DslType::List(b) = &rt else { - unreachable!("+ with list and non-list: {} + {}", lt, rt) - }; - DslType::List(Box::new(join(a, b))) + if let DslType::List(b) = &rt { + DslType::List(Box::new(join(a, b))) + } else { + errors.push(format!("+ with list and non-list: {} + {}", lt, rt)); + DslType::Int + } } else if matches!(lt, DslType::Str) { - assert!( - matches!(rt, DslType::Str), - "+ with str and non-str: {} + {}", - lt, - rt - ); + if !matches!(rt, DslType::Str) { + errors.push(format!("+ with str and non-str: {} + {}", lt, rt)); + } DslType::Str } else { arithmetic_result(<, &rt) @@ -1828,84 +1888,94 @@ fn infer_expr(expr: &DslExpr, env: &TypeEnv, sigs: &FnRetTypes) -> DslType { } DslExpr::UnaryOp { op, operand } => match op { DslUnaryOp::Not => { - infer_expr(operand, env, sigs); + infer_expr(operand, env, sigs, errors); DslType::Bool } - DslUnaryOp::Neg => infer_expr(operand, env, sigs), + DslUnaryOp::Neg => infer_expr(operand, env, sigs, errors), }, - DslExpr::Call { func, args } => infer_call(func, args, env, sigs), + DslExpr::Call { func, args } => infer_call(func, args, env, sigs, errors), DslExpr::IsInstance { expr, .. } => { - infer_expr(expr, env, sigs); + infer_expr(expr, env, sigs, errors); DslType::Bool } DslExpr::In { left, right } => { - infer_expr(left, env, sigs); - infer_expr(right, env, sigs); + infer_expr(left, env, sigs, errors); + infer_expr(right, env, sigs, errors); DslType::Bool } DslExpr::Shape(inner) => { - infer_expr(inner, env, sigs); + infer_expr(inner, env, sigs, errors); DslType::List(Box::new(dim_type())) } DslExpr::TensorNew(inner) => { - infer_expr(inner, env, sigs); + infer_expr(inner, env, sigs, errors); DslType::Tensor } DslExpr::IfExpr { body, test, orelse } => { - let (then_env, else_env) = narrow(test, env); - let body_ty = infer_expr(body, &then_env, sigs); - let else_ty = infer_expr(orelse, &else_env, sigs); + let (then_env, else_env) = narrow(test, env, errors); + let body_ty = infer_expr(body, &then_env, sigs, errors); + let else_ty = infer_expr(orelse, &else_env, sigs, errors); join(&body_ty, &else_ty) } - DslExpr::Ellipsis => unreachable!("Ellipsis should be handled by List"), + DslExpr::Ellipsis => { + errors.push("Ellipsis should be handled by List".to_owned()); + DslType::Int + } DslExpr::Unknown => DslType::None, // sentinel for fixture fallback } } /// Type-check a function body, updating the environment through assignments /// and narrowing through conditionals. -fn check_body(body: &DslBody, env: &TypeEnv, sigs: &FnRetTypes) { +fn check_body(body: &DslBody, env: &TypeEnv, sigs: &FnRetTypes, errors: &mut Vec) { match body { DslBody::Assign { vars, expr, rest } => { - let ty = infer_expr(expr, env, sigs); + let ty = infer_expr(expr, env, sigs, errors); let mut new_env = env.clone(); if vars.len() == 1 { new_env.insert(vars[0].clone(), ty); } else { - let elem = element_type(&ty); + let elem = element_type(&ty, errors); for var in vars { new_env.insert(var.clone(), elem.clone()); } } - check_body(rest, &new_env, sigs); + check_body(rest, &new_env, sigs, errors); } DslBody::If { cond, then_body, rest, } => { - let (then_env, else_env) = narrow(cond, env); - check_body(then_body, &then_env, sigs); - check_body(rest, &else_env, sigs); + let (then_env, else_env) = narrow(cond, env, errors); + check_body(then_body, &then_env, sigs, errors); + check_body(rest, &else_env, sigs, errors); } DslBody::Return(expr) => { - infer_expr(expr, env, sigs); + infer_expr(expr, env, sigs, errors); } DslBody::Raise(expr) => { - infer_expr(expr, env, sigs); + infer_expr(expr, env, sigs, errors); } } } -/// Type-check all DSL function definitions. -fn type_check_program(fndefs: &[DslFnDef]) { - let sigs = build_fn_ret_types(fndefs); +/// Type-check all DSL function definitions. Returns `Err` with collected +/// error messages if type errors are found. +fn type_check_program(fndefs: &[DslFnDef]) -> Result<(), Vec> { + let mut errors = Vec::new(); + let sigs = build_fn_ret_types(fndefs, &mut errors); for fndef in fndefs { let mut env = TypeEnv::new(); for param in &fndef.params { env.insert(param.name.clone(), param.ty.clone()); } - check_body(&fndef.body, &env, &sigs); + check_body(&fndef.body, &env, &sigs, &mut errors); + } + if errors.is_empty() { + Ok(()) + } else { + Err(errors) } } @@ -3273,13 +3343,14 @@ fn collect_call_targets_expr(expr: &DslExpr, targets: &mut HashSet) { /// Validate a set of `ShapeDslFunction`s as a program. /// /// Runs `type_check_program` on the inner `DslFnDef`s, verifying that -/// cross-function calls have consistent signatures. Panics on type errors. +/// cross-function calls have consistent signatures. Returns collected +/// type error messages on failure. /// /// Intended to be called with a per-caller transitive closure (root + /// its resolved helpers), not the full module. -pub fn validate_shape_dsl_functions(fns: &[Arc]) { +pub fn validate_shape_dsl_functions(fns: &[Arc]) -> Result<(), Vec> { let defs: Vec = fns.iter().map(|f| (*f.inner).clone()).collect(); - type_check_program(&defs); + type_check_program(&defs) } /// Reference to a shape-DSL function that refines a callable's return type. @@ -3414,7 +3485,11 @@ pub fn build_shape_dsl_program(fns: impl IntoIterator) // `type_check_program` panics on type errors today; that matches the // existing `parse_dsl` semantics. let view: Vec = fns.iter().map(|f| (**f).clone()).collect(); - type_check_program(&view); + // Errors from `type_check_program` are intentionally discarded here: + // `build_shape_dsl_program` is called from test/fixture code where + // panicking is acceptable. The production path goes through + // `validate_shape_dsl_functions` which propagates errors as diagnostics. + let _ = type_check_program(&view); ShapeDslProgram { fns } } diff --git a/pyrefly/lib/alt/function.rs b/pyrefly/lib/alt/function.rs index 99867ba0c1..cb6158a15a 100644 --- a/pyrefly/lib/alt/function.rs +++ b/pyrefly/lib/alt/function.rs @@ -545,20 +545,38 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { let tparams = self.validated_tparams(def.range, tparams, TParamsSource::Function, errors); let kind = if let Some(dsl_fn) = shape_dsl_def { - // Shape DSL functions carry their parsed IR for call-site evaluation. - let func_id = Arc::new(FuncId { - module: self.module().dupe(), - cls: defining_cls.clone(), - name: def.name.id.clone(), - def_index: Some(def_index), - outer_funcs, - }); // Build the transitive closure of helper functions this DSL function calls, // then validate cross-function call signatures. let module_dsl_fns = self.bindings().metadata().shape_dsl_functions(); let helpers = compute_transitive_helpers(&dsl_fn, module_dsl_fns); - validate_shape_dsl_functions(&helpers); - FunctionKind::ShapeDsl(func_id, dsl_fn, Derived(helpers)) + if let Err(type_errors) = validate_shape_dsl_functions(&helpers) { + for msg in &type_errors { + self.error( + errors, + def.name.range, + ErrorKind::InvalidArgument, + format!("@shape_dsl_function type error: {msg}"), + ); + } + // Fall back to a normal function — the DSL evaluator must + // never run on a program that failed type checking. + FunctionKind::from_name( + self.module().dupe(), + defining_cls.clone(), + &def.name.id, + Some(def_index), + outer_funcs, + ) + } else { + let func_id = Arc::new(FuncId { + module: self.module().dupe(), + cls: defining_cls.clone(), + name: def.name.id.clone(), + def_index: Some(def_index), + outer_funcs, + }); + FunctionKind::ShapeDsl(func_id, dsl_fn, Derived(helpers)) + } } else { FunctionKind::from_name( self.module().dupe(), diff --git a/pyrefly/lib/test/shape_dsl.rs b/pyrefly/lib/test/shape_dsl.rs index 0b2e026ff3..9bc1045862 100644 --- a/pyrefly/lib/test/shape_dsl.rs +++ b/pyrefly/lib/test/shape_dsl.rs @@ -40,6 +40,10 @@ def bad_syntax_ir(x: int) -> int: @shape_dsl_function def kwargs_ir(x: int, **kwargs) -> int: # E: @shape_dsl_function: **kwargs parameters are not supported return x + +@shape_dsl_function +def calls_undefined(x: int) -> int: # E: @shape_dsl_function type error: undefined function: nonexistent + return nonexistent(x) # E: Could not find name `nonexistent` "#, ); env.add_with_path( @@ -77,6 +81,7 @@ def bad_syntax_fn(x: int) -> int: ... @uses_shape_dsl(kwargs_ir) def kwargs_fn(x: int) -> int: ... + "#, ); env @@ -169,3 +174,8 @@ from my_lib import kwargs_fn assert_type(kwargs_fn(1), Literal[1]) "#, ); + +// The `calls_undefined` function in my_shapes.pyi calls `nonexistent()`, +// which produces a type error diagnostic (tested by the `# E:` annotation +// on its definition). No separate test case is needed here — the error is +// validated by every test that uses `shape_dsl_env()`.