diff --git a/.cargo/config.toml b/.cargo/config.toml index f3e89cd01e..c3a81b9d69 100644 --- a/.cargo/config.toml +++ b/.cargo/config.toml @@ -10,3 +10,4 @@ MARSHMALLOW_TEST_PATH = { value = "pyrefly/lib/test/marshmallow/third-party", re GLEAN_SNAPSHOTS_PATH = { value = "pyrefly/lib/report/glean/snapshots", relative = true } REPORT_TEST_PATH = { value = "pyrefly/lib/test/report/test_files", relative = true } STUBGEN_TEST_PATH = { value = "pyrefly/lib/test/stubgen", relative = true } +SHAPE_DSL_TEST_PATH = { value = "test/tensor_shapes/fixtures", relative = true } diff --git a/crates/pyrefly_types/src/callable.rs b/crates/pyrefly_types/src/callable.rs index 7ec216a8f8..e60b6cc78b 100644 --- a/crates/pyrefly_types/src/callable.rs +++ b/crates/pyrefly_types/src/callable.rs @@ -12,6 +12,7 @@ use std::fmt; use std::fmt::Display; use std::hash::Hash; use std::hash::Hasher; +use std::ops::Deref; use std::sync::Arc; use dupe::Dupe; @@ -36,10 +37,68 @@ use crate::display::TypeDisplayContext; use crate::equality::TypeEq; use crate::equality::TypeEqCtx; use crate::keywords::DataclassTransformMetadata; +use crate::meta_shape_dsl::ShapeDslFunction; +use crate::meta_shape_dsl::ShapeTransformRef; use crate::type_output::TypeOutput; use crate::types::AnyStyle; use crate::types::Type; +/// A wrapper for derived/cached data that should not participate in +/// equality, hashing, or ordering comparisons. `Derived` always +/// compares as equal, hashes as a no-op, and orders as `Equal`. +/// +/// This is useful for attaching auxiliary data to types that derive +/// `PartialEq`, `Hash`, `Ord`, etc. without affecting their identity. +#[derive(Debug, Clone)] +pub struct Derived(pub T); + +impl PartialEq for Derived { + fn eq(&self, _other: &Self) -> bool { + true + } +} + +impl Eq for Derived {} + +impl Hash for Derived { + fn hash(&self, _state: &mut H) {} +} + +impl PartialOrd for Derived { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Derived { + fn cmp(&self, _other: &Self) -> Ordering { + Ordering::Equal + } +} + +impl Visit for Derived { + const RECURSE_CONTAINS: bool = false; + fn recurse<'a>(&'a self, _: &mut dyn FnMut(&'a Type)) {} +} + +impl VisitMut for Derived { + const RECURSE_CONTAINS: bool = false; + fn recurse_mut(&mut self, _: &mut dyn FnMut(&mut Type)) {} +} + +impl TypeEq for Derived { + fn type_eq(&self, _other: &Self, _ctx: &mut TypeEqCtx) -> bool { + true + } +} + +impl Deref for Derived { + type Target = T; + fn deref(&self) -> &T { + &self.0 + } +} + #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] #[derive(Visit, VisitMut, TypeEq)] pub struct Callable { @@ -645,6 +704,9 @@ pub struct FuncFlags { /// `dataclass_transform` call. See /// https://typing.python.org/en/latest/spec/dataclasses.html#specification. pub dataclass_transform_metadata: Option, + /// A function decorated with `@uses_shape_dsl`, whose return type should be + /// refined by evaluating the referenced shape-DSL function at call sites. + pub shape_transform: Option>, } impl FuncFlags { @@ -810,6 +872,16 @@ pub enum FunctionKind { NumbaJit, /// `numba.njit()` NumbaNjit, + /// A function whose return type is computed by a shape DSL definition. + /// The `FuncId` provides identity (module, class, name) for display and + /// lookup; the `ShapeDslFunction` carries the parsed DSL IR. + ShapeDsl( + Arc, + Arc, + Derived>>>, + ), + /// The `shape_extensions.uses_shape_dsl` decorator function itself. + UsesShapeDsl, } impl Callable { @@ -1185,6 +1257,7 @@ impl FunctionKind { ("typing" | "typing_extensions", None, "disjoint_base") => Self::DisjointBase, ("numba.core.decorators", None, "jit") => Self::NumbaJit, ("numba.core.decorators", None, "njit") => Self::NumbaNjit, + ("shape_extensions", None, "uses_shape_dsl") => Self::UsesShapeDsl, _ => Self::Def(Arc::new(FuncId { module, cls, @@ -1218,6 +1291,8 @@ impl FunctionKind { Self::NumbaJit => ModuleName::from_str("numba"), Self::NumbaNjit => ModuleName::from_str("numba"), Self::Def(func_id) => func_id.module.name().dupe(), + Self::ShapeDsl(id, _, _) => id.module.name().dupe(), + Self::UsesShapeDsl => ModuleName::from_str("shape_extensions"), } } @@ -1244,6 +1319,8 @@ impl FunctionKind { Self::NumbaJit => Cow::Owned(Name::new_static("jit")), Self::NumbaNjit => Cow::Owned(Name::new_static("njit")), Self::Def(func_id) => Cow::Borrowed(&func_id.name), + Self::ShapeDsl(id, _, _) => Cow::Borrowed(&id.name), + Self::UsesShapeDsl => Cow::Owned(Name::new_static("uses_shape_dsl")), } } @@ -1270,12 +1347,14 @@ impl FunctionKind { Self::TotalOrdering => None, Self::DisjointBase => None, Self::Def(func_id) => func_id.cls.clone(), + Self::ShapeDsl(id, _, _) => id.cls.clone(), + Self::UsesShapeDsl => None, } } pub fn outer_funcs(&self) -> Option<&Name> { match self { - Self::Def(func_id) => func_id.outer_funcs.as_ref(), + Self::Def(func_id) | Self::ShapeDsl(func_id, _, _) => func_id.outer_funcs.as_ref(), _ => None, } } diff --git a/crates/pyrefly_types/src/lib.rs b/crates/pyrefly_types/src/lib.rs index 2b6607c4ca..83c7c4ff21 100644 --- a/crates/pyrefly_types/src/lib.rs +++ b/crates/pyrefly_types/src/lib.rs @@ -41,7 +41,6 @@ pub mod simplify; pub mod special_form; pub mod stdlib; pub mod tensor; -pub mod tensor_ops_registry; pub mod tuple; pub mod type_alias; pub mod type_info; diff --git a/crates/pyrefly_types/src/meta_shape_dsl.rs b/crates/pyrefly_types/src/meta_shape_dsl.rs index c6fd1dc27f..e1a8e42e0c 100644 --- a/crates/pyrefly_types/src/meta_shape_dsl.rs +++ b/crates/pyrefly_types/src/meta_shape_dsl.rs @@ -19,24 +19,31 @@ //! //! The data types mirror the DSL grammar defined in `meta_shape_pythonic.md`. +use std::cmp::Ordering; use std::collections::HashMap; +use std::collections::HashSet; use std::fmt; use std::fmt::Debug; +use std::hash::Hash; +use std::hash::Hasher; use std::sync::Arc; -use pyrefly_python::ast::Ast; +use pyrefly_util::visit::Visit; +use pyrefly_util::visit::VisitMut; use ruff_python_ast::BoolOp as RuffBoolOp; use ruff_python_ast::CmpOp as RuffCmpOp; use ruff_python_ast::Expr; use ruff_python_ast::Number; use ruff_python_ast::Operator as RuffOperator; -use ruff_python_ast::PySourceType; use ruff_python_ast::Stmt; use ruff_python_ast::UnaryOp as RuffUnaryOp; +use crate::callable::Derived; use crate::dimension::ShapeError; use crate::dimension::SizeExpr; use crate::dimension::canonicalize; +use crate::equality::TypeEq; +use crate::equality::TypeEqCtx; use crate::lit_int::LitInt; use crate::literal::Lit; use crate::tensor::TensorShape; @@ -44,9 +51,7 @@ use crate::tensor::TensorType; use crate::tuple::Tuple; use crate::types::Type; -// ============================================================================ -// Runtime Values -// ============================================================================ +// Section: Runtime Values /// Runtime value produced by parameter extraction and manipulated by the /// interpreter. Bridges between `Type` (the type-checker's representation) @@ -156,9 +161,7 @@ impl Val { } } -// ============================================================================ -// Extraction Helpers -// ============================================================================ +// Section: Extraction Helpers /// Helper functions for extracting typed values from `Type`. /// @@ -332,9 +335,7 @@ mod extract { } } -// ============================================================================ -// Meta-Shape Function Trait -// ============================================================================ +// Section: Meta-Shape Function Trait /// A function that computes output shapes from input shapes. /// @@ -369,9 +370,7 @@ pub trait MetaShapeFunction: Debug + Send + Sync { } } -// ============================================================================ -// Grammar-aligned data types -// ============================================================================ +// Section: Grammar-aligned data types /// Binary operators: arithmetic, comparison, and logical. /// Corresponds to OP in ` OP `. @@ -563,15 +562,13 @@ enum DslExpr { /// Function definition. Corresponds to `` in the grammar. #[derive(Debug, Clone)] pub(crate) struct DslFnDef { - pub(crate) name: String, + name: String, params: Vec, return_type: Option, body: DslBody, } -// ============================================================================ -// Display implementations -// ============================================================================ +// Section: Display implementations impl fmt::Display for DslOp { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { @@ -607,10 +604,12 @@ impl fmt::Display for DslBuiltin { match self { DslBuiltin::Len => write!(f, "len"), DslBuiltin::Range => write!(f, "range"), - DslBuiltin::Prod => write!(f, "shape_extensions.prod"), - DslBuiltin::Sum => write!(f, "shape_extensions.sum"), + DslBuiltin::Prod => write!(f, "shape_extensions.dsl.prod"), + DslBuiltin::Sum => write!(f, "shape_extensions.dsl.sum"), DslBuiltin::Str => write!(f, "str"), - DslBuiltin::ParseEinsumEquation => write!(f, "shape_extensions.parse_einsum_equation"), + DslBuiltin::ParseEinsumEquation => { + write!(f, "shape_extensions.dsl.parse_einsum_equation") + } DslBuiltin::Enumerate => write!(f, "enumerate"), DslBuiltin::Zip => write!(f, "zip"), } @@ -765,9 +764,7 @@ impl fmt::Display for DslParam { } } -// ============================================================================ -// AST conversion: ruff Python AST → DSL grammar types -// ============================================================================ +// Section: AST conversion: ruff Python AST → DSL grammar types /// Convert an isinstance type argument to a DslTypeCon. fn convert_type_constructor(expr: &Expr) -> Result { @@ -1240,21 +1237,25 @@ fn convert_expr(expr: &Expr) -> Result { } } -/// Convert a function call expression, dispatching special forms. -fn convert_call(call: &ruff_python_ast::ExprCall) -> Result { - // Extract function name for dispatch. Supports both simple names (`len`) - // and dotted names (`shape_extensions.prod`). - let func_name = match call.func.as_ref() { - Expr::Name(n) => n.id.to_string(), +/// Recursively extract a dotted name from an expression (e.g. `shape_extensions.dsl.prod`). +fn dotted_name(expr: &Expr) -> Option { + match expr { + Expr::Name(n) => Some(n.id.to_string()), Expr::Attribute(a) => { - let prefix = match a.value.as_ref() { - Expr::Name(n) => n.id.as_str(), - _ => return Err(format!("unsupported call target: {:?}", call.func)), - }; - format!("{}.{}", prefix, a.attr) + let prefix = dotted_name(&a.value)?; + Some(format!("{}.{}", prefix, a.attr)) } - _ => return Err(format!("unsupported call target: {:?}", call.func)), - }; + _ => None, + } +} + +/// Convert a function call expression, dispatching special forms. +fn convert_call(call: &ruff_python_ast::ExprCall) -> Result { + // Extract function name for dispatch. Supports simple names (`len`), + // single-dotted names (`Tensor`), and multi-dotted names + // (`shape_extensions.dsl.prod`). + let func_name = dotted_name(&call.func) + .ok_or_else(|| format!("unsupported call target: {:?}", call.func))?; match func_name.as_str() { // Special forms with non-call syntax — keep as dedicated DslExpr variants @@ -1324,14 +1325,14 @@ fn convert_call(call: &ruff_python_ast::ExprCall) -> Result { "str" | "enumerate" | "zip" - | "shape_extensions.prod" - | "shape_extensions.sum" - | "shape_extensions.parse_einsum_equation" => { + | "shape_extensions.dsl.prod" + | "shape_extensions.dsl.sum" + | "shape_extensions.dsl.parse_einsum_equation" => { let builtin = match func_name.as_str() { - "shape_extensions.prod" => DslBuiltin::Prod, - "shape_extensions.sum" => DslBuiltin::Sum, + "shape_extensions.dsl.prod" => DslBuiltin::Prod, + "shape_extensions.dsl.sum" => DslBuiltin::Sum, "str" => DslBuiltin::Str, - "shape_extensions.parse_einsum_equation" => DslBuiltin::ParseEinsumEquation, + "shape_extensions.dsl.parse_einsum_equation" => DslBuiltin::ParseEinsumEquation, "enumerate" => DslBuiltin::Enumerate, "zip" => DslBuiltin::Zip, _ => unreachable!(), @@ -1411,9 +1412,7 @@ fn convert_fndef(func: &ruff_python_ast::StmtFunctionDef) -> Result DslType { } } -/// Extract the element type from a list type. Panics on non-list (IR bug). -fn element_type(ty: &DslType) -> DslType { +/// Extract the element type from a list type. Pushes an error on non-list. +fn element_type(ty: &DslType, errors: &mut Vec) -> DslType { match ty { DslType::List(inner) => *inner.clone(), - _ => unreachable!("expected list type, got {}", ty), + _ => { + errors.push(format!("expected list type, got {}", ty)); + DslType::Int + } } } /// Narrow a type to only variants matching a constructor. -fn narrow_to(ty: &DslType, con: DslTypeCon) -> DslType { +fn narrow_to(ty: &DslType, con: DslTypeCon, errors: &mut Vec) -> DslType { match ty { DslType::Union(types) => { let matching: Vec<_> = types @@ -1506,17 +1508,29 @@ fn narrow_to(ty: &DslType, con: DslTypeCon) -> DslType { .collect(); match matching.len() { 1 => matching.into_iter().next().unwrap(), - 0 => unreachable!("isinstance narrowing: no variant of {} matches {}", ty, con), + 0 => { + errors.push(format!( + "isinstance narrowing: no variant of {} matches {}", + ty, con + )); + ty.clone() + } _ => DslType::Union(matching), } } _ if matches_constructor(ty, con) => ty.clone(), - _ => unreachable!("isinstance narrowing: {} does not match {}", ty, con), + _ => { + errors.push(format!( + "isinstance narrowing: {} does not match {}", + ty, con + )); + ty.clone() + } } } /// Narrow a type to exclude variants matching a constructor. -fn narrow_away(ty: &DslType, con: DslTypeCon) -> DslType { +fn narrow_away(ty: &DslType, con: DslTypeCon, errors: &mut Vec) -> DslType { match ty { DslType::Union(types) => { let remaining: Vec<_> = types @@ -1526,7 +1540,10 @@ fn narrow_away(ty: &DslType, con: DslTypeCon) -> DslType { .collect(); match remaining.len() { 1 => remaining.into_iter().next().unwrap(), - 0 => unreachable!("narrowed away all variants of {}", ty), + 0 => { + errors.push(format!("narrowed away all variants of {}", ty)); + ty.clone() + } _ => DslType::Union(remaining), } } @@ -1535,7 +1552,7 @@ fn narrow_away(ty: &DslType, con: DslTypeCon) -> DslType { } /// Narrow a type to exclude None. -fn narrow_away_none(ty: &DslType) -> DslType { +fn narrow_away_none(ty: &DslType, errors: &mut Vec) -> DslType { match ty { DslType::Union(types) => { let remaining: Vec<_> = types @@ -1545,7 +1562,10 @@ fn narrow_away_none(ty: &DslType) -> DslType { .collect(); match remaining.len() { 1 => remaining.into_iter().next().unwrap(), - 0 => unreachable!("narrowed away all variants of {}", ty), + 0 => { + errors.push(format!("narrowed away all variants of {}", ty)); + ty.clone() + } _ => DslType::Union(remaining), } } @@ -1555,7 +1575,7 @@ fn narrow_away_none(ty: &DslType) -> DslType { /// Analyze a condition expression for type narrowing. /// Returns (then_env, else_env) — the environments for the true and false branches. -fn narrow(cond: &DslExpr, env: &TypeEnv) -> (TypeEnv, TypeEnv) { +fn narrow(cond: &DslExpr, env: &TypeEnv, errors: &mut Vec) -> (TypeEnv, TypeEnv) { match cond { // isinstance(x, con) DslExpr::IsInstance { expr, ty } => { @@ -1564,8 +1584,8 @@ fn narrow(cond: &DslExpr, env: &TypeEnv) -> (TypeEnv, TypeEnv) { { let mut then_env = env.clone(); let mut else_env = env.clone(); - then_env.insert(name.clone(), narrow_to(var_ty, *ty)); - else_env.insert(name.clone(), narrow_away(var_ty, *ty)); + then_env.insert(name.clone(), narrow_to(var_ty, *ty, errors)); + else_env.insert(name.clone(), narrow_away(var_ty, *ty, errors)); return (then_env, else_env); } (env.clone(), env.clone()) @@ -1583,7 +1603,7 @@ fn narrow(cond: &DslExpr, env: &TypeEnv) -> (TypeEnv, TypeEnv) { let mut then_env = env.clone(); let mut else_env = env.clone(); then_env.insert(name.clone(), DslType::None); - else_env.insert(name.clone(), narrow_away_none(var_ty)); + else_env.insert(name.clone(), narrow_away_none(var_ty, errors)); return (then_env, else_env); } (env.clone(), env.clone()) @@ -1600,7 +1620,7 @@ fn narrow(cond: &DslExpr, env: &TypeEnv) -> (TypeEnv, TypeEnv) { { let mut then_env = env.clone(); let mut else_env = env.clone(); - then_env.insert(name.clone(), narrow_away_none(var_ty)); + then_env.insert(name.clone(), narrow_away_none(var_ty, errors)); else_env.insert(name.clone(), DslType::None); return (then_env, else_env); } @@ -1611,7 +1631,7 @@ fn narrow(cond: &DslExpr, env: &TypeEnv) -> (TypeEnv, TypeEnv) { op: DslUnaryOp::Not, operand, } => { - let (then_env, else_env) = narrow(operand, env); + let (then_env, else_env) = narrow(operand, env, errors); (else_env, then_env) } // cond1 and cond2 — narrow both in then-branch, conservative in else @@ -1620,8 +1640,8 @@ fn narrow(cond: &DslExpr, env: &TypeEnv) -> (TypeEnv, TypeEnv) { op: DslOp::And, right, } => { - let (then1, _) = narrow(left, env); - let (then2, _) = narrow(right, &then1); + let (then1, _) = narrow(left, env, errors); + let (then2, _) = narrow(right, &then1, errors); (then2, env.clone()) } _ => (env.clone(), env.clone()), @@ -1629,15 +1649,15 @@ fn narrow(cond: &DslExpr, env: &TypeEnv) -> (TypeEnv, TypeEnv) { } /// Build function return type map from DSL definitions. -fn build_fn_ret_types(fndefs: &[DslFnDef]) -> FnRetTypes { +fn build_fn_ret_types(fndefs: &[DslFnDef], errors: &mut Vec) -> FnRetTypes { fndefs .iter() - .map(|f| { - let return_type = f - .return_type - .clone() - .unwrap_or_else(|| unreachable!("DSL function {} must have a return type", f.name)); - (f.name.clone(), return_type) + .filter_map(|f| match &f.return_type { + Some(rt) => Some((f.name.clone(), rt.clone())), + None => { + errors.push(format!("DSL function {} must have a return type", f.name)); + None + } }) .collect() } @@ -1654,56 +1674,75 @@ fn arithmetic_result(a: &DslType, b: &DslType) -> DslType { } /// Infer the element type of a list literal from its elements. -fn infer_list_elem_type(elts: &[DslExpr], env: &TypeEnv, sigs: &FnRetTypes) -> DslType { - assert!( - !elts.is_empty(), - "infer_list_elem_type called with empty list" - ); - let mut result = infer_expr(&elts[0], env, sigs); +fn infer_list_elem_type( + elts: &[DslExpr], + env: &TypeEnv, + sigs: &FnRetTypes, + errors: &mut Vec, +) -> DslType { + if elts.is_empty() { + errors.push("infer_list_elem_type called with empty list".to_owned()); + return DslType::Int; + } + let mut result = infer_expr(&elts[0], env, sigs, errors); for elt in &elts[1..] { - result = join(&result, &infer_expr(elt, env, sigs)); + result = join(&result, &infer_expr(elt, env, sigs, errors)); } result } /// Bind comprehension variables based on the iterator expression. /// Handles zip (multi-list iteration) and enumerate (index + element). -fn bind_comp_vars(vars: &[String], iter: &DslExpr, env: &TypeEnv, sigs: &FnRetTypes) -> TypeEnv { +fn bind_comp_vars( + vars: &[String], + iter: &DslExpr, + env: &TypeEnv, + sigs: &FnRetTypes, + errors: &mut Vec, +) -> TypeEnv { let mut new_env = env.clone(); match iter { DslExpr::Call { func: DslCallTarget::Builtin(DslBuiltin::Zip), args, } => { - assert_eq!( - vars.len(), - args.len(), - "zip: {} vars but {} args", - vars.len(), - args.len() - ); + if vars.len() != args.len() { + errors.push(format!("zip: {} vars but {} args", vars.len(), args.len())); + } for (var, arg) in vars.iter().zip(args.iter()) { - let arg_ty = infer_expr(arg, env, sigs); - new_env.insert(var.clone(), element_type(&arg_ty)); + let arg_ty = infer_expr(arg, env, sigs, errors); + new_env.insert(var.clone(), element_type(&arg_ty, errors)); } } DslExpr::Call { func: DslCallTarget::Builtin(DslBuiltin::Enumerate), args, } => { - assert_eq!(args.len(), 1, "enumerate takes exactly 1 argument"); - assert_eq!(vars.len(), 2, "enumerate requires exactly 2 variables"); - let list_ty = infer_expr(&args[0], env, sigs); - new_env.insert(vars[0].clone(), DslType::Int); - new_env.insert(vars[1].clone(), element_type(&list_ty)); + if args.len() != 1 { + errors.push(format!( + "enumerate takes exactly 1 argument, got {}", + args.len() + )); + } + if vars.len() != 2 { + errors.push(format!( + "enumerate requires exactly 2 variables, got {}", + vars.len() + )); + } + if !args.is_empty() && vars.len() >= 2 { + let list_ty = infer_expr(&args[0], env, sigs, errors); + new_env.insert(vars[0].clone(), DslType::Int); + new_env.insert(vars[1].clone(), element_type(&list_ty, errors)); + } } _ => { - let iter_ty = infer_expr(iter, env, sigs); + let iter_ty = infer_expr(iter, env, sigs, errors); if vars.len() == 1 { - new_env.insert(vars[0].clone(), element_type(&iter_ty)); + new_env.insert(vars[0].clone(), element_type(&iter_ty, errors)); } else { // Multiple vars iterating over a single list — each gets element type. - let elem = element_type(&iter_ty); + let elem = element_type(&iter_ty, errors); for var in vars { new_env.insert(var.clone(), elem.clone()); } @@ -1714,17 +1753,23 @@ fn bind_comp_vars(vars: &[String], iter: &DslExpr, env: &TypeEnv, sigs: &FnRetTy } /// Infer the return type of a function call. -fn infer_call(func: &DslCallTarget, args: &[DslExpr], env: &TypeEnv, sigs: &FnRetTypes) -> DslType { +fn infer_call( + func: &DslCallTarget, + args: &[DslExpr], + env: &TypeEnv, + sigs: &FnRetTypes, + errors: &mut Vec, +) -> DslType { // Infer all arguments for logging, regardless of whether we need them. for arg in args { - infer_expr(arg, env, sigs); + infer_expr(arg, env, sigs, errors); } match func { DslCallTarget::Builtin(builtin) => match builtin { // prod/sum reduce a list of dims to a single dim. DslBuiltin::Prod | DslBuiltin::Sum => { - let arg_ty = infer_expr(&args[0], env, sigs); - element_type(&arg_ty) + let arg_ty = infer_expr(&args[0], env, sigs, errors); + element_type(&arg_ty, errors) } DslBuiltin::Str => DslType::Str, DslBuiltin::ParseEinsumEquation => DslType::List(Box::new(DslType::List(Box::new( @@ -1733,18 +1778,30 @@ fn infer_call(func: &DslCallTarget, args: &[DslExpr], env: &TypeEnv, sigs: &FnRe DslBuiltin::Len => DslType::Int, DslBuiltin::Range => DslType::List(Box::new(DslType::Int)), DslBuiltin::Zip | DslBuiltin::Enumerate => { - unreachable!("{} should only appear as comprehension iterator", builtin) + errors.push(format!( + "{} should only appear as comprehension iterator", + builtin + )); + DslType::Int + } + }, + DslCallTarget::UserDefined(name) => match sigs.get(name) { + Some(ty) => ty.clone(), + None => { + errors.push(format!("undefined function: {}", name)); + DslType::Int } }, - DslCallTarget::UserDefined(name) => sigs - .get(name) - .unwrap_or_else(|| unreachable!("undefined function: {}", name)) - .clone(), } } /// Infer the type of a DSL expression. -fn infer_expr(expr: &DslExpr, env: &TypeEnv, sigs: &FnRetTypes) -> DslType { +fn infer_expr( + expr: &DslExpr, + env: &TypeEnv, + sigs: &FnRetTypes, + errors: &mut Vec, +) -> DslType { match expr { DslExpr::Const(c) => match c { DslConst::None => DslType::None, @@ -1752,63 +1809,65 @@ fn infer_expr(expr: &DslExpr, env: &TypeEnv, sigs: &FnRetTypes) -> DslType { DslConst::Bool(_) => DslType::Bool, DslConst::Str(_) => DslType::Str, }, - DslExpr::Var(name) => env - .get(name) - .cloned() - .unwrap_or_else(|| unreachable!("undefined variable: {}", name)), + DslExpr::Var(name) => match env.get(name) { + Some(ty) => ty.clone(), + None => { + errors.push(format!("undefined variable: {}", name)); + DslType::Int + } + }, DslExpr::List(elts) => { if elts.is_empty() { // All empty list literals in the DSL are dimension lists. DslType::List(Box::new(dim_type())) } else if matches!(elts.last(), Some(DslExpr::Ellipsis)) { // [expr, ...] — unbounded list sentinel. - let elem_ty = infer_list_elem_type(&elts[..elts.len() - 1], env, sigs); + let elem_ty = infer_list_elem_type(&elts[..elts.len() - 1], env, sigs, errors); DslType::List(Box::new(elem_ty)) } else { - let elem_ty = infer_list_elem_type(elts, env, sigs); + let elem_ty = infer_list_elem_type(elts, env, sigs, errors); DslType::List(Box::new(elem_ty)) } } DslExpr::ListComp { elt, vars, iter, .. } => { - let comp_env = bind_comp_vars(vars, iter, env, sigs); - let elt_ty = infer_expr(elt, &comp_env, sigs); + let comp_env = bind_comp_vars(vars, iter, env, sigs, errors); + let elt_ty = infer_expr(elt, &comp_env, sigs, errors); DslType::List(Box::new(elt_ty)) } DslExpr::Index { base, index } => { - let base_ty = infer_expr(base, env, sigs); - infer_expr(index, env, sigs); - element_type(&base_ty) + let base_ty = infer_expr(base, env, sigs, errors); + infer_expr(index, env, sigs, errors); + element_type(&base_ty, errors) } DslExpr::Slice { base, lower, upper } => { - let base_ty = infer_expr(base, env, sigs); + let base_ty = infer_expr(base, env, sigs, errors); if let Some(l) = lower { - infer_expr(l, env, sigs); + infer_expr(l, env, sigs, errors); } if let Some(u) = upper { - infer_expr(u, env, sigs); + infer_expr(u, env, sigs, errors); } base_ty } DslExpr::BinOp { left, op, right } => { - let lt = infer_expr(left, env, sigs); - let rt = infer_expr(right, env, sigs); + let lt = infer_expr(left, env, sigs, errors); + let rt = infer_expr(right, env, sigs, errors); match op { DslOp::Add => { // List concatenation, string concatenation, or numeric addition. if let DslType::List(a) = < { - let DslType::List(b) = &rt else { - unreachable!("+ with list and non-list: {} + {}", lt, rt) - }; - DslType::List(Box::new(join(a, b))) + if let DslType::List(b) = &rt { + DslType::List(Box::new(join(a, b))) + } else { + errors.push(format!("+ with list and non-list: {} + {}", lt, rt)); + DslType::Int + } } else if matches!(lt, DslType::Str) { - assert!( - matches!(rt, DslType::Str), - "+ with str and non-str: {} + {}", - lt, - rt - ); + if !matches!(rt, DslType::Str) { + errors.push(format!("+ with str and non-str: {} + {}", lt, rt)); + } DslType::Str } else { arithmetic_result(<, &rt) @@ -1829,126 +1888,98 @@ fn infer_expr(expr: &DslExpr, env: &TypeEnv, sigs: &FnRetTypes) -> DslType { } DslExpr::UnaryOp { op, operand } => match op { DslUnaryOp::Not => { - infer_expr(operand, env, sigs); + infer_expr(operand, env, sigs, errors); DslType::Bool } - DslUnaryOp::Neg => infer_expr(operand, env, sigs), + DslUnaryOp::Neg => infer_expr(operand, env, sigs, errors), }, - DslExpr::Call { func, args } => infer_call(func, args, env, sigs), + DslExpr::Call { func, args } => infer_call(func, args, env, sigs, errors), DslExpr::IsInstance { expr, .. } => { - infer_expr(expr, env, sigs); + infer_expr(expr, env, sigs, errors); DslType::Bool } DslExpr::In { left, right } => { - infer_expr(left, env, sigs); - infer_expr(right, env, sigs); + infer_expr(left, env, sigs, errors); + infer_expr(right, env, sigs, errors); DslType::Bool } DslExpr::Shape(inner) => { - infer_expr(inner, env, sigs); + infer_expr(inner, env, sigs, errors); DslType::List(Box::new(dim_type())) } DslExpr::TensorNew(inner) => { - infer_expr(inner, env, sigs); + infer_expr(inner, env, sigs, errors); DslType::Tensor } DslExpr::IfExpr { body, test, orelse } => { - let (then_env, else_env) = narrow(test, env); - let body_ty = infer_expr(body, &then_env, sigs); - let else_ty = infer_expr(orelse, &else_env, sigs); + let (then_env, else_env) = narrow(test, env, errors); + let body_ty = infer_expr(body, &then_env, sigs, errors); + let else_ty = infer_expr(orelse, &else_env, sigs, errors); join(&body_ty, &else_ty) } - DslExpr::Ellipsis => unreachable!("Ellipsis should be handled by List"), + DslExpr::Ellipsis => { + errors.push("Ellipsis should be handled by List".to_owned()); + DslType::Int + } DslExpr::Unknown => DslType::None, // sentinel for fixture fallback } } /// Type-check a function body, updating the environment through assignments /// and narrowing through conditionals. -fn check_body(body: &DslBody, env: &TypeEnv, sigs: &FnRetTypes) { +fn check_body(body: &DslBody, env: &TypeEnv, sigs: &FnRetTypes, errors: &mut Vec) { match body { DslBody::Assign { vars, expr, rest } => { - let ty = infer_expr(expr, env, sigs); + let ty = infer_expr(expr, env, sigs, errors); let mut new_env = env.clone(); if vars.len() == 1 { new_env.insert(vars[0].clone(), ty); } else { - let elem = element_type(&ty); + let elem = element_type(&ty, errors); for var in vars { new_env.insert(var.clone(), elem.clone()); } } - check_body(rest, &new_env, sigs); + check_body(rest, &new_env, sigs, errors); } DslBody::If { cond, then_body, rest, } => { - let (then_env, else_env) = narrow(cond, env); - check_body(then_body, &then_env, sigs); - check_body(rest, &else_env, sigs); + let (then_env, else_env) = narrow(cond, env, errors); + check_body(then_body, &then_env, sigs, errors); + check_body(rest, &else_env, sigs, errors); } DslBody::Return(expr) => { - infer_expr(expr, env, sigs); + infer_expr(expr, env, sigs, errors); } DslBody::Raise(expr) => { - infer_expr(expr, env, sigs); + infer_expr(expr, env, sigs, errors); } } } -/// Type-check all DSL function definitions. -fn type_check_program(fndefs: &[DslFnDef]) { - let sigs = build_fn_ret_types(fndefs); +/// Type-check all DSL function definitions. Returns `Err` with collected +/// error messages if type errors are found. +fn type_check_program(fndefs: &[DslFnDef]) -> Result<(), Vec> { + let mut errors = Vec::new(); + let sigs = build_fn_ret_types(fndefs, &mut errors); for fndef in fndefs { let mut env = TypeEnv::new(); for param in &fndef.params { env.insert(param.name.clone(), param.ty.clone()); } - check_body(&fndef.body, &env, &sigs); + check_body(&fndef.body, &env, &sigs, &mut errors); } -} - -// ============================================================================ -// Entry point -// ============================================================================ - -/// Parse DSL source code, convert to grammar-aligned types, and return the -/// list of function definitions. -pub(crate) fn parse_dsl(source: &str) -> Result, String> { - let (module, errors, _unsupported) = Ast::parse(source, PySourceType::Python); - if !errors.is_empty() { - return Err(format!( - "DSL syntax errors:\n{}", - errors - .iter() - .map(|e| format!(" {}", e)) - .collect::>() - .join("\n") - )); + if errors.is_empty() { + Ok(()) + } else { + Err(errors) } - - let fndefs: Vec = module - .body - .iter() - .filter_map(|stmt| { - if let Stmt::FunctionDef(f) = stmt { - Some(convert_fndef(f)) - } else { - None // skip comments, blank lines (not in AST anyway) - } - }) - .collect::>()?; - - type_check_program(&fndefs); - - Ok(fndefs) } -// ============================================================================ -// Interpreter — evaluate DSL directly against runtime Val values -// ============================================================================ +// Section: Interpreter — evaluate DSL directly against runtime Val values /// Extract a runtime `Val` from a type-checker `Type` based on the declared `DslType`. /// `actual_arg_type` is the type the user passed; `expected_param_type` is the DSL @@ -2945,6 +2976,21 @@ fn val_to_type( ), }, + // Int and Bool synthesize Literal[n] / Literal[bool] from the DSL's + // traced runtime value, just like SymInt does via `val_to_scalar_type`. + // This is intentionally load-bearing: functions like `dim_ir`, + // `numel_ir`, and `size_ir(dim=N)` trace exact integer results, and + // downstream consumers (assert_type, reshape validation, shape + // inference) rely on the literal precision. Returning + // `expected_return_type` here would discard the traced value and + // produce `int` instead of e.g. `Literal[3]`. + // + // This differs from the Tensor/List/Tuple/None/Str branches, which + // return `expected_return_type.clone()`. Those branches are correct + // because their `expected_return_type` already carries the refined + // structure (e.g. `Tensor[B, C, H, W]` with shape injected). For + // scalars, the fixture return type is just `int` — the literal value + // comes solely from DSL evaluation. DslType::Int => match val { Val::Int(n) => Lit::Int(LitInt::new(n)).to_implicit_type(), _ => panic!( @@ -2953,6 +2999,10 @@ fn val_to_type( ), }, + // SymInt synthesizes a type from the traced `Val`: `val_to_scalar_type` + // returns `Type::Dim` for `Val::Dim` and `Literal[n]` for `Val::Int`. + // The trace value is load-bearing for shape inference — downstream + // tensor shape types are built from these dimension representations. DslType::SymInt => val_to_scalar_type(&val), DslType::Bool => match val { @@ -3043,12 +3093,12 @@ fn val_to_type( /// A `MetaShapeFunction` backed by a parsed DSL function definition. /// The DSL is interpreted directly — no IR conversion. #[derive(Debug)] -pub(crate) struct DslMetaShapeFunction { +struct DslMetaShapeFunction { /// The primary function to evaluate. - pub(crate) fn_def: Arc, + fn_def: Arc, /// Precomputed lookup table mapping function names to definitions. /// Shared across all instances — built once at registry init. - pub(crate) fn_lookup: Arc>>, + fn_lookup: Arc>>, } impl MetaShapeFunction for DslMetaShapeFunction { @@ -3103,3 +3153,461 @@ impl MetaShapeFunction for DslMetaShapeFunction { } } } +// Section: Public wrapper API +// +// These wrappers form the public surface of +// `pyrefly_types::meta_shape_dsl`. They let callers outside this module (the +// binder and solver in `pyrefly/lib`) drive the DSL pipeline without exposing +// the grammar-aligned `DslFnDef` internals. +// +// Constraint: the public surface never returns `Arc`. Callers store +// `ShapeDslFunction` / `ShapeDslProgram` opaquely; only code inside +// `pyrefly_types` (e.g. `tensor_ops_registry`) may reach the underlying +// `DslFnDef` via the `pub(crate)` fields. + +/// A single DSL function that has been lowered from its Python AST. +/// +/// This is a cheap (one `Arc`) opaque handle. It is the unit produced by +/// [`convert_shape_dsl_function`] and consumed by [`build_shape_dsl_program`]. +#[derive(Debug, Clone)] +pub struct ShapeDslFunction { + pub(crate) inner: Arc, +} + +/// Pointer identity: two `ShapeDslFunction`s are equal iff they point to the +/// same `DslFnDef` allocation. +impl PartialEq for ShapeDslFunction { + fn eq(&self, other: &Self) -> bool { + Arc::ptr_eq(&self.inner, &other.inner) + } +} + +impl Eq for ShapeDslFunction {} + +impl Hash for ShapeDslFunction { + fn hash(&self, state: &mut H) { + (Arc::as_ptr(&self.inner) as *const () as usize).hash(state); + } +} + +impl PartialOrd for ShapeDslFunction { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for ShapeDslFunction { + fn cmp(&self, other: &Self) -> Ordering { + let self_ptr = Arc::as_ptr(&self.inner) as *const () as usize; + let other_ptr = Arc::as_ptr(&other.inner) as *const () as usize; + self_ptr.cmp(&other_ptr) + } +} + +/// DSL IR contains no `Type` values, so visiting is a no-op. +impl Visit for ShapeDslFunction { + const RECURSE_CONTAINS: bool = false; + fn recurse<'a>(&'a self, _: &mut dyn FnMut(&'a Type)) {} +} + +/// DSL IR contains no `Type` values, so visiting is a no-op. +impl VisitMut for ShapeDslFunction { + const RECURSE_CONTAINS: bool = false; + fn recurse_mut(&mut self, _: &mut dyn FnMut(&mut Type)) {} +} + +/// DSL IR contains no `Type` values, so visiting through `Arc` is also a no-op. +impl Visit for Arc { + const RECURSE_CONTAINS: bool = false; + fn recurse<'a>(&'a self, _: &mut dyn FnMut(&'a Type)) {} +} + +/// DSL IR contains no `Type` values, so visiting through `Arc` is also a no-op. +impl VisitMut for Arc { + const RECURSE_CONTAINS: bool = false; + fn recurse_mut(&mut self, _: &mut dyn FnMut(&mut Type)) {} +} + +impl TypeEq for ShapeDslFunction { + fn type_eq(&self, other: &Self, _ctx: &mut TypeEqCtx) -> bool { + self == other + } +} + +impl ShapeDslFunction { + /// The function name from the DSL definition. + pub fn name(&self) -> &str { + &self.inner.name + } + + /// Returns the set of user-defined function names called in this function's body. + pub fn call_targets(&self) -> HashSet { + let mut targets = HashSet::new(); + collect_call_targets_body(&self.inner.body, &mut targets); + targets + } +} + +/// Walk a `DslBody` and collect all `DslCallTarget::UserDefined` names. +fn collect_call_targets_body(body: &DslBody, targets: &mut HashSet) { + match body { + DslBody::Assign { expr, rest, .. } => { + collect_call_targets_expr(expr, targets); + collect_call_targets_body(rest, targets); + } + DslBody::If { + cond, + then_body, + rest, + } => { + collect_call_targets_expr(cond, targets); + collect_call_targets_body(then_body, targets); + collect_call_targets_body(rest, targets); + } + DslBody::Return(expr) | DslBody::Raise(expr) => { + collect_call_targets_expr(expr, targets); + } + } +} + +/// Walk a `DslExpr` and collect all `DslCallTarget::UserDefined` names. +fn collect_call_targets_expr(expr: &DslExpr, targets: &mut HashSet) { + match expr { + DslExpr::Call { + func: DslCallTarget::UserDefined(name), + args, + } => { + targets.insert(name.clone()); + for arg in args { + collect_call_targets_expr(arg, targets); + } + } + DslExpr::Call { args, .. } => { + for arg in args { + collect_call_targets_expr(arg, targets); + } + } + DslExpr::List(items) => { + for item in items { + collect_call_targets_expr(item, targets); + } + } + DslExpr::ListComp { + elt, iter, cond, .. + } => { + collect_call_targets_expr(elt, targets); + collect_call_targets_expr(iter, targets); + if let Some(c) = cond { + collect_call_targets_expr(c, targets); + } + } + DslExpr::Index { base, index } => { + collect_call_targets_expr(base, targets); + collect_call_targets_expr(index, targets); + } + DslExpr::Slice { base, lower, upper } => { + collect_call_targets_expr(base, targets); + if let Some(l) = lower { + collect_call_targets_expr(l, targets); + } + if let Some(u) = upper { + collect_call_targets_expr(u, targets); + } + } + DslExpr::BinOp { left, right, .. } => { + collect_call_targets_expr(left, targets); + collect_call_targets_expr(right, targets); + } + DslExpr::UnaryOp { operand, .. } => { + collect_call_targets_expr(operand, targets); + } + DslExpr::IsInstance { expr, .. } => { + collect_call_targets_expr(expr, targets); + } + DslExpr::In { left, right } => { + collect_call_targets_expr(left, targets); + collect_call_targets_expr(right, targets); + } + DslExpr::Shape(expr) | DslExpr::TensorNew(expr) => { + collect_call_targets_expr(expr, targets); + } + DslExpr::IfExpr { body, test, orelse } => { + collect_call_targets_expr(body, targets); + collect_call_targets_expr(test, targets); + collect_call_targets_expr(orelse, targets); + } + DslExpr::Const(_) | DslExpr::Var(_) | DslExpr::Ellipsis | DslExpr::Unknown => {} + } +} + +/// Validate a set of `ShapeDslFunction`s as a program. +/// +/// Runs `type_check_program` on the inner `DslFnDef`s, verifying that +/// cross-function calls have consistent signatures. Returns collected +/// type error messages on failure. +/// +/// Intended to be called with a per-caller transitive closure (root + +/// its resolved helpers), not the full module. +pub fn validate_shape_dsl_functions(fns: &[Arc]) -> Result<(), Vec> { + let defs: Vec = fns.iter().map(|f| (*f.inner).clone()).collect(); + type_check_program(&defs) +} + +/// Reference to a shape-DSL function that refines a callable's return type. +/// Carried on `FuncFlags` for functions decorated with `@uses_shape_dsl`. +#[derive(Debug, Clone)] +pub struct ShapeTransformRef { + pub dsl_fn: Arc, + /// Transitive closure of user-defined helpers called by `dsl_fn`. + pub helpers: Derived>>>, +} + +/// Pointer identity: delegates to `ShapeDslFunction`'s pointer-identity equality. +impl PartialEq for ShapeTransformRef { + fn eq(&self, other: &Self) -> bool { + self.dsl_fn == other.dsl_fn + } +} + +impl Eq for ShapeTransformRef {} + +impl Hash for ShapeTransformRef { + fn hash(&self, state: &mut H) { + self.dsl_fn.hash(state); + } +} + +impl PartialOrd for ShapeTransformRef { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for ShapeTransformRef { + fn cmp(&self, other: &Self) -> Ordering { + self.dsl_fn.cmp(&other.dsl_fn) + } +} + +impl Visit for ShapeTransformRef { + const RECURSE_CONTAINS: bool = false; + fn recurse<'a>(&'a self, _: &mut dyn FnMut(&'a Type)) {} +} + +impl VisitMut for ShapeTransformRef { + const RECURSE_CONTAINS: bool = false; + fn recurse_mut(&mut self, _: &mut dyn FnMut(&mut Type)) {} +} + +impl Visit for Arc { + const RECURSE_CONTAINS: bool = false; + fn recurse<'a>(&'a self, _: &mut dyn FnMut(&'a Type)) {} +} + +impl VisitMut for Arc { + const RECURSE_CONTAINS: bool = false; + fn recurse_mut(&mut self, _: &mut dyn FnMut(&mut Type)) {} +} + +impl TypeEq for ShapeTransformRef { + fn type_eq(&self, other: &Self, _ctx: &mut TypeEqCtx) -> bool { + self == other + } +} + +impl ShapeTransformRef { + /// Build a `MetaShapeFunction` evaluator from this shape transform reference. + /// Populates `fn_lookup` with this function and its transitive callees + /// so that cross-function DSL calls resolve correctly. + pub fn to_meta_shape_function(&self) -> Box { + // helpers contains self and its transitive callees. + let fn_lookup: Arc>> = Arc::new( + self.helpers + .iter() + .map(|h| (h.inner.name.clone(), h.inner.clone())) + .collect(), + ); + Box::new(DslMetaShapeFunction { + fn_def: self.dsl_fn.inner.clone(), + fn_lookup, + }) + } +} + +/// A bundle of DSL functions that have been validated together as a program. +/// +/// The functions held by a `ShapeDslProgram` are guaranteed to have passed +/// `type_check_program` against the program's own signature set. The factory +/// [`make_meta_shape_function`] takes a `&ShapeDslProgram` (not raw pieces) +/// precisely to enforce that callers can only build a `MetaShapeFunction` +/// from validated DSL. +/// +/// The names used as keys for helper resolution are the *local* names a +/// caller's body refers to. Programs are populated with +/// `[self_dsl_fn, ...resolved_callees]` where each callee's `name` field +/// has already been rewritten (or annotated) to its caller-local form. This +/// API does not perform that rewriting itself — it accepts programs that +/// already satisfy that constraint. +#[derive(Debug, Clone)] +pub struct ShapeDslProgram { + pub(crate) fns: Vec>, +} + +/// Convert a single Python function definition into a [`ShapeDslFunction`]. +/// +/// This is pure AST-to-IR lowering — it does not parse source text or run +/// the type checker. The output is a single opaque handle; the caller is +/// expected to combine handles from this function (and possibly other +/// modules) into a [`ShapeDslProgram`] via [`build_shape_dsl_program`]. +/// +/// Returns `Err` with a terse description if the function body uses Python +/// syntax outside the DSL subset. +pub fn convert_shape_dsl_function( + func: &ruff_python_ast::StmtFunctionDef, +) -> Result { + let fndef = convert_fndef(func)?; + Ok(ShapeDslFunction { + inner: Arc::new(fndef), + }) +} + +/// Bundle a set of [`ShapeDslFunction`]s into a validated [`ShapeDslProgram`]. +/// +/// Type-checks the bundle as a whole, building the global signature map that +/// `infer_call` needs in order to resolve cross-function calls. Today this +/// step *panics* on type errors (undefined variable, unknown call target, +/// etc.), matching the existing `parse_dsl` behavior. +pub fn build_shape_dsl_program(fns: impl IntoIterator) -> ShapeDslProgram { + let fns: Vec> = fns.into_iter().map(|f| f.inner).collect(); + // `type_check_program` borrows a `&[DslFnDef]`; clone the Arc'd defs into + // a temporary slice for the call. The clone is one-time per program build + // and avoids changing the internal `type_check_program` signature. + // `type_check_program` panics on type errors today; that matches the + // existing `parse_dsl` semantics. + let view: Vec = fns.iter().map(|f| (**f).clone()).collect(); + // Errors from `type_check_program` are intentionally discarded here: + // `build_shape_dsl_program` is called from test/fixture code where + // panicking is acceptable. The production path goes through + // `validate_shape_dsl_functions` which propagates errors as diagnostics. + let _ = type_check_program(&view); + ShapeDslProgram { fns } +} + +/// Construct a [`MetaShapeFunction`] keyed at `root_name` from a validated +/// [`ShapeDslProgram`]. +/// +/// The factory takes a `&ShapeDslProgram` (not raw pieces) so that a caller +/// cannot build a `MetaShapeFunction` from un-type-checked DSL — the only +/// way to obtain a `ShapeDslProgram` is via [`build_shape_dsl_program`], +/// which runs the type checker. +/// +/// `root_name` selects which function inside the program is the entry point +/// (the one whose parameters get bound from call-site arguments). All +/// functions in the program — including `root_name` — become part of the +/// resulting `fn_lookup`, which the interpreter consults for cross-function +/// calls. The lookup is keyed by each function's `name`, which the caller +/// is responsible for setting to the caller-local name; see +/// [`ShapeDslProgram`] for that invariant. +/// +/// Returns `None` if no function in the program has `name == root_name`. +pub fn make_meta_shape_function( + program: &ShapeDslProgram, + root_name: &str, +) -> Option> { + let fn_def = program + .fns + .iter() + .find(|f| f.name == root_name) + .map(Arc::clone)?; + let fn_lookup: Arc>> = Arc::new( + program + .fns + .iter() + .map(|f| (f.name.clone(), Arc::clone(f))) + .collect(), + ); + Some(Box::new(DslMetaShapeFunction { fn_def, fn_lookup })) +} + +#[cfg(test)] +mod tests { + use pyrefly_python::ast::Ast; + use ruff_python_ast::PySourceType; + use ruff_python_ast::Stmt; + + use super::*; + + fn parse_dsl_functions(source: &str) -> Vec { + let (module, _, _) = Ast::parse(source, PySourceType::Stub); + module + .body + .iter() + .filter_map(|stmt| { + if let Stmt::FunctionDef(func) = stmt { + Some(convert_shape_dsl_function(func).unwrap()) + } else { + None + } + }) + .collect() + } + + #[test] + fn test_call_targets_disjoint_closures() { + let fns = parse_dsl_functions( + r#" +def helper_a(x: int) -> int: + return x + 1 + +def helper_b(x: int) -> int: + return x + 2 + +def calls_a(x: int) -> int: + return helper_a(x) + +def calls_b(x: int) -> int: + return helper_b(x) + +def leaf(x: int) -> int: + return x +"#, + ); + assert_eq!(fns.len(), 5); + + let calls_a = fns.iter().find(|f| f.name() == "calls_a").unwrap(); + let calls_b = fns.iter().find(|f| f.name() == "calls_b").unwrap(); + let leaf = fns.iter().find(|f| f.name() == "leaf").unwrap(); + + assert_eq!( + calls_a.call_targets(), + HashSet::from(["helper_a".to_owned()]) + ); + assert_eq!( + calls_b.call_targets(), + HashSet::from(["helper_b".to_owned()]) + ); + assert!(leaf.call_targets().is_empty()); + } + + #[test] + fn test_call_targets_transitive() { + let fns = parse_dsl_functions( + r#" +def deep(x: int) -> int: + return x + +def mid(x: int) -> int: + return deep(x) + +def top(x: int) -> int: + return mid(x) +"#, + ); + let top = fns.iter().find(|f| f.name() == "top").unwrap(); + let mid = fns.iter().find(|f| f.name() == "mid").unwrap(); + + // call_targets is direct only, not transitive + assert_eq!(top.call_targets(), HashSet::from(["mid".to_owned()])); + assert_eq!(mid.call_targets(), HashSet::from(["deep".to_owned()])); + } +} diff --git a/crates/pyrefly_types/src/tensor_ops_registry.rs b/crates/pyrefly_types/src/tensor_ops_registry.rs deleted file mode 100644 index a40bfd257c..0000000000 --- a/crates/pyrefly_types/src/tensor_ops_registry.rs +++ /dev/null @@ -1,1041 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. - */ - -//! Meta-shape op registry. -//! -//! This module registers all PyTorch op shape functions in `TensorOpsRegistry`. -//! All op definitions are expressed in the DSL (parsed by `meta_shape_dsl.rs`) and -//! interpreted directly — no IR layer. - -use std::collections::HashMap; -use std::sync::Arc; - -use crate::meta_shape_dsl::DslFnDef; -use crate::meta_shape_dsl::DslMetaShapeFunction; -use crate::meta_shape_dsl::MetaShapeFunction; -use crate::meta_shape_dsl::parse_dsl; - -// ============================================================================ -// DSL-based MetaShapeFunction construction -// ============================================================================ - -/// Look up a DSL function by name and create a `DslMetaShapeFunction`. -fn dsl_fn( - fn_lookup: &Arc>>, - name: &str, -) -> Box { - let fn_def = Arc::clone( - fn_lookup - .get(name) - .unwrap_or_else(|| panic!("DSL function `{name}` not found")), - ); - Box::new(DslMetaShapeFunction { - fn_def, - fn_lookup: Arc::clone(fn_lookup), - }) -} - -// ============================================================================ -// Meta-Shape Registry -// ============================================================================ - -/// Registry mapping PyTorch op names to their shape functions. -/// -/// All shape functions are backed by DSL definitions (see `meta_shape_dsl.rs`). -/// The DSL source is parsed once at registry construction time. Parsed definitions -/// are shared via `Arc` across all shape function instances. -pub struct TensorOpsRegistry { - functions: HashMap>, - /// Maps qualified class names (e.g., "torch.nn.MaxPool2d") to the list of - /// __init__ parameter names to capture. When a class has an init capture - /// registered, `construct_class` builds a `Type::NNModule` instead of a - /// `Type::ClassType`, storing the captured arg values in the NNModule's - /// field map. This allows forward DSL functions to access constructor - /// parameters directly from the type. - init_captures: HashMap>, -} - -impl TensorOpsRegistry { - /// Create a new registry with built-in meta-shape functions. - pub fn new() -> Self { - // Parse DSL once; definitions are shared via Arc across all instances. - let dsl_fns: Vec> = parse_dsl(DSL_SOURCE) - .expect("DSL source in tensor_ops_registry.rs has errors") - .into_iter() - .map(Arc::new) - .collect(); - // Build function lookup table once, shared by all DslMetaShapeFunctions. - let fn_lookup: Arc>> = Arc::new( - dsl_fns - .iter() - .map(|f| (f.name.clone(), Arc::clone(f))) - .collect(), - ); - let mut registry = Self { - functions: HashMap::new(), - init_captures: HashMap::new(), - }; - - // Shape manipulation - registry.register_dual("reshape", || dsl_fn(&fn_lookup, "reshape_ir")); - registry.register("torch.cat", dsl_fn(&fn_lookup, "cat_ir")); - registry.register("torch.broadcast_to", dsl_fn(&fn_lookup, "broadcast_to_ir")); - registry.register_dual("squeeze", || dsl_fn(&fn_lookup, "squeeze_ir")); - registry.register_dual("unsqueeze", || dsl_fn(&fn_lookup, "unsqueeze_ir")); - registry.register_dual("transpose", || dsl_fn(&fn_lookup, "transpose_ir")); - // torch.permute takes dims as a tuple; Tensor.permute takes *dims (variadic). - // Both use the same DSL function; parameter binding matches by name. - registry.register("torch.permute", dsl_fn(&fn_lookup, "permute_ir")); - registry.register("torch.Tensor.permute", dsl_fn(&fn_lookup, "permute_ir")); - registry.register("torch.flatten", dsl_fn(&fn_lookup, "flatten_ir")); - registry.register("torch.stack", dsl_fn(&fn_lookup, "stack_ir")); - registry.register("torch.tile", dsl_fn(&fn_lookup, "tile_ir")); - registry.register("torch.view", dsl_fn(&fn_lookup, "reshape_ir")); - registry.register("torch.unbind", dsl_fn(&fn_lookup, "unbind_ir")); - registry.register("torch.Tensor.unbind", dsl_fn(&fn_lookup, "unbind_ir")); - registry.register("torch.movedim", dsl_fn(&fn_lookup, "movedim_ir")); - registry.register("torch.moveaxis", dsl_fn(&fn_lookup, "movedim_ir")); - registry.register("torch.Tensor.movedim", dsl_fn(&fn_lookup, "movedim_ir")); - registry.register("torch.Tensor.moveaxis", dsl_fn(&fn_lookup, "movedim_ir")); - registry.register("torch.unfold", dsl_fn(&fn_lookup, "unfold_ir")); - registry.register("torch.Tensor.unfold", dsl_fn(&fn_lookup, "unfold_ir")); - - // Method-only shape manipulation - registry.register("torch.Tensor.reshape", dsl_fn(&fn_lookup, "reshape_ir")); - registry.register("torch.Tensor.view", dsl_fn(&fn_lookup, "reshape_ir")); - registry.register("torch.Tensor.squeeze", dsl_fn(&fn_lookup, "squeeze_ir")); - registry.register("torch.Tensor.flatten", dsl_fn(&fn_lookup, "flatten_ir")); - registry.register("torch.Tensor.tile", dsl_fn(&fn_lookup, "tile_ir")); - registry.register( - "torch.Tensor.diag_embed", - dsl_fn(&fn_lookup, "diag_embed_ir"), - ); - registry.register("torch.Tensor.repeat", dsl_fn(&fn_lookup, "repeat_ir")); - registry.register("torch.Tensor.expand", dsl_fn(&fn_lookup, "expand_ir")); - - // Reduction operations - registry.register_dual("sum", || dsl_fn(&fn_lookup, "reduce_ir")); - registry.register_dual("mean", || dsl_fn(&fn_lookup, "reduce_ir")); - registry.register_dual("prod", || dsl_fn(&fn_lookup, "reduce_ir")); - registry.register_dual("min", || dsl_fn(&fn_lookup, "min_max_median_ir")); - registry.register_dual("max", || dsl_fn(&fn_lookup, "min_max_median_ir")); - registry.register_dual("all", || dsl_fn(&fn_lookup, "reduce_ir")); - registry.register_dual("any", || dsl_fn(&fn_lookup, "reduce_ir")); - registry.register_dual("std", || dsl_fn(&fn_lookup, "reduce_ir")); - registry.register_dual("var", || dsl_fn(&fn_lookup, "reduce_ir")); - registry.register_dual("argmax", || dsl_fn(&fn_lookup, "reduce_ir")); - registry.register_dual("argmin", || dsl_fn(&fn_lookup, "reduce_ir")); - registry.register("torch.median", dsl_fn(&fn_lookup, "min_max_median_ir")); - registry.register("torch.logsumexp", dsl_fn(&fn_lookup, "reduce_ir")); - registry.register("torch.count_nonzero", dsl_fn(&fn_lookup, "reduce_ir")); - registry.register("torch.aminmax", dsl_fn(&fn_lookup, "aminmax_ir")); - registry.register("torch.norm", dsl_fn(&fn_lookup, "reduce_ir")); - registry.register("torch.mode", dsl_fn(&fn_lookup, "tuple_reduce_ir")); - registry.register("torch.topk", dsl_fn(&fn_lookup, "topk_ir")); - registry.register("torch.kthvalue", dsl_fn(&fn_lookup, "tuple_reduce_ir")); - registry.register("torch.var_mean", dsl_fn(&fn_lookup, "aminmax_ir")); - registry.register("torch.std_mean", dsl_fn(&fn_lookup, "aminmax_ir")); - - // Reduction method versions - registry.register( - "torch.Tensor.median", - dsl_fn(&fn_lookup, "min_max_median_ir"), - ); - registry.register("torch.Tensor.logsumexp", dsl_fn(&fn_lookup, "reduce_ir")); - registry.register( - "torch.Tensor.count_nonzero", - dsl_fn(&fn_lookup, "reduce_ir"), - ); - registry.register("torch.Tensor.aminmax", dsl_fn(&fn_lookup, "aminmax_ir")); - registry.register("torch.Tensor.norm", dsl_fn(&fn_lookup, "reduce_ir")); - registry.register("torch.Tensor.mode", dsl_fn(&fn_lookup, "tuple_reduce_ir")); - registry.register("torch.Tensor.topk", dsl_fn(&fn_lookup, "topk_ir")); - registry.register( - "torch.Tensor.kthvalue", - dsl_fn(&fn_lookup, "tuple_reduce_ir"), - ); - - // Repeat interleave - registry.register( - "torch.Tensor.repeat_interleave", - dsl_fn(&fn_lookup, "repeat_interleave_ir"), - ); - registry.register( - "torch.repeat_interleave", - dsl_fn(&fn_lookup, "repeat_interleave_ir"), - ); - - // Cosine similarity (reduces one dim) - registry.register( - "torch.nn.functional.cosine_similarity", - dsl_fn(&fn_lookup, "cosine_similarity_ir"), - ); - - // Indexing/slicing - registry.register("torch.select", dsl_fn(&fn_lookup, "select_ir")); - registry.register("torch.narrow", dsl_fn(&fn_lookup, "narrow_ir")); - registry.register("torch.split", dsl_fn(&fn_lookup, "split_ir")); - registry.register("torch.chunk", dsl_fn(&fn_lookup, "chunk_ir")); - registry.register("torch.index_select", dsl_fn(&fn_lookup, "index_select_ir")); - registry.register("torch.Tensor.select", dsl_fn(&fn_lookup, "select_ir")); - registry.register("torch.Tensor.narrow", dsl_fn(&fn_lookup, "narrow_ir")); - registry.register("torch.Tensor.split", dsl_fn(&fn_lookup, "split_ir")); - registry.register("torch.Tensor.chunk", dsl_fn(&fn_lookup, "chunk_ir")); - registry.register( - "torch.Tensor.index_select", - dsl_fn(&fn_lookup, "index_select_ir"), - ); - - // Tensor creation - registry.register("torch.randn", dsl_fn(&fn_lookup, "randn_ir")); - registry.register("torch.rand", dsl_fn(&fn_lookup, "randn_ir")); - registry.register("torch.zeros", dsl_fn(&fn_lookup, "randn_ir")); - registry.register("torch.ones", dsl_fn(&fn_lookup, "randn_ir")); - registry.register("torch.empty", dsl_fn(&fn_lookup, "randn_ir")); - registry.register("torch.full", dsl_fn(&fn_lookup, "randn_ir")); - registry.register("torch.randint", dsl_fn(&fn_lookup, "randint_ir")); - registry.register("torch.arange", dsl_fn(&fn_lookup, "arange_ir")); - registry.register("torch.linspace", dsl_fn(&fn_lookup, "linspace_ir")); - registry.register("torch.eye", dsl_fn(&fn_lookup, "eye_ir")); - registry.register("torch.diag_embed", dsl_fn(&fn_lookup, "diag_embed_ir")); - registry.register("torch.tril_indices", dsl_fn(&fn_lookup, "tri_indices_ir")); - registry.register("torch.triu_indices", dsl_fn(&fn_lookup, "tri_indices_ir")); - - // Linear algebra - registry.register("torch.matmul", dsl_fn(&fn_lookup, "matmul_ir")); - registry.register("torch.mv", dsl_fn(&fn_lookup, "mv_ir")); - registry.register("torch.outer", dsl_fn(&fn_lookup, "outer_ir")); - registry.register("torch.tensordot", dsl_fn(&fn_lookup, "tensordot_ir")); - registry.register("torch.einsum", dsl_fn(&fn_lookup, "einsum_ir")); - registry.register("torch.Tensor.matmul", dsl_fn(&fn_lookup, "matmul_ir")); - registry.register("torch.Tensor.__matmul__", dsl_fn(&fn_lookup, "matmul_ir")); - registry.register("torch.Tensor.mv", dsl_fn(&fn_lookup, "mv_ir")); - - // Eigenvalue decomposition - registry.register("torch.linalg.eig", dsl_fn(&fn_lookup, "eig_ir")); - registry.register("torch.eig", dsl_fn(&fn_lookup, "eig_ir")); - registry.register("torch.linalg.eigh", dsl_fn(&fn_lookup, "eig_ir")); - registry.register("torch.eigh", dsl_fn(&fn_lookup, "eig_ir")); - registry.register("torch.linalg.eigvals", dsl_fn(&fn_lookup, "eigvals_ir")); - registry.register("torch.linalg.eigvalsh", dsl_fn(&fn_lookup, "eigvals_ir")); - - // Linear solvers - registry.register("torch.linalg.solve", dsl_fn(&fn_lookup, "solve_ir")); - registry.register("torch.solve", dsl_fn(&fn_lookup, "solve_ir")); - registry.register( - "torch.linalg.solve_triangular", - dsl_fn(&fn_lookup, "solve_ir"), - ); - registry.register( - "torch.triangular_solve", - dsl_fn(&fn_lookup, "solve_reversed_ir"), - ); - registry.register( - "torch.linalg.cholesky_solve", - dsl_fn(&fn_lookup, "solve_reversed_ir"), - ); - registry.register( - "torch.cholesky_solve", - dsl_fn(&fn_lookup, "solve_reversed_ir"), - ); - registry.register("torch.lu_solve", dsl_fn(&fn_lookup, "solve_ir")); - - // Determinant - registry.register("torch.linalg.slogdet", dsl_fn(&fn_lookup, "slogdet_ir")); - registry.register("torch.slogdet", dsl_fn(&fn_lookup, "slogdet_ir")); - registry.register("torch.Tensor.slogdet", dsl_fn(&fn_lookup, "slogdet_ir")); - - // Convolution - registry.register("torch.nn.functional.conv1d", dsl_fn(&fn_lookup, "conv_ir")); - registry.register("torch.nn.functional.conv2d", dsl_fn(&fn_lookup, "conv_ir")); - registry.register("torch.nn.functional.conv3d", dsl_fn(&fn_lookup, "conv_ir")); - registry.register( - "torch.nn.functional.conv_transpose1d", - dsl_fn(&fn_lookup, "conv_transpose_ir"), - ); - registry.register( - "torch.nn.functional.conv_transpose2d", - dsl_fn(&fn_lookup, "conv_transpose_ir"), - ); - registry.register( - "torch.nn.functional.conv_transpose3d", - dsl_fn(&fn_lookup, "conv_transpose_ir"), - ); - - // Pooling - registry.register( - "torch.nn.functional.max_pool1d", - dsl_fn(&fn_lookup, "pool_ir"), - ); - registry.register( - "torch.nn.functional.max_pool2d", - dsl_fn(&fn_lookup, "pool_ir"), - ); - registry.register( - "torch.nn.functional.max_pool3d", - dsl_fn(&fn_lookup, "pool_ir"), - ); - registry.register( - "torch.nn.functional.avg_pool1d", - dsl_fn(&fn_lookup, "pool_ir"), - ); - registry.register( - "torch.nn.functional.avg_pool2d", - dsl_fn(&fn_lookup, "pool_ir"), - ); - registry.register( - "torch.nn.functional.avg_pool3d", - dsl_fn(&fn_lookup, "pool_ir"), - ); - registry.register( - "torch.nn.functional.adaptive_max_pool1d", - dsl_fn(&fn_lookup, "adaptive_pool_ir"), - ); - registry.register( - "torch.nn.functional.adaptive_max_pool2d", - dsl_fn(&fn_lookup, "adaptive_pool_ir"), - ); - registry.register( - "torch.nn.functional.adaptive_max_pool3d", - dsl_fn(&fn_lookup, "adaptive_pool_ir"), - ); - registry.register( - "torch.nn.functional.adaptive_avg_pool1d", - dsl_fn(&fn_lookup, "adaptive_pool_ir"), - ); - registry.register( - "torch.nn.functional.adaptive_avg_pool2d", - dsl_fn(&fn_lookup, "adaptive_pool_ir"), - ); - registry.register( - "torch.nn.functional.adaptive_avg_pool3d", - dsl_fn(&fn_lookup, "adaptive_pool_ir"), - ); - - // Interpolation - registry.register( - "torch.nn.functional.interpolate", - dsl_fn(&fn_lookup, "interpolate_ir"), - ); - registry.register( - "torch.nn.functional.upsample", - dsl_fn(&fn_lookup, "interpolate_ir"), - ); - - // Conditional operations - registry.register("torch.where", dsl_fn(&fn_lookup, "where_ir")); - registry.register( - "torch.take_along_dim", - dsl_fn(&fn_lookup, "take_along_dim_ir"), - ); - registry.register( - "torch.Tensor.take_along_dim", - dsl_fn(&fn_lookup, "take_along_dim_ir"), - ); - - // Loss functions - registry.register( - "torch.nn.functional.mse_loss", - dsl_fn(&fn_lookup, "loss_ir"), - ); - registry.register("torch.nn.functional.l1_loss", dsl_fn(&fn_lookup, "loss_ir")); - registry.register( - "torch.nn.functional.nll_loss", - dsl_fn(&fn_lookup, "loss_ir"), - ); - registry.register( - "torch.nn.functional.cross_entropy", - dsl_fn(&fn_lookup, "loss_ir"), - ); - registry.register( - "torch.nn.functional.binary_cross_entropy", - dsl_fn(&fn_lookup, "loss_ir"), - ); - registry.register( - "torch.nn.functional.binary_cross_entropy_with_logits", - dsl_fn(&fn_lookup, "loss_ir"), - ); - registry.register("torch.nn.functional.kl_div", dsl_fn(&fn_lookup, "loss_ir")); - registry.register( - "torch.nn.functional.smooth_l1_loss", - dsl_fn(&fn_lookup, "loss_ir"), - ); - registry.register( - "torch.nn.functional.huber_loss", - dsl_fn(&fn_lookup, "loss_ir"), - ); - registry.register( - "torch.nn.functional.poisson_nll_loss", - dsl_fn(&fn_lookup, "loss_ir"), - ); - registry.register( - "torch.nn.functional.cosine_embedding_loss", - dsl_fn(&fn_lookup, "loss_ir"), - ); - registry.register( - "torch.nn.functional.margin_ranking_loss", - dsl_fn(&fn_lookup, "loss_ir"), - ); - registry.register( - "torch.nn.functional.triplet_margin_loss", - dsl_fn(&fn_lookup, "loss_ir"), - ); - registry.register( - "torch.nn.functional.hinge_embedding_loss", - dsl_fn(&fn_lookup, "loss_ir"), - ); - - // Padding - registry.register("torch.nn.functional.pad", dsl_fn(&fn_lookup, "pad_ir")); - - // FFT - registry.register("torch.fft.rfft", dsl_fn(&fn_lookup, "rfft_ir")); - registry.register("torch.fft.irfft", dsl_fn(&fn_lookup, "irfft_ir")); - registry.register("torch.fft.hfft", dsl_fn(&fn_lookup, "irfft_ir")); - registry.register("torch.fft.ihfft", dsl_fn(&fn_lookup, "rfft_ir")); - - // Tensor properties - registry.register("torch.Tensor.size", dsl_fn(&fn_lookup, "size_ir")); - registry.register("torch.Tensor.numel", dsl_fn(&fn_lookup, "numel_ir")); - registry.register("torch.Tensor.dim", dsl_fn(&fn_lookup, "dim_ir")); - registry.register("torch.Tensor.nelement", dsl_fn(&fn_lookup, "numel_ir")); - registry.register("torch.Tensor.item", dsl_fn(&fn_lookup, "item_ir")); - registry.register("torch.Tensor.tolist", dsl_fn(&fn_lookup, "tolist_ir")); - registry.register("torch.numel", dsl_fn(&fn_lookup, "numel_ir")); - - // nn.Module forward methods with init capture. - // register_init_forward registers both the forward DSL function and the - // list of __init__ params to capture in the NNModule type. - let maxpool_captures = &["kernel_size", "stride", "padding", "dilation"]; - registry.register_init_forward( - &fn_lookup, - "torch.nn.MaxPool1d", - "nn_maxpool_forward_ir", - maxpool_captures, - ); - registry.register_init_forward( - &fn_lookup, - "torch.nn.MaxPool2d", - "nn_maxpool_forward_ir", - maxpool_captures, - ); - registry.register_init_forward( - &fn_lookup, - "torch.nn.MaxPool3d", - "nn_maxpool_forward_ir", - maxpool_captures, - ); - - let avgpool_captures = &["kernel_size", "stride", "padding"]; - registry.register_init_forward( - &fn_lookup, - "torch.nn.AvgPool1d", - "nn_avgpool_forward_ir", - avgpool_captures, - ); - registry.register_init_forward( - &fn_lookup, - "torch.nn.AvgPool2d", - "nn_avgpool_forward_ir", - avgpool_captures, - ); - registry.register_init_forward( - &fn_lookup, - "torch.nn.AvgPool3d", - "nn_avgpool_forward_ir", - avgpool_captures, - ); - - registry.register_init_forward( - &fn_lookup, - "torch.nn.Flatten", - "nn_flatten_forward_ir", - &["start_dim", "end_dim"], - ); - registry.register_init_forward( - &fn_lookup, - "torch.nn.PixelShuffle", - "nn_pixel_shuffle_forward_ir", - &["upscale_factor"], - ); - registry.register_init_forward(&fn_lookup, "torch.nn.GLU", "nn_glu_forward_ir", &["dim"]); - registry.register_init_forward( - &fn_lookup, - "torch.nn.LSTM", - "nn_lstm_forward_ir", - &["input_size", "hidden_size", "num_layers", "bidirectional"], - ); - registry.register_init_forward( - &fn_lookup, - "torch.nn.Upsample", - "nn_upsample_forward_ir", - &["size", "scale_factor"], - ); - registry.register_init_forward( - &fn_lookup, - "torch.nn.GRU", - "nn_gru_forward_ir", - &["input_size", "hidden_size", "num_layers", "bidirectional"], - ); - registry.register_init_forward( - &fn_lookup, - "torch.nn.LSTMCell", - "nn_lstmcell_forward_ir", - &["input_size", "hidden_size"], - ); - registry.register_init_forward( - &fn_lookup, - "torch.nn.ReflectionPad2d", - "nn_reflectionpad2d_forward_ir", - &["padding"], - ); - registry.register_init_forward( - &fn_lookup, - "torch.nn.ReplicationPad2d", - "nn_reflectionpad2d_forward_ir", - &["padding"], - ); - - // Random sampling - registry.register("torch.multinomial", dsl_fn(&fn_lookup, "multinomial_ir")); - registry.register( - "torch.Tensor.multinomial", - dsl_fn(&fn_lookup, "multinomial_ir"), - ); - registry.register("torch.normal", dsl_fn(&fn_lookup, "normal_ir")); - - registry - } - - /// Register a meta-shape function. - pub fn register(&mut self, name: impl Into, func: Box) { - self.functions.insert(name.into(), func); - } - - /// Register a meta-shape function for both `torch.X` and `torch.Tensor.X`. - pub fn register_dual Box>( - &mut self, - op_name: &str, - factory: F, - ) { - self.functions - .insert(format!("torch.{}", op_name), factory()); - self.functions - .insert(format!("torch.Tensor.{}", op_name), factory()); - } - - /// Get a meta-shape function by name. - pub fn get(&self, name: &str) -> Option<&dyn MetaShapeFunction> { - self.functions.get(name).map(|b| b.as_ref()) - } - - /// Register an nn.Module: both the forward DSL function and the list of - /// __init__ parameter names to capture in the NNModule type. - /// - /// `class_name` is the qualified class name (e.g., `"torch.nn.MaxPool2d"`). - /// This registers the forward function under `"{class_name}.forward"` and - /// the init captures under `"{class_name}"`. - fn register_init_forward( - &mut self, - fn_lookup: &Arc>>, - class_name: &str, - dsl_fn_name: &str, - capture_params: &[&str], - ) { - self.functions.insert( - format!("{class_name}.forward"), - dsl_fn(fn_lookup, dsl_fn_name), - ); - self.init_captures.insert( - class_name.to_owned(), - capture_params.iter().map(|s| (*s).to_owned()).collect(), - ); - } - - /// Look up init capture config for a qualified class name. - /// Returns the list of __init__ parameter names to capture. - pub fn get_init_capture(&self, class_name: &str) -> Option<&[String]> { - self.init_captures.get(class_name).map(|v| v.as_slice()) - } -} - -impl Default for TensorOpsRegistry { - fn default() -> Self { - Self::new() - } -} - -// ============================================================================ -// DSL source code -// ============================================================================ - -/// The full DSL source defining all tensor shape ops and utility functions. -/// This is valid Python syntax (a strict subset) that we parse with Pyrefly's parser. -const DSL_SOURCE: &str = r#" -def normalize_dim(rank: int, dim: int) -> int: - if dim < 0: - return dim + rank - return dim - -def int_max(a: int, b: int) -> int: - if a > b: - return a - return b - -def replace_dim(dims: list[int | symint], i: int, value: int | symint) -> list[int | symint]: - return dims[:i] + [value] + dims[i + 1:] - -def remove_dim(dims: list[int | symint], i: int) -> list[int | symint]: - return dims[:i] + dims[i + 1:] - -def insert_dim(dims: list[int | symint], i: int, value: int | symint) -> list[int | symint]: - return dims[:i] + [value] + dims[i:] - -def broadcast(a: list[int | symint], b: list[int | symint]) -> list[int | symint]: - max_len = int_max(len(a), len(b)) - padded_a = [1 for _ in range(max_len - len(a))] + a - padded_b = [1 for _ in range(max_len - len(b))] + b - return [bd if ad == 1 else ad for ad, bd in zip(padded_a, padded_b)] - -def broadcast_int(expr: int | symint | list[int | symint], n: int) -> list[int | symint]: - if isinstance(expr, list): - return expr - return [expr for _ in range(n)] - -def reduce_shape(dims: list[int | symint], dim: int | list[int] | None, keepdim: bool) -> list[int | symint]: - if dim == None: - if keepdim: - return [1 for _ in range(len(dims))] - return [] - dim_list = dim if isinstance(dim, list) else [dim] - norm = [normalize_dim(len(dims), d) for d in dim_list] - return [1 if i in norm else elem for i, elem in enumerate(dims) if not (i in norm) or keepdim] - -def contains(lst: list[int], val: int) -> bool: - return len([x for x in lst if x == val]) > 0 - -def scatter(size: int, indices: list[int], values: list[int], fill: int) -> list[int]: - matches = [[k for k in range(len(indices)) if indices[k] == i] for i in range(size)] - return [values[m[0]] if len(m) > 0 else fill for m in matches] - -def move_dims(dims: list[int | symint], source: int | list[int], dest: int | list[int], rank: int) -> list[int | symint]: - src = broadcast_int(source, 1) - dst = broadcast_int(dest, 1) - src_norm = [normalize_dim(rank, s) for s in src] - dst_norm = [normalize_dim(rank, d) for d in dst] - non_dst = [i for i in range(rank) if not contains(dst_norm, i)] - remaining = [i for i in range(rank) if not contains(src_norm, i)] - perm = scatter(rank, dst_norm + non_dst, src_norm + remaining, 0) - return [dims[p] for p in perm] - -def conv_spatial_out(input_dim: int | symint, kernel: int | symint, stride: int | symint, padding: int | symint, dilation: int | symint) -> int | symint: - return (input_dim + 2 * padding - dilation * (kernel - 1) - 1) // stride + 1 - -def reshape_ir(self: Tensor, shape: list[int | symint]) -> Tensor: - minus_one_count = len([d for d in shape if d == -1]) - if minus_one_count > 1: - raise Error("can only specify one unknown dimension as -1") - has_bad_neg = len([d for d in shape if isinstance(d, int) and d < -1]) > 0 - if has_bad_neg: - raise Error("invalid negative dimension value (only -1 is allowed)") - has_zero = len([d for d in shape if isinstance(d, int) and d == 0]) > 0 - if has_zero: - raise Error("reshape dimensions cannot contain 0") - if minus_one_count > 0: - known = shape_extensions.prod([d for d in shape if d != -1]) - total = shape_extensions.prod(self.shape) - if isinstance(total, int) and isinstance(known, int) and total % known != 0: - raise Error("could not infer size for dimension -1: expected " + str(total) + " to be divisible by " + str(known)) - return Tensor(shape=[total // known if d == -1 else d for d in shape]) - return Tensor(shape=shape) - -def squeeze_ir(self: Tensor, dim: int | None = None) -> Tensor: - if dim == None: - return Tensor(shape=[d for d in self.shape if d != 1]) - idx = normalize_dim(len(self.shape), dim) - return Tensor(shape=[d for i, d in enumerate(self.shape) if not (i == idx and d == 1)]) - -def unsqueeze_ir(self: Tensor, dim: int) -> Tensor: - d = normalize_dim(len(self.shape) + 1, dim) - return Tensor(shape=insert_dim(self.shape, d, 1)) - -def transpose_ir(self: Tensor, dim0: int, dim1: int) -> Tensor: - rank = len(self.shape) - d0 = normalize_dim(rank, dim0) - d1 = normalize_dim(rank, dim1) - return Tensor(shape=[self.shape[d1] if i == d0 else self.shape[d0] if i == d1 else d for i, d in enumerate(self.shape)]) - -def permute_ir(self: Tensor, dims: list[int]) -> Tensor: - rank = len(self.shape) - if len(dims) != rank: - raise Error("permute: expected " + str(rank) + " dims, got " + str(len(dims))) - return Tensor(shape=[self.shape[normalize_dim(rank, d)] for d in dims]) - -def flatten_ir(self: Tensor, start_dim: int = 0, end_dim: int = -1) -> Tensor: - rank = len(self.shape) - s = normalize_dim(rank, start_dim) - e = normalize_dim(rank, end_dim) - return Tensor(shape=self.shape[:s] + [shape_extensions.prod(self.shape[s:e + 1])] + self.shape[e + 1:]) - -def expand_ir(self: Tensor, sizes: list[int | symint]) -> Tensor: - return Tensor(shape=[d if t == -1 else t for d, t in zip(self.shape, sizes)]) - -def repeat_ir(self: Tensor, sizes: list[int | symint]) -> Tensor: - return Tensor(shape=[d * r for d, r in zip(self.shape, sizes)]) - -def unbind_ir(self: Tensor, dim: int = 0) -> list[Tensor]: - d = normalize_dim(len(self.shape), dim) - return [Tensor(shape=remove_dim(self.shape, d)), ...] - -def movedim_ir(self: Tensor, source: int | list[int], destination: int | list[int]) -> Tensor: - return Tensor(shape=move_dims(self.shape, source, destination, len(self.shape))) - -def unfold_ir(self: Tensor, dimension: int, size: int | symint, step: int = 1) -> Tensor: - d = normalize_dim(len(self.shape), dimension) - new_dim = (self.shape[d] - size) // step + 1 - return Tensor(shape=replace_dim(self.shape, d, new_dim) + [size]) - -def cat_ir(tensors: list[Tensor], dim: int = 0) -> Tensor: - first = tensors[0] - d = normalize_dim(len(first.shape), dim) - return Tensor(shape=[shape_extensions.sum([t.shape[i] for t in tensors]) if i == d else dim_val for i, dim_val in enumerate(first.shape)]) - -def stack_ir(tensors: list[Tensor], dim: int = 0) -> Tensor: - first = tensors[0] - d = normalize_dim(len(first.shape) + 1, dim) - return Tensor(shape=insert_dim(first.shape, d, len(tensors))) - -def broadcast_to_ir(self: Tensor, shape: list[int | symint]) -> Tensor: - return Tensor(shape=shape) - -def tile_ir(self: Tensor, dims: list[int]) -> Tensor: - rank = len(self.shape) - if len(dims) > rank: - extra = len(dims) - rank - return Tensor(shape=[r for r in dims[:extra]] + [d * r for d, r in zip(self.shape, dims[extra:])]) - return Tensor(shape=[d * r for d, r in zip(self.shape, dims)]) - -def select_ir(self: Tensor, dim: int) -> Tensor: - d = normalize_dim(len(self.shape), dim) - return Tensor(shape=remove_dim(self.shape, d)) - -def narrow_ir(self: Tensor, dim: int, length: int | symint) -> Tensor: - return Tensor(shape=replace_dim(self.shape, normalize_dim(len(self.shape), dim), length)) - -def split_ir(self: Tensor, split_size_or_sections: int | symint | list[int | symint] | None = None, dim: int = 0) -> list[Tensor]: - d = normalize_dim(len(self.shape), dim) - if isinstance(split_size_or_sections, list): - return [Tensor(shape=replace_dim(self.shape, d, section)) for section in split_size_or_sections] - if isinstance(split_size_or_sections, int): - dim_val = self.shape[d] - if isinstance(dim_val, int): - count = (dim_val + split_size_or_sections - 1) // split_size_or_sections - return [Tensor(shape=replace_dim(self.shape, d, split_size_or_sections if i < count - 1 else dim_val - (count - 1) * split_size_or_sections)) for i in range(count)] - return [Tensor(shape=replace_dim(self.shape, d, split_size_or_sections)), ...] - if split_size_or_sections != None: - quotient = self.shape[d] // split_size_or_sections - if isinstance(quotient, int): - return [Tensor(shape=replace_dim(self.shape, d, split_size_or_sections)) for _ in range(quotient)] - return [Tensor(shape=replace_dim(self.shape, d, split_size_or_sections)), ...] - return Unknown - -def chunk_ir(self: Tensor, chunks: int, dim: int = 0) -> list[Tensor]: - d = normalize_dim(len(self.shape), dim) - dim_val = self.shape[d] - if isinstance(dim_val, int): - chunk_size = (dim_val + chunks - 1) // chunks - return [Tensor(shape=replace_dim(self.shape, d, chunk_size if i < chunks - 1 else dim_val - (chunks - 1) * chunk_size)) for i in range(chunks)] - return [Tensor(shape=replace_dim(self.shape, d, dim_val // chunks)) for i in range(chunks)] - -def index_select_ir(self: Tensor, dim: int, index: Tensor) -> Tensor: - return Tensor(shape=replace_dim(self.shape, normalize_dim(len(self.shape), dim), index.shape[0])) - -def reduce_ir(self: Tensor, dim: int | list[int] | None = None, keepdim: bool = False) -> Tensor: - if dim == None: - return Tensor(shape=reduce_shape(self.shape, dim, keepdim)) - if isinstance(dim, list): - return Tensor(shape=reduce_shape(self.shape, dim, keepdim)) - return Tensor(shape=reduce_single(self.shape, dim, keepdim)) - -def reduce_single(dims: list[int | symint], dim: int, keepdim: bool) -> list[int | symint]: - before = dims[:dim] - if dim == -1: - if keepdim: - return before + [1] - return before - after = dims[dim + 1:] - if keepdim: - return before + [1] + after - return before + after - -def min_max_median_ir(self: Tensor, dim: int | None = None, keepdim: bool = False) -> Tensor: - if dim == None: - return Tensor(shape=[]) - s = reduce_shape(self.shape, dim, keepdim) - return [Tensor(shape=s), Tensor(shape=s)] - -def aminmax_ir(self: Tensor, dim: int | list[int] | None = None, keepdim: bool = False) -> [Tensor, Tensor]: - s = reduce_shape(self.shape, dim, keepdim) - return [Tensor(shape=s), Tensor(shape=s)] - -def tuple_reduce_ir(self: Tensor, dim: int = -1, keepdim: bool = False) -> [Tensor, Tensor]: - s = reduce_shape(self.shape, dim, keepdim) - return [Tensor(shape=s), Tensor(shape=s)] - -def topk_ir(self: Tensor, k: int | symint, dim: int = -1) -> [Tensor, Tensor]: - s = replace_dim(self.shape, normalize_dim(len(self.shape), dim), k) - return [Tensor(shape=s), Tensor(shape=s)] - -def repeat_interleave_ir(self: Tensor, repeats: int | symint, dim: int | None = None) -> Tensor: - if dim == None: - return Tensor(shape=[shape_extensions.prod(self.shape) * repeats]) - d = normalize_dim(len(self.shape), dim) - return Tensor(shape=replace_dim(self.shape, d, self.shape[d] * repeats)) - -def cosine_similarity_ir(x1: Tensor, x2: Tensor, dim: int = 1) -> Tensor: - s = broadcast(x1.shape, x2.shape) - return Tensor(shape=reduce_single(s, normalize_dim(len(s), dim), False)) - -def randn_ir(size: list[int | symint]) -> Tensor: - return Tensor(shape=size) - -def randint_ir(low: int, high: int, size: list[int | symint]) -> Tensor: - return Tensor(shape=size) - -def linspace_ir(steps: int | symint) -> Tensor: - return Tensor(shape=[steps]) - -def eye_ir(n: int | symint, m: int | symint | None = None) -> Tensor: - if m == None: - return Tensor(shape=[n, n]) - return Tensor(shape=[n, m]) - -def arange_ir(start: int | symint | None = None, end: int | symint | None = None, step: int | symint | None = None) -> Tensor: - if start != None and end != None and step != None: - return Tensor(shape=[(end - start) // step]) - if start != None and end != None: - return Tensor(shape=[end - start]) - if end != None: - return Tensor(shape=[end]) - if start != None: - return Tensor(shape=[start]) - return Unknown - -def normal_ir(mean: Tensor | None = None, std: Tensor | None = None, size: list[int] | None = None) -> Tensor: - if size != None: - return Tensor(shape=[s for s in size]) - if mean != None: - return Tensor(shape=mean.shape) - if std != None: - return Tensor(shape=std.shape) - return Unknown - -def diag_embed_ir(self: Tensor, offset: int = 0) -> Tensor: - new_dim = self.shape[-1] + (offset if offset >= 0 else -offset) - return Tensor(shape=self.shape[:-1] + [new_dim, new_dim]) - -def tri_indices_ir(row: int | symint, col: int | symint, offset: int = 0) -> Tensor: - return Tensor(shape=[2, 0]) - -def matmul_ir(self: Tensor, other: Tensor) -> Tensor: - r1 = len(self.shape) - r2 = len(other.shape) - if r1 == 1 and r2 == 1: - return Tensor(shape=[]) - if r1 == 1 and r2 == 2: - return Tensor(shape=[other.shape[1]]) - if r1 == 2 and r2 == 1: - return Tensor(shape=[self.shape[0]]) - if r1 == 2 and r2 == 2: - return Tensor(shape=[self.shape[0], other.shape[1]]) - if r1 == 2 and r2 >= 3: - return Tensor(shape=other.shape[:-2] + [self.shape[0]] + [other.shape[-1]]) - if r1 >= 3 and r2 == 2: - return Tensor(shape=self.shape[:-2] + [self.shape[-2]] + [other.shape[1]]) - if r1 >= 3 and r2 >= 3: - return Tensor(shape=broadcast(self.shape[:-2], other.shape[:-2]) + [self.shape[-2]] + [other.shape[-1]]) - return Unknown - -def mv_ir(self: Tensor, vec: Tensor) -> Tensor: - if len(self.shape) != 2: - raise Error("mv expects 2D matrix, got " + str(len(self.shape)) + "D tensor") - if len(vec.shape) != 1: - raise Error("mv expects 1D vector, got " + str(len(vec.shape)) + "D tensor") - return Tensor(shape=[self.shape[0]]) - -def outer_ir(self: Tensor, vec2: Tensor) -> Tensor: - if len(self.shape) != 1 or len(vec2.shape) != 1: - raise Error("outer expects 1D tensors, got " + str(len(self.shape)) + "D and " + str(len(vec2.shape)) + "D") - return Tensor(shape=[self.shape[0], vec2.shape[0]]) - -def tensordot_ir(self: Tensor, other: Tensor, dims: int) -> Tensor: - return Tensor(shape=self.shape[:len(self.shape) - dims] + other.shape[dims:]) - -def apply_einsum(output_map: list[list[int]], check_pairs: list[list[int]], inputs: list[Tensor]) -> Tensor: - bad_dims = [1 for i0, d0, i1, d1 in check_pairs if isinstance(inputs[i0].shape[d0], int) and isinstance(inputs[i1].shape[d1], int) and inputs[i0].shape[d0] != inputs[i1].shape[d1]] - if len(bad_dims) > 0: - raise Error("einsum: inconsistent dimensions for repeated index") - return Tensor(shape=[inputs[inp].shape[dim] for inp, dim in output_map]) - -def einsum_ir(spec: str, operands: list[Tensor] | None = None) -> Tensor: - if operands != None: - output_map, check_pairs = shape_extensions.parse_einsum_equation(spec) - return apply_einsum(output_map, check_pairs, operands) - return Unknown - -def eigvals_ir(self: Tensor) -> Tensor: - if len(self.shape) < 2: - raise Error("eigvals requires at least 2D input, got " + str(len(self.shape)) + "D tensor") - return Tensor(shape=self.shape[:-2] + [self.shape[-2]]) - -def eig_ir(self: Tensor) -> [Tensor, Tensor]: - if len(self.shape) < 2: - raise Error("eig requires at least 2D input, got " + str(len(self.shape)) + "D tensor") - batch = self.shape[:-2] - return [Tensor(shape=batch + [self.shape[-2]]), Tensor(shape=batch + self.shape[-2:])] - -def slogdet_ir(self: Tensor) -> [Tensor, Tensor]: - if len(self.shape) < 2: - raise Error("slogdet requires at least 2D input, got " + str(len(self.shape)) + "D tensor") - return [Tensor(shape=self.shape[:-2]), Tensor(shape=self.shape[:-2])] - -def solve_ir(self: Tensor, other: Tensor) -> Tensor: - return Tensor(shape=other.shape) - -def solve_reversed_ir(self: Tensor, other: Tensor) -> Tensor: - return Tensor(shape=self.shape) - -def conv_ir(self: Tensor, weight: Tensor, stride: int | list[int] = 1, padding: int | list[int] = 0, dilation: int | list[int] = 1) -> Tensor: - spatial_dims = len(self.shape) - 2 - stride_list = broadcast_int(stride, spatial_dims) - padding_list = broadcast_int(padding, spatial_dims) - dilation_list = broadcast_int(dilation, spatial_dims) - return Tensor(shape=[self.shape[0], weight.shape[0]] + [conv_spatial_out(s, k, st, p, dil) for s, k, st, p, dil in zip(self.shape[2:], weight.shape[2:], stride_list, padding_list, dilation_list)]) - -def conv_transpose_ir(self: Tensor, weight: Tensor, stride: int | list[int] = 1, padding: int | list[int] = 0, output_padding: int | list[int] = 0, dilation: int | list[int] = 1) -> Tensor: - spatial_dims = len(self.shape) - 2 - stride_list = broadcast_int(stride, spatial_dims) - padding_list = broadcast_int(padding, spatial_dims) - outpad_list = broadcast_int(output_padding, spatial_dims) - dilation_list = broadcast_int(dilation, spatial_dims) - return Tensor(shape=[self.shape[0], weight.shape[1]] + [(s - 1) * st - 2 * p + dil * (k - 1) + op + 1 for s, k, st, p, op, dil in zip(self.shape[2:], weight.shape[2:], stride_list, padding_list, outpad_list, dilation_list)]) - -def pool_ir(self: Tensor, kernel_size: int | list[int], stride: int | list[int] | None = None, padding: int | list[int] = 0, dilation: int | list[int] = 1, return_indices: bool = False) -> Tensor: - spatial_dims = len(self.shape) - 2 - ks_list = broadcast_int(kernel_size, spatial_dims) - stride_list = ks_list if stride == None else broadcast_int(stride, spatial_dims) - padding_list = broadcast_int(padding, spatial_dims) - dilation_list = broadcast_int(dilation, spatial_dims) - out = [self.shape[0], self.shape[1]] + [conv_spatial_out(s, k, st, p, dil) for s, k, st, p, dil in zip(self.shape[2:], ks_list, stride_list, padding_list, dilation_list)] - if return_indices: - return [Tensor(shape=out), Tensor(shape=out)] - return Tensor(shape=out) - -def adaptive_pool_ir(self: Tensor, output_size: int | symint | list[int | symint]) -> Tensor: - out_sizes = broadcast_int(output_size, len(self.shape) - 2) - return Tensor(shape=[self.shape[0], self.shape[1]] + out_sizes) - -def interpolate_ir(self: Tensor, size: int | symint | list[int | symint] | None = None, scale_factor: int | symint | None = None) -> Tensor: - if size != None: - return Tensor(shape=[self.shape[0], self.shape[1]] + broadcast_int(size, len(self.shape) - 2)) - if scale_factor != None: - return Tensor(shape=[self.shape[0], self.shape[1]] + [d * scale_factor for d in self.shape[2:]]) - raise Error("interpolate requires either 'size' or 'scale_factor' argument") - -def loss_ir(self: Tensor, reduction: str = "mean") -> Tensor: - if reduction == "none": - return Tensor(shape=self.shape) - return Tensor(shape=[]) - -def pad_ir(self: Tensor, pad: list[int]) -> Tensor: - rank = len(self.shape) - num_pad_dims = len(pad) // 2 - offsets = [pad[(rank - 1 - i) * 2] + pad[(rank - 1 - i) * 2 + 1] if i >= rank - num_pad_dims else 0 for i in range(rank)] - return Tensor(shape=[d + offsets[i] for i, d in enumerate(self.shape)]) - -def rfft_ir(self: Tensor, n: int | symint | None = None, dim: int = -1) -> Tensor: - d = normalize_dim(len(self.shape), dim) - if n != None: - return Tensor(shape=replace_dim(self.shape, d, n // 2 + 1)) - return Tensor(shape=replace_dim(self.shape, d, self.shape[d] // 2 + 1)) - -def irfft_ir(self: Tensor, n: int | symint | None = None, dim: int = -1) -> Tensor: - d = normalize_dim(len(self.shape), dim) - if n != None: - return Tensor(shape=replace_dim(self.shape, d, n)) - return Tensor(shape=replace_dim(self.shape, d, 2 * (self.shape[d] - 1))) - -def size_ir(self: Tensor, dim: int | None = None) -> int | symint: - if dim != None: - return self.shape[normalize_dim(len(self.shape), dim)] - return [d for d in self.shape] - -def numel_ir(self: Tensor) -> int | symint: - return shape_extensions.prod(self.shape) - -def dim_ir(self: Tensor) -> int: - return len(self.shape) - -def item_ir(self: Tensor) -> Tensor: - if len(self.shape) != 0: - raise Error("item() only works on 0-dimensional tensors, got " + str(len(self.shape)) + "D tensor") - return Unknown - -def tolist_ir(self: Tensor) -> Tensor: - return Unknown - -def multinomial_ir(self: Tensor, num_samples: int | symint) -> Tensor: - return Tensor(shape=self.shape[:-1] + [num_samples]) - -def where_ir(condition: Tensor, x: Tensor, y: Tensor) -> Tensor: - return Tensor(shape=x.shape) - -def take_along_dim_ir(self: Tensor, indices: Tensor) -> Tensor: - return Tensor(shape=indices.shape) - -def nn_flatten_forward_ir(input: Tensor, start_dim: symint = 1, end_dim: symint = -1) -> Tensor: - return flatten_ir(input, start_dim, end_dim) - -def nn_maxpool_forward_ir(input: Tensor, kernel_size: symint = 1, stride: symint | None = None, padding: symint = 0, dilation: symint = 1) -> Tensor: - return pool_ir(input, kernel_size, stride, padding, dilation) - -def nn_avgpool_forward_ir(input: Tensor, kernel_size: symint = 1, stride: symint | None = None, padding: symint = 0) -> Tensor: - return pool_ir(input, kernel_size, stride, padding, 1) - -def nn_upsample_forward_ir(input: Tensor, size: symint | None = None, scale_factor: symint | None = None) -> Tensor: - return interpolate_ir(input, size, scale_factor) - -def nn_pixel_shuffle_forward_ir(input: Tensor, upscale_factor: symint) -> Tensor: - r = upscale_factor - return Tensor(shape=[input.shape[0], input.shape[1] // (r * r)] + [d * r for d in input.shape[2:]]) - -def nn_glu_forward_ir(input: Tensor, dim: symint = 1) -> Tensor: - rank = len(input.shape) - d = normalize_dim(rank, dim) - return Tensor(shape=replace_dim(input.shape, d, input.shape[d] // 2)) - -def nn_lstm_forward_ir(input: Tensor, input_size: symint, hidden_size: symint, num_layers: symint = 1, bidirectional: bool = False) -> [Tensor, Tensor, Tensor]: - nd = 2 if bidirectional else 1 - output = Tensor(shape=[input.shape[0], input.shape[1], hidden_size * nd]) - h_n = Tensor(shape=[num_layers * nd, input.shape[0], hidden_size]) - c_n = Tensor(shape=[num_layers * nd, input.shape[0], hidden_size]) - return [output, h_n, c_n] - -def nn_gru_forward_ir(input: Tensor, input_size: symint, hidden_size: symint, num_layers: symint = 1, bidirectional: bool = False) -> [Tensor, Tensor]: - nd = 2 if bidirectional else 1 - output = Tensor(shape=[input.shape[0], input.shape[1], hidden_size * nd]) - h_n = Tensor(shape=[num_layers * nd, input.shape[0], hidden_size]) - return [output, h_n] - -def nn_lstmcell_forward_ir(input: Tensor, input_size: symint, hidden_size: symint) -> [Tensor, Tensor]: - h = Tensor(shape=[input.shape[0], hidden_size]) - c = Tensor(shape=[input.shape[0], hidden_size]) - return [h, c] - -def nn_reflectionpad2d_forward_ir(input: Tensor, padding: symint) -> Tensor: - return Tensor(shape=[input.shape[0], input.shape[1], input.shape[2] + 2 * padding, input.shape[3] + 2 * padding]) -"#; diff --git a/pyrefly/lib/alt/call.rs b/pyrefly/lib/alt/call.rs index 1c085cdbf1..a7a2c3bcb2 100644 --- a/pyrefly/lib/alt/call.rs +++ b/pyrefly/lib/alt/call.rs @@ -10,9 +10,9 @@ use std::sync::Arc; use dupe::Dupe; use pyrefly_python::dunder; +use pyrefly_types::meta_shape_dsl::ShapeTransformRef; use pyrefly_types::quantified::Quantified; use pyrefly_types::special_form::SpecialForm; -use pyrefly_types::tensor_ops_registry::TensorOpsRegistry; use pyrefly_types::typed_dict::TypedDictInner; use pyrefly_types::types::CalleeKind; use pyrefly_types::types::NNModuleType; @@ -1120,14 +1120,13 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { errors: &ErrorCollector, result: Type, ) -> Type { - use std::sync::OnceLock; - static TENSOR_OPS_REGISTRY: OnceLock = OnceLock::new(); - - let class_name = format!("{}.{}", ct.class_object().module_name(), ct.name()); - let registry = TENSOR_OPS_REGISTRY.get_or_init(TensorOpsRegistry::new); - let capture_names = match registry.get_init_capture(&class_name) { - Some(names) => names, - None => return result, + let class_metadata = self.get_metadata_for_class(ct.class_object()); + let capture_names_from_metadata: Vec; + let capture_names: &[Name] = if let Some(names) = class_metadata.capture_init() { + capture_names_from_metadata = names.to_vec(); + &capture_names_from_metadata + } else { + return result; }; let infer_type_or_expr = |toe: TypeOrExpr, errors: &ErrorCollector| -> Type { @@ -1139,17 +1138,16 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { let mut fields = SmallMap::new(); for (i, param_name) in capture_names.iter().enumerate() { - let name = Name::new(param_name); // First check keyword args. if let Some(kw) = keywords.iter().find(|k| { k.arg .is_some_and(|id| id.id.as_str() == param_name.as_str()) }) { - fields.insert(name, infer_type_or_expr(kw.value, errors)); + fields.insert(param_name.clone(), infer_type_or_expr(kw.value, errors)); } else if i < args.len() { // Map positional arg by index to the capture param name. if let CallArg::Arg(toe) = &args[i] { - fields.insert(name, infer_type_or_expr(*toe, errors)); + fields.insert(param_name.clone(), infer_type_or_expr(*toe, errors)); } } // If neither keyword nor positional, the param uses its default. @@ -1313,7 +1311,9 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { if let Some(m) = metadata && (matches!( m.kind, - FunctionKind::Dataclass | FunctionKind::DataclassTransform + FunctionKind::Dataclass + | FunctionKind::DataclassTransform + | FunctionKind::UsesShapeDsl ) || m.kind.is_signature_preserving_decorator() || m.flags.dataclass_transform_metadata.is_some()) { @@ -1420,6 +1420,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { ) => self.call_infer_inner( signature, Some(&metadata.kind), + metadata.flags.shape_transform.as_deref(), tparams.as_deref(), Some(obj), args, @@ -1434,6 +1435,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { CallTarget::Callable(TargetWithTParams(tparams, callable)) => self.call_infer_inner( callable, None, + None, tparams.as_deref(), None, args, @@ -1454,6 +1456,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { )) => self.call_infer_inner( callable, Some(&metadata.kind), + metadata.flags.shape_transform.as_deref(), tparams.as_deref(), None, args, @@ -1469,6 +1472,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { self.call_overloads( overloads, &metadata, + metadata.flags.shape_transform.as_deref(), None, args, keywords, @@ -1484,6 +1488,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { self.call_overloads( overloads, &meta, + meta.flags.shape_transform.as_deref(), Some(obj), args, keywords, @@ -1551,6 +1556,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { &self, callable: Callable, callable_name: Option<&FunctionKind>, + shape_transform: Option<&ShapeTransformRef>, tparams: Option<&TParams>, self_obj: Option, args: &[CallArg], @@ -1568,6 +1574,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { let res_no_hint = self.callable_infer( callable.clone(), callable_name, + shape_transform, tparams, self_obj.clone(), args, @@ -1587,6 +1594,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { let res_with_hint = self.callable_infer( callable, callable_name, + shape_transform, tparams, self_obj, args, diff --git a/pyrefly/lib/alt/callable.rs b/pyrefly/lib/alt/callable.rs index 551cf0e905..4344ea46f0 100644 --- a/pyrefly/lib/alt/callable.rs +++ b/pyrefly/lib/alt/callable.rs @@ -12,7 +12,7 @@ use itertools::Itertools; use pyrefly_python::dunder; use pyrefly_types::callable::FunctionKind; use pyrefly_types::meta_shape_dsl::MetaShapeFunction; -use pyrefly_types::tensor_ops_registry::TensorOpsRegistry; +use pyrefly_types::meta_shape_dsl::ShapeTransformRef; use pyrefly_types::tuple::Tuple; use pyrefly_types::typed_dict::ExtraItems; use pyrefly_types::types::TArgs; @@ -1326,6 +1326,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { &self, callable: Callable, callable_name: Option<&FunctionKind>, + shape_transform: Option<&ShapeTransformRef>, tparams: Option<&TParams>, self_obj: Option, args: &[CallArg], @@ -1348,6 +1349,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { self.callable_infer_inner( callable.clone(), callable_name, + shape_transform, tparams, self_obj.clone(), args, @@ -1368,6 +1370,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { &self, callable: Callable, callable_name: Option<&FunctionKind>, + shape_transform: Option<&ShapeTransformRef>, tparams: Option<&TParams>, mut self_obj: Option, mut args: &[CallArg], @@ -1387,14 +1390,8 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { call_context.set_argument_side(ArgumentSide::Got); call_context.require_boundary_consumption(); - // Look up meta-shape early so we can conditionally collect bound args. - // Only consult the registry when tensor_shapes is enabled to avoid - // unnecessary DSL parsing and per-call HashMap lookups. - let meta_shape_func = if self.solver().tensor_shapes { - Self::lookup_meta_shape(callable_name) - } else { - None - }; + let shape_transform_func = shape_transform.map(|t| t.to_meta_shape_function()); + let meta_shape_func: Option<&dyn MetaShapeFunction> = shape_transform_func.as_deref(); let mut bound_args: Option> = meta_shape_func.map(|_| HashMap::new()); let (callable_qs, mut callable) = if let Some(tparams) = tparams { @@ -1619,26 +1616,6 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { ) } - /// Look up whether a callable has a registered meta-shape function. - fn lookup_meta_shape(callable_name: Option<&FunctionKind>) -> Option<&dyn MetaShapeFunction> { - use std::sync::OnceLock; - static TENSOR_OPS_REGISTRY: OnceLock = OnceLock::new(); - - let func_id = callable_name.and_then(|fk| match fk { - FunctionKind::Def(box_func_id) => Some(box_func_id.as_ref()), - _ => None, - })?; - - let qualified_name = if let Some(cls) = &func_id.cls { - format!("{}.{}.{}", func_id.module.name(), cls.name(), func_id.name) - } else { - format!("{}.{}", func_id.module.name(), func_id.name) - }; - - let registry = TENSOR_OPS_REGISTRY.get_or_init(TensorOpsRegistry::new); - registry.get(&qualified_name) - } - /// Auto-inject module field values into `bound_args` for DSL parameters /// that aren't method parameters but match fields on `self`. /// diff --git a/pyrefly/lib/alt/class/class_metadata.rs b/pyrefly/lib/alt/class/class_metadata.rs index c0a3853bcc..a4e5994e0d 100644 --- a/pyrefly/lib/alt/class/class_metadata.rs +++ b/pyrefly/lib/alt/class/class_metadata.rs @@ -129,6 +129,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { pydantic_config_dict: &PydanticConfigDict, pydantic_before_validator_fields: &[Name], django_field_info: &DjangoFieldInfo, + capture_init: Option<&[Name]>, errors: &ErrorCollector, ) -> ClassMetadata { // Get class decorators. @@ -507,6 +508,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { is_factory_boy_factory, is_metaclass, slots_info, + capture_init.map(|names| names.to_vec()), ) } diff --git a/pyrefly/lib/alt/class/dataclass.rs b/pyrefly/lib/alt/class/dataclass.rs index 85b4969cba..a1a8b8c03b 100644 --- a/pyrefly/lib/alt/class/dataclass.rs +++ b/pyrefly/lib/alt/class/dataclass.rs @@ -678,6 +678,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { })) .unwrap(), &overload.metadata, + None, // no shape_transform for dataclass constructors None, &args.args.map(CallArg::expr_maybe_starred), &args.keywords.map(CallKeyword::new), diff --git a/pyrefly/lib/alt/function.rs b/pyrefly/lib/alt/function.rs index efad7d3617..cb6158a15a 100644 --- a/pyrefly/lib/alt/function.rs +++ b/pyrefly/lib/alt/function.rs @@ -5,6 +5,7 @@ * LICENSE file in the root directory of this source tree. */ +use std::collections::HashSet; use std::collections::VecDeque; use std::ops::Deref; use std::sync::Arc; @@ -15,11 +16,16 @@ use pyrefly_python::ast::Ast; use pyrefly_python::dunder; use pyrefly_python::module_path::ModuleStyle; use pyrefly_python::short_identifier::ShortIdentifier; +use pyrefly_types::callable::Derived; +use pyrefly_types::callable::FuncId; use pyrefly_types::callable::Params; use pyrefly_types::callable::PlaceholderBodyKind; use pyrefly_types::class::Class; use pyrefly_types::class::ClassType; use pyrefly_types::dimension::SizeExpr; +use pyrefly_types::meta_shape_dsl::ShapeDslFunction; +use pyrefly_types::meta_shape_dsl::ShapeTransformRef; +use pyrefly_types::meta_shape_dsl::validate_shape_dsl_functions; use pyrefly_types::quantified::Quantified; use pyrefly_types::quantified::QuantifiedOrigin; use pyrefly_types::type_var::Restriction; @@ -435,6 +441,8 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { legacy_tparams: &[Idx], module_style: ModuleStyle, outer_funcs: Option, + shape_dsl_def: Option>, + uses_shape_dsl_ir_name: Option<(Name, ShortIdentifier)>, errors: &ErrorCollector, ) -> Arc { let defining_cls = class_key.and_then(|k| self.get_idx(*k).0.dupe()); @@ -536,13 +544,71 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { tparams.extend(legacy_tparams); let tparams = self.validated_tparams(def.range, tparams, TParamsSource::Function, errors); - let kind = FunctionKind::from_name( - self.module().dupe(), - defining_cls.clone(), - &def.name.id, - Some(def_index), - outer_funcs, - ); + let kind = if let Some(dsl_fn) = shape_dsl_def { + // Build the transitive closure of helper functions this DSL function calls, + // then validate cross-function call signatures. + let module_dsl_fns = self.bindings().metadata().shape_dsl_functions(); + let helpers = compute_transitive_helpers(&dsl_fn, module_dsl_fns); + if let Err(type_errors) = validate_shape_dsl_functions(&helpers) { + for msg in &type_errors { + self.error( + errors, + def.name.range, + ErrorKind::InvalidArgument, + format!("@shape_dsl_function type error: {msg}"), + ); + } + // Fall back to a normal function — the DSL evaluator must + // never run on a program that failed type checking. + FunctionKind::from_name( + self.module().dupe(), + defining_cls.clone(), + &def.name.id, + Some(def_index), + outer_funcs, + ) + } else { + let func_id = Arc::new(FuncId { + module: self.module().dupe(), + cls: defining_cls.clone(), + name: def.name.id.clone(), + def_index: Some(def_index), + outer_funcs, + }); + FunctionKind::ShapeDsl(func_id, dsl_fn, Derived(helpers)) + } + } else { + FunctionKind::from_name( + self.module().dupe(), + defining_cls.clone(), + &def.name.id, + Some(def_index), + outer_funcs, + ) + }; + + // Resolve the IR function reference from @uses_shape_dsl(ir_fn) and + // populate `flags.shape_transform` with the DSL function it points to. + if let Some((_name, ir_identifier)) = uses_shape_dsl_ir_name { + let ir_type = self.get(&Key::BoundName(ir_identifier)).arc_clone_ty(); + if let Type::Function(func) = &ir_type + && let FunctionKind::ShapeDsl(_, dsl_fn, helpers) = &func.metadata.kind + { + flags.shape_transform = Some(Arc::new(ShapeTransformRef { + dsl_fn: dsl_fn.clone(), + helpers: helpers.clone(), + })); + } else { + self.error( + errors, + ir_identifier.range(), + ErrorKind::InvalidArgument, + "`@uses_shape_dsl` argument does not resolve to a `@shape_dsl_function`" + .to_owned(), + ); + } + } + let metadata = FuncMetadata { kind, flags }; Arc::new(UndecoratedFunction { @@ -784,6 +850,11 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { { Some(SpecialDecorator::DataclassTransformCall(&call.keywords)) } + _ if let Type::KwCall(call) = &decorator.ty + && call.has_function_kind(FunctionKind::UsesShapeDsl) => + { + Some(SpecialDecorator::UsesShapeDsl) + } Some(CalleeKind::Class(ClassKind::EnumNonmember)) => { Some(SpecialDecorator::EnumNonmember) } @@ -872,6 +943,13 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { flags.is_abstract_method = true; true } + SpecialDecorator::UsesShapeDsl => { + // The actual shape_transform flag is populated after the decorator + // loop in undecorated_function, where uses_shape_dsl_ir_name is + // available. Returning true here just filters the decorator out of + // the list so it doesn't go through the generic decorator pipeline. + true + } _ => false, } } @@ -2495,3 +2573,34 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { } } } + +/// Compute the transitive closure of DSL helper functions called by `root`. +/// +/// Starting from `root`, follows `call_targets()` through `module_dsl_fns` +/// to collect every user-defined function reachable from `root`. The result +/// always includes `root` itself as the first element. +fn compute_transitive_helpers( + root: &Arc, + module_dsl_fns: &[(Name, Arc)], +) -> Arc>> { + let mut closure: Vec> = vec![Arc::clone(root)]; + let mut visited: HashSet = HashSet::new(); + visited.insert(root.name().to_owned()); + + let mut queue: Vec = root.call_targets().into_iter().collect(); + while let Some(name) = queue.pop() { + if !visited.insert(name.clone()) { + continue; + } + if let Some((_, dsl_fn)) = module_dsl_fns.iter().find(|(n, _)| n.as_str() == name) { + closure.push(Arc::clone(dsl_fn)); + for target in dsl_fn.call_targets() { + if !visited.contains(&target) { + queue.push(target); + } + } + } + } + + Arc::new(closure) +} diff --git a/pyrefly/lib/alt/overload.rs b/pyrefly/lib/alt/overload.rs index 28311de3c7..aba14988e4 100644 --- a/pyrefly/lib/alt/overload.rs +++ b/pyrefly/lib/alt/overload.rs @@ -13,6 +13,7 @@ use itertools::Itertools; use pyrefly_types::callable::ArgCount; use pyrefly_types::callable::ArgCounts; use pyrefly_types::callable::Param; +use pyrefly_types::meta_shape_dsl::ShapeTransformRef; use pyrefly_types::tuple::Tuple; use pyrefly_types::types::TArgs; use pyrefly_util::gas::Gas; @@ -231,6 +232,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { &self, overloads: Vec1>, metadata: &FuncMetadata, + shape_transform: Option<&ShapeTransformRef>, self_obj: Option, args: &[CallArg], keywords: &[CallKeyword], @@ -286,6 +288,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { let (mut closest_overload, mut matched) = self.find_closest_overload( &arity_compatible_overloads, metadata, + shape_transform, self_obj.as_ref(), &args, &keywords, @@ -310,6 +313,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { let (cur_closest, cur_matched) = self.find_closest_overload( &arity_compatible_overloads, metadata, + shape_transform, self_obj.as_ref(), cur_args, cur_keywords, @@ -516,6 +520,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { &self, overloads: &Vec1<&'c TargetWithTParams>, metadata: &FuncMetadata, + shape_transform: Option<&ShapeTransformRef>, self_obj: Option<&Type>, args: &[CallArg], keywords: &[CallKeyword], @@ -535,6 +540,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { let called_overload = self.call_overload( callable, metadata, + shape_transform, self_obj, args, keywords, @@ -656,6 +662,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { let res = self.call_overload( o.func, metadata, + shape_transform, self_obj, &materialized_args, &materialized_keywords, @@ -687,6 +694,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { let contextual_overload = self.call_overload( overload.func, metadata, + shape_transform, self_obj, args, keywords, @@ -811,6 +819,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { &self, callable: &'c TargetWithTParams, metadata: &FuncMetadata, + shape_transform: Option<&ShapeTransformRef>, self_obj: Option<&Type>, args: &[CallArg], keywords: &[CallKeyword], @@ -830,6 +839,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { let (res, specialization_errors, expected_types) = self.callable_infer( callable.1.signature.clone(), Some(&metadata.kind), + shape_transform, tparams, self_obj.cloned(), args, diff --git a/pyrefly/lib/alt/solve.rs b/pyrefly/lib/alt/solve.rs index dd49d2f0c5..5401e4c59a 100644 --- a/pyrefly/lib/alt/solve.rs +++ b/pyrefly/lib/alt/solve.rs @@ -344,6 +344,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { pydantic_config_dict, pydantic_before_validator_fields, django_field_info, + capture_init, } = binding; let metadata = match &self.get_idx(*k).0 { None => ClassMetadata::recursive(), @@ -356,6 +357,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { pydantic_config_dict, pydantic_before_validator_fields, django_field_info, + capture_init.as_deref(), errors, ), }; @@ -5349,6 +5351,8 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { &x.legacy_tparams, x.module_style, x.outer_funcs.clone(), + x.shape_dsl_def.clone(), + x.uses_shape_dsl_ir_name.clone(), errors, ) } diff --git a/pyrefly/lib/alt/types/class_metadata.rs b/pyrefly/lib/alt/types/class_metadata.rs index 8d4b97eb7b..7c8d8636e3 100644 --- a/pyrefly/lib/alt/types/class_metadata.rs +++ b/pyrefly/lib/alt/types/class_metadata.rs @@ -74,6 +74,9 @@ pub struct ClassMetadata { /// Whether this class is a metaclass (i.e., a subclass of `type`). is_metaclass: bool, slots_info: Option, + /// `__init__` parameter names to capture for shape inference, extracted from + /// `@uses_shape_dsl(..., capture_init=[...])` on a `forward` method. + capture_init: Option>, } impl VisitMut for ClassMetadata { @@ -132,6 +135,7 @@ impl ClassMetadata { is_factory_boy_factory: bool, is_metaclass: bool, slots_info: Option, + capture_init: Option>, ) -> ClassMetadata { ClassMetadata { metaclass, @@ -158,6 +162,7 @@ impl ClassMetadata { is_factory_boy_factory, is_metaclass, slots_info, + capture_init, } } @@ -187,6 +192,7 @@ impl ClassMetadata { is_factory_boy_factory: false, is_metaclass: false, slots_info: None, + capture_init: None, } } @@ -346,6 +352,10 @@ impl ClassMetadata { pub fn django_model_metadata(&self) -> Option<&DjangoModelMetadata> { self.django_model_metadata.as_ref() } + + pub fn capture_init(&self) -> Option<&[Name]> { + self.capture_init.as_deref() + } } /// A field that we synthesize and add to a class. Note that if a non-synthesized field already diff --git a/pyrefly/lib/alt/types/decorated_function.rs b/pyrefly/lib/alt/types/decorated_function.rs index 3d97bb0c34..45c6fffda9 100644 --- a/pyrefly/lib/alt/types/decorated_function.rs +++ b/pyrefly/lib/alt/types/decorated_function.rs @@ -98,6 +98,7 @@ pub enum SpecialDecorator<'a> { DataclassTransformCall(&'a TypeMap), EnumNonmember, AbstractMethod, + UsesShapeDsl, } impl UndecoratedFunction { diff --git a/pyrefly/lib/binding/binding.rs b/pyrefly/lib/binding/binding.rs index 005dc9ea8f..d642f282b4 100644 --- a/pyrefly/lib/binding/binding.rs +++ b/pyrefly/lib/binding/binding.rs @@ -9,6 +9,7 @@ use std::fmt; use std::fmt::Debug; use std::fmt::Display; use std::hash::Hash; +use std::sync::Arc; use dupe::Dupe; use pyrefly_derive::TypeEq; @@ -22,6 +23,7 @@ use pyrefly_python::short_identifier::ShortIdentifier; use pyrefly_python::symbol_kind::SymbolKind; use pyrefly_types::callable::PlaceholderBodyKind; use pyrefly_types::heap::TypeHeap; +use pyrefly_types::meta_shape_dsl::ShapeDslFunction; use pyrefly_types::special_form::SpecialForm; use pyrefly_types::type_alias::TypeAlias; use pyrefly_types::type_alias::TypeAliasIndex; @@ -118,7 +120,7 @@ assert_words!(BindingAnnotation, 15); assert_words!(BindingClass, 11); assert_words!(BindingTParams, 10); assert_words!(BindingClassBaseType, 3); -assert_words!(BindingClassMetadata, 11); +assert_words!(BindingClassMetadata, 13); assert_bytes!(BindingClassMro, 4); assert_bytes!(BindingAbstractClassCheck, 4); assert_bytes!(BindingClassSubscriptSymmetry, 4); @@ -129,7 +131,7 @@ assert_words!(BindingYield, 4); assert_words!(BindingYieldFrom, 4); assert_words!(BindingDecorator, 10); assert_bytes!(BindingDecoratedFunction, 20); -assert_words!(BindingUndecoratedFunction, 18); +assert_words!(BindingUndecoratedFunction, 23); #[derive(Clone, Dupe, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub enum AnyIdx { @@ -1856,6 +1858,13 @@ pub struct BindingUndecoratedFunction { /// Dot-separated path of enclosing function names (e.g. `"f1"` for `f2` defined inside `f1`, /// or `"f1.g1"` for two levels deep). `None` for top-level or class-method functions. pub outer_funcs: Option, + /// When the function is decorated with `@shape_dsl_function`, this holds the + /// parsed DSL IR so the solver can produce `FunctionKind::ShapeDsl`. + pub shape_dsl_def: Option>, + /// Name of the IR function passed as the first positional argument to + /// `@uses_shape_dsl(ir_fn)`. Extracted at binding time so the solver can + /// resolve it to a `FunctionKind::ShapeDsl` type. + pub uses_shape_dsl_ir_name: Option<(Name, ShortIdentifier)>, } impl DisplayWith for BindingUndecoratedFunction { @@ -3167,6 +3176,9 @@ pub struct BindingClassMetadata { pub pydantic_before_validator_fields: Box<[Name]>, /// Django-specific field information. pub django_field_info: Box, + /// `__init__` parameter names to capture for shape inference, extracted from + /// `@uses_shape_dsl(..., capture_init=[...])` on a `forward` method. + pub capture_init: Option>, } impl DisplayWith for BindingClassMetadata { diff --git a/pyrefly/lib/binding/class.rs b/pyrefly/lib/binding/class.rs index 8faa1ce0c8..fb821221fb 100644 --- a/pyrefly/lib/binding/class.rs +++ b/pyrefly/lib/binding/class.rs @@ -246,6 +246,7 @@ impl<'a> BindingsBuilder<'a> { let body = mem::take(&mut x.body); let field_docstrings = self.extract_field_docstrings(&body); let pydantic_before_validator_fields = self.extract_field_validator_fields(&body); + let capture_init = self.extract_capture_init(&body); let decorators = self.ensure_and_bind_decorators(mem::take(&mut x.decorator_list), class_object.usage()); @@ -540,6 +541,7 @@ impl<'a> BindingsBuilder<'a> { pydantic_before_validator_fields: pydantic_before_validator_fields .into_boxed_slice(), django_field_info: Box::new(django_field_info), + capture_init: capture_init.map(|v| v.into_boxed_slice()), }, ); self.insert_binding_idx( @@ -556,6 +558,37 @@ impl<'a> BindingsBuilder<'a> { ); } + /// Scan a class body for a `forward` method decorated with + /// `@uses_shape_dsl(..., capture_init=[...])` and return the list of `__init__` + /// parameter names to capture for shape inference. + fn extract_capture_init(&self, body: &[Stmt]) -> Option> { + body.iter() + .filter_map(|stmt| stmt.as_function_def_stmt()) + .filter(|func_def| func_def.name.as_str() == "forward") + .flat_map(|func_def| &func_def.decorator_list) + .find_map(|decorator| { + let call = decorator.expression.as_call_expr()?; + if self.as_special_export(&call.func) != Some(SpecialExport::UsesShapeDsl) { + return None; + } + let capture_init_kw = call.arguments.keywords.iter().find(|kw| { + kw.arg + .as_ref() + .is_some_and(|a| a.as_str() == "capture_init") + })?; + let list = capture_init_kw.value.as_list_expr()?; + let names: Vec = list + .elts + .iter() + .filter_map(|elt| { + elt.as_string_literal_expr() + .map(|s| Name::new(s.value.to_str())) + }) + .collect(); + Some(names) + }) + } + /// Extracts docstrings for each field, mapping the field's range to the docstring's range. fn extract_field_docstrings( &self, @@ -994,6 +1027,7 @@ impl<'a> BindingsBuilder<'a> { pydantic_config_dict: PydanticConfigDict::default(), pydantic_before_validator_fields: Box::default(), django_field_info: Box::default(), + capture_init: None, }, ); self.insert_binding_idx( diff --git a/pyrefly/lib/binding/function.rs b/pyrefly/lib/binding/function.rs index 8d187964f7..4ca56a8ae8 100644 --- a/pyrefly/lib/binding/function.rs +++ b/pyrefly/lib/binding/function.rs @@ -6,6 +6,7 @@ */ use std::mem; +use std::sync::Arc; use dupe::Dupe as _; use pyrefly_graph::index::Idx; @@ -16,6 +17,7 @@ use pyrefly_python::nesting_context::NestingContext; use pyrefly_python::short_identifier::ShortIdentifier; use pyrefly_python::sys_info::SysInfo; use pyrefly_types::callable::PlaceholderBodyKind; +use pyrefly_types::meta_shape_dsl::convert_shape_dsl_function; use pyrefly_util::prelude::VecExt; use pyrefly_util::visit::Visit; use ruff_python_ast::Decorator; @@ -70,6 +72,7 @@ use crate::binding::scope::UnusedParameter; use crate::binding::scope::UnusedVariable; use crate::binding::scope::YieldsAndReturns; use crate::config::base::InferReturnTypes; +use crate::config::error_kind::ErrorKind; use crate::export::special::SpecialExport; use crate::types::types::AnyStyle; @@ -804,6 +807,78 @@ impl<'a> BindingsBuilder<'a> { _ => (None, None), }; + // Check whether this function is decorated with `@shape_dsl_function` + // before `decorators()` takes the decorator list. + let is_shape_dsl = x.decorator_list.iter().any(|d| { + self.as_special_export(&d.expression) == Some(SpecialExport::ShapeDslFunction) + }); + + // Extract the IR function name from @uses_shape_dsl(ir_fn) if present. + let uses_shape_dsl_ir_name = x.decorator_list.iter().find_map(|d| { + let call = d.expression.as_call_expr()?; + if self.as_special_export(&call.func) != Some(SpecialExport::UsesShapeDsl) { + return None; + } + // The first positional argument is the IR function reference. + let first_arg = call.arguments.args.first()?; + // Must be a simple name (not a dotted path or arbitrary expression). + let name_expr = first_arg.as_name_expr()?; + Some((name_expr.id.clone(), ShortIdentifier::expr_name(name_expr))) + }); + + // Convert the function to DSL IR before `function_header` takes `returns` + // and before `function_body` takes `body`. + let shape_dsl_def = if is_shape_dsl { + // Warn about parameter kinds the DSL silently ignores. + if let Some(vararg) = &x.parameters.vararg { + self.error( + vararg.range(), + ErrorKind::InvalidArgument, + "@shape_dsl_function: *args parameters are not supported in the shape DSL and will be ignored".to_owned(), + ); + } + if let Some(kwarg) = &x.parameters.kwarg { + self.error( + kwarg.range(), + ErrorKind::InvalidArgument, + "@shape_dsl_function: **kwargs parameters are not supported in the shape DSL and will be ignored".to_owned(), + ); + } + if !x.parameters.kwonlyargs.is_empty() { + self.error( + x.parameters.kwonlyargs[0].range(), + ErrorKind::InvalidArgument, + "@shape_dsl_function: keyword-only parameters are not supported in the shape DSL and will be ignored".to_owned(), + ); + } + if !x.parameters.posonlyargs.is_empty() { + self.error( + x.parameters.posonlyargs[0].range(), + ErrorKind::InvalidArgument, + "@shape_dsl_function: positional-only parameters are not supported in the shape DSL and will be ignored".to_owned(), + ); + } + + match convert_shape_dsl_function(&x) { + Ok(dsl_fn) => { + let dsl_fn = Arc::new(dsl_fn); + self.metadata + .push_shape_dsl(func_name.id.clone(), Arc::clone(&dsl_fn)); + Some(dsl_fn) + } + Err(msg) => { + self.error( + x.range, + ErrorKind::InvalidArgument, + format!("@shape_dsl_function: {msg}"), + ); + None + } + } + } else { + None + }; + self.scopes.push(Scope::annotation(x.range)); let (return_ann_with_range, legacy_tparams) = self.function_header(&mut x, &func_name, class_key, def_idx.usage(), parent); @@ -844,6 +919,8 @@ impl<'a> BindingsBuilder<'a> { legacy_tparams: legacy_tparams.into_boxed_slice(), module_style: self.module_info.path().style(), outer_funcs, + shape_dsl_def, + uses_shape_dsl_ir_name, }, ); diff --git a/pyrefly/lib/binding/metadata.rs b/pyrefly/lib/binding/metadata.rs index 628c78c82b..fe3a38478b 100644 --- a/pyrefly/lib/binding/metadata.rs +++ b/pyrefly/lib/binding/metadata.rs @@ -11,8 +11,12 @@ //! This module stores per-class metadata (starting with field information) //! in a `Vec` indexed by `ClassDefIndex`, enabling efficient lookups. +use std::sync::Arc; + use pyrefly_types::class::ClassDefIndex; use pyrefly_types::class::ClassFields; +use pyrefly_types::meta_shape_dsl::ShapeDslFunction; +use ruff_python_ast::name::Name; /// Metadata for a single class definition, populated during binding. #[derive(Debug, Clone, Default)] @@ -38,12 +42,14 @@ pub struct ClassMetadata { #[derive(Debug, Clone)] pub struct BindingsMetadata { classes: Vec, + shape_dsl_functions: Vec<(Name, Arc)>, } impl BindingsMetadata { pub fn new() -> Self { Self { classes: Vec::new(), + shape_dsl_functions: Vec::new(), } } @@ -67,4 +73,15 @@ impl BindingsMetadata { pub fn get_class_mut(&mut self, idx: ClassDefIndex) -> &mut ClassMetadata { &mut self.classes[idx.0 as usize] } + + /// Record a `@shape_dsl_function` definition so sibling DSL functions + /// can be discovered for cross-function helper resolution. + pub fn push_shape_dsl(&mut self, name: Name, dsl_fn: Arc) { + self.shape_dsl_functions.push((name, dsl_fn)); + } + + /// All `@shape_dsl_function` definitions recorded in this module. + pub fn shape_dsl_functions(&self) -> &[(Name, Arc)] { + &self.shape_dsl_functions + } } diff --git a/pyrefly/lib/export/special.rs b/pyrefly/lib/export/special.rs index 3d38a6f98a..f8de550faf 100644 --- a/pyrefly/lib/export/special.rs +++ b/pyrefly/lib/export/special.rs @@ -71,6 +71,8 @@ pub enum SpecialExport { Final, TypingMapping, TypeForm, + UsesShapeDsl, + ShapeDslFunction, } impl SpecialExport { @@ -133,6 +135,8 @@ impl SpecialExport { "Final" => Some(Self::Final), "Mapping" => Some(Self::TypingMapping), "TypeForm" => Some(Self::TypeForm), + "uses_shape_dsl" => Some(Self::UsesShapeDsl), + "shape_dsl_function" => Some(Self::ShapeDslFunction), _ => None, } } @@ -204,6 +208,8 @@ impl SpecialExport { "typing" | "typing_extensions" | "collections.abc" ), Self::Deprecated => matches!(m.as_str(), "warnings" | "typing_extensions"), + Self::UsesShapeDsl => matches!(m.as_str(), "shape_extensions"), + Self::ShapeDslFunction => matches!(m.as_str(), "shape_extensions.dsl"), } } diff --git a/pyrefly/lib/report/binding_memory.rs b/pyrefly/lib/report/binding_memory.rs index 80661881df..fac7aa3c0d 100644 --- a/pyrefly/lib/report/binding_memory.rs +++ b/pyrefly/lib/report/binding_memory.rs @@ -166,6 +166,7 @@ mod tests { pydantic_config_dict: PydanticConfigDict::default(), pydantic_before_validator_fields: Box::default(), django_field_info: Box::default(), + capture_init: None, }; assert_eq!( ReportKey::new(module, &v), diff --git a/pyrefly/lib/test/mod.rs b/pyrefly/lib/test/mod.rs index 032d2bcbbe..7f9cc219bc 100644 --- a/pyrefly/lib/test/mod.rs +++ b/pyrefly/lib/test/mod.rs @@ -67,6 +67,7 @@ mod redundant_cast; mod returns; mod scope; mod semantic_syntax_errors; +mod shape_dsl; mod simple; mod slots; mod state; diff --git a/pyrefly/lib/test/shape_dsl.rs b/pyrefly/lib/test/shape_dsl.rs new file mode 100644 index 0000000000..9bc1045862 --- /dev/null +++ b/pyrefly/lib/test/shape_dsl.rs @@ -0,0 +1,181 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +use crate::test::util::TestEnv; +use crate::testcase; + +fn shape_dsl_env() -> TestEnv { + let path = std::env::var("SHAPE_DSL_TEST_PATH").expect("SHAPE_DSL_TEST_PATH must be set"); + let mut env = TestEnv::new_with_site_package_paths(&[&path]); + env.add_with_path( + "my_shapes", + "my_shapes.pyi", + r#" +from shape_extensions.dsl import shape_dsl_function + +@shape_dsl_function +def identity_ir(x: int) -> int: + return x + +@shape_dsl_function +def times_two(x: int) -> int: + return x + x + +@shape_dsl_function +def double_ir(x: int) -> int: + return times_two(x) + +def not_a_dsl_fn(x: int) -> int: ... + +@shape_dsl_function # E: @shape_dsl_function: unexpected statement in DSL body +def bad_syntax_ir(x: int) -> int: + while x > 0: + x = x - 1 + return x + +@shape_dsl_function +def kwargs_ir(x: int, **kwargs) -> int: # E: @shape_dsl_function: **kwargs parameters are not supported + return x + +@shape_dsl_function +def calls_undefined(x: int) -> int: # E: @shape_dsl_function type error: undefined function: nonexistent + return nonexistent(x) # E: Could not find name `nonexistent` +"#, + ); + env.add_with_path( + "my_lib", + "my_lib.pyi", + r#" +from typing import overload +from shape_extensions import uses_shape_dsl +from my_shapes import identity_ir, double_ir, not_a_dsl_fn, bad_syntax_ir, kwargs_ir + +@uses_shape_dsl(identity_ir) +def plain_fn(x: int) -> int: ... + +@overload +def overloaded_with_impl(x: int) -> int: ... +@overload +def overloaded_with_impl(x: str) -> str: ... +@uses_shape_dsl(identity_ir) +def overloaded_with_impl(x: int | str) -> int | str: ... + +@uses_shape_dsl(identity_ir) +@overload +def overloaded_no_impl(x: int) -> int: ... +@overload +def overloaded_no_impl(x: str) -> str: ... + +@uses_shape_dsl(double_ir) +def double_fn(x: int) -> int: ... + +@uses_shape_dsl(not_a_dsl_fn) # E: `@uses_shape_dsl` argument does not resolve to a `@shape_dsl_function` +def bad_fn(x: int) -> int: ... + +@uses_shape_dsl(bad_syntax_ir) # E: `@uses_shape_dsl` argument does not resolve to a `@shape_dsl_function` +def bad_syntax_fn(x: int) -> int: ... + +@uses_shape_dsl(kwargs_ir) +def kwargs_fn(x: int) -> int: ... + +"#, + ); + env +} + +testcase!( + test_uses_shape_dsl_preserves_type, + shape_dsl_env(), + r#" +from typing import Literal, assert_type +from my_lib import plain_fn + +# identity_ir returns its input unchanged. Because val_to_type synthesizes +# Literal[n] from the DSL's traced integer value (not the declared return +# type), the result is Literal[1], not int. +assert_type(plain_fn(1), Literal[1]) +"#, +); + +testcase!( + test_uses_shape_dsl_overload_with_implementation, + shape_dsl_env(), + r#" +from typing import Literal, assert_type +from my_lib import overloaded_with_impl + +assert_type(overloaded_with_impl(1), Literal[1]) +assert_type(overloaded_with_impl("a"), str) +"#, +); + +testcase!( + test_uses_shape_dsl_overload_no_implementation, + shape_dsl_env(), + r#" +from typing import Literal, assert_type +from my_lib import overloaded_no_impl + +assert_type(overloaded_no_impl(1), Literal[1]) +assert_type(overloaded_no_impl("a"), str) +"#, +); + +testcase!( + test_uses_shape_dsl_cross_function_call, + shape_dsl_env(), + r#" +from typing import Literal, assert_type +from my_lib import double_fn + +assert_type(double_fn(3), Literal[6]) +"#, +); + +testcase!( + test_uses_shape_dsl_not_a_dsl_function, + shape_dsl_env(), + r#" +from typing import assert_type +from my_lib import bad_fn + +# The @uses_shape_dsl argument is not a @shape_dsl_function, so no shape +# transform is applied and the declared return type (int) is used instead. +assert_type(bad_fn(1), int) +"#, +); + +testcase!( + test_shape_dsl_unsupported_syntax, + shape_dsl_env(), + r#" +from typing import assert_type +from my_lib import bad_syntax_fn + +# bad_syntax_ir uses a while loop which is unsupported DSL syntax, so +# bad_syntax_fn falls back to the declared return type. +assert_type(bad_syntax_fn(1), int) +"#, +); + +testcase!( + test_shape_dsl_kwargs_warning, + shape_dsl_env(), + r#" +from typing import Literal, assert_type +from my_lib import kwargs_fn + +# kwargs_ir has **kwargs which triggers a warning but the DSL conversion +# still succeeds (kwargs are silently dropped), so shape inference works. +assert_type(kwargs_fn(1), Literal[1]) +"#, +); + +// The `calls_undefined` function in my_shapes.pyi calls `nonexistent()`, +// which produces a type error diagnostic (tested by the `# E:` annotation +// on its definition). No separate test case is needed here — the error is +// validated by every test that uses `shape_dsl_env()`. diff --git a/test/tensor_shapes/fixtures/shape_extensions/__init__.py b/test/tensor_shapes/fixtures/shape_extensions/__init__.py index 52f764be46..990d24529f 100644 --- a/test/tensor_shapes/fixtures/shape_extensions/__init__.py +++ b/test/tensor_shapes/fixtures/shape_extensions/__init__.py @@ -149,3 +149,22 @@ def __iter__(self): @property def __typing_is_unpacked_typevartuple__(self): return True + + +def uses_shape_dsl( + ir_fn: typing.Callable, + *, + capture_init: list[str] | None = None, +) -> typing.Callable[[typing.Callable], typing.Callable]: + """Decorator that associates a shape DSL function with an API function. + + At runtime this is a no-op: the decorator arguments are ignored and the + decorated function is returned unchanged. Pyrefly uses this decorator + at type-checking time to route bound arguments through the shape DSL + for return-type refinement. + """ + + def decorator(fn: typing.Callable) -> typing.Callable: + return fn + + return decorator diff --git a/test/tensor_shapes/fixtures/shape_extensions/dsl.py b/test/tensor_shapes/fixtures/shape_extensions/dsl.py new file mode 100644 index 0000000000..3d8cfeee7e --- /dev/null +++ b/test/tensor_shapes/fixtures/shape_extensions/dsl.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-ignore-all-errors + +"""DSL internals for shape typing. + +Only used inside DSL definition files (e.g. torch/_shapes.pyi), not in +normal stubs or user code. +""" + +import typing + + +def shape_dsl_function(fn: typing.Callable) -> typing.Callable: + """Marks a function as a shape DSL function. + + At runtime this is a no-op: the decorated function is returned unchanged. + Pyrefly uses this decorator at type-checking time to convert the function + body to DSL IR via convert_fndef. + """ + return fn + + +def prod(xs: list[int]) -> int: + """Compute the product of a list of dimension sizes.""" + ... + + +def sum(xs: list[int]) -> int: + """Compute the sum of a list of dimension sizes.""" + ... + + +def parse_einsum_equation(spec: str) -> list[list[list[int]]]: + """Parse an einsum equation string into output map and check pairs.""" + ... diff --git a/test/tensor_shapes/fixtures/torch/__init__.pyi b/test/tensor_shapes/fixtures/torch/__init__.pyi index d0480d3b60..c17aeca161 100644 --- a/test/tensor_shapes/fixtures/torch/__init__.pyi +++ b/test/tensor_shapes/fixtures/torch/__init__.pyi @@ -19,6 +19,61 @@ For operations handled by meta-shapes, see pyrefly_types/src/meta_shape.rs: import builtins from typing import Any, overload, Self, TYPE_CHECKING +from shape_extensions import uses_shape_dsl +from torch._shapes import ( + aminmax_ir, + arange_ir, + broadcast_to_ir, + cat_ir, + chunk_ir, + diag_embed_ir, + dim_ir, + eig_ir, + einsum_ir, + expand_ir, + eye_ir, + flatten_ir, + index_select_ir, + item_ir, + linspace_ir, + matmul_ir, + min_max_median_ir, + movedim_ir, + multinomial_ir, + mv_ir, + narrow_ir, + normal_ir, + numel_ir, + outer_ir, + permute_ir, + randint_ir, + randn_ir, + reduce_ir, + repeat_interleave_ir, + repeat_ir, + reshape_ir, + select_ir, + size_ir, + slogdet_ir, + solve_ir, + solve_reversed_ir, + split_ir, + squeeze_ir, + stack_ir, + take_along_dim_ir, + tensordot_ir, + tile_ir, + tolist_ir, + topk_ir, + transpose_ir, + tri_indices_ir, + tuple_reduce_ir, + unbind_ir, + unfold_ir, + unsqueeze_ir, + where_ir, +) + if TYPE_CHECKING: from shape_extensions import Dim as _Dim @@ -95,6 +150,7 @@ class Tensor[*Shape]: # ==== Matrix Multiplication ==== # Uses meta-shape for shape inference + @uses_shape_dsl(matmul_ir) def __matmul__(self: Tensor, other: Tensor) -> Tensor: """Matrix multiplication (@). Shape inference via meta-shape: torch.Tensor.matmul""" ... @@ -134,6 +190,7 @@ class Tensor[*Shape]: # ==== Shape Manipulation Operations ==== # Handled by meta-shape functions - simplified signatures + @uses_shape_dsl(reshape_ir) @overload def reshape(self: Tensor, *shape: int) -> Tensor: """Reshape tensor. Shape inference via meta-shape: torch.Tensor.reshape""" @@ -144,6 +201,7 @@ class Tensor[*Shape]: """Reshape tensor. Shape inference via meta-shape: torch.Tensor.reshape""" ... + @uses_shape_dsl(reshape_ir) @overload def view(self: Tensor, *shape: int) -> Tensor: """View (alias for reshape). Shape inference via meta-shape: torch.Tensor.view""" @@ -154,14 +212,17 @@ class Tensor[*Shape]: """View (alias for reshape). Shape inference via meta-shape: torch.Tensor.view""" ... + @uses_shape_dsl(flatten_ir) def flatten(self: Tensor, start_dim: int = 0, end_dim: int = -1) -> Tensor: """Flatten dimensions. Shape inference via meta-shape: torch.flatten""" ... + @uses_shape_dsl(transpose_ir) def transpose(self: Tensor, dim0: int, dim1: int) -> Tensor: """Transpose two dimensions. Shape inference via meta-shape: torch.transpose""" ... + @uses_shape_dsl(permute_ir) @overload def permute(self: Tensor, *dims: int) -> Tensor: """Permute dimensions. Shape inference via meta-shape: torch.Tensor.permute""" @@ -172,14 +233,17 @@ class Tensor[*Shape]: """Permute dimensions. Shape inference via meta-shape: torch.Tensor.permute""" ... + @uses_shape_dsl(squeeze_ir) def squeeze(self: Tensor, dim: int | None = None) -> Tensor: """Remove dimensions of size 1. Shape inference via meta-shape: torch.squeeze""" ... + @uses_shape_dsl(unsqueeze_ir) def unsqueeze(self: Tensor, dim: int) -> Tensor: """Add dimension of size 1. Shape inference via meta-shape: torch.unsqueeze""" ... + @uses_shape_dsl(repeat_ir) @overload def repeat(self: Tensor, *sizes: int) -> Tensor: """Repeat tensor. Shape inference via meta-shape: torch.Tensor.repeat""" @@ -194,6 +258,7 @@ class Tensor[*Shape]: """Transpose 2D tensor. Swaps dimensions.""" ... + @uses_shape_dsl(expand_ir) def expand(self: Tensor, *sizes: int) -> Tensor: """Expand tensor. Shape inference via meta-shape: torch.Tensor.expand""" ... @@ -202,6 +267,7 @@ class Tensor[*Shape]: """Expand tensor to match the shape of `other`.""" ... + @uses_shape_dsl(repeat_interleave_ir) def repeat_interleave( self: Tensor, repeats: int | Tensor, dim: int | None = None ) -> Tensor: @@ -332,26 +398,32 @@ class Tensor[*Shape]: """Enable/disable gradient tracking in-place. Shape-preserving.""" ... + @uses_shape_dsl(item_ir) def item(self: Tensor) -> float | int: """Returns Python scalar from 0-dimensional tensor. Shape inference via meta-shape: torch.Tensor.item""" ... + @uses_shape_dsl(tolist_ir) def tolist(self: Tensor) -> Any: """Returns tensor as nested Python list. Shape inference via meta-shape: torch.Tensor.tolist""" ... + @uses_shape_dsl(tile_ir) def tile(self: Tensor, dims: tuple[int, ...]) -> Tensor: """Tile tensor. Shape inference via meta-shape: torch.Tensor.tile""" ... + @uses_shape_dsl(select_ir) def select(self: Tensor, dim: int, index: int) -> Tensor: """Select along dimension. Shape inference via meta-shape: torch.Tensor.select""" ... + @uses_shape_dsl(narrow_ir) def narrow(self: Tensor, dim: int, start: int, length: int) -> Tensor: """Narrow tensor along dimension. Shape inference via meta-shape: torch.Tensor.narrow""" ... + @uses_shape_dsl(split_ir) @overload def split( self: Tensor, split_size_or_sections: int, dim: int = 0 @@ -366,10 +438,12 @@ class Tensor[*Shape]: """Split tensor into variable-sized chunks. Shape inference via meta-shape: torch.Tensor.split""" ... + @uses_shape_dsl(chunk_ir) def chunk(self: Tensor, chunks: int, dim: int = 0) -> tuple[Tensor, ...]: """Split tensor into chunks. Shape inference via meta-shape: torch.Tensor.chunk""" ... + @uses_shape_dsl(index_select_ir) def index_select(self: Tensor, dim: int, index: Tensor) -> Tensor: """Select elements along dimension. Shape inference via meta-shape: torch.Tensor.index_select""" ... @@ -392,10 +466,12 @@ class Tensor[*Shape]: # ==== Phase 1.1: Missing Shape Operations (Methods) ==== + @uses_shape_dsl(unbind_ir) def unbind(self: Tensor, dim: int = 0) -> tuple[Tensor, ...]: """Remove dimension by slicing along it. Shape inference via meta-shape: torch.Tensor.unbind""" ... + @uses_shape_dsl(movedim_ir) @overload def movedim(self: Tensor, source: int, destination: int) -> Tensor: """Move single dimension to new position. Shape inference via meta-shape: torch.Tensor.movedim""" @@ -408,6 +484,7 @@ class Tensor[*Shape]: """Move multiple dimensions to new positions. Shape inference via meta-shape: torch.Tensor.movedim""" ... + @uses_shape_dsl(movedim_ir) @overload def moveaxis(self: Tensor, source: int, destination: int) -> Tensor: """Alias for movedim. Shape inference via meta-shape: torch.Tensor.moveaxis""" @@ -420,10 +497,12 @@ class Tensor[*Shape]: """Alias for movedim. Shape inference via meta-shape: torch.Tensor.moveaxis""" ... + @uses_shape_dsl(unfold_ir) def unfold(self: Tensor, dimension: int, size: int, step: int) -> Tensor: """Returns sliding window view. Shape inference via meta-shape: torch.Tensor.unfold""" ... + @uses_shape_dsl(size_ir) @overload def size(self: Tensor) -> tuple[builtins.int, ...]: """Returns the size of the tensor as a tuple. Shape inference via meta-shape: torch.Tensor.size""" @@ -437,6 +516,7 @@ class Tensor[*Shape]: # ==== Reduction Operations ==== # Handled by meta-shape functions - simplified signatures + @uses_shape_dsl(reduce_ir) @overload def sum(self: Tensor, dim: int | None = None, keepdim: bool = False) -> Tensor: """Sum along dimension(s). Shape inference via meta-shape: torch.Tensor.sum""" @@ -447,10 +527,12 @@ class Tensor[*Shape]: """Sum along multiple dimensions. Shape inference via meta-shape: torch.Tensor.sum""" ... + @uses_shape_dsl(reduce_ir) def mean(self: Tensor, dim: int | None = None, keepdim: bool = False) -> Tensor: """Mean along dimension(s). Shape inference via meta-shape: torch.mean""" ... + @uses_shape_dsl(min_max_median_ir) @overload def max(self: Tensor) -> Tensor: """Max of all elements (scalar). Shape inference via meta-shape: torch.Tensor.max""" @@ -461,6 +543,7 @@ class Tensor[*Shape]: """Max along dimension. Returns (values, indices). Shape inference via meta-shape: torch.Tensor.max""" ... + @uses_shape_dsl(min_max_median_ir) @overload def min(self: Tensor) -> Tensor: """Min of all elements (scalar). Shape inference via meta-shape: torch.Tensor.min""" @@ -471,28 +554,34 @@ class Tensor[*Shape]: """Min along dimension. Returns (values, indices). Shape inference via meta-shape: torch.Tensor.min""" ... + @uses_shape_dsl(reduce_ir) def prod(self: Tensor, dim: int | None = None, keepdim: bool = False) -> Tensor: """Product along dimension(s). Shape inference via meta-shape: torch.prod""" ... + @uses_shape_dsl(reduce_ir) def std(self: Tensor, dim: int | None = None, keepdim: bool = False) -> Tensor: """Standard deviation along dimension(s). Shape inference via meta-shape: torch.std""" ... + @uses_shape_dsl(reduce_ir) def var(self: Tensor, dim: int | None = None, keepdim: bool = False) -> Tensor: """Variance along dimension(s). Shape inference via meta-shape: torch.var""" ... + @uses_shape_dsl(reduce_ir) def argmax(self: Tensor, dim: int | None = None, keepdim: bool = False) -> Tensor: """Argmax along dimension(s). Shape inference via meta-shape: torch.argmax""" ... + @uses_shape_dsl(reduce_ir) def argmin(self: Tensor, dim: int | None = None, keepdim: bool = False) -> Tensor: """Argmin along dimension(s). Shape inference via meta-shape: torch.argmin""" ... # ==== Phase 1.2: Missing Reduction Operations (Methods) ==== + @uses_shape_dsl(min_max_median_ir) @overload def median(self: Tensor) -> Tensor: """Median of all elements (scalar). Shape inference via meta-shape: torch.Tensor.median""" @@ -503,22 +592,26 @@ class Tensor[*Shape]: """Median along dimension. Returns (values, indices). Shape inference via meta-shape: torch.Tensor.median""" ... + @uses_shape_dsl(reduce_ir) def logsumexp( self: Tensor, dim: int | None = None, keepdim: bool = False ) -> Tensor: """Log-sum-exp along dimension(s). Shape inference via meta-shape: torch.Tensor.logsumexp""" ... + @uses_shape_dsl(reduce_ir) def count_nonzero(self: Tensor, dim: int | None = None) -> Tensor: """Count non-zero elements. Shape inference via meta-shape: torch.Tensor.count_nonzero""" ... + @uses_shape_dsl(aminmax_ir) def aminmax( self: Tensor, dim: int | None = None, keepdim: bool = False ) -> tuple[Tensor, Tensor]: """Min and max along dimension(s). Shape inference via meta-shape: torch.Tensor.aminmax""" ... + @uses_shape_dsl(reduce_ir) def norm( self: Tensor, p: int | float = 2, @@ -554,12 +647,14 @@ class Tensor[*Shape]: # ==== Tier 2: Additional Reduction Methods ==== + @uses_shape_dsl(tuple_reduce_ir) def mode( self: Tensor, dim: int = -1, keepdim: bool = False ) -> tuple[Tensor, Tensor]: """Mode along dimension. Returns (values, indices). Shape inference via meta-shape: torch.Tensor.mode""" ... + @uses_shape_dsl(topk_ir) def topk( self: Tensor, k: int, dim: int = -1, largest: bool = True, sorted: bool = True ) -> tuple[Tensor, Tensor]: @@ -575,6 +670,7 @@ class Tensor[*Shape]: """Sort tensor. Returns (values, indices). Shape-preserving operation.""" ... + @uses_shape_dsl(tuple_reduce_ir) def kthvalue( self: Tensor, k: int, dim: int = -1, keepdim: bool = False ) -> tuple[Tensor, Tensor]: @@ -583,6 +679,7 @@ class Tensor[*Shape]: # ==== Phase 1.3: Tensor Creation Operations (Methods) ==== + @uses_shape_dsl(diag_embed_ir) def diag_embed( self: Tensor, offset: int = 0, dim1: int = -2, dim2: int = -1 ) -> Tensor: @@ -599,6 +696,7 @@ class Tensor[*Shape]: # ==== Phase 1.4: Basic Linear Algebra Operations (Methods) ==== + @uses_shape_dsl(matmul_ir) def matmul(self: Tensor, other: Tensor) -> Tensor: """Matrix multiplication. Shape inference via meta-shape: torch.Tensor.matmul""" ... @@ -613,6 +711,7 @@ class Tensor[*Shape]: """Batch matrix multiplication (3D @ 3D). Output: [B, N, M].""" ... + @uses_shape_dsl(mv_ir) def mv(self: Tensor, vec: Tensor) -> Tensor: """Matrix-vector multiplication. Shape inference via meta-shape: torch.Tensor.mv""" ... @@ -976,6 +1075,7 @@ class Tensor[*Shape]: """Log determinant. Returns batch dimensions only (drops last 2 dims).""" ... + @uses_shape_dsl(slogdet_ir) def slogdet(self: Tensor) -> tuple[Tensor, Tensor]: """Sign and log determinant. Shape inference via meta-shape: torch.Tensor.slogdet""" ... @@ -1058,6 +1158,7 @@ class Tensor[*Shape]: """Take elements at indices. Output shape matches index shape.""" ... + @uses_shape_dsl(take_along_dim_ir) def take_along_dim(self: Tensor, indices: Tensor, dim: int) -> Tensor: """Take along dimension. Shape inference via meta-shape: torch.Tensor.take_along_dim""" ... @@ -1080,6 +1181,7 @@ class Tensor[*Shape]: """Sample from Bernoulli distribution in-place. Shape inference via generic fixture signature.""" ... + @uses_shape_dsl(multinomial_ir) def multinomial( self: Tensor, num_samples: int, replacement: bool = False ) -> Tensor: @@ -1098,14 +1200,17 @@ class Tensor[*Shape]: """Fill with uniform distribution in-place. Shape inference via generic fixture signature.""" ... + @uses_shape_dsl(numel_ir) def numel(self: Tensor) -> int: """Number of elements. Shape inference via meta-shape: torch.Tensor.numel""" ... + @uses_shape_dsl(dim_ir) def dim(self: Tensor) -> int: """Number of dimensions. Shape inference via meta-shape: torch.Tensor.dim""" ... + @uses_shape_dsl(numel_ir) def nelement(self: Tensor) -> int: """Number of elements. Shape inference via meta-shape: torch.Tensor.nelement""" ... @@ -1114,46 +1219,57 @@ class Tensor[*Shape]: # Module-level Functions # ============================================================================ +@uses_shape_dsl(matmul_ir) def matmul(self: Tensor, other: Tensor) -> Tensor: """Matrix multiplication function. Shape inference via meta-shape: torch.matmul""" ... +@uses_shape_dsl(cat_ir) def cat(tensors: list[Tensor] | tuple[Tensor, ...], dim: int = 0) -> Tensor: """Concatenate tensors. Shape inference via meta-shape: torch.cat""" ... +@uses_shape_dsl(stack_ir) def stack(tensors: list[Tensor] | tuple[Tensor, ...], dim: int = 0) -> Tensor: """Stack tensors (adds new dimension).""" ... +@uses_shape_dsl(transpose_ir) def transpose(self: Tensor, dim0: int, dim1: int) -> Tensor: """Transpose two dimensions. Shape inference via meta-shape: torch.transpose""" ... +@uses_shape_dsl(reshape_ir) def reshape(self: Tensor, shape: tuple[int, ...]) -> Tensor: """Reshape tensor. Shape inference via meta-shape: torch.reshape""" ... +@uses_shape_dsl(squeeze_ir) def squeeze(self: Tensor, dim: int | None = None) -> Tensor: """Remove dimensions of size 1. Shape inference via meta-shape: torch.squeeze""" ... +@uses_shape_dsl(unsqueeze_ir) def unsqueeze(self: Tensor, dim: int) -> Tensor: """Add dimension of size 1. Shape inference via meta-shape: torch.unsqueeze""" ... +@uses_shape_dsl(permute_ir) def permute(self: Tensor, dims: tuple[int, ...]) -> Tensor: """Permute dimensions. Shape inference via meta-shape: torch.permute""" ... +@uses_shape_dsl(reduce_ir) def sum(self: Tensor, dim: int | None = None, keepdim: bool = False) -> Tensor: """Sum along dimension(s). Shape inference via meta-shape: torch.sum""" ... +@uses_shape_dsl(reduce_ir) def mean(self: Tensor, dim: int | None = None, keepdim: bool = False) -> Tensor: """Mean along dimension(s). Shape inference via meta-shape: torch.mean""" ... +@uses_shape_dsl(min_max_median_ir) @overload def max(self: Tensor) -> Tensor: """Max of all elements (scalar). Shape inference via meta-shape: torch.max""" @@ -1164,6 +1280,7 @@ def max(self: Tensor, dim: int, keepdim: bool = False) -> tuple[Tensor, Tensor]: """Max along dimension. Returns (values, indices). Shape inference via meta-shape: torch.max""" ... +@uses_shape_dsl(min_max_median_ir) @overload def min(self: Tensor) -> Tensor: """Min of all elements (scalar). Shape inference via meta-shape: torch.min""" @@ -1179,36 +1296,44 @@ def min[*S](input: Tensor[*S], other: Tensor) -> Tensor[*S]: """Element-wise minimum of two tensors.""" ... +@uses_shape_dsl(reduce_ir) def prod(self: Tensor, dim: int | None = None, keepdim: bool = False) -> Tensor: """Product along dimension(s). Shape inference via meta-shape: torch.prod""" ... +@uses_shape_dsl(reduce_ir) def std(self: Tensor, dim: int | None = None, keepdim: bool = False) -> Tensor: """Standard deviation. Shape inference via meta-shape: torch.std""" ... +@uses_shape_dsl(reduce_ir) def var(self: Tensor, dim: int | None = None, keepdim: bool = False) -> Tensor: """Variance. Shape inference via meta-shape: torch.var""" ... +@uses_shape_dsl(reduce_ir) def argmax(self: Tensor, dim: int | None = None, keepdim: bool = False) -> Tensor: """Argmax. Shape inference via meta-shape: torch.argmax""" ... +@uses_shape_dsl(reduce_ir) def argmin(self: Tensor, dim: int | None = None, keepdim: bool = False) -> Tensor: """Argmin. Shape inference via meta-shape: torch.argmin""" ... +@uses_shape_dsl(flatten_ir) def flatten(self: Tensor, start_dim: int = 0, end_dim: int = -1) -> Tensor: """Flatten dimensions. Shape inference via meta-shape: torch.flatten""" ... +@uses_shape_dsl(stack_ir) def stack(tensors: list[Tensor] | tuple[Tensor, ...], dim: int = 0) -> Tensor: """Stack tensors. Shape inference via meta-shape: torch.stack""" ... # ==== Tensor Creation Functions ==== +@uses_shape_dsl(randn_ir) @overload def randn(*size: int, dtype: Any = None, device: Any = None) -> Tensor: """Create tensor with random values. Shape inference via meta-shape: torch.randn""" @@ -1219,6 +1344,7 @@ def randn(size: tuple[int, ...], dtype: Any = None, device: Any = None) -> Tenso """Create tensor with random values (tuple size). Shape inference via meta-shape: torch.randn""" ... +@uses_shape_dsl(randn_ir) @overload def rand(*size: int, dtype: Any = None, device: Any = None) -> Tensor: """Create tensor with random values [0, 1). Shape inference via meta-shape: torch.rand""" @@ -1229,6 +1355,7 @@ def rand(size: tuple[int, ...], dtype: Any = None, device: Any = None) -> Tensor """Create tensor with random values (tuple size). Shape inference via meta-shape: torch.rand""" ... +@uses_shape_dsl(randn_ir) @overload def zeros(*size: int, dtype: Any = None, device: Any = None) -> Tensor: """Create tensor filled with zeros. Shape inference via meta-shape: torch.zeros""" @@ -1239,6 +1366,7 @@ def zeros(size: tuple[int, ...], dtype: Any = None, device: Any = None) -> Tenso """Create tensor filled with zeros (tuple size). Shape inference via meta-shape: torch.zeros""" ... +@uses_shape_dsl(randn_ir) @overload def ones(*size: int, dtype: Any = None, device: Any = None) -> Tensor: """Create tensor filled with ones. Shape inference via meta-shape: torch.ones""" @@ -1249,6 +1377,7 @@ def ones(size: tuple[int, ...], dtype: Any = None, device: Any = None) -> Tensor """Create tensor filled with ones (tuple size). Shape inference via meta-shape: torch.ones""" ... +@uses_shape_dsl(randn_ir) @overload def empty(*size: int, dtype: Any = None, device: Any = None) -> Tensor: """Create uninitialized tensor. Shape inference via meta-shape: torch.empty""" @@ -1259,11 +1388,13 @@ def empty(size: tuple[int, ...], dtype: Any = None, device: Any = None) -> Tenso """Create uninitialized tensor (tuple size). Shape inference via meta-shape: torch.empty""" ... +@uses_shape_dsl(randn_ir) def full(size: tuple[int, ...], fill_value: float) -> Tensor: """Create tensor filled with value. Shape inference via meta-shape: torch.full""" ... # arange overloads - Dim is compatible with int, so meta-shape handles both +@uses_shape_dsl(arange_ir) @overload def arange(end: int) -> Tensor: """Create 1D tensor with range [0, end). Shape inference via meta-shape: torch.arange""" @@ -1291,44 +1422,53 @@ def arange( """Create 1D tensor with range [start, end) with step. Shape inference via meta-shape: torch.arange""" ... +@uses_shape_dsl(linspace_ir) def linspace( start: float, end: float, steps: int, *, dtype: Any = None, device: Any = None ) -> Tensor: """Create 1D tensor with linearly spaced values. Shape inference via meta-shape: torch.linspace""" ... +@uses_shape_dsl(eye_ir) def eye(n: int) -> Tensor: """Create 2D identity matrix. Shape inference via meta-shape: torch.eye""" ... # ==== Shape Manipulation Functions ==== +@uses_shape_dsl(broadcast_to_ir) def broadcast_to(self: Tensor, shape: tuple[int, ...]) -> Tensor: """Broadcast tensor to shape. Shape inference via meta-shape: torch.broadcast_to""" ... +@uses_shape_dsl(tile_ir) def tile(self: Tensor, dims: tuple[int, ...]) -> Tensor: """Tile tensor by repeating. Shape inference via meta-shape: torch.tile""" ... +@uses_shape_dsl(select_ir) def select(self: Tensor, dim: int, index: int) -> Tensor: """Select along dimension. Shape inference via meta-shape: torch.select""" ... +@uses_shape_dsl(narrow_ir) def narrow(self: Tensor, dim: int, start: int, length: int) -> Tensor: """Narrow tensor along dimension. Shape inference via meta-shape: torch.narrow""" ... +@uses_shape_dsl(split_ir) def split( self: Tensor, split_size_or_sections: int, dim: int = 0 ) -> tuple[Tensor, ...]: """Split tensor into chunks. Shape inference via meta-shape: torch.split""" ... +@uses_shape_dsl(chunk_ir) def chunk(self: Tensor, chunks: int, dim: int = 0) -> tuple[Tensor, ...]: """Split tensor into chunks. Shape inference via meta-shape: torch.chunk""" ... +@uses_shape_dsl(index_select_ir) def index_select(self: Tensor, dim: int, index: Tensor) -> Tensor: """Select elements along dimension. Shape inference via meta-shape: torch.index_select""" ... @@ -1351,10 +1491,12 @@ def masked_select(self: Tensor, mask: Tensor) -> Tensor[Any]: # ==== Phase 1.1: Missing Shape Operations ==== +@uses_shape_dsl(unbind_ir) def unbind(self: Tensor, dim: int = 0) -> tuple[Tensor, ...]: """Remove dimension by slicing along it. Shape inference via meta-shape: torch.unbind""" ... +@uses_shape_dsl(movedim_ir) @overload def movedim(self: Tensor, source: int, destination: int) -> Tensor: """Move single dimension to new position. Shape inference via meta-shape: torch.movedim""" @@ -1367,6 +1509,7 @@ def movedim( """Move multiple dimensions to new positions. Shape inference via meta-shape: torch.movedim""" ... +@uses_shape_dsl(movedim_ir) @overload def moveaxis(self: Tensor, source: int, destination: int) -> Tensor: """Alias for movedim. Shape inference via meta-shape: torch.moveaxis""" @@ -1379,22 +1522,26 @@ def moveaxis( """Alias for movedim. Shape inference via meta-shape: torch.moveaxis""" ... +@uses_shape_dsl(unfold_ir) def unfold(self: Tensor, dimension: int, size: int, step: int) -> Tensor: """Returns sliding window view. Shape inference via meta-shape: torch.unfold""" ... # ==== Additional Reduction Functions ==== +@uses_shape_dsl(reduce_ir) def all(self: Tensor, dim: int | None = None, keepdim: bool = False) -> Tensor: """Check if all elements are True. Shape inference via meta-shape: torch.all""" ... +@uses_shape_dsl(reduce_ir) def any(self: Tensor, dim: int | None = None, keepdim: bool = False) -> Tensor: """Check if any element is True. Shape inference via meta-shape: torch.any""" ... # ==== Phase 1.2: Missing Reduction Operations ==== +@uses_shape_dsl(min_max_median_ir) @overload def median(self: Tensor) -> Tensor: """Median of all elements (scalar). Shape inference via meta-shape: torch.median""" @@ -1405,20 +1552,24 @@ def median(self: Tensor, dim: int, keepdim: bool = False) -> tuple[Tensor, Tenso """Median along dimension. Returns (values, indices). Shape inference via meta-shape: torch.median""" ... +@uses_shape_dsl(reduce_ir) def logsumexp(self: Tensor, dim: int | None = None, keepdim: bool = False) -> Tensor: """Log-sum-exp along dimension(s). Shape inference via meta-shape: torch.logsumexp""" ... +@uses_shape_dsl(reduce_ir) def count_nonzero(self: Tensor, dim: int | None = None) -> Tensor: """Count non-zero elements. Shape inference via meta-shape: torch.count_nonzero""" ... +@uses_shape_dsl(aminmax_ir) def aminmax( self: Tensor, dim: int | None = None, keepdim: bool = False ) -> tuple[Tensor, Tensor]: """Min and max along dimension(s). Shape inference via meta-shape: torch.aminmax""" ... +@uses_shape_dsl(reduce_ir) def norm( self: Tensor, p: int | float = 2, @@ -1453,10 +1604,12 @@ def cummin[*Shape]( ... # Tier 2: Additional reduction operations (always return tuples) +@uses_shape_dsl(tuple_reduce_ir) def mode(self: Tensor, dim: int = -1, keepdim: bool = False) -> tuple[Tensor, Tensor]: """Mode along dimension. Returns (values, indices). Shape inference via meta-shape: torch.mode""" ... +@uses_shape_dsl(topk_ir) def topk( self: Tensor, k: int, dim: int = -1, largest: bool = True, sorted: bool = True ) -> tuple[Tensor, Tensor]: @@ -1469,6 +1622,7 @@ def sort[*Shape]( """Sort tensor. Returns (values, indices). Shape-preserving operation.""" ... +@uses_shape_dsl(tuple_reduce_ir) def kthvalue( self: Tensor, k: int, dim: int = -1, keepdim: bool = False ) -> tuple[Tensor, Tensor]: @@ -1476,6 +1630,7 @@ def kthvalue( ... # Tier 3: Statistical operations returning tuples +@uses_shape_dsl(aminmax_ir) def var_mean( self: Tensor, dim: int | tuple[int, ...] | None = None, @@ -1485,6 +1640,7 @@ def var_mean( """Variance and mean. Returns (var, mean). Shape inference via meta-shape: torch.var_mean""" ... +@uses_shape_dsl(aminmax_ir) def std_mean( self: Tensor, dim: int | tuple[int, ...] | None = None, @@ -1520,6 +1676,7 @@ def randn_like[*Shape](input: Tensor[*Shape]) -> Tensor[*Shape]: """Create random normal tensor with same shape. Shape inference via generic fixture signature.""" ... +@uses_shape_dsl(diag_embed_ir) def diag_embed(self: Tensor, offset: int = 0, dim1: int = -2, dim2: int = -1) -> Tensor: """Create diagonal tensor. Shape inference via meta-shape: torch.diag_embed""" ... @@ -1532,10 +1689,12 @@ def triu[*Shape](input: Tensor[*Shape], diagonal: int = 0) -> Tensor[*Shape]: """Upper triangular part. Shape inference via generic fixture signature.""" ... +@uses_shape_dsl(tri_indices_ir) def tril_indices(row: int, col: int, offset: int = 0) -> Tensor: """Indices of lower triangular part. Shape inference via meta-shape: torch.tril_indices""" ... +@uses_shape_dsl(tri_indices_ir) def triu_indices(row: int, col: int, offset: int = 0) -> Tensor: """Indices of upper triangular part. Shape inference via meta-shape: torch.triu_indices""" ... @@ -1553,6 +1712,7 @@ def bmm[B, N, K, M](input: Tensor[B, N, K], mat2: Tensor[B, K, M]) -> Tensor[B, """Batch matrix multiplication (3D @ 3D). Output: [B, N, M].""" ... +@uses_shape_dsl(mv_ir) def mv(self: Tensor, vec: Tensor) -> Tensor: """Matrix-vector multiplication (2D @ 1D). Shape inference via meta-shape: torch.mv""" ... @@ -1839,21 +1999,25 @@ def fmin[*Shape](input: Tensor[*Shape], other: Tensor) -> Tensor[*Shape]: # ============================================================================== # Advanced matmul operations +@uses_shape_dsl(tensordot_ir) def tensordot( self: Tensor, other: Tensor, dims: int | tuple[list[int], list[int]] = 2 ) -> Tensor: """Tensor contraction over specified dimensions. Shape inference via meta-shape: torch.tensordot""" ... +@uses_shape_dsl(einsum_ir) def einsum(spec: str, *operands: Tensor) -> Tensor: """Einstein summation convention. Shape inference via meta-shape: torch.einsum""" ... # Eigenvalue decomposition +@uses_shape_dsl(eig_ir) def eig(self: Tensor, eigenvectors: bool = False) -> tuple[Tensor, Tensor]: """Eigenvalue decomposition. Shape inference via meta-shape: torch.eig""" ... +@uses_shape_dsl(eig_ir) def eigh(self: Tensor, UPLO: str = "L") -> tuple[Tensor, Tensor]: """Hermitian eigenvalue decomposition. Shape inference via meta-shape: torch.eigh""" ... @@ -1864,18 +2028,22 @@ def cholesky[*Shape](input: Tensor[*Shape], upper: bool = False) -> Tensor[*Shap ... # Linear system solvers +@uses_shape_dsl(solve_ir) def solve(self: Tensor, other: Tensor) -> Tensor: """Solve linear system. Shape inference via meta-shape: torch.solve""" ... +@uses_shape_dsl(solve_reversed_ir) def triangular_solve(self: Tensor, other: Tensor, upper: bool = True) -> Tensor: """Solve triangular system. Shape inference via meta-shape: torch.triangular_solve""" ... +@uses_shape_dsl(solve_reversed_ir) def cholesky_solve(self: Tensor, other: Tensor, upper: bool = False) -> Tensor: """Solve using Cholesky. Shape inference via meta-shape: torch.cholesky_solve""" ... +@uses_shape_dsl(solve_ir) def lu_solve(self: Tensor, other: Tensor, LU_pivots: Tensor) -> Tensor: """Solve using LU decomposition. Shape inference via meta-shape: torch.lu_solve""" ... @@ -1894,6 +2062,7 @@ def logdet[*Batch, M, N](input: Tensor[*Batch, M, N]) -> Tensor[*Batch]: """Log determinant. Returns batch dimensions only (drops last 2 dims).""" ... +@uses_shape_dsl(slogdet_ir) def slogdet(self: Tensor) -> tuple[Tensor, Tensor]: """Sign and log determinant. Shape inference via meta-shape: torch.slogdet""" ... @@ -1924,6 +2093,7 @@ def matrix_rank[*Batch, M, N]( # ============================================================================== # Conditional operations +@uses_shape_dsl(where_ir) def where(condition: Tensor, x: Tensor, y: Tensor) -> Tensor: """Conditional element-wise selection. Shape inference via meta-shape: torch.where""" ... @@ -1973,6 +2143,7 @@ def take[*IndexShape](input: Tensor, index: Tensor[*IndexShape]) -> Tensor[*Inde """Take elements at indices. Output shape matches index shape.""" ... +@uses_shape_dsl(take_along_dim_ir) def take_along_dim(self: Tensor, indices: Tensor, dim: int) -> Tensor: """Take along dimension. Shape inference via meta-shape: torch.take_along_dim""" ... @@ -1992,10 +2163,12 @@ def bernoulli[*Shape](input: Tensor[*Shape], p: float = 0.5) -> Tensor[*Shape]: """Sample from Bernoulli distribution. Shape inference via generic fixture signature.""" ... +@uses_shape_dsl(multinomial_ir) def multinomial(self: Tensor, num_samples: int, replacement: bool = False) -> Tensor: """Sample from multinomial distribution. Shape inference via meta-shape: torch.multinomial""" ... +@uses_shape_dsl(normal_ir) @overload def normal(mean: Tensor, std: Tensor) -> Tensor: """Sample from normal distribution (tensor mean, tensor std). Shape inference via meta-shape: torch.normal""" @@ -2021,6 +2194,7 @@ def poisson[*Shape](input: Tensor[*Shape]) -> Tensor[*Shape]: ... # Tensor property functions +@uses_shape_dsl(numel_ir) def numel[*Dims](self: Tensor[*Dims]) -> int: """Number of elements. Shape inference via meta-shape: torch.numel""" ... @@ -2056,6 +2230,7 @@ def tensor( """Create tensor from data. Returns shapeless tensor (shape depends on input data).""" ... +@uses_shape_dsl(randint_ir) def randint( low: int, high: int, @@ -2076,6 +2251,7 @@ def rsqrt[*Shape](input: Tensor[*Shape]) -> Tensor[*Shape]: """Reciprocal square root (1/sqrt(x)). Shape-preserving element-wise operation.""" ... +@uses_shape_dsl(outer_ir) def outer(self: Tensor, vec2: Tensor) -> Tensor: """Outer product of two 1D tensors. Shape inference via meta-shape: torch.outer""" ... @@ -2142,6 +2318,7 @@ def cross[*B]( """Cross product of two tensors along a dimension of size 3.""" ... +@uses_shape_dsl(flatten_ir) def flatten( self: Tensor, start_dim: int = 0, diff --git a/test/tensor_shapes/fixtures/torch/_shapes.pyi b/test/tensor_shapes/fixtures/torch/_shapes.pyi new file mode 100644 index 0000000000..56d8b37f13 --- /dev/null +++ b/test/tensor_shapes/fixtures/torch/_shapes.pyi @@ -0,0 +1,805 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from shape_extensions.dsl import shape_dsl_function + +@shape_dsl_function +def normalize_dim(rank: int, dim: int) -> int: + if dim < 0: + return dim + rank + return dim + +@shape_dsl_function +def int_max(a: int, b: int) -> int: + if a > b: + return a + return b + +@shape_dsl_function +def replace_dim( + dims: list[int | symint], i: int, value: int | symint +) -> list[int | symint]: + return dims[:i] + [value] + dims[i + 1 :] + +@shape_dsl_function +def remove_dim(dims: list[int | symint], i: int) -> list[int | symint]: + return dims[:i] + dims[i + 1 :] + +@shape_dsl_function +def insert_dim( + dims: list[int | symint], i: int, value: int | symint +) -> list[int | symint]: + return dims[:i] + [value] + dims[i:] + +@shape_dsl_function +def broadcast(a: list[int | symint], b: list[int | symint]) -> list[int | symint]: + max_len = int_max(len(a), len(b)) + padded_a = [1 for _ in range(max_len - len(a))] + a + padded_b = [1 for _ in range(max_len - len(b))] + b + return [bd if ad == 1 else ad for ad, bd in zip(padded_a, padded_b)] + +@shape_dsl_function +def broadcast_int( + expr: int | symint | list[int | symint], n: int +) -> list[int | symint]: + if isinstance(expr, list): + return expr + return [expr for _ in range(n)] + +@shape_dsl_function +def reduce_shape( + dims: list[int | symint], dim: int | list[int] | None, keepdim: bool +) -> list[int | symint]: + if dim == None: + if keepdim: + return [1 for _ in range(len(dims))] + return [] + dim_list = dim if isinstance(dim, list) else [dim] + norm = [normalize_dim(len(dims), d) for d in dim_list] + return [ + 1 if i in norm else elem + for i, elem in enumerate(dims) + if not (i in norm) or keepdim + ] + +@shape_dsl_function +def contains(lst: list[int], val: int) -> bool: + return len([x for x in lst if x == val]) > 0 + +@shape_dsl_function +def scatter(size: int, indices: list[int], values: list[int], fill: int) -> list[int]: + matches = [[k for k in range(len(indices)) if indices[k] == i] for i in range(size)] + return [values[m[0]] if len(m) > 0 else fill for m in matches] + +@shape_dsl_function +def move_dims( + dims: list[int | symint], source: int | list[int], dest: int | list[int], rank: int +) -> list[int | symint]: + src = broadcast_int(source, 1) + dst = broadcast_int(dest, 1) + src_norm = [normalize_dim(rank, s) for s in src] + dst_norm = [normalize_dim(rank, d) for d in dst] + non_dst = [i for i in range(rank) if not contains(dst_norm, i)] + remaining = [i for i in range(rank) if not contains(src_norm, i)] + perm = scatter(rank, dst_norm + non_dst, src_norm + remaining, 0) + return [dims[p] for p in perm] + +@shape_dsl_function +def conv_spatial_out( + input_dim: int | symint, + kernel: int | symint, + stride: int | symint, + padding: int | symint, + dilation: int | symint, +) -> int | symint: + return (input_dim + 2 * padding - dilation * (kernel - 1) - 1) // stride + 1 + +@shape_dsl_function +def reshape_ir(self: Tensor, shape: list[int | symint]) -> Tensor: + minus_one_count = len([d for d in shape if d == -1]) + if minus_one_count > 1: + raise Error("can only specify one unknown dimension as -1") + has_bad_neg = len([d for d in shape if isinstance(d, int) and d < -1]) > 0 + if has_bad_neg: + raise Error("invalid negative dimension value (only -1 is allowed)") + has_zero = len([d for d in shape if isinstance(d, int) and d == 0]) > 0 + if has_zero: + raise Error("reshape dimensions cannot contain 0") + if minus_one_count > 0: + known = shape_extensions.dsl.prod([d for d in shape if d != -1]) + total = shape_extensions.dsl.prod(self.shape) + if isinstance(total, int) and isinstance(known, int) and total % known != 0: + raise Error( + "could not infer size for dimension -1: expected " + + str(total) + + " to be divisible by " + + str(known) + ) + return Tensor(shape=[total // known if d == -1 else d for d in shape]) + return Tensor(shape=shape) + +@shape_dsl_function +def squeeze_ir(self: Tensor, dim: int | None = None) -> Tensor: + if dim == None: + return Tensor(shape=[d for d in self.shape if d != 1]) + idx = normalize_dim(len(self.shape), dim) + return Tensor( + shape=[d for i, d in enumerate(self.shape) if not (i == idx and d == 1)] + ) + +@shape_dsl_function +def unsqueeze_ir(self: Tensor, dim: int) -> Tensor: + d = normalize_dim(len(self.shape) + 1, dim) + return Tensor(shape=insert_dim(self.shape, d, 1)) + +@shape_dsl_function +def transpose_ir(self: Tensor, dim0: int, dim1: int) -> Tensor: + rank = len(self.shape) + d0 = normalize_dim(rank, dim0) + d1 = normalize_dim(rank, dim1) + return Tensor( + shape=[ + self.shape[d1] if i == d0 else self.shape[d0] if i == d1 else d + for i, d in enumerate(self.shape) + ] + ) + +@shape_dsl_function +def permute_ir(self: Tensor, dims: list[int]) -> Tensor: + rank = len(self.shape) + if len(dims) != rank: + raise Error("permute: expected " + str(rank) + " dims, got " + str(len(dims))) + return Tensor(shape=[self.shape[normalize_dim(rank, d)] for d in dims]) + +@shape_dsl_function +def flatten_ir(self: Tensor, start_dim: int = 0, end_dim: int = -1) -> Tensor: + rank = len(self.shape) + s = normalize_dim(rank, start_dim) + e = normalize_dim(rank, end_dim) + return Tensor( + shape=self.shape[:s] + + [shape_extensions.dsl.prod(self.shape[s : e + 1])] + + self.shape[e + 1 :] + ) + +@shape_dsl_function +def expand_ir(self: Tensor, sizes: list[int | symint]) -> Tensor: + return Tensor(shape=[d if t == -1 else t for d, t in zip(self.shape, sizes)]) + +@shape_dsl_function +def repeat_ir(self: Tensor, sizes: list[int | symint]) -> Tensor: + return Tensor(shape=[d * r for d, r in zip(self.shape, sizes)]) + +@shape_dsl_function +def unbind_ir(self: Tensor, dim: int = 0) -> list[Tensor]: + d = normalize_dim(len(self.shape), dim) + return [Tensor(shape=remove_dim(self.shape, d)), ...] + +@shape_dsl_function +def movedim_ir( + self: Tensor, source: int | list[int], destination: int | list[int] +) -> Tensor: + return Tensor(shape=move_dims(self.shape, source, destination, len(self.shape))) + +@shape_dsl_function +def unfold_ir( + self: Tensor, dimension: int, size: int | symint, step: int = 1 +) -> Tensor: + d = normalize_dim(len(self.shape), dimension) + new_dim = (self.shape[d] - size) // step + 1 + return Tensor(shape=replace_dim(self.shape, d, new_dim) + [size]) + +@shape_dsl_function +def cat_ir(tensors: list[Tensor], dim: int = 0) -> Tensor: + first = tensors[0] + d = normalize_dim(len(first.shape), dim) + return Tensor( + shape=[ + shape_extensions.dsl.sum([t.shape[i] for t in tensors]) + if i == d + else dim_val + for i, dim_val in enumerate(first.shape) + ] + ) + +@shape_dsl_function +def stack_ir(tensors: list[Tensor], dim: int = 0) -> Tensor: + first = tensors[0] + d = normalize_dim(len(first.shape) + 1, dim) + return Tensor(shape=insert_dim(first.shape, d, len(tensors))) + +@shape_dsl_function +def broadcast_to_ir(self: Tensor, shape: list[int | symint]) -> Tensor: + return Tensor(shape=shape) + +@shape_dsl_function +def tile_ir(self: Tensor, dims: list[int]) -> Tensor: + rank = len(self.shape) + if len(dims) > rank: + extra = len(dims) - rank + return Tensor( + shape=[r for r in dims[:extra]] + + [d * r for d, r in zip(self.shape, dims[extra:])] + ) + return Tensor(shape=[d * r for d, r in zip(self.shape, dims)]) + +@shape_dsl_function +def select_ir(self: Tensor, dim: int) -> Tensor: + d = normalize_dim(len(self.shape), dim) + return Tensor(shape=remove_dim(self.shape, d)) + +@shape_dsl_function +def narrow_ir(self: Tensor, dim: int, length: int | symint) -> Tensor: + return Tensor( + shape=replace_dim(self.shape, normalize_dim(len(self.shape), dim), length) + ) + +@shape_dsl_function +def split_ir( + self: Tensor, + split_size_or_sections: int | symint | list[int | symint] | None = None, + dim: int = 0, +) -> list[Tensor]: + d = normalize_dim(len(self.shape), dim) + if isinstance(split_size_or_sections, list): + return [ + Tensor(shape=replace_dim(self.shape, d, section)) + for section in split_size_or_sections + ] + if isinstance(split_size_or_sections, int): + dim_val = self.shape[d] + if isinstance(dim_val, int): + count = (dim_val + split_size_or_sections - 1) // split_size_or_sections + return [ + Tensor( + shape=replace_dim( + self.shape, + d, + split_size_or_sections + if i < count - 1 + else dim_val - (count - 1) * split_size_or_sections, + ) + ) + for i in range(count) + ] + return [Tensor(shape=replace_dim(self.shape, d, split_size_or_sections)), ...] + if split_size_or_sections != None: + quotient = self.shape[d] // split_size_or_sections + if isinstance(quotient, int): + return [ + Tensor(shape=replace_dim(self.shape, d, split_size_or_sections)) + for _ in range(quotient) + ] + return [Tensor(shape=replace_dim(self.shape, d, split_size_or_sections)), ...] + return Unknown + +@shape_dsl_function +def chunk_ir(self: Tensor, chunks: int, dim: int = 0) -> list[Tensor]: + d = normalize_dim(len(self.shape), dim) + dim_val = self.shape[d] + if isinstance(dim_val, int): + chunk_size = (dim_val + chunks - 1) // chunks + return [ + Tensor( + shape=replace_dim( + self.shape, + d, + chunk_size + if i < chunks - 1 + else dim_val - (chunks - 1) * chunk_size, + ) + ) + for i in range(chunks) + ] + return [ + Tensor(shape=replace_dim(self.shape, d, dim_val // chunks)) + for i in range(chunks) + ] + +@shape_dsl_function +def index_select_ir(self: Tensor, dim: int, index: Tensor) -> Tensor: + return Tensor( + shape=replace_dim( + self.shape, normalize_dim(len(self.shape), dim), index.shape[0] + ) + ) + +@shape_dsl_function +def reduce_ir( + self: Tensor, dim: int | list[int] | None = None, keepdim: bool = False +) -> Tensor: + if dim == None: + return Tensor(shape=reduce_shape(self.shape, dim, keepdim)) + if isinstance(dim, list): + return Tensor(shape=reduce_shape(self.shape, dim, keepdim)) + return Tensor(shape=reduce_single(self.shape, dim, keepdim)) + +@shape_dsl_function +def reduce_single( + dims: list[int | symint], dim: int, keepdim: bool +) -> list[int | symint]: + before = dims[:dim] + if dim == -1: + if keepdim: + return before + [1] + return before + after = dims[dim + 1 :] + if keepdim: + return before + [1] + after + return before + after + +@shape_dsl_function +def min_max_median_ir( + self: Tensor, dim: int | None = None, keepdim: bool = False +) -> Tensor: + if dim == None: + return Tensor(shape=[]) + s = reduce_shape(self.shape, dim, keepdim) + return [Tensor(shape=s), Tensor(shape=s)] + +@shape_dsl_function +def aminmax_ir( + self: Tensor, dim: int | list[int] | None = None, keepdim: bool = False +) -> [Tensor, Tensor]: + s = reduce_shape(self.shape, dim, keepdim) + return [Tensor(shape=s), Tensor(shape=s)] + +@shape_dsl_function +def tuple_reduce_ir( + self: Tensor, dim: int = -1, keepdim: bool = False +) -> [Tensor, Tensor]: + s = reduce_shape(self.shape, dim, keepdim) + return [Tensor(shape=s), Tensor(shape=s)] + +@shape_dsl_function +def topk_ir(self: Tensor, k: int | symint, dim: int = -1) -> [Tensor, Tensor]: + s = replace_dim(self.shape, normalize_dim(len(self.shape), dim), k) + return [Tensor(shape=s), Tensor(shape=s)] + +@shape_dsl_function +def repeat_interleave_ir( + self: Tensor, repeats: int | symint, dim: int | None = None +) -> Tensor: + if dim == None: + return Tensor(shape=[shape_extensions.dsl.prod(self.shape) * repeats]) + d = normalize_dim(len(self.shape), dim) + return Tensor(shape=replace_dim(self.shape, d, self.shape[d] * repeats)) + +@shape_dsl_function +def cosine_similarity_ir(x1: Tensor, x2: Tensor, dim: int = 1) -> Tensor: + s = broadcast(x1.shape, x2.shape) + return Tensor(shape=reduce_single(s, normalize_dim(len(s), dim), False)) + +@shape_dsl_function +def randn_ir(size: list[int | symint]) -> Tensor: + return Tensor(shape=size) + +@shape_dsl_function +def randint_ir(low: int, high: int, size: list[int | symint]) -> Tensor: + return Tensor(shape=size) + +@shape_dsl_function +def linspace_ir(steps: int | symint) -> Tensor: + return Tensor(shape=[steps]) + +@shape_dsl_function +def eye_ir(n: int | symint, m: int | symint | None = None) -> Tensor: + if m == None: + return Tensor(shape=[n, n]) + return Tensor(shape=[n, m]) + +@shape_dsl_function +def arange_ir( + start: int | symint | None = None, + end: int | symint | None = None, + step: int | symint | None = None, +) -> Tensor: + if start != None and end != None and step != None: + return Tensor(shape=[(end - start) // step]) + if start != None and end != None: + return Tensor(shape=[end - start]) + if end != None: + return Tensor(shape=[end]) + if start != None: + return Tensor(shape=[start]) + return Unknown + +@shape_dsl_function +def normal_ir( + mean: Tensor | None = None, std: Tensor | None = None, size: list[int] | None = None +) -> Tensor: + if size != None: + return Tensor(shape=[s for s in size]) + if mean != None: + return Tensor(shape=mean.shape) + if std != None: + return Tensor(shape=std.shape) + return Unknown + +@shape_dsl_function +def diag_embed_ir(self: Tensor, offset: int = 0) -> Tensor: + new_dim = self.shape[-1] + (offset if offset >= 0 else -offset) + return Tensor(shape=self.shape[:-1] + [new_dim, new_dim]) + +@shape_dsl_function +def tri_indices_ir(row: int | symint, col: int | symint, offset: int = 0) -> Tensor: + return Tensor(shape=[2, 0]) + +@shape_dsl_function +def matmul_ir(self: Tensor, other: Tensor) -> Tensor: + r1 = len(self.shape) + r2 = len(other.shape) + if r1 == 1 and r2 == 1: + return Tensor(shape=[]) + if r1 == 1 and r2 == 2: + return Tensor(shape=[other.shape[1]]) + if r1 == 2 and r2 == 1: + return Tensor(shape=[self.shape[0]]) + if r1 == 2 and r2 == 2: + return Tensor(shape=[self.shape[0], other.shape[1]]) + if r1 == 2 and r2 >= 3: + return Tensor(shape=other.shape[:-2] + [self.shape[0]] + [other.shape[-1]]) + if r1 >= 3 and r2 == 2: + return Tensor(shape=self.shape[:-2] + [self.shape[-2]] + [other.shape[1]]) + if r1 >= 3 and r2 >= 3: + return Tensor( + shape=broadcast(self.shape[:-2], other.shape[:-2]) + + [self.shape[-2]] + + [other.shape[-1]] + ) + return Unknown + +@shape_dsl_function +def mv_ir(self: Tensor, vec: Tensor) -> Tensor: + if len(self.shape) != 2: + raise Error("mv expects 2D matrix, got " + str(len(self.shape)) + "D tensor") + if len(vec.shape) != 1: + raise Error("mv expects 1D vector, got " + str(len(vec.shape)) + "D tensor") + return Tensor(shape=[self.shape[0]]) + +@shape_dsl_function +def outer_ir(self: Tensor, vec2: Tensor) -> Tensor: + if len(self.shape) != 1 or len(vec2.shape) != 1: + raise Error( + "outer expects 1D tensors, got " + + str(len(self.shape)) + + "D and " + + str(len(vec2.shape)) + + "D" + ) + return Tensor(shape=[self.shape[0], vec2.shape[0]]) + +@shape_dsl_function +def tensordot_ir(self: Tensor, other: Tensor, dims: int) -> Tensor: + return Tensor(shape=self.shape[: len(self.shape) - dims] + other.shape[dims:]) + +@shape_dsl_function +def apply_einsum( + output_map: list[list[int]], check_pairs: list[list[int]], inputs: list[Tensor] +) -> Tensor: + bad_dims = [ + 1 + for i0, d0, i1, d1 in check_pairs + if isinstance(inputs[i0].shape[d0], int) + and isinstance(inputs[i1].shape[d1], int) + and inputs[i0].shape[d0] != inputs[i1].shape[d1] + ] + if len(bad_dims) > 0: + raise Error("einsum: inconsistent dimensions for repeated index") + return Tensor(shape=[inputs[inp].shape[dim] for inp, dim in output_map]) + +@shape_dsl_function +def einsum_ir(spec: str, operands: list[Tensor] | None = None) -> Tensor: + if operands != None: + output_map, check_pairs = shape_extensions.dsl.parse_einsum_equation(spec) + return apply_einsum(output_map, check_pairs, operands) + return Unknown + +@shape_dsl_function +def eigvals_ir(self: Tensor) -> Tensor: + if len(self.shape) < 2: + raise Error( + "eigvals requires at least 2D input, got " + + str(len(self.shape)) + + "D tensor" + ) + return Tensor(shape=self.shape[:-2] + [self.shape[-2]]) + +@shape_dsl_function +def eig_ir(self: Tensor) -> [Tensor, Tensor]: + if len(self.shape) < 2: + raise Error( + "eig requires at least 2D input, got " + str(len(self.shape)) + "D tensor" + ) + batch = self.shape[:-2] + return [ + Tensor(shape=batch + [self.shape[-2]]), + Tensor(shape=batch + self.shape[-2:]), + ] + +@shape_dsl_function +def slogdet_ir(self: Tensor) -> [Tensor, Tensor]: + if len(self.shape) < 2: + raise Error( + "slogdet requires at least 2D input, got " + + str(len(self.shape)) + + "D tensor" + ) + return [Tensor(shape=self.shape[:-2]), Tensor(shape=self.shape[:-2])] + +@shape_dsl_function +def solve_ir(self: Tensor, other: Tensor) -> Tensor: + return Tensor(shape=other.shape) + +@shape_dsl_function +def solve_reversed_ir(self: Tensor, other: Tensor) -> Tensor: + return Tensor(shape=self.shape) + +@shape_dsl_function +def conv_ir( + self: Tensor, + weight: Tensor, + stride: int | list[int] = 1, + padding: int | list[int] = 0, + dilation: int | list[int] = 1, +) -> Tensor: + spatial_dims = len(self.shape) - 2 + stride_list = broadcast_int(stride, spatial_dims) + padding_list = broadcast_int(padding, spatial_dims) + dilation_list = broadcast_int(dilation, spatial_dims) + return Tensor( + shape=[self.shape[0], weight.shape[0]] + + [ + conv_spatial_out(s, k, st, p, dil) + for s, k, st, p, dil in zip( + self.shape[2:], + weight.shape[2:], + stride_list, + padding_list, + dilation_list, + ) + ] + ) + +@shape_dsl_function +def conv_transpose_ir( + self: Tensor, + weight: Tensor, + stride: int | list[int] = 1, + padding: int | list[int] = 0, + output_padding: int | list[int] = 0, + dilation: int | list[int] = 1, +) -> Tensor: + spatial_dims = len(self.shape) - 2 + stride_list = broadcast_int(stride, spatial_dims) + padding_list = broadcast_int(padding, spatial_dims) + outpad_list = broadcast_int(output_padding, spatial_dims) + dilation_list = broadcast_int(dilation, spatial_dims) + return Tensor( + shape=[self.shape[0], weight.shape[1]] + + [ + (s - 1) * st - 2 * p + dil * (k - 1) + op + 1 + for s, k, st, p, op, dil in zip( + self.shape[2:], + weight.shape[2:], + stride_list, + padding_list, + outpad_list, + dilation_list, + ) + ] + ) + +@shape_dsl_function +def pool_ir( + self: Tensor, + kernel_size: int | list[int], + stride: int | list[int] | None = None, + padding: int | list[int] = 0, + dilation: int | list[int] = 1, + return_indices: bool = False, +) -> Tensor: + spatial_dims = len(self.shape) - 2 + ks_list = broadcast_int(kernel_size, spatial_dims) + stride_list = ks_list if stride == None else broadcast_int(stride, spatial_dims) + padding_list = broadcast_int(padding, spatial_dims) + dilation_list = broadcast_int(dilation, spatial_dims) + out = [self.shape[0], self.shape[1]] + [ + conv_spatial_out(s, k, st, p, dil) + for s, k, st, p, dil in zip( + self.shape[2:], ks_list, stride_list, padding_list, dilation_list + ) + ] + if return_indices: + return [Tensor(shape=out), Tensor(shape=out)] + return Tensor(shape=out) + +@shape_dsl_function +def adaptive_pool_ir( + self: Tensor, output_size: int | symint | list[int | symint] +) -> Tensor: + out_sizes = broadcast_int(output_size, len(self.shape) - 2) + return Tensor(shape=[self.shape[0], self.shape[1]] + out_sizes) + +@shape_dsl_function +def interpolate_ir( + self: Tensor, + size: int | symint | list[int | symint] | None = None, + scale_factor: int | symint | None = None, +) -> Tensor: + if size != None: + return Tensor( + shape=[self.shape[0], self.shape[1]] + + broadcast_int(size, len(self.shape) - 2) + ) + if scale_factor != None: + return Tensor( + shape=[self.shape[0], self.shape[1]] + + [d * scale_factor for d in self.shape[2:]] + ) + raise Error("interpolate requires either 'size' or 'scale_factor' argument") + +@shape_dsl_function +def loss_ir(self: Tensor, reduction: str = "mean") -> Tensor: + if reduction == "none": + return Tensor(shape=self.shape) + return Tensor(shape=[]) + +@shape_dsl_function +def pad_ir(self: Tensor, pad: list[int]) -> Tensor: + rank = len(self.shape) + num_pad_dims = len(pad) // 2 + offsets = [ + pad[(rank - 1 - i) * 2] + pad[(rank - 1 - i) * 2 + 1] + if i >= rank - num_pad_dims + else 0 + for i in range(rank) + ] + return Tensor(shape=[d + offsets[i] for i, d in enumerate(self.shape)]) + +@shape_dsl_function +def rfft_ir(self: Tensor, n: int | symint | None = None, dim: int = -1) -> Tensor: + d = normalize_dim(len(self.shape), dim) + if n != None: + return Tensor(shape=replace_dim(self.shape, d, n // 2 + 1)) + return Tensor(shape=replace_dim(self.shape, d, self.shape[d] // 2 + 1)) + +@shape_dsl_function +def irfft_ir(self: Tensor, n: int | symint | None = None, dim: int = -1) -> Tensor: + d = normalize_dim(len(self.shape), dim) + if n != None: + return Tensor(shape=replace_dim(self.shape, d, n)) + return Tensor(shape=replace_dim(self.shape, d, 2 * (self.shape[d] - 1))) + +@shape_dsl_function +def size_ir(self: Tensor, dim: int | None = None) -> int | symint: + if dim != None: + return self.shape[normalize_dim(len(self.shape), dim)] + return [d for d in self.shape] + +@shape_dsl_function +def numel_ir(self: Tensor) -> int | symint: + return shape_extensions.dsl.prod(self.shape) + +@shape_dsl_function +def dim_ir(self: Tensor) -> int: + return len(self.shape) + +@shape_dsl_function +def item_ir(self: Tensor) -> Tensor: + if len(self.shape) != 0: + raise Error( + "item() only works on 0-dimensional tensors, got " + + str(len(self.shape)) + + "D tensor" + ) + return Unknown + +@shape_dsl_function +def tolist_ir(self: Tensor) -> Tensor: + return Unknown + +@shape_dsl_function +def multinomial_ir(self: Tensor, num_samples: int | symint) -> Tensor: + return Tensor(shape=self.shape[:-1] + [num_samples]) + +@shape_dsl_function +def where_ir(condition: Tensor, x: Tensor, y: Tensor) -> Tensor: + return Tensor(shape=x.shape) + +@shape_dsl_function +def take_along_dim_ir(self: Tensor, indices: Tensor) -> Tensor: + return Tensor(shape=indices.shape) + +@shape_dsl_function +def nn_flatten_forward_ir( + input: Tensor, start_dim: symint = 1, end_dim: symint = -1 +) -> Tensor: + return flatten_ir(input, start_dim, end_dim) + +@shape_dsl_function +def nn_maxpool_forward_ir( + input: Tensor, + kernel_size: symint = 1, + stride: symint | None = None, + padding: symint = 0, + dilation: symint = 1, +) -> Tensor: + return pool_ir(input, kernel_size, stride, padding, dilation) + +@shape_dsl_function +def nn_avgpool_forward_ir( + input: Tensor, + kernel_size: symint = 1, + stride: symint | None = None, + padding: symint = 0, +) -> Tensor: + return pool_ir(input, kernel_size, stride, padding, 1) + +@shape_dsl_function +def nn_upsample_forward_ir( + input: Tensor, size: symint | None = None, scale_factor: symint | None = None +) -> Tensor: + return interpolate_ir(input, size, scale_factor) + +@shape_dsl_function +def nn_pixel_shuffle_forward_ir(input: Tensor, upscale_factor: symint) -> Tensor: + r = upscale_factor + return Tensor( + shape=[input.shape[0], input.shape[1] // (r * r)] + + [d * r for d in input.shape[2:]] + ) + +@shape_dsl_function +def nn_glu_forward_ir(input: Tensor, dim: symint = 1) -> Tensor: + rank = len(input.shape) + d = normalize_dim(rank, dim) + return Tensor(shape=replace_dim(input.shape, d, input.shape[d] // 2)) + +@shape_dsl_function +def nn_lstm_forward_ir( + input: Tensor, + input_size: symint, + hidden_size: symint, + num_layers: symint = 1, + bidirectional: bool = False, +) -> [Tensor, Tensor, Tensor]: + nd = 2 if bidirectional else 1 + output = Tensor(shape=[input.shape[0], input.shape[1], hidden_size * nd]) + h_n = Tensor(shape=[num_layers * nd, input.shape[0], hidden_size]) + c_n = Tensor(shape=[num_layers * nd, input.shape[0], hidden_size]) + return [output, h_n, c_n] + +@shape_dsl_function +def nn_gru_forward_ir( + input: Tensor, + input_size: symint, + hidden_size: symint, + num_layers: symint = 1, + bidirectional: bool = False, +) -> [Tensor, Tensor]: + nd = 2 if bidirectional else 1 + output = Tensor(shape=[input.shape[0], input.shape[1], hidden_size * nd]) + h_n = Tensor(shape=[num_layers * nd, input.shape[0], hidden_size]) + return [output, h_n] + +@shape_dsl_function +def nn_lstmcell_forward_ir( + input: Tensor, input_size: symint, hidden_size: symint +) -> [Tensor, Tensor]: + h = Tensor(shape=[input.shape[0], hidden_size]) + c = Tensor(shape=[input.shape[0], hidden_size]) + return [h, c] + +@shape_dsl_function +def nn_reflectionpad2d_forward_ir(input: Tensor, padding: symint) -> Tensor: + return Tensor( + shape=[ + input.shape[0], + input.shape[1], + input.shape[2] + 2 * padding, + input.shape[3] + 2 * padding, + ] + ) diff --git a/test/tensor_shapes/fixtures/torch/fft.pyi b/test/tensor_shapes/fixtures/torch/fft.pyi index d03502ab22..64720f95ca 100644 --- a/test/tensor_shapes/fixtures/torch/fft.pyi +++ b/test/tensor_shapes/fixtures/torch/fft.pyi @@ -4,7 +4,9 @@ # LICENSE file in the root directory of this source tree. # Type stubs for torch.fft module (Phase 6: FFT Operations) +from shape_extensions import uses_shape_dsl from torch import Tensor +from torch._shapes import irfft_ir, rfft_ir # 1D FFT operations def fft[*Shape]( @@ -13,9 +15,13 @@ def fft[*Shape]( def ifft[*Shape]( input: Tensor[*Shape], n: int = None, dim: int = -1, norm: str = None ) -> Tensor[*Shape]: ... +@uses_shape_dsl(rfft_ir) def rfft(self: Tensor, n: int = None, dim: int = -1, norm: str = None) -> Tensor: ... +@uses_shape_dsl(irfft_ir) def irfft(self: Tensor, n: int = None, dim: int = -1, norm: str = None) -> Tensor: ... +@uses_shape_dsl(irfft_ir) def hfft(self: Tensor, n: int = None, dim: int = -1, norm: str = None) -> Tensor: ... +@uses_shape_dsl(rfft_ir) def ihfft(self: Tensor, n: int = None, dim: int = -1, norm: str = None) -> Tensor: ... # 2D FFT operations diff --git a/test/tensor_shapes/fixtures/torch/linalg.pyi b/test/tensor_shapes/fixtures/torch/linalg.pyi index 2154a55a84..73d9cdec8f 100644 --- a/test/tensor_shapes/fixtures/torch/linalg.pyi +++ b/test/tensor_shapes/fixtures/torch/linalg.pyi @@ -4,22 +4,31 @@ # LICENSE file in the root directory of this source tree. # Type stubs for torch.linalg module (Phase 4: Advanced Linear Algebra) +from shape_extensions import uses_shape_dsl from torch import Tensor +from torch._shapes import eig_ir, eigvals_ir, slogdet_ir, solve_ir, solve_reversed_ir # Eigenvalue decomposition +@uses_shape_dsl(eig_ir) def eig(self: Tensor) -> tuple[Tensor, Tensor]: ... +@uses_shape_dsl(eig_ir) def eigh(self: Tensor, UPLO: str = "L") -> tuple[Tensor, Tensor]: ... # Tier 3: Eigenvalues only (no eigenvectors) +@uses_shape_dsl(eigvals_ir) def eigvals(self: Tensor) -> Tensor: ... +@uses_shape_dsl(eigvals_ir) def eigvalsh(self: Tensor, UPLO: str = "L") -> Tensor: ... # Cholesky decomposition def cholesky[*Shape](input: Tensor[*Shape], upper: bool = False) -> Tensor[*Shape]: ... # Linear system solvers +@uses_shape_dsl(solve_ir) def solve(self: Tensor, other: Tensor) -> Tensor: ... +@uses_shape_dsl(solve_ir) def solve_triangular(self: Tensor, other: Tensor, upper: bool = False) -> Tensor: ... +@uses_shape_dsl(solve_reversed_ir) def cholesky_solve(self: Tensor, other: Tensor, upper: bool = False) -> Tensor: ... # Matrix inverse @@ -29,6 +38,7 @@ def inv[*Shape](input: Tensor[*Shape]) -> Tensor[*Shape]: ... def det[*Batch, M, N](input: Tensor[*Batch, M, N]) -> Tensor[*Batch]: ... # Sign and log determinant +@uses_shape_dsl(slogdet_ir) def slogdet(self: Tensor) -> tuple[Tensor, Tensor]: ... # Matrix power diff --git a/test/tensor_shapes/fixtures/torch/nn/__init__.pyi b/test/tensor_shapes/fixtures/torch/nn/__init__.pyi index a8fee4f2c1..84236a48db 100644 --- a/test/tensor_shapes/fixtures/torch/nn/__init__.pyi +++ b/test/tensor_shapes/fixtures/torch/nn/__init__.pyi @@ -21,8 +21,20 @@ from typing import ( ) if TYPE_CHECKING: - from shape_extensions import Dim as _Dim + from shape_extensions import Dim as _Dim, uses_shape_dsl from torch import Tensor + from torch._shapes import ( + nn_avgpool_forward_ir, + nn_flatten_forward_ir, + nn_glu_forward_ir, + nn_gru_forward_ir, + nn_lstm_forward_ir, + nn_lstmcell_forward_ir, + nn_maxpool_forward_ir, + nn_pixel_shuffle_forward_ir, + nn_reflectionpad2d_forward_ir, + nn_upsample_forward_ir, + ) # Re-export submodules from . import functional as functional, init as init @@ -675,6 +687,10 @@ class MaxPool1d(Module): return_indices: bool = False, ceil_mode: bool = False, ) -> None: ... + @uses_shape_dsl( + nn_maxpool_forward_ir, + capture_init=["kernel_size", "stride", "padding", "dilation"], + ) def forward(self, input: Tensor) -> Tensor: ... class MaxPool2d(Module): @@ -688,6 +704,10 @@ class MaxPool2d(Module): return_indices: bool = False, ceil_mode: bool = False, ) -> None: ... + @uses_shape_dsl( + nn_maxpool_forward_ir, + capture_init=["kernel_size", "stride", "padding", "dilation"], + ) def forward(self, input: Tensor) -> Tensor: ... class MaxPool3d(Module): @@ -701,6 +721,10 @@ class MaxPool3d(Module): return_indices: bool = False, ceil_mode: bool = False, ) -> None: ... + @uses_shape_dsl( + nn_maxpool_forward_ir, + capture_init=["kernel_size", "stride", "padding", "dilation"], + ) def forward(self, input: Tensor) -> Tensor: ... class AvgPool1d(Module): @@ -713,6 +737,9 @@ class AvgPool1d(Module): ceil_mode: bool = False, count_include_pad: bool = True, ) -> None: ... + @uses_shape_dsl( + nn_avgpool_forward_ir, capture_init=["kernel_size", "stride", "padding"] + ) def forward(self, input: Tensor) -> Tensor: ... class AvgPool2d(Module): @@ -726,6 +753,9 @@ class AvgPool2d(Module): count_include_pad: bool = True, divisor_override: int | None = None, ) -> None: ... + @uses_shape_dsl( + nn_avgpool_forward_ir, capture_init=["kernel_size", "stride", "padding"] + ) def forward(self, input: Tensor) -> Tensor: ... class AvgPool3d(Module): @@ -739,6 +769,9 @@ class AvgPool3d(Module): count_include_pad: bool = True, divisor_override: int | None = None, ) -> None: ... + @uses_shape_dsl( + nn_avgpool_forward_ir, capture_init=["kernel_size", "stride", "padding"] + ) def forward(self, input: Tensor) -> Tensor: ... class AdaptiveAvgPool1d[OL](Module): @@ -794,6 +827,7 @@ class PixelShuffle(Module): """ def __init__(self, upscale_factor: int) -> None: ... + @uses_shape_dsl(nn_pixel_shuffle_forward_ir, capture_init=["upscale_factor"]) def forward(self, input: Tensor) -> Tensor: ... class GLU(Module): @@ -806,6 +840,7 @@ class GLU(Module): """ def __init__(self, dim: int = 1) -> None: ... + @uses_shape_dsl(nn_glu_forward_ir, capture_init=["dim"]) def forward(self, input: Tensor) -> Tensor: ... class LSTM(Module): @@ -834,6 +869,10 @@ class LSTM(Module): def flatten_parameters(self) -> None: """Reset parameter data pointer for CUDA contiguous memory. No-op on CPU.""" ... + @uses_shape_dsl( + nn_lstm_forward_ir, + capture_init=["input_size", "hidden_size", "num_layers", "bidirectional"], + ) def forward(self, input: Tensor) -> tuple[Tensor, Tensor, Tensor]: ... class LSTMCell(Module): @@ -853,6 +892,7 @@ class LSTMCell(Module): device: Any = None, dtype: Any = None, ) -> None: ... + @uses_shape_dsl(nn_lstmcell_forward_ir, capture_init=["input_size", "hidden_size"]) def forward( self, input: Tensor, hx: tuple[Tensor, Tensor] | None = None ) -> tuple[Tensor, Tensor]: ... @@ -882,6 +922,10 @@ class GRU(Module): def flatten_parameters(self) -> None: """Reset parameter data pointer for CUDA contiguous memory. No-op on CPU.""" ... + @uses_shape_dsl( + nn_gru_forward_ir, + capture_init=["input_size", "hidden_size", "num_layers", "bidirectional"], + ) def forward( self, input: Tensor, hx: Tensor | None = None ) -> tuple[Tensor, Tensor]: ... @@ -920,6 +964,7 @@ class Upsample(Module): mode: str = "nearest", align_corners: bool | None = None, ) -> None: ... + @uses_shape_dsl(nn_upsample_forward_ir, capture_init=["size", "scale_factor"]) def forward(self, input: Tensor) -> Tensor: ... # ============================================================================== @@ -1073,6 +1118,7 @@ class Flatten(Module): """ def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None: ... + @uses_shape_dsl(nn_flatten_forward_ir, capture_init=["start_dim", "end_dim"]) def forward(self, input: Tensor) -> Tensor: ... class Unflatten(Module): @@ -1087,6 +1133,7 @@ class ReflectionPad2d(Module): """ def __init__(self, padding: int) -> None: ... + @uses_shape_dsl(nn_reflectionpad2d_forward_ir, capture_init=["padding"]) def forward(self, input: Tensor) -> Tensor: ... class ReplicationPad2d(Module): @@ -1096,6 +1143,7 @@ class ReplicationPad2d(Module): """ def __init__(self, padding: int) -> None: ... + @uses_shape_dsl(nn_reflectionpad2d_forward_ir, capture_init=["padding"]) def forward(self, input: Tensor) -> Tensor: ... # Embedding variants diff --git a/test/tensor_shapes/fixtures/torch/nn/functional.pyi b/test/tensor_shapes/fixtures/torch/nn/functional.pyi index 9198e9b808..04a648262b 100644 --- a/test/tensor_shapes/fixtures/torch/nn/functional.pyi +++ b/test/tensor_shapes/fixtures/torch/nn/functional.pyi @@ -10,6 +10,18 @@ Functional neural network operations including convolution, pooling, activation, from typing import Literal, overload +from shape_extensions import uses_shape_dsl +from torch._shapes import ( + adaptive_pool_ir, + conv_ir, + conv_transpose_ir, + cosine_similarity_ir, + interpolate_ir, + loss_ir, + pad_ir, + pool_ir, +) + from .. import Tensor __all__ = [ @@ -93,6 +105,7 @@ __all__ = [ # ==================================================================== # Convolution operations +@uses_shape_dsl(conv_ir) def conv1d( self: Tensor, weight: Tensor, @@ -105,6 +118,7 @@ def conv1d( """1D convolution. Shape inference via meta-shape: torch.nn.functional.conv1d""" ... +@uses_shape_dsl(conv_ir) def conv2d( self: Tensor, weight: Tensor, @@ -117,6 +131,7 @@ def conv2d( """2D convolution. Shape inference via meta-shape: torch.nn.functional.conv2d""" ... +@uses_shape_dsl(conv_ir) def conv3d( self: Tensor, weight: Tensor, @@ -130,6 +145,7 @@ def conv3d( ... # Transposed convolution operations +@uses_shape_dsl(conv_transpose_ir) def conv_transpose1d( self: Tensor, weight: Tensor, @@ -143,6 +159,7 @@ def conv_transpose1d( """1D transposed convolution. Shape inference via meta-shape: torch.nn.functional.conv_transpose1d""" ... +@uses_shape_dsl(conv_transpose_ir) def conv_transpose2d( self: Tensor, weight: Tensor, @@ -156,6 +173,7 @@ def conv_transpose2d( """2D transposed convolution. Shape inference via meta-shape: torch.nn.functional.conv_transpose2d""" ... +@uses_shape_dsl(conv_transpose_ir) def conv_transpose3d( self: Tensor, weight: Tensor, @@ -170,6 +188,7 @@ def conv_transpose3d( ... # Max pooling operations +@uses_shape_dsl(pool_ir) @overload def max_pool1d( self: Tensor, @@ -196,6 +215,7 @@ def max_pool1d( """1D max pooling with indices. Shape inference via meta-shape: torch.nn.functional.max_pool1d""" ... +@uses_shape_dsl(pool_ir) @overload def max_pool2d( self: Tensor, @@ -222,6 +242,7 @@ def max_pool2d( """2D max pooling with indices. Shape inference via meta-shape: torch.nn.functional.max_pool2d""" ... +@uses_shape_dsl(pool_ir) @overload def max_pool3d( self: Tensor, @@ -249,6 +270,7 @@ def max_pool3d( ... # Average pooling operations +@uses_shape_dsl(pool_ir) def avg_pool1d( self: Tensor, kernel_size: int | tuple[int], @@ -260,6 +282,7 @@ def avg_pool1d( """1D average pooling. Shape inference via meta-shape: torch.nn.functional.avg_pool1d""" ... +@uses_shape_dsl(pool_ir) def avg_pool2d( self: Tensor, kernel_size: int | tuple[int, int], @@ -272,6 +295,7 @@ def avg_pool2d( """2D average pooling. Shape inference via meta-shape: torch.nn.functional.avg_pool2d""" ... +@uses_shape_dsl(pool_ir) def avg_pool3d( self: Tensor, kernel_size: int | tuple[int, int, int], @@ -285,12 +309,14 @@ def avg_pool3d( ... # Adaptive max pooling operations +@uses_shape_dsl(adaptive_pool_ir) def adaptive_max_pool1d( self: Tensor, output_size: int | tuple[int], return_indices: bool = False ) -> Tensor: """1D adaptive max pooling. Shape inference via meta-shape: torch.nn.functional.adaptive_max_pool1d""" ... +@uses_shape_dsl(adaptive_pool_ir) def adaptive_max_pool2d( self: Tensor, output_size: int | tuple[int, int] | None, @@ -299,6 +325,7 @@ def adaptive_max_pool2d( """2D adaptive max pooling. Shape inference via meta-shape: torch.nn.functional.adaptive_max_pool2d""" ... +@uses_shape_dsl(adaptive_pool_ir) def adaptive_max_pool3d( self: Tensor, output_size: int | tuple[int, int, int] | None, @@ -308,16 +335,19 @@ def adaptive_max_pool3d( ... # Adaptive average pooling operations +@uses_shape_dsl(adaptive_pool_ir) def adaptive_avg_pool1d(self: Tensor, output_size: int | tuple[int]) -> Tensor: """1D adaptive average pooling. Shape inference via meta-shape: torch.nn.functional.adaptive_avg_pool1d""" ... +@uses_shape_dsl(adaptive_pool_ir) def adaptive_avg_pool2d( self: Tensor, output_size: int | tuple[int, int] | None ) -> Tensor: """2D adaptive average pooling. Shape inference via meta-shape: torch.nn.functional.adaptive_avg_pool2d""" ... +@uses_shape_dsl(adaptive_pool_ir) def adaptive_avg_pool3d( self: Tensor, output_size: int | tuple[int, int, int] | None ) -> Tensor: @@ -325,6 +355,7 @@ def adaptive_avg_pool3d( ... # Interpolation/upsampling operations +@uses_shape_dsl(interpolate_ir) def interpolate( self: Tensor, size: int | tuple[int, ...] | None = None, @@ -337,6 +368,7 @@ def interpolate( """Interpolate/upsample tensor. Shape inference via meta-shape: torch.nn.functional.interpolate""" ... +@uses_shape_dsl(interpolate_ir) def upsample( self: Tensor, size: int | tuple[int, ...] | None = None, @@ -552,6 +584,7 @@ def logsigmoid[*Shape](input: Tensor[*Shape]) -> Tensor[*Shape]: # Phase 6: Loss Functions # ============================================================================== +@uses_shape_dsl(loss_ir) def mse_loss( self: Tensor, target: Tensor, @@ -562,6 +595,7 @@ def mse_loss( """Mean squared error loss. Shape inference via meta-shape: torch.nn.functional.mse_loss""" ... +@uses_shape_dsl(loss_ir) def l1_loss( self: Tensor, target: Tensor, @@ -572,6 +606,7 @@ def l1_loss( """L1 loss. Shape inference via meta-shape: torch.nn.functional.l1_loss""" ... +@uses_shape_dsl(loss_ir) def nll_loss( self: Tensor, target: Tensor, @@ -584,6 +619,7 @@ def nll_loss( """Negative log likelihood loss. Shape inference via meta-shape: torch.nn.functional.nll_loss""" ... +@uses_shape_dsl(loss_ir) def cross_entropy( self: Tensor, target: Tensor, @@ -597,6 +633,7 @@ def cross_entropy( """Cross entropy loss. Shape inference via meta-shape: torch.nn.functional.cross_entropy""" ... +@uses_shape_dsl(loss_ir) def binary_cross_entropy( self: Tensor, target: Tensor, @@ -608,6 +645,7 @@ def binary_cross_entropy( """Binary cross entropy loss. Shape inference via meta-shape: torch.nn.functional.binary_cross_entropy""" ... +@uses_shape_dsl(loss_ir) def binary_cross_entropy_with_logits( self: Tensor, target: Tensor, @@ -620,6 +658,7 @@ def binary_cross_entropy_with_logits( """Binary cross entropy with logits. Shape inference via meta-shape: torch.nn.functional.binary_cross_entropy_with_logits""" ... +@uses_shape_dsl(loss_ir) def kl_div( self: Tensor, target: Tensor, @@ -631,6 +670,7 @@ def kl_div( """KL divergence loss. Shape inference via meta-shape: torch.nn.functional.kl_div""" ... +@uses_shape_dsl(loss_ir) def smooth_l1_loss( self: Tensor, target: Tensor, @@ -642,12 +682,14 @@ def smooth_l1_loss( """Smooth L1 loss. Shape inference via meta-shape: torch.nn.functional.smooth_l1_loss""" ... +@uses_shape_dsl(loss_ir) def huber_loss( self: Tensor, target: Tensor, reduction: str = "mean", delta: float = 1.0 ) -> Tensor: """Huber loss. Shape inference via meta-shape: torch.nn.functional.huber_loss""" ... +@uses_shape_dsl(loss_ir) def poisson_nll_loss( self: Tensor, target: Tensor, @@ -661,6 +703,7 @@ def poisson_nll_loss( """Poisson NLL loss. Shape inference via meta-shape: torch.nn.functional.poisson_nll_loss""" ... +@uses_shape_dsl(loss_ir) def cosine_embedding_loss( self: Tensor, input2: Tensor, @@ -673,6 +716,7 @@ def cosine_embedding_loss( """Cosine embedding loss. Shape inference via meta-shape: torch.nn.functional.cosine_embedding_loss""" ... +@uses_shape_dsl(loss_ir) def margin_ranking_loss( self: Tensor, input2: Tensor, @@ -685,6 +729,7 @@ def margin_ranking_loss( """Margin ranking loss. Shape inference via meta-shape: torch.nn.functional.margin_ranking_loss""" ... +@uses_shape_dsl(loss_ir) def triplet_margin_loss( self: Tensor, positive: Tensor, @@ -700,6 +745,7 @@ def triplet_margin_loss( """Triplet margin loss. Shape inference via meta-shape: torch.nn.functional.triplet_margin_loss""" ... +@uses_shape_dsl(loss_ir) def hinge_embedding_loss( self: Tensor, target: Tensor, @@ -712,6 +758,7 @@ def hinge_embedding_loss( ... # Padding operation +@uses_shape_dsl(pad_ir) def pad( self: Tensor, pad: tuple[int, ...], mode: str = "constant", value: float = 0.0 ) -> Tensor: @@ -829,6 +876,7 @@ def scaled_dot_product_attention[ """Scaled dot product attention. Shape inference via meta-shape: torch.nn.functional.scaled_dot_product_attention""" ... +@uses_shape_dsl(cosine_similarity_ir) def cosine_similarity( x1: Tensor, x2: Tensor, dim: int = 1, eps: float = 1e-8 ) -> Tensor: