Skip to content

Convert type_check_program from panics to collected errors (#3487)#3487

Open
stroxler wants to merge 25 commits into
mainfrom
export-D105783604
Open

Convert type_check_program from panics to collected errors (#3487)#3487
stroxler wants to merge 25 commits into
mainfrom
export-D105783604

Conversation

@stroxler
Copy link
Copy Markdown
Contributor

@stroxler stroxler commented May 20, 2026

Summary:

This stack

Reworks tensor shape operations so that instead of being hardcoded in tensor_ops_registry.rs, only DSL primitives are
directly hardcoded into Pyrefly; the actual operations and the association with "normal" (non DSL) stubs all lives in user-space
stub files. This allows iterating on the ops without rebuilding Pryefly, and is 100% essential for actually building out full stubs
for pytorch (and even more so if we want to extend to other libraries like numpy and jax).

The DSL itself is unchanged but we will use a decorator to indicate when a stub function is a DSL function; we use a different
decorator to actually register a DSL function as the "shape transform" associated with some normal function (e.g. to
associate the DSL function reshape_ir with a torch.reshape function).

Details of the plan are in https://github.com/stroxler/pyrefly-docs/blob/main/tensor-shapes-in-stubs/v2-doc.md

This commit

Replace ~20 panic sites in the DSL type checker with error collection,
so type errors in shape_dsl_function stubs produce diagnostics instead
of crashing the type checker.

type_check_program now returns Result<(), Vec<String>>, threading
errors through check_body, check_expr, infer_expr, infer_call,
and the narrowing/joining helpers. validate_shape_dsl_functions
propagates these errors to the solver, which emits them as
InvalidArgument diagnostics on the function definition.

Eval-time panics in eval_dsl_body / eval_dsl_expr are intentionally
left as panics — they are correctness assertions that should be
unreachable for type-checked programs.

Differential Revision: D105783604

@meta-cla meta-cla Bot added the cla signed label May 20, 2026
@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync Bot commented May 20, 2026

@stroxler has exported this pull request. If you are a Meta employee, you can view the originating Diff in D105783604.

@stroxler stroxler self-assigned this May 20, 2026
stroxler added a commit that referenced this pull request May 20, 2026
Summary:
Pull Request resolved: #3487

**This stack**

Reworks tensor shape operations so that instead of being hardcoded in tensor_ops_registry.rs, only DSL primitives are
directly hardcoded into Pyrefly; the actual operations and the association with "normal" (non DSL) stubs all lives in user-space
stub files. This allows iterating on the ops without rebuilding Pryefly, and is 100% essential for actually building out full stubs
for pytorch (and even more so if we want to extend to other libraries like numpy and jax).

The DSL itself is unchanged but we will use a decorator to indicate when a stub function is a DSL function; we use a different
decorator to actually register a DSL function as the "shape transform" associated with some normal function (e.g. to
associate the DSL function `reshape_ir` with a `torch.reshape` function).

Details of the plan are in https://github.com/stroxler/pyrefly-docs/blob/main/tensor-shapes-in-stubs/v2-doc.md

**This commit**

Replace ~20 panic sites in the DSL type checker with error collection,
so type errors in `shape_dsl_function` stubs produce diagnostics instead
of crashing the type checker.

`type_check_program` now returns `Result<(), Vec<String>>`, threading
errors through `check_body`, `check_expr`, `infer_expr`, `infer_call`,
and the narrowing/joining helpers. `validate_shape_dsl_functions`
propagates these errors to the solver, which emits them as
`InvalidArgument` diagnostics on the function definition.

Eval-time panics in `eval_dsl_body` / `eval_dsl_expr` are intentionally
left as panics — they are correctness assertions that should be
unreachable for type-checked programs.

Differential Revision: D105783604
@stroxler stroxler force-pushed the export-D105783604 branch from 3e69bd6 to 719e973 Compare May 20, 2026 04:50
@meta-codesync meta-codesync Bot changed the title Convert type_check_program from panics to collected errors Convert type_check_program from panics to collected errors (#3487) May 20, 2026
@github-actions github-actions Bot added size/xl and removed size/xl labels May 20, 2026
@github-actions

This comment has been minimized.

stroxler added 18 commits May 21, 2026 21:19
Summary:
Add the `uses_shape_dsl` decorator to `shape_extensions`, the public API decorator that associates a shape DSL function with an API function in library stubs. At runtime it's a no-op passthrough; Pyrefly will use it at type-checking time to route bound arguments through the shape DSL for return-type refinement.

Also register `UsesShapeDsl` as a `SpecialExport` variant in the Rust export system, so that later phases can detect `uses_shape_dsl` decorators during binding.

Differential Revision: D105696519
Summary:
Add the `shape_extensions.dsl` submodule containing the `shape_dsl_function` decorator. This decorator is DSL-internal — it marks a function whose body should be converted to shape DSL IR during binding. It lives in a separate submodule from the public `shape_extensions` API because it's only used inside DSL definition files (like `torch/_shapes.pyi`), not in normal stubs or user code.

Also register `ShapeDslFunction` as a `SpecialExport` variant with `defined_in` matching `shape_extensions.dsl`, so later phases can detect `shape_dsl_function` decorators during binding.

Differential Revision: D105696516
Summary:
Add `prod`, `sum`, and `parse_einsum_equation` stub definitions to `shape_extensions/dsl.py` so DSL function files can import them from the canonical location. Update all DSL builtin references from `shape_extensions.*` to `shape_extensions.dsl.*` in both the Rust string matching (`convert_call`, `Display for DslBuiltin`) and the `DSL_SOURCE` Python code.

Fix `convert_call` to support multi-level dotted names (e.g. `shape_extensions.dsl.prod`) — it previously only handled single-dotted names like `shape_extensions.prod`.

Differential Revision: D105698000
Summary: Add an import statement at the top of `DSL_SOURCE` to make it look more like a normal Python module. The import is silently skipped by `parse_dsl` (which only processes function definitions), but having it there makes the DSL source ready to behave like a real stub file once Phase 5 migrates it out of the Rust string.

Differential Revision: D105698001
Summary:
Phase 2 of the tensor-shapes-in-stubs migration. Adds a public surface
to `pyrefly_types::meta_shape_dsl` that lets later phases (the binder
and solver in `pyrefly/lib`) drive the DSL pipeline without exposing
the grammar-aligned `DslFnDef` internals.

The new surface:

- `ShapeDslFunction` — an opaque, cheap (one `Arc`) handle to a single
  DSL function lowered from its Python AST.
- `ShapeDslProgram` — a bundle of `ShapeDslFunction`s that has been
  validated together as a program. The only way to obtain one is via
  `build_shape_dsl_program`, which type-checks the bundle.
- `convert_shape_dsl_function` — AST → `ShapeDslFunction`.
- `build_shape_dsl_program` — `Iterator<ShapeDslFunction>` →
  `ShapeDslProgram`. Panics on type-check failure today, matching the
  existing `parse_dsl` semantics; Phase 7 of the migration will
  convert this to a `Result` and surface diagnostics.
- `make_meta_shape_function` — `(&ShapeDslProgram, root_name)` →
  `Box<dyn MetaShapeFunction>`. Taking a `&ShapeDslProgram` (not raw
  pieces) enforces that callers cannot build a `MetaShapeFunction`
  from un-type-checked DSL.

The internal helpers (`convert_fndef`, `type_check_program`,
`bind_dsl_params`, `eval_dsl_body`) stay module-private; the wrappers
live in the same module and call them directly. This keeps DSL
internals fully opaque outside the module.

No call sites change in this commit; the new API is exercised by the
follow-up dogfood refactor of `tensor_ops_registry`.

Differential Revision: D105720303
Summary:
Refactor `TensorOpsRegistry::new()` to drive DSL construction through
the public `pyrefly_types::meta_shape_dsl` wrapper API added in the
previous commit, rather than reaching into the now-obsolete
`parse_dsl` / `Arc<DslFnDef>` / `DslMetaShapeFunction` internals.

This validates the new API works end-to-end before Phase 3 depends on
it, and it shrinks the number of code paths through the DSL engine to
one: AST → `convert_shape_dsl_function` → `build_shape_dsl_program`
→ `make_meta_shape_function`.

`parse_dsl` was the only caller of itself; with the registry
converted there are no callers left, so it is deleted (together with
the source-text parsing helper it relied on). The pipeline still
panics on parse / type-check failure, exactly as before — Phase 7
will revisit error handling end-to-end.

Differential Revision: D105720304
Summary: Thread a `capture_init: Option<Vec<Name>>` field from class binding through to solved `ClassMetadata`, following the `pydantic_before_validator_fields` precedent. This field holds `__init__` parameter names extracted from `uses_shape_dsl(..., capture_init=[...])` decorators on `forward` methods — today it is populated but not yet consumed. Phase 4 will wire it into `maybe_wrap_nn_module` to replace the hardcoded `TensorOpsRegistry::get_init_capture` lookup.

Differential Revision: D105720302
Summary:
Add a new `FunctionKind::ShapeDsl(Arc<FuncId>, Arc<ShapeDslFunction>)` variant that will represent functions whose return types are computed by the shape DSL. The `FuncId` provides identity (module, class, name) for display and lookup; the `ShapeDslFunction` carries the parsed DSL IR.

The DSL definition is carried inside the `FunctionKind` variant rather than as a separate `Option` field on `Function` to avoid touching ~30-40 construction sites with `dsl_def: None`. This is semantically equivalent — `dsl_def` is `Some` exactly when `FunctionKind` is `ShapeDsl`, so embedding it in the variant enforces the invariant by construction.

Adds `PartialEq`/`Eq`/`Hash`/`Ord`/`Visit`/`VisitMut`/`TypeEq` implementations on `ShapeDslFunction` (pointer-identity semantics, no-op visiting since DSL IR contains no `Type` values).

Differential Revision: D105720305
Summary:
Wire the binder and solver so that functions decorated with `shape_dsl_function` (from `shape_extensions.dsl`) are converted to DSL IR at binding time and produce `FunctionKind::ShapeDsl` at solve time.

The DSL definition is stored as `shape_dsl_def: Option<Arc<ShapeDslFunction>>` on `BindingUndecoratedFunction` rather than as a separate `Binding` variant, since the function still needs its normal name binding and decorator processing chain. When the solver sees `shape_dsl_def` is `Some`, it constructs `FunctionKind::ShapeDsl(func_id, dsl_fn)` instead of `FunctionKind::Def`. The function's name resolves through the normal binding chain, so `from torch._shapes import reshape_ir` works automatically.

The conversion must happen before `function_body()` consumes the AST body via `mem::take`. Conversion failures panic (Phase 7 will add structured diagnostics).

Differential Revision: D105728837
Summary:
When a function is decorated with `uses_shape_dsl(reshape_ir)`, extract the first positional argument's name from the decorator call AST and store it as `uses_shape_dsl_ir_name: Option<Name>` on `BindingUndecoratedFunction`.

This is needed because the existing `KwCall` mechanism in the decorator pipeline only captures keyword arguments, not positional ones. Binding-time extraction makes the IR function name available for Phase 4, where the solver will resolve it to a `Type::Function` with `FunctionKind::ShapeDsl` and wire it into `FuncFlags.shape_transform`.

Differential Revision: D105728838
Summary:
Add `ShapeTransformRef` type in `meta_shape_dsl.rs` and a new
`shape_transform: Option<Arc<ShapeTransformRef>>` field on `FuncFlags`.

`ShapeTransformRef` carries an `Arc<ShapeDslFunction>` — the resolved DSL
function definition. By the time this field is populated (Phase 4b), the IR
function name has been resolved to a `Type::Function` with
`FunctionKind::ShapeDsl`, so we store the extracted definition directly rather
than a name or binding key.

Trait impls follow the same pointer-identity pattern as `ShapeDslFunction`:
`PartialEq/Eq/Hash/PartialOrd/Ord` delegate to the inner `ShapeDslFunction`,
`Visit/VisitMut` are no-ops (DSL IR contains no `Type` values), and `TypeEq`
delegates to `PartialEq`.

Differential Revision: D105739362
… population

Summary:
Add `FunctionKind::UsesShapeDsl` so that calling `uses_shape_dsl(...)` produces
a `Type::KwCall`, making the decorator identifiable in `get_special_decorator`.
Without this, the decorator's type would be plain `Callable`, indistinguishable
from any other callable-returning decorator, and the generic pipeline would
produce `Any`.

The decorator is consumed via `SpecialDecorator::UsesShapeDsl` →
`set_flag_from_special_decorator` (returns `true` to filter it out). The actual
`FuncFlags.shape_transform` is populated after the decorator loop in
`undecorated_function`, where `uses_shape_dsl_ir_name` (extracted at binding
time in Phase 3d) is resolved via `Key::BoundName` to get the IR function's
`ShapeDslFunction` from its `FunctionKind::ShapeDsl` variant.

To enable solve-time lookup, `uses_shape_dsl_ir_name` is changed from
`Option<Name>` to `Option<(Name, ShortIdentifier)>`, carrying the `TextRange`
needed for `Key::BoundName` resolution (Pyrefly's binding lookup is range-based,
not name-based).

Differential Revision: D105739363
Summary:
Verify that `uses_shape_dsl` decorator recognition works correctly:
- Plain function: decorator is consumed and function type is preserved
- Overloaded with implementation: `shape_transform` flows through
  `merge_overload_metadata_with_implementation` via `FuncFlags`
- Stub-only overloads (no implementation): `shape_transform` flows through
  `merge_overload_metadata_no_implementation` from the first overload

These tests catch regressions if `merge_overload_metadata_*` ever changes
in a way that drops `FuncFlags` fields.

Differential Revision: D105739364
Summary:
The `DslType::Int` and `DslType::Bool` branches in `val_to_type` synthesize `Literal[n]` / `Literal[bool]` from the DSL's traced runtime value. This looks inconsistent with the other branches (Tensor, List, Tuple, None, Str) which return `expected_return_type.clone()`, but the difference is intentional and load-bearing.

Functions like `dim_ir`, `numel_ir`, and `size_ir(dim=N)` trace exact integer results. Downstream consumers (assert_type, reshape validation, shape inference) rely on this literal precision. The fixture return type for these functions is just `int` — the literal value comes solely from DSL evaluation. In contrast, the Tensor/List branches' `expected_return_type` already carries refined structure (e.g. `Tensor[B, C, H, W]` with shape injected), so cloning it is correct there.

This commit adds comments explaining the invariant so future readers don't mistake the asymmetry for a bug.

Differential Revision: D105758573
Summary:
Add `ShapeTransformRef::to_meta_shape_function()` which builds a
`DslMetaShapeFunction` from the decorator-carried DSL definition, and
wire it into `callable_infer_inner` so the decorator-based
`shape_transform` is preferred over the legacy registry lookup. The
registry serves as fallback for functions not yet migrated to
`uses_shape_dsl`. The decorator path is intentionally not gated by the
`tensor_shapes` flag since `uses_shape_dsl` is itself the opt-in.

Thread `shape_transform: Option<&ShapeTransformRef>` through the call
inference chain (`call_infer_with_callee_range` → `call_infer_inner` →
`callable_infer` → `callable_infer_inner`) and the overload path
(`call_overloads` → `find_closest_overload` → `call_overload`).

Fix ordering in the binder: `convert_shape_dsl_function` must run before
`function_header`, which consumes `x.returns` via `mem::take`. Without
this, the DSL converter sees no return type annotation and produces
`DslFnDef.return_type = None`.

Update `maybe_wrap_nn_module` to check `ClassMetadata.capture_init()`
first, falling back to `TensorOpsRegistry::get_init_capture` for classes
not yet migrated.

Differential Revision: D105758771
Summary:
Move all 86 shape DSL functions (14 helpers + 72 IR functions) from
`DSL_SOURCE` in `tensor_ops_registry.rs` to
`test/tensor_shapes/fixtures/torch/_shapes.pyi`, adding
`shape_dsl_function` decorators. The function bodies are verbatim copies.

No behavioral change — nothing references this file yet. The decorators
will be consumed in commit 5c when `uses_shape_dsl` decorators are added
to the torch fixture stubs.

Differential Revision: D105775130
Summary:
Enable DSL functions that call helpers (e.g., `reshape_ir` calling
`normalize_dim`) to work through the decorator path. Previously,
`ShapeTransformRef::to_meta_shape_function()` used an empty `fn_lookup`,
which panicked on any helper call.

Introduces a `Derived<T>` wrapper in `pyrefly_types` for attaching
auxiliary data to types without affecting identity comparisons. Uses it
to carry same-module DSL siblings on `FunctionKind::ShapeDsl` and
`ShapeTransformRef`. At `shape_dsl_function` solve time, all siblings
from the module's `BindingsMetadata` are collected; at `uses_shape_dsl`
consumer sites, the siblings flow through to `to_meta_shape_function`
which builds fn_lookup from self + siblings.

This is a deliberate all-siblings shortcut matching the registry's
flat-namespace behavior. Follow-up #9 replaces it with per-caller
transitive-callee resolution.

Differential Revision: D105775131
Summary:
Wire up ~131 torch fixture stub functions with `uses_shape_dsl(ir_fn)`
decorators, matching the `TensorOpsRegistry` mappings. The decorator path
is now preferred over the registry for all decorated functions.

Modified fixture files:
- `torch/__init__.pyi`: ~112 decorators on module-level functions and
  `Tensor` methods (shape manipulation, reductions, creation, linalg,
  indexing, random, conditional, properties)
- `torch/nn/functional.pyi`: ~30 decorators (conv, pool, loss, pad,
  interpolate, cosine_similarity)
- `torch/fft.pyi`: 4 decorators (rfft, irfft, hfft, ihfft)
- `torch/linalg.pyi`: 9 decorators (eig, eigvals, solve, slogdet, etc.)

All tensor_shapes tests pass through the decorator path.

Differential Revision: D105775133
stroxler added 7 commits May 21, 2026 21:19
Summary:
Add `uses_shape_dsl(ir_fn, capture_init=[...])` decorators to the
`forward` methods of all 15 nn.Module classes that have shape-aware
forward inference: MaxPool1d/2d/3d, AvgPool1d/2d/3d, Flatten,
PixelShuffle, GLU, LSTM, Upsample, GRU, LSTMCell, ReflectionPad2d,
ReplicationPad2d.

The `capture_init` plumbing (binding extraction, ClassMetadata
propagation, and `maybe_wrap_nn_module` consumption) was completed in
Phases 3a and 4c. This commit is purely fixture annotations — once added,
the decorators override the registry path for each annotated class.

Differential Revision: D105775129
Summary:
Remove the old hardcoded registry now that all ~131 shape functions and
15 nn.Module classes are wired through `uses_shape_dsl` decorators.

Deleted:
- `tensor_ops_registry.rs` (1,022 lines): `DSL_SOURCE` string, registry
  struct, ~131 `register*()` calls, `OnceLock` statics
- `lookup_meta_shape` in `callable.rs`: registry-based shape function lookup
- Registry fallback in `callable_infer_inner`: now decorator-only
- Registry fallback in `maybe_wrap_nn_module`: now ClassMetadata-only
- Phase 2 unit test in `meta_shape_dsl.rs`: superseded by e2e tests

The `tensor_shapes` config flag remains — it gates non-registry behavior
(Tensor subscript support, jaxtyping, operator overloads).

Differential Revision: D105775132
…lution

Summary:
Each `shape_dsl_function` now carries only its transitive callees instead
of every DSL function in the module. A leaf function like `randn_ir` gets
`helpers = [self]` (1 entry) instead of all 86 siblings. `reshape_ir` gets
`[reshape_ir, normalize_dim]` (2 entries). `movedim_ir` gets its 6
transitive callees.

Adds `ShapeDslFunction::call_targets()` which walks `DslBody`/`DslExpr`
collecting `DslCallTarget::UserDefined` names, and
`compute_transitive_helpers` which resolves those names against the
per-module `BindingsMetadata` index and computes the fixed-point closure.

Behaviorally identical — same shape inference results, same types. The
change reduces per-function memory footprint and will enable finer cache
invalidation once solver dependency edges are per-helper.

Differential Revision: D105783601
Summary:
Run `type_check_program` on each DSL function's transitive-callee closure
at `shape_dsl_function` solve time. This validates that cross-function
call signatures are consistent (e.g., `reshape_ir` calling `normalize_dim`
with the right argument types).

Previously, validation was done once globally by the deleted
`TensorOpsRegistry`. With per-caller resolution from 7a-1, each function
is now validated against only its actual callees. Panics on type errors
for now; Phase 7b will convert to diagnostics.

Differential Revision: D105783602
Summary:
When `uses_shape_dsl(ir_fn)` references a function that is not decorated
with `shape_dsl_function`, emit an `InvalidArgument` diagnostic instead
of silently producing a function with no shape inference. The function
still falls back to its declared return type.

Previously this was a silent no-op — the `if let FunctionKind::ShapeDsl`
chain simply didn't match and `shape_transform` stayed `None` with no
indication to the user that their decorator was misconfigured.

Differential Revision: D105783626
Summary:
When a `shape_dsl_function` body uses unsupported Python syntax, emit a
diagnostic instead of panicking. The function degrades to a normal
`FunctionKind::Def` with no shape inference.

Also adds warnings for unsupported parameter kinds (`*args`, `**kwargs`,
keyword-only, positional-only) — these are silently dropped by the DSL
converter but the user should know they have no effect.

Previously, any `convert_fndef` failure crashed the type checker via
`.expect()`. Third-party stub authors writing new DSL functions would
hit this crash on any syntax mistake.

Differential Revision: D105783605
Summary:
Pull Request resolved: #3487

**This stack**

Reworks tensor shape operations so that instead of being hardcoded in tensor_ops_registry.rs, only DSL primitives are
directly hardcoded into Pyrefly; the actual operations and the association with "normal" (non DSL) stubs all lives in user-space
stub files. This allows iterating on the ops without rebuilding Pryefly, and is 100% essential for actually building out full stubs
for pytorch (and even more so if we want to extend to other libraries like numpy and jax).

The DSL itself is unchanged but we will use a decorator to indicate when a stub function is a DSL function; we use a different
decorator to actually register a DSL function as the "shape transform" associated with some normal function (e.g. to
associate the DSL function `reshape_ir` with a `torch.reshape` function).

Details of the plan are in https://github.com/stroxler/pyrefly-docs/blob/main/tensor-shapes-in-stubs/v2-doc.md

**This commit**

Replace ~20 panic sites in the DSL type checker with error collection,
so type errors in `shape_dsl_function` stubs produce diagnostics instead
of crashing the type checker.

`type_check_program` now returns `Result<(), Vec<String>>`, threading
errors through `check_body`, `check_expr`, `infer_expr`, `infer_call`,
and the narrowing/joining helpers. `validate_shape_dsl_functions`
propagates these errors to the solver, which emits them as
`InvalidArgument` diagnostics on the function definition.

Eval-time panics in `eval_dsl_body` / `eval_dsl_expr` are intentionally
left as panics — they are correctness assertions that should be
unreachable for type-checked programs.

Differential Revision: D105783604
@stroxler stroxler force-pushed the export-D105783604 branch from 719e973 to 53dc247 Compare May 22, 2026 06:48
@github-actions github-actions Bot added size/xl and removed size/xl labels May 22, 2026
@github-actions
Copy link
Copy Markdown

According to mypy_primer, this change doesn't affect type check results on a corpus of open source code. ✅

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant