From 74fa8d10ee95f4b0dc71c0db48f48507d711b9b0 Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Tue, 10 Mar 2026 16:56:24 +0000 Subject: [PATCH 01/41] plans --- .../plan-simplifyKernelGen-phase1.prompt.md | 213 ++++++++++++++++++ .../plan-simplifyKernelGen-phase2.prompt.md | 121 ++++++++++ .../plan-simplifyKernelGen-phase3.prompt.md | 48 ++++ .../prompts/plan-simplifyKernelGen.prompt.md | 90 ++++++++ slangpy/benchmarks/test_benchmark_autograd.py | 6 +- slangpy/core/calldata.py | 3 + .../tests/slangpy_tests/test_kernel_gen.py | 53 +++++ 7 files changed, 531 insertions(+), 3 deletions(-) create mode 100644 .github/prompts/plan-simplifyKernelGen-phase1.prompt.md create mode 100644 .github/prompts/plan-simplifyKernelGen-phase2.prompt.md create mode 100644 .github/prompts/plan-simplifyKernelGen-phase3.prompt.md create mode 100644 .github/prompts/plan-simplifyKernelGen.prompt.md create mode 100644 slangpy/tests/slangpy_tests/test_kernel_gen.py diff --git a/.github/prompts/plan-simplifyKernelGen-phase1.prompt.md b/.github/prompts/plan-simplifyKernelGen-phase1.prompt.md new file mode 100644 index 000000000..bfe58cc28 --- /dev/null +++ b/.github/prompts/plan-simplifyKernelGen-phase1.prompt.md @@ -0,0 +1,213 @@ +## Phase 1: Direct Type Marshalling + +**Goal**: For dim-0, non-composite arguments, emit the raw Slang type in CallData and use direct assignment in the trampoline — eliminating `ValueType` wrappers, `__slangpy_load`/`__slangpy_store` indirection, mapping constants, and `Context.map()` calls. + +**Parent plan**: [plan-simplifyKernelGen.prompt.md](plan-simplifyKernelGen.prompt.md) + +--- + +### Step 1.1: Define eligibility predicate + +Add a **global Python function** `is_direct_bind_eligible(binding: BoundVariable) -> bool` (e.g., in [slangpy/bindings/boundvariable.py](slangpy/bindings/boundvariable.py) or a small utility module). This is intentionally NOT a method on `NativeMarshall` — placing it on the C++ side would require nanobind trampoline plumbing for a function that is only consumed during Python-side codegen. A simple Python function avoids all C++/nanobind complexity. + +The conditions for leaf types are: +- `binding.call_dimensionality is not None and binding.call_dimensionality == 0` (note: `call_dimensionality` is initialized to `None`, so a `None` check is required) +- `not binding.children` (not composite/dict) +- The marshall's Python type is one of the known direct-eligible types (`ValueMarshall`, `ScalarMarshall`, `VectorMarshall`, `MatrixMarshall`, `ArrayMarshall`, `ValueRefMarshall`). Types like `WangHashArgMarshall` are excluded. + +Individual marshalls call this function inside their `gen_calldata`, `gen_trampoline_load`, `gen_trampoline_store`, and `create_calldata` methods to decide which codegen path to take. + +Add a companion function `is_direct_bind_recursive(binding: BoundVariable) -> bool` that handles composite types: +- If `binding.children is None`: delegates to `is_direct_bind_eligible(binding)` +- If `binding.children is not None`: returns `True` only if `binding.call_dimensionality is not None and binding.call_dimensionality == 0` AND every child's `is_direct_bind_recursive()` returns `True`. This handles dicts bound to Slang structs where all fields are dim-0 leaves (or recursively dim-0 structs). +- Additionally, the `binding.vector_type` must be a concrete Slang struct type (not `UnknownType`). Dicts without `_type` may resolve to `UnknownType` and are ineligible. + +Both functions are consulted by the marshalls and by `gen_call_data_code` (for struct/dict bindings). + +--- + +### Step 1.2: Implement for `ValueMarshall` (scalars/matrices) + +In [slangpy/builtin/value.py](slangpy/builtin/value.py): + +- Modify `gen_calldata`: call `is_direct_bind_eligible(binding)`. When eligible, emit `typealias _t_{name} = {raw_slang_type};` instead of `ValueType<{type}>` +- Add `gen_trampoline_load`: when direct-eligible, emit `{name} = {data_name};` and return `True` +- Add `gen_trampoline_store`: when direct-eligible (read-only scalars), return `True` (suppress default store, no-op) +- Modify `create_calldata`: when direct-eligible, return raw value instead of `{"value": data}`. The cursor write system in [cursor_utils.h](src/slangpy_ext/device/cursor_utils.h) already handles writing scalars/vectors/matrices directly — the `write_internal` method dispatches on `TypeReflection::Kind::scalar/vector/matrix`. + +#### Step 1.2a: Critical C++ change — `NativeValueMarshall` fast path + +`NativeValueMarshall::write_shader_cursor_pre_dispatch` in [slangpyvalue.cpp](src/slangpy_ext/utils/slangpyvalue.cpp) has a cached fast path that navigates `cursor[variable_name]["value"]` on first call: + +```cpp +ShaderCursor field = cursor[binding->variable_name()]["value"]; +m_cached.value_offset = field.offset(); +``` + +If the Slang type changes from `ValueType` (which has a `value` sub-field) to raw `int` (a scalar with no sub-fields), the `["value"]` navigation will crash. This affects **all** `NativeValueMarshall` subclasses: `ValueMarshall`, `VectorMarshall`, `MatrixMarshall`, `StructMarshall`, `ArrayMarshall`. + +**Required fix**: Add a `direct_bind` flag to `NativeValueMarshall` (set from the Python side when `can_direct_bind` returns `True`). In `ensure_cached`, branch on this flag: +- **`direct_bind == false`** (current path): navigate `cursor[variable_name]["value"]` +- **`direct_bind == true`** (new path): navigate `cursor[variable_name]` only (no `"value"` sub-field), and cache the resulting offset/layout/writer directly + +The flag can be set during `CallData` construction when the binding is finalized, or passed via a `NativeBoundVariableRuntime` property. Alternatively, detect the absence of the `"value"` field by checking the type layout — but an explicit flag is safer and clearer. + +--- + +### Step 1.3: Implement for `VectorMarshall`, `MatrixMarshall`, and `ArrayMarshall` + +In [slangpy/builtin/value.py](slangpy/builtin/value.py): +- `VectorMarshall`: same pattern as `ValueMarshall`. `gen_calldata` emits `typealias _t_{name} = {vector_type};` instead of `VectorValueType<{et},{n}>`. +- `MatrixMarshall`: same pattern. Note that `MatrixMarshall` does **not** override `gen_calldata` — it inherits `ValueMarshall.gen_calldata` which emits `ValueType<{matrix_type}>` (not `MatrixValueType`). The `MatrixValueType<...>` name only appears in `resolve_types` for the experimental vectorization path. The direct-bind override goes on `gen_calldata` and emits the raw matrix type (e.g., `float4x4`) instead of `ValueType`. + +In [slangpy/builtin/array.py](slangpy/builtin/array.py): +- `ArrayMarshall`: at dim-0, it already falls through to `super().gen_calldata()` (i.e. `ValueMarshall`) which uses `ValueType`. The same direct-bind pattern applies — emit the raw array type instead of wrapping in `ValueType`. + +--- + +### Step 1.4: Implement for `StructMarshall` (dict → struct) + +In [slangpy/bindings/boundvariable.py](slangpy/bindings/boundvariable.py): + +When a Python dict is bound to a Slang struct and `is_direct_bind_recursive(binding)` returns `True` (all children are dim-0 and direct-eligible recursively), the **Slang-side** struct can bypass the inline `__slangpy_load`/`__slangpy_store` struct generation. The **Python/C++ side** keeps the existing tree of marshalls unchanged — they continue to recurse through children and cache offsets for efficient cursor writes. + +This is a Slang-code-gen-only simplification: +- **Current behavior (children path in `gen_call_data_code`)**: generates an inline struct `_t_{name}` with field declarations, `__slangpy_load`/`__slangpy_store` methods, and mapping constants — delegates each child's code gen recursively +- **Direct-eligible behavior**: emit `typealias _t_{name} = {vector_type.full_name};` (the raw Slang struct type). Skip generating the inline struct, its load/store methods, child type aliases, and child mapping constants entirely. + +In the trampoline: +- `gen_trampoline_load`: emit `{name} = {data_name};` (direct struct assignment) and return `True` +- `gen_trampoline_store`: return `True` (suppress default store for read-only structs) + +**Python/C++ dispatch — keep the child tree, fix the per-child cursor path**: + +The Python-side tree of marshalls is kept for dispatch. When a `BoundVariable` has children (dict case), the C++ dispatch in [slangpy.cpp](src/slangpy_ext/utils/slangpy.cpp) `NativeBoundVariableRuntime::write_shader_cursor_pre_dispatch` still takes the children branch: + +```cpp +ShaderCursor child_field = cursor[m_variable_name.c_str()]; +for (const auto& [name, child_ref] : *m_children) { + child_ref->write_shader_cursor_pre_dispatch(context, child_field, child_value, read_back); +} +``` + +Each child leaf calls `NativeValueMarshall::write_shader_cursor_pre_dispatch`, which navigates `cursor[variable_name]["value"]`. If the Slang struct type changes from the inline struct (where each child field is `ValueType` with a `value` sub-field) to the raw Slang struct (where each child field is `float` directly), the `["value"]` navigation will crash. + +**Solution**: Set the `direct_bind` flag from Step 1.2a on each child's `NativeValueMarshall`. The per-child flag causes each leaf's `ensure_cached` to navigate `cursor[variable_name]` only (no `["value"]` sub-field). This is the same fix as Step 1.2a applied to each child — no changes to the children dispatch path itself are needed. + +`StructMarshall.create_calldata` is dead code for the children path: when `m_children` is set on the C++ `NativeBoundVariableRuntime`, the children dispatch branch runs instead of calling the marshall's `create_calldata`. The current `ValueMarshall.create_calldata` (which `StructMarshall` inherits) is never called for dict bindings. It can be removed from `StructMarshall` if desired, but is harmless. + +**Complexity considerations:** +- The recursive eligibility check must traverse all children. Nested dicts (struct-of-struct) work if all leaves are direct-eligible. +- The `vector_type` on the `BoundVariable` must be a concrete Slang struct type (not `UnknownType`). If the dict has `_type` specified, the struct type is resolved; if not, it may be `UnknownType` and ineligible. +- Writable struct fields (inout/out parameters) need the same treatment as writable scalars — the struct in CallData stays as the raw type, but the trampoline does direct assignment both ways. +- This optimization can be deferred if it proves too complex for the initial Phase 1 implementation — the fallback (current inline struct with load/store) always works. Priority should be leaf types first. + +--- + +### Step 1.5: Implement for `ValueRefMarshall` + +In [slangpy/builtin/valueref.py](slangpy/builtin/valueref.py): + +Note: There is only **one** `ValueRefMarshall` class (not separate `ValueRef`/`RWValueRef` classes). It inherits from `Marshall` (not `NativeValueMarshall`). Read vs. write behavior is determined by `binding.access` at codegen time — the same class emits `ValueRef` or `RWValueRef` depending on access mode. + +- In `gen_calldata`, call `is_direct_bind_eligible(binding)`. When eligible: +- Read-only path (`access[0] == AccessType.read`): `gen_calldata` emits raw type, `gen_trampoline_load` does direct assignment, `create_calldata` returns raw value +- Writable path (`access[0] != AccessType.read`): `gen_calldata` emits `RWStructuredBuffer<{type}>`, `gen_trampoline_load` emits `{name} = {data_name}[0];`, `gen_trampoline_store` emits `{data_name}[0] = {name};`, `create_calldata` returns the buffer directly (no `{"value": buffer}` wrapper). Note: `RWStructuredBuffer` is a **resource type** in Slang — the cursor write system handles it via the resource binding mechanism. Buffer objects are written to resource-typed cursor fields via the `write_value` virtual path in [cursor_utils.h](src/slangpy_ext/device/cursor_utils.h), not the struct/scalar dispatch in `write_internal`. + +Since `ValueRefMarshall` extends `Marshall` (not `NativeValueMarshall`), the C++ fast path issue from Step 1.2a does not apply — `NativeMarshall::write_shader_cursor_pre_dispatch` calls `create_calldata` and then passes the result to the generic `write_shader_cursor(cursor, cd_val)`, which dispatches based on the Slang type layout. There is no cached `["value"]` navigation. + +--- + +### Step 1.6: Implement for tensor marshalls + +In [slangpy/builtin/tensorcommon.py](slangpy/builtin/tensorcommon.py): The `TensorView`/`DiffTensorView` case already works via direct assignment. + +For `Tensor` (the slangpy Tensor type): this is a **complex struct** containing `_data` (a `StructuredBuffer`/pointer), `_shape[D]`, `_strides[D]`, and `_offset`. It is NOT a simple assignable type like a scalar — it always requires its buffer resource descriptor and metadata to be bound. However, since it is already a well-defined Slang struct, it can still use direct assignment (`name = call_data.name;`) in the trampoline when dim-0. The `gen_calldata` already emits the correct tensor type name. Add `gen_trampoline_load` to handle `ITensorType` dim-0 with direct assignment (same as TensorView pattern — the struct is copied as a whole). + +Note: `Tensor` cannot be simplified to a raw value the way scalars can — it stays as a struct in CallData. The simplification here is only at the trampoline level (bypassing `__slangpy_load`/`__slangpy_store`). + +--- + +### Step 1.7: Eliminate unused boilerplate in code generation + +In [slangpy/core/callsignature.py](slangpy/core/callsignature.py): + +- **Mapping constants**: In `BoundVariable.gen_call_data_code()` ([slangpy/bindings/boundvariable.py](slangpy/bindings/boundvariable.py)), skip emitting `static const int _m_{name} = 0;` when `is_direct_bind_eligible(self)` (for leaves) or `is_direct_bind_recursive(self)` (for composites) returns `True`. These constants are only consumed by `__slangpy_context__.map(_m_{name})` calls, which direct-bound variables skip. +- **`import "slangpy"`**: Keep this import. Attempting to detect and eliminate it provides negligible benefit for significant complexity. The slangpy Slang module is always available and the link-time constants are always emitted. The focus of this phase is eliminating wrapper types and `__slangpy_load`/`__slangpy_store` indirection, not the import. + +--- + +### Step 1.8: Handle autodiff (bwds mode) + +For differentiable types in bwds mode: +- Primal reads are still direct-eligible (just a direct assignment) +- Derivative writes need writable backing — use `RWStructuredBuffer` for derivative fields (similar to `RWValueRef` pattern) +- The trampoline must remain `[Differentiable]`, but direct assignment `a = call_data.a;` is trivially differentiable in Slang +- The `gen_trampoline_load`/`gen_trampoline_store` implementations need to account for `access[1]` (derivative access) and emit derivative load/store code when needed +- This is the most complex part of Phase 1; consider implementing prim-mode direct binding first, then extending to bwds + +--- + +### Step 1.9: Tests + +Extend [slangpy/tests/slangpy_tests/test_kernel_gen.py](slangpy/tests/slangpy_tests/test_kernel_gen.py). All tests use `generate_code()` which calls `func.debug_build_call_data(*args, **kwargs)` and returns `cd.code`. Tests are parametrized across `helpers.DEFAULT_DEVICE_TYPES`. + +**Assertion helpers** (added to test file): +- `assert_contains(code, *patterns)` — assert all patterns appear in generated code +- `assert_not_contains(code, *patterns)` — assert none appear + +**Gating tests** — assert CURRENT behavior so they break when each step is implemented: + +| Test | Slang Source | Args | Asserts (current behavior) | Breaks when | +|------|-------------|------|---------------------------|-------------| +| `test_gate_scalar_uses_valuetype` | `int add(int a, int b) { return a + b; }` | `(1, 2)` | `ValueType` present, `__slangpy_load` for `a`/`b`, `__slangpy_store` for `_result` | Step 1.2 | +| `test_gate_float_scalar_uses_valuetype` | `float mul(float x, float y) { return x * y; }` | `(1.0, 2.0)` | `ValueType` present, `__slangpy_load` present | Step 1.2 | +| `test_gate_vector_uses_vectorvaluetype` | `float3 scale(float3 v, float s) { return v * s; }` | `(spy.math.float3(1,2,3), 1.0)` | `VectorValueType` for `v` (no space after comma), `ValueType` for `s` | Step 1.3 | +| `test_gate_matrix_uses_valuetype` | `float4x4 ident(float4x4 m) { return m; }` | `(spy.math.float4x4.identity(),)` | `ValueType` present | Step 1.3 | +| `test_gate_valueref_read_uses_wrapper` | `float read_val(float v) { return v; }` | `(spy.ValueRef(1.0),)` | `ValueRef` present, `__slangpy_load` present | Step 1.5 | +| `test_gate_valueref_write_uses_wrapper` | `int add(int a, int b) { return a + b; }` | `(1, 2)` (auto `_result`) | `RWValueRef` for `_result`, `__slangpy_store` present | Step 1.5 | +| `test_gate_array_dim0_uses_valuetype` | `void process(float a[4]) { }` | `([1.0, 2.0, 3.0, 4.0],)` | `ValueType<` present for array binding | Step 1.3 | +| `test_gate_mapping_constants_present` | `int add(int a, int b) { return a + b; }` | `(1, 2)` | `static const int _m_a = 0` and `_m_b` and `_m__result` present | Step 1.7 | +| `test_gate_context_map_in_trampoline` | `int add(int a, int b) { return a + b; }` | `(1, 2)` | `__slangpy_context__.map(_m_a)` in trampoline | Step 1.7 | +| `test_gate_struct_uses_slangpy_load` | `struct S { float x; float y; }; float sum(S s) { return s.x + s.y; }` | `({"x": 1.0, "y": 2.0},)` | inline struct `_t_s` with `__slangpy_load`, child mapping constants `_m_x`, `_m_y` | Step 1.4 | + +**Negative gates** — should REMAIN passing after Phase 1 (these types are NOT direct-bind eligible): + +| Test | Slang Source | Args | Asserts (must stay) | +|------|-------------|------|--------------------| +| `test_gate_wanghasharg_uses_wrapper` | `int rng(WangHashArg rng) { return 0; }` | `(spy.WangHashArg(1),)` | `WangHashArg<` in type alias, `__slangpy_load` present | +| `test_gate_vectorized_scalar_keeps_wrapper` | `float square(float x) { return x * x; }` | `(Tensor.numpy(np.array([1,2,3], dtype=np.float32)),)` | `ValueType` present (dim > 0, not direct-eligible) | +| `test_gate_vectorized_dict_keeps_struct_load` | `struct S { float x; float y; }; void apply(S s, float scale) {}` | `({"x": Tensor(...), "y": Tensor(...)}, 1.0)` | inline struct with `__slangpy_load` (children are vectorized, dim > 0) | + +**Autodiff gating tests:** + +| Test | Slang Source | Args | Asserts | +|------|-------------|------|---------| +| `test_gate_bwds_scalar_uses_valuetype` | `[Differentiable] float square(float x) { return x * x; }` | `func.bwds.debug_build_call_data(diffPair(2.0), diffPair(d=1.0))` | `ValueType` present, `[Differentiable]` on trampoline, `bwd_diff(_trampoline)` in kernel | +| `test_gate_bwds_trampoline_is_differentiable` | same as above | same | `[Differentiable]` appears before `void _trampoline` | + +**Post-implementation tests** — should pass AFTER Phase 1 is complete: + +- `test_phase1_scalar_direct_bind`: verify NO `ValueType` or `__slangpy_load` for scalar args +- `test_phase1_vector_direct_bind`: verify NO `VectorValueType` for vector args +- `test_phase1_valueref_direct_bind`: verify `RWStructuredBuffer` appears directly for writable result +- `test_phase1_struct_direct_bind`: verify NO inline struct with `__slangpy_load` for dim-0 dict-to-struct +- `test_phase1_no_mapping_constants`: verify NO `_m_a`, `_m_b` for direct-bound args +- `test_phase1_functional_scalar_add`: dispatch `add(1, 2)` and verify result == 3 +- `test_phase1_functional_vector_scale`: dispatch vector scale and verify result +- `test_phase1_functional_struct_sum`: dispatch struct sum via dict and verify result + +--- + +### Implementation Order Within Phase 1 + +To avoid C++ crashes from the `NativeValueMarshall` fast path (Step 1.2a), the implementation order within Phase 1 must be: + +1. **Step 1.2a first**: Update `NativeValueMarshall::ensure_cached` in C++ to handle direct-bind types (no `"value"` sub-field navigation). This is the only C++ change needed — `is_direct_bind_eligible` is pure Python, no nanobind changes required. +2. **Step 1.1**: Add `is_direct_bind_eligible` and `is_direct_bind_recursive` as global Python functions +3. **Steps 1.2–1.7**: Implement Python-side changes for each marshall type +4. **Step 1.4**: For struct children, set the `direct_bind` flag on each child's `NativeValueMarshall` (same per-child flag from Step 1.2a) — no changes to the C++ children dispatch path needed +5. **Step 1.8**: Autodiff support +6. **Step 1.9**: Tests + +Never deploy Python-side `gen_calldata` changes that emit raw types without the corresponding C++ fast path fix — the cached `["value"]` navigation will crash at dispatch time. diff --git a/.github/prompts/plan-simplifyKernelGen-phase2.prompt.md b/.github/prompts/plan-simplifyKernelGen-phase2.prompt.md new file mode 100644 index 000000000..5fa19c3d0 --- /dev/null +++ b/.github/prompts/plan-simplifyKernelGen-phase2.prompt.md @@ -0,0 +1,121 @@ +## Phase 2: Eliminate CallData Struct + +**Goal**: When ALL arguments are direct-eligible, bypass the `CallData` struct entirely and pass arguments as individual parameters on the entry point (or individual globals). + +**Parent plan**: [plan-simplifyKernelGen.prompt.md](plan-simplifyKernelGen.prompt.md) + +--- + +### Step 2.1: Determine eligibility + +Add a check in [slangpy/core/calldata.py](slangpy/core/calldata.py) after all bindings are resolved: if every `BoundVariable` satisfies `is_direct_bind_eligible` (for leaves) or `is_direct_bind_recursive` (for composites) AND `call_data_len == 0` (no N-dimensional shape arrays needed), set a new flag `self.use_direct_args = True`. + +--- + +### Step 2.2: New code generation path + +In [slangpy/core/callsignature.py](slangpy/core/callsignature.py), when `use_direct_args`: + +- **Skip CallData struct generation** entirely. Note: `CodeGen.__init__` in [codegen.py](slangpy/bindings/codegen.py) unconditionally emits `struct CallData { ... }` — the constructor creates the `self.call_data` block and `finish()` calls `self.call_data.end_block()`. To eliminate CallData, either: + - Add a `skip_call_data` flag to `CodeGen.__init__` that conditionally initializes the block, and condition the `end_block()` in `finish()` on the same flag, OR + - Clear `self.call_data` contents before `finish()` when `use_direct_args` is true +- **Generate compute_main** with individual `uniform` parameters. The current compute_main signature has three semantic params: + ``` + void compute_main(int3 flat_call_thread_id: SV_DispatchThreadID, int3 flat_call_group_id: SV_GroupID, int flat_call_group_thread_id: SV_GroupIndex, uniform CallData call_data) + ``` + When `use_direct_args` and `call_data_len == 0`, the `SV_GroupID` and `SV_GroupIndex` params are unused (they feed `init_thread_local_call_shape_info` which reads `call_data._grid_stride`/`_grid_dim`/`_call_dim`). They can be dropped, simplifying to: + ``` + void compute_main(int3 flat_call_thread_id: SV_DispatchThreadID, uniform uint3 _thread_count, uniform int a, uniform int b, uniform RWStructuredBuffer _result) + ``` +- **Inline the function call** into compute_main (skip trampoline for prim mode): `_result[0] = add(a, b);` +- **Keep trampoline** for bwds mode (needed for `bwd_diff()`). The trampoline wraps the call with `[Differentiable]` and allows `bwd_diff(_trampoline)` from compute_main. In this case, generate a trampoline that takes individual params instead of a struct. Direct assignment `a = param_a;` is trivially differentiable in Slang for floating-point types. For non-differentiable types (int, etc.), autodiff is irrelevant. + +--- + +### Step 2.3: Entry point parameters for all backends + +Currently, CUDA (entry_point mode) already passes a `CallData` struct as a `uniform` entry point parameter. The simplification extends this: instead of a single struct, pass individual `uniform` parameters on the entry point — for ALL backends, not just CUDA. + +See [slangpy/tests/device/test_pipeline_utils.slang](slangpy/tests/device/test_pipeline_utils.slang) for examples of manually-written compute shaders that use entry point parameters on all backends: +```slang +[shader("compute")] +[numthreads(16, 16, 1)] +void setcolor( + uint3 tid: SV_DispatchThreadID, + RWTexture2D render_texture, + uniform int2 pos, + uniform int2 size, + uniform float4 color +) +``` + +Entry point parameters work on all backends (CUDA, Vulkan, D3D12). For `global_data` mode, the C++ side currently navigates `cursor["call_data"]` to write into a `ParameterBlock` global. With direct args, it would instead navigate `cursor.find_entry_point(0)` and write each parameter by index — the same mechanism CUDA already uses, but now applied universally. + +The `CallData` struct can be omitted entirely when all args are direct-eligible. If some args still need the struct (e.g., shape arrays for `call_data_len > 0`, or non-direct-eligible types), emit a hybrid: direct-eligible args as individual entry point params, and the remaining data in a `CallData` struct that is also an entry point param. + +**Entry point size limits**: Some platforms impose limits on the total size of entry point parameter data (e.g., CUDA root constants are limited to ~4KB, D3D12 root signature has a 64 DWORD limit). To handle this: +- Define a per-backend threshold for maximum entry point parameter data size (queryable from device/backend info) +- During code generation, accumulate the uniform byte size of each direct-eligible argument. Resource types (`RWStructuredBuffer`, `Texture2D`, etc.) don't count toward the limit — they are bound as descriptors, not inline data +- If a single argument exceeds the threshold, force it back to `CallData` +- If the cumulative total exceeds the threshold, force remaining arguments (in declaration order) back to `CallData` +- The result may be a hybrid kernel: some args as entry point params, the rest in a `CallData` struct entry point param +- The C++ dispatch side must know which args are direct vs CallData-bound (store a per-argument flag or a bitmask on `NativeCallData`) + +--- + +### Step 2.4: C++ dispatch changes + +In [src/slangpy_ext/utils/slangpy.cpp](src/slangpy_ext/utils/slangpy.cpp): + +- **Store `use_direct_args` flag** on `NativeCallData` (receive from Python `CallData`) +- **Both modes**: In `bind_call_data`, navigate via `cursor.find_entry_point(0)` and write each argument directly to its own entry point parameter by index. This is the same cursor API already used for CUDA entry_point mode — it just needs to write individual params instead of navigating into a single `CallData` struct field. +- **Thread count**: Write `_thread_count` as a separate entry point parameter instead of a struct field +- **Context construction**: The current kernel code constructs a `Context __slangpy_context__` from `call_data` fields (e.g., `flat_call_thread_id, CallShapeInfo::get_call_id().shape`). When `use_direct_args` and `call_data_len == 0`, the Context is simplified to just `{flat_call_thread_id}` and `CallShapeInfo` / `init_thread_local_call_shape_info` can be skipped. If Context is eliminated entirely (Phase 2 with inlined function calls), this becomes moot. +- **Skip shape array writing** (`_grid_stride`, `_grid_dim`, `_call_dim`) since `call_data_len == 0` +- **Cache parameter offsets**: Cache the entry point parameter indices at first dispatch (similar to existing `m_cached_call_data_offsets`) + +--- + +### Step 2.5: Trampoline elimination for prim mode + +When `use_direct_args` and `call_mode == prim`: +- Don't generate a trampoline function +- Emit the function call directly in `compute_main` using the uniform parameter names +- For output variables, emit the store directly (e.g., `_result[0] = add(a, b);`) + +When `call_mode == bwds`: +- Still generate a trampoline (needed for `bwd_diff()`) +- Pass individual params to the trampoline instead of a struct + +--- + +### Step 2.6: Tests + +**Gating tests** — assert CURRENT behavior so they break when Phase 2 is implemented: + +| Test | Slang Source | Args | Asserts (current behavior) | Breaks when | +|------|-------------|------|---------------------------|-------------| +| `test_gate_calldata_struct_present` | `int add(int a, int b) { return a + b; }` | `(1, 2)` | `struct CallData` present in generated code | Step 2.1 | +| `test_gate_calldata_uniform_param` | same | same | `uniform CallData call_data` in `compute_main` signature (note: actual signature also includes `SV_GroupID` and `SV_GroupIndex` params) | Step 2.2 | +| `test_gate_thread_count_in_calldata` | same | same | `call_data._thread_count` in kernel body | Step 2.4 | +| `test_gate_context_from_calldata` | same | same | `Context __slangpy_context__` construction present in kernel body | Step 2.4 | +| `test_gate_trampoline_present_for_prim` | same | same | `void _trampoline(` present | Step 2.5 | +| `test_gate_trampoline_calls_function` | same | same | `_result = add(a, b)` inside trampoline | Step 2.5 | +| `test_gate_kernel_calls_trampoline` | same | same | `_trampoline(` inside `compute_main` body | Step 2.5 | + +**Negative gates** — should REMAIN passing after Phase 2: + +| Test | Slang Source | Args | Asserts (must stay) | +|------|-------------|------|--------------------| +| `test_gate_wanghasharg_forces_calldata` | `int rng(WangHashArg rng, int x) { return x; }` | `(spy.WangHashArg(1), 1)` | `struct CallData` present (non-eligible arg forces fallback) | + +**Post-implementation tests** — should pass AFTER Phase 2 is complete: + +- `test_phase2_no_calldata_struct`: verify `struct CallData` absent for all-eligible scalar call +- `test_phase2_uniform_params_on_entry`: verify individual `uniform int a`, `uniform int b` on `compute_main` +- `test_phase2_no_trampoline_prim`: verify no `_trampoline(` for prim-mode eligible calls +- `test_phase2_thread_count_as_uniform`: verify `uniform uint3 _thread_count` as entry point param +- `test_phase2_inline_function_call`: verify `_result[0] = add(a, b)` directly in kernel +- `test_phase2_bwds_keeps_trampoline`: verify bwds mode still has `_trampoline` and `bwd_diff` +- `test_phase2_mixed_args_hybrid`: mix direct-eligible + WangHashArg → hybrid kernel +- `test_phase2_functional_all_backends`: dispatch scalar add on each backend, verify result diff --git a/.github/prompts/plan-simplifyKernelGen-phase3.prompt.md b/.github/prompts/plan-simplifyKernelGen-phase3.prompt.md new file mode 100644 index 000000000..8736dca4a --- /dev/null +++ b/.github/prompts/plan-simplifyKernelGen-phase3.prompt.md @@ -0,0 +1,48 @@ +## Phase 3: Direct Compute Kernel Invocation + +**Goal**: When the user's Slang function is ALREADY a `[shader("compute")]` entry point (or can trivially be one), skip kernel generation entirely and dispatch the pre-written shader directly. + +**Parent plan**: [plan-simplifyKernelGen.prompt.md](plan-simplifyKernelGen.prompt.md) + +--- + +### Step 3.1: Detection + +In the function resolution phase, detect when the target Slang function: +- Has `[shader("compute")]` attribute +- Has parameter types that SlangPy can bind directly (uniforms, buffers, textures) +- Has explicit thread count specified by the user (already supported via `function.set_thread_count()`) + +--- + +### Step 3.2: Direct dispatch path + +When eligible: +- Skip Phase 2 (kernel generation) entirely +- Create a `ComputePipeline` directly from the user's shader +- Map Python arguments to entry point parameters using the type marshalling but without code generation +- Dispatch directly + +--- + +### Step 3.3: Argument binding + +Leverage Phase 2's per-argument binding infrastructure — the same cursor write logic that writes individual uniform params would write to the pre-written shader's entry point params. + +--- + +### Step 3.4: Tests + +**Gating test** — assert CURRENT behavior so it breaks when Phase 3 is implemented: + +| Test | Slang Source | Args | Asserts (current behavior) | Breaks when | +|------|-------------|------|---------------------------|-------------| +| `test_gate_compute_shader_generates_wrapper` | Source with `[shader("compute")] void my_kernel(...)` function, test calls a helper function in the same module | N/A | SlangPy generates its own `compute_main` wrapper; user's `[shader("compute")]` is ignored | Step 3.1 | + +**Post-implementation tests** — should pass AFTER Phase 3 is complete: + +- `test_phase3_direct_dispatch`: dispatch a pre-written `[shader("compute")]` kernel directly, verify no wrapper generated +- `test_phase3_requires_thread_count`: verify error when thread count not specified +- `test_phase3_scalar_params`: verify scalar uniform params bind correctly +- `test_phase3_buffer_params`: verify `RWStructuredBuffer` params bind correctly +- `test_phase3_texture_params`: verify texture params bind correctly diff --git a/.github/prompts/plan-simplifyKernelGen.prompt.md b/.github/prompts/plan-simplifyKernelGen.prompt.md new file mode 100644 index 000000000..867262cc2 --- /dev/null +++ b/.github/prompts/plan-simplifyKernelGen.prompt.md @@ -0,0 +1,90 @@ +## Plan: Simplify Generated SlangPy Kernels + +**TL;DR**: A three-phase effort to make generated kernels resemble hand-written GPU code. Phase 1 adds direct type marshalling (bypassing `ValueType` wrappers and `__slangpy_load`/`__slangpy_store`) for dim-0 non-composite types, following the pattern already used by `TensorView`. Phase 2 eliminates the `CallData` struct when all arguments are direct-eligible, passing them as individual uniforms/globals. Phase 3 enables calling pre-written compute kernels directly without generating wrapper shaders. + +**Target example** — `add(int a, int b) -> int` with scalar args should go from 40+ lines of boilerplate to approximately: + +```slang +import "module"; +[shader("compute")] +[numthreads(32, 1, 1)] +void compute_main(int3 tid: SV_DispatchThreadID, uniform uint3 _thread_count, uniform int a, uniform int b, uniform RWStructuredBuffer _result) +{ + if (any(tid >= _thread_count)) return; + _result[0] = add(a, b); +} +``` + +--- + +### Phase Plans + +- [Phase 1: Direct Type Marshalling](plan-simplifyKernelGen-phase1.prompt.md) +- [Phase 2: Eliminate CallData Struct](plan-simplifyKernelGen-phase2.prompt.md) +- [Phase 3: Direct Compute Kernel Invocation](plan-simplifyKernelGen-phase3.prompt.md) + +--- + +### Gating Tests — Pre-Implementation Checklist + +Before implementing any phase, add **gating tests** to [slangpy/tests/slangpy_tests/test_kernel_gen.py](slangpy/tests/slangpy_tests/test_kernel_gen.py) that assert the CURRENT generated kernel patterns. These tests document the baseline and will intentionally break as each simplification step is implemented. + +**Design principles:** +- All gating tests are code-generation-only (no GPU dispatch) — fast and deterministic +- All tests use the existing `generate_code()` helper → `func.debug_build_call_data()` → `cd.code` +- Tests are parametrized across `helpers.DEFAULT_DEVICE_TYPES` +- String matching (substring checks) rather than regex or golden files +- Named `test_gate_*` for easy identification +- WangHashArg and dict/composite tests serve as "negative gates" — they remain passing after simplification + +**Test infrastructure additions:** +```python +def assert_contains(code: str, *patterns: str) -> None: + for p in patterns: + assert p in code, f"Expected pattern not found: {p}" + +def assert_not_contains(code: str, *patterns: str) -> None: + for p in patterns: + assert p not in code, f"Unexpected pattern found: {p}" + +def generate_bwds_code(device, func_name, module_source, *args, **kwargs) -> str: + func = helpers.create_function_from_module(device, func_name, module_source) + cd = func.bwds.debug_build_call_data(*args, **kwargs) + if PRINT_TEST_KERNEL_GEN: + print(cd.code) + return cd.code +``` + +**Summary of all gating tests by phase:** + +| Phase | Gating Tests (break on implementation) | Negative Gates (must stay passing) | +|-------|---------------------------------------|-----------------------------------| +| 1 | 12 tests: scalar/float/vector/matrix/valueref-read/valueref-write/array/mapping-constants/context-map/struct-slangpy-load/bwds-scalar/bwds-trampoline | 3 tests: wanghasharg/vectorized-scalar/vectorized-dict | +| 2 | 7 tests: calldata-struct/calldata-uniform/thread-count/context-from-calldata/trampoline-present/trampoline-calls/kernel-calls-trampoline | 1 test: wanghasharg-forces-calldata | +| 3 | 1 test: compute-shader-generates-wrapper | — | + +--- + +### Verification (all phases) + +```bash +# Build first (required) +cmake --build --preset windows-msvc-debug + +# Run kernel gen tests +$env:PRINT_TEST_KERNEL_GEN="1"; pytest slangpy/tests/slangpy_tests/test_kernel_gen.py -v + +# Run full test suite +pytest slangpy/tests -v + +# Run pre-commit +pre-commit run --all-files +``` + +### Key Decisions + +- Phase 1 changes both `gen_calldata` and trampoline load/store (TensorView-complete pattern, not partial) +- All dim-0 non-composite types are eligible, including tensors and value refs +- Phase 2 targets both `entry_point` (CUDA) and `global_data` (Vulkan/D3D12) modes +- Autograd (bwds mode) is included in simplification, but implemented after prim mode within each phase +- WangHashArg explicitly excluded from direct binding (needs per-thread `thread_id` computation) diff --git a/slangpy/benchmarks/test_benchmark_autograd.py b/slangpy/benchmarks/test_benchmark_autograd.py index 114e0a42c..ee04667dc 100644 --- a/slangpy/benchmarks/test_benchmark_autograd.py +++ b/slangpy/benchmarks/test_benchmark_autograd.py @@ -27,9 +27,9 @@ pass SLEEPS = True -ITERATIONS = 10 +ITERATIONS = 100 SUB_ITERATIONS = 20000 -WARMUPS = 10 +WARMUPS = 1000 # ITERATIONS = 1 # SUB_ITERATIONS = 1 @@ -49,7 +49,7 @@ # ============================================================================= RUN_PURE_TORCH_BENCHMARK = False -RUN_SLANGTORCH_BENCHMARK = False +RUN_SLANGTORCH_BENCHMARK = True RUN_SLANGPY_MANUAL_HOOK_BENCHMARK = True RUN_SLANGPY_AUTOMATIC_BENCHMARK = True diff --git a/slangpy/core/calldata.py b/slangpy/core/calldata.py index 6e72803f3..209ae8ffa 100644 --- a/slangpy/core/calldata.py +++ b/slangpy/core/calldata.py @@ -431,6 +431,9 @@ def build(self, build_info: "FunctionBuildInfo", *args: Any, **kwargs: Any): self.debug_only_bindings = bindings self.runtime = BoundCallRuntime(bindings) + # Store the code as its useful for debugging + self.code = code + # If using autograd, build list of access modes for each tensor argument. if self.torch_autograd: self._build_autograd_access_list(unpacked_args, unpacked_kwargs) diff --git a/slangpy/tests/slangpy_tests/test_kernel_gen.py b/slangpy/tests/slangpy_tests/test_kernel_gen.py new file mode 100644 index 000000000..bcbb0e6e2 --- /dev/null +++ b/slangpy/tests/slangpy_tests/test_kernel_gen.py @@ -0,0 +1,53 @@ +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +""" +Kernel generation test. + +These tests exercise different code paths for kernel generation, to exercise different kernel types, such as: +- passing arguments directly vs via call data +- passing read-only arguments that don't need storing directly rather than via marshalls +- handling the semantic 'dispatch thread id' etc and calling kernels directly +""" + +from typing import Any + +import pytest +import os + +import slangpy as spy +from slangpy.testing import helpers + +PRINT_TEST_KERNEL_GEN = os.getenv("PRINT_TEST_KERNEL_GEN", "0") == "1" + + +def generate_code( + device: spy.Device, func_name: str, module_source: str, *args: Any, **kwargs: Any +) -> str: + """ + Generate code for the given function and arguments, and return the generated code as a string. + """ + func = helpers.create_function_from_module(device, func_name, module_source) + cd = func.debug_build_call_data(*args, **kwargs) + if PRINT_TEST_KERNEL_GEN: + print(cd.code) + return cd.code + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_kernel_gen_basic(device_type: spy.DeviceType): + """ + Test basic kernel generation with a simple function that adds two numbers. + """ + src = """ +int add(int a, int b) { + return a + b; +} +""" + device = helpers.get_device(device_type) + code = generate_code(device, "add", src, 1, 2) + print(code) + assert "add" in code + + +if __name__ == "__main__": + pytest.main([__file__, "-vs"]) From 9bca5ab998fc6b7e6d53bcc59cae2c20a1df2d75 Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Tue, 10 Mar 2026 19:34:27 +0000 Subject: [PATCH 02/41] first attempt at phase 1 --- .../plan-simplifyKernelGen-phase1.prompt.md | 2 + .../prompts/plan-simplifyKernelGen.prompt.md | 41 +- slangpy/bindings/__init__.py | 2 + slangpy/bindings/boundvariable.py | 183 +++++++-- slangpy/builtin/struct.py | 30 +- slangpy/builtin/tensorcommon.py | 19 + slangpy/builtin/value.py | 37 +- slangpy/builtin/valueref.py | 78 +++- .../tests/slangpy_tests/test_kernel_gen.py | 384 ++++++++++++++++++ src/slangpy_ext/utils/slangpyvalue.cpp | 8 +- src/slangpy_ext/utils/slangpyvalue.h | 14 +- 11 files changed, 742 insertions(+), 56 deletions(-) diff --git a/.github/prompts/plan-simplifyKernelGen-phase1.prompt.md b/.github/prompts/plan-simplifyKernelGen-phase1.prompt.md index bfe58cc28..bc398df03 100644 --- a/.github/prompts/plan-simplifyKernelGen-phase1.prompt.md +++ b/.github/prompts/plan-simplifyKernelGen-phase1.prompt.md @@ -1,5 +1,7 @@ ## Phase 1: Direct Type Marshalling +**Status**: Prim-mode complete (Steps 1.1–1.7, 1.9). Step 1.8 (autodiff derivative fields) deferred. + **Goal**: For dim-0, non-composite arguments, emit the raw Slang type in CallData and use direct assignment in the trampoline — eliminating `ValueType` wrappers, `__slangpy_load`/`__slangpy_store` indirection, mapping constants, and `Context.map()` calls. **Parent plan**: [plan-simplifyKernelGen.prompt.md](plan-simplifyKernelGen.prompt.md) diff --git a/.github/prompts/plan-simplifyKernelGen.prompt.md b/.github/prompts/plan-simplifyKernelGen.prompt.md index 867262cc2..d95a79e58 100644 --- a/.github/prompts/plan-simplifyKernelGen.prompt.md +++ b/.github/prompts/plan-simplifyKernelGen.prompt.md @@ -19,12 +19,51 @@ void compute_main(int3 tid: SV_DispatchThreadID, uniform uint3 _thread_count, un ### Phase Plans -- [Phase 1: Direct Type Marshalling](plan-simplifyKernelGen-phase1.prompt.md) +- [Phase 1: Direct Type Marshalling](plan-simplifyKernelGen-phase1.prompt.md) — **implemented (prim-mode)** - [Phase 2: Eliminate CallData Struct](plan-simplifyKernelGen-phase2.prompt.md) - [Phase 3: Direct Compute Kernel Invocation](plan-simplifyKernelGen-phase3.prompt.md) --- +### Phase 1 Progress + +Phase 1 prim-mode direct binding is complete. Steps 1.1–1.7, 1.9 are implemented and passing. Step 1.8 (autodiff/bwds) is deferred. + +**What was done:** + +| Step | Status | Summary | +|------|--------|---------| +| 1.2a | ✅ Done | `NativeValueMarshall` C++ fast path: `m_direct_bind` flag gates `["value"]` sub-field navigation in `ensure_cached`. Exposed via nanobind `direct_bind` property. | +| 1.1 | ✅ Done | `is_direct_bind_eligible()` and `is_direct_bind_recursive()` in `boundvariable.py`. Excludes `PackedArg` bindings and children inside non-direct-bind structs (via `_force_no_direct_bind` flag). | +| 1.2 | ✅ Done | `ValueMarshall`: `gen_calldata` emits raw `typealias`; `gen_trampoline_load/store` do direct assignment; `create_calldata` returns raw value. `ScalarMarshall`/`MatrixMarshall` inherit. | +| 1.3 | ✅ Done | `VectorMarshall`: `gen_calldata` emits raw `typealias` (e.g., `vector`). Inherits trampoline load/store from `ValueMarshall`. | +| 1.4 | ✅ Done | `StructMarshall`/`BoundVariable`: `gen_call_data_code` children path emits `typealias _t_{name} = {struct_type}` when `is_direct_bind_recursive`. Sets `direct_bind` on child marshalls via `_set_direct_bind_on_children`. Non-direct-bind structs set `_force_no_direct_bind` on children to prevent incorrect leaf optimization. `gen_trampoline_load/store` added. | +| 1.5 | ✅ Done | `ValueRefMarshall`: read-only emits raw type + direct assignment; writable emits `RWStructuredBuffer` + `[0]` load/store. `create_calldata`/`read_calldata` skip `{"value": ...}` wrapper when direct-eligible. | +| 1.6 | ✅ Done | Tensor dim-0: `gen_trampoline_load/store` extended for `ITensorType` at dim-0 (direct struct assignment). | +| 1.7 | ✅ Done | Mapping constants (`static const int _m_{name}`) skipped for direct-bind-eligible variables. | +| 1.8 | ⬜ Deferred | Autodiff/bwds mode still uses wrapper types. Prim-mode direct binding does apply to bwds primals (code gen verified), but derivative fields still use the old path. | +| 1.9 | ✅ Done | 21 tests (×3 device types = 63 cases): 16 code-gen assertion tests + 5 functional GPU dispatch tests. All pass on d3d12/vulkan/cuda. | + +**Files modified:** + +| File | Changes | +|------|---------| +| `src/slangpy_ext/utils/slangpyvalue.h` | `m_direct_bind` flag, getter/setter | +| `src/slangpy_ext/utils/slangpyvalue.cpp` | `ensure_cached` direct-bind branch; nanobind export | +| `slangpy/bindings/boundvariable.py` | `is_direct_bind_eligible`, `is_direct_bind_recursive`, `_set_direct_bind_on_children`, `_force_no_direct_bind`, mapping-constant skip in `gen_call_data_code` | +| `slangpy/bindings/__init__.py` | Exports for predicates | +| `slangpy/builtin/value.py` | `gen_calldata`, `gen_trampoline_load`, `gen_trampoline_store`, `create_calldata` | +| `slangpy/builtin/valueref.py` | `gen_calldata`, `gen_trampoline_load`, `gen_trampoline_store`, `create_calldata`, `read_calldata` | +| `slangpy/builtin/struct.py` | `gen_trampoline_load`, `gen_trampoline_store` | +| `slangpy/builtin/tensorcommon.py` | `gen_trampoline_load`, `gen_trampoline_store` extended for `ITensorType` | +| `slangpy/tests/slangpy_tests/test_kernel_gen.py` | All Phase 1 tests | + +**Test results:** 2952 passed / 0 failed in `slangpy/tests/slangpy_tests`. 6 pre-existing failures in `slangpy/tests/device/` (raytracing pipeline, type conformance cache — unrelated). + +**Implementation note — `_force_no_direct_bind`:** The plan did not anticipate that children inside non-direct-bind composite structs (mixed dim-0/dim-N children) would incorrectly inherit direct binding from their leaf predicates. A `_force_no_direct_bind` flag was added: when `gen_call_data_code` takes the non-direct-bind struct path, it marks all children so their `is_direct_bind_eligible`/`is_direct_bind_recursive` return `False`. This prevents generating e.g. `typealias _t_velocity = vector` for a child inside a struct that still uses `__slangpy_load`. Similarly, `PackedArg` bindings (`create_param_block = True`) are excluded since `ParameterBlock` is invalid in Slang. + +--- + ### Gating Tests — Pre-Implementation Checklist Before implementing any phase, add **gating tests** to [slangpy/tests/slangpy_tests/test_kernel_gen.py](slangpy/tests/slangpy_tests/test_kernel_gen.py) that assert the CURRENT generated kernel patterns. These tests document the baseline and will intentionally break as each simplification step is implemented. diff --git a/slangpy/bindings/__init__.py b/slangpy/bindings/__init__.py index a43c1416f..0c171a52a 100644 --- a/slangpy/bindings/__init__.py +++ b/slangpy/bindings/__init__.py @@ -7,6 +7,8 @@ BoundVariable, BoundCall, BoundVariableException, + is_direct_bind_eligible, + is_direct_bind_recursive, ) from slangpy.bindings.boundvariableruntime import BoundVariableRuntime, BoundCallRuntime from slangpy.bindings.codegen import CodeGen, CodeGenBlock diff --git a/slangpy/bindings/boundvariable.py b/slangpy/bindings/boundvariable.py index a72bfc61c..83ab111c6 100644 --- a/slangpy/bindings/boundvariable.py +++ b/slangpy/bindings/boundvariable.py @@ -145,6 +145,91 @@ def finalize_mappings(self, context: BindContext): arg.finalize_mappings(context) +# Cache of direct-bind-eligible marshall types (populated on first call). +_DIRECT_BIND_TYPES: Optional[tuple[type, ...]] = None + + +def _get_direct_bind_types() -> tuple[type, ...]: + """Lazily import and cache the set of marshall types eligible for direct binding.""" + global _DIRECT_BIND_TYPES + if _DIRECT_BIND_TYPES is None: + from slangpy.builtin.value import ( + ValueMarshall, + ScalarMarshall, + VectorMarshall, + MatrixMarshall, + ) + from slangpy.builtin.array import ArrayMarshall + from slangpy.builtin.valueref import ValueRefMarshall + + _DIRECT_BIND_TYPES = ( + ValueMarshall, + ScalarMarshall, + VectorMarshall, + MatrixMarshall, + ArrayMarshall, + ValueRefMarshall, + ) + return _DIRECT_BIND_TYPES + + +def is_direct_bind_eligible(binding: "BoundVariable") -> bool: + """Check if a leaf binding can use direct type marshalling (no ValueType wrapper). + + Eligible when: + - dim-0 (call_dimensionality == 0, not None) + - not composite (no children) + - not using a ParameterBlock (PackedArg) + - not inside a non-direct-bind composite (struct children path) + - marshall is a known direct-eligible type + + :param binding: The bound variable to check. + :return: True if the variable can use direct binding. + """ + if binding.call_dimensionality is None or binding.call_dimensionality != 0: + return False + if binding.children: + return False + if getattr(binding, "create_param_block", False): + return False + if getattr(binding, "_force_no_direct_bind", False): + return False + return isinstance(binding.python, _get_direct_bind_types()) + + +def is_direct_bind_recursive(binding: "BoundVariable") -> bool: + """Check if a binding (leaf or composite) can use direct type marshalling. + + For leaves, delegates to :func:`is_direct_bind_eligible`. + For composites (dicts bound to structs), returns True only if dim-0 and every + child is recursively direct-bind eligible, and the vector_type is a concrete + Slang struct (not UnknownType). + + :param binding: The bound variable to check. + :return: True if the variable (and all its children) can use direct binding. + """ + # If this binding is inside a non-direct-bind struct, it must not use direct binding + if getattr(binding, "_force_no_direct_bind", False): + return False + + if binding.children is None: + return is_direct_bind_eligible(binding) + + if binding.call_dimensionality is None or binding.call_dimensionality != 0: + return False + + if getattr(binding, "create_param_block", False): + return False + + # Must have a concrete struct type (not UnknownType) + from slangpy.reflection import UnknownType + + if binding.vector_type is None or isinstance(binding.vector_type, UnknownType): + return False + + return all(is_direct_bind_recursive(child) for child in binding.children.values()) + + class BoundVariable: """ Node in a built signature tree, maintains a pairing of python+slang marshall, @@ -540,55 +625,85 @@ def _calculate_differentiability(self, mode: CallMode): # todo: fwds self.access = (AccessType.none, AccessType.none) + def _set_direct_bind_on_children(self) -> None: + """Recursively set direct_bind flag on all leaf children's NativeValueMarshall.""" + if self.children is None: + from slangpy.core.native import NativeValueMarshall + + if isinstance(self.python, NativeValueMarshall): + self.python.direct_bind = True + return + for child in self.children.values(): + child._set_direct_bind_on_children() + def gen_call_data_code(self, cg: CodeGen, context: BindContext, depth: int = 0): if self.children is not None: cgb = cg.call_data_structs - cgb.begin_struct(f"_t_{self.variable_name}") + if is_direct_bind_recursive(self): + # Direct-bind: emit raw type alias and set direct_bind on children + assert self.vector_type is not None + cgb.type_alias(f"_t_{self.variable_name}", self.vector_type.full_name) + self._set_direct_bind_on_children() + else: + cgb.begin_struct(f"_t_{self.variable_name}") - for field, variable in self.children.items(): - variable.gen_call_data_code(cg, context, depth + 1) + # Children inside a non-direct-bind struct must not use direct + # binding — the struct's __slangpy_load/store expect wrapper types. + for variable in self.children.values(): + variable._force_no_direct_bind = True - for var in self.children.values(): - cgb.declare(f"_t_{var.variable_name}", var.variable_name) + for field, variable in self.children.items(): + variable.gen_call_data_code(cg, context, depth + 1) - assert self.vector_type is not None - context_decl = f"ContextND<{self.call_dimensionality}> context" - value_decl = f"{self.vector_type.full_name} value" - prefix = "[Differentiable]" if self.access[1] != AccessType.none else "" - - if self.access[0] in (AccessType.read, AccessType.readwrite): - cgb.empty_line() - cgb.append_line(f"{prefix} void __slangpy_load({context_decl}, out {value_decl})") - cgb.begin_block() - for field, var in self.children.items(): - cgb.append_statement( - f"{var.variable_name}.__slangpy_load(context.map(_m_{var.variable_name}),value.{field})" + for var in self.children.values(): + cgb.declare(f"_t_{var.variable_name}", var.variable_name) + + assert self.vector_type is not None + context_decl = f"ContextND<{self.call_dimensionality}> context" + value_decl = f"{self.vector_type.full_name} value" + prefix = "[Differentiable]" if self.access[1] != AccessType.none else "" + + if self.access[0] in (AccessType.read, AccessType.readwrite): + cgb.empty_line() + cgb.append_line( + f"{prefix} void __slangpy_load({context_decl}, out {value_decl})" ) - cgb.end_block() - - if self.access[0] in (AccessType.write, AccessType.readwrite): - cgb.empty_line() - cgb.append_line(f"{prefix} void __slangpy_store({context_decl}, in {value_decl})") - cgb.begin_block() - for field, var in self.children.items(): - cgb.append_statement( - f"{var.variable_name}.__slangpy_store(context.map(_m_{var.variable_name}),value.{field})" + cgb.begin_block() + for field, var in self.children.items(): + cgb.append_statement( + f"{var.variable_name}.__slangpy_load(context.map(_m_{var.variable_name}),value.{field})" + ) + cgb.end_block() + + if self.access[0] in (AccessType.write, AccessType.readwrite): + cgb.empty_line() + cgb.append_line( + f"{prefix} void __slangpy_store({context_decl}, in {value_decl})" ) - cgb.end_block() + cgb.begin_block() + for field, var in self.children.items(): + cgb.append_statement( + f"{var.variable_name}.__slangpy_store(context.map(_m_{var.variable_name}),value.{field})" + ) + cgb.end_block() - cgb.end_struct() + cgb.end_struct() else: # Generate call data self.python.gen_calldata(cg.call_data_structs, context, self) - if len(self.vector_mapping) > 0: - cg.call_data_structs.append_statement( - f"static const int[] _m_{self.variable_name} = {{ {','.join([str(x) for x in self.vector_mapping.as_tuple()])} }}" - ) - else: - cg.call_data_structs.append_statement(f"static const int _m_{self.variable_name} = 0") + # Skip mapping constants for direct-bind variables (they bypass __slangpy_load/store) + if not is_direct_bind_recursive(self): + if len(self.vector_mapping) > 0: + cg.call_data_structs.append_statement( + f"static const int[] _m_{self.variable_name} = {{ {','.join([str(x) for x in self.vector_mapping.as_tuple()])} }}" + ) + else: + cg.call_data_structs.append_statement( + f"static const int _m_{self.variable_name} = 0" + ) if depth == 0: if self.create_param_block: diff --git a/slangpy/builtin/struct.py b/slangpy/builtin/struct.py index 0b38cc955..0a8622b90 100644 --- a/slangpy/builtin/struct.py +++ b/slangpy/builtin/struct.py @@ -4,8 +4,9 @@ from slangpy.core.native import Shape, NativeMarshall import slangpy.bindings.typeregistry as tr -from slangpy.bindings import PYTHON_TYPES, BindContext, BoundVariable +from slangpy.bindings import PYTHON_TYPES, BindContext, BoundVariable, is_direct_bind_recursive from slangpy.reflection import SlangProgramLayout, SlangType, UnknownType, StructType, InterfaceType +from slangpy.core.native import AccessType from .value import ValueMarshall import slangpy.reflection.vectorize as spyvec @@ -77,6 +78,33 @@ def resolve_dimensionality( # A struct type should get a dictionary, and just return that for raw dispatch + def gen_trampoline_load( + self, cgb: "CodeGenBlock", binding: "BoundVariable", is_entry_point: bool + ) -> bool: + if not is_direct_bind_recursive(binding): + return False + data_name = ( + f"_param_{binding.variable_name}" + if binding.create_param_block + else f"{'__calldata__' if is_entry_point else 'call_data'}.{binding.variable_name}" + ) + cgb.append_statement(f"{binding.variable_name} = {data_name}") + return True + + def gen_trampoline_store( + self, cgb: "CodeGenBlock", binding: "BoundVariable", is_entry_point: bool + ) -> bool: + if not is_direct_bind_recursive(binding): + return False + if binding.access[0] in (AccessType.write, AccessType.readwrite): + data_name = ( + f"_param_{binding.variable_name}" + if binding.create_param_block + else f"{'__calldata__' if is_entry_point else 'call_data'}.{binding.variable_name}" + ) + cgb.append_statement(f"{data_name} = {binding.variable_name}") + return True + def create_dispatchdata(self, data: Any) -> Any: if isinstance(data, dict): return data diff --git a/slangpy/builtin/tensorcommon.py b/slangpy/builtin/tensorcommon.py index 886cb97d8..0eae47b5a 100644 --- a/slangpy/builtin/tensorcommon.py +++ b/slangpy/builtin/tensorcommon.py @@ -383,6 +383,18 @@ def gen_trampoline_load( self: ITensorMarshall, cgb: CodeGenBlock, binding: BoundVariable, is_entry_point: bool ) -> bool: if not isinstance(binding.vector_type, (TensorViewType, DiffTensorViewType)): + # For ITensorType at dim-0, use direct assignment (struct copy) + if ( + isinstance(binding.vector_type, ITensorType) + and binding.call_dimensionality is not None + and binding.call_dimensionality == 0 + ): + if is_entry_point: + data_name = f"__calldata__.{binding.variable_name}" + else: + data_name = f"call_data.{binding.variable_name}" + cgb.append_statement(f"{binding.variable_name} = {data_name}") + return True return False if is_entry_point: data_name = f"__calldata__.{binding.variable_name}" @@ -396,5 +408,12 @@ def gen_trampoline_store( self: ITensorMarshall, cgb: CodeGenBlock, binding: BoundVariable, is_entry_point: bool ) -> bool: if not isinstance(binding.vector_type, (TensorViewType, DiffTensorViewType)): + # For ITensorType at dim-0, suppress default store + if ( + isinstance(binding.vector_type, ITensorType) + and binding.call_dimensionality is not None + and binding.call_dimensionality == 0 + ): + return True return False return True diff --git a/slangpy/builtin/value.py b/slangpy/builtin/value.py index 45c41ef83..c9d770e7c 100644 --- a/slangpy/builtin/value.py +++ b/slangpy/builtin/value.py @@ -15,6 +15,7 @@ BoundVariable, BoundVariableRuntime, CodeGenBlock, + is_direct_bind_eligible, ) from slangpy.reflection.reflectiontypes import ( BOOL_TYPES, @@ -95,16 +96,44 @@ def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: "BoundV name = binding.variable_name if access[0] in [AccessType.read, AccessType.readwrite]: assert binding.vector_type is not None - cgb.type_alias(f"_t_{name}", f"ValueType<{binding.vector_type.full_name}>") + if is_direct_bind_eligible(binding): + self.direct_bind = True + cgb.type_alias(f"_t_{name}", binding.vector_type.full_name) + else: + cgb.type_alias(f"_t_{name}", f"ValueType<{binding.vector_type.full_name}>") else: cgb.type_alias(f"_t_{name}", f"NoneType") + def gen_trampoline_load( + self, cgb: CodeGenBlock, binding: "BoundVariable", is_entry_point: bool + ) -> bool: + if not is_direct_bind_eligible(binding): + return False + if binding.access[0] not in (AccessType.read, AccessType.readwrite): + return False + if is_entry_point: + data_name = f"__calldata__.{binding.variable_name}" + else: + data_name = f"call_data.{binding.variable_name}" + cgb.append_statement(f"{binding.variable_name} = {data_name}") + return True + + def gen_trampoline_store( + self, cgb: CodeGenBlock, binding: "BoundVariable", is_entry_point: bool + ) -> bool: + if not is_direct_bind_eligible(binding): + return False + # ValueMarshall is read-only — suppress the default store + return True + # Call data just returns the primal def create_calldata( self, context: CallContext, binding: "BoundVariableRuntime", data: Any ) -> Any: access = binding.access if access[0] in [AccessType.read, AccessType.readwrite]: + if self.direct_bind: + return data return {"value": data} # Values just return themselves for raw dispatch @@ -314,7 +343,11 @@ def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: "BoundV if access[0] in [AccessType.read, AccessType.readwrite]: st = cast(kfr.VectorType, self.slang_type) et = cast(SlangType, st.element_type) - cgb.type_alias(f"_t_{name}", f"VectorValueType<{et.full_name},{st.num_elements}>") + if is_direct_bind_eligible(binding): + self.direct_bind = True + cgb.type_alias(f"_t_{name}", binding.vector_type.full_name) + else: + cgb.type_alias(f"_t_{name}", f"VectorValueType<{et.full_name},{st.num_elements}>") else: cgb.type_alias(f"_t_{name}", f"NoneType") diff --git a/slangpy/builtin/valueref.py b/slangpy/builtin/valueref.py index 3bf439b51..26d3d0d12 100644 --- a/slangpy/builtin/valueref.py +++ b/slangpy/builtin/valueref.py @@ -18,6 +18,7 @@ CodeGenBlock, ReturnContext, get_or_create_type, + is_direct_bind_eligible, ) from slangpy.builtin.value import slang_type_to_return_type from slangpy.reflection.reflectiontypes import SlangType @@ -111,6 +112,7 @@ class ValueRefMarshall(Marshall): def __init__(self, layout: kfr.SlangProgramLayout, value_type: kfr.SlangType): super().__init__(layout) self.value_type = value_type + self._direct_bind = False st = layout.find_type_by_name(f"ValueRef<{value_type.full_name}>") if st is None: @@ -155,10 +157,50 @@ def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: "BoundV assert access[0] != AccessType.none assert access[1] == AccessType.none assert binding.vector_type is not None - if access[0] == AccessType.read: - cgb.type_alias(f"_t_{name}", f"ValueRef<{binding.vector_type.full_name}>") + if is_direct_bind_eligible(binding): + self._direct_bind = True + if access[0] == AccessType.read: + cgb.type_alias(f"_t_{name}", binding.vector_type.full_name) + else: + cgb.type_alias( + f"_t_{name}", + f"RWStructuredBuffer<{binding.vector_type.full_name}>", + ) + else: + if access[0] == AccessType.read: + cgb.type_alias(f"_t_{name}", f"ValueRef<{binding.vector_type.full_name}>") + else: + cgb.type_alias(f"_t_{name}", f"RWValueRef<{binding.vector_type.full_name}>") + + def gen_trampoline_load( + self, cgb: CodeGenBlock, binding: "BoundVariable", is_entry_point: bool + ) -> bool: + if not is_direct_bind_eligible(binding): + return False + if binding.access[0] == AccessType.none: + return False + if is_entry_point: + data_name = f"__calldata__.{binding.variable_name}" + else: + data_name = f"call_data.{binding.variable_name}" + if binding.access[0] == AccessType.read: + cgb.append_statement(f"{binding.variable_name} = {data_name}") else: - cgb.type_alias(f"_t_{name}", f"RWValueRef<{binding.vector_type.full_name}>") + cgb.append_statement(f"{binding.variable_name} = {data_name}[0]") + return True + + def gen_trampoline_store( + self, cgb: CodeGenBlock, binding: "BoundVariable", is_entry_point: bool + ) -> bool: + if not is_direct_bind_eligible(binding): + return False + if binding.access[0] in (AccessType.write, AccessType.readwrite): + if is_entry_point: + data_name = f"__calldata__.{binding.variable_name}" + else: + data_name = f"call_data.{binding.variable_name}" + cgb.append_statement(f"{data_name}[0] = {binding.variable_name}") + return True # Call data just returns the primal def create_calldata( @@ -167,7 +209,9 @@ def create_calldata( access = binding.access assert access[0] != AccessType.none assert access[1] == AccessType.none - if access[0] == AccessType.read: + if self._direct_bind and access[0] == AccessType.read: + return data.value + elif access[0] == AccessType.read: return {"value": data.value} else: if isinstance(binding.vector_type, (kfr.StructType, kfr.ArrayType)): @@ -182,6 +226,8 @@ def create_calldata( if access[0] != AccessType.write: cursor[0].write(data.value) cursor.apply() + if self._direct_bind: + return buffer return {"value": buffer} else: if isinstance(self.value_type, kfr.SlangType): @@ -189,14 +235,15 @@ def create_calldata( else: npdata = self.value_type.to_numpy(data.value) npdata = npdata.view(dtype=np.uint8) - return { - "value": context.device.create_buffer( - element_count=1, - struct_size=npdata.size, - data=npdata, - usage=BufferUsage.shader_resource | BufferUsage.unordered_access, - ) - } + buffer = context.device.create_buffer( + element_count=1, + struct_size=npdata.size, + data=npdata, + usage=BufferUsage.shader_resource | BufferUsage.unordered_access, + ) + if self._direct_bind: + return buffer + return {"value": buffer} # Value ref just passes its value for raw dispatch def create_dispatchdata(self, data: Any) -> Any: @@ -212,12 +259,13 @@ def read_calldata( ) -> None: access = binding.access if access[0] in [AccessType.write, AccessType.readwrite]: - assert isinstance(result["value"], Buffer) + buffer = result if self._direct_bind else result["value"] + assert isinstance(buffer, Buffer) if isinstance(binding.vector_type, (kfr.StructType, kfr.ArrayType)): - cursor = BufferCursor(binding.vector_type.buffer_layout.reflection, result["value"]) + cursor = BufferCursor(binding.vector_type.buffer_layout.reflection, buffer) data.value = cursor[0].read() else: - npdata = result["value"].to_numpy() + npdata = buffer.to_numpy() if isinstance(self.value_type, kfr.SlangType): data.value = numpy_to_slang_value(self.value_type, npdata) else: diff --git a/slangpy/tests/slangpy_tests/test_kernel_gen.py b/slangpy/tests/slangpy_tests/test_kernel_gen.py index bcbb0e6e2..3f7d9a1c8 100644 --- a/slangpy/tests/slangpy_tests/test_kernel_gen.py +++ b/slangpy/tests/slangpy_tests/test_kernel_gen.py @@ -7,19 +7,58 @@ - passing arguments directly vs via call data - passing read-only arguments that don't need storing directly rather than via marshalls - handling the semantic 'dispatch thread id' etc and calling kernels directly + +Gating tests (test_gate_*) assert CURRENT generated kernel patterns and will +intentionally break as simplification steps from the kernel-gen simplification +plan are implemented. Negative gates (test_gate_*_keeps_*) must remain +passing after simplification — they cover types that are NOT direct-bind +eligible. """ from typing import Any +import numpy as np import pytest import os import slangpy as spy from slangpy.testing import helpers +from slangpy.types import ValueRef, Tensor, diffPair +from slangpy.types.wanghasharg import WangHashArg PRINT_TEST_KERNEL_GEN = os.getenv("PRINT_TEST_KERNEL_GEN", "0") == "1" +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def assert_contains(code: str, *patterns: str) -> None: + """Assert all patterns appear in generated code.""" + for p in patterns: + assert p in code, f"Expected pattern not found: {p}" + + +def assert_not_contains(code: str, *patterns: str) -> None: + """Assert none of the patterns appear in generated code.""" + for p in patterns: + assert p not in code, f"Unexpected pattern found: {p}" + + +def assert_trampoline_has(code: str, *stmts: str) -> None: + """Assert trampoline contains statements, insensitive to call_data vs __calldata__ prefix.""" + for s in stmts: + # Replace __calldata__ with both options for matching + if "__calldata__." in s: + alt = s.replace("__calldata__.", "call_data.") + assert ( + s in code or alt in code + ), f"Expected trampoline statement not found: {s} (or {alt})" + else: + assert s in code, f"Expected trampoline statement not found: {s}" + + def generate_code( device: spy.Device, func_name: str, module_source: str, *args: Any, **kwargs: Any ) -> str: @@ -33,6 +72,24 @@ def generate_code( return cd.code +def generate_bwds_code( + device: spy.Device, func_name: str, module_source: str, *args: Any, **kwargs: Any +) -> str: + """ + Generate backwards-mode code for the given function and arguments. + """ + func = helpers.create_function_from_module(device, func_name, module_source) + cd = func.bwds.debug_build_call_data(*args, **kwargs) + if PRINT_TEST_KERNEL_GEN: + print(cd.code) + return cd.code + + +# --------------------------------------------------------------------------- +# Basic test +# --------------------------------------------------------------------------- + + @pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) def test_kernel_gen_basic(device_type: spy.DeviceType): """ @@ -49,5 +106,332 @@ def test_kernel_gen_basic(device_type: spy.DeviceType): assert "add" in code +# =========================================================================== +# Phase 1 tests — assert direct-bind behaviour after implementation +# =========================================================================== + +# -- Step 1.2: Scalar direct binding -- + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_scalar_uses_valuetype(device_type: spy.DeviceType): + device = helpers.get_device(device_type) + code = generate_code( + device, + "add", + "int add(int a, int b) { return a + b; }", + 1, + 2, + ) + # Scalars now use direct binding: typealias to raw type, no ValueType wrapper + assert_not_contains(code, "ValueType") + assert_contains(code, "typealias _t_a = int;", "typealias _t_b = int;") + # Trampoline uses direct assignment, no __slangpy_load + assert_trampoline_has(code, "a = __calldata__.a;", "b = __calldata__.b;") + # _result is auto-created as RWValueRef — now uses RWStructuredBuffer + assert_not_contains(code, "RWValueRef") + assert_contains(code, "RWStructuredBuffer") + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_float_scalar_uses_valuetype(device_type: spy.DeviceType): + device = helpers.get_device(device_type) + code = generate_code( + device, + "mymul", + "float mymul(float x, float y) { return x * y; }", + 1.0, + 2.0, + ) + assert_not_contains(code, "ValueType") + assert_contains(code, "typealias _t_x = float;", "typealias _t_y = float;") + + +# -- Step 1.3: Vector / Matrix / Array direct binding -- + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_vector_uses_vectorvaluetype(device_type: spy.DeviceType): + device = helpers.get_device(device_type) + code = generate_code( + device, + "scale", + "float3 scale(float3 v, float s) { return v * s; }", + spy.math.float3(1, 2, 3), + 1.0, + ) + assert_not_contains(code, "VectorValueType") + assert_contains(code, "typealias _t_v = vector;") + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_matrix_uses_valuetype(device_type: spy.DeviceType): + device = helpers.get_device(device_type) + code = generate_code( + device, + "ident", + "float4x4 ident(float4x4 m) { return m; }", + spy.math.float4x4.identity(), + ) + assert_not_contains(code, "ValueType>") + assert_contains(code, "typealias _t_m = matrix;") + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_array_dim0_uses_valuetype(device_type: spy.DeviceType): + device = helpers.get_device(device_type) + code = generate_code( + device, + "process", + "void process(float a[4]) { }", + [1.0, 2.0, 3.0, 4.0], + ) + assert_not_contains(code, "ValueType<") + assert_contains(code, "typealias _t_a = ") + + +# -- Step 1.5: ValueRef direct binding -- + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_valueref_read_uses_wrapper(device_type: spy.DeviceType): + device = helpers.get_device(device_type) + code = generate_code( + device, + "read_val", + "float read_val(float v) { return v; }", + ValueRef(1.0), + ) + # Read-only ValueRef now uses raw type alias, not ValueRef + assert_not_contains(code, "ValueRef") + assert_contains(code, "typealias _t_v = float;") + # Direct assignment in trampoline + assert_trampoline_has(code, "v = __calldata__.v;") + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_valueref_write_uses_wrapper(device_type: spy.DeviceType): + device = helpers.get_device(device_type) + code = generate_code( + device, + "add", + "int add(int a, int b) { return a + b; }", + 1, + 2, + ) + # Auto-created _result uses RWStructuredBuffer instead of RWValueRef + assert_not_contains(code, "RWValueRef") + assert_contains(code, "RWStructuredBuffer") + # Trampoline uses buffer load/store + assert_trampoline_has(code, "_result = __calldata__._result[0];") + + +# -- Step 1.7: Mapping constants and context.map -- + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_mapping_constants_present(device_type: spy.DeviceType): + device = helpers.get_device(device_type) + code = generate_code( + device, + "add", + "int add(int a, int b) { return a + b; }", + 1, + 2, + ) + # Direct-bind variables no longer emit mapping constants + assert_not_contains( + code, + "static const int _m_a = 0", + "static const int _m_b = 0", + "static const int _m__result = 0", + ) + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_context_map_in_trampoline(device_type: spy.DeviceType): + device = helpers.get_device(device_type) + code = generate_code( + device, + "add", + "int add(int a, int b) { return a + b; }", + 1, + 2, + ) + # Direct-bind variables don't use context.map + assert_not_contains(code, "__slangpy_context__.map(_m_a)") + + +# -- Step 1.4: Struct / dict direct binding -- + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_struct_uses_slangpy_load(device_type: spy.DeviceType): + device = helpers.get_device(device_type) + src = """ +struct S { + float x; + float y; +}; +float sum(S s) { return s.x + s.y; } +""" + code = generate_code(device, "sum", src, {"_type": "S", "x": 1.0, "y": 2.0}) + # Direct-bind struct: uses raw type alias, no inline struct with __slangpy_load + assert_not_contains(code, "__slangpy_load") + assert_contains(code, "typealias _t_s = S;") + # Direct assignment in trampoline + assert_trampoline_has(code, "s = __calldata__.s;") + + +# -- Step 1.8: Autodiff gating -- + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_bwds_scalar_uses_valuetype(device_type: spy.DeviceType): + device = helpers.get_device(device_type) + src = """ +[Differentiable] +float polynomial(float a, float b) { + return a * a + b + 1; +} +""" + code = generate_bwds_code(device, "polynomial", src, 5.0, 10.0, 26.0) + # bwds still uses direct bind for primals; check differentiable markers remain + assert_not_contains(code, "ValueType") + assert_contains(code, "[Differentiable]", "bwd_diff(_trampoline)") + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_bwds_trampoline_is_differentiable(device_type: spy.DeviceType): + device = helpers.get_device(device_type) + src = """ +[Differentiable] +float polynomial(float a, float b) { + return a * a + b + 1; +} +""" + code = generate_bwds_code(device, "polynomial", src, 5.0, 10.0, 26.0) + # [Differentiable] should appear before the trampoline function + diff_idx = code.index("[Differentiable]") + trampoline_idx = code.index("void _trampoline") + assert diff_idx < trampoline_idx + + +# =========================================================================== +# Phase 1 negative gates — must REMAIN passing after Phase 1 +# =========================================================================== + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_wanghasharg_uses_wrapper(device_type: spy.DeviceType): + device = helpers.get_device(device_type) + src = "uint3 rng(uint3 input) { return input; }" + code = generate_code(device, "rng", src, WangHashArg(3)) + assert_contains(code, "WangHashArg<") + # WangHashArg uses wrapper type. Check the type alias is present. + assert_contains(code, "_t_input") + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_vectorized_scalar_keeps_wrapper(device_type: spy.DeviceType): + device = helpers.get_device(device_type) + src = "float square(float x) { return x * x; }" + tensor = Tensor.from_numpy( + helpers.get_device(device_type), np.array([1, 2, 3], dtype=np.float32) + ) + code = generate_code(device, "square", src, tensor) + # Vectorized (dim > 0) — tensor marshall used, __slangpy_load still present + assert_contains(code, "__slangpy_load") + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_vectorized_dict_keeps_struct_load(device_type: spy.DeviceType): + device = helpers.get_device(device_type) + src = """ +struct S { + float x; + float y; +}; +void apply(S s, float scale) {} +""" + tensor_x = Tensor.from_numpy( + helpers.get_device(device_type), np.array([1, 2, 3], dtype=np.float32) + ) + tensor_y = Tensor.from_numpy( + helpers.get_device(device_type), np.array([4, 5, 6], dtype=np.float32) + ) + code = generate_code(device, "apply", src, {"_type": "S", "x": tensor_x, "y": tensor_y}, 1.0) + # Children are vectorized (dim > 0) — should keep inline struct with __slangpy_load + assert_contains(code, "__slangpy_load") + + +# =========================================================================== +# Phase 1 functional dispatch tests — verify GPU results are correct +# =========================================================================== + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_phase1_functional_scalar_add(device_type: spy.DeviceType): + """Dispatch scalar add with direct binding and verify GPU result.""" + device = helpers.get_device(device_type) + func = helpers.create_function_from_module( + device, "add", "int add(int a, int b) { return a + b; }" + ) + result = func(3, 7) + assert result == 10 + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_phase1_functional_float_mul(device_type: spy.DeviceType): + """Dispatch float multiply with direct binding.""" + device = helpers.get_device(device_type) + func = helpers.create_function_from_module( + device, "mymul", "float mymul(float x, float y) { return x * y; }" + ) + result = func(3.0, 4.0) + assert abs(result - 12.0) < 1e-5 + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_phase1_functional_vector_scale(device_type: spy.DeviceType): + """Dispatch vector scale with direct binding.""" + device = helpers.get_device(device_type) + func = helpers.create_function_from_module( + device, "scale", "float3 scale(float3 v, float s) { return v * s; }" + ) + result = func(spy.math.float3(1, 2, 3), 2.0) + assert result.x == 2.0 + assert result.y == 4.0 + assert result.z == 6.0 + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_phase1_functional_struct_sum(device_type: spy.DeviceType): + """Dispatch struct sum via dict with direct binding.""" + device = helpers.get_device(device_type) + src = """ +struct S { + float x; + float y; +}; +float sum(S s) { return s.x + s.y; } +""" + func = helpers.create_function_from_module(device, "sum", src) + result = func({"_type": "S", "x": 3.0, "y": 7.0}) + assert abs(result - 10.0) < 1e-5 + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_phase1_functional_valueref_write(device_type: spy.DeviceType): + """Dispatch with explicit ValueRef output and read back.""" + device = helpers.get_device(device_type) + func = helpers.create_function_from_module( + device, "add", "int add(int a, int b) { return a + b; }" + ) + out = ValueRef(0) + func(5, 8, _result=out) + assert out.value == 13 + + if __name__ == "__main__": pytest.main([__file__, "-vs"]) diff --git a/src/slangpy_ext/utils/slangpyvalue.cpp b/src/slangpy_ext/utils/slangpyvalue.cpp index bdee96ce7..f6ff0a454 100644 --- a/src/slangpy_ext/utils/slangpyvalue.cpp +++ b/src/slangpy_ext/utils/slangpyvalue.cpp @@ -16,7 +16,7 @@ void NativeValueMarshall::ensure_cached(ShaderCursor cursor, NativeBoundVariable { if (m_cached.is_valid) return; - ShaderCursor field = cursor[binding->variable_name()]["value"]; + ShaderCursor field = m_direct_bind ? cursor[binding->variable_name()] : cursor[binding->variable_name()]["value"]; m_cached.value_offset = field.offset(); m_cached.value_type_layout = field.slang_type_layout(); m_cached.writer = get_shader_cursor_writer(m_cached.value_type_layout); @@ -63,5 +63,11 @@ SGL_PY_EXPORT(utils_slangpy_value) new (&self) NativeValueMarshall(); }, D_NA(NativeValueMarshall, NativeValueMarshall) + ) + .def_prop_rw( + "direct_bind", + &NativeValueMarshall::direct_bind, + &NativeValueMarshall::set_direct_bind, + D_NA(NativeValueMarshall, direct_bind) ); } diff --git a/src/slangpy_ext/utils/slangpyvalue.h b/src/slangpy_ext/utils/slangpyvalue.h index 43c16805c..1ba086c59 100644 --- a/src/slangpy_ext/utils/slangpyvalue.h +++ b/src/slangpy_ext/utils/slangpyvalue.h @@ -28,17 +28,27 @@ class NativeValueMarshall : public NativeMarshall { nb::list read_back ) const override; + /// When true, the Slang type is a raw value (no "value" sub-field). + /// When false (default), the type has a "value" sub-field (e.g. ValueType). + bool direct_bind() const { return m_direct_bind; } + + /// Set the direct_bind flag. + void set_direct_bind(bool direct_bind) { m_direct_bind = direct_bind; } + private: /// Cached data for fast-path value writing, populated on first dispatch. struct CachedValueWrite { - ShaderOffset value_offset; ///< Offset to the "value" sub-field. - slang::TypeLayoutReflection* value_type_layout = nullptr; ///< Type layout for "value" field. + ShaderOffset value_offset; ///< Offset to the value field. + slang::TypeLayoutReflection* value_type_layout = nullptr; ///< Type layout for value field. std::function writer; ///< Pre-resolved writer fn. bool is_valid = false; }; mutable CachedValueWrite m_cached; + /// Whether the Slang type is raw (no "value" sub-field). + bool m_direct_bind{false}; + /// Populate m_cached on first call by resolving the cursor path and writer function. void ensure_cached(ShaderCursor cursor, NativeBoundVariableRuntime* binding) const; From 40416d62a63b257cb6a2620d8d7fc2aa89ffeab9 Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Wed, 11 Mar 2026 09:19:37 +0000 Subject: [PATCH 03/41] work on redoing direct_bind logic --- slangpy/bindings/__init__.py | 3 +- slangpy/bindings/boundvariable.py | 141 ++++++++--------------- slangpy/bindings/boundvariableruntime.py | 3 + slangpy/bindings/marshall.py | 11 ++ slangpy/builtin/struct.py | 11 +- slangpy/builtin/value.py | 17 +-- slangpy/builtin/valueref.py | 21 ++-- slangpy/core/calldata.py | 3 + slangpy/core/callsignature.py | 10 ++ src/slangpy_ext/utils/slangpy.cpp | 8 +- src/slangpy_ext/utils/slangpy.h | 7 ++ src/slangpy_ext/utils/slangpyvalue.cpp | 9 +- src/slangpy_ext/utils/slangpyvalue.h | 10 -- 13 files changed, 120 insertions(+), 134 deletions(-) diff --git a/slangpy/bindings/__init__.py b/slangpy/bindings/__init__.py index 0c171a52a..f8bc597c2 100644 --- a/slangpy/bindings/__init__.py +++ b/slangpy/bindings/__init__.py @@ -7,8 +7,7 @@ BoundVariable, BoundCall, BoundVariableException, - is_direct_bind_eligible, - is_direct_bind_recursive, + can_direct_bind_common, ) from slangpy.bindings.boundvariableruntime import BoundVariableRuntime, BoundCallRuntime from slangpy.bindings.codegen import CodeGen, CodeGenBlock diff --git a/slangpy/bindings/boundvariable.py b/slangpy/bindings/boundvariable.py index 83ab111c6..b6f5f470f 100644 --- a/slangpy/bindings/boundvariable.py +++ b/slangpy/bindings/boundvariable.py @@ -145,46 +145,14 @@ def finalize_mappings(self, context: BindContext): arg.finalize_mappings(context) -# Cache of direct-bind-eligible marshall types (populated on first call). -_DIRECT_BIND_TYPES: Optional[tuple[type, ...]] = None - - -def _get_direct_bind_types() -> tuple[type, ...]: - """Lazily import and cache the set of marshall types eligible for direct binding.""" - global _DIRECT_BIND_TYPES - if _DIRECT_BIND_TYPES is None: - from slangpy.builtin.value import ( - ValueMarshall, - ScalarMarshall, - VectorMarshall, - MatrixMarshall, - ) - from slangpy.builtin.array import ArrayMarshall - from slangpy.builtin.valueref import ValueRefMarshall - - _DIRECT_BIND_TYPES = ( - ValueMarshall, - ScalarMarshall, - VectorMarshall, - MatrixMarshall, - ArrayMarshall, - ValueRefMarshall, - ) - return _DIRECT_BIND_TYPES - +def can_direct_bind_common(binding: "BoundVariable") -> bool: + """Common checks for direct binding eligibility. -def is_direct_bind_eligible(binding: "BoundVariable") -> bool: - """Check if a leaf binding can use direct type marshalling (no ValueType wrapper). - - Eligible when: - - dim-0 (call_dimensionality == 0, not None) - - not composite (no children) - - not using a ParameterBlock (PackedArg) - - not inside a non-direct-bind composite (struct children path) - - marshall is a known direct-eligible type + Marshalls call this from their ``can_direct_bind`` method and then + optionally add type-specific logic. :param binding: The bound variable to check. - :return: True if the variable can use direct binding. + :return: True if the common prerequisites for direct binding are met. """ if binding.call_dimensionality is None or binding.call_dimensionality != 0: return False @@ -192,42 +160,7 @@ def is_direct_bind_eligible(binding: "BoundVariable") -> bool: return False if getattr(binding, "create_param_block", False): return False - if getattr(binding, "_force_no_direct_bind", False): - return False - return isinstance(binding.python, _get_direct_bind_types()) - - -def is_direct_bind_recursive(binding: "BoundVariable") -> bool: - """Check if a binding (leaf or composite) can use direct type marshalling. - - For leaves, delegates to :func:`is_direct_bind_eligible`. - For composites (dicts bound to structs), returns True only if dim-0 and every - child is recursively direct-bind eligible, and the vector_type is a concrete - Slang struct (not UnknownType). - - :param binding: The bound variable to check. - :return: True if the variable (and all its children) can use direct binding. - """ - # If this binding is inside a non-direct-bind struct, it must not use direct binding - if getattr(binding, "_force_no_direct_bind", False): - return False - - if binding.children is None: - return is_direct_bind_eligible(binding) - - if binding.call_dimensionality is None or binding.call_dimensionality != 0: - return False - - if getattr(binding, "create_param_block", False): - return False - - # Must have a concrete struct type (not UnknownType) - from slangpy.reflection import UnknownType - - if binding.vector_type is None or isinstance(binding.vector_type, UnknownType): - return False - - return all(is_direct_bind_recursive(child) for child in binding.children.values()) + return True class BoundVariable: @@ -263,6 +196,9 @@ def __init__( #: Is this variable differentiable self.differentiable = False + #: Whether this variable uses direct binding (raw Slang type, no wrapper). + self.direct_bind = False + #: Call dimensionality of this variable. self.call_dimensionality = None @@ -564,6 +500,42 @@ def calculate_differentiability(self, context: BindContext): for child in self.children.values(): child.calculate_differentiability(context) + def calculate_direct_bind(self) -> None: + """Depth-first calculation of direct_bind for the variable tree. + + For composites (dicts), recurses children first, then checks if all + children are direct-bind eligible and the composite has a concrete + Slang struct type. + + For leaves, delegates to the marshall's ``can_direct_bind`` method. + """ + if self.children is not None: + for child in self.children.values(): + child.calculate_direct_bind() + if ( + self.call_dimensionality is not None + and self.call_dimensionality == 0 + and not getattr(self, "create_param_block", False) + and self.vector_type is not None + and all(child.direct_bind for child in self.children.values()) + ): + self.direct_bind = True + else: + # Parent is not direct-bind — children must use wrapper types + # so the parent's generated __slangpy_load/store can call theirs. + for child in self.children.values(): + child._clear_direct_bind() + else: + if self.python is not None and hasattr(self.python, "can_direct_bind"): + self.direct_bind = self.python.can_direct_bind(self) + + def _clear_direct_bind(self) -> None: + """Recursively clear direct_bind on this node and all descendants.""" + self.direct_bind = False + if self.children is not None: + for child in self.children.values(): + child._clear_direct_bind() + def get_input_list(self, args: list["BoundVariable"]): """ Recursively populate flat list of argument nodes @@ -625,34 +597,17 @@ def _calculate_differentiability(self, mode: CallMode): # todo: fwds self.access = (AccessType.none, AccessType.none) - def _set_direct_bind_on_children(self) -> None: - """Recursively set direct_bind flag on all leaf children's NativeValueMarshall.""" - if self.children is None: - from slangpy.core.native import NativeValueMarshall - - if isinstance(self.python, NativeValueMarshall): - self.python.direct_bind = True - return - for child in self.children.values(): - child._set_direct_bind_on_children() - def gen_call_data_code(self, cg: CodeGen, context: BindContext, depth: int = 0): if self.children is not None: cgb = cg.call_data_structs - if is_direct_bind_recursive(self): - # Direct-bind: emit raw type alias and set direct_bind on children + if self.direct_bind: + # Direct-bind: emit raw type alias assert self.vector_type is not None cgb.type_alias(f"_t_{self.variable_name}", self.vector_type.full_name) - self._set_direct_bind_on_children() else: cgb.begin_struct(f"_t_{self.variable_name}") - # Children inside a non-direct-bind struct must not use direct - # binding — the struct's __slangpy_load/store expect wrapper types. - for variable in self.children.values(): - variable._force_no_direct_bind = True - for field, variable in self.children.items(): variable.gen_call_data_code(cg, context, depth + 1) @@ -695,7 +650,7 @@ def gen_call_data_code(self, cg: CodeGen, context: BindContext, depth: int = 0): self.python.gen_calldata(cg.call_data_structs, context, self) # Skip mapping constants for direct-bind variables (they bypass __slangpy_load/store) - if not is_direct_bind_recursive(self): + if not self.direct_bind: if len(self.vector_mapping) > 0: cg.call_data_structs.append_statement( f"static const int[] _m_{self.variable_name} = {{ {','.join([str(x) for x in self.vector_mapping.as_tuple()])} }}" diff --git a/slangpy/bindings/boundvariableruntime.py b/slangpy/bindings/boundvariableruntime.py index d47576b8e..739a1da92 100644 --- a/slangpy/bindings/boundvariableruntime.py +++ b/slangpy/bindings/boundvariableruntime.py @@ -54,6 +54,9 @@ def __init__(self, source: "BoundVariable"): #: Call dimensionality of variable. self.call_dimensionality = source.call_dimensionality + #: Whether this variable uses direct binding. + self.direct_bind = source.direct_bind + # Temp data stored / updated each call. self.shape = Shape(None) diff --git a/slangpy/bindings/marshall.py b/slangpy/bindings/marshall.py index 43b8861ec..348a52f9d 100644 --- a/slangpy/bindings/marshall.py +++ b/slangpy/bindings/marshall.py @@ -133,6 +133,17 @@ def gen_trampoline_store( """ return False + def can_direct_bind(self, binding: "BoundVariable") -> bool: + """ + Whether this marshall supports direct binding for the given variable. + Direct binding emits raw Slang types instead of ValueType wrappers. + Default: False. Override in subclasses to opt in. + + :param binding: The bound variable to check. + :return: True if this marshall supports direct binding for this variable. + """ + return False + def reduce_type(self, context: BindContext, dimensions: int) -> "SlangType": """ Get the slang type for this variable when a given number of dimensions diff --git a/slangpy/builtin/struct.py b/slangpy/builtin/struct.py index 0a8622b90..2831493f3 100644 --- a/slangpy/builtin/struct.py +++ b/slangpy/builtin/struct.py @@ -4,7 +4,7 @@ from slangpy.core.native import Shape, NativeMarshall import slangpy.bindings.typeregistry as tr -from slangpy.bindings import PYTHON_TYPES, BindContext, BoundVariable, is_direct_bind_recursive +from slangpy.bindings import PYTHON_TYPES, BindContext, BoundVariable, can_direct_bind_common from slangpy.reflection import SlangProgramLayout, SlangType, UnknownType, StructType, InterfaceType from slangpy.core.native import AccessType @@ -76,12 +76,17 @@ def resolve_dimensionality( cast(int, binding.children[name].call_dimensionality) for name in self._fields.keys() ) + def can_direct_bind(self, binding: "BoundVariable") -> bool: + if binding.children is not None: + return all(child.direct_bind for child in binding.children.values()) + return can_direct_bind_common(binding) + # A struct type should get a dictionary, and just return that for raw dispatch def gen_trampoline_load( self, cgb: "CodeGenBlock", binding: "BoundVariable", is_entry_point: bool ) -> bool: - if not is_direct_bind_recursive(binding): + if not binding.direct_bind: return False data_name = ( f"_param_{binding.variable_name}" @@ -94,7 +99,7 @@ def gen_trampoline_load( def gen_trampoline_store( self, cgb: "CodeGenBlock", binding: "BoundVariable", is_entry_point: bool ) -> bool: - if not is_direct_bind_recursive(binding): + if not binding.direct_bind: return False if binding.access[0] in (AccessType.write, AccessType.readwrite): data_name = ( diff --git a/slangpy/builtin/value.py b/slangpy/builtin/value.py index c9d770e7c..8f7c59fea 100644 --- a/slangpy/builtin/value.py +++ b/slangpy/builtin/value.py @@ -15,7 +15,7 @@ BoundVariable, BoundVariableRuntime, CodeGenBlock, - is_direct_bind_eligible, + can_direct_bind_common, ) from slangpy.reflection.reflectiontypes import ( BOOL_TYPES, @@ -90,14 +90,16 @@ def has_derivative(self) -> bool: def is_writable(self) -> bool: return False + def can_direct_bind(self, binding: "BoundVariable") -> bool: + return can_direct_bind_common(binding) + # Call data can only be read access to primal, and simply declares it as a variable def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: "BoundVariable"): access = binding.access name = binding.variable_name if access[0] in [AccessType.read, AccessType.readwrite]: assert binding.vector_type is not None - if is_direct_bind_eligible(binding): - self.direct_bind = True + if binding.direct_bind: cgb.type_alias(f"_t_{name}", binding.vector_type.full_name) else: cgb.type_alias(f"_t_{name}", f"ValueType<{binding.vector_type.full_name}>") @@ -107,7 +109,7 @@ def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: "BoundV def gen_trampoline_load( self, cgb: CodeGenBlock, binding: "BoundVariable", is_entry_point: bool ) -> bool: - if not is_direct_bind_eligible(binding): + if not binding.direct_bind: return False if binding.access[0] not in (AccessType.read, AccessType.readwrite): return False @@ -121,7 +123,7 @@ def gen_trampoline_load( def gen_trampoline_store( self, cgb: CodeGenBlock, binding: "BoundVariable", is_entry_point: bool ) -> bool: - if not is_direct_bind_eligible(binding): + if not binding.direct_bind: return False # ValueMarshall is read-only — suppress the default store return True @@ -132,7 +134,7 @@ def create_calldata( ) -> Any: access = binding.access if access[0] in [AccessType.read, AccessType.readwrite]: - if self.direct_bind: + if binding.direct_bind: return data return {"value": data} @@ -343,8 +345,7 @@ def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: "BoundV if access[0] in [AccessType.read, AccessType.readwrite]: st = cast(kfr.VectorType, self.slang_type) et = cast(SlangType, st.element_type) - if is_direct_bind_eligible(binding): - self.direct_bind = True + if binding.direct_bind: cgb.type_alias(f"_t_{name}", binding.vector_type.full_name) else: cgb.type_alias(f"_t_{name}", f"VectorValueType<{et.full_name},{st.num_elements}>") diff --git a/slangpy/builtin/valueref.py b/slangpy/builtin/valueref.py index 26d3d0d12..cbe226948 100644 --- a/slangpy/builtin/valueref.py +++ b/slangpy/builtin/valueref.py @@ -18,7 +18,7 @@ CodeGenBlock, ReturnContext, get_or_create_type, - is_direct_bind_eligible, + can_direct_bind_common, ) from slangpy.builtin.value import slang_type_to_return_type from slangpy.reflection.reflectiontypes import SlangType @@ -112,7 +112,6 @@ class ValueRefMarshall(Marshall): def __init__(self, layout: kfr.SlangProgramLayout, value_type: kfr.SlangType): super().__init__(layout) self.value_type = value_type - self._direct_bind = False st = layout.find_type_by_name(f"ValueRef<{value_type.full_name}>") if st is None: @@ -150,6 +149,9 @@ def resolve_dimensionality( ): return len(self.value_type.shape) - len(vector_target_type.shape) + def can_direct_bind(self, binding: "BoundVariable") -> bool: + return can_direct_bind_common(binding) + # Call data can only be read access to primal, and simply declares it as a variable def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: "BoundVariable"): access = binding.access @@ -157,8 +159,7 @@ def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: "BoundV assert access[0] != AccessType.none assert access[1] == AccessType.none assert binding.vector_type is not None - if is_direct_bind_eligible(binding): - self._direct_bind = True + if binding.direct_bind: if access[0] == AccessType.read: cgb.type_alias(f"_t_{name}", binding.vector_type.full_name) else: @@ -175,7 +176,7 @@ def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: "BoundV def gen_trampoline_load( self, cgb: CodeGenBlock, binding: "BoundVariable", is_entry_point: bool ) -> bool: - if not is_direct_bind_eligible(binding): + if not binding.direct_bind: return False if binding.access[0] == AccessType.none: return False @@ -192,7 +193,7 @@ def gen_trampoline_load( def gen_trampoline_store( self, cgb: CodeGenBlock, binding: "BoundVariable", is_entry_point: bool ) -> bool: - if not is_direct_bind_eligible(binding): + if not binding.direct_bind: return False if binding.access[0] in (AccessType.write, AccessType.readwrite): if is_entry_point: @@ -209,7 +210,7 @@ def create_calldata( access = binding.access assert access[0] != AccessType.none assert access[1] == AccessType.none - if self._direct_bind and access[0] == AccessType.read: + if binding.direct_bind and access[0] == AccessType.read: return data.value elif access[0] == AccessType.read: return {"value": data.value} @@ -226,7 +227,7 @@ def create_calldata( if access[0] != AccessType.write: cursor[0].write(data.value) cursor.apply() - if self._direct_bind: + if binding.direct_bind: return buffer return {"value": buffer} else: @@ -241,7 +242,7 @@ def create_calldata( data=npdata, usage=BufferUsage.shader_resource | BufferUsage.unordered_access, ) - if self._direct_bind: + if binding.direct_bind: return buffer return {"value": buffer} @@ -259,7 +260,7 @@ def read_calldata( ) -> None: access = binding.access if access[0] in [AccessType.write, AccessType.readwrite]: - buffer = result if self._direct_bind else result["value"] + buffer = result if binding.direct_bind else result["value"] assert isinstance(buffer, Buffer) if isinstance(binding.vector_type, (kfr.StructType, kfr.ArrayType)): cursor = BufferCursor(binding.vector_type.buffer_layout.reflection, buffer) diff --git a/slangpy/core/calldata.py b/slangpy/core/calldata.py index 209ae8ffa..42d016748 100644 --- a/slangpy/core/calldata.py +++ b/slangpy/core/calldata.py @@ -266,6 +266,9 @@ def build(self, build_info: "FunctionBuildInfo", *args: Any, **kwargs: Any): # Calculate differentiability of all variables. calculate_differentiability(context, bindings) + # Calculate direct binding eligibility for all variables. + calculate_direct_binding(bindings) + # Generate code. codegen = CodeGen() generate_code(context, build_info, bindings, codegen) diff --git a/slangpy/core/callsignature.py b/slangpy/core/callsignature.py index b7bb42c3d..369e625e4 100644 --- a/slangpy/core/callsignature.py +++ b/slangpy/core/callsignature.py @@ -152,6 +152,16 @@ def calculate_differentiability(context: BindContext, call: BoundCall): arg.calculate_differentiability(context) +def calculate_direct_binding(call: BoundCall): + """ + Calculate direct binding eligibility for all variables. + """ + for arg in call.args: + arg.calculate_direct_bind() + for arg in call.kwargs.values(): + arg.calculate_direct_bind() + + def calculate_call_dimensionality(signature: BoundCall) -> int: """ Calculate the dimensionality of the call diff --git a/src/slangpy_ext/utils/slangpy.cpp b/src/slangpy_ext/utils/slangpy.cpp index 8d553f83d..ad99da49f 100644 --- a/src/slangpy_ext/utils/slangpy.cpp +++ b/src/slangpy_ext/utils/slangpy.cpp @@ -1543,7 +1543,13 @@ SGL_PY_EXPORT(utils_slangpy) &NativeBoundVariableRuntime::write_raw_dispatch_data, D_NA(NativeBoundVariableRuntime, write_raw_dispatch_data) ) - .def("read_output", &NativeBoundVariableRuntime::read_output, D_NA(NativeBoundVariableRuntime, read_output)); + .def("read_output", &NativeBoundVariableRuntime::read_output, D_NA(NativeBoundVariableRuntime, read_output)) + .def_prop_rw( + "direct_bind", + &NativeBoundVariableRuntime::direct_bind, + &NativeBoundVariableRuntime::set_direct_bind, + D_NA(NativeBoundVariableRuntime, direct_bind) + ); nb::class_(slangpy, "NativeBoundCallRuntime") // .def(nb::init<>(), D_NA(NativeBoundCallRuntime, NativeBoundCallRuntime)) diff --git a/src/slangpy_ext/utils/slangpy.h b/src/slangpy_ext/utils/slangpy.h index e71959ce8..a996c9da1 100644 --- a/src/slangpy_ext/utils/slangpy.h +++ b/src/slangpy_ext/utils/slangpy.h @@ -533,6 +533,12 @@ class NativeBoundVariableRuntime : public Object { /// Set the call dimensionality. void set_call_dimensionality(int call_dimensionality) { m_call_dimensionality = call_dimensionality; } + /// Whether this variable uses direct binding (raw Slang type, no wrapper). + bool direct_bind() const { return m_direct_bind; } + + /// Set the direct_bind flag. + void set_direct_bind(bool direct_bind) { m_direct_bind = direct_bind; } + /// Recursively populate the overall kernel call shape. void populate_call_shape(Shape& call_shape, nb::object value, NativeCallData* error_context); @@ -560,6 +566,7 @@ class NativeBoundVariableRuntime : public Object { int m_call_dimensionality{0}; ref m_vector_type; bool m_is_param_block{false}; + bool m_direct_bind{false}; }; /// Binding information for a call to a compute kernel. Includes a set of positional diff --git a/src/slangpy_ext/utils/slangpyvalue.cpp b/src/slangpy_ext/utils/slangpyvalue.cpp index f6ff0a454..0b3435208 100644 --- a/src/slangpy_ext/utils/slangpyvalue.cpp +++ b/src/slangpy_ext/utils/slangpyvalue.cpp @@ -16,7 +16,8 @@ void NativeValueMarshall::ensure_cached(ShaderCursor cursor, NativeBoundVariable { if (m_cached.is_valid) return; - ShaderCursor field = m_direct_bind ? cursor[binding->variable_name()] : cursor[binding->variable_name()]["value"]; + ShaderCursor field + = binding->direct_bind() ? cursor[binding->variable_name()] : cursor[binding->variable_name()]["value"]; m_cached.value_offset = field.offset(); m_cached.value_type_layout = field.slang_type_layout(); m_cached.writer = get_shader_cursor_writer(m_cached.value_type_layout); @@ -63,11 +64,5 @@ SGL_PY_EXPORT(utils_slangpy_value) new (&self) NativeValueMarshall(); }, D_NA(NativeValueMarshall, NativeValueMarshall) - ) - .def_prop_rw( - "direct_bind", - &NativeValueMarshall::direct_bind, - &NativeValueMarshall::set_direct_bind, - D_NA(NativeValueMarshall, direct_bind) ); } diff --git a/src/slangpy_ext/utils/slangpyvalue.h b/src/slangpy_ext/utils/slangpyvalue.h index 1ba086c59..42ad5ae28 100644 --- a/src/slangpy_ext/utils/slangpyvalue.h +++ b/src/slangpy_ext/utils/slangpyvalue.h @@ -28,13 +28,6 @@ class NativeValueMarshall : public NativeMarshall { nb::list read_back ) const override; - /// When true, the Slang type is a raw value (no "value" sub-field). - /// When false (default), the type has a "value" sub-field (e.g. ValueType). - bool direct_bind() const { return m_direct_bind; } - - /// Set the direct_bind flag. - void set_direct_bind(bool direct_bind) { m_direct_bind = direct_bind; } - private: /// Cached data for fast-path value writing, populated on first dispatch. struct CachedValueWrite { @@ -46,9 +39,6 @@ class NativeValueMarshall : public NativeMarshall { mutable CachedValueWrite m_cached; - /// Whether the Slang type is raw (no "value" sub-field). - bool m_direct_bind{false}; - /// Populate m_cached on first call by resolving the cursor path and writer function. void ensure_cached(ShaderCursor cursor, NativeBoundVariableRuntime* binding) const; From 22a82b7f9f9e10b488ac7f111875e3ed45510c73 Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Wed, 11 Mar 2026 11:08:07 +0000 Subject: [PATCH 04/41] wip tests to figure out the clear binding info --- .../plan-simplifyKernelGen-phase1.prompt.md | 248 ++++++++---------- .../prompts/plan-simplifyKernelGen.prompt.md | 42 +-- .../tests/slangpy_tests/test_kernel_gen.py | 241 +++++++++++++++++ 3 files changed, 373 insertions(+), 158 deletions(-) diff --git a/.github/prompts/plan-simplifyKernelGen-phase1.prompt.md b/.github/prompts/plan-simplifyKernelGen-phase1.prompt.md index bc398df03..09159ceda 100644 --- a/.github/prompts/plan-simplifyKernelGen-phase1.prompt.md +++ b/.github/prompts/plan-simplifyKernelGen-phase1.prompt.md @@ -8,208 +8,176 @@ --- -### Step 1.1: Define eligibility predicate - -Add a **global Python function** `is_direct_bind_eligible(binding: BoundVariable) -> bool` (e.g., in [slangpy/bindings/boundvariable.py](slangpy/bindings/boundvariable.py) or a small utility module). This is intentionally NOT a method on `NativeMarshall` — placing it on the C++ side would require nanobind trampoline plumbing for a function that is only consumed during Python-side codegen. A simple Python function avoids all C++/nanobind complexity. +### Architecture -The conditions for leaf types are: -- `binding.call_dimensionality is not None and binding.call_dimensionality == 0` (note: `call_dimensionality` is initialized to `None`, so a `None` check is required) -- `not binding.children` (not composite/dict) -- The marshall's Python type is one of the known direct-eligible types (`ValueMarshall`, `ScalarMarshall`, `VectorMarshall`, `MatrixMarshall`, `ArrayMarshall`, `ValueRefMarshall`). Types like `WangHashArgMarshall` are excluded. +Direct binding eligibility is determined by a **marshall-driven `can_direct_bind` property** combined with a **single depth-first `calculate_direct_bind` pass** on the `BoundVariable` tree. This follows the same pattern as `calculate_differentiability`. -Individual marshalls call this function inside their `gen_calldata`, `gen_trampoline_load`, `gen_trampoline_store`, and `create_calldata` methods to decide which codegen path to take. +#### Key components -Add a companion function `is_direct_bind_recursive(binding: BoundVariable) -> bool` that handles composite types: -- If `binding.children is None`: delegates to `is_direct_bind_eligible(binding)` -- If `binding.children is not None`: returns `True` only if `binding.call_dimensionality is not None and binding.call_dimensionality == 0` AND every child's `is_direct_bind_recursive()` returns `True`. This handles dicts bound to Slang structs where all fields are dim-0 leaves (or recursively dim-0 structs). -- Additionally, the `binding.vector_type` must be a concrete Slang struct type (not `UnknownType`). Dicts without `_type` may resolve to `UnknownType` and are ineligible. +| Component | Location | Role | +|-----------|----------|------| +| `Marshall.can_direct_bind(binding)` | `slangpy/bindings/marshall.py` | Virtual method (default `False`). Marshalls override to opt in. | +| `can_direct_bind_common(binding)` | `slangpy/bindings/boundvariable.py` | Shared eligibility checks (dim-0, no children, no param block). Marshalls call this then add type-specific logic. | +| `BoundVariable.direct_bind` | `slangpy/bindings/boundvariable.py` | Boolean attribute set by `calculate_direct_bind()`. Consumed by `gen_call_data_code`, `gen_calldata`, `gen_trampoline_load/store`, `create_calldata`. | +| `BoundVariable.calculate_direct_bind()` | `slangpy/bindings/boundvariable.py` | Depth-first tree pass. Leaves delegate to `marshall.can_direct_bind()`. Composites require all children to be direct-bind AND dim-0 with a concrete vector type. If composite is NOT direct-bind, recursively clears children via `_clear_direct_bind()`. | +| `calculate_direct_binding(call)` | `slangpy/core/callsignature.py` | Top-level function iterating `call.args` + `call.kwargs.values()`, calling `arg.calculate_direct_bind()`. | +| `NativeBoundVariableRuntime.direct_bind` | `slangpy.h` / `boundvariableruntime.py` | C++ member + Python propagation. Read by `NativeValueMarshall::ensure_cached` to gate `["value"]` sub-field navigation. | -Both functions are consulted by the marshalls and by `gen_call_data_code` (for struct/dict bindings). - ---- +#### Control flow -### Step 1.2: Implement for `ValueMarshall` (scalars/matrices) +``` +CallData.build() + → calculate_differentiability(context, bindings) + → calculate_direct_binding(bindings) ← NEW + → generate_code(...) + → gen_call_data_code() — reads binding.direct_bind + → gen_trampoline() — reads binding.direct_bind + → BoundCallRuntime(bindings) — propagates binding.direct_bind to C++ runtime +``` -In [slangpy/builtin/value.py](slangpy/builtin/value.py): +At dispatch time, `NativeValueMarshall::ensure_cached()` reads `binding->direct_bind()` to decide cursor navigation: +- `direct_bind == false`: `cursor[variable_name]["value"]` (wrapper path) +- `direct_bind == true`: `cursor[variable_name]` (raw type path) -- Modify `gen_calldata`: call `is_direct_bind_eligible(binding)`. When eligible, emit `typealias _t_{name} = {raw_slang_type};` instead of `ValueType<{type}>` -- Add `gen_trampoline_load`: when direct-eligible, emit `{name} = {data_name};` and return `True` -- Add `gen_trampoline_store`: when direct-eligible (read-only scalars), return `True` (suppress default store, no-op) -- Modify `create_calldata`: when direct-eligible, return raw value instead of `{"value": data}`. The cursor write system in [cursor_utils.h](src/slangpy_ext/device/cursor_utils.h) already handles writing scalars/vectors/matrices directly — the `write_internal` method dispatches on `TypeReflection::Kind::scalar/vector/matrix`. +#### Composite (struct/dict) handling -#### Step 1.2a: Critical C++ change — `NativeValueMarshall` fast path +When `calculate_direct_bind()` visits a composite node: +1. Recurse children first (depth-first) +2. If all children have `direct_bind == True` AND the composite is dim-0 with a concrete vector type → set `self.direct_bind = True` +3. Otherwise → call `_clear_direct_bind()` on all children, forcing them to use wrapper types. This is necessary because the parent's generated `__slangpy_load`/`__slangpy_store` expects children to have wrapper types (e.g., `ValueType`). A child emitting raw `float` inside a parent that emits `__slangpy_load` would produce invalid Slang. -`NativeValueMarshall::write_shader_cursor_pre_dispatch` in [slangpyvalue.cpp](src/slangpy_ext/utils/slangpyvalue.cpp) has a cached fast path that navigates `cursor[variable_name]["value"]` on first call: +--- -```cpp -ShaderCursor field = cursor[binding->variable_name()]["value"]; -m_cached.value_offset = field.offset(); -``` +### Step 1.1: Define eligibility predicate -If the Slang type changes from `ValueType` (which has a `value` sub-field) to raw `int` (a scalar with no sub-fields), the `["value"]` navigation will crash. This affects **all** `NativeValueMarshall` subclasses: `ValueMarshall`, `VectorMarshall`, `MatrixMarshall`, `StructMarshall`, `ArrayMarshall`. +**Implemented.** A `can_direct_bind(binding)` virtual method on `Marshall` (default `False`) replaces the original `is_direct_bind_eligible` / `is_direct_bind_recursive` global functions. Each marshall subclass overrides `can_direct_bind` to opt in. -**Required fix**: Add a `direct_bind` flag to `NativeValueMarshall` (set from the Python side when `can_direct_bind` returns `True`). In `ensure_cached`, branch on this flag: -- **`direct_bind == false`** (current path): navigate `cursor[variable_name]["value"]` -- **`direct_bind == true`** (new path): navigate `cursor[variable_name]` only (no `"value"` sub-field), and cache the resulting offset/layout/writer directly +A shared helper `can_direct_bind_common(binding)` in `boundvariable.py` provides the common checks: +- `binding.call_dimensionality is not None and binding.call_dimensionality == 0` +- `not binding.children` (not composite/dict) +- `not getattr(binding, "create_param_block", False)` (excludes `PackedArg`) -The flag can be set during `CallData` construction when the binding is finalized, or passed via a `NativeBoundVariableRuntime` property. Alternatively, detect the absence of the `"value"` field by checking the type layout — but an explicit flag is safer and clearer. +Marshall subclasses call `can_direct_bind_common(binding)` and optionally add type-specific logic. `StructMarshall` has its own implementation: if it has children, all children must have `direct_bind == True`; otherwise it delegates to `can_direct_bind_common`. --- -### Step 1.3: Implement for `VectorMarshall`, `MatrixMarshall`, and `ArrayMarshall` +### Step 1.2: Implement for `ValueMarshall` (scalars/matrices) -In [slangpy/builtin/value.py](slangpy/builtin/value.py): -- `VectorMarshall`: same pattern as `ValueMarshall`. `gen_calldata` emits `typealias _t_{name} = {vector_type};` instead of `VectorValueType<{et},{n}>`. -- `MatrixMarshall`: same pattern. Note that `MatrixMarshall` does **not** override `gen_calldata` — it inherits `ValueMarshall.gen_calldata` which emits `ValueType<{matrix_type}>` (not `MatrixValueType`). The `MatrixValueType<...>` name only appears in `resolve_types` for the experimental vectorization path. The direct-bind override goes on `gen_calldata` and emits the raw matrix type (e.g., `float4x4`) instead of `ValueType`. +**Implemented.** In [slangpy/builtin/value.py](slangpy/builtin/value.py): -In [slangpy/builtin/array.py](slangpy/builtin/array.py): -- `ArrayMarshall`: at dim-0, it already falls through to `super().gen_calldata()` (i.e. `ValueMarshall`) which uses `ValueType`. The same direct-bind pattern applies — emit the raw array type instead of wrapping in `ValueType`. +- `can_direct_bind(binding)`: calls `can_direct_bind_common(binding)` +- `gen_calldata`: when `binding.direct_bind`, emits `typealias _t_{name} = {raw_slang_type}` instead of `ValueType<{type}>` +- `gen_trampoline_load`: when `binding.direct_bind`, emits `{name} = {data_name}` and returns `True` +- `gen_trampoline_store`: when `binding.direct_bind` (read-only), returns `True` (suppress default store) +- `create_calldata`: when `binding.direct_bind`, returns raw value instead of `{"value": data}` ---- +#### Step 1.2a: C++ fast path -### Step 1.4: Implement for `StructMarshall` (dict → struct) +**Implemented.** `NativeValueMarshall::ensure_cached` in [slangpyvalue.cpp](src/slangpy_ext/utils/slangpyvalue.cpp) reads `binding->direct_bind()` from the `NativeBoundVariableRuntime`: -In [slangpy/bindings/boundvariable.py](slangpy/bindings/boundvariable.py): +```cpp +ShaderCursor field = binding->direct_bind() + ? cursor[binding->variable_name()] + : cursor[binding->variable_name()]["value"]; +``` -When a Python dict is bound to a Slang struct and `is_direct_bind_recursive(binding)` returns `True` (all children are dim-0 and direct-eligible recursively), the **Slang-side** struct can bypass the inline `__slangpy_load`/`__slangpy_store` struct generation. The **Python/C++ side** keeps the existing tree of marshalls unchanged — they continue to recurse through children and cache offsets for efficient cursor writes. +The `direct_bind` flag is a `bool` member on `NativeBoundVariableRuntime` (declared in [slangpy.h](src/slangpy_ext/utils/slangpy.h)), exposed via nanobind property in [slangpy.cpp](src/slangpy_ext/utils/slangpy.cpp), and propagated from `BoundVariable.direct_bind` via [boundvariableruntime.py](slangpy/bindings/boundvariableruntime.py). -This is a Slang-code-gen-only simplification: -- **Current behavior (children path in `gen_call_data_code`)**: generates an inline struct `_t_{name}` with field declarations, `__slangpy_load`/`__slangpy_store` methods, and mapping constants — delegates each child's code gen recursively -- **Direct-eligible behavior**: emit `typealias _t_{name} = {vector_type.full_name};` (the raw Slang struct type). Skip generating the inline struct, its load/store methods, child type aliases, and child mapping constants entirely. +The `m_direct_bind` / `direct_bind` / `set_direct_bind` members were **removed** from `NativeValueMarshall` — the flag lives exclusively on `NativeBoundVariableRuntime`. -In the trampoline: -- `gen_trampoline_load`: emit `{name} = {data_name};` (direct struct assignment) and return `True` -- `gen_trampoline_store`: return `True` (suppress default store for read-only structs) +--- -**Python/C++ dispatch — keep the child tree, fix the per-child cursor path**: +### Step 1.3: Implement for `VectorMarshall`, `MatrixMarshall`, and `ArrayMarshall` -The Python-side tree of marshalls is kept for dispatch. When a `BoundVariable` has children (dict case), the C++ dispatch in [slangpy.cpp](src/slangpy_ext/utils/slangpy.cpp) `NativeBoundVariableRuntime::write_shader_cursor_pre_dispatch` still takes the children branch: +**Implemented.** All inherit `can_direct_bind` and `gen_trampoline_load`/`gen_trampoline_store` from `ValueMarshall`. `VectorMarshall` overrides `gen_calldata` to emit the raw vector type (e.g., `vector`) instead of `VectorValueType` when `binding.direct_bind`. `MatrixMarshall` and `ArrayMarshall` (at dim-0) inherit `ValueMarshall.gen_calldata`. -```cpp -ShaderCursor child_field = cursor[m_variable_name.c_str()]; -for (const auto& [name, child_ref] : *m_children) { - child_ref->write_shader_cursor_pre_dispatch(context, child_field, child_value, read_back); -} -``` +--- + +### Step 1.4: Implement for `StructMarshall` (dict → struct) -Each child leaf calls `NativeValueMarshall::write_shader_cursor_pre_dispatch`, which navigates `cursor[variable_name]["value"]`. If the Slang struct type changes from the inline struct (where each child field is `ValueType` with a `value` sub-field) to the raw Slang struct (where each child field is `float` directly), the `["value"]` navigation will crash. +**Implemented.** In [slangpy/builtin/struct.py](slangpy/builtin/struct.py): -**Solution**: Set the `direct_bind` flag from Step 1.2a on each child's `NativeValueMarshall`. The per-child flag causes each leaf's `ensure_cached` to navigate `cursor[variable_name]` only (no `["value"]` sub-field). This is the same fix as Step 1.2a applied to each child — no changes to the children dispatch path itself are needed. +- `can_direct_bind(binding)`: if `binding.children is not None`, returns `True` only if all children have `direct_bind == True`. Otherwise delegates to `can_direct_bind_common(binding)`. +- `gen_trampoline_load`: when `binding.direct_bind`, emits `{name} = {data_name}` and returns `True` +- `gen_trampoline_store`: when `binding.direct_bind`, emits `{data_name} = {name}` for writable and returns `True` -`StructMarshall.create_calldata` is dead code for the children path: when `m_children` is set on the C++ `NativeBoundVariableRuntime`, the children dispatch branch runs instead of calling the marshall's `create_calldata`. The current `ValueMarshall.create_calldata` (which `StructMarshall` inherits) is never called for dict bindings. It can be removed from `StructMarshall` if desired, but is harmless. +In [slangpy/bindings/boundvariable.py](slangpy/bindings/boundvariable.py), `gen_call_data_code`: +- When `self.direct_bind`, emits `typealias _t_{name} = {vector_type.full_name}` (raw struct type) — skipping inline struct generation, `__slangpy_load`/`__slangpy_store`, and child type aliases. +- When NOT `self.direct_bind`, uses the standard children path with inline struct. -**Complexity considerations:** -- The recursive eligibility check must traverse all children. Nested dicts (struct-of-struct) work if all leaves are direct-eligible. -- The `vector_type` on the `BoundVariable` must be a concrete Slang struct type (not `UnknownType`). If the dict has `_type` specified, the struct type is resolved; if not, it may be `UnknownType` and ineligible. -- Writable struct fields (inout/out parameters) need the same treatment as writable scalars — the struct in CallData stays as the raw type, but the trampoline does direct assignment both ways. -- This optimization can be deferred if it proves too complex for the initial Phase 1 implementation — the fallback (current inline struct with load/store) always works. Priority should be leaf types first. +Children inside non-direct-bind composites have their `direct_bind` cleared by `_clear_direct_bind()` during `calculate_direct_bind`. This ensures children use wrapper types compatible with the parent's `__slangpy_load`/`__slangpy_store`. --- ### Step 1.5: Implement for `ValueRefMarshall` -In [slangpy/builtin/valueref.py](slangpy/builtin/valueref.py): +**Implemented.** In [slangpy/builtin/valueref.py](slangpy/builtin/valueref.py): -Note: There is only **one** `ValueRefMarshall` class (not separate `ValueRef`/`RWValueRef` classes). It inherits from `Marshall` (not `NativeValueMarshall`). Read vs. write behavior is determined by `binding.access` at codegen time — the same class emits `ValueRef` or `RWValueRef` depending on access mode. +- `can_direct_bind(binding)`: calls `can_direct_bind_common(binding)` +- `gen_calldata`: when `binding.direct_bind`, read-only emits raw type, writable emits `RWStructuredBuffer` +- `gen_trampoline_load/store`: when `binding.direct_bind`, read-only does direct assignment, writable does `[0]` indexing +- `create_calldata` / `read_calldata`: when `binding.direct_bind`, skip `{"value": ...}` wrapper -- In `gen_calldata`, call `is_direct_bind_eligible(binding)`. When eligible: -- Read-only path (`access[0] == AccessType.read`): `gen_calldata` emits raw type, `gen_trampoline_load` does direct assignment, `create_calldata` returns raw value -- Writable path (`access[0] != AccessType.read`): `gen_calldata` emits `RWStructuredBuffer<{type}>`, `gen_trampoline_load` emits `{name} = {data_name}[0];`, `gen_trampoline_store` emits `{data_name}[0] = {name};`, `create_calldata` returns the buffer directly (no `{"value": buffer}` wrapper). Note: `RWStructuredBuffer` is a **resource type** in Slang — the cursor write system handles it via the resource binding mechanism. Buffer objects are written to resource-typed cursor fields via the `write_value` virtual path in [cursor_utils.h](src/slangpy_ext/device/cursor_utils.h), not the struct/scalar dispatch in `write_internal`. - -Since `ValueRefMarshall` extends `Marshall` (not `NativeValueMarshall`), the C++ fast path issue from Step 1.2a does not apply — `NativeMarshall::write_shader_cursor_pre_dispatch` calls `create_calldata` and then passes the result to the generic `write_shader_cursor(cursor, cd_val)`, which dispatches based on the Slang type layout. There is no cached `["value"]` navigation. +The old `self._direct_bind` attribute was **removed** — all checks now use `binding.direct_bind`. --- ### Step 1.6: Implement for tensor marshalls -In [slangpy/builtin/tensorcommon.py](slangpy/builtin/tensorcommon.py): The `TensorView`/`DiffTensorView` case already works via direct assignment. - -For `Tensor` (the slangpy Tensor type): this is a **complex struct** containing `_data` (a `StructuredBuffer`/pointer), `_shape[D]`, `_strides[D]`, and `_offset`. It is NOT a simple assignable type like a scalar — it always requires its buffer resource descriptor and metadata to be bound. However, since it is already a well-defined Slang struct, it can still use direct assignment (`name = call_data.name;`) in the trampoline when dim-0. The `gen_calldata` already emits the correct tensor type name. Add `gen_trampoline_load` to handle `ITensorType` dim-0 with direct assignment (same as TensorView pattern — the struct is copied as a whole). +**Implemented.** In [slangpy/builtin/tensorcommon.py](slangpy/builtin/tensorcommon.py): -Note: `Tensor` cannot be simplified to a raw value the way scalars can — it stays as a struct in CallData. The simplification here is only at the trampoline level (bypassing `__slangpy_load`/`__slangpy_store`). +`gen_trampoline_load/store` extended for `ITensorType` at dim-0 (direct struct assignment). Tensor marshalls do NOT implement `can_direct_bind` — tensor dim-0 handling is done via trampoline-level checks on `binding.call_dimensionality` and `binding.vector_type` type, independent of the `direct_bind` flag. --- ### Step 1.7: Eliminate unused boilerplate in code generation -In [slangpy/core/callsignature.py](slangpy/core/callsignature.py): - -- **Mapping constants**: In `BoundVariable.gen_call_data_code()` ([slangpy/bindings/boundvariable.py](slangpy/bindings/boundvariable.py)), skip emitting `static const int _m_{name} = 0;` when `is_direct_bind_eligible(self)` (for leaves) or `is_direct_bind_recursive(self)` (for composites) returns `True`. These constants are only consumed by `__slangpy_context__.map(_m_{name})` calls, which direct-bound variables skip. -- **`import "slangpy"`**: Keep this import. Attempting to detect and eliminate it provides negligible benefit for significant complexity. The slangpy Slang module is always available and the link-time constants are always emitted. The focus of this phase is eliminating wrapper types and `__slangpy_load`/`__slangpy_store` indirection, not the import. +**Implemented.** In [slangpy/bindings/boundvariable.py](slangpy/bindings/boundvariable.py), `gen_call_data_code` skips emitting `static const int _m_{name} = 0` mapping constants when `self.direct_bind` is `True`. --- ### Step 1.8: Handle autodiff (bwds mode) -For differentiable types in bwds mode: -- Primal reads are still direct-eligible (just a direct assignment) -- Derivative writes need writable backing — use `RWStructuredBuffer` for derivative fields (similar to `RWValueRef` pattern) -- The trampoline must remain `[Differentiable]`, but direct assignment `a = call_data.a;` is trivially differentiable in Slang -- The `gen_trampoline_load`/`gen_trampoline_store` implementations need to account for `access[1]` (derivative access) and emit derivative load/store code when needed -- This is the most complex part of Phase 1; consider implementing prim-mode direct binding first, then extending to bwds +⬜ **Deferred.** Prim-mode direct binding applies to bwds primals (code gen verified), but derivative fields still use the old `ValueType` wrapper path. --- ### Step 1.9: Tests -Extend [slangpy/tests/slangpy_tests/test_kernel_gen.py](slangpy/tests/slangpy_tests/test_kernel_gen.py). All tests use `generate_code()` which calls `func.debug_build_call_data(*args, **kwargs)` and returns `cd.code`. Tests are parametrized across `helpers.DEFAULT_DEVICE_TYPES`. - -**Assertion helpers** (added to test file): -- `assert_contains(code, *patterns)` — assert all patterns appear in generated code -- `assert_not_contains(code, *patterns)` — assert none appear - -**Gating tests** — assert CURRENT behavior so they break when each step is implemented: - -| Test | Slang Source | Args | Asserts (current behavior) | Breaks when | -|------|-------------|------|---------------------------|-------------| -| `test_gate_scalar_uses_valuetype` | `int add(int a, int b) { return a + b; }` | `(1, 2)` | `ValueType` present, `__slangpy_load` for `a`/`b`, `__slangpy_store` for `_result` | Step 1.2 | -| `test_gate_float_scalar_uses_valuetype` | `float mul(float x, float y) { return x * y; }` | `(1.0, 2.0)` | `ValueType` present, `__slangpy_load` present | Step 1.2 | -| `test_gate_vector_uses_vectorvaluetype` | `float3 scale(float3 v, float s) { return v * s; }` | `(spy.math.float3(1,2,3), 1.0)` | `VectorValueType` for `v` (no space after comma), `ValueType` for `s` | Step 1.3 | -| `test_gate_matrix_uses_valuetype` | `float4x4 ident(float4x4 m) { return m; }` | `(spy.math.float4x4.identity(),)` | `ValueType` present | Step 1.3 | -| `test_gate_valueref_read_uses_wrapper` | `float read_val(float v) { return v; }` | `(spy.ValueRef(1.0),)` | `ValueRef` present, `__slangpy_load` present | Step 1.5 | -| `test_gate_valueref_write_uses_wrapper` | `int add(int a, int b) { return a + b; }` | `(1, 2)` (auto `_result`) | `RWValueRef` for `_result`, `__slangpy_store` present | Step 1.5 | -| `test_gate_array_dim0_uses_valuetype` | `void process(float a[4]) { }` | `([1.0, 2.0, 3.0, 4.0],)` | `ValueType<` present for array binding | Step 1.3 | -| `test_gate_mapping_constants_present` | `int add(int a, int b) { return a + b; }` | `(1, 2)` | `static const int _m_a = 0` and `_m_b` and `_m__result` present | Step 1.7 | -| `test_gate_context_map_in_trampoline` | `int add(int a, int b) { return a + b; }` | `(1, 2)` | `__slangpy_context__.map(_m_a)` in trampoline | Step 1.7 | -| `test_gate_struct_uses_slangpy_load` | `struct S { float x; float y; }; float sum(S s) { return s.x + s.y; }` | `({"x": 1.0, "y": 2.0},)` | inline struct `_t_s` with `__slangpy_load`, child mapping constants `_m_x`, `_m_y` | Step 1.4 | - -**Negative gates** — should REMAIN passing after Phase 1 (these types are NOT direct-bind eligible): - -| Test | Slang Source | Args | Asserts (must stay) | -|------|-------------|------|--------------------| -| `test_gate_wanghasharg_uses_wrapper` | `int rng(WangHashArg rng) { return 0; }` | `(spy.WangHashArg(1),)` | `WangHashArg<` in type alias, `__slangpy_load` present | -| `test_gate_vectorized_scalar_keeps_wrapper` | `float square(float x) { return x * x; }` | `(Tensor.numpy(np.array([1,2,3], dtype=np.float32)),)` | `ValueType` present (dim > 0, not direct-eligible) | -| `test_gate_vectorized_dict_keeps_struct_load` | `struct S { float x; float y; }; void apply(S s, float scale) {}` | `({"x": Tensor(...), "y": Tensor(...)}, 1.0)` | inline struct with `__slangpy_load` (children are vectorized, dim > 0) | - -**Autodiff gating tests:** - -| Test | Slang Source | Args | Asserts | -|------|-------------|------|---------| -| `test_gate_bwds_scalar_uses_valuetype` | `[Differentiable] float square(float x) { return x * x; }` | `func.bwds.debug_build_call_data(diffPair(2.0), diffPair(d=1.0))` | `ValueType` present, `[Differentiable]` on trampoline, `bwd_diff(_trampoline)` in kernel | -| `test_gate_bwds_trampoline_is_differentiable` | same as above | same | `[Differentiable]` appears before `void _trampoline` | - -**Post-implementation tests** — should pass AFTER Phase 1 is complete: - -- `test_phase1_scalar_direct_bind`: verify NO `ValueType` or `__slangpy_load` for scalar args -- `test_phase1_vector_direct_bind`: verify NO `VectorValueType` for vector args -- `test_phase1_valueref_direct_bind`: verify `RWStructuredBuffer` appears directly for writable result -- `test_phase1_struct_direct_bind`: verify NO inline struct with `__slangpy_load` for dim-0 dict-to-struct -- `test_phase1_no_mapping_constants`: verify NO `_m_a`, `_m_b` for direct-bound args -- `test_phase1_functional_scalar_add`: dispatch `add(1, 2)` and verify result == 3 -- `test_phase1_functional_vector_scale`: dispatch vector scale and verify result -- `test_phase1_functional_struct_sum`: dispatch struct sum via dict and verify result +**Implemented.** 21 tests × 3 device types = 63 cases. All pass on d3d12/vulkan/cuda. --- -### Implementation Order Within Phase 1 +### Files Modified + +| File | Changes | +|------|---------| +| `src/slangpy_ext/utils/slangpy.h` | `m_direct_bind` member, `direct_bind()` getter, `set_direct_bind()` setter on `NativeBoundVariableRuntime` | +| `src/slangpy_ext/utils/slangpy.cpp` | Nanobind `direct_bind` property on `NativeBoundVariableRuntime` | +| `src/slangpy_ext/utils/slangpyvalue.h` | `m_direct_bind`, `direct_bind()`, `set_direct_bind()` **removed** from `NativeValueMarshall` | +| `src/slangpy_ext/utils/slangpyvalue.cpp` | `ensure_cached` reads `binding->direct_bind()` instead of `m_direct_bind`; nanobind `direct_bind` property **removed** from `NativeValueMarshall` | +| `slangpy/bindings/marshall.py` | `can_direct_bind(binding)` virtual method (default `False`) | +| `slangpy/bindings/boundvariable.py` | `can_direct_bind_common()`, `BoundVariable.direct_bind` attribute, `BoundVariable.calculate_direct_bind()`, `BoundVariable._clear_direct_bind()`. Old functions removed: `is_direct_bind_eligible`, `is_direct_bind_recursive`, `_set_direct_bind_on_children`, `_force_no_direct_bind`, `_DIRECT_BIND_TYPES`. | +| `slangpy/bindings/boundvariableruntime.py` | `self.direct_bind = source.direct_bind` propagation | +| `slangpy/bindings/__init__.py` | Exports `can_direct_bind_common` (removed `is_direct_bind_eligible`, `is_direct_bind_recursive`) | +| `slangpy/core/callsignature.py` | `calculate_direct_binding(call)` function | +| `slangpy/core/calldata.py` | `calculate_direct_binding(bindings)` call after `calculate_differentiability` | +| `slangpy/builtin/value.py` | `can_direct_bind`, `gen_calldata`, `gen_trampoline_load`, `gen_trampoline_store`, `create_calldata` use `binding.direct_bind`. Removed `self.direct_bind` on marshall. | +| `slangpy/builtin/valueref.py` | `can_direct_bind`, `gen_calldata`, `gen_trampoline_load`, `gen_trampoline_store`, `create_calldata`, `read_calldata` use `binding.direct_bind`. Removed `self._direct_bind`. | +| `slangpy/builtin/struct.py` | `can_direct_bind`, `gen_trampoline_load`, `gen_trampoline_store` use `binding.direct_bind` | +| `slangpy/builtin/tensorcommon.py` | `gen_trampoline_load`, `gen_trampoline_store` extended for `ITensorType` dim-0 (unchanged in refactor) | +| `slangpy/tests/slangpy_tests/test_kernel_gen.py` | All Phase 1 tests | + +### Test Results + +2952 passed / 0 failed in `slangpy/tests/slangpy_tests`. 6 pre-existing failures in `slangpy/tests/device/` (raytracing pipeline, type conformance cache — unrelated). + +### Design Decisions + +**`direct_bind` lives on `NativeBoundVariableRuntime`, not `NativeValueMarshall`.** The original implementation stored `m_direct_bind` on the marshall itself (`NativeValueMarshall`), but marshalls are shared across calls while bindings are per-call. Moving the flag to the binding makes it immutable per-call and eliminates mutable state on shared marshall instances. -To avoid C++ crashes from the `NativeValueMarshall` fast path (Step 1.2a), the implementation order within Phase 1 must be: +**Marshall-driven `can_direct_bind` replaces hardcoded type list.** The original `is_direct_bind_eligible` used a lazily-populated `_DIRECT_BIND_TYPES` tuple to check marshall type. The new design uses a virtual method — each marshall opts in explicitly. Adding a new direct-bind-eligible type requires only overriding `can_direct_bind` on the new class. -1. **Step 1.2a first**: Update `NativeValueMarshall::ensure_cached` in C++ to handle direct-bind types (no `"value"` sub-field navigation). This is the only C++ change needed — `is_direct_bind_eligible` is pure Python, no nanobind changes required. -2. **Step 1.1**: Add `is_direct_bind_eligible` and `is_direct_bind_recursive` as global Python functions -3. **Steps 1.2–1.7**: Implement Python-side changes for each marshall type -4. **Step 1.4**: For struct children, set the `direct_bind` flag on each child's `NativeValueMarshall` (same per-child flag from Step 1.2a) — no changes to the C++ children dispatch path needed -5. **Step 1.8**: Autodiff support -6. **Step 1.9**: Tests +**Single `calculate_direct_bind` pass replaces repeated predicate calls.** The original `is_direct_bind_eligible` / `is_direct_bind_recursive` were called multiple times per variable during code gen. The new design computes `direct_bind` once in a single tree pass after `calculate_differentiability`, and consumers simply read the boolean. -Never deploy Python-side `gen_calldata` changes that emit raw types without the corresponding C++ fast path fix — the cached `["value"]` navigation will crash at dispatch time. +**`_clear_direct_bind` replaces `_force_no_direct_bind`.** When a composite struct is NOT direct-bind-eligible (e.g., has vectorized children), its children must NOT use direct binding either — the parent's generated `__slangpy_load`/`__slangpy_store` expects children to have wrapper types. The old implementation set `_force_no_direct_bind = True` on children during code gen. The new implementation clears `direct_bind` recursively during the `calculate_direct_bind` pass itself, before code gen runs. diff --git a/.github/prompts/plan-simplifyKernelGen.prompt.md b/.github/prompts/plan-simplifyKernelGen.prompt.md index d95a79e58..4a05f68cb 100644 --- a/.github/prompts/plan-simplifyKernelGen.prompt.md +++ b/.github/prompts/plan-simplifyKernelGen.prompt.md @@ -29,39 +29,45 @@ void compute_main(int3 tid: SV_DispatchThreadID, uniform uint3 _thread_count, un Phase 1 prim-mode direct binding is complete. Steps 1.1–1.7, 1.9 are implemented and passing. Step 1.8 (autodiff/bwds) is deferred. +The implementation was refactored from global predicate functions (`is_direct_bind_eligible`, `is_direct_bind_recursive`) with mutable marshall state (`_force_no_direct_bind`, `_set_direct_bind_on_children`) to a marshall-driven `can_direct_bind` property + single depth-first `calculate_direct_bind` pass on the `BoundVariable` tree, following the `calculate_differentiability` pattern. + **What was done:** | Step | Status | Summary | |------|--------|---------| -| 1.2a | ✅ Done | `NativeValueMarshall` C++ fast path: `m_direct_bind` flag gates `["value"]` sub-field navigation in `ensure_cached`. Exposed via nanobind `direct_bind` property. | -| 1.1 | ✅ Done | `is_direct_bind_eligible()` and `is_direct_bind_recursive()` in `boundvariable.py`. Excludes `PackedArg` bindings and children inside non-direct-bind structs (via `_force_no_direct_bind` flag). | -| 1.2 | ✅ Done | `ValueMarshall`: `gen_calldata` emits raw `typealias`; `gen_trampoline_load/store` do direct assignment; `create_calldata` returns raw value. `ScalarMarshall`/`MatrixMarshall` inherit. | -| 1.3 | ✅ Done | `VectorMarshall`: `gen_calldata` emits raw `typealias` (e.g., `vector`). Inherits trampoline load/store from `ValueMarshall`. | -| 1.4 | ✅ Done | `StructMarshall`/`BoundVariable`: `gen_call_data_code` children path emits `typealias _t_{name} = {struct_type}` when `is_direct_bind_recursive`. Sets `direct_bind` on child marshalls via `_set_direct_bind_on_children`. Non-direct-bind structs set `_force_no_direct_bind` on children to prevent incorrect leaf optimization. `gen_trampoline_load/store` added. | -| 1.5 | ✅ Done | `ValueRefMarshall`: read-only emits raw type + direct assignment; writable emits `RWStructuredBuffer` + `[0]` load/store. `create_calldata`/`read_calldata` skip `{"value": ...}` wrapper when direct-eligible. | +| 1.2a | ✅ Done | C++ fast path: `NativeValueMarshall::ensure_cached` reads `binding->direct_bind()` from `NativeBoundVariableRuntime` to gate `["value"]` sub-field navigation. `m_direct_bind` **removed** from `NativeValueMarshall` — flag lives on `NativeBoundVariableRuntime`. | +| 1.1 | ✅ Done | `Marshall.can_direct_bind(binding)` virtual method (default `False`). Shared `can_direct_bind_common(binding)` helper. `BoundVariable.calculate_direct_bind()` depth-first tree pass. `calculate_direct_binding(call)` in `callsignature.py`. | +| 1.2 | ✅ Done | `ValueMarshall`: `can_direct_bind` overrides. `gen_calldata`, `gen_trampoline_load/store`, `create_calldata` read `binding.direct_bind`. | +| 1.3 | ✅ Done | `VectorMarshall`: `gen_calldata` emits raw `typealias` (e.g., `vector`). Inherits trampoline load/store and `can_direct_bind` from `ValueMarshall`. | +| 1.4 | ✅ Done | `StructMarshall`/`BoundVariable`: `can_direct_bind` checks all children. `gen_call_data_code` uses `self.direct_bind`. Non-direct-bind composites clear children's `direct_bind` via `_clear_direct_bind()`. | +| 1.5 | ✅ Done | `ValueRefMarshall`: `can_direct_bind` override. All methods read `binding.direct_bind`. | | 1.6 | ✅ Done | Tensor dim-0: `gen_trampoline_load/store` extended for `ITensorType` at dim-0 (direct struct assignment). | -| 1.7 | ✅ Done | Mapping constants (`static const int _m_{name}`) skipped for direct-bind-eligible variables. | -| 1.8 | ⬜ Deferred | Autodiff/bwds mode still uses wrapper types. Prim-mode direct binding does apply to bwds primals (code gen verified), but derivative fields still use the old path. | +| 1.7 | ✅ Done | Mapping constants (`static const int _m_{name}`) skipped when `self.direct_bind`. | +| 1.8 | ⬜ Deferred | Autodiff/bwds mode still uses wrapper types. | | 1.9 | ✅ Done | 21 tests (×3 device types = 63 cases): 16 code-gen assertion tests + 5 functional GPU dispatch tests. All pass on d3d12/vulkan/cuda. | **Files modified:** | File | Changes | |------|---------| -| `src/slangpy_ext/utils/slangpyvalue.h` | `m_direct_bind` flag, getter/setter | -| `src/slangpy_ext/utils/slangpyvalue.cpp` | `ensure_cached` direct-bind branch; nanobind export | -| `slangpy/bindings/boundvariable.py` | `is_direct_bind_eligible`, `is_direct_bind_recursive`, `_set_direct_bind_on_children`, `_force_no_direct_bind`, mapping-constant skip in `gen_call_data_code` | -| `slangpy/bindings/__init__.py` | Exports for predicates | -| `slangpy/builtin/value.py` | `gen_calldata`, `gen_trampoline_load`, `gen_trampoline_store`, `create_calldata` | -| `slangpy/builtin/valueref.py` | `gen_calldata`, `gen_trampoline_load`, `gen_trampoline_store`, `create_calldata`, `read_calldata` | -| `slangpy/builtin/struct.py` | `gen_trampoline_load`, `gen_trampoline_store` | -| `slangpy/builtin/tensorcommon.py` | `gen_trampoline_load`, `gen_trampoline_store` extended for `ITensorType` | +| `src/slangpy_ext/utils/slangpy.h` | `m_direct_bind` member, getter/setter on `NativeBoundVariableRuntime` | +| `src/slangpy_ext/utils/slangpy.cpp` | Nanobind `direct_bind` property on `NativeBoundVariableRuntime` | +| `src/slangpy_ext/utils/slangpyvalue.h` | `m_direct_bind`, `direct_bind()`, `set_direct_bind()` **removed** from `NativeValueMarshall` | +| `src/slangpy_ext/utils/slangpyvalue.cpp` | `ensure_cached` reads `binding->direct_bind()`; nanobind `direct_bind` property **removed** from `NativeValueMarshall` | +| `slangpy/bindings/marshall.py` | `can_direct_bind(binding)` virtual method (default `False`) | +| `slangpy/bindings/boundvariable.py` | `can_direct_bind_common()`, `BoundVariable.direct_bind`, `calculate_direct_bind()`, `_clear_direct_bind()`. Removed: `is_direct_bind_eligible`, `is_direct_bind_recursive`, `_set_direct_bind_on_children`, `_force_no_direct_bind`, `_DIRECT_BIND_TYPES`. | +| `slangpy/bindings/boundvariableruntime.py` | `self.direct_bind = source.direct_bind` propagation | +| `slangpy/bindings/__init__.py` | Exports `can_direct_bind_common` (removed old predicate exports) | +| `slangpy/core/callsignature.py` | `calculate_direct_binding(call)` function | +| `slangpy/core/calldata.py` | `calculate_direct_binding(bindings)` call after `calculate_differentiability` | +| `slangpy/builtin/value.py` | `can_direct_bind`, `gen_calldata`, `gen_trampoline_load/store`, `create_calldata` use `binding.direct_bind` | +| `slangpy/builtin/valueref.py` | `can_direct_bind`, all methods use `binding.direct_bind`. Removed `self._direct_bind`. | +| `slangpy/builtin/struct.py` | `can_direct_bind`, `gen_trampoline_load/store` use `binding.direct_bind` | +| `slangpy/builtin/tensorcommon.py` | `gen_trampoline_load/store` extended for `ITensorType` (unchanged in refactor) | | `slangpy/tests/slangpy_tests/test_kernel_gen.py` | All Phase 1 tests | **Test results:** 2952 passed / 0 failed in `slangpy/tests/slangpy_tests`. 6 pre-existing failures in `slangpy/tests/device/` (raytracing pipeline, type conformance cache — unrelated). -**Implementation note — `_force_no_direct_bind`:** The plan did not anticipate that children inside non-direct-bind composite structs (mixed dim-0/dim-N children) would incorrectly inherit direct binding from their leaf predicates. A `_force_no_direct_bind` flag was added: when `gen_call_data_code` takes the non-direct-bind struct path, it marks all children so their `is_direct_bind_eligible`/`is_direct_bind_recursive` return `False`. This prevents generating e.g. `typealias _t_velocity = vector` for a child inside a struct that still uses `__slangpy_load`. Similarly, `PackedArg` bindings (`create_param_block = True`) are excluded since `ParameterBlock` is invalid in Slang. - --- ### Gating Tests — Pre-Implementation Checklist diff --git a/slangpy/tests/slangpy_tests/test_kernel_gen.py b/slangpy/tests/slangpy_tests/test_kernel_gen.py index 3f7d9a1c8..efd42df0a 100644 --- a/slangpy/tests/slangpy_tests/test_kernel_gen.py +++ b/slangpy/tests/slangpy_tests/test_kernel_gen.py @@ -433,5 +433,246 @@ def test_phase1_functional_valueref_write(device_type: spy.DeviceType): assert out.value == 13 +# =========================================================================== +# Mixed direct-bind tests — some args direct, some not +# =========================================================================== + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_mixed_args_scalar_and_tensor(device_type: spy.DeviceType): + """Scalar arg gets direct-bind; vectorized tensor arg does not.""" + device = helpers.get_device(device_type) + tensor = Tensor.from_numpy(device, np.array([1, 2, 3], dtype=np.float32)) + code = generate_code( + device, + "add", + "float add(float a, float b) { return a + b; }", + 1.0, + tensor, + ) + # 'a' is direct-bind (scalar dim-0): raw typealias, direct trampoline load + assert_contains(code, "typealias _t_a = float;") + assert_not_contains(code, "ValueType") + assert_trampoline_has(code, "a = __calldata__.a;") + # 'b' is NOT direct-bind (vectorized tensor dim-1): uses Tensor, + # __slangpy_load, and mapping constant + assert_contains(code, "Tensor") + assert_contains(code, "__slangpy_load") + assert_contains(code, "_m_b") + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_mixed_args_direct_bind_flags(device_type: spy.DeviceType): + """Verify direct_bind flags on bindings for mixed scalar + tensor call.""" + device = helpers.get_device(device_type) + tensor = Tensor.from_numpy(device, np.array([1, 2, 3], dtype=np.float32)) + func = helpers.create_function_from_module( + device, "add", "float add(float a, float b) { return a + b; }" + ) + cd = func.debug_build_call_data(1.0, tensor) + bindings = cd.debug_only_bindings + assert bindings.args[0].direct_bind is True, "scalar arg 'a' should be direct_bind" + assert bindings.args[0].call_dimensionality == 0 + assert bindings.args[1].direct_bind is False, "tensor arg 'b' should NOT be direct_bind" + assert bindings.args[1].call_dimensionality == 1 + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_phase1_functional_mixed_scalar_tensor(device_type: spy.DeviceType): + """Dispatch mixed scalar + tensor and verify GPU result.""" + device = helpers.get_device(device_type) + func = helpers.create_function_from_module( + device, "add", "float add(float a, float b) { return a + b; }" + ) + tensor = Tensor.from_numpy(device, np.array([10, 20, 30], dtype=np.float32)) + result = func(5.0, tensor) + expected = np.array([15, 25, 35], dtype=np.float32) + np.testing.assert_allclose(result.to_numpy().flatten(), expected, atol=1e-5) + + +# =========================================================================== +# Struct with mixed direct-bind fields +# =========================================================================== + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_struct_mixed_fields_codegen(device_type: spy.DeviceType): + """Struct with one tensor field and one scalar field. + + The struct is NOT direct-bind because child x is vectorized (dim-1). + Child y (scalar) has direct_bind cleared by _clear_direct_bind, so it + uses ValueType wrapper — required for the parent's __slangpy_load. + """ + device = helpers.get_device(device_type) + src = """ +struct S { + float x; + float y; +}; +void apply(S s, float scale) {} +""" + tensor_x = Tensor.from_numpy(device, np.array([1, 2, 3], dtype=np.float32)) + code = generate_code(device, "apply", src, {"_type": "S", "x": tensor_x, "y": 1.0}, 2.0) + # Struct is NOT direct-bind: uses inline struct with __slangpy_load + assert_contains(code, "__slangpy_load") + assert_contains(code, "struct _t_s") + assert_not_contains(code, "typealias _t_s = S;") + # Child y should use ValueType wrapper (cleared by _clear_direct_bind) + assert_contains(code, "ValueType") + # Child x should use tensor type + assert_contains(code, "Tensor") + # Scalar arg 'scale' is independent — should still be direct-bind + assert_contains(code, "typealias _t_scale = float;") + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_struct_mixed_fields_binding_flags(device_type: spy.DeviceType): + """Verify direct_bind flags on struct children when struct is NOT direct-bind.""" + device = helpers.get_device(device_type) + src = """ +struct S { + float x; + float y; +}; +void apply(S s, float scale) {} +""" + tensor_x = Tensor.from_numpy(device, np.array([1, 2, 3], dtype=np.float32)) + func = helpers.create_function_from_module(device, "apply", src) + cd = func.debug_build_call_data({"_type": "S", "x": tensor_x, "y": 1.0}, 2.0) + bindings = cd.debug_only_bindings + s_binding = bindings.args[0] + assert s_binding.direct_bind is False, "struct 's' should NOT be direct_bind" + # Both children should have direct_bind=False (cleared by _clear_direct_bind) + assert s_binding.children["x"].direct_bind is False + assert s_binding.children["y"].direct_bind is False + # 'scale' is independent scalar — should be direct_bind + assert bindings.args[1].direct_bind is True + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_phase1_functional_struct_mixed_fields(device_type: spy.DeviceType): + """Dispatch struct with mixed tensor+scalar fields and verify GPU result.""" + device = helpers.get_device(device_type) + src = """ +struct S { + float x; + float y; +}; +float weighted_sum(S s, float scale) { return (s.x + s.y) * scale; } +""" + func = helpers.create_function_from_module(device, "weighted_sum", src) + tensor_x = Tensor.from_numpy(device, np.array([1, 2, 3], dtype=np.float32)) + result = func({"_type": "S", "x": tensor_x, "y": 10.0}, 2.0) + expected = np.array([22, 24, 26], dtype=np.float32) + np.testing.assert_allclose(result.to_numpy().flatten(), expected, atol=1e-5) + + +# =========================================================================== +# Tensor at dim-0 (whole tensor passed to Tensor parameter) +# =========================================================================== + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_tensor_dim0_codegen(device_type: spy.DeviceType): + """1D Tensor passed to Tensor param — dim-0, direct assignment.""" + device = helpers.get_device(device_type) + src = """ +float tensor_read(Tensor t) { + return t[0]; +} +""" + tensor = Tensor.from_numpy(device, np.array([42, 2, 3], dtype=np.float32)) + code = generate_code(device, "tensor_read", src, tensor) + # Type alias should use Tensor + assert_contains(code, "typealias _t_t = Tensor;") + # Trampoline uses direct assignment (not __slangpy_load) + assert_trampoline_has(code, "t = __calldata__.t;") + # No wrapper type for the tensor + assert_not_contains(code, "ValueType<") + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_tensor_dim0_binding_flags(device_type: spy.DeviceType): + """Tensor at dim-0 has direct_bind=False (tensor marshalls don't opt in).""" + device = helpers.get_device(device_type) + src = """ +float tensor_read(Tensor t) { + return t[0]; +} +""" + tensor = Tensor.from_numpy(device, np.array([42, 2, 3], dtype=np.float32)) + func = helpers.create_function_from_module(device, "tensor_read", src) + cd = func.debug_build_call_data(tensor) + bindings = cd.debug_only_bindings + t_binding = bindings.args[0] + # Tensor marshalls don't implement can_direct_bind — direct_bind stays False + assert t_binding.direct_bind is False + assert t_binding.call_dimensionality == 0 + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_phase1_functional_tensor_dim0(device_type: spy.DeviceType): + """Dispatch with whole tensor at dim-0 and verify GPU result.""" + device = helpers.get_device(device_type) + src = """ +float tensor_read(Tensor t) { + return t[0]; +} +""" + func = helpers.create_function_from_module(device, "tensor_read", src) + tensor = Tensor.from_numpy(device, np.array([42, 99, 7], dtype=np.float32)) + result = func(tensor) + assert abs(result - 42.0) < 1e-5 + + +# =========================================================================== +# _clear_direct_bind necessity — demonstrates the compile error without it +# =========================================================================== + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_clear_direct_bind_prevents_compile_error(device_type: spy.DeviceType): + """Demonstrate that _clear_direct_bind is necessary. + + Without _clear_direct_bind, a scalar child inside a non-direct-bind struct + would keep direct_bind=True. This makes it emit a raw typealias (e.g., + ``typealias _t_y = float;``) instead of ``ValueType``. But the + parent struct's ``__slangpy_load`` calls ``y.__slangpy_load(...)`` — which + doesn't exist on raw ``float``. The result is a Slang compile error: + + - ``undefined identifier '_m_y'`` (mapping constant skipped) + - ``'__slangpy_load' is not a member of 'float'`` + + This test monkey-patches _clear_direct_bind to a no-op and verifies the + compile error occurs, then checks normal behavior succeeds. + """ + from slangpy.bindings.boundvariable import BoundVariable + + device = helpers.get_device(device_type) + src = """ +struct S { + float x; + float y; +}; +float weighted_sum(S s, float scale) { return (s.x + s.y) * scale; } +""" + tensor_x = Tensor.from_numpy(device, np.array([1, 2, 3], dtype=np.float32)) + + # With _clear_direct_bind disabled: compile should fail + original_clear = BoundVariable._clear_direct_bind + try: + BoundVariable._clear_direct_bind = lambda self: None # type: ignore[assignment] + func_broken = helpers.create_function_from_module(device, "weighted_sum", src) + # with pytest.raises(ValueError, match="__slangpy_load"): + func_broken.debug_build_call_data({"_type": "S", "x": tensor_x, "y": 1.0}, 2.0) + finally: + BoundVariable._clear_direct_bind = original_clear # type: ignore[assignment] + + # With _clear_direct_bind intact: should succeed + func_ok = helpers.create_function_from_module(device, "weighted_sum", src) + cd = func_ok.debug_build_call_data({"_type": "S", "x": tensor_x, "y": 1.0}, 2.0) + assert cd.code is not None + + if __name__ == "__main__": pytest.main([__file__, "-vs"]) From ec3fc1e82e9749a343f176eaefcbb6625ec12994 Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Wed, 11 Mar 2026 11:28:08 +0000 Subject: [PATCH 05/41] start switching to using the gen load+store for everything --- slangpy/bindings/boundvariable.py | 17 ++--- slangpy/bindings/marshall.py | 20 ++++-- slangpy/builtin/struct.py | 18 ++---- slangpy/builtin/tensor.py | 8 +-- slangpy/builtin/tensorcommon.py | 24 +++---- slangpy/builtin/value.py | 10 +-- slangpy/builtin/valueref.py | 18 ++---- slangpy/core/callsignature.py | 36 +++++------ .../tests/slangpy_tests/test_kernel_gen.py | 64 ++++++++----------- .../torchintegration/torchtensormarshall.py | 8 +-- 10 files changed, 96 insertions(+), 127 deletions(-) diff --git a/slangpy/bindings/boundvariable.py b/slangpy/bindings/boundvariable.py index b6f5f470f..57f5070c8 100644 --- a/slangpy/bindings/boundvariable.py +++ b/slangpy/bindings/boundvariable.py @@ -520,22 +520,10 @@ def calculate_direct_bind(self) -> None: and all(child.direct_bind for child in self.children.values()) ): self.direct_bind = True - else: - # Parent is not direct-bind — children must use wrapper types - # so the parent's generated __slangpy_load/store can call theirs. - for child in self.children.values(): - child._clear_direct_bind() else: if self.python is not None and hasattr(self.python, "can_direct_bind"): self.direct_bind = self.python.can_direct_bind(self) - def _clear_direct_bind(self) -> None: - """Recursively clear direct_bind on this node and all descendants.""" - self.direct_bind = False - if self.children is not None: - for child in self.children.values(): - child._clear_direct_bind() - def get_input_list(self, args: list["BoundVariable"]): """ Recursively populate flat list of argument nodes @@ -626,6 +614,11 @@ def gen_call_data_code(self, cg: CodeGen, context: BindContext, depth: int = 0): ) cgb.begin_block() for field, var in self.children.items(): + gen_load = getattr(var.python, "gen_trampoline_load", None) + if gen_load is not None and gen_load( + cgb, var, var.variable_name, f"value.{field}" + ): + continue cgb.append_statement( f"{var.variable_name}.__slangpy_load(context.map(_m_{var.variable_name}),value.{field})" ) diff --git a/slangpy/bindings/marshall.py b/slangpy/bindings/marshall.py index 348a52f9d..0aa719a3f 100644 --- a/slangpy/bindings/marshall.py +++ b/slangpy/bindings/marshall.py @@ -108,27 +108,35 @@ def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: "BoundV return super().gen_calldata(cgb, context, binding) def gen_trampoline_load( - self, cgb: CodeGenBlock, binding: "BoundVariable", is_entry_point: bool + self, cgb: CodeGenBlock, binding: "BoundVariable", data_name: str, value_name: str ) -> bool: """ - Generate custom trampoline load code for this parameter. + Generate custom load code for this parameter. + + Works universally for both root-level trampoline parameters and + children inside composite ``__slangpy_load`` bodies. :param cgb: Code generation block to append load statements to. :param binding: The bound variable being loaded. - :param is_entry_point: Whether the trampoline is an entry point kernel. + :param data_name: Expression referencing the stored data (e.g. ``call_data.x`` or ``x``). + :param value_name: Expression referencing the destination value (e.g. ``x`` or ``value.x``). :return: True if handled (skip standard __slangpy_load), False for default behavior. """ return False def gen_trampoline_store( - self, cgb: CodeGenBlock, binding: "BoundVariable", is_entry_point: bool + self, cgb: CodeGenBlock, binding: "BoundVariable", data_name: str, value_name: str ) -> bool: """ - Generate custom trampoline store code for this parameter. + Generate custom store code for this parameter. + + Works universally for both root-level trampoline parameters and + children inside composite ``__slangpy_store`` bodies. :param cgb: Code generation block to append store statements to. :param binding: The bound variable being stored. - :param is_entry_point: Whether the trampoline is an entry point kernel. + :param data_name: Expression referencing the stored data (e.g. ``call_data.x`` or ``x``). + :param value_name: Expression referencing the source value (e.g. ``x`` or ``value.x``). :return: True if handled (skip standard __slangpy_store), False for default behavior. """ return False diff --git a/slangpy/builtin/struct.py b/slangpy/builtin/struct.py index 2831493f3..4b70fabd7 100644 --- a/slangpy/builtin/struct.py +++ b/slangpy/builtin/struct.py @@ -84,30 +84,20 @@ def can_direct_bind(self, binding: "BoundVariable") -> bool: # A struct type should get a dictionary, and just return that for raw dispatch def gen_trampoline_load( - self, cgb: "CodeGenBlock", binding: "BoundVariable", is_entry_point: bool + self, cgb: "CodeGenBlock", binding: "BoundVariable", data_name: str, value_name: str ) -> bool: if not binding.direct_bind: return False - data_name = ( - f"_param_{binding.variable_name}" - if binding.create_param_block - else f"{'__calldata__' if is_entry_point else 'call_data'}.{binding.variable_name}" - ) - cgb.append_statement(f"{binding.variable_name} = {data_name}") + cgb.append_statement(f"{value_name} = {data_name}") return True def gen_trampoline_store( - self, cgb: "CodeGenBlock", binding: "BoundVariable", is_entry_point: bool + self, cgb: "CodeGenBlock", binding: "BoundVariable", data_name: str, value_name: str ) -> bool: if not binding.direct_bind: return False if binding.access[0] in (AccessType.write, AccessType.readwrite): - data_name = ( - f"_param_{binding.variable_name}" - if binding.create_param_block - else f"{'__calldata__' if is_entry_point else 'call_data'}.{binding.variable_name}" - ) - cgb.append_statement(f"{data_name} = {binding.variable_name}") + cgb.append_statement(f"{data_name} = {value_name}") return True def create_dispatchdata(self, data: Any) -> Any: diff --git a/slangpy/builtin/tensor.py b/slangpy/builtin/tensor.py index e8561e4a2..70951d043 100644 --- a/slangpy/builtin/tensor.py +++ b/slangpy/builtin/tensor.py @@ -139,14 +139,14 @@ def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: BoundVa return spytc.gen_calldata(self, cgb, context, binding) def gen_trampoline_load( - self, cgb: CodeGenBlock, binding: BoundVariable, is_entry_point: bool + self, cgb: CodeGenBlock, binding: BoundVariable, data_name: str, value_name: str ) -> bool: - return spytc.gen_trampoline_load(self, cgb, binding, is_entry_point) + return spytc.gen_trampoline_load(self, cgb, binding, data_name, value_name) def gen_trampoline_store( - self, cgb: CodeGenBlock, binding: BoundVariable, is_entry_point: bool + self, cgb: CodeGenBlock, binding: BoundVariable, data_name: str, value_name: str ) -> bool: - return spytc.gen_trampoline_store(self, cgb, binding, is_entry_point) + return spytc.gen_trampoline_store(self, cgb, binding, data_name, value_name) def build_shader_object(self, context: "BindContext", data: Any) -> "ShaderObject": so = context.device.create_shader_object(self.slang_type.uniform_layout.reflection) diff --git a/slangpy/builtin/tensorcommon.py b/slangpy/builtin/tensorcommon.py index 0eae47b5a..8d4e42e54 100644 --- a/slangpy/builtin/tensorcommon.py +++ b/slangpy/builtin/tensorcommon.py @@ -380,7 +380,11 @@ def gen_calldata( def gen_trampoline_load( - self: ITensorMarshall, cgb: CodeGenBlock, binding: BoundVariable, is_entry_point: bool + self: ITensorMarshall, + cgb: CodeGenBlock, + binding: BoundVariable, + data_name: str, + value_name: str, ) -> bool: if not isinstance(binding.vector_type, (TensorViewType, DiffTensorViewType)): # For ITensorType at dim-0, use direct assignment (struct copy) @@ -389,23 +393,19 @@ def gen_trampoline_load( and binding.call_dimensionality is not None and binding.call_dimensionality == 0 ): - if is_entry_point: - data_name = f"__calldata__.{binding.variable_name}" - else: - data_name = f"call_data.{binding.variable_name}" - cgb.append_statement(f"{binding.variable_name} = {data_name}") + cgb.append_statement(f"{value_name} = {data_name}") return True return False - if is_entry_point: - data_name = f"__calldata__.{binding.variable_name}" - else: - data_name = f"call_data.{binding.variable_name}" - cgb.append_statement(f"{binding.variable_name} = {data_name}") + cgb.append_statement(f"{value_name} = {data_name}") return True def gen_trampoline_store( - self: ITensorMarshall, cgb: CodeGenBlock, binding: BoundVariable, is_entry_point: bool + self: ITensorMarshall, + cgb: CodeGenBlock, + binding: BoundVariable, + data_name: str, + value_name: str, ) -> bool: if not isinstance(binding.vector_type, (TensorViewType, DiffTensorViewType)): # For ITensorType at dim-0, suppress default store diff --git a/slangpy/builtin/value.py b/slangpy/builtin/value.py index 8f7c59fea..9542b81ea 100644 --- a/slangpy/builtin/value.py +++ b/slangpy/builtin/value.py @@ -107,21 +107,17 @@ def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: "BoundV cgb.type_alias(f"_t_{name}", f"NoneType") def gen_trampoline_load( - self, cgb: CodeGenBlock, binding: "BoundVariable", is_entry_point: bool + self, cgb: CodeGenBlock, binding: "BoundVariable", data_name: str, value_name: str ) -> bool: if not binding.direct_bind: return False if binding.access[0] not in (AccessType.read, AccessType.readwrite): return False - if is_entry_point: - data_name = f"__calldata__.{binding.variable_name}" - else: - data_name = f"call_data.{binding.variable_name}" - cgb.append_statement(f"{binding.variable_name} = {data_name}") + cgb.append_statement(f"{value_name} = {data_name}") return True def gen_trampoline_store( - self, cgb: CodeGenBlock, binding: "BoundVariable", is_entry_point: bool + self, cgb: CodeGenBlock, binding: "BoundVariable", data_name: str, value_name: str ) -> bool: if not binding.direct_bind: return False diff --git a/slangpy/builtin/valueref.py b/slangpy/builtin/valueref.py index cbe226948..e222dcffb 100644 --- a/slangpy/builtin/valueref.py +++ b/slangpy/builtin/valueref.py @@ -174,33 +174,25 @@ def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: "BoundV cgb.type_alias(f"_t_{name}", f"RWValueRef<{binding.vector_type.full_name}>") def gen_trampoline_load( - self, cgb: CodeGenBlock, binding: "BoundVariable", is_entry_point: bool + self, cgb: CodeGenBlock, binding: "BoundVariable", data_name: str, value_name: str ) -> bool: if not binding.direct_bind: return False if binding.access[0] == AccessType.none: return False - if is_entry_point: - data_name = f"__calldata__.{binding.variable_name}" - else: - data_name = f"call_data.{binding.variable_name}" if binding.access[0] == AccessType.read: - cgb.append_statement(f"{binding.variable_name} = {data_name}") + cgb.append_statement(f"{value_name} = {data_name}") else: - cgb.append_statement(f"{binding.variable_name} = {data_name}[0]") + cgb.append_statement(f"{value_name} = {data_name}[0]") return True def gen_trampoline_store( - self, cgb: CodeGenBlock, binding: "BoundVariable", is_entry_point: bool + self, cgb: CodeGenBlock, binding: "BoundVariable", data_name: str, value_name: str ) -> bool: if not binding.direct_bind: return False if binding.access[0] in (AccessType.write, AccessType.readwrite): - if is_entry_point: - data_name = f"__calldata__.{binding.variable_name}" - else: - data_name = f"call_data.{binding.variable_name}" - cgb.append_statement(f"{data_name}[0] = {binding.variable_name}") + cgb.append_statement(f"{data_name}[0] = {value_name}") return True # Call data just returns the primal diff --git a/slangpy/core/callsignature.py b/slangpy/core/callsignature.py index 369e625e4..92093d573 100644 --- a/slangpy/core/callsignature.py +++ b/slangpy/core/callsignature.py @@ -378,22 +378,22 @@ def generate_code( assert x.vector_type is not None cg.trampoline.declare(x.vector_type.full_name, x.variable_name) for x in root_params: + if is_entry_point: + data_name = ( + f"_param_{x.variable_name}" + if x.create_param_block + else f"__calldata__.{x.variable_name}" + ) + else: + data_name = ( + f"_param_{x.variable_name}" + if x.create_param_block + else f"call_data.{x.variable_name}" + ) gen_load = getattr(x.python, "gen_trampoline_load", None) - if gen_load is not None and gen_load(cg.trampoline, x, is_entry_point): + if gen_load is not None and gen_load(cg.trampoline, x, data_name, x.variable_name): continue if x.access[0] == AccessType.read or x.access[0] == AccessType.readwrite: - if is_entry_point: - data_name = ( - f"_param_{x.variable_name}" - if x.create_param_block - else f"__calldata__.{x.variable_name}" - ) - else: - data_name = ( - f"_param_{x.variable_name}" - if x.create_param_block - else f"call_data.{x.variable_name}" - ) cg.trampoline.append_statement( f"{data_name}.__slangpy_load(__slangpy_context__.map(_m_{x.variable_name}), {x.variable_name})" ) @@ -429,11 +429,6 @@ def generate_code( or x.access[0] == AccessType.readwrite or x.access[1] == AccessType.read ): - gen_store = getattr(x.python, "gen_trampoline_store", None) - if gen_store is not None and gen_store(cg.trampoline, x, is_entry_point): - continue - if not x.python.is_writable: - raise BoundVariableException(f"Cannot read back value for non-writable type", x) if is_entry_point: data_name = ( f"_param_{x.variable_name}" @@ -446,6 +441,11 @@ def generate_code( if x.create_param_block else f"call_data.{x.variable_name}" ) + gen_store = getattr(x.python, "gen_trampoline_store", None) + if gen_store is not None and gen_store(cg.trampoline, x, data_name, x.variable_name): + continue + if not x.python.is_writable: + raise BoundVariableException(f"Cannot read back value for non-writable type", x) cg.trampoline.append_statement( f"{data_name}.__slangpy_store(__slangpy_context__.map(_m_{x.variable_name}), {x.variable_name})" ) diff --git a/slangpy/tests/slangpy_tests/test_kernel_gen.py b/slangpy/tests/slangpy_tests/test_kernel_gen.py index efd42df0a..f1a089db0 100644 --- a/slangpy/tests/slangpy_tests/test_kernel_gen.py +++ b/slangpy/tests/slangpy_tests/test_kernel_gen.py @@ -500,8 +500,8 @@ def test_gate_struct_mixed_fields_codegen(device_type: spy.DeviceType): """Struct with one tensor field and one scalar field. The struct is NOT direct-bind because child x is vectorized (dim-1). - Child y (scalar) has direct_bind cleared by _clear_direct_bind, so it - uses ValueType wrapper — required for the parent's __slangpy_load. + Child y (scalar) keeps direct_bind=True — gen_call_data_code emits + direct assignment (value.y = y) instead of y.__slangpy_load(...). """ device = helpers.get_device(device_type) src = """ @@ -517,8 +517,10 @@ def test_gate_struct_mixed_fields_codegen(device_type: spy.DeviceType): assert_contains(code, "__slangpy_load") assert_contains(code, "struct _t_s") assert_not_contains(code, "typealias _t_s = S;") - # Child y should use ValueType wrapper (cleared by _clear_direct_bind) - assert_contains(code, "ValueType") + # Child y is direct-bind: raw type alias, direct assignment in __slangpy_load + assert_contains(code, "typealias _t_y = float;") + assert_contains(code, "value.y = y;") + assert_not_contains(code, "ValueType") # Child x should use tensor type assert_contains(code, "Tensor") # Scalar arg 'scale' is independent — should still be direct-bind @@ -542,9 +544,10 @@ def test_gate_struct_mixed_fields_binding_flags(device_type: spy.DeviceType): bindings = cd.debug_only_bindings s_binding = bindings.args[0] assert s_binding.direct_bind is False, "struct 's' should NOT be direct_bind" - # Both children should have direct_bind=False (cleared by _clear_direct_bind) + # Child x is a tensor (dim-1), not direct-bind assert s_binding.children["x"].direct_bind is False - assert s_binding.children["y"].direct_bind is False + # Child y is a scalar (dim-0), keeps its direct_bind status + assert s_binding.children["y"].direct_bind is True # 'scale' is independent scalar — should be direct_bind assert bindings.args[1].direct_bind is True @@ -626,28 +629,20 @@ def test_phase1_functional_tensor_dim0(device_type: spy.DeviceType): # =========================================================================== -# _clear_direct_bind necessity — demonstrates the compile error without it +# Mixed direct-bind children in non-direct-bind struct — validates that +# gen_call_data_code correctly uses direct assignment for direct-bind +# children and __slangpy_load for non-direct-bind children. # =========================================================================== @pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_clear_direct_bind_prevents_compile_error(device_type: spy.DeviceType): - """Demonstrate that _clear_direct_bind is necessary. - - Without _clear_direct_bind, a scalar child inside a non-direct-bind struct - would keep direct_bind=True. This makes it emit a raw typealias (e.g., - ``typealias _t_y = float;``) instead of ``ValueType``. But the - parent struct's ``__slangpy_load`` calls ``y.__slangpy_load(...)`` — which - doesn't exist on raw ``float``. The result is a Slang compile error: - - - ``undefined identifier '_m_y'`` (mapping constant skipped) - - ``'__slangpy_load' is not a member of 'float'`` +def test_mixed_children_direct_bind_codegen(device_type: spy.DeviceType): + """Validate code gen for struct with mixed direct-bind / non-direct-bind children. - This test monkey-patches _clear_direct_bind to a no-op and verifies the - compile error occurs, then checks normal behavior succeeds. + Scalar child y gets direct assignment (value.y = y) inside __slangpy_load. + Tensor child x goes through __slangpy_load with context mapping. + Both patterns coexist in the same generated struct. """ - from slangpy.bindings.boundvariable import BoundVariable - device = helpers.get_device(device_type) src = """ struct S { @@ -657,21 +652,16 @@ def test_clear_direct_bind_prevents_compile_error(device_type: spy.DeviceType): float weighted_sum(S s, float scale) { return (s.x + s.y) * scale; } """ tensor_x = Tensor.from_numpy(device, np.array([1, 2, 3], dtype=np.float32)) - - # With _clear_direct_bind disabled: compile should fail - original_clear = BoundVariable._clear_direct_bind - try: - BoundVariable._clear_direct_bind = lambda self: None # type: ignore[assignment] - func_broken = helpers.create_function_from_module(device, "weighted_sum", src) - # with pytest.raises(ValueError, match="__slangpy_load"): - func_broken.debug_build_call_data({"_type": "S", "x": tensor_x, "y": 1.0}, 2.0) - finally: - BoundVariable._clear_direct_bind = original_clear # type: ignore[assignment] - - # With _clear_direct_bind intact: should succeed - func_ok = helpers.create_function_from_module(device, "weighted_sum", src) - cd = func_ok.debug_build_call_data({"_type": "S", "x": tensor_x, "y": 1.0}, 2.0) - assert cd.code is not None + code = generate_code(device, "weighted_sum", src, {"_type": "S", "x": tensor_x, "y": 1.0}, 2.0) + # Child y uses raw type and direct assignment + assert_contains(code, "typealias _t_y = float;") + assert_contains(code, "value.y = y;") + # No mapping constant for y (direct-bind skips it) + assert_not_contains(code, "_m_y") + # Child x uses tensor wrapper with __slangpy_load + assert_contains(code, "x.__slangpy_load(context.map(_m_x),value.x)") + # The struct itself is not direct-bind + assert_contains(code, "struct _t_s") if __name__ == "__main__": diff --git a/slangpy/torchintegration/torchtensormarshall.py b/slangpy/torchintegration/torchtensormarshall.py index 4fd841186..970dde412 100644 --- a/slangpy/torchintegration/torchtensormarshall.py +++ b/slangpy/torchintegration/torchtensormarshall.py @@ -212,14 +212,14 @@ def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: BoundVa return spytc.gen_calldata(self, cgb, context, binding) def gen_trampoline_load( - self, cgb: CodeGenBlock, binding: BoundVariable, is_entry_point: bool + self, cgb: CodeGenBlock, binding: BoundVariable, data_name: str, value_name: str ) -> bool: - return spytc.gen_trampoline_load(self, cgb, binding, is_entry_point) + return spytc.gen_trampoline_load(self, cgb, binding, data_name, value_name) def gen_trampoline_store( - self, cgb: CodeGenBlock, binding: BoundVariable, is_entry_point: bool + self, cgb: CodeGenBlock, binding: BoundVariable, data_name: str, value_name: str ) -> bool: - return spytc.gen_trampoline_store(self, cgb, binding, is_entry_point) + return spytc.gen_trampoline_store(self, cgb, binding, data_name, value_name) def build_shader_object(self, context: BindContext, data: torch.Tensor) -> ShaderObject: """Build shader object for dispatch.""" From d8e60b5da85f1726ff31bf29bb4dc3a3f6f89e21 Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Wed, 11 Mar 2026 12:10:02 +0000 Subject: [PATCH 06/41] Fix binding issues --- slangpy/bindings/boundvariable.py | 5 +++++ slangpy/builtin/struct.py | 7 ++----- slangpy/builtin/value.py | 5 +++-- slangpy/builtin/valueref.py | 25 +++++++++---------------- 4 files changed, 19 insertions(+), 23 deletions(-) diff --git a/slangpy/bindings/boundvariable.py b/slangpy/bindings/boundvariable.py index 57f5070c8..5d0b74e36 100644 --- a/slangpy/bindings/boundvariable.py +++ b/slangpy/bindings/boundvariable.py @@ -631,6 +631,11 @@ def gen_call_data_code(self, cg: CodeGen, context: BindContext, depth: int = 0): ) cgb.begin_block() for field, var in self.children.items(): + gen_store = getattr(var.python, "gen_trampoline_store", None) + if gen_store is not None and gen_store( + cgb, var, var.variable_name, f"value.{field}" + ): + continue cgb.append_statement( f"{var.variable_name}.__slangpy_store(context.map(_m_{var.variable_name}),value.{field})" ) diff --git a/slangpy/builtin/struct.py b/slangpy/builtin/struct.py index 4b70fabd7..2a39d98dd 100644 --- a/slangpy/builtin/struct.py +++ b/slangpy/builtin/struct.py @@ -88,17 +88,14 @@ def gen_trampoline_load( ) -> bool: if not binding.direct_bind: return False - cgb.append_statement(f"{value_name} = {data_name}") - return True + return super().gen_trampoline_load(cgb, binding, data_name, value_name) def gen_trampoline_store( self, cgb: "CodeGenBlock", binding: "BoundVariable", data_name: str, value_name: str ) -> bool: if not binding.direct_bind: return False - if binding.access[0] in (AccessType.write, AccessType.readwrite): - cgb.append_statement(f"{data_name} = {value_name}") - return True + return super().gen_trampoline_store(cgb, binding, data_name, value_name) def create_dispatchdata(self, data: Any) -> Any: if isinstance(data, dict): diff --git a/slangpy/builtin/value.py b/slangpy/builtin/value.py index 9542b81ea..6d30b33a5 100644 --- a/slangpy/builtin/value.py +++ b/slangpy/builtin/value.py @@ -112,8 +112,9 @@ def gen_trampoline_load( if not binding.direct_bind: return False if binding.access[0] not in (AccessType.read, AccessType.readwrite): - return False - cgb.append_statement(f"{value_name} = {data_name}") + cgb.append_statement(f"{value_name} = {{}}") + else: + cgb.append_statement(f"{value_name} = {data_name}") return True def gen_trampoline_store( diff --git a/slangpy/builtin/valueref.py b/slangpy/builtin/valueref.py index e222dcffb..012c6e707 100644 --- a/slangpy/builtin/valueref.py +++ b/slangpy/builtin/valueref.py @@ -150,7 +150,11 @@ def resolve_dimensionality( return len(self.value_type.shape) - len(vector_target_type.shape) def can_direct_bind(self, binding: "BoundVariable") -> bool: - return can_direct_bind_common(binding) + if not can_direct_bind_common(binding): + return False + if binding.access[0] != AccessType.read: + return False + return True # Call data can only be read access to primal, and simply declares it as a variable def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: "BoundVariable"): @@ -160,13 +164,8 @@ def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: "BoundV assert access[1] == AccessType.none assert binding.vector_type is not None if binding.direct_bind: - if access[0] == AccessType.read: - cgb.type_alias(f"_t_{name}", binding.vector_type.full_name) - else: - cgb.type_alias( - f"_t_{name}", - f"RWStructuredBuffer<{binding.vector_type.full_name}>", - ) + assert access[0] == AccessType.read + cgb.type_alias(f"_t_{name}", binding.vector_type.full_name) else: if access[0] == AccessType.read: cgb.type_alias(f"_t_{name}", f"ValueRef<{binding.vector_type.full_name}>") @@ -178,12 +177,8 @@ def gen_trampoline_load( ) -> bool: if not binding.direct_bind: return False - if binding.access[0] == AccessType.none: - return False - if binding.access[0] == AccessType.read: - cgb.append_statement(f"{value_name} = {data_name}") - else: - cgb.append_statement(f"{value_name} = {data_name}[0]") + assert binding.access[0] == AccessType.read + cgb.append_statement(f"{value_name} = {data_name}") return True def gen_trampoline_store( @@ -191,8 +186,6 @@ def gen_trampoline_store( ) -> bool: if not binding.direct_bind: return False - if binding.access[0] in (AccessType.write, AccessType.readwrite): - cgb.append_statement(f"{data_name}[0] = {value_name}") return True # Call data just returns the primal From 5f30772d4d27d5a2600b00febfc8f62bdc68b744 Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Wed, 11 Mar 2026 12:23:28 +0000 Subject: [PATCH 07/41] Fix some tests --- .../plan-simplifyKernelGen-phase1.prompt.md | 31 ++++++++++--------- .../prompts/plan-simplifyKernelGen.prompt.md | 10 +++--- .../tests/slangpy_tests/test_kernel_gen.py | 22 ++++++------- 3 files changed, 33 insertions(+), 30 deletions(-) diff --git a/.github/prompts/plan-simplifyKernelGen-phase1.prompt.md b/.github/prompts/plan-simplifyKernelGen-phase1.prompt.md index 09159ceda..3c7576a84 100644 --- a/.github/prompts/plan-simplifyKernelGen-phase1.prompt.md +++ b/.github/prompts/plan-simplifyKernelGen-phase1.prompt.md @@ -19,7 +19,7 @@ Direct binding eligibility is determined by a **marshall-driven `can_direct_bind | `Marshall.can_direct_bind(binding)` | `slangpy/bindings/marshall.py` | Virtual method (default `False`). Marshalls override to opt in. | | `can_direct_bind_common(binding)` | `slangpy/bindings/boundvariable.py` | Shared eligibility checks (dim-0, no children, no param block). Marshalls call this then add type-specific logic. | | `BoundVariable.direct_bind` | `slangpy/bindings/boundvariable.py` | Boolean attribute set by `calculate_direct_bind()`. Consumed by `gen_call_data_code`, `gen_calldata`, `gen_trampoline_load/store`, `create_calldata`. | -| `BoundVariable.calculate_direct_bind()` | `slangpy/bindings/boundvariable.py` | Depth-first tree pass. Leaves delegate to `marshall.can_direct_bind()`. Composites require all children to be direct-bind AND dim-0 with a concrete vector type. If composite is NOT direct-bind, recursively clears children via `_clear_direct_bind()`. | +| `BoundVariable.calculate_direct_bind()` | `slangpy/bindings/boundvariable.py` | Depth-first tree pass. Leaves delegate to `marshall.can_direct_bind()`. Composites require all children to be direct-bind AND dim-0 with a concrete vector type. Children retain their individual `direct_bind` status regardless of the parent's eligibility. | | `calculate_direct_binding(call)` | `slangpy/core/callsignature.py` | Top-level function iterating `call.args` + `call.kwargs.values()`, calling `arg.calculate_direct_bind()`. | | `NativeBoundVariableRuntime.direct_bind` | `slangpy.h` / `boundvariableruntime.py` | C++ member + Python propagation. Read by `NativeValueMarshall::ensure_cached` to gate `["value"]` sub-field navigation. | @@ -44,7 +44,7 @@ At dispatch time, `NativeValueMarshall::ensure_cached()` reads `binding->direct_ When `calculate_direct_bind()` visits a composite node: 1. Recurse children first (depth-first) 2. If all children have `direct_bind == True` AND the composite is dim-0 with a concrete vector type → set `self.direct_bind = True` -3. Otherwise → call `_clear_direct_bind()` on all children, forcing them to use wrapper types. This is necessary because the parent's generated `__slangpy_load`/`__slangpy_store` expects children to have wrapper types (e.g., `ValueType`). A child emitting raw `float` inside a parent that emits `__slangpy_load` would produce invalid Slang. +3. Otherwise → the composite is NOT direct-bind, but children **retain** their individual `direct_bind` status. Inside the parent's generated `__slangpy_load`/`__slangpy_store`, `gen_call_data_code` delegates to each child's `gen_trampoline_load`/`gen_trampoline_store` — direct-bind children get direct assignment (e.g., `value.y = y;`) while non-direct-bind children use the standard `__slangpy_load(context.map(...))` path. This allows mixed direct-bind / non-direct-bind children within the same struct. --- @@ -57,7 +57,7 @@ A shared helper `can_direct_bind_common(binding)` in `boundvariable.py` provides - `not binding.children` (not composite/dict) - `not getattr(binding, "create_param_block", False)` (excludes `PackedArg`) -Marshall subclasses call `can_direct_bind_common(binding)` and optionally add type-specific logic. `StructMarshall` has its own implementation: if it has children, all children must have `direct_bind == True`; otherwise it delegates to `can_direct_bind_common`. +Marshall subclasses call `can_direct_bind_common(binding)` and optionally add type-specific logic. `StructMarshall` has its own implementation: if it has children, all children must have `direct_bind == True`; otherwise it delegates to `can_direct_bind_common`. `ValueRefMarshall` additionally requires `binding.access[0] == AccessType.read` — writable value refs need buffer read/write logic that is incompatible with direct binding. --- @@ -98,14 +98,12 @@ The `m_direct_bind` / `direct_bind` / `set_direct_bind` members were **removed** **Implemented.** In [slangpy/builtin/struct.py](slangpy/builtin/struct.py): - `can_direct_bind(binding)`: if `binding.children is not None`, returns `True` only if all children have `direct_bind == True`. Otherwise delegates to `can_direct_bind_common(binding)`. -- `gen_trampoline_load`: when `binding.direct_bind`, emits `{name} = {data_name}` and returns `True` -- `gen_trampoline_store`: when `binding.direct_bind`, emits `{data_name} = {name}` for writable and returns `True` +- `gen_trampoline_load`: when `binding.direct_bind`, delegates to `ValueMarshall.gen_trampoline_load` (emits `{name} = {data_name}`) and returns `True`. Direct-bind structs are read-only, like other value types. +- `gen_trampoline_store`: when `binding.direct_bind`, delegates to `ValueMarshall.gen_trampoline_store` (suppresses store for read-only). Returns `True`. In [slangpy/bindings/boundvariable.py](slangpy/bindings/boundvariable.py), `gen_call_data_code`: - When `self.direct_bind`, emits `typealias _t_{name} = {vector_type.full_name}` (raw struct type) — skipping inline struct generation, `__slangpy_load`/`__slangpy_store`, and child type aliases. -- When NOT `self.direct_bind`, uses the standard children path with inline struct. - -Children inside non-direct-bind composites have their `direct_bind` cleared by `_clear_direct_bind()` during `calculate_direct_bind`. This ensures children use wrapper types compatible with the parent's `__slangpy_load`/`__slangpy_store`. +- When NOT `self.direct_bind`, uses the standard children path with inline struct. Children **retain** their individual `direct_bind` status — `gen_call_data_code` calls each child's `gen_trampoline_load`/`gen_trampoline_store`, which emit direct assignment for direct-bind children and fall through to `__slangpy_load`/`__slangpy_store` for non-direct-bind children. --- @@ -113,13 +111,16 @@ Children inside non-direct-bind composites have their `direct_bind` cleared by ` **Implemented.** In [slangpy/builtin/valueref.py](slangpy/builtin/valueref.py): -- `can_direct_bind(binding)`: calls `can_direct_bind_common(binding)` -- `gen_calldata`: when `binding.direct_bind`, read-only emits raw type, writable emits `RWStructuredBuffer` -- `gen_trampoline_load/store`: when `binding.direct_bind`, read-only does direct assignment, writable does `[0]` indexing -- `create_calldata` / `read_calldata`: when `binding.direct_bind`, skip `{"value": ...}` wrapper +- `can_direct_bind(binding)`: calls `can_direct_bind_common(binding)` AND requires `binding.access[0] == AccessType.read`. Writable value refs are NOT direct-bind eligible because they need buffer allocation and readback logic that requires the wrapper path. +- `gen_calldata`: when `binding.direct_bind`, emits raw type alias (read-only only). Non-direct-bind uses `ValueRef` / `RWValueRef` as before. +- `gen_trampoline_load`: when `binding.direct_bind`, emits direct assignment. Non-direct-bind falls through. +- `gen_trampoline_store`: when `binding.direct_bind`, returns `True` (suppress store for read-only). Non-direct-bind falls through. +- `create_calldata` / `read_calldata`: when `binding.direct_bind` AND read-only, returns raw value / skips readback. The old `self._direct_bind` attribute was **removed** — all checks now use `binding.direct_bind`. +**Implication for `_result`:** Auto-created return values are writable `ValueRef` instances. Since writable value refs are not direct-bind eligible, `_result` uses `RWValueRef` with `__slangpy_store`, mapping constants, and the standard wrapper path. This is a deliberate constraint — writable value refs inside structs would prevent the struct from being direct-bind eligible, which is the correct behavior since the struct's `__slangpy_load`/`__slangpy_store` must exist to handle the buffer operations. + --- ### Step 1.6: Implement for tensor marshalls @@ -157,7 +158,7 @@ The old `self._direct_bind` attribute was **removed** — all checks now use `bi | `src/slangpy_ext/utils/slangpyvalue.h` | `m_direct_bind`, `direct_bind()`, `set_direct_bind()` **removed** from `NativeValueMarshall` | | `src/slangpy_ext/utils/slangpyvalue.cpp` | `ensure_cached` reads `binding->direct_bind()` instead of `m_direct_bind`; nanobind `direct_bind` property **removed** from `NativeValueMarshall` | | `slangpy/bindings/marshall.py` | `can_direct_bind(binding)` virtual method (default `False`) | -| `slangpy/bindings/boundvariable.py` | `can_direct_bind_common()`, `BoundVariable.direct_bind` attribute, `BoundVariable.calculate_direct_bind()`, `BoundVariable._clear_direct_bind()`. Old functions removed: `is_direct_bind_eligible`, `is_direct_bind_recursive`, `_set_direct_bind_on_children`, `_force_no_direct_bind`, `_DIRECT_BIND_TYPES`. | +| `slangpy/bindings/boundvariable.py` | `can_direct_bind_common()`, `BoundVariable.direct_bind` attribute, `BoundVariable.calculate_direct_bind()`. Old functions removed: `is_direct_bind_eligible`, `is_direct_bind_recursive`, `_set_direct_bind_on_children`, `_force_no_direct_bind`, `_DIRECT_BIND_TYPES`, `_clear_direct_bind()`. | | `slangpy/bindings/boundvariableruntime.py` | `self.direct_bind = source.direct_bind` propagation | | `slangpy/bindings/__init__.py` | Exports `can_direct_bind_common` (removed `is_direct_bind_eligible`, `is_direct_bind_recursive`) | | `slangpy/core/callsignature.py` | `calculate_direct_binding(call)` function | @@ -180,4 +181,6 @@ The old `self._direct_bind` attribute was **removed** — all checks now use `bi **Single `calculate_direct_bind` pass replaces repeated predicate calls.** The original `is_direct_bind_eligible` / `is_direct_bind_recursive` were called multiple times per variable during code gen. The new design computes `direct_bind` once in a single tree pass after `calculate_differentiability`, and consumers simply read the boolean. -**`_clear_direct_bind` replaces `_force_no_direct_bind`.** When a composite struct is NOT direct-bind-eligible (e.g., has vectorized children), its children must NOT use direct binding either — the parent's generated `__slangpy_load`/`__slangpy_store` expects children to have wrapper types. The old implementation set `_force_no_direct_bind = True` on children during code gen. The new implementation clears `direct_bind` recursively during the `calculate_direct_bind` pass itself, before code gen runs. +**Children retain `direct_bind` in non-direct-bind composites.** When a composite struct is NOT direct-bind-eligible (e.g., has vectorized children), children **retain** their individual `direct_bind` status. The parent's `gen_call_data_code` delegates to each child's `gen_trampoline_load`/`gen_trampoline_store` — direct-bind children emit direct assignment (e.g., `value.y = y;`) within the parent's `__slangpy_load`, while non-direct-bind children use the standard `__slangpy_load(context.map(...))` path. The old `_clear_direct_bind()` / `_force_no_direct_bind` approach was removed. + +**Writable ValueRef excluded from direct binding.** Writable value refs require buffer allocation, GPU readback, and `__slangpy_store` indirection. Only read-only value refs (`access[0] == AccessType.read`) are direct-bind eligible. This means auto-created `_result` (which is writable) always uses the `RWValueRef` wrapper path. diff --git a/.github/prompts/plan-simplifyKernelGen.prompt.md b/.github/prompts/plan-simplifyKernelGen.prompt.md index 4a05f68cb..03e7033e2 100644 --- a/.github/prompts/plan-simplifyKernelGen.prompt.md +++ b/.github/prompts/plan-simplifyKernelGen.prompt.md @@ -39,8 +39,8 @@ The implementation was refactored from global predicate functions (`is_direct_bi | 1.1 | ✅ Done | `Marshall.can_direct_bind(binding)` virtual method (default `False`). Shared `can_direct_bind_common(binding)` helper. `BoundVariable.calculate_direct_bind()` depth-first tree pass. `calculate_direct_binding(call)` in `callsignature.py`. | | 1.2 | ✅ Done | `ValueMarshall`: `can_direct_bind` overrides. `gen_calldata`, `gen_trampoline_load/store`, `create_calldata` read `binding.direct_bind`. | | 1.3 | ✅ Done | `VectorMarshall`: `gen_calldata` emits raw `typealias` (e.g., `vector`). Inherits trampoline load/store and `can_direct_bind` from `ValueMarshall`. | -| 1.4 | ✅ Done | `StructMarshall`/`BoundVariable`: `can_direct_bind` checks all children. `gen_call_data_code` uses `self.direct_bind`. Non-direct-bind composites clear children's `direct_bind` via `_clear_direct_bind()`. | -| 1.5 | ✅ Done | `ValueRefMarshall`: `can_direct_bind` override. All methods read `binding.direct_bind`. | +| 1.4 | ✅ Done | `StructMarshall`/`BoundVariable`: `can_direct_bind` checks all children. `gen_call_data_code` uses `self.direct_bind`. Non-direct-bind composites let children retain their `direct_bind` status; `gen_call_data_code` delegates to children's `gen_trampoline_load/store`. | +| 1.5 | ✅ Done | `ValueRefMarshall`: `can_direct_bind` requires read-only access. Writable value refs (including auto-created `_result`) use wrapper path (`RWValueRef`). | | 1.6 | ✅ Done | Tensor dim-0: `gen_trampoline_load/store` extended for `ITensorType` at dim-0 (direct struct assignment). | | 1.7 | ✅ Done | Mapping constants (`static const int _m_{name}`) skipped when `self.direct_bind`. | | 1.8 | ⬜ Deferred | Autodiff/bwds mode still uses wrapper types. | @@ -55,13 +55,13 @@ The implementation was refactored from global predicate functions (`is_direct_bi | `src/slangpy_ext/utils/slangpyvalue.h` | `m_direct_bind`, `direct_bind()`, `set_direct_bind()` **removed** from `NativeValueMarshall` | | `src/slangpy_ext/utils/slangpyvalue.cpp` | `ensure_cached` reads `binding->direct_bind()`; nanobind `direct_bind` property **removed** from `NativeValueMarshall` | | `slangpy/bindings/marshall.py` | `can_direct_bind(binding)` virtual method (default `False`) | -| `slangpy/bindings/boundvariable.py` | `can_direct_bind_common()`, `BoundVariable.direct_bind`, `calculate_direct_bind()`, `_clear_direct_bind()`. Removed: `is_direct_bind_eligible`, `is_direct_bind_recursive`, `_set_direct_bind_on_children`, `_force_no_direct_bind`, `_DIRECT_BIND_TYPES`. | +| `slangpy/bindings/boundvariable.py` | `can_direct_bind_common()`, `BoundVariable.direct_bind`, `calculate_direct_bind()`. Removed: `is_direct_bind_eligible`, `is_direct_bind_recursive`, `_set_direct_bind_on_children`, `_force_no_direct_bind`, `_DIRECT_BIND_TYPES`, `_clear_direct_bind()`. | | `slangpy/bindings/boundvariableruntime.py` | `self.direct_bind = source.direct_bind` propagation | | `slangpy/bindings/__init__.py` | Exports `can_direct_bind_common` (removed old predicate exports) | | `slangpy/core/callsignature.py` | `calculate_direct_binding(call)` function | | `slangpy/core/calldata.py` | `calculate_direct_binding(bindings)` call after `calculate_differentiability` | | `slangpy/builtin/value.py` | `can_direct_bind`, `gen_calldata`, `gen_trampoline_load/store`, `create_calldata` use `binding.direct_bind` | -| `slangpy/builtin/valueref.py` | `can_direct_bind`, all methods use `binding.direct_bind`. Removed `self._direct_bind`. | +| `slangpy/builtin/valueref.py` | `can_direct_bind` (read-only only), all methods use `binding.direct_bind`. Removed `self._direct_bind`. | | `slangpy/builtin/struct.py` | `can_direct_bind`, `gen_trampoline_load/store` use `binding.direct_bind` | | `slangpy/builtin/tensorcommon.py` | `gen_trampoline_load/store` extended for `ITensorType` (unchanged in refactor) | | `slangpy/tests/slangpy_tests/test_kernel_gen.py` | All Phase 1 tests | @@ -129,7 +129,7 @@ pre-commit run --all-files ### Key Decisions - Phase 1 changes both `gen_calldata` and trampoline load/store (TensorView-complete pattern, not partial) -- All dim-0 non-composite types are eligible, including tensors and value refs +- All dim-0 non-composite types are eligible, excluding writable value refs (which need buffer logic) - Phase 2 targets both `entry_point` (CUDA) and `global_data` (Vulkan/D3D12) modes - Autograd (bwds mode) is included in simplification, but implemented after prim mode within each phase - WangHashArg explicitly excluded from direct binding (needs per-thread `thread_id` computation) diff --git a/slangpy/tests/slangpy_tests/test_kernel_gen.py b/slangpy/tests/slangpy_tests/test_kernel_gen.py index f1a089db0..400fc040e 100644 --- a/slangpy/tests/slangpy_tests/test_kernel_gen.py +++ b/slangpy/tests/slangpy_tests/test_kernel_gen.py @@ -128,9 +128,8 @@ def test_gate_scalar_uses_valuetype(device_type: spy.DeviceType): assert_contains(code, "typealias _t_a = int;", "typealias _t_b = int;") # Trampoline uses direct assignment, no __slangpy_load assert_trampoline_has(code, "a = __calldata__.a;", "b = __calldata__.b;") - # _result is auto-created as RWValueRef — now uses RWStructuredBuffer - assert_not_contains(code, "RWValueRef") - assert_contains(code, "RWStructuredBuffer") + # _result is auto-created as writable RWValueRef (not direct-bind) + assert_contains(code, "RWValueRef") @pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) @@ -202,11 +201,12 @@ def test_gate_valueref_read_uses_wrapper(device_type: spy.DeviceType): "float read_val(float v) { return v; }", ValueRef(1.0), ) - # Read-only ValueRef now uses raw type alias, not ValueRef - assert_not_contains(code, "ValueRef") + # Read-only ValueRef uses raw type alias (direct-bind) assert_contains(code, "typealias _t_v = float;") # Direct assignment in trampoline assert_trampoline_has(code, "v = __calldata__.v;") + # _result (writable) still uses RWValueRef wrapper + assert_contains(code, "RWValueRef") @pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) @@ -219,11 +219,10 @@ def test_gate_valueref_write_uses_wrapper(device_type: spy.DeviceType): 1, 2, ) - # Auto-created _result uses RWStructuredBuffer instead of RWValueRef - assert_not_contains(code, "RWValueRef") - assert_contains(code, "RWStructuredBuffer") - # Trampoline uses buffer load/store - assert_trampoline_has(code, "_result = __calldata__._result[0];") + # Auto-created _result uses RWValueRef (writable, not direct-bind) + assert_contains(code, "RWValueRef") + # Trampoline uses __slangpy_store via wrapper + assert_contains(code, "__slangpy_store") # -- Step 1.7: Mapping constants and context.map -- @@ -244,8 +243,9 @@ def test_gate_mapping_constants_present(device_type: spy.DeviceType): code, "static const int _m_a = 0", "static const int _m_b = 0", - "static const int _m__result = 0", ) + # _result is NOT direct-bind (writable ValueRef), so it keeps mapping constant + assert_contains(code, "static const int _m__result = 0") @pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) From f85dd1a5d19ddf125d79c5dc3302c849fd585ae1 Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Wed, 11 Mar 2026 12:57:26 +0000 Subject: [PATCH 08/41] code cleanup --- .../plan-simplifyKernelGen-phase1.prompt.md | 18 +++ .../prompts/plan-simplifyKernelGen.prompt.md | 30 +++++ slangpy/bindings/boundvariable.py | 13 +- slangpy/builtin/struct.py | 9 +- slangpy/builtin/valueref.py | 9 +- .../tests/slangpy_tests/test_kernel_gen.py | 111 ++++++++++++++++++ src/slangpy_ext/utils/slangpyvalue.cpp | 1 + src/slangpy_ext/utils/slangpyvalue.h | 1 + 8 files changed, 175 insertions(+), 17 deletions(-) diff --git a/.github/prompts/plan-simplifyKernelGen-phase1.prompt.md b/.github/prompts/plan-simplifyKernelGen-phase1.prompt.md index 3c7576a84..2778f6071 100644 --- a/.github/prompts/plan-simplifyKernelGen-phase1.prompt.md +++ b/.github/prompts/plan-simplifyKernelGen-phase1.prompt.md @@ -173,6 +173,24 @@ The old `self._direct_bind` attribute was **removed** — all checks now use `bi 2952 passed / 0 failed in `slangpy/tests/slangpy_tests`. 6 pre-existing failures in `slangpy/tests/device/` (raytracing pipeline, type conformance cache — unrelated). +### Review Notes + +**Issues to address before merge:** + +1. **`StructMarshall.can_direct_bind` children branch is dead code.** `calculate_direct_bind()` handles composites directly (when `self.children is not None`) and never calls the marshall's `can_direct_bind`. The `if binding.children is not None:` branch in `StructMarshall.can_direct_bind` is unreachable. Fix: remove the children branch or have `calculate_direct_bind` delegate to the marshall for composites. + +2. **Composite direct-bind should gate on read-only access.** Add `and self.access[0] == AccessType.read` to the composite branch in `calculate_direct_bind()` (matching `ValueRefMarshall` pattern). Without this, a writable dim-0 composite would be incorrectly marked direct-bind. + +3. **Dead `binding.direct_bind` checks in writable ValueRef paths** ([valueref.py](slangpy/builtin/valueref.py) lines ~215, ~230, ~248). Since `can_direct_bind` rejects non-read access, these branches are unreachable. Remove or add `assert not binding.direct_bind` to make the invariant explicit. + +4. **Overly defensive `hasattr` guard** in `calculate_direct_bind()` — `hasattr(self.python, "can_direct_bind")` is unnecessary since `Marshall` base class always defines this method. + +5. **Benchmark file** — `test_benchmark_autograd.py` has accidental local changes that should be reverted. + +6. **C++ improvements** — Add debug assertion in `NativeValueMarshall::ensure_cached` verifying cached `direct_bind` matches binding's; consider making `NativeBoundVariableRuntime.direct_bind` read-only in nanobind. + +**Missing tests to add:** Writable ValueRef inout, `_result` binding flag, all-scalar struct binding flag, struct+WangHashArg child, WangHashArg binding flag, functional read-only ValueRef, bwds binding flags. See parent plan for full table. + ### Design Decisions **`direct_bind` lives on `NativeBoundVariableRuntime`, not `NativeValueMarshall`.** The original implementation stored `m_direct_bind` on the marshall itself (`NativeValueMarshall`), but marshalls are shared across calls while bindings are per-call. Moving the flag to the binding makes it immutable per-call and eliminates mutable state on shared marshall instances. diff --git a/.github/prompts/plan-simplifyKernelGen.prompt.md b/.github/prompts/plan-simplifyKernelGen.prompt.md index 03e7033e2..75290749c 100644 --- a/.github/prompts/plan-simplifyKernelGen.prompt.md +++ b/.github/prompts/plan-simplifyKernelGen.prompt.md @@ -133,3 +133,33 @@ pre-commit run --all-files - Phase 2 targets both `entry_point` (CUDA) and `global_data` (Vulkan/D3D12) modes - Autograd (bwds mode) is included in simplification, but implemented after prim mode within each phase - WangHashArg explicitly excluded from direct binding (needs per-thread `thread_id` computation) + +--- + +### Code Review Notes (PR #862) + +**Bugs / concerns found:** + +1. **Benchmark file changes are accidental** — `test_benchmark_autograd.py` has local tuning changes (ITERATIONS 10→100, WARMUPS 10→1000, RUN_SLANGTORCH_BENCHMARK False→True) that should be reverted before merge. + +2. **Composite `calculate_direct_bind` doesn't consult the marshall** — When `self.children is not None`, `calculate_direct_bind()` hard-codes eligibility criteria and never calls `self.python.can_direct_bind(self)`. This auto-opts-in composites if all children pass, preventing a marshall from rejecting. The `StructMarshall.can_direct_bind` children branch is dead code as a result. Either have composites delegate to the marshall, or remove the dead code from `StructMarshall.can_direct_bind`. + +3. **Composite direct-bind doesn't gate on read-only access** — `calculate_direct_bind`'s composite branch doesn't check `self.access[0] == AccessType.read`. A writable composite at dim-0 could be marked direct-bind, but no `__slangpy_store` is generated for the raw type alias. `ValueRefMarshall` correctly gates on read-only — composites should too. + +4. **Dead code in `ValueRefMarshall.create_calldata`/`read_calldata`** — `if binding.direct_bind` checks in writable code paths are unreachable since `can_direct_bind` rejects non-read access. Remove or assert. + +5. **C++ cache safety** — `NativeValueMarshall::ensure_cached` caches the `direct_bind` cursor path on first dispatch but has no assertion that subsequent calls use the same `direct_bind` value. Safe in current architecture (each call signature gets its own marshall), but fragile. Consider adding `SGL_ASSERT(m_cached.direct_bind == binding->direct_bind())` for debug builds. + +6. **`set_direct_bind` exposed as read-write nanobind property** — After first dispatch, mutating `direct_bind` invalidates the cached cursor offset silently. Consider making it read-only in the nanobind binding. + +**Missing test coverage (high priority):** + +| Test | Purpose | +|------|---------| +| Writable `ValueRef` `inout` param → `direct_bind=False` | Guards access-check logic in `ValueRefMarshall.can_direct_bind` | +| `_result` auto-created binding → `direct_bind=False` flag | Binding-level assertion, not just codegen | +| All-scalar struct → `direct_bind=True` binding flag | Struct direct-bind logic verified at binding level | +| Struct with WangHashArg child → composite NOT direct-bind | Mixed non-eligible child in composite | +| WangHashArg → `direct_bind=False` binding flag | Type without `can_direct_bind` override | +| Functional GPU test: read-only `ValueRef` input | End-to-end direct-bind ValueRef pipeline | +| Bwds mode binding flags on primal args | Verify access-tuple indexing in backwards mode | diff --git a/slangpy/bindings/boundvariable.py b/slangpy/bindings/boundvariable.py index 5d0b74e36..534029b44 100644 --- a/slangpy/bindings/boundvariable.py +++ b/slangpy/bindings/boundvariable.py @@ -512,17 +512,8 @@ def calculate_direct_bind(self) -> None: if self.children is not None: for child in self.children.values(): child.calculate_direct_bind() - if ( - self.call_dimensionality is not None - and self.call_dimensionality == 0 - and not getattr(self, "create_param_block", False) - and self.vector_type is not None - and all(child.direct_bind for child in self.children.values()) - ): - self.direct_bind = True - else: - if self.python is not None and hasattr(self.python, "can_direct_bind"): - self.direct_bind = self.python.can_direct_bind(self) + if self.python is not None and hasattr(self.python, "can_direct_bind"): + self.direct_bind = self.python.can_direct_bind(self) def get_input_list(self, args: list["BoundVariable"]): """ diff --git a/slangpy/builtin/struct.py b/slangpy/builtin/struct.py index 2a39d98dd..3439797b5 100644 --- a/slangpy/builtin/struct.py +++ b/slangpy/builtin/struct.py @@ -78,7 +78,14 @@ def resolve_dimensionality( def can_direct_bind(self, binding: "BoundVariable") -> bool: if binding.children is not None: - return all(child.direct_bind for child in binding.children.values()) + return ( + binding.call_dimensionality is not None + and binding.call_dimensionality == 0 + and not getattr(binding, "create_param_block", False) + and binding.vector_type is not None + and binding.access[0] == AccessType.read + and all(child.direct_bind for child in binding.children.values()) + ) return can_direct_bind_common(binding) # A struct type should get a dictionary, and just return that for raw dispatch diff --git a/slangpy/builtin/valueref.py b/slangpy/builtin/valueref.py index 012c6e707..c13549d0f 100644 --- a/slangpy/builtin/valueref.py +++ b/slangpy/builtin/valueref.py @@ -212,8 +212,7 @@ def create_calldata( if access[0] != AccessType.write: cursor[0].write(data.value) cursor.apply() - if binding.direct_bind: - return buffer + assert not binding.direct_bind return {"value": buffer} else: if isinstance(self.value_type, kfr.SlangType): @@ -227,8 +226,7 @@ def create_calldata( data=npdata, usage=BufferUsage.shader_resource | BufferUsage.unordered_access, ) - if binding.direct_bind: - return buffer + assert not binding.direct_bind return {"value": buffer} # Value ref just passes its value for raw dispatch @@ -245,7 +243,8 @@ def read_calldata( ) -> None: access = binding.access if access[0] in [AccessType.write, AccessType.readwrite]: - buffer = result if binding.direct_bind else result["value"] + assert not binding.direct_bind + buffer = result["value"] assert isinstance(buffer, Buffer) if isinstance(binding.vector_type, (kfr.StructType, kfr.ArrayType)): cursor = BufferCursor(binding.vector_type.buffer_layout.reflection, buffer) diff --git a/slangpy/tests/slangpy_tests/test_kernel_gen.py b/slangpy/tests/slangpy_tests/test_kernel_gen.py index 400fc040e..06ccae31c 100644 --- a/slangpy/tests/slangpy_tests/test_kernel_gen.py +++ b/slangpy/tests/slangpy_tests/test_kernel_gen.py @@ -664,5 +664,116 @@ def test_mixed_children_direct_bind_codegen(device_type: spy.DeviceType): assert_contains(code, "struct _t_s") +# =========================================================================== +# Review coverage — binding flag verification tests +# =========================================================================== + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_writable_valueref_not_direct_bind(device_type: spy.DeviceType): + """Writable ValueRef (inout) must not be direct-bind — needs buffer read/write.""" + device = helpers.get_device(device_type) + src = "void inc(inout int v) { v += 1; }" + func = helpers.create_function_from_module(device, "inc", src) + vr = ValueRef(5) + cd = func.debug_build_call_data(vr) + bindings = cd.debug_only_bindings + v_binding = bindings.args[0] + assert v_binding.direct_bind is False + assert v_binding.call_dimensionality == 0 + code = cd.code + assert_contains(code, "RWValueRef") + assert_not_contains(code, "typealias _t_v = int;") + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_result_binding_not_direct_bind(device_type: spy.DeviceType): + """Auto-created _result (writable ValueRef) must not be direct-bind.""" + device = helpers.get_device(device_type) + func = helpers.create_function_from_module( + device, "add", "int add(int a, int b) { return a + b; }" + ) + cd = func.debug_build_call_data(1, 2) + result_binding = cd.debug_only_bindings.kwargs["_result"] + assert result_binding.direct_bind is False + assert result_binding.call_dimensionality == 0 + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_struct_all_scalars_binding_flag(device_type: spy.DeviceType): + """All-scalar struct at dim-0 should have direct_bind=True (and so should children).""" + device = helpers.get_device(device_type) + src = """ +struct S { float x; float y; }; +float sum(S s) { return s.x + s.y; } +""" + func = helpers.create_function_from_module(device, "sum", src) + cd = func.debug_build_call_data({"_type": "S", "x": 1.0, "y": 2.0}) + bindings = cd.debug_only_bindings + s = bindings.args[0] + assert s.direct_bind is True + assert s.children["x"].direct_bind is True + assert s.children["y"].direct_bind is True + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_struct_with_wanghash_child_not_direct_bind(device_type: spy.DeviceType): + """Struct with a WangHashArg child must NOT be direct-bind.""" + device = helpers.get_device(device_type) + src = """ +struct S { uint3 seed; float scale; }; +float apply(S s) { return float(s.seed.x) * s.scale; } +""" + func = helpers.create_function_from_module(device, "apply", src) + cd = func.debug_build_call_data({"_type": "S", "seed": WangHashArg(3), "scale": 1.0}) + bindings = cd.debug_only_bindings + s = bindings.args[0] + assert s.direct_bind is False + # scale child should still be direct-bind individually + assert s.children["scale"].direct_bind is True + code = cd.code + assert_contains(code, "struct _t_s") + assert_not_contains(code, "typealias _t_s = S;") + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_wanghasharg_binding_flag(device_type: spy.DeviceType): + """WangHashArg (no can_direct_bind override) should have direct_bind=False.""" + device = helpers.get_device(device_type) + src = "uint3 rng(uint3 input) { return input; }" + func = helpers.create_function_from_module(device, "rng", src) + cd = func.debug_build_call_data(WangHashArg(3)) + bindings = cd.debug_only_bindings + assert bindings.args[0].direct_bind is False + assert bindings.args[0].call_dimensionality == 0 + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_phase1_functional_valueref_read_input(device_type: spy.DeviceType): + """Dispatch with a read-only ValueRef input — verifies direct-bind ValueRef pipeline end-to-end.""" + device = helpers.get_device(device_type) + func = helpers.create_function_from_module( + device, "double_it", "float double_it(float v) { return v * 2; }" + ) + result = func(ValueRef(7.0)) + assert abs(result - 14.0) < 1e-5 + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_bwds_primal_binding_flags(device_type: spy.DeviceType): + """In bwds mode, primal args (access[0]=read) should have direct_bind=True.""" + device = helpers.get_device(device_type) + src = """ +[Differentiable] +float polynomial(float a, float b) { return a * a + b + 1; } +""" + func = helpers.create_function_from_module(device, "polynomial", src) + cd = func.bwds.debug_build_call_data(5.0, 10.0, 26.0) + bindings = cd.debug_only_bindings + # Primal args in bwds mode → access[0]=read → direct_bind should be True + assert bindings.args[0].direct_bind is True # 'a' + assert bindings.args[1].direct_bind is True # 'b' + + if __name__ == "__main__": pytest.main([__file__, "-vs"]) diff --git a/src/slangpy_ext/utils/slangpyvalue.cpp b/src/slangpy_ext/utils/slangpyvalue.cpp index 0b3435208..6f6543ff9 100644 --- a/src/slangpy_ext/utils/slangpyvalue.cpp +++ b/src/slangpy_ext/utils/slangpyvalue.cpp @@ -21,6 +21,7 @@ void NativeValueMarshall::ensure_cached(ShaderCursor cursor, NativeBoundVariable m_cached.value_offset = field.offset(); m_cached.value_type_layout = field.slang_type_layout(); m_cached.writer = get_shader_cursor_writer(m_cached.value_type_layout); + m_cached.direct_bind = binding->direct_bind(); m_cached.is_valid = true; } diff --git a/src/slangpy_ext/utils/slangpyvalue.h b/src/slangpy_ext/utils/slangpyvalue.h index 42ad5ae28..41040407d 100644 --- a/src/slangpy_ext/utils/slangpyvalue.h +++ b/src/slangpy_ext/utils/slangpyvalue.h @@ -34,6 +34,7 @@ class NativeValueMarshall : public NativeMarshall { ShaderOffset value_offset; ///< Offset to the value field. slang::TypeLayoutReflection* value_type_layout = nullptr; ///< Type layout for value field. std::function writer; ///< Pre-resolved writer fn. + bool direct_bind{false}; ///< direct_bind value used when populating cache. bool is_valid = false; }; From 0dedb1dc2e83d344e849f6c2c95aaa3096641c0f Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Wed, 11 Mar 2026 13:25:11 +0000 Subject: [PATCH 09/41] code cleanup --- slangpy/bindings/boundvariable.py | 2 +- slangpy/builtin/struct.py | 5 ++--- slangpy/builtin/tensorcommon.py | 34 +++++++++++-------------------- slangpy/builtin/value.py | 10 --------- 4 files changed, 15 insertions(+), 36 deletions(-) diff --git a/slangpy/bindings/boundvariable.py b/slangpy/bindings/boundvariable.py index 534029b44..aa318ec25 100644 --- a/slangpy/bindings/boundvariable.py +++ b/slangpy/bindings/boundvariable.py @@ -154,7 +154,7 @@ def can_direct_bind_common(binding: "BoundVariable") -> bool: :param binding: The bound variable to check. :return: True if the common prerequisites for direct binding are met. """ - if binding.call_dimensionality is None or binding.call_dimensionality != 0: + if binding.call_dimensionality != 0: return False if binding.children: return False diff --git a/slangpy/builtin/struct.py b/slangpy/builtin/struct.py index 3439797b5..886b864e2 100644 --- a/slangpy/builtin/struct.py +++ b/slangpy/builtin/struct.py @@ -79,9 +79,8 @@ def resolve_dimensionality( def can_direct_bind(self, binding: "BoundVariable") -> bool: if binding.children is not None: return ( - binding.call_dimensionality is not None - and binding.call_dimensionality == 0 - and not getattr(binding, "create_param_block", False) + binding.call_dimensionality == 0 + and not binding.create_param_block and binding.vector_type is not None and binding.access[0] == AccessType.read and all(child.direct_bind for child in binding.children.values()) diff --git a/slangpy/builtin/tensorcommon.py b/slangpy/builtin/tensorcommon.py index 8d4e42e54..c30d964e7 100644 --- a/slangpy/builtin/tensorcommon.py +++ b/slangpy/builtin/tensorcommon.py @@ -386,18 +386,13 @@ def gen_trampoline_load( data_name: str, value_name: str, ) -> bool: - if not isinstance(binding.vector_type, (TensorViewType, DiffTensorViewType)): - # For ITensorType at dim-0, use direct assignment (struct copy) - if ( - isinstance(binding.vector_type, ITensorType) - and binding.call_dimensionality is not None - and binding.call_dimensionality == 0 - ): - cgb.append_statement(f"{value_name} = {data_name}") - return True - return False - cgb.append_statement(f"{value_name} = {data_name}") - return True + if isinstance(binding.vector_type, (TensorViewType, DiffTensorViewType)): + cgb.append_statement(f"{value_name} = {data_name}") + return True + if isinstance(binding.vector_type, ITensorType) and binding.call_dimensionality == 0: + cgb.append_statement(f"{value_name} = {data_name}") + return True + return False def gen_trampoline_store( @@ -407,13 +402,8 @@ def gen_trampoline_store( data_name: str, value_name: str, ) -> bool: - if not isinstance(binding.vector_type, (TensorViewType, DiffTensorViewType)): - # For ITensorType at dim-0, suppress default store - if ( - isinstance(binding.vector_type, ITensorType) - and binding.call_dimensionality is not None - and binding.call_dimensionality == 0 - ): - return True - return False - return True + if isinstance(binding.vector_type, (TensorViewType, DiffTensorViewType)): + return True + if isinstance(binding.vector_type, ITensorType) and binding.call_dimensionality == 0: + return True + return False diff --git a/slangpy/builtin/value.py b/slangpy/builtin/value.py index 6d30b33a5..69aa66dfb 100644 --- a/slangpy/builtin/value.py +++ b/slangpy/builtin/value.py @@ -125,16 +125,6 @@ def gen_trampoline_store( # ValueMarshall is read-only — suppress the default store return True - # Call data just returns the primal - def create_calldata( - self, context: CallContext, binding: "BoundVariableRuntime", data: Any - ) -> Any: - access = binding.access - if access[0] in [AccessType.read, AccessType.readwrite]: - if binding.direct_bind: - return data - return {"value": data} - # Values just return themselves for raw dispatch def create_dispatchdata(self, data: Any) -> Any: return data From bd6690f358b29279efd78568c781a47d37d1ab43 Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Wed, 11 Mar 2026 14:01:30 +0000 Subject: [PATCH 10/41] more tensor cleanup --- slangpy/builtin/tensor.py | 3 +++ slangpy/builtin/tensorcommon.py | 27 ++++++++++--------- .../tests/slangpy_tests/test_kernel_gen.py | 5 ++-- .../torchintegration/torchtensormarshall.py | 3 +++ 4 files changed, 22 insertions(+), 16 deletions(-) diff --git a/slangpy/builtin/tensor.py b/slangpy/builtin/tensor.py index 70951d043..4ef9a9e49 100644 --- a/slangpy/builtin/tensor.py +++ b/slangpy/builtin/tensor.py @@ -135,6 +135,9 @@ def resolve_dimensionality( ): return spytc.resolve_dimensionality(self, context, binding, vector_target_type) + def can_direct_bind(self, binding: BoundVariable) -> bool: + return spytc.can_direct_bind(self, binding) + def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: BoundVariable): return spytc.gen_calldata(self, cgb, context, binding) diff --git a/slangpy/builtin/tensorcommon.py b/slangpy/builtin/tensorcommon.py index c30d964e7..ca39d6cb1 100644 --- a/slangpy/builtin/tensorcommon.py +++ b/slangpy/builtin/tensorcommon.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from typing import Optional, Protocol -from slangpy.bindings import BoundVariable, BindContext, CodeGenBlock +from slangpy.bindings import BoundVariable, BindContext, CodeGenBlock, can_direct_bind_common from slangpy.core.native import CallMode, AccessType from slangpy.reflection import ( SlangType, @@ -379,6 +379,12 @@ def gen_calldata( cgb.type_alias(f"_t_{binding.variable_name}", type_name) +def can_direct_bind(self: ITensorMarshall, binding: BoundVariable) -> bool: + if not can_direct_bind_common(binding): + return False + return isinstance(binding.vector_type, (TensorViewType, DiffTensorViewType, ITensorType)) + + def gen_trampoline_load( self: ITensorMarshall, cgb: CodeGenBlock, @@ -386,13 +392,10 @@ def gen_trampoline_load( data_name: str, value_name: str, ) -> bool: - if isinstance(binding.vector_type, (TensorViewType, DiffTensorViewType)): - cgb.append_statement(f"{value_name} = {data_name}") - return True - if isinstance(binding.vector_type, ITensorType) and binding.call_dimensionality == 0: - cgb.append_statement(f"{value_name} = {data_name}") - return True - return False + if not binding.direct_bind: + return False + cgb.append_statement(f"{value_name} = {data_name}") + return True def gen_trampoline_store( @@ -402,8 +405,6 @@ def gen_trampoline_store( data_name: str, value_name: str, ) -> bool: - if isinstance(binding.vector_type, (TensorViewType, DiffTensorViewType)): - return True - if isinstance(binding.vector_type, ITensorType) and binding.call_dimensionality == 0: - return True - return False + if not binding.direct_bind: + return False + return True diff --git a/slangpy/tests/slangpy_tests/test_kernel_gen.py b/slangpy/tests/slangpy_tests/test_kernel_gen.py index 06ccae31c..ee0787e3a 100644 --- a/slangpy/tests/slangpy_tests/test_kernel_gen.py +++ b/slangpy/tests/slangpy_tests/test_kernel_gen.py @@ -596,7 +596,7 @@ def test_gate_tensor_dim0_codegen(device_type: spy.DeviceType): @pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) def test_gate_tensor_dim0_binding_flags(device_type: spy.DeviceType): - """Tensor at dim-0 has direct_bind=False (tensor marshalls don't opt in).""" + """Tensor at dim-0 has direct_bind=True (consistent with other dim-0 types).""" device = helpers.get_device(device_type) src = """ float tensor_read(Tensor t) { @@ -608,8 +608,7 @@ def test_gate_tensor_dim0_binding_flags(device_type: spy.DeviceType): cd = func.debug_build_call_data(tensor) bindings = cd.debug_only_bindings t_binding = bindings.args[0] - # Tensor marshalls don't implement can_direct_bind — direct_bind stays False - assert t_binding.direct_bind is False + assert t_binding.direct_bind is True assert t_binding.call_dimensionality == 0 diff --git a/slangpy/torchintegration/torchtensormarshall.py b/slangpy/torchintegration/torchtensormarshall.py index 970dde412..1e75384b3 100644 --- a/slangpy/torchintegration/torchtensormarshall.py +++ b/slangpy/torchintegration/torchtensormarshall.py @@ -207,6 +207,9 @@ def resolve_dimensionality( """Resolve dimensionality during vectorization.""" return spytc.resolve_dimensionality(self, context, binding, vector_target_type) + def can_direct_bind(self, binding: BoundVariable) -> bool: + return spytc.can_direct_bind(self, binding) + def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: BoundVariable): """Generate call data code for the kernel.""" return spytc.gen_calldata(self, cgb, context, binding) From d6032495340ba19f870e266a651db4b1d55c69fd Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Wed, 11 Mar 2026 14:15:52 +0000 Subject: [PATCH 11/41] extra tests --- .../tests/slangpy_tests/test_kernel_gen.py | 947 ++++++++++++++++++ 1 file changed, 947 insertions(+) diff --git a/slangpy/tests/slangpy_tests/test_kernel_gen.py b/slangpy/tests/slangpy_tests/test_kernel_gen.py index ee0787e3a..cc25948f5 100644 --- a/slangpy/tests/slangpy_tests/test_kernel_gen.py +++ b/slangpy/tests/slangpy_tests/test_kernel_gen.py @@ -774,5 +774,952 @@ def test_bwds_primal_binding_flags(device_type: spy.DeviceType): assert bindings.args[1].direct_bind is True # 'b' +# =========================================================================== +# ND tensor → (N-1)D parameter vectorization — kernel source pattern tests +# =========================================================================== + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_2d_tensor_to_vector_codegen(device_type: spy.DeviceType): + """2D Tensor shape=(10,3) → float3 param: trailing dim consumed by vector, outer dim dispatched.""" + device = helpers.get_device(device_type) + tensor = Tensor.from_numpy(device, np.ones((10, 3), dtype=np.float32)) + code = generate_code( + device, + "scale", + "float3 scale(float3 v, float s) { return v * s; }", + tensor, + 2.0, + ) + # v is vectorized dim-1: tensor wrapping a vector type + assert_contains(code, "__slangpy_load") + assert_contains(code, "_t_v") + assert_contains(code, "_m_v") + # s is scalar dim-0: direct-bind + assert_contains(code, "typealias _t_s = float;") + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_2d_tensor_to_vector_binding_flags(device_type: spy.DeviceType): + """2D Tensor shape=(10,3) → float3 param: check binding metadata.""" + device = helpers.get_device(device_type) + tensor = Tensor.from_numpy(device, np.ones((10, 3), dtype=np.float32)) + func = helpers.create_function_from_module( + device, + "scale", + "float3 scale(float3 v, float s) { return v * s; }", + ) + cd = func.debug_build_call_data(tensor, 2.0) + bindings = cd.debug_only_bindings + v_binding = bindings.args[0] + # Tensor vectorized over outer dim: call_dimensionality == 1 + assert v_binding.call_dimensionality == 1 + assert v_binding.direct_bind is False + assert v_binding.vector_type is not None + assert v_binding.vector_type.full_name == "vector" + # Scalar s: dim-0 direct-bind + s_binding = bindings.args[1] + assert s_binding.call_dimensionality == 0 + assert s_binding.direct_bind is True + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_phase1_functional_2d_tensor_to_vector(device_type: spy.DeviceType): + """Dispatch 2D tensor → float3 and verify GPU result.""" + device = helpers.get_device(device_type) + func = helpers.create_function_from_module( + device, + "scale", + "float3 scale(float3 v, float s) { return v * s; }", + ) + data = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) + tensor = Tensor.from_numpy(device, data) + result = func(tensor, 2.0) + expected = data * 2.0 + np.testing.assert_allclose(result.to_numpy().reshape(expected.shape), expected, atol=1e-5) + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_3d_tensor_to_vector_codegen(device_type: spy.DeviceType): + """3D Tensor shape=(2,5,3) → float3 param: two outer dims dispatched.""" + device = helpers.get_device(device_type) + tensor = Tensor.from_numpy(device, np.ones((2, 5, 3), dtype=np.float32)) + code = generate_code( + device, + "negate", + "float3 negate(float3 v) { return -v; }", + tensor, + ) + # v vectorized dim-2: uses __slangpy_load, mapping constant + assert_contains(code, "__slangpy_load") + assert_contains(code, "_m_v") + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_3d_tensor_to_vector_binding_flags(device_type: spy.DeviceType): + """3D Tensor shape=(2,5,3) → float3 param: call_dimensionality == 2.""" + device = helpers.get_device(device_type) + tensor = Tensor.from_numpy(device, np.ones((2, 5, 3), dtype=np.float32)) + func = helpers.create_function_from_module( + device, + "negate", + "float3 negate(float3 v) { return -v; }", + ) + cd = func.debug_build_call_data(tensor) + bindings = cd.debug_only_bindings + v = bindings.args[0] + assert v.call_dimensionality == 2 + assert v.direct_bind is False + assert v.vector_type.full_name == "vector" + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_phase1_functional_3d_tensor_to_vector(device_type: spy.DeviceType): + """Dispatch 3D tensor → float3 and verify GPU result.""" + device = helpers.get_device(device_type) + func = helpers.create_function_from_module( + device, + "negate", + "float3 negate(float3 v) { return -v; }", + ) + data = np.arange(30, dtype=np.float32).reshape(2, 5, 3) + tensor = Tensor.from_numpy(device, data) + result = func(tensor) + expected = -data + np.testing.assert_allclose(result.to_numpy().reshape(expected.shape), expected, atol=1e-5) + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_2d_tensor_to_scalar_codegen(device_type: spy.DeviceType): + """2D Tensor shape=(4,5) → float scalar: both dims dispatched (call_dim=2).""" + device = helpers.get_device(device_type) + tensor = Tensor.from_numpy(device, np.ones((4, 5), dtype=np.float32)) + code = generate_code( + device, + "square", + "float square(float x) { return x * x; }", + tensor, + ) + assert_contains(code, "__slangpy_load") + assert_contains(code, "_m_x") + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_2d_tensor_to_scalar_binding_flags(device_type: spy.DeviceType): + """2D Tensor shape=(4,5) → float scalar: call_dimensionality == 2.""" + device = helpers.get_device(device_type) + tensor = Tensor.from_numpy(device, np.ones((4, 5), dtype=np.float32)) + func = helpers.create_function_from_module( + device, + "square", + "float square(float x) { return x * x; }", + ) + cd = func.debug_build_call_data(tensor) + v = cd.debug_only_bindings.args[0] + assert v.call_dimensionality == 2 + assert v.direct_bind is False + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_phase1_functional_2d_tensor_to_scalar(device_type: spy.DeviceType): + """Dispatch 2D tensor elementwise to scalar and verify GPU result.""" + device = helpers.get_device(device_type) + func = helpers.create_function_from_module( + device, + "square", + "float square(float x) { return x * x; }", + ) + data = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) + tensor = Tensor.from_numpy(device, data) + result = func(tensor) + expected = data * data + np.testing.assert_allclose(result.to_numpy().reshape(expected.shape), expected, atol=1e-5) + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_2d_tensor_to_1d_array_codegen(device_type: spy.DeviceType): + """2D Tensor shape=(4,8) → half[8] param: trailing dim consumed by array, outer dim dispatched.""" + device = helpers.get_device(device_type) + tensor = Tensor.from_numpy(device, np.ones((4, 8), dtype=np.float16)) + code = generate_code( + device, + "tensor_test_channels<8>", + r""" +half[NumChannels] tensor_test_channels(half[NumChannels] data) +{ + [ForceUnroll] + for (int i = 0; i < NumChannels; ++i) + { + data[i] = 2.h * data[i]; + } + return data; +} +""", + tensor, + ) + # data is vectorized (trailing dim consumed by array): __slangpy_load present + assert_contains(code, "__slangpy_load") + assert_contains(code, "_m_data") + assert_contains(code, "_t_data") + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_2d_tensor_to_1d_array_binding_flags(device_type: spy.DeviceType): + """2D Tensor shape=(4,8) → half[8] param: call_dimensionality == 1.""" + device = helpers.get_device(device_type) + tensor = Tensor.from_numpy(device, np.ones((4, 8), dtype=np.float16)) + func = helpers.create_function_from_module( + device, + "tensor_test_channels<8>", + r""" +half[NumChannels] tensor_test_channels(half[NumChannels] data) +{ + [ForceUnroll] + for (int i = 0; i < NumChannels; ++i) + { + data[i] = 2.h * data[i]; + } + return data; +} +""", + ) + cd = func.debug_build_call_data(tensor) + v = cd.debug_only_bindings.args[0] + assert v.call_dimensionality == 1 + assert v.direct_bind is False + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_phase1_functional_2d_tensor_to_1d_array(device_type: spy.DeviceType): + """Dispatch 2D tensor → half[8] and verify GPU doubles each element.""" + device = helpers.get_device(device_type) + func = helpers.create_function_from_module( + device, + "tensor_test_channels<8>", + r""" +half[NumChannels] tensor_test_channels(half[NumChannels] data) +{ + [ForceUnroll] + for (int i = 0; i < NumChannels; ++i) + { + data[i] = 2.h * data[i]; + } + return data; +} +""", + ).return_type(Tensor) + data = np.ones((4, 8), dtype=np.float16) + tensor = Tensor.from_numpy(device, data) + result = func(tensor) + expected = data * 2.0 + np.testing.assert_allclose( + result.to_numpy().reshape(expected.shape).astype(np.float32), + expected.astype(np.float32), + atol=1e-2, + ) + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_mixed_vectorized_and_dim0_tensor_codegen(device_type: spy.DeviceType): + """One tensor vectorized (2D→float3) and another at dim-0 (Tensor param).""" + device = helpers.get_device(device_type) + src = """ +float dot_lookup(float3 v, Tensor weights) { + return v.x * weights[0] + v.y * weights[1] + v.z * weights[2]; +} +""" + vec_tensor = Tensor.from_numpy(device, np.ones((5, 3), dtype=np.float32)) + weight_tensor = Tensor.from_numpy(device, np.array([1, 2, 3], dtype=np.float32)) + code = generate_code(device, "dot_lookup", src, vec_tensor, weight_tensor) + # v: vectorized dim-1 (2D→float3), uses __slangpy_load + assert_contains(code, "_m_v") + assert_contains(code, "__slangpy_load") + # weights: dim-0 direct-bind (Tensor param), uses typealias + direct assignment + assert_contains(code, "typealias _t_weights = Tensor;") + assert_trampoline_has(code, "weights = __calldata__.weights;") + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_mixed_vectorized_and_dim0_tensor_binding_flags(device_type: spy.DeviceType): + """Binding flags: vectorized tensor has dim>0, dim-0 tensor has direct_bind.""" + device = helpers.get_device(device_type) + src = """ +float dot_lookup(float3 v, Tensor weights) { + return v.x * weights[0] + v.y * weights[1] + v.z * weights[2]; +} +""" + vec_tensor = Tensor.from_numpy(device, np.ones((5, 3), dtype=np.float32)) + weight_tensor = Tensor.from_numpy(device, np.array([1, 2, 3], dtype=np.float32)) + func = helpers.create_function_from_module(device, "dot_lookup", src) + cd = func.debug_build_call_data(vec_tensor, weight_tensor) + bindings = cd.debug_only_bindings + v = bindings.args[0] + assert v.call_dimensionality == 1 + assert v.direct_bind is False + w = bindings.args[1] + assert w.call_dimensionality == 0 + assert w.direct_bind is True + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_phase1_functional_mixed_vectorized_and_dim0_tensor(device_type: spy.DeviceType): + """Dispatch vectorized float3 + dim-0 Tensor and verify GPU result.""" + device = helpers.get_device(device_type) + src = """ +float dot_lookup(float3 v, Tensor weights) { + return v.x * weights[0] + v.y * weights[1] + v.z * weights[2]; +} +""" + func = helpers.create_function_from_module(device, "dot_lookup", src) + vecs = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float32) + weights = np.array([10, 20, 30], dtype=np.float32) + result = func( + Tensor.from_numpy(device, vecs), + Tensor.from_numpy(device, weights), + ) + expected = np.array([10, 20, 30], dtype=np.float32) + np.testing.assert_allclose(result.to_numpy().flatten(), expected, atol=1e-5) + + +# =========================================================================== +# Composite struct codegen tests — nested structs, vector/matrix/array fields +# =========================================================================== + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_nested_struct_codegen(device_type: spy.DeviceType): + """Nested struct: Outer{Inner inner, float scale} — all-scalar, direct-bind.""" + device = helpers.get_device(device_type) + src = """ +struct Inner { + float x; + float y; +}; +struct Outer { + Inner inner; + float scale; +}; +float compute(Outer o) { return (o.inner.x + o.inner.y) * o.scale; } +""" + code = generate_code( + device, + "compute", + src, + {"_type": "Outer", "inner": {"_type": "Inner", "x": 1.0, "y": 2.0}, "scale": 3.0}, + ) + # All-scalar nested struct at dim-0: direct-bind → raw typealias + assert_contains(code, "typealias _t_o = Outer;") + assert_not_contains(code, "__slangpy_load") + assert_not_contains(code, "struct _t_o") + assert_trampoline_has(code, "o = __calldata__.o;") + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_nested_struct_binding_flags(device_type: spy.DeviceType): + """Nested struct: all-scalar → direct_bind=True at every level.""" + device = helpers.get_device(device_type) + src = """ +struct Inner { + float x; + float y; +}; +struct Outer { + Inner inner; + float scale; +}; +float compute(Outer o) { return (o.inner.x + o.inner.y) * o.scale; } +""" + func = helpers.create_function_from_module(device, "compute", src) + cd = func.debug_build_call_data( + {"_type": "Outer", "inner": {"_type": "Inner", "x": 1.0, "y": 2.0}, "scale": 3.0} + ) + bindings = cd.debug_only_bindings + o = bindings.args[0] + assert o.direct_bind is True + assert o.children["inner"].direct_bind is True + assert o.children["inner"].children["x"].direct_bind is True + assert o.children["inner"].children["y"].direct_bind is True + assert o.children["scale"].direct_bind is True + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_phase1_functional_nested_struct(device_type: spy.DeviceType): + """Dispatch nested struct and verify GPU result.""" + device = helpers.get_device(device_type) + src = """ +struct Inner { + float x; + float y; +}; +struct Outer { + Inner inner; + float scale; +}; +float compute(Outer o) { return (o.inner.x + o.inner.y) * o.scale; } +""" + func = helpers.create_function_from_module(device, "compute", src) + result = func({"_type": "Outer", "inner": {"_type": "Inner", "x": 3.0, "y": 7.0}, "scale": 2.0}) + assert abs(result - 20.0) < 1e-5 + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_struct_with_vector_fields_codegen(device_type: spy.DeviceType): + """Struct with vector fields: S{float3 pos, float scale} — all dim-0, direct-bind.""" + device = helpers.get_device(device_type) + src = """ +struct S { + float3 pos; + float scale; +}; +float3 apply(S s) { return s.pos * s.scale; } +""" + code = generate_code( + device, + "apply", + src, + {"_type": "S", "pos": spy.math.float3(1, 2, 3), "scale": 2.0}, + ) + # All-scalar struct with vector field at dim-0: direct-bind → raw typealias + assert_contains(code, "typealias _t_s = S;") + assert_not_contains(code, "__slangpy_load") + assert_trampoline_has(code, "s = __calldata__.s;") + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_struct_with_vector_fields_binding_flags(device_type: spy.DeviceType): + """Struct with vector field — all children direct-bind.""" + device = helpers.get_device(device_type) + src = """ +struct S { + float3 pos; + float scale; +}; +float3 apply(S s) { return s.pos * s.scale; } +""" + func = helpers.create_function_from_module(device, "apply", src) + cd = func.debug_build_call_data({"_type": "S", "pos": spy.math.float3(1, 2, 3), "scale": 2.0}) + bindings = cd.debug_only_bindings + s = bindings.args[0] + assert s.direct_bind is True + assert s.children["pos"].direct_bind is True + assert s.children["scale"].direct_bind is True + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_phase1_functional_struct_with_vector_fields(device_type: spy.DeviceType): + """Dispatch struct with vector field and verify GPU result.""" + device = helpers.get_device(device_type) + src = """ +struct S { + float3 pos; + float scale; +}; +float3 apply(S s) { return s.pos * s.scale; } +""" + func = helpers.create_function_from_module(device, "apply", src) + result = func({"_type": "S", "pos": spy.math.float3(1, 2, 3), "scale": 3.0}) + assert result.x == 3.0 + assert result.y == 6.0 + assert result.z == 9.0 + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_struct_with_matrix_field_codegen(device_type: spy.DeviceType): + """Struct with matrix field: S{float4x4 m, float scale} — all dim-0, direct-bind.""" + device = helpers.get_device(device_type) + src = """ +struct S { + float4x4 m; + float scale; +}; +float4x4 apply(S s) { return s.m * s.scale; } +""" + code = generate_code( + device, + "apply", + src, + {"_type": "S", "m": spy.math.float4x4.identity(), "scale": 2.0}, + ) + assert_contains(code, "typealias _t_s = S;") + assert_not_contains(code, "__slangpy_load") + assert_trampoline_has(code, "s = __calldata__.s;") + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_phase1_functional_struct_with_matrix_field(device_type: spy.DeviceType): + """Dispatch struct with matrix field and verify GPU result.""" + device = helpers.get_device(device_type) + src = """ +struct S { + float4x4 m; + float scale; +}; +float4x4 apply(S s) { return s.m * s.scale; } +""" + func = helpers.create_function_from_module(device, "apply", src) + result = func({"_type": "S", "m": spy.math.float4x4.identity(), "scale": 2.0}) + # Identity * 2 → diagonal is 2 + assert abs(result[0][0] - 2.0) < 1e-5 + assert abs(result[1][1] - 2.0) < 1e-5 + assert abs(result[0][1] - 0.0) < 1e-5 + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_struct_with_array_field_codegen(device_type: spy.DeviceType): + """Struct with fixed-size array field: Foo{int vals[4]} — all dim-0, direct-bind.""" + device = helpers.get_device(device_type) + src = """ +struct Foo { + int vals[4]; +}; +int sum_inner(Foo foo) { + int s = 0; + for (int i = 0; i < 4; i++) { + s += foo.vals[i]; + } + return s; +} +""" + code = generate_code( + device, + "sum_inner", + src, + {"_type": "Foo", "vals": [1, 2, 3, 4]}, + ) + assert_contains(code, "typealias _t_foo = Foo;") + assert_not_contains(code, "__slangpy_load") + assert_trampoline_has(code, "foo = __calldata__.foo;") + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_struct_with_array_field_binding_flags(device_type: spy.DeviceType): + """Struct with array field: all direct_bind=True.""" + device = helpers.get_device(device_type) + src = """ +struct Foo { + int vals[4]; +}; +int sum_inner(Foo foo) { + int s = 0; + for (int i = 0; i < 4; i++) { + s += foo.vals[i]; + } + return s; +} +""" + func = helpers.create_function_from_module(device, "sum_inner", src) + cd = func.debug_build_call_data({"_type": "Foo", "vals": [1, 2, 3, 4]}) + bindings = cd.debug_only_bindings + foo = bindings.args[0] + assert foo.direct_bind is True + assert foo.children["vals"].direct_bind is True + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_phase1_functional_struct_with_array_field(device_type: spy.DeviceType): + """Dispatch struct with array field and verify GPU result.""" + device = helpers.get_device(device_type) + src = """ +struct Foo { + int vals[4]; +}; +int sum_inner(Foo foo) { + int s = 0; + for (int i = 0; i < 4; i++) { + s += foo.vals[i]; + } + return s; +} +""" + func = helpers.create_function_from_module(device, "sum_inner", src) + result = func({"_type": "Foo", "vals": [10, 20, 30, 40]}) + assert result == 100 + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_deeply_nested_struct_codegen(device_type: spy.DeviceType): + """3-level deep nesting: Top{Mid{Bot{float v}, int c}, float s} — all dim-0, direct-bind.""" + device = helpers.get_device(device_type) + src = """ +struct Bot { + float v; +}; +struct Mid { + Bot bot; + int c; +}; +struct Top { + Mid mid; + float s; +}; +float compute(Top t) { return t.mid.bot.v * float(t.mid.c) * t.s; } +""" + code = generate_code( + device, + "compute", + src, + { + "_type": "Top", + "mid": {"_type": "Mid", "bot": {"_type": "Bot", "v": 2.0}, "c": 3}, + "s": 4.0, + }, + ) + assert_contains(code, "typealias _t_t = Top;") + assert_not_contains(code, "__slangpy_load") + assert_not_contains(code, "struct _t_t") + assert_trampoline_has(code, "t = __calldata__.t;") + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_deeply_nested_struct_binding_flags(device_type: spy.DeviceType): + """3-level deep: all direct_bind=True at every level.""" + device = helpers.get_device(device_type) + src = """ +struct Bot { + float v; +}; +struct Mid { + Bot bot; + int c; +}; +struct Top { + Mid mid; + float s; +}; +float compute(Top t) { return t.mid.bot.v * float(t.mid.c) * t.s; } +""" + func = helpers.create_function_from_module(device, "compute", src) + cd = func.debug_build_call_data( + { + "_type": "Top", + "mid": {"_type": "Mid", "bot": {"_type": "Bot", "v": 2.0}, "c": 3}, + "s": 4.0, + } + ) + bindings = cd.debug_only_bindings + t = bindings.args[0] + assert t.direct_bind is True + assert t.children["mid"].direct_bind is True + assert t.children["mid"].children["bot"].direct_bind is True + assert t.children["mid"].children["bot"].children["v"].direct_bind is True + assert t.children["mid"].children["c"].direct_bind is True + assert t.children["s"].direct_bind is True + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_phase1_functional_deeply_nested_struct(device_type: spy.DeviceType): + """Dispatch 3-level nested struct and verify GPU result.""" + device = helpers.get_device(device_type) + src = """ +struct Bot { + float v; +}; +struct Mid { + Bot bot; + int c; +}; +struct Top { + Mid mid; + float s; +}; +float compute(Top t) { return t.mid.bot.v * float(t.mid.c) * t.s; } +""" + func = helpers.create_function_from_module(device, "compute", src) + result = func( + { + "_type": "Top", + "mid": {"_type": "Mid", "bot": {"_type": "Bot", "v": 2.0}, "c": 3}, + "s": 4.0, + } + ) + assert abs(result - 24.0) < 1e-5 + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_nested_struct_with_tensor_child_codegen(device_type: spy.DeviceType): + """Nested struct where a leaf is a tensor: Outer{Inner{float x (tensor), float y (scalar)}, float s}. + + Outer and Inner are NOT direct-bind (Inner.x is vectorized). + Inner.y and s retain direct_bind=True inside the non-direct-bind parent. + """ + device = helpers.get_device(device_type) + src = """ +struct Inner { + float x; + float y; +}; +struct Outer { + Inner inner; + float s; +}; +float compute(Outer o) { return (o.inner.x + o.inner.y) * o.s; } +""" + tensor_x = Tensor.from_numpy(device, np.array([1, 2, 3], dtype=np.float32)) + code = generate_code( + device, + "compute", + src, + { + "_type": "Outer", + "inner": {"_type": "Inner", "x": tensor_x, "y": 10.0}, + "s": 2.0, + }, + ) + # Outer and Inner are NOT direct-bind: inline structs generated + assert_contains(code, "struct _t_o") + assert_contains(code, "__slangpy_load") + assert_not_contains(code, "typealias _t_o = Outer;") + # Scalar children retain direct-bind: raw type aliases + assert_contains(code, "typealias _t_y = float;") + assert_contains(code, "typealias _t_s = float;") + # Direct assignment for scalar children within __slangpy_load + assert_contains(code, "value.y = y;") + # Tensor child uses standard path + assert_contains(code, "_m_x") + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_nested_struct_with_tensor_child_binding_flags(device_type: spy.DeviceType): + """Nested struct with tensor: Outer not direct-bind, scalar children retain direct_bind.""" + device = helpers.get_device(device_type) + src = """ +struct Inner { + float x; + float y; +}; +struct Outer { + Inner inner; + float s; +}; +float compute(Outer o) { return (o.inner.x + o.inner.y) * o.s; } +""" + tensor_x = Tensor.from_numpy(device, np.array([1, 2, 3], dtype=np.float32)) + func = helpers.create_function_from_module(device, "compute", src) + cd = func.debug_build_call_data( + { + "_type": "Outer", + "inner": {"_type": "Inner", "x": tensor_x, "y": 10.0}, + "s": 2.0, + } + ) + bindings = cd.debug_only_bindings + o = bindings.args[0] + assert o.direct_bind is False + assert o.children["inner"].direct_bind is False # has non-direct child + assert o.children["inner"].children["x"].direct_bind is False # tensor dim>0 + assert o.children["inner"].children["y"].direct_bind is True # scalar dim-0 + assert o.children["s"].direct_bind is True # scalar dim-0 + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_phase1_functional_nested_struct_with_tensor(device_type: spy.DeviceType): + """Dispatch nested struct with tensor leaf and verify GPU result.""" + device = helpers.get_device(device_type) + src = """ +struct Inner { + float x; + float y; +}; +struct Outer { + Inner inner; + float s; +}; +float compute(Outer o) { return (o.inner.x + o.inner.y) * o.s; } +""" + func = helpers.create_function_from_module(device, "compute", src) + tensor_x = Tensor.from_numpy(device, np.array([1, 2, 3], dtype=np.float32)) + result = func( + { + "_type": "Outer", + "inner": {"_type": "Inner", "x": tensor_x, "y": 10.0}, + "s": 2.0, + } + ) + expected = np.array([22, 24, 26], dtype=np.float32) + np.testing.assert_allclose(result.to_numpy().flatten(), expected, atol=1e-5) + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_struct_with_struct_array_field_codegen(device_type: spy.DeviceType): + """Struct with array-of-structs field: Outer{Inner items[4]} — all dim-0, direct-bind.""" + device = helpers.get_device(device_type) + src = """ +struct Inner { + int x; +}; +struct Outer { + Inner items[4]; +}; +int sum_inner(Outer outer) { + int s = 0; + for (int i = 0; i < 4; i++) { + s += outer.items[i].x; + } + return s; +} +""" + code = generate_code( + device, + "sum_inner", + src, + { + "_type": "Outer", + "items": [ + {"_type": "Inner", "x": 10}, + {"_type": "Inner", "x": 20}, + {"_type": "Inner", "x": 30}, + {"_type": "Inner", "x": 40}, + ], + }, + ) + assert_contains(code, "typealias _t_outer = Outer;") + assert_not_contains(code, "__slangpy_load") + assert_trampoline_has(code, "outer = __calldata__.outer;") + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_phase1_functional_struct_with_struct_array_field(device_type: spy.DeviceType): + """Dispatch struct with array-of-structs field and verify GPU result.""" + device = helpers.get_device(device_type) + src = """ +struct Inner { + int x; +}; +struct Outer { + Inner items[4]; +}; +int sum_inner(Outer outer) { + int s = 0; + for (int i = 0; i < 4; i++) { + s += outer.items[i].x; + } + return s; +} +""" + func = helpers.create_function_from_module(device, "sum_inner", src) + result = func( + { + "_type": "Outer", + "items": [ + {"_type": "Inner", "x": 10}, + {"_type": "Inner", "x": 20}, + {"_type": "Inner", "x": 30}, + {"_type": "Inner", "x": 40}, + ], + } + ) + assert result == 100 + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_struct_return_codegen(device_type: spy.DeviceType): + """Function returning a struct: _result uses RWValueRef wrapper, not direct-bind.""" + device = helpers.get_device(device_type) + src = """ +struct S { + int x; + int y; +}; +S make_struct(int a, int b) { return { a, b }; } +""" + code = generate_code(device, "make_struct", src, 4, 5) + # Scalar inputs are direct-bind + assert_contains(code, "typealias _t_a = int;", "typealias _t_b = int;") + # _result is writable → NOT direct-bind → uses wrapper + assert_contains(code, "__slangpy_store") + assert_contains(code, "_m__result") + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_struct_return_binding_flags(device_type: spy.DeviceType): + """Struct return: _result binding is NOT direct-bind (writable).""" + device = helpers.get_device(device_type) + src = """ +struct S { + int x; + int y; +}; +S make_struct(int a, int b) { return { a, b }; } +""" + func = helpers.create_function_from_module(device, "make_struct", src) + cd = func.debug_build_call_data(4, 5) + bindings = cd.debug_only_bindings + result = bindings.kwargs["_result"] + assert result.direct_bind is False + # Inputs are direct-bind + assert bindings.args[0].direct_bind is True + assert bindings.args[1].direct_bind is True + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_phase1_functional_struct_return(device_type: spy.DeviceType): + """Dispatch struct return and verify result is dict with correct values.""" + device = helpers.get_device(device_type) + src = """ +struct S { + int x; + int y; +}; +S make_struct(int a, int b) { return { a, b }; } +""" + func = helpers.create_function_from_module(device, "make_struct", src) + result = func(4, 5) + assert isinstance(result, dict) + assert result["x"] == 4 + assert result["y"] == 5 + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_struct_with_vectorized_2d_tensor_child_codegen(device_type: spy.DeviceType): + """Struct with 2D tensor child vectorized to float3: struct NOT direct-bind. + + S{float3 v (2D tensor→float3), float s (scalar)}. + The tensor's outer dim becomes dispatch, struct generates inline __slangpy_load. + """ + device = helpers.get_device(device_type) + src = """ +struct S { + float3 v; + float s; +}; +float3 apply(S st) { return st.v * st.s; } +""" + tensor_v = Tensor.from_numpy(device, np.ones((5, 3), dtype=np.float32)) + code = generate_code( + device, + "apply", + src, + {"_type": "S", "v": tensor_v, "s": 2.0}, + ) + # Struct NOT direct-bind (tensor child is vectorized) + assert_contains(code, "struct _t_st") + assert_contains(code, "__slangpy_load") + assert_not_contains(code, "typealias _t_st = S;") + # Scalar child s retains direct-bind + assert_contains(code, "typealias _t_s = float;") + assert_contains(code, "value.s = s;") + # Tensor child v uses standard path + assert_contains(code, "_m_v") + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_phase1_functional_struct_with_vectorized_2d_tensor(device_type: spy.DeviceType): + """Dispatch struct with 2D tensor→float3 child and verify GPU result.""" + device = helpers.get_device(device_type) + src = """ +struct S { + float3 v; + float s; +}; +float3 apply(S st) { return st.v * st.s; } +""" + func = helpers.create_function_from_module(device, "apply", src) + data = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) + tensor_v = Tensor.from_numpy(device, data) + result = func({"_type": "S", "v": tensor_v, "s": 2.0}) + expected = data * 2.0 + np.testing.assert_allclose(result.to_numpy().reshape(expected.shape), expected, atol=1e-5) + + if __name__ == "__main__": pytest.main([__file__, "-vs"]) From 3b8cdc55dc736893c0919e7e487711ed64157b46 Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Wed, 11 Mar 2026 14:57:32 +0000 Subject: [PATCH 12/41] better error --- src/slangpy_ext/utils/slangpy.cpp | 22 ++++------------------ 1 file changed, 4 insertions(+), 18 deletions(-) diff --git a/src/slangpy_ext/utils/slangpy.cpp b/src/slangpy_ext/utils/slangpy.cpp index ad99da49f..030d8d961 100644 --- a/src/slangpy_ext/utils/slangpy.cpp +++ b/src/slangpy_ext/utils/slangpy.cpp @@ -330,25 +330,11 @@ void NativeBoundVariableRuntime::write_raw_dispatch_data(nb::dict call_data, nb: nb::object NativeBoundVariableRuntime::read_output(CallContext* context, nb::object data) { - if (m_children) { - // We have children, so read the output for each child and store in a dictionary. - nb::dict res; - for (const auto& [name, child_ref] : *m_children) { - if (res.contains(name.c_str())) { - if (child_ref) { - nb::object child_data = data[child_ref->m_variable_name.c_str()]; - res[name.c_str()] = child_ref->read_output(context, child_data); - } - } - } - return res; - } else { - // We are a leaf node, so read the output if the variable was writable. - if (m_access.first == AccessType::write || m_access.first == AccessType::readwrite) { - return m_python_type->read_output(context, this, data); - } - return nb::none(); + SGL_CHECK(!m_children, "Internal error: read_output should only be called on leaf nodes."); + if (m_access.first == AccessType::write || m_access.first == AccessType::readwrite) { + return m_python_type->read_output(context, this, data); } + return nb::none(); } Shape NativeBoundCallRuntime::calculate_call_shape( From 2fd9437a03a621296792d7faa29d05209c8b6ff6 Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Wed, 11 Mar 2026 15:34:05 +0000 Subject: [PATCH 13/41] fix read_output --- src/slangpy_ext/utils/slangpy.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/slangpy_ext/utils/slangpy.cpp b/src/slangpy_ext/utils/slangpy.cpp index 030d8d961..64714903c 100644 --- a/src/slangpy_ext/utils/slangpy.cpp +++ b/src/slangpy_ext/utils/slangpy.cpp @@ -330,9 +330,11 @@ void NativeBoundVariableRuntime::write_raw_dispatch_data(nb::dict call_data, nb: nb::object NativeBoundVariableRuntime::read_output(CallContext* context, nb::object data) { - SGL_CHECK(!m_children, "Internal error: read_output should only be called on leaf nodes."); - if (m_access.first == AccessType::write || m_access.first == AccessType::readwrite) { - return m_python_type->read_output(context, this, data); + // Note: variables with children don't read_output directly - it is handled by their children. + if (!m_children) { + if (m_access.first == AccessType::write || m_access.first == AccessType::readwrite) { + return m_python_type->read_output(context, this, data); + } } return nb::none(); } From 8249e4f15e44371cbf9630fc27e8dd8e5f210d79 Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Thu, 12 Mar 2026 10:30:47 +0000 Subject: [PATCH 14/41] Updated plan --- .../prompts/plan-simplifyKernelGen.prompt.md | 258 ++++++++++++------ 1 file changed, 173 insertions(+), 85 deletions(-) diff --git a/.github/prompts/plan-simplifyKernelGen.prompt.md b/.github/prompts/plan-simplifyKernelGen.prompt.md index 75290749c..a21cee212 100644 --- a/.github/prompts/plan-simplifyKernelGen.prompt.md +++ b/.github/prompts/plan-simplifyKernelGen.prompt.md @@ -19,54 +19,176 @@ void compute_main(int3 tid: SV_DispatchThreadID, uniform uint3 _thread_count, un ### Phase Plans -- [Phase 1: Direct Type Marshalling](plan-simplifyKernelGen-phase1.prompt.md) — **implemented (prim-mode)** -- [Phase 2: Eliminate CallData Struct](plan-simplifyKernelGen-phase2.prompt.md) -- [Phase 3: Direct Compute Kernel Invocation](plan-simplifyKernelGen-phase3.prompt.md) +- [Phase 1: Direct Type Marshalling](plan-simplifyKernelGen-phase1.prompt.md) — **✅ merged (PR #863)** +- [Phase 2: Eliminate CallData Struct](plan-simplifyKernelGen-phase2.prompt.md) — **not started** +- [Phase 3: Direct Compute Kernel Invocation](plan-simplifyKernelGen-phase3.prompt.md) — **not started** --- -### Phase 1 Progress +### Phase 1 Summary (Complete — PR #863) -Phase 1 prim-mode direct binding is complete. Steps 1.1–1.7, 1.9 are implemented and passing. Step 1.8 (autodiff/bwds) is deferred. +Phase 1 introduced **direct binding**: dim-0 arguments that can be bound using raw Slang types instead of `ValueType` wrappers, eliminating `__slangpy_load`/`__slangpy_store` indirection, `Context.map()` calls, and mapping constants for eligible arguments. PR #863 was merged to `main` on 2026-03-11 (+2,044 / −122 lines, 18 files changed, squash-merged). -The implementation was refactored from global predicate functions (`is_direct_bind_eligible`, `is_direct_bind_recursive`) with mutable marshall state (`_force_no_direct_bind`, `_set_direct_bind_on_children`) to a marshall-driven `can_direct_bind` property + single depth-first `calculate_direct_bind` pass on the `BoundVariable` tree, following the `calculate_differentiability` pattern. +#### What Phase 1 Changed -**What was done:** +**Architecture**: A marshall-driven `can_direct_bind(binding)` virtual method (default `False`) combined with a single depth-first `calculate_direct_bind()` pass on the `BoundVariable` tree. This follows the same pattern as `calculate_differentiability`. The `direct_bind` boolean is stored on `BoundVariable` (Python) and propagated to `NativeBoundVariableRuntime` (C++). + +**Eligibility**: A variable is direct-bind eligible if: +- `call_dimensionality == 0` (not vectorized) +- Not composite with children (unless all children are also direct-bind AND the composite is dim-0 with a concrete Slang struct type and read-only access) +- Not a param block (`PackedArg`) +- The marshall opts in via `can_direct_bind()` override +- For `ValueRefMarshall`: `access[0] == AccessType.read` (writable value refs need buffer logic) + +**Code generation effects** — when `binding.direct_bind == True`: +- `gen_calldata` emits `typealias _t_{name} = {raw_slang_type}` instead of `ValueType` / `VectorValueType` / `ValueRef` +- `gen_trampoline_load` emits `{value_name} = {data_name};` (direct assignment) instead of `{data_name}.__slangpy_load(context.map(_m_{name}), {name})` +- `gen_trampoline_store` returns `True` (suppresses store for read-only types) +- Mapping constants (`static const int _m_{name} = 0`) are skipped +- `create_calldata` returns the raw value instead of `{"value": data}` + +**C++ fast path**: `NativeValueMarshall::ensure_cached` reads `binding->direct_bind()` to decide cursor navigation — `cursor[variable_name]` for direct-bind vs `cursor[variable_name]["value"]` for wrapper path. + +**Composite (struct/dict) handling**: When `calculate_direct_bind()` visits a composite, it recurses children first. If all children are direct-bind AND the composite is dim-0 with a concrete vector type and read-only access → the composite itself is direct-bind (emits raw `typealias`). Otherwise the composite is NOT direct-bind, but children **retain** their individual `direct_bind` status — the parent's `__slangpy_load`/`__slangpy_store` body uses `gen_trampoline_load`/`gen_trampoline_store` for each child, so direct-bind children get direct assignment (e.g., `value.y = y;`) while non-direct-bind children use `__slangpy_load(context.map(...))`. + +**API changes to `gen_trampoline_load`/`gen_trampoline_store`**: Signature changed from `(cgb, binding, is_entry_point)` → `(cgb, binding, data_name, value_name)`. The caller now computes `data_name` (e.g., `__calldata__.x` or `call_data.x`) and `value_name` (e.g., `x` or `value.x`), allowing these methods to work both at the root trampoline level and inside composite `__slangpy_load`/`__slangpy_store` bodies. + +**`read_output` fix** (C++): `NativeBoundVariableRuntime::read_output` was simplified — composites no longer attempt to read output directly (it is handled by their children). The old composite branch had a logic error (checking `res.contains(name)` before insertion). + +#### Control Flow (post-Phase 1) + +``` +CallData.build() + → calculate_differentiability(context, bindings) + → calculate_direct_binding(bindings) ← Phase 1 + → generate_code(...) + → gen_call_data_code() — reads binding.direct_bind + → gen_trampoline() — reads binding.direct_bind + → BoundCallRuntime(bindings) — propagates binding.direct_bind to C++ runtime +``` + +At dispatch time, `NativeValueMarshall::ensure_cached()` reads `binding->direct_bind()` to decide cursor navigation: +- `direct_bind == false`: `cursor[variable_name]["value"]` (wrapper path) +- `direct_bind == true`: `cursor[variable_name]` (raw type path) + +#### Implemented Steps | Step | Status | Summary | |------|--------|---------| -| 1.2a | ✅ Done | C++ fast path: `NativeValueMarshall::ensure_cached` reads `binding->direct_bind()` from `NativeBoundVariableRuntime` to gate `["value"]` sub-field navigation. `m_direct_bind` **removed** from `NativeValueMarshall` — flag lives on `NativeBoundVariableRuntime`. | -| 1.1 | ✅ Done | `Marshall.can_direct_bind(binding)` virtual method (default `False`). Shared `can_direct_bind_common(binding)` helper. `BoundVariable.calculate_direct_bind()` depth-first tree pass. `calculate_direct_binding(call)` in `callsignature.py`. | -| 1.2 | ✅ Done | `ValueMarshall`: `can_direct_bind` overrides. `gen_calldata`, `gen_trampoline_load/store`, `create_calldata` read `binding.direct_bind`. | -| 1.3 | ✅ Done | `VectorMarshall`: `gen_calldata` emits raw `typealias` (e.g., `vector`). Inherits trampoline load/store and `can_direct_bind` from `ValueMarshall`. | -| 1.4 | ✅ Done | `StructMarshall`/`BoundVariable`: `can_direct_bind` checks all children. `gen_call_data_code` uses `self.direct_bind`. Non-direct-bind composites let children retain their `direct_bind` status; `gen_call_data_code` delegates to children's `gen_trampoline_load/store`. | -| 1.5 | ✅ Done | `ValueRefMarshall`: `can_direct_bind` requires read-only access. Writable value refs (including auto-created `_result`) use wrapper path (`RWValueRef`). | -| 1.6 | ✅ Done | Tensor dim-0: `gen_trampoline_load/store` extended for `ITensorType` at dim-0 (direct struct assignment). | +| 1.1 | ✅ Done | `Marshall.can_direct_bind(binding)` virtual method. `can_direct_bind_common(binding)` helper. `BoundVariable.calculate_direct_bind()` depth-first tree pass. `calculate_direct_binding(call)` in `callsignature.py`. | +| 1.2 | ✅ Done | `ValueMarshall`: `can_direct_bind`, `gen_calldata`, `gen_trampoline_load/store` read `binding.direct_bind`. | +| 1.2a | ✅ Done | C++ fast path: `NativeValueMarshall::ensure_cached` reads `binding->direct_bind()` from `NativeBoundVariableRuntime`. `m_direct_bind` **removed** from `NativeValueMarshall`. | +| 1.3 | ✅ Done | `VectorMarshall`/`MatrixMarshall`/`ArrayMarshall`: inherit from `ValueMarshall`. `VectorMarshall.gen_calldata` emits raw vector type (e.g., `vector`). | +| 1.4 | ✅ Done | `StructMarshall`: `can_direct_bind` checks all children. `BoundVariable.gen_call_data_code` uses `self.direct_bind`. Non-direct-bind composites delegate to children's `gen_trampoline_load/store`. | +| 1.5 | ✅ Done | `ValueRefMarshall`: `can_direct_bind` requires `access[0] == AccessType.read`. Writable value refs (including auto-created `_result`) use `RWValueRef`. | +| 1.6 | ✅ Done | Tensor dim-0: `can_direct_bind` added to `tensorcommon.py`. `gen_trampoline_load/store` extended for dim-0 tensors (`ITensorType`, `TensorViewType`, `DiffTensorViewType`). | | 1.7 | ✅ Done | Mapping constants (`static const int _m_{name}`) skipped when `self.direct_bind`. | -| 1.8 | ⬜ Deferred | Autodiff/bwds mode still uses wrapper types. | -| 1.9 | ✅ Done | 21 tests (×3 device types = 63 cases): 16 code-gen assertion tests + 5 functional GPU dispatch tests. All pass on d3d12/vulkan/cuda. | +| 1.8 | ⬜ Deferred | Autodiff derivative fields still use `ValueType` wrappers. Bwds primals use direct bind. | +| 1.9 | ✅ Done | 77 tests (×3 device types = 231 cases). All pass on d3d12/vulkan/cuda. | -**Files modified:** +#### Files Modified (PR #863) | File | Changes | |------|---------| -| `src/slangpy_ext/utils/slangpy.h` | `m_direct_bind` member, getter/setter on `NativeBoundVariableRuntime` | -| `src/slangpy_ext/utils/slangpy.cpp` | Nanobind `direct_bind` property on `NativeBoundVariableRuntime` | -| `src/slangpy_ext/utils/slangpyvalue.h` | `m_direct_bind`, `direct_bind()`, `set_direct_bind()` **removed** from `NativeValueMarshall` | -| `src/slangpy_ext/utils/slangpyvalue.cpp` | `ensure_cached` reads `binding->direct_bind()`; nanobind `direct_bind` property **removed** from `NativeValueMarshall` | -| `slangpy/bindings/marshall.py` | `can_direct_bind(binding)` virtual method (default `False`) | -| `slangpy/bindings/boundvariable.py` | `can_direct_bind_common()`, `BoundVariable.direct_bind`, `calculate_direct_bind()`. Removed: `is_direct_bind_eligible`, `is_direct_bind_recursive`, `_set_direct_bind_on_children`, `_force_no_direct_bind`, `_DIRECT_BIND_TYPES`, `_clear_direct_bind()`. | -| `slangpy/bindings/boundvariableruntime.py` | `self.direct_bind = source.direct_bind` propagation | -| `slangpy/bindings/__init__.py` | Exports `can_direct_bind_common` (removed old predicate exports) | -| `slangpy/core/callsignature.py` | `calculate_direct_binding(call)` function | -| `slangpy/core/calldata.py` | `calculate_direct_binding(bindings)` call after `calculate_differentiability` | -| `slangpy/builtin/value.py` | `can_direct_bind`, `gen_calldata`, `gen_trampoline_load/store`, `create_calldata` use `binding.direct_bind` | -| `slangpy/builtin/valueref.py` | `can_direct_bind` (read-only only), all methods use `binding.direct_bind`. Removed `self._direct_bind`. | -| `slangpy/builtin/struct.py` | `can_direct_bind`, `gen_trampoline_load/store` use `binding.direct_bind` | -| `slangpy/builtin/tensorcommon.py` | `gen_trampoline_load/store` extended for `ITensorType` (unchanged in refactor) | -| `slangpy/tests/slangpy_tests/test_kernel_gen.py` | All Phase 1 tests | - -**Test results:** 2952 passed / 0 failed in `slangpy/tests/slangpy_tests`. 6 pre-existing failures in `slangpy/tests/device/` (raytracing pipeline, type conformance cache — unrelated). +| `src/slangpy_ext/utils/slangpy.h` | `m_direct_bind` member, `direct_bind()`, `set_direct_bind()` on `NativeBoundVariableRuntime` | +| `src/slangpy_ext/utils/slangpy.cpp` | Nanobind `direct_bind` r/w property on `NativeBoundVariableRuntime`. `read_output` composite branch simplified. | +| `src/slangpy_ext/utils/slangpyvalue.h` | `CachedValueWrite.direct_bind` field added. `m_direct_bind`/`direct_bind()`/`set_direct_bind()` **removed** from `NativeValueMarshall`. | +| `src/slangpy_ext/utils/slangpyvalue.cpp` | `ensure_cached` reads `binding->direct_bind()` for cursor path; caches `direct_bind` value. | +| `slangpy/bindings/marshall.py` | `can_direct_bind(binding)` virtual method (default `False`). `gen_trampoline_load/store` signature changed to `(cgb, binding, data_name, value_name)`. | +| `slangpy/bindings/boundvariable.py` | `can_direct_bind_common()` helper. `BoundVariable.direct_bind` attribute. `BoundVariable.calculate_direct_bind()` method. `gen_call_data_code` handles direct-bind composites (raw typealias) and delegates to children's `gen_trampoline_load/store`. Mapping constant emission gated on `not self.direct_bind`. | +| `slangpy/bindings/boundvariableruntime.py` | `self.direct_bind = source.direct_bind` propagation to C++ runtime. | +| `slangpy/bindings/__init__.py` | Exports `can_direct_bind_common`. | +| `slangpy/core/callsignature.py` | `calculate_direct_binding(call)` function. Trampoline code gen refactored: `data_name` computed before `gen_trampoline_load` call. Store path moved after `data_name` computation. | +| `slangpy/core/calldata.py` | `calculate_direct_binding(bindings)` call after `calculate_differentiability`. `self.code = code` stored for debugging. | +| `slangpy/builtin/value.py` | `can_direct_bind`, `gen_trampoline_load`, `gen_trampoline_store` added. `gen_calldata` gates on `binding.direct_bind`. | +| `slangpy/builtin/valueref.py` | `can_direct_bind` (read-only gate), `gen_trampoline_load`, `gen_trampoline_store` added. `gen_calldata`, `create_calldata`, `read_calldata` gate on `binding.direct_bind`. `self._direct_bind` removed. | +| `slangpy/builtin/struct.py` | `can_direct_bind` (children check + `AccessType.read` gate). `gen_trampoline_load`, `gen_trampoline_store` delegate to `ValueMarshall` when direct-bind. | +| `slangpy/builtin/tensor.py` | `can_direct_bind` delegates to `tensorcommon`. `gen_trampoline_load/store` signature updated. | +| `slangpy/builtin/tensorcommon.py` | `can_direct_bind()` function added. `gen_trampoline_load/store` signature changed, condition changed from `isinstance(vector_type, TensorViewType)` to `binding.direct_bind`. | +| `slangpy/torchintegration/torchtensormarshall.py` | `can_direct_bind` delegates to `tensorcommon`. `gen_trampoline_load/store` signature updated. | +| `slangpy/benchmarks/test_benchmark_autograd.py` | Removed accidental blank line (1-line whitespace change). | +| `slangpy/tests/slangpy_tests/test_kernel_gen.py` | New file: 77 tests covering all Phase 1 scenarios. | + +#### Test Coverage Summary + +The test file (`test_kernel_gen.py`) provides 77 test functions × 3 device types = 231 parametrized cases covering: + +**Code-gen assertion tests** (`test_gate_*`): Verify generated Slang code patterns — type aliases, trampoline load/store statements, mapping constants, wrapper types, `__slangpy_load`/`__slangpy_store` presence/absence. + +**Binding flag tests**: Verify `direct_bind`, `call_dimensionality`, and `vector_type` on `BoundVariable` instances for: scalars, vectors, tensors (dim-0 and vectorized), structs (all-scalar, mixed, nested, deeply nested), writable ValueRef, auto-created `_result`, WangHashArg, bwds primal args. + +**Functional GPU dispatch tests** (`test_phase1_functional_*`): End-to-end dispatch verifying correct GPU results for: scalar add/mul, vector scale, struct sum, ValueRef write, mixed scalar+tensor, mixed struct fields, tensor dim-0, 2D/3D tensor→vector, 2D tensor→scalar, 2D tensor→array, nested/deeply-nested structs, struct with matrix/vector/array fields, struct return types, struct with vectorized 2D tensor child. + +**Negative gates** (`test_gate_*_keeps_*`): Verify types that are NOT direct-bind eligible remain using wrappers: WangHashArg, vectorized scalar (dim > 0), vectorized dict. + +**Helper infrastructure**: `assert_contains`, `assert_not_contains`, `assert_trampoline_has`, `generate_code`, `generate_bwds_code`. + +#### Known Issues (from review, not yet addressed) + +1. **`set_direct_bind` exposed as read-write nanobind property** — After first dispatch, mutating `direct_bind` would invalidate the cached cursor offset. Consider making it read-only. + +2. **C++ cache safety** — `NativeValueMarshall::ensure_cached` caches `direct_bind` but has no debug assertion verifying it matches on subsequent calls. + +3. **Dead `binding.direct_bind` checks in writable ValueRef paths** — `create_calldata` and `read_calldata` in `valueref.py` have `assert not binding.direct_bind` in writable code paths (reachable only as assertions, since `can_direct_bind` rejects non-read access). + +--- + +### What Phase 2 Needs to Know + +Phase 2 builds on Phase 1's `direct_bind` infrastructure. Key context for implementation: + +**Current kernel structure** (post-Phase 1, for `int add(int a, int b)` with args `(1, 2)`): +```slang +import "module"; +import "slangpy"; +// CallData struct with per-arg type aliases and mapping constants +struct CallData { + typealias _t_a = int; // Phase 1: raw type (was ValueType) + _t_a a; + typealias _t_b = int; // Phase 1: raw type (was ValueType) + _t_b b; + typealias _t__result = RWValueRef; // writable _result still wrapped + _t__result _result; + static const int _m__result = 0; // mapping constant only for _result + uint3 _thread_count; + // ... shape arrays if call_data_len > 0 ... +}; +void _trampoline(CallData call_data /*or __calldata__ on CUDA*/) { + int a; + a = call_data.a; // Phase 1: direct assignment (was __slangpy_load) + int b; + b = call_data.b; // Phase 1: direct assignment + int _result; + _result = add(a, b); + call_data._result.__slangpy_store(__slangpy_context__.map(_m__result), _result); +} +[shader("compute")] [numthreads(32,1,1)] +void compute_main(..., uniform CallData call_data) { + // thread bounds check, context construction + _trampoline(call_data); +} +``` + +**Phase 2 goal**: Eliminate the `CallData` struct entirely when ALL args are direct-bind eligible. Pass args as individual `uniform` parameters on the entry point. Inline the function call into `compute_main` (skip trampoline for prim mode). + +**Blocking issue for Phase 2**: Auto-created `_result` is a writable `ValueRef` → NOT direct-bind (needs `RWValueRef` wrapper with buffer). Phase 2 must either: +- Accept that `_result` prevents full CallData elimination for functions with return values, and use a hybrid approach (direct args + `_result` in CallData or as a separate `RWStructuredBuffer` entry point param), OR +- Add a new code path for `_result` that emits `uniform RWStructuredBuffer _result` as an entry point param with `_result[0] = ...` for the store + +**Key files for Phase 2**: +- `slangpy/core/callsignature.py` — `generate_code()` builds the trampoline and compute_main +- `slangpy/core/calldata.py` — `CallData.build()` orchestrates the pipeline +- `slangpy/bindings/codegen.py` — `CodeGen` class manages `call_data_structs` block +- `src/slangpy_ext/utils/slangpy.cpp` — `NativeCallData::exec()` dispatches; cursor navigation for uniforms + +**`BoundVariable.direct_bind`** is already computed for all args by Phase 1. Phase 2 can check `all(arg.direct_bind for arg in all_args)` to decide whether to use the direct-args path. + +**Entry point parameter precedent**: See `slangpy/tests/device/test_pipeline_utils.slang` — manually written compute shaders already use individual `uniform` entry point params on all backends (CUDA, Vulkan, D3D12). + +**Design decisions deferred to Phase 2**: +- Whether to support hybrid kernels (some args as entry-point params, some in CallData) or only all-or-nothing +- Handle entry-point parameter size limits (CUDA ~4KB root constants, D3D12 64 DWORD root signature limit) +- Whether to inline the function call directly in compute_main for prim mode, or keep a simplified trampoline --- @@ -82,31 +204,27 @@ Before implementing any phase, add **gating tests** to [slangpy/tests/slangpy_te - Named `test_gate_*` for easy identification - WangHashArg and dict/composite tests serve as "negative gates" — they remain passing after simplification -**Test infrastructure additions:** +**Test infrastructure** (already present in `test_kernel_gen.py`): ```python -def assert_contains(code: str, *patterns: str) -> None: - for p in patterns: - assert p in code, f"Expected pattern not found: {p}" - -def assert_not_contains(code: str, *patterns: str) -> None: - for p in patterns: - assert p not in code, f"Unexpected pattern found: {p}" - -def generate_bwds_code(device, func_name, module_source, *args, **kwargs) -> str: - func = helpers.create_function_from_module(device, func_name, module_source) - cd = func.bwds.debug_build_call_data(*args, **kwargs) - if PRINT_TEST_KERNEL_GEN: - print(cd.code) - return cd.code +def assert_contains(code: str, *patterns: str) -> None +def assert_not_contains(code: str, *patterns: str) -> None +def assert_trampoline_has(code: str, *stmts: str) -> None +def generate_code(device, func_name, module_source, *args, **kwargs) -> str +def generate_bwds_code(device, func_name, module_source, *args, **kwargs) -> str ``` -**Summary of all gating tests by phase:** +**Phase 2 gating tests to add** (assert CURRENT behavior, will break on implementation): -| Phase | Gating Tests (break on implementation) | Negative Gates (must stay passing) | -|-------|---------------------------------------|-----------------------------------| -| 1 | 12 tests: scalar/float/vector/matrix/valueref-read/valueref-write/array/mapping-constants/context-map/struct-slangpy-load/bwds-scalar/bwds-trampoline | 3 tests: wanghasharg/vectorized-scalar/vectorized-dict | -| 2 | 7 tests: calldata-struct/calldata-uniform/thread-count/context-from-calldata/trampoline-present/trampoline-calls/kernel-calls-trampoline | 1 test: wanghasharg-forces-calldata | -| 3 | 1 test: compute-shader-generates-wrapper | — | +| Test | Asserts (current behavior) | Breaks when | +|------|---------------------------|-------------| +| `test_gate_calldata_struct_present` | `struct CallData` present | Step 2.1 | +| `test_gate_calldata_uniform_param` | `uniform CallData call_data` in `compute_main` | Step 2.2 | +| `test_gate_thread_count_in_calldata` | `call_data._thread_count` in kernel body | Step 2.4 | +| `test_gate_context_from_calldata` | `Context __slangpy_context__` present | Step 2.4 | +| `test_gate_trampoline_present_for_prim` | `void _trampoline(` present | Step 2.5 | +| `test_gate_trampoline_calls_function` | `_result = add(a, b)` inside trampoline | Step 2.5 | +| `test_gate_kernel_calls_trampoline` | `_trampoline(` inside `compute_main` | Step 2.5 | +| `test_gate_wanghasharg_forces_calldata` (negative) | `struct CallData` present with non-eligible arg | Must stay passing | --- @@ -133,33 +251,3 @@ pre-commit run --all-files - Phase 2 targets both `entry_point` (CUDA) and `global_data` (Vulkan/D3D12) modes - Autograd (bwds mode) is included in simplification, but implemented after prim mode within each phase - WangHashArg explicitly excluded from direct binding (needs per-thread `thread_id` computation) - ---- - -### Code Review Notes (PR #862) - -**Bugs / concerns found:** - -1. **Benchmark file changes are accidental** — `test_benchmark_autograd.py` has local tuning changes (ITERATIONS 10→100, WARMUPS 10→1000, RUN_SLANGTORCH_BENCHMARK False→True) that should be reverted before merge. - -2. **Composite `calculate_direct_bind` doesn't consult the marshall** — When `self.children is not None`, `calculate_direct_bind()` hard-codes eligibility criteria and never calls `self.python.can_direct_bind(self)`. This auto-opts-in composites if all children pass, preventing a marshall from rejecting. The `StructMarshall.can_direct_bind` children branch is dead code as a result. Either have composites delegate to the marshall, or remove the dead code from `StructMarshall.can_direct_bind`. - -3. **Composite direct-bind doesn't gate on read-only access** — `calculate_direct_bind`'s composite branch doesn't check `self.access[0] == AccessType.read`. A writable composite at dim-0 could be marked direct-bind, but no `__slangpy_store` is generated for the raw type alias. `ValueRefMarshall` correctly gates on read-only — composites should too. - -4. **Dead code in `ValueRefMarshall.create_calldata`/`read_calldata`** — `if binding.direct_bind` checks in writable code paths are unreachable since `can_direct_bind` rejects non-read access. Remove or assert. - -5. **C++ cache safety** — `NativeValueMarshall::ensure_cached` caches the `direct_bind` cursor path on first dispatch but has no assertion that subsequent calls use the same `direct_bind` value. Safe in current architecture (each call signature gets its own marshall), but fragile. Consider adding `SGL_ASSERT(m_cached.direct_bind == binding->direct_bind())` for debug builds. - -6. **`set_direct_bind` exposed as read-write nanobind property** — After first dispatch, mutating `direct_bind` invalidates the cached cursor offset silently. Consider making it read-only in the nanobind binding. - -**Missing test coverage (high priority):** - -| Test | Purpose | -|------|---------| -| Writable `ValueRef` `inout` param → `direct_bind=False` | Guards access-check logic in `ValueRefMarshall.can_direct_bind` | -| `_result` auto-created binding → `direct_bind=False` flag | Binding-level assertion, not just codegen | -| All-scalar struct → `direct_bind=True` binding flag | Struct direct-bind logic verified at binding level | -| Struct with WangHashArg child → composite NOT direct-bind | Mixed non-eligible child in composite | -| WangHashArg → `direct_bind=False` binding flag | Type without `can_direct_bind` override | -| Functional GPU test: read-only `ValueRef` input | End-to-end direct-bind ValueRef pipeline | -| Bwds mode binding flags on primal args | Verify access-tuple indexing in backwards mode | From 9b32e0e7ef38a4fc889b33509857ecd993fded97 Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Thu, 12 Mar 2026 11:44:11 +0000 Subject: [PATCH 15/41] Reduce type alias use --- slangpy/bindings/boundvariable.py | 47 ++++- slangpy/builtin/accelerationstructure.py | 3 +- slangpy/builtin/array.py | 7 +- slangpy/builtin/descriptor.py | 6 +- slangpy/builtin/diffpair.py | 1 + slangpy/builtin/ndbuffer.py | 7 +- slangpy/builtin/range.py | 3 +- slangpy/builtin/structuredbuffer.py | 14 +- slangpy/builtin/tensorcommon.py | 2 +- slangpy/builtin/texture.py | 17 +- slangpy/builtin/value.py | 16 +- slangpy/builtin/valueref.py | 7 +- slangpy/experimental/gridarg.py | 3 +- .../tests/slangpy_tests/test_interfaces.py | 2 +- .../tests/slangpy_tests/test_kernel_gen.py | 197 ++++++++++++++---- slangpy/types/callidarg.py | 3 +- slangpy/types/randfloatarg.py | 4 +- slangpy/types/threadidarg.py | 3 +- slangpy/types/wanghasharg.py | 3 +- 19 files changed, 243 insertions(+), 102 deletions(-) diff --git a/slangpy/bindings/boundvariable.py b/slangpy/bindings/boundvariable.py index aa318ec25..9214ae607 100644 --- a/slangpy/bindings/boundvariable.py +++ b/slangpy/bindings/boundvariable.py @@ -6,7 +6,7 @@ from slangpy import ModifierID from slangpy.bindings.marshall import BindContext -from slangpy.bindings.codegen import CodeGen +from slangpy.bindings.codegen import CodeGen, CodeGenBlock from slangpy.bindings.typeregistry import get_or_create_type from slangpy.reflection import ( SlangField, @@ -16,6 +16,11 @@ ) from slangpy.reflection.typeresolution import ResolvedParam +#: Type names longer than this threshold get a ``typealias _t_{name}`` alias +#: to keep the generated ``CallData`` struct readable. Shorter names are +#: inlined directly. +MAX_INLINE_TYPE_LEN = 60 + class BoundVariableException(Exception): """ @@ -199,6 +204,9 @@ def __init__( #: Whether this variable uses direct binding (raw Slang type, no wrapper). self.direct_bind = False + #: The resolved Slang type name for this variable's CallData field. + self.calldata_type_name: Optional[str] = None + #: Call dimensionality of this variable. self.call_dimensionality = None @@ -272,6 +280,23 @@ def debug_name(self) -> str: else: return f"arg{self.python_pos_arg_index}" + def gen_calldata_type_name(self, cgb: CodeGenBlock, type_name: str) -> None: + """Record the Slang type name for this variable's CallData field. + + If the type name exceeds ``MAX_INLINE_TYPE_LEN``, a + ``typealias _t_{name}`` is emitted and the alias is stored. + Otherwise the raw type name is stored directly. + + :param cgb: The code-gen block to write the type alias to (if needed). + :param type_name: The resolved Slang type name. + """ + if len(type_name) > MAX_INLINE_TYPE_LEN: + alias = f"_t_{self.variable_name}" + cgb.type_alias(alias, type_name) + self.calldata_type_name = alias + else: + self.calldata_type_name = type_name + def bind( self, slang: Union[SlangField, ResolvedParam, SlangType], @@ -581,17 +606,21 @@ def gen_call_data_code(self, cg: CodeGen, context: BindContext, depth: int = 0): cgb = cg.call_data_structs if self.direct_bind: - # Direct-bind: emit raw type alias + # Direct-bind: use raw type name directly assert self.vector_type is not None - cgb.type_alias(f"_t_{self.variable_name}", self.vector_type.full_name) + self.gen_calldata_type_name(cgb, self.vector_type.full_name) else: - cgb.begin_struct(f"_t_{self.variable_name}") + struct_name = f"_t_{self.variable_name}" + cgb.begin_struct(struct_name) for field, variable in self.children.items(): variable.gen_call_data_code(cg, context, depth + 1) for var in self.children.values(): - cgb.declare(f"_t_{var.variable_name}", var.variable_name) + assert ( + var.calldata_type_name is not None + ), f"calldata_type_name not set for '{var.variable_name}'" + cgb.declare(var.calldata_type_name, var.variable_name) assert self.vector_type is not None context_decl = f"ContextND<{self.call_dimensionality}> context" @@ -633,6 +662,7 @@ def gen_call_data_code(self, cg: CodeGen, context: BindContext, depth: int = 0): cgb.end_block() cgb.end_struct() + self.calldata_type_name = struct_name else: # Generate call data @@ -650,10 +680,13 @@ def gen_call_data_code(self, cg: CodeGen, context: BindContext, depth: int = 0): ) if depth == 0: + assert ( + self.calldata_type_name is not None + ), f"calldata_type_name not set for '{self.variable_name}'" if self.create_param_block: - cg.add_parameter_block(f"_t_{self.variable_name}", "_param_" + self.variable_name) + cg.add_parameter_block(self.calldata_type_name, "_param_" + self.variable_name) else: - cg.call_data.declare(f"_t_{self.variable_name}", self.variable_name) + cg.call_data.declare(self.calldata_type_name, self.variable_name) def _gen_trampoline_argument(self): assert self.vector_type is not None diff --git a/slangpy/builtin/accelerationstructure.py b/slangpy/builtin/accelerationstructure.py index 5fa717cb0..cc749cbe1 100644 --- a/slangpy/builtin/accelerationstructure.py +++ b/slangpy/builtin/accelerationstructure.py @@ -29,9 +29,8 @@ def __init__(self, layout: kfr.SlangProgramLayout): # Call data can only be read access to primal, and simply declares it as a variable def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: "BoundVariable"): - name = binding.variable_name assert isinstance(binding.vector_type, kfr.RaytracingAccelerationStructureType) - cgb.type_alias(f"_t_{name}", f"RaytracingAccelerationStructureType") + binding.gen_calldata_type_name(cgb, "RaytracingAccelerationStructureType") # Call data just returns the primal def create_calldata( diff --git a/slangpy/builtin/array.py b/slangpy/builtin/array.py index b7113b046..dc847d877 100644 --- a/slangpy/builtin/array.py +++ b/slangpy/builtin/array.py @@ -115,7 +115,6 @@ def resolve_types(self, context: BindContext, bound_type: "SlangType"): # Call data can only be read access to primal, and simply declares it as a variable def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: "BoundVariable"): access = binding.access - name = binding.variable_name if access[0] in [AccessType.read, AccessType.readwrite]: if binding.call_dimensionality == 0: # If not vectorizing, fallback to use of basic type as it works well @@ -125,9 +124,11 @@ def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: "BoundV # If vectorizing, utilize the value type. st = cast(kfr.ArrayType, self.slang_type) et = cast(SlangType, st.element_type) - cgb.type_alias(f"_t_{name}", f"Array1DValueType<{et.full_name},{st.num_elements}>") + binding.gen_calldata_type_name( + cgb, f"Array1DValueType<{et.full_name},{st.num_elements}>" + ) else: - cgb.type_alias(f"_t_{name}", f"NoneType") + binding.gen_calldata_type_name(cgb, "NoneType") def build_shader_object(self, context: "BindContext", data: Any) -> "ShaderObject": if len(self.concrete_shape) != 1: diff --git a/slangpy/builtin/descriptor.py b/slangpy/builtin/descriptor.py index 7bcc3adb0..109029a85 100644 --- a/slangpy/builtin/descriptor.py +++ b/slangpy/builtin/descriptor.py @@ -50,12 +50,12 @@ def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: "BoundV name = binding.variable_name if access[0] in [AccessType.read, AccessType.readwrite]: assert binding.vector_type is not None - cgb.type_alias( - f"_t_{name}", + binding.gen_calldata_type_name( + cgb, binding.vector_type.full_name.replace("DescriptorHandle", "DescriptorType"), ) else: - cgb.type_alias(f"_t_{name}", f"NoneType") + binding.gen_calldata_type_name(cgb, "NoneType") def reduce_type(self, context: BindContext, dimensions: int) -> kfr.SlangType: if dimensions == 0: diff --git a/slangpy/builtin/diffpair.py b/slangpy/builtin/diffpair.py index 741aff9e8..53726faa8 100644 --- a/slangpy/builtin/diffpair.py +++ b/slangpy/builtin/diffpair.py @@ -154,6 +154,7 @@ def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: "BoundV deriv_target, ) ) + binding.calldata_type_name = f"_t_{name}" def get_type(self, prim: PrimType): return ( diff --git a/slangpy/builtin/ndbuffer.py b/slangpy/builtin/ndbuffer.py index 3807d9341..6ee073b96 100644 --- a/slangpy/builtin/ndbuffer.py +++ b/slangpy/builtin/ndbuffer.py @@ -279,7 +279,6 @@ def ndbuffer_gen_calldata( binding: "BoundVariable", ): access = binding.access - name = binding.variable_name assert access[0] != AccessType.none assert access[1] == AccessType.none writable = access[0] != AccessType.read @@ -287,7 +286,7 @@ def ndbuffer_gen_calldata( # If passing to NDBuffer, just use the NDBuffer type assert access[0] == AccessType.read assert isinstance(binding.vector_type, ITensorType) - cgb.type_alias(f"_t_{name}", binding.vector_type.full_name) + binding.gen_calldata_type_name(cgb, binding.vector_type.full_name) else: # If we pass to a structured buffer, check the writable flag from the type if isinstance(binding.vector_type, StructuredBufferType): @@ -296,9 +295,9 @@ def ndbuffer_gen_calldata( # If broadcasting to an element, use the type of this buffer for code gen\ et = cast(SlangType, self.slang_element_type) if writable: - cgb.type_alias(f"_t_{name}", f"RWTensor<{et.full_name},{self.dims}>") + binding.gen_calldata_type_name(cgb, f"RWTensor<{et.full_name},{self.dims}>") else: - cgb.type_alias(f"_t_{name}", f"Tensor<{et.full_name},{self.dims}>") + binding.gen_calldata_type_name(cgb, f"Tensor<{et.full_name},{self.dims}>") class BaseNDBufferMarshall(Marshall): diff --git a/slangpy/builtin/range.py b/slangpy/builtin/range.py index cb22a375b..4f6ab7d1b 100644 --- a/slangpy/builtin/range.py +++ b/slangpy/builtin/range.py @@ -36,9 +36,8 @@ def __init__(self, layout: SlangProgramLayout): def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: "BoundVariable"): access = binding.access - name = binding.variable_name if access[0] == AccessType.read: - cgb.type_alias(f"_t_{name}", self.slang_type.full_name) + binding.gen_calldata_type_name(cgb, self.slang_type.full_name) def create_calldata( self, context: CallContext, binding: BoundVariableRuntime, data: range diff --git a/slangpy/builtin/structuredbuffer.py b/slangpy/builtin/structuredbuffer.py index ff9cd2894..531a03971 100644 --- a/slangpy/builtin/structuredbuffer.py +++ b/slangpy/builtin/structuredbuffer.py @@ -85,24 +85,24 @@ def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: "BoundV if isinstance(binding.vector_type, StructuredBufferType): assert binding.vector_type.element_type is not None if binding.vector_type.writable: - cgb.type_alias( - f"_t_{name}", + binding.gen_calldata_type_name( + cgb, f"RWStructuredBufferType<{binding.vector_type.element_type.full_name}>", ) else: - cgb.type_alias( - f"_t_{name}", + binding.gen_calldata_type_name( + cgb, f"StructuredBufferType<{binding.vector_type.element_type.full_name}>", ) elif isinstance(binding.vector_type, ByteAddressBufferType): if binding.vector_type.writable: - cgb.type_alias(f"_t_{name}", f"RWByteAddressBufferType") + binding.gen_calldata_type_name(cgb, "RWByteAddressBufferType") else: - cgb.type_alias(f"_t_{name}", f"ByteAddressBufferType") + binding.gen_calldata_type_name(cgb, "ByteAddressBufferType") elif isinstance(binding.vector_type, PointerType): # To bind as a pointer, use the 'ValueType', which just like the buffer wrappers # has a 'value' field that refers to the actual buffer (in this case as a pointer) - cgb.type_alias(f"_t_{name}", f"ValueType<{binding.vector_type.full_name}>") + binding.gen_calldata_type_name(cgb, f"ValueType<{binding.vector_type.full_name}>") else: raise ValueError( "Raw buffers can not be vectorized. If you need vectorized buffers, see the Tensor slangpy type" diff --git a/slangpy/builtin/tensorcommon.py b/slangpy/builtin/tensorcommon.py index ca39d6cb1..e6923fdde 100644 --- a/slangpy/builtin/tensorcommon.py +++ b/slangpy/builtin/tensorcommon.py @@ -376,7 +376,7 @@ def gen_calldata( access=access, tensor_type=tensor_type, ) - cgb.type_alias(f"_t_{binding.variable_name}", type_name) + binding.gen_calldata_type_name(cgb, type_name) def can_direct_bind(self: ITensorMarshall, binding: BoundVariable) -> bool: diff --git a/slangpy/builtin/texture.py b/slangpy/builtin/texture.py index 810329234..a128b6968 100644 --- a/slangpy/builtin/texture.py +++ b/slangpy/builtin/texture.py @@ -212,7 +212,7 @@ def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: "BoundV name = binding.variable_name if access == AccessType.none: - cgb.type_alias(f"_t_{name}", f"NoneType") + binding.gen_calldata_type_name(cgb, "NoneType") return if binding.call_dimensionality == 0: @@ -223,23 +223,25 @@ def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: "BoundV raise ValueError( f"Cannot bind texture view {name} with usage {binding.vector_type.usage}" ) - cgb.type_alias(f"_t_{name}", binding.vector_type.full_name.replace("<", "Type<", 1)) + binding.gen_calldata_type_name( + cgb, binding.vector_type.full_name.replace("<", "Type<", 1) + ) elif binding.call_dimensionality == self.texture_dims: # If broadcast is the same shape as the texture, this is loading from pixels, so use the # type required to support the required access if access == AccessType.read: # Read access can be either shader resource or UAV, so just bind the correct type # for this resource view - cgb.type_alias( - f"_t_{name}", + binding.gen_calldata_type_name( + cgb, self.build_accessor_name(self.usage, self.slang_element_type), ) else: # Write access requires a UAV so check it and bind RW type if not has_uav(self.usage): raise ValueError(f"Cannot write to read-only texture {name}") - cgb.type_alias( - f"_t_{name}", + binding.gen_calldata_type_name( + cgb, self.build_accessor_name( TextureUsage.unordered_access, self.slang_element_type ), @@ -373,9 +375,8 @@ def __init__(self, layout: refl.SlangProgramLayout): # Call data can only be read access to primal, and simply declares it as a variable def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: "BoundVariable"): - name = binding.variable_name assert isinstance(binding.vector_type, refl.SamplerStateType) - cgb.type_alias(f"_t_{name}", f"SamplerStateType") + binding.gen_calldata_type_name(cgb, "SamplerStateType") # Call data just returns the primal def create_calldata( diff --git a/slangpy/builtin/value.py b/slangpy/builtin/value.py index 69aa66dfb..f555a3b73 100644 --- a/slangpy/builtin/value.py +++ b/slangpy/builtin/value.py @@ -96,15 +96,14 @@ def can_direct_bind(self, binding: "BoundVariable") -> bool: # Call data can only be read access to primal, and simply declares it as a variable def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: "BoundVariable"): access = binding.access - name = binding.variable_name if access[0] in [AccessType.read, AccessType.readwrite]: assert binding.vector_type is not None if binding.direct_bind: - cgb.type_alias(f"_t_{name}", binding.vector_type.full_name) + binding.gen_calldata_type_name(cgb, binding.vector_type.full_name) else: - cgb.type_alias(f"_t_{name}", f"ValueType<{binding.vector_type.full_name}>") + binding.gen_calldata_type_name(cgb, f"ValueType<{binding.vector_type.full_name}>") else: - cgb.type_alias(f"_t_{name}", f"NoneType") + binding.gen_calldata_type_name(cgb, "NoneType") def gen_trampoline_load( self, cgb: CodeGenBlock, binding: "BoundVariable", data_name: str, value_name: str @@ -328,16 +327,17 @@ def resolve_types(self, context: BindContext, bound_type: "SlangType"): # Call data can only be read access to primal, and simply declares it as a variable def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: "BoundVariable"): access = binding.access - name = binding.variable_name if access[0] in [AccessType.read, AccessType.readwrite]: st = cast(kfr.VectorType, self.slang_type) et = cast(SlangType, st.element_type) if binding.direct_bind: - cgb.type_alias(f"_t_{name}", binding.vector_type.full_name) + binding.gen_calldata_type_name(cgb, binding.vector_type.full_name) else: - cgb.type_alias(f"_t_{name}", f"VectorValueType<{et.full_name},{st.num_elements}>") + binding.gen_calldata_type_name( + cgb, f"VectorValueType<{et.full_name},{st.num_elements}>" + ) else: - cgb.type_alias(f"_t_{name}", f"NoneType") + binding.gen_calldata_type_name(cgb, "NoneType") def build_shader_object(self, context: "BindContext", data: Any) -> "slangpy.ShaderObject": unpacked = unpack_arg(data) diff --git a/slangpy/builtin/valueref.py b/slangpy/builtin/valueref.py index c13549d0f..fb4196089 100644 --- a/slangpy/builtin/valueref.py +++ b/slangpy/builtin/valueref.py @@ -159,18 +159,17 @@ def can_direct_bind(self, binding: "BoundVariable") -> bool: # Call data can only be read access to primal, and simply declares it as a variable def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: "BoundVariable"): access = binding.access - name = binding.variable_name assert access[0] != AccessType.none assert access[1] == AccessType.none assert binding.vector_type is not None if binding.direct_bind: assert access[0] == AccessType.read - cgb.type_alias(f"_t_{name}", binding.vector_type.full_name) + binding.gen_calldata_type_name(cgb, binding.vector_type.full_name) else: if access[0] == AccessType.read: - cgb.type_alias(f"_t_{name}", f"ValueRef<{binding.vector_type.full_name}>") + binding.gen_calldata_type_name(cgb, f"ValueRef<{binding.vector_type.full_name}>") else: - cgb.type_alias(f"_t_{name}", f"RWValueRef<{binding.vector_type.full_name}>") + binding.gen_calldata_type_name(cgb, f"RWValueRef<{binding.vector_type.full_name}>") def gen_trampoline_load( self, cgb: CodeGenBlock, binding: "BoundVariable", data_name: str, value_name: str diff --git a/slangpy/experimental/gridarg.py b/slangpy/experimental/gridarg.py index 3e0cdfac0..344b5d9b1 100644 --- a/slangpy/experimental/gridarg.py +++ b/slangpy/experimental/gridarg.py @@ -78,9 +78,8 @@ def __init__(self, layout: SlangProgramLayout, dims: int): def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: BoundVariable): access = binding.access - name = binding.variable_name if access[0] == AccessType.read: - cgb.type_alias(f"_t_{name}", self.slang_type.full_name) + binding.gen_calldata_type_name(cgb, self.slang_type.full_name) def create_calldata( self, context: CallContext, binding: BoundVariableRuntime, data: GridArg diff --git a/slangpy/tests/slangpy_tests/test_interfaces.py b/slangpy/tests/slangpy_tests/test_interfaces.py index da8d378e6..223ec722f 100644 --- a/slangpy/tests/slangpy_tests/test_interfaces.py +++ b/slangpy/tests/slangpy_tests/test_interfaces.py @@ -58,7 +58,7 @@ def __init__(self, layout: SlangProgramLayout, T: SlangType, N: int): self.concrete_shape = Shape() def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: BoundVariable): - cgb.type_alias(f"_t_{binding.variable_name}", self.slang_type.full_name) + binding.gen_calldata_type_name(cgb, self.slang_type.full_name) def create_test_impl(layout: SlangProgramLayout, value: Any): diff --git a/slangpy/tests/slangpy_tests/test_kernel_gen.py b/slangpy/tests/slangpy_tests/test_kernel_gen.py index cc25948f5..f245b214e 100644 --- a/slangpy/tests/slangpy_tests/test_kernel_gen.py +++ b/slangpy/tests/slangpy_tests/test_kernel_gen.py @@ -123,9 +123,9 @@ def test_gate_scalar_uses_valuetype(device_type: spy.DeviceType): 1, 2, ) - # Scalars now use direct binding: typealias to raw type, no ValueType wrapper + # Scalars now use direct binding: type used directly in CallData, no ValueType wrapper assert_not_contains(code, "ValueType") - assert_contains(code, "typealias _t_a = int;", "typealias _t_b = int;") + assert_not_contains(code, "typealias _t_a", "typealias _t_b") # Trampoline uses direct assignment, no __slangpy_load assert_trampoline_has(code, "a = __calldata__.a;", "b = __calldata__.b;") # _result is auto-created as writable RWValueRef (not direct-bind) @@ -143,7 +143,7 @@ def test_gate_float_scalar_uses_valuetype(device_type: spy.DeviceType): 2.0, ) assert_not_contains(code, "ValueType") - assert_contains(code, "typealias _t_x = float;", "typealias _t_y = float;") + assert_not_contains(code, "typealias _t_x", "typealias _t_y") # -- Step 1.3: Vector / Matrix / Array direct binding -- @@ -160,7 +160,8 @@ def test_gate_vector_uses_vectorvaluetype(device_type: spy.DeviceType): 1.0, ) assert_not_contains(code, "VectorValueType") - assert_contains(code, "typealias _t_v = vector;") + assert_not_contains(code, "typealias _t_v") + assert_contains(code, "vector v;") @pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) @@ -173,7 +174,8 @@ def test_gate_matrix_uses_valuetype(device_type: spy.DeviceType): spy.math.float4x4.identity(), ) assert_not_contains(code, "ValueType>") - assert_contains(code, "typealias _t_m = matrix;") + assert_not_contains(code, "typealias _t_m") + assert_contains(code, "matrix m;") @pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) @@ -186,7 +188,7 @@ def test_gate_array_dim0_uses_valuetype(device_type: spy.DeviceType): [1.0, 2.0, 3.0, 4.0], ) assert_not_contains(code, "ValueType<") - assert_contains(code, "typealias _t_a = ") + assert_not_contains(code, "typealias _t_a") # -- Step 1.5: ValueRef direct binding -- @@ -201,8 +203,9 @@ def test_gate_valueref_read_uses_wrapper(device_type: spy.DeviceType): "float read_val(float v) { return v; }", ValueRef(1.0), ) - # Read-only ValueRef uses raw type alias (direct-bind) - assert_contains(code, "typealias _t_v = float;") + # Read-only ValueRef uses raw type directly (direct-bind) + assert_not_contains(code, "typealias _t_v") + assert_contains(code, "float v;") # Direct assignment in trampoline assert_trampoline_has(code, "v = __calldata__.v;") # _result (writable) still uses RWValueRef wrapper @@ -276,9 +279,10 @@ def test_gate_struct_uses_slangpy_load(device_type: spy.DeviceType): float sum(S s) { return s.x + s.y; } """ code = generate_code(device, "sum", src, {"_type": "S", "x": 1.0, "y": 2.0}) - # Direct-bind struct: uses raw type alias, no inline struct with __slangpy_load + # Direct-bind struct: uses raw type directly, no inline struct with __slangpy_load assert_not_contains(code, "__slangpy_load") - assert_contains(code, "typealias _t_s = S;") + assert_not_contains(code, "typealias _t_s") + assert_contains(code, "S s;") # Direct assignment in trampoline assert_trampoline_has(code, "s = __calldata__.s;") @@ -328,8 +332,8 @@ def test_gate_wanghasharg_uses_wrapper(device_type: spy.DeviceType): src = "uint3 rng(uint3 input) { return input; }" code = generate_code(device, "rng", src, WangHashArg(3)) assert_contains(code, "WangHashArg<") - # WangHashArg uses wrapper type. Check the type alias is present. - assert_contains(code, "_t_input") + # WangHashArg uses wrapper type — field declaration present in CallData + assert_contains(code, "input") @pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) @@ -450,8 +454,8 @@ def test_gate_mixed_args_scalar_and_tensor(device_type: spy.DeviceType): 1.0, tensor, ) - # 'a' is direct-bind (scalar dim-0): raw typealias, direct trampoline load - assert_contains(code, "typealias _t_a = float;") + # 'a' is direct-bind (scalar dim-0): type used directly, direct trampoline load + assert_not_contains(code, "typealias _t_a") assert_not_contains(code, "ValueType") assert_trampoline_has(code, "a = __calldata__.a;") # 'b' is NOT direct-bind (vectorized tensor dim-1): uses Tensor, @@ -517,14 +521,16 @@ def test_gate_struct_mixed_fields_codegen(device_type: spy.DeviceType): assert_contains(code, "__slangpy_load") assert_contains(code, "struct _t_s") assert_not_contains(code, "typealias _t_s = S;") - # Child y is direct-bind: raw type alias, direct assignment in __slangpy_load - assert_contains(code, "typealias _t_y = float;") + # Child y is direct-bind: type used directly, direct assignment in __slangpy_load + assert_not_contains(code, "typealias _t_y") + assert_contains(code, "float y;") assert_contains(code, "value.y = y;") assert_not_contains(code, "ValueType") # Child x should use tensor type assert_contains(code, "Tensor") # Scalar arg 'scale' is independent — should still be direct-bind - assert_contains(code, "typealias _t_scale = float;") + assert_not_contains(code, "typealias _t_scale") + assert_contains(code, "float scale;") @pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) @@ -586,8 +592,9 @@ def test_gate_tensor_dim0_codegen(device_type: spy.DeviceType): """ tensor = Tensor.from_numpy(device, np.array([42, 2, 3], dtype=np.float32)) code = generate_code(device, "tensor_read", src, tensor) - # Type alias should use Tensor - assert_contains(code, "typealias _t_t = Tensor;") + # Type should use Tensor directly (no typealias) + assert_not_contains(code, "typealias _t_t") + assert_contains(code, "Tensor t;") # Trampoline uses direct assignment (not __slangpy_load) assert_trampoline_has(code, "t = __calldata__.t;") # No wrapper type for the tensor @@ -653,7 +660,8 @@ def test_mixed_children_direct_bind_codegen(device_type: spy.DeviceType): tensor_x = Tensor.from_numpy(device, np.array([1, 2, 3], dtype=np.float32)) code = generate_code(device, "weighted_sum", src, {"_type": "S", "x": tensor_x, "y": 1.0}, 2.0) # Child y uses raw type and direct assignment - assert_contains(code, "typealias _t_y = float;") + assert_not_contains(code, "typealias _t_y") + assert_contains(code, "float y;") assert_contains(code, "value.y = y;") # No mapping constant for y (direct-bind skips it) assert_not_contains(code, "_m_y") @@ -793,10 +801,10 @@ def test_gate_2d_tensor_to_vector_codegen(device_type: spy.DeviceType): ) # v is vectorized dim-1: tensor wrapping a vector type assert_contains(code, "__slangpy_load") - assert_contains(code, "_t_v") assert_contains(code, "_m_v") - # s is scalar dim-0: direct-bind - assert_contains(code, "typealias _t_s = float;") + # s is scalar dim-0: direct-bind, type used directly + assert_not_contains(code, "typealias _t_s") + assert_contains(code, "float s;") @pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) @@ -960,7 +968,6 @@ def test_gate_2d_tensor_to_1d_array_codegen(device_type: spy.DeviceType): # data is vectorized (trailing dim consumed by array): __slangpy_load present assert_contains(code, "__slangpy_load") assert_contains(code, "_m_data") - assert_contains(code, "_t_data") @pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) @@ -1034,8 +1041,9 @@ def test_gate_mixed_vectorized_and_dim0_tensor_codegen(device_type: spy.DeviceTy # v: vectorized dim-1 (2D→float3), uses __slangpy_load assert_contains(code, "_m_v") assert_contains(code, "__slangpy_load") - # weights: dim-0 direct-bind (Tensor param), uses typealias + direct assignment - assert_contains(code, "typealias _t_weights = Tensor;") + # weights: dim-0 direct-bind (Tensor param), type used directly + direct assignment + assert_not_contains(code, "typealias _t_weights") + assert_contains(code, "Tensor weights;") assert_trampoline_has(code, "weights = __calldata__.weights;") @@ -1107,8 +1115,9 @@ def test_gate_nested_struct_codegen(device_type: spy.DeviceType): src, {"_type": "Outer", "inner": {"_type": "Inner", "x": 1.0, "y": 2.0}, "scale": 3.0}, ) - # All-scalar nested struct at dim-0: direct-bind → raw typealias - assert_contains(code, "typealias _t_o = Outer;") + # All-scalar nested struct at dim-0: direct-bind → type used directly + assert_not_contains(code, "typealias _t_o") + assert_contains(code, "Outer o;") assert_not_contains(code, "__slangpy_load") assert_not_contains(code, "struct _t_o") assert_trampoline_has(code, "o = __calldata__.o;") @@ -1179,8 +1188,9 @@ def test_gate_struct_with_vector_fields_codegen(device_type: spy.DeviceType): src, {"_type": "S", "pos": spy.math.float3(1, 2, 3), "scale": 2.0}, ) - # All-scalar struct with vector field at dim-0: direct-bind → raw typealias - assert_contains(code, "typealias _t_s = S;") + # All-scalar struct with vector field at dim-0: direct-bind → type used directly + assert_not_contains(code, "typealias _t_s") + assert_contains(code, "S s;") assert_not_contains(code, "__slangpy_load") assert_trampoline_has(code, "s = __calldata__.s;") @@ -1240,7 +1250,8 @@ def test_gate_struct_with_matrix_field_codegen(device_type: spy.DeviceType): src, {"_type": "S", "m": spy.math.float4x4.identity(), "scale": 2.0}, ) - assert_contains(code, "typealias _t_s = S;") + assert_not_contains(code, "typealias _t_s") + assert_contains(code, "S s;") assert_not_contains(code, "__slangpy_load") assert_trampoline_has(code, "s = __calldata__.s;") @@ -1286,7 +1297,8 @@ def test_gate_struct_with_array_field_codegen(device_type: spy.DeviceType): src, {"_type": "Foo", "vals": [1, 2, 3, 4]}, ) - assert_contains(code, "typealias _t_foo = Foo;") + assert_not_contains(code, "typealias _t_foo") + assert_contains(code, "Foo foo;") assert_not_contains(code, "__slangpy_load") assert_trampoline_has(code, "foo = __calldata__.foo;") @@ -1364,7 +1376,8 @@ def test_gate_deeply_nested_struct_codegen(device_type: spy.DeviceType): "s": 4.0, }, ) - assert_contains(code, "typealias _t_t = Top;") + assert_not_contains(code, "typealias _t_t") + assert_contains(code, "Top t;") assert_not_contains(code, "__slangpy_load") assert_not_contains(code, "struct _t_t") assert_trampoline_has(code, "t = __calldata__.t;") @@ -1469,9 +1482,11 @@ def test_gate_nested_struct_with_tensor_child_codegen(device_type: spy.DeviceTyp assert_contains(code, "struct _t_o") assert_contains(code, "__slangpy_load") assert_not_contains(code, "typealias _t_o = Outer;") - # Scalar children retain direct-bind: raw type aliases - assert_contains(code, "typealias _t_y = float;") - assert_contains(code, "typealias _t_s = float;") + # Scalar children retain direct-bind: types used directly + assert_not_contains(code, "typealias _t_y") + assert_contains(code, "float y;") + assert_not_contains(code, "typealias _t_s") + assert_contains(code, "float s;") # Direct assignment for scalar children within __slangpy_load assert_contains(code, "value.y = y;") # Tensor child uses standard path @@ -1572,7 +1587,8 @@ def test_gate_struct_with_struct_array_field_codegen(device_type: spy.DeviceType ], }, ) - assert_contains(code, "typealias _t_outer = Outer;") + assert_not_contains(code, "typealias _t_outer") + assert_contains(code, "Outer outer;") assert_not_contains(code, "__slangpy_load") assert_trampoline_has(code, "outer = __calldata__.outer;") @@ -1623,8 +1639,8 @@ def test_gate_struct_return_codegen(device_type: spy.DeviceType): S make_struct(int a, int b) { return { a, b }; } """ code = generate_code(device, "make_struct", src, 4, 5) - # Scalar inputs are direct-bind - assert_contains(code, "typealias _t_a = int;", "typealias _t_b = int;") + # Scalar inputs are direct-bind, types used directly + assert_not_contains(code, "typealias _t_a", "typealias _t_b") # _result is writable → NOT direct-bind → uses wrapper assert_contains(code, "__slangpy_store") assert_contains(code, "_m__result") @@ -1695,8 +1711,9 @@ def test_gate_struct_with_vectorized_2d_tensor_child_codegen(device_type: spy.De assert_contains(code, "struct _t_st") assert_contains(code, "__slangpy_load") assert_not_contains(code, "typealias _t_st = S;") - # Scalar child s retains direct-bind - assert_contains(code, "typealias _t_s = float;") + # Scalar child s retains direct-bind — type used directly, no alias + assert_not_contains(code, "typealias _t_s") + assert_contains(code, "float s;") assert_contains(code, "value.s = s;") # Tensor child v uses standard path assert_contains(code, "_m_v") @@ -1721,5 +1738,103 @@ def test_phase1_functional_struct_with_vectorized_2d_tensor(device_type: spy.Dev np.testing.assert_allclose(result.to_numpy().reshape(expected.shape), expected, atol=1e-5) +# =========================================================================== +# Long type name heuristic — typealias emitted for names > MAX_INLINE_TYPE_LEN +# =========================================================================== + +# Struct name that is deliberately longer than MAX_INLINE_TYPE_LEN (60 chars). +# 70 chars: +_LONG_STRUCT_NAME = "MyVeryLongStructNameThatExceedsSixtyCharactersForTesting12345" +assert len(_LONG_STRUCT_NAME) > 60 + +_SHORT_STRUCT_NAME = "S" +assert len(_SHORT_STRUCT_NAME) <= 60 + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_long_struct_name_gets_typealias(device_type: spy.DeviceType): + """A direct-bind struct with a name > 60 chars should emit a typealias.""" + device = helpers.get_device(device_type) + src = f""" +struct {_LONG_STRUCT_NAME} {{ + float x; + float y; +}}; +float sum({_LONG_STRUCT_NAME} s) {{ return s.x + s.y; }} +""" + code = generate_code( + device, + "sum", + src, + {"_type": _LONG_STRUCT_NAME, "x": 1.0, "y": 2.0}, + ) + # Long name → typealias _t_s emitted, CallData field declared as _t_s + assert_contains(code, f"typealias _t_s = {_LONG_STRUCT_NAME};") + assert_contains(code, "_t_s s;") + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_short_struct_name_inlined(device_type: spy.DeviceType): + """A direct-bind struct with a short name should NOT emit a typealias.""" + device = helpers.get_device(device_type) + src = f""" +struct {_SHORT_STRUCT_NAME} {{ + float x; + float y; +}}; +float sum({_SHORT_STRUCT_NAME} s) {{ return s.x + s.y; }} +""" + code = generate_code( + device, + "sum", + src, + {"_type": _SHORT_STRUCT_NAME, "x": 1.0, "y": 2.0}, + ) + # Short name → no typealias, raw type inlined + assert_not_contains(code, "typealias _t_s") + assert_contains(code, f"{_SHORT_STRUCT_NAME} s;") + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_long_scalar_type_name_gets_typealias(device_type: spy.DeviceType): + """A non-direct-bind arg whose wrapper type name exceeds 60 chars gets a typealias.""" + device = helpers.get_device(device_type) + src = f""" +struct {_LONG_STRUCT_NAME} {{ + float x; + float y; +}}; +{_LONG_STRUCT_NAME} identity({_LONG_STRUCT_NAME} s) {{ return s; }} +""" + # Pass as a ValueRef so _result is writable → uses wrapper, and the wrapper + # type name for _result will include the long struct name. + code = generate_code( + device, + "identity", + src, + {"_type": _LONG_STRUCT_NAME, "x": 1.0, "y": 2.0}, + ) + # The _result binding uses RWValueRef which exceeds 60 chars + result_type = f"RWValueRef<{_LONG_STRUCT_NAME}>" + assert len(result_type) > 60, f"Expected >60 chars, got {len(result_type)}" + assert_contains(code, f"typealias _t__result = {result_type};") + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_phase1_functional_long_struct_name(device_type: spy.DeviceType): + """End-to-end dispatch with a struct whose name exceeds 60 chars.""" + device = helpers.get_device(device_type) + src = f""" +struct {_LONG_STRUCT_NAME} {{ + float x; + float y; +}}; +float sum({_LONG_STRUCT_NAME} s) {{ return s.x + s.y; }} +""" + func = helpers.create_function_from_module(device, "sum", src) + result = func({"_type": _LONG_STRUCT_NAME, "x": 3.0, "y": 7.0}) + assert abs(result - 10.0) < 1e-5 + + if __name__ == "__main__": pytest.main([__file__, "-vs"]) diff --git a/slangpy/types/callidarg.py b/slangpy/types/callidarg.py index a920846c6..3d6e1e0e8 100644 --- a/slangpy/types/callidarg.py +++ b/slangpy/types/callidarg.py @@ -53,9 +53,8 @@ def __init__(self, layout: SlangProgramLayout, dims: int): def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: BoundVariable): access = binding.access - name = binding.variable_name if access[0] == AccessType.read: - cgb.type_alias(f"_t_{name}", self.slang_type.full_name) + binding.gen_calldata_type_name(cgb, self.slang_type.full_name) def resolve_type(self, context: BindContext, bound_type: "SlangType"): # Resolve type using reflection. diff --git a/slangpy/types/randfloatarg.py b/slangpy/types/randfloatarg.py index 75ae1a4ad..e9d7796f3 100644 --- a/slangpy/types/randfloatarg.py +++ b/slangpy/types/randfloatarg.py @@ -84,10 +84,8 @@ def __init__(self, layout: SlangProgramLayout, dim: int, warmup: int): def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: "BoundVariable"): access = binding.access - name = binding.variable_name if access[0] == AccessType.read: - # cgb.add_import("randfloatarg") - cgb.type_alias(f"_t_{name}", self.slang_type.full_name) + binding.gen_calldata_type_name(cgb, self.slang_type.full_name) def create_calldata( self, context: CallContext, binding: BoundVariableRuntime, data: RandFloatArg diff --git a/slangpy/types/threadidarg.py b/slangpy/types/threadidarg.py index e53b04963..6b32749a7 100644 --- a/slangpy/types/threadidarg.py +++ b/slangpy/types/threadidarg.py @@ -48,9 +48,8 @@ def __init__(self, layout: SlangProgramLayout, dims: int): def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: BoundVariable): access = binding.access - name = binding.variable_name if access[0] == AccessType.read: - cgb.type_alias(f"_t_{name}", self.slang_type.full_name) + binding.gen_calldata_type_name(cgb, self.slang_type.full_name) def resolve_type(self, context: BindContext, bound_type: "SlangType"): # Thread id arg is valid to pass to vector or scalar integer types. diff --git a/slangpy/types/wanghasharg.py b/slangpy/types/wanghasharg.py index 7f602b514..902775c2d 100644 --- a/slangpy/types/wanghasharg.py +++ b/slangpy/types/wanghasharg.py @@ -84,9 +84,8 @@ def __init__(self, layout: SlangProgramLayout, dims: int, warmup: int): def gen_calldata(self, cgb: CodeGenBlock, context: BindContext, binding: BoundVariable): access = binding.access - name = binding.variable_name if access[0] == AccessType.read: - cgb.type_alias(f"_t_{name}", self.slang_type.full_name) + binding.gen_calldata_type_name(cgb, self.slang_type.full_name) def create_calldata( self, context: CallContext, binding: BoundVariableRuntime, data: WangHashArg From 1566d510570b48b3569db74147e6d170e81cad5b Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Thu, 12 Mar 2026 12:09:38 +0000 Subject: [PATCH 16/41] Neater tests --- .../plan-consolidateCodeGenTests.prompt.md | 269 +++++ slangpy/tests/slangpy_tests/test_code_gen.py | 1069 +++++++++++++++++ 2 files changed, 1338 insertions(+) create mode 100644 .github/prompts/plan-consolidateCodeGenTests.prompt.md create mode 100644 slangpy/tests/slangpy_tests/test_code_gen.py diff --git a/.github/prompts/plan-consolidateCodeGenTests.prompt.md b/.github/prompts/plan-consolidateCodeGenTests.prompt.md new file mode 100644 index 000000000..f8315a232 --- /dev/null +++ b/.github/prompts/plan-consolidateCodeGenTests.prompt.md @@ -0,0 +1,269 @@ +## Plan: Consolidate test_kernel_gen.py → test_code_gen.py + +**TL;DR**: Reduce 80 test functions to ~34 by: (1) merging codegen-pattern + binding-flag tests that generate the same kernel into single combined tests, (2) dropping functional GPU dispatch tests that duplicate coverage in existing test files, (3) consolidating tests that use identical Slang source strings, and (4) subsuming shallow struct nesting tests into deeper ones. + +Current: **80 tests × 3 device types = 240 parametrized cases** (~1841 lines) +Proposed: **~34 tests × 3 device types = ~102 parametrized cases** (~700-800 lines) + +--- + +### Consolidation Strategies + +#### Strategy A: Merge same-source codegen tests + +Five tests use `int add(int a, int b)` with args `(1, 2)` and generate the exact same kernel: +- `test_gate_scalar_uses_valuetype` +- `test_gate_valueref_write_uses_wrapper` +- `test_gate_mapping_constants_present` +- `test_gate_context_map_in_trampoline` +- `test_result_binding_not_direct_bind` + +→ **1 combined test** `test_scalar_direct_bind`: generates kernel once, asserts all codegen patterns (no `ValueType`, no `typealias`, direct assignment, `RWValueRef` for `_result`, no mapping constants for args, `_m__result` present, no `context.map` for args) **and** binding flags (`args[0].direct_bind=True`, `kwargs["_result"].direct_bind=False`). + +Also fold `test_gate_float_scalar_uses_valuetype` into this as a sub-assertion (generate a second kernel for `float mymul` and check same patterns), or drop entirely since `int` and `float` exercise the same `ValueMarshall` code path. + +#### Strategy B: Merge codegen + binding flag pairs + +Each of these pairs generates the same kernel twice — merge into one test using a new `generate_code_and_bindings` helper that returns `(code, bindings)`: + +| Merged test | From (codegen) | From (binding flags) | +|---|---|---| +| `test_vector_direct_bind` | `test_gate_vector_uses_vectorvaluetype` | (no binding-flag test exists; add flags check) | +| `test_struct_all_scalar_direct_bind` | `test_gate_struct_uses_slangpy_load` | `test_struct_all_scalars_binding_flag` | +| `test_mixed_scalar_tensor` | `test_gate_mixed_args_scalar_and_tensor` | `test_gate_mixed_args_direct_bind_flags` | +| `test_struct_mixed_fields` | `test_gate_struct_mixed_fields_codegen` + `test_mixed_children_direct_bind_codegen` | `test_gate_struct_mixed_fields_binding_flags` | +| `test_tensor_dim0_direct_bind` | `test_gate_tensor_dim0_codegen` | `test_gate_tensor_dim0_binding_flags` | +| `test_2d_tensor_to_vector` | `test_gate_2d_tensor_to_vector_codegen` | `test_gate_2d_tensor_to_vector_binding_flags` | +| `test_3d_tensor_to_vector` | `test_gate_3d_tensor_to_vector_codegen` | `test_gate_3d_tensor_to_vector_binding_flags` | +| `test_2d_tensor_to_scalar` | `test_gate_2d_tensor_to_scalar_codegen` | `test_gate_2d_tensor_to_scalar_binding_flags` | +| `test_2d_tensor_to_array` | `test_gate_2d_tensor_to_1d_array_codegen` | `test_gate_2d_tensor_to_1d_array_binding_flags` | +| `test_mixed_vectorized_dim0_tensor` | `test_gate_mixed_vectorized_and_dim0_tensor_codegen` | `test_gate_mixed_vectorized_and_dim0_tensor_binding_flags` | +| `test_struct_return_not_direct_bind` | `test_gate_struct_return_codegen` | `test_gate_struct_return_binding_flags` | +| `test_nested_struct_with_tensor_child` | `test_gate_nested_struct_with_tensor_child_codegen` | `test_gate_nested_struct_with_tensor_child_binding_flags` | + +#### Strategy C: Subsume shallow struct tests into deeper ones + +The 3-level-deep all-scalar struct test (`test_gate_deeply_nested_struct_codegen` + `test_gate_deeply_nested_struct_binding_flags`) covers the same pattern as the 2-level nested struct (`test_gate_nested_struct_codegen` + `test_gate_nested_struct_binding_flags`). Keep only **`test_deeply_nested_struct_direct_bind`** combining both. Drop tests for 2-level nesting. + +#### Strategy D: Consolidate all-scalar struct composite field variants + +Four test groups cover "struct with {vector/matrix/array/struct-array} fields, all dim-0 → direct-bind". These all exercise the same property (recursive `can_direct_bind` returning `True`): +- struct with vector field (`test_gate_struct_with_vector_fields_codegen` + binding) +- struct with matrix field (`test_gate_struct_with_matrix_field_codegen`) +- struct with array field (`test_gate_struct_with_array_field_codegen` + binding) +- struct with struct-array field (`test_gate_struct_with_struct_array_field_codegen`) + +→ Merge into **one parametrized test** `test_struct_composite_fields_direct_bind` with parameters for the variant (vector field, array field). Drop matrix and struct-array codegen-only tests — if vector+array pass, the mechanism works for all composite field types. + +#### Strategy E: Consolidate negative gates + +Merge `test_gate_wanghasharg_uses_wrapper`, `test_wanghasharg_binding_flag`, and `test_struct_with_wanghash_child_not_direct_bind` into **one test** `test_wanghasharg_not_direct_bind` that covers standalone + struct-child cases. + +Keep `test_gate_vectorized_scalar_keeps_wrapper` and `test_gate_vectorized_dict_keeps_struct_load` as separate small tests (they're already minimal). + +#### Strategy F: Consolidate long-name tests + +Merge `test_gate_long_struct_name_gets_typealias`, `test_gate_short_struct_name_inlined`, `test_gate_long_scalar_type_name_gets_typealias` into **one test** `test_long_type_name_typealias`. + +#### Strategy G: Drop functional dispatch tests with existing coverage + +| Dropped test | Covered by | +|---|---| +| `test_phase1_functional_scalar_add` | `test_simple_function_call.py::test_returnvalue` | +| `test_phase1_functional_float_mul` | Same mechanism as scalar_add, float type tested elsewhere | +| `test_phase1_functional_valueref_write` | `test_simple_function_call.py::test_scalar_outparam` | +| `test_phase1_functional_struct_return` | `test_return_types.py::test_return_struct_as_dict` | +| `test_phase1_functional_struct_sum` | Similar to struct_return, struct dispatch tested elsewhere | +| `test_phase1_functional_nested_struct` | Subsumed by deeply_nested + nested_with_tensor tests | +| `test_phase1_functional_struct_with_vector_fields` | Covered by composite field parametrized test pattern | +| `test_phase1_functional_struct_with_matrix_field` | Same mechanism as vector_fields test | +| `test_phase1_functional_struct_with_array_field` | Same mechanism, array dispatch tested in `test_simple_function_call.py` | +| `test_phase1_functional_deeply_nested_struct` | 3-level dispatch validates same mechanism as 2-level | +| `test_phase1_functional_vector_scale` | `test_vector_function_call.py` covers vector dispatch | +| `test_phase1_functional_3d_tensor_to_vector` | If 2D→vector works, 3D exercises same path with extra dim | +| `test_phase1_functional_2d_tensor_to_scalar` | Element-wise tensor dispatch covered by `test_tensor.py` | +| `test_phase1_functional_valueref_read_input` | Read ValueRef → scalar is tested indirectly; codegen test verifies the binding | +| `test_phase1_functional_long_struct_name` | Long name is a codegen-only concern; dispatch is identical to short-name struct | + +--- + +### Proposed Final Test List (~34 tests) + +**New helper:** +```python +def generate_code_and_bindings(device, func_name, module_source, *args, **kwargs): + """Generate code and return (code_str, bindings) from a single debug_build_call_data call.""" + func = helpers.create_function_from_module(device, func_name, module_source) + cd = func.debug_build_call_data(*args, **kwargs) + return cd.code, cd.debug_only_bindings +``` + +**Codegen + binding flag tests (21):** + +| # | Test name | Scenario | Merges from | +|---|---|---|---| +| 1 | `test_scalar_direct_bind` | int/float scalar dim-0; _result writable | 5 codegen tests + 1 binding test + float variant | +| 2 | `test_vector_direct_bind` | float3 dim-0 | codegen test + new binding assertions | +| 3 | `test_matrix_direct_bind` | float4x4 dim-0 | standalone | +| 4 | `test_array_direct_bind` | float[4] dim-0 | standalone | +| 5 | `test_valueref_read_direct_bind` | read-only ValueRef | standalone | +| 6 | `test_writable_valueref_not_direct_bind` | inout ValueRef (RWValueRef) | standalone | +| 7 | `test_struct_all_scalar_direct_bind` | S{float x, y} via dict | codegen + binding pair | +| 8 | `test_struct_composite_fields_direct_bind` | parametrized: struct with vector / array field | 4 codegen + 2 binding tests | +| 9 | `test_deeply_nested_struct_direct_bind` | 3-level Top{Mid{Bot}} | subsumes 2-level; codegen + binding pair | +| 10 | `test_struct_mixed_fields` | S{x(tensor), y(scalar)} | 2 codegen + 1 binding test | +| 11 | `test_nested_struct_with_tensor_child` | Outer{Inner{x(tensor),y},s} | codegen + binding pair | +| 12 | `test_struct_return_not_direct_bind` | function returning struct | codegen + binding pair | +| 13 | `test_struct_vectorized_2d_child` | S{float3 v (2D tensor), float s} | standalone | +| 14 | `test_mixed_scalar_and_tensor` | scalar + tensor args | codegen + binding pair | +| 15 | `test_tensor_dim0_direct_bind` | Tensor at dim-0 | codegen + binding pair | +| 16 | `test_2d_tensor_to_vector` | 2D(10,3) → float3 | codegen + binding pair | +| 17 | `test_3d_tensor_to_vector` | 3D(2,5,3) → float3 | codegen + binding pair | +| 18 | `test_2d_tensor_to_scalar` | 2D(4,5) → float | codegen + binding pair | +| 19 | `test_2d_tensor_to_array` | 2D(4,8) → half[8] | codegen + binding pair | +| 20 | `test_mixed_vectorized_dim0_tensor` | vectorized + dim-0 tensor | codegen + binding pair | +| 21 | `test_long_type_name_typealias` | long/short struct name, wrapper name | 3 tests merged | + +**Negative gates (3):** + +| # | Test name | Scenario | Merges from | +|---|---|---|---| +| 22 | `test_wanghasharg_not_direct_bind` | standalone + struct child | 3 tests merged | +| 23 | `test_vectorized_scalar_keeps_wrapper` | 1D tensor → float | standalone | +| 24 | `test_vectorized_dict_keeps_wrapper` | dict with tensor children | standalone | + +**Autodiff (1):** + +| # | Test name | Scenario | Merges from | +|---|---|---|---| +| 25 | `test_bwds_direct_bind` | codegen + binding flags for bwds polynomial | 3 tests merged | + +**Functional GPU dispatch — novel scenarios only (9):** + +| # | Test name | Scenario | Why novel | +|---|---|---|---| +| 26 | `test_dispatch_mixed_scalar_tensor` | scalar + 1D tensor | Not tested elsewhere | +| 27 | `test_dispatch_struct_mixed_fields` | struct{tensor+scalar} | Unique dispatch scenario | +| 28 | `test_dispatch_tensor_dim0` | Tensor at dim-0 | Specific dim-0 behavior | +| 29 | `test_dispatch_2d_tensor_to_vector` | 2D→float3 | Novel param mapping | +| 30 | `test_dispatch_2d_tensor_to_array` | 2D→half[8] generic | Unique test | +| 31 | `test_dispatch_mixed_vectorized_dim0_tensor` | vectorized + dim-0 tensor | Unique | +| 32 | `test_dispatch_nested_struct_with_tensor` | nested struct with tensor leaf | Unique | +| 33 | `test_dispatch_struct_vectorized_2d_child` | struct with 2D tensor→float3 child | Unique | +| 34 | `test_dispatch_struct_array_of_structs` | struct with `Inner items[4]` | Unique | + +--- + +### Old → New Mapping + +| Old test (test_kernel_gen.py) | New test (test_code_gen.py) | Action | +|---|---|---| +| `test_kernel_gen_basic` | — | **Dropped** (subset of `test_scalar_direct_bind`) | +| `test_gate_scalar_uses_valuetype` | `test_scalar_direct_bind` | **Merged** | +| `test_gate_float_scalar_uses_valuetype` | `test_scalar_direct_bind` | **Merged** (or dropped) | +| `test_gate_vector_uses_vectorvaluetype` | `test_vector_direct_bind` | **Merged** | +| `test_gate_matrix_uses_valuetype` | `test_matrix_direct_bind` | **Kept** (standalone) | +| `test_gate_array_dim0_uses_valuetype` | `test_array_direct_bind` | **Kept** (standalone) | +| `test_gate_valueref_read_uses_wrapper` | `test_valueref_read_direct_bind` | **Kept** (standalone) | +| `test_gate_valueref_write_uses_wrapper` | `test_scalar_direct_bind` | **Merged** | +| `test_gate_mapping_constants_present` | `test_scalar_direct_bind` | **Merged** | +| `test_gate_context_map_in_trampoline` | `test_scalar_direct_bind` | **Merged** | +| `test_gate_struct_uses_slangpy_load` | `test_struct_all_scalar_direct_bind` | **Merged** | +| `test_gate_bwds_scalar_uses_valuetype` | `test_bwds_direct_bind` | **Merged** | +| `test_gate_bwds_trampoline_is_differentiable` | `test_bwds_direct_bind` | **Merged** | +| `test_gate_wanghasharg_uses_wrapper` | `test_wanghasharg_not_direct_bind` | **Merged** | +| `test_gate_vectorized_scalar_keeps_wrapper` | `test_vectorized_scalar_keeps_wrapper` | **Kept** | +| `test_gate_vectorized_dict_keeps_struct_load` | `test_vectorized_dict_keeps_wrapper` | **Kept** | +| `test_phase1_functional_scalar_add` | — | **Dropped** (covered by `test_simple_function_call.py`) | +| `test_phase1_functional_float_mul` | — | **Dropped** | +| `test_phase1_functional_vector_scale` | — | **Dropped** (covered by `test_vector_function_call.py`) | +| `test_phase1_functional_struct_sum` | — | **Dropped** | +| `test_phase1_functional_valueref_write` | — | **Dropped** (covered by `test_simple_function_call.py`) | +| `test_gate_mixed_args_scalar_and_tensor` | `test_mixed_scalar_and_tensor` | **Merged** | +| `test_gate_mixed_args_direct_bind_flags` | `test_mixed_scalar_and_tensor` | **Merged** | +| `test_phase1_functional_mixed_scalar_tensor` | `test_dispatch_mixed_scalar_tensor` | **Kept** | +| `test_gate_struct_mixed_fields_codegen` | `test_struct_mixed_fields` | **Merged** | +| `test_gate_struct_mixed_fields_binding_flags` | `test_struct_mixed_fields` | **Merged** | +| `test_phase1_functional_struct_mixed_fields` | `test_dispatch_struct_mixed_fields` | **Kept** | +| `test_gate_tensor_dim0_codegen` | `test_tensor_dim0_direct_bind` | **Merged** | +| `test_gate_tensor_dim0_binding_flags` | `test_tensor_dim0_direct_bind` | **Merged** | +| `test_phase1_functional_tensor_dim0` | `test_dispatch_tensor_dim0` | **Kept** | +| `test_mixed_children_direct_bind_codegen` | `test_struct_mixed_fields` | **Merged** (overlap with struct_mixed_fields) | +| `test_writable_valueref_not_direct_bind` | `test_writable_valueref_not_direct_bind` | **Kept** | +| `test_result_binding_not_direct_bind` | `test_scalar_direct_bind` | **Merged** | +| `test_struct_all_scalars_binding_flag` | `test_struct_all_scalar_direct_bind` | **Merged** | +| `test_struct_with_wanghash_child_not_direct_bind` | `test_wanghasharg_not_direct_bind` | **Merged** | +| `test_wanghasharg_binding_flag` | `test_wanghasharg_not_direct_bind` | **Merged** | +| `test_bwds_primal_binding_flags` | `test_bwds_direct_bind` | **Merged** | +| `test_gate_2d_tensor_to_vector_codegen` | `test_2d_tensor_to_vector` | **Merged** | +| `test_gate_2d_tensor_to_vector_binding_flags` | `test_2d_tensor_to_vector` | **Merged** | +| `test_phase1_functional_2d_tensor_to_vector` | `test_dispatch_2d_tensor_to_vector` | **Kept** | +| `test_gate_3d_tensor_to_vector_codegen` | `test_3d_tensor_to_vector` | **Merged** | +| `test_gate_3d_tensor_to_vector_binding_flags` | `test_3d_tensor_to_vector` | **Merged** | +| `test_phase1_functional_3d_tensor_to_vector` | — | **Dropped** (2D→vector is sufficient) | +| `test_gate_2d_tensor_to_scalar_codegen` | `test_2d_tensor_to_scalar` | **Merged** | +| `test_gate_2d_tensor_to_scalar_binding_flags` | `test_2d_tensor_to_scalar` | **Merged** | +| `test_phase1_functional_2d_tensor_to_scalar` | — | **Dropped** (covered by `test_tensor.py`) | +| `test_gate_2d_tensor_to_1d_array_codegen` | `test_2d_tensor_to_array` | **Merged** | +| `test_gate_2d_tensor_to_1d_array_binding_flags` | `test_2d_tensor_to_array` | **Merged** | +| `test_phase1_functional_2d_tensor_to_1d_array` | `test_dispatch_2d_tensor_to_array` | **Kept** | +| `test_gate_mixed_vectorized_and_dim0_tensor_codegen` | `test_mixed_vectorized_dim0_tensor` | **Merged** | +| `test_gate_mixed_vectorized_and_dim0_tensor_binding_flags` | `test_mixed_vectorized_dim0_tensor` | **Merged** | +| `test_phase1_functional_mixed_vectorized_and_dim0_tensor` | `test_dispatch_mixed_vectorized_dim0_tensor` | **Kept** | +| `test_gate_nested_struct_codegen` | — | **Dropped** (subsumed by deeply_nested) | +| `test_gate_nested_struct_binding_flags` | — | **Dropped** (subsumed by deeply_nested) | +| `test_phase1_functional_nested_struct` | — | **Dropped** | +| `test_gate_struct_with_vector_fields_codegen` | `test_struct_composite_fields_direct_bind` | **Merged** (parametrized) | +| `test_gate_struct_with_vector_fields_binding_flags` | `test_struct_composite_fields_direct_bind` | **Merged** | +| `test_phase1_functional_struct_with_vector_fields` | — | **Dropped** | +| `test_gate_struct_with_matrix_field_codegen` | — | **Dropped** (covered by vector+array variants) | +| `test_phase1_functional_struct_with_matrix_field` | — | **Dropped** | +| `test_gate_struct_with_array_field_codegen` | `test_struct_composite_fields_direct_bind` | **Merged** (parametrized) | +| `test_gate_struct_with_array_field_binding_flags` | `test_struct_composite_fields_direct_bind` | **Merged** | +| `test_phase1_functional_struct_with_array_field` | — | **Dropped** | +| `test_gate_deeply_nested_struct_codegen` | `test_deeply_nested_struct_direct_bind` | **Merged** | +| `test_gate_deeply_nested_struct_binding_flags` | `test_deeply_nested_struct_direct_bind` | **Merged** | +| `test_phase1_functional_deeply_nested_struct` | — | **Dropped** | +| `test_gate_nested_struct_with_tensor_child_codegen` | `test_nested_struct_with_tensor_child` | **Merged** | +| `test_gate_nested_struct_with_tensor_child_binding_flags` | `test_nested_struct_with_tensor_child` | **Merged** | +| `test_phase1_functional_nested_struct_with_tensor` | `test_dispatch_nested_struct_with_tensor` | **Kept** | +| `test_gate_struct_with_struct_array_field_codegen` | — | **Dropped** (covered by array field variant) | +| `test_phase1_functional_struct_with_struct_array_field` | `test_dispatch_struct_array_of_structs` | **Kept** | +| `test_gate_struct_return_codegen` | `test_struct_return_not_direct_bind` | **Merged** | +| `test_gate_struct_return_binding_flags` | `test_struct_return_not_direct_bind` | **Merged** | +| `test_phase1_functional_struct_return` | — | **Dropped** (covered by `test_return_types.py`) | +| `test_gate_struct_with_vectorized_2d_tensor_child_codegen` | `test_struct_vectorized_2d_child` | **Kept** | +| `test_phase1_functional_struct_with_vectorized_2d_tensor` | `test_dispatch_struct_vectorized_2d_child` | **Kept** | +| `test_gate_long_struct_name_gets_typealias` | `test_long_type_name_typealias` | **Merged** | +| `test_gate_short_struct_name_inlined` | `test_long_type_name_typealias` | **Merged** | +| `test_gate_long_scalar_type_name_gets_typealias` | `test_long_type_name_typealias` | **Merged** | +| `test_phase1_functional_long_struct_name` | — | **Dropped** | +| `test_phase1_functional_valueref_read_input` | — | **Dropped** | + +--- + +### Verification + +```bash +# Build first (required) +cmake --build --preset windows-msvc-debug + +# Run new test file +pytest slangpy/tests/slangpy_tests/test_code_gen.py -v + +# Confirm full suite still passes (existing tests in other files cover dropped dispatch tests) +pytest slangpy/tests -v + +# Run pre-commit +pre-commit run --all-files +``` + +### Key Decisions + +- Combined codegen+binding tests: one `debug_build_call_data` call yields both `.code` and `.debug_only_bindings` — no redundant kernel generation +- Dropped `test_kernel_gen_basic`: its sole assertion (`"add" in code`) is a strict subset of `test_scalar_direct_bind` +- Dropped matrix/struct-array field variants: if vector field and array field pass, the `can_direct_bind` recursion works for all composite types +- Dropped 2-level nested struct: the 3-level test covers the same recursion with deeper nesting +- Dropped 15 functional dispatch tests that are covered by existing test files (`test_simple_function_call.py`, `test_return_types.py`, `test_vector_function_call.py`, `test_tensor.py`) +- Kept all negative gates — they deliberately test types NOT eligible for simplification and must remain passing as Phase 2 proceeds +- The old `test_kernel_gen.py` should be deleted once the new `test_code_gen.py` is verified diff --git a/slangpy/tests/slangpy_tests/test_code_gen.py b/slangpy/tests/slangpy_tests/test_code_gen.py new file mode 100644 index 000000000..23ab05f06 --- /dev/null +++ b/slangpy/tests/slangpy_tests/test_code_gen.py @@ -0,0 +1,1069 @@ +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +""" +These tests exercise different code paths for kernel generation, verifying both +generated code patterns and binding flags in a single pass per scenario. Each +test calls ``debug_build_call_data`` once and asserts on both ``.code`` and +``.debug_only_bindings``. + +Negative-gate tests (``test_*_not_direct_bind``, ``test_*_keeps_wrapper``) must +remain passing — they cover types that are NOT direct-bind eligible. + +Functional dispatch tests are included only for scenarios that are not covered +by other test files (``test_simple_function_call.py``, ``test_tensor.py``, etc.). +""" + +from typing import Any, Tuple + +import numpy as np +import os +import pytest + +import slangpy as spy +from slangpy.testing import helpers +from slangpy.types import ValueRef, Tensor, diffPair +from slangpy.types.wanghasharg import WangHashArg + +PRINT_CODE = os.getenv("PRINT_TEST_KERNEL_GEN", "0") == "1" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def assert_contains(code: str, *patterns: str) -> None: + """Assert all patterns appear in generated code.""" + for p in patterns: + assert p in code, f"Expected pattern not found: {p}" + + +def assert_not_contains(code: str, *patterns: str) -> None: + """Assert none of the patterns appear in generated code.""" + for p in patterns: + assert p not in code, f"Unexpected pattern found: {p}" + + +def assert_trampoline_has(code: str, *stmts: str) -> None: + """Assert trampoline contains statements (tolerates call_data vs __calldata__).""" + for s in stmts: + if "__calldata__." in s: + alt = s.replace("__calldata__.", "call_data.") + assert ( + s in code or alt in code + ), f"Expected trampoline statement not found: {s} (or {alt})" + else: + assert s in code, f"Expected trampoline statement not found: {s}" + + +def generate_code_and_bindings( + device: spy.Device, func_name: str, module_source: str, *args: Any, **kwargs: Any +) -> Tuple[str, Any]: + """Generate code and return ``(code_str, bindings)`` from a single ``debug_build_call_data`` call.""" + func = helpers.create_function_from_module(device, func_name, module_source) + cd = func.debug_build_call_data(*args, **kwargs) + if PRINT_CODE: + print(cd.code) + return cd.code, cd.debug_only_bindings + + +def generate_bwds_code_and_bindings( + device: spy.Device, func_name: str, module_source: str, *args: Any, **kwargs: Any +) -> Tuple[str, Any]: + """Generate backwards-mode code and return ``(code_str, bindings)``.""" + func = helpers.create_function_from_module(device, func_name, module_source) + cd = func.bwds.debug_build_call_data(*args, **kwargs) + if PRINT_CODE: + print(cd.code) + return cd.code, cd.debug_only_bindings + + +# =========================================================================== +# Codegen + binding flag tests (1–21) +# =========================================================================== + + +# 1 ------------------------------------------------------------------------- +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_scalar_direct_bind(device_type: spy.DeviceType): + """int/float scalar dim-0: direct-bind, _result writable RWValueRef. + + Merges: test_gate_scalar_uses_valuetype, test_gate_float_scalar_uses_valuetype, + test_gate_valueref_write_uses_wrapper, test_gate_mapping_constants_present, + test_gate_context_map_in_trampoline, test_result_binding_not_direct_bind. + """ + device = helpers.get_device(device_type) + code, bindings = generate_code_and_bindings( + device, "add", "int add(int a, int b) { return a + b; }", 1, 2 + ) + + # --- codegen assertions --- + # Scalars use raw type directly, no wrapper + assert_not_contains(code, "ValueType") + assert_not_contains(code, "typealias _t_a", "typealias _t_b") + # Direct assignment in trampoline + assert_trampoline_has(code, "a = __calldata__.a;", "b = __calldata__.b;") + # _result is auto-created writable RWValueRef + assert_contains(code, "RWValueRef") + assert_contains(code, "__slangpy_store") + # No mapping constants for direct-bind args; _result keeps its mapping constant + assert_not_contains(code, "static const int _m_a = 0", "static const int _m_b = 0") + assert_contains(code, "static const int _m__result = 0") + # No context.map for direct-bind args + assert_not_contains(code, "__slangpy_context__.map(_m_a)") + + # --- binding flag assertions --- + assert bindings.args[0].direct_bind is True + assert bindings.args[0].call_dimensionality == 0 + assert bindings.args[1].direct_bind is True + assert bindings.kwargs["_result"].direct_bind is False + assert bindings.kwargs["_result"].call_dimensionality == 0 + + +# 2 ------------------------------------------------------------------------- +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_vector_direct_bind(device_type: spy.DeviceType): + """float3 dim-0: direct-bind, type used directly.""" + device = helpers.get_device(device_type) + code, bindings = generate_code_and_bindings( + device, + "scale", + "float3 scale(float3 v, float s) { return v * s; }", + spy.math.float3(1, 2, 3), + 1.0, + ) + + assert_not_contains(code, "VectorValueType") + assert_not_contains(code, "typealias _t_v") + assert_contains(code, "vector v;") + + assert bindings.args[0].direct_bind is True + assert bindings.args[0].call_dimensionality == 0 + assert bindings.args[1].direct_bind is True + + +# 3 ------------------------------------------------------------------------- +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_matrix_direct_bind(device_type: spy.DeviceType): + """float4x4 dim-0: direct-bind.""" + device = helpers.get_device(device_type) + code, bindings = generate_code_and_bindings( + device, + "ident", + "float4x4 ident(float4x4 m) { return m; }", + spy.math.float4x4.identity(), + ) + + assert_not_contains(code, "ValueType>") + assert_not_contains(code, "typealias _t_m") + assert_contains(code, "matrix m;") + + assert bindings.args[0].direct_bind is True + assert bindings.args[0].call_dimensionality == 0 + + +# 4 ------------------------------------------------------------------------- +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_array_direct_bind(device_type: spy.DeviceType): + """float[4] dim-0: direct-bind.""" + device = helpers.get_device(device_type) + code, bindings = generate_code_and_bindings( + device, + "process", + "void process(float a[4]) { }", + [1.0, 2.0, 3.0, 4.0], + ) + + assert_not_contains(code, "ValueType<") + assert_not_contains(code, "typealias _t_a") + + assert bindings.args[0].direct_bind is True + assert bindings.args[0].call_dimensionality == 0 + + +# 5 ------------------------------------------------------------------------- +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_valueref_read_direct_bind(device_type: spy.DeviceType): + """Read-only ValueRef: direct-bind, raw type.""" + device = helpers.get_device(device_type) + code, bindings = generate_code_and_bindings( + device, + "read_val", + "float read_val(float v) { return v; }", + ValueRef(1.0), + ) + + assert_not_contains(code, "typealias _t_v") + assert_contains(code, "float v;") + assert_trampoline_has(code, "v = __calldata__.v;") + assert_contains(code, "RWValueRef") + + assert bindings.args[0].direct_bind is True + assert bindings.args[0].call_dimensionality == 0 + + +# 6 ------------------------------------------------------------------------- +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_writable_valueref_not_direct_bind(device_type: spy.DeviceType): + """Writable ValueRef (inout) must not be direct-bind — needs buffer read/write.""" + device = helpers.get_device(device_type) + func = helpers.create_function_from_module(device, "inc", "void inc(inout int v) { v += 1; }") + cd = func.debug_build_call_data(ValueRef(5)) + code, bindings = cd.code, cd.debug_only_bindings + + assert_contains(code, "RWValueRef") + assert_not_contains(code, "typealias _t_v = int;") + + assert bindings.args[0].direct_bind is False + assert bindings.args[0].call_dimensionality == 0 + + +# 7 ------------------------------------------------------------------------- +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_struct_all_scalar_direct_bind(device_type: spy.DeviceType): + """S{float x, y} via dict — all-scalar, direct-bind. + + Merges: test_gate_struct_uses_slangpy_load, test_struct_all_scalars_binding_flag. + """ + device = helpers.get_device(device_type) + src = """ +struct S { + float x; + float y; +}; +float sum(S s) { return s.x + s.y; } +""" + code, bindings = generate_code_and_bindings( + device, "sum", src, {"_type": "S", "x": 1.0, "y": 2.0} + ) + + # Direct-bind struct — raw type, no __slangpy_load + assert_not_contains(code, "__slangpy_load") + assert_not_contains(code, "typealias _t_s") + assert_contains(code, "S s;") + assert_trampoline_has(code, "s = __calldata__.s;") + + s = bindings.args[0] + assert s.direct_bind is True + assert s.children["x"].direct_bind is True + assert s.children["y"].direct_bind is True + + +# 8 ------------------------------------------------------------------------- +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +@pytest.mark.parametrize( + "variant", + ["vector_field", "array_field"], + ids=["vector_field", "array_field"], +) +def test_struct_composite_fields_direct_bind(device_type: spy.DeviceType, variant: str): + """Struct with composite field (vector / array) all dim-0 → direct-bind. + + Merges: struct_with_vector_fields, struct_with_array_field codegen+binding tests. + """ + device = helpers.get_device(device_type) + + if variant == "vector_field": + src = """ +struct S { + float3 pos; + float scale; +}; +float3 apply(S s) { return s.pos * s.scale; } +""" + arg = {"_type": "S", "pos": spy.math.float3(1, 2, 3), "scale": 2.0} + func_name = "apply" + child_name = "pos" + else: + src = """ +struct Foo { + int vals[4]; +}; +int sum_inner(Foo foo) { + int s = 0; + for (int i = 0; i < 4; i++) { + s += foo.vals[i]; + } + return s; +} +""" + arg = {"_type": "Foo", "vals": [1, 2, 3, 4]} + func_name = "sum_inner" + child_name = "vals" + + code, bindings = generate_code_and_bindings(device, func_name, src, arg) + + # Struct is direct-bind — raw type, no __slangpy_load + assert_not_contains(code, "__slangpy_load") + param_name = "s" if variant == "vector_field" else "foo" + assert_not_contains(code, f"typealias _t_{param_name}") + assert_trampoline_has(code, f"{param_name} = __calldata__.{param_name};") + + s = bindings.args[0] + assert s.direct_bind is True + assert s.children[child_name].direct_bind is True + + +# 9 ------------------------------------------------------------------------- +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_deeply_nested_struct_direct_bind(device_type: spy.DeviceType): + """3-level deep Top{Mid{Bot}} — all-scalar, direct-bind at every level. + + Subsumes 2-level nested struct tests. Merges: test_gate_deeply_nested_struct_codegen, + test_gate_deeply_nested_struct_binding_flags. + """ + device = helpers.get_device(device_type) + src = """ +struct Bot { + float v; +}; +struct Mid { + Bot bot; + int c; +}; +struct Top { + Mid mid; + float s; +}; +float compute(Top t) { return t.mid.bot.v * float(t.mid.c) * t.s; } +""" + arg = { + "_type": "Top", + "mid": {"_type": "Mid", "bot": {"_type": "Bot", "v": 2.0}, "c": 3}, + "s": 4.0, + } + code, bindings = generate_code_and_bindings(device, "compute", src, arg) + + assert_not_contains(code, "typealias _t_t") + assert_contains(code, "Top t;") + assert_not_contains(code, "__slangpy_load") + assert_not_contains(code, "struct _t_t") + assert_trampoline_has(code, "t = __calldata__.t;") + + t = bindings.args[0] + assert t.direct_bind is True + assert t.children["mid"].direct_bind is True + assert t.children["mid"].children["bot"].direct_bind is True + assert t.children["mid"].children["bot"].children["v"].direct_bind is True + assert t.children["mid"].children["c"].direct_bind is True + assert t.children["s"].direct_bind is True + + +# 10 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_struct_mixed_fields(device_type: spy.DeviceType): + """S{x(tensor), y(scalar)} — struct NOT direct-bind, scalar child keeps direct-bind. + + Merges: test_gate_struct_mixed_fields_codegen, test_mixed_children_direct_bind_codegen, + test_gate_struct_mixed_fields_binding_flags. + """ + device = helpers.get_device(device_type) + src = """ +struct S { + float x; + float y; +}; +void apply(S s, float scale) {} +""" + tensor_x = Tensor.from_numpy(device, np.array([1, 2, 3], dtype=np.float32)) + code, bindings = generate_code_and_bindings( + device, "apply", src, {"_type": "S", "x": tensor_x, "y": 1.0}, 2.0 + ) + + # Struct NOT direct-bind — inline struct with __slangpy_load + assert_contains(code, "__slangpy_load") + assert_contains(code, "struct _t_s") + assert_not_contains(code, "typealias _t_s = S;") + # Child y direct-bind — type used directly, direct assignment + assert_not_contains(code, "typealias _t_y") + assert_contains(code, "float y;") + assert_contains(code, "value.y = y;") + assert_not_contains(code, "ValueType") + assert_not_contains(code, "_m_y") + # Child x — tensor + assert_contains(code, "Tensor") + assert_contains(code, "x.__slangpy_load(context.map(_m_x),value.x)") + # Independent scalar arg 'scale' — direct-bind + assert_not_contains(code, "typealias _t_scale") + assert_contains(code, "float scale;") + + # Binding flags + s = bindings.args[0] + assert s.direct_bind is False + assert s.children["x"].direct_bind is False + assert s.children["y"].direct_bind is True + assert bindings.args[1].direct_bind is True # scale + + +# 11 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_nested_struct_with_tensor_child(device_type: spy.DeviceType): + """Outer{Inner{x(tensor),y(scalar)},s} — Outer/Inner NOT direct-bind, scalar leaves are. + + Merges: test_gate_nested_struct_with_tensor_child_codegen, + test_gate_nested_struct_with_tensor_child_binding_flags. + """ + device = helpers.get_device(device_type) + src = """ +struct Inner { + float x; + float y; +}; +struct Outer { + Inner inner; + float s; +}; +float compute(Outer o) { return (o.inner.x + o.inner.y) * o.s; } +""" + tensor_x = Tensor.from_numpy(device, np.array([1, 2, 3], dtype=np.float32)) + code, bindings = generate_code_and_bindings( + device, + "compute", + src, + {"_type": "Outer", "inner": {"_type": "Inner", "x": tensor_x, "y": 10.0}, "s": 2.0}, + ) + + # Outer/Inner NOT direct-bind + assert_contains(code, "struct _t_o") + assert_contains(code, "__slangpy_load") + assert_not_contains(code, "typealias _t_o = Outer;") + # Scalar children retain direct-bind + assert_not_contains(code, "typealias _t_y") + assert_contains(code, "float y;") + assert_not_contains(code, "typealias _t_s") + assert_contains(code, "float s;") + assert_contains(code, "value.y = y;") + assert_contains(code, "_m_x") + + o = bindings.args[0] + assert o.direct_bind is False + assert o.children["inner"].direct_bind is False + assert o.children["inner"].children["x"].direct_bind is False + assert o.children["inner"].children["y"].direct_bind is True + assert o.children["s"].direct_bind is True + + +# 12 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_struct_return_not_direct_bind(device_type: spy.DeviceType): + """Function returning struct — _result uses wrapper, NOT direct-bind. + + Merges: test_gate_struct_return_codegen, test_gate_struct_return_binding_flags. + """ + device = helpers.get_device(device_type) + src = """ +struct S { + int x; + int y; +}; +S make_struct(int a, int b) { return { a, b }; } +""" + code, bindings = generate_code_and_bindings(device, "make_struct", src, 4, 5) + + # Scalar inputs direct-bind + assert_not_contains(code, "typealias _t_a", "typealias _t_b") + # _result writable — uses wrapper + assert_contains(code, "__slangpy_store") + assert_contains(code, "_m__result") + + result_binding = bindings.kwargs["_result"] + assert result_binding.direct_bind is False + assert bindings.args[0].direct_bind is True + assert bindings.args[1].direct_bind is True + + +# 13 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_struct_vectorized_2d_child(device_type: spy.DeviceType): + """S{float3 v (2D tensor→float3), float s (scalar)} — struct NOT direct-bind.""" + device = helpers.get_device(device_type) + src = """ +struct S { + float3 v; + float s; +}; +float3 apply(S st) { return st.v * st.s; } +""" + tensor_v = Tensor.from_numpy(device, np.ones((5, 3), dtype=np.float32)) + code, bindings = generate_code_and_bindings( + device, "apply", src, {"_type": "S", "v": tensor_v, "s": 2.0} + ) + + assert_contains(code, "struct _t_st") + assert_contains(code, "__slangpy_load") + assert_not_contains(code, "typealias _t_st = S;") + # Scalar child s direct-bind + assert_not_contains(code, "typealias _t_s") + assert_contains(code, "float s;") + assert_contains(code, "value.s = s;") + assert_contains(code, "_m_v") + + +# 14 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_mixed_scalar_and_tensor(device_type: spy.DeviceType): + """Scalar + tensor args — scalar direct-bind, tensor not. + + Merges: test_gate_mixed_args_scalar_and_tensor, test_gate_mixed_args_direct_bind_flags. + """ + device = helpers.get_device(device_type) + tensor = Tensor.from_numpy(device, np.array([1, 2, 3], dtype=np.float32)) + code, bindings = generate_code_and_bindings( + device, "add", "float add(float a, float b) { return a + b; }", 1.0, tensor + ) + + # 'a' direct-bind + assert_not_contains(code, "typealias _t_a") + assert_not_contains(code, "ValueType") + assert_trampoline_has(code, "a = __calldata__.a;") + # 'b' NOT direct-bind (vectorized tensor) + assert_contains(code, "Tensor") + assert_contains(code, "__slangpy_load") + assert_contains(code, "_m_b") + + assert bindings.args[0].direct_bind is True + assert bindings.args[0].call_dimensionality == 0 + assert bindings.args[1].direct_bind is False + assert bindings.args[1].call_dimensionality == 1 + + +# 15 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_tensor_dim0_direct_bind(device_type: spy.DeviceType): + """Tensor at dim-0: whole tensor passed, direct-bind. + + Merges: test_gate_tensor_dim0_codegen, test_gate_tensor_dim0_binding_flags. + """ + device = helpers.get_device(device_type) + src = """ +float tensor_read(Tensor t) { + return t[0]; +} +""" + tensor = Tensor.from_numpy(device, np.array([42, 2, 3], dtype=np.float32)) + code, bindings = generate_code_and_bindings(device, "tensor_read", src, tensor) + + assert_not_contains(code, "typealias _t_t") + assert_contains(code, "Tensor t;") + assert_trampoline_has(code, "t = __calldata__.t;") + assert_not_contains(code, "ValueType<") + + t = bindings.args[0] + assert t.direct_bind is True + assert t.call_dimensionality == 0 + + +# 16 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_2d_tensor_to_vector(device_type: spy.DeviceType): + """2D Tensor (10,3) → float3: trailing dim consumed by vector, outer dispatched. + + Merges: test_gate_2d_tensor_to_vector_codegen, test_gate_2d_tensor_to_vector_binding_flags. + """ + device = helpers.get_device(device_type) + tensor = Tensor.from_numpy(device, np.ones((10, 3), dtype=np.float32)) + code, bindings = generate_code_and_bindings( + device, "scale", "float3 scale(float3 v, float s) { return v * s; }", tensor, 2.0 + ) + + assert_contains(code, "__slangpy_load") + assert_contains(code, "_m_v") + assert_not_contains(code, "typealias _t_s") + assert_contains(code, "float s;") + + v = bindings.args[0] + assert v.call_dimensionality == 1 + assert v.direct_bind is False + assert v.vector_type is not None + assert v.vector_type.full_name == "vector" + s = bindings.args[1] + assert s.call_dimensionality == 0 + assert s.direct_bind is True + + +# 17 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_3d_tensor_to_vector(device_type: spy.DeviceType): + """3D Tensor (2,5,3) → float3: two outer dims dispatched (call_dim=2). + + Merges: test_gate_3d_tensor_to_vector_codegen, test_gate_3d_tensor_to_vector_binding_flags. + """ + device = helpers.get_device(device_type) + tensor = Tensor.from_numpy(device, np.ones((2, 5, 3), dtype=np.float32)) + code, bindings = generate_code_and_bindings( + device, "negate", "float3 negate(float3 v) { return -v; }", tensor + ) + + assert_contains(code, "__slangpy_load") + assert_contains(code, "_m_v") + + v = bindings.args[0] + assert v.call_dimensionality == 2 + assert v.direct_bind is False + assert v.vector_type.full_name == "vector" + + +# 18 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_2d_tensor_to_scalar(device_type: spy.DeviceType): + """2D Tensor (4,5) → float: both dims dispatched (call_dim=2). + + Merges: test_gate_2d_tensor_to_scalar_codegen, test_gate_2d_tensor_to_scalar_binding_flags. + """ + device = helpers.get_device(device_type) + tensor = Tensor.from_numpy(device, np.ones((4, 5), dtype=np.float32)) + code, bindings = generate_code_and_bindings( + device, "square", "float square(float x) { return x * x; }", tensor + ) + + assert_contains(code, "__slangpy_load") + assert_contains(code, "_m_x") + + v = bindings.args[0] + assert v.call_dimensionality == 2 + assert v.direct_bind is False + + +# 19 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_2d_tensor_to_array(device_type: spy.DeviceType): + """2D Tensor (4,8) → half[8]: trailing dim consumed by array, outer dispatched. + + Merges: test_gate_2d_tensor_to_1d_array_codegen, test_gate_2d_tensor_to_1d_array_binding_flags. + """ + device = helpers.get_device(device_type) + tensor = Tensor.from_numpy(device, np.ones((4, 8), dtype=np.float16)) + src = r""" +half[NumChannels] tensor_test_channels(half[NumChannels] data) +{ + [ForceUnroll] + for (int i = 0; i < NumChannels; ++i) + { + data[i] = 2.h * data[i]; + } + return data; +} +""" + code, bindings = generate_code_and_bindings(device, "tensor_test_channels<8>", src, tensor) + + assert_contains(code, "__slangpy_load") + assert_contains(code, "_m_data") + + v = bindings.args[0] + assert v.call_dimensionality == 1 + assert v.direct_bind is False + + +# 20 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_mixed_vectorized_dim0_tensor(device_type: spy.DeviceType): + """One tensor vectorized (2D→float3) and another at dim-0 (Tensor param). + + Merges: test_gate_mixed_vectorized_and_dim0_tensor_codegen, + test_gate_mixed_vectorized_and_dim0_tensor_binding_flags. + """ + device = helpers.get_device(device_type) + src = """ +float dot_lookup(float3 v, Tensor weights) { + return v.x * weights[0] + v.y * weights[1] + v.z * weights[2]; +} +""" + vec_tensor = Tensor.from_numpy(device, np.ones((5, 3), dtype=np.float32)) + weight_tensor = Tensor.from_numpy(device, np.array([1, 2, 3], dtype=np.float32)) + code, bindings = generate_code_and_bindings( + device, "dot_lookup", src, vec_tensor, weight_tensor + ) + + # v: vectorized dim-1 (2D→float3) + assert_contains(code, "_m_v") + assert_contains(code, "__slangpy_load") + # weights: dim-0 direct-bind + assert_not_contains(code, "typealias _t_weights") + assert_contains(code, "Tensor weights;") + assert_trampoline_has(code, "weights = __calldata__.weights;") + + v = bindings.args[0] + assert v.call_dimensionality == 1 + assert v.direct_bind is False + w = bindings.args[1] + assert w.call_dimensionality == 0 + assert w.direct_bind is True + + +# 21 ------------------------------------------------------------------------ +# Long type name heuristic constants +_LONG_STRUCT_NAME = "MyVeryLongStructNameThatExceedsSixtyCharactersForTesting12345" +assert len(_LONG_STRUCT_NAME) > 60 +_SHORT_STRUCT_NAME = "S" +assert len(_SHORT_STRUCT_NAME) <= 60 + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_long_type_name_typealias(device_type: spy.DeviceType): + """Long struct name (>60 chars) emits typealias; short name inlines. + Also covers long wrapper name for _result. + + Merges: test_gate_long_struct_name_gets_typealias, test_gate_short_struct_name_inlined, + test_gate_long_scalar_type_name_gets_typealias. + """ + device = helpers.get_device(device_type) + + # --- Long name → typealias emitted --- + long_src = f""" +struct {_LONG_STRUCT_NAME} {{ + float x; + float y; +}}; +float sum({_LONG_STRUCT_NAME} s) {{ return s.x + s.y; }} +""" + code_long, _ = generate_code_and_bindings( + device, "sum", long_src, {"_type": _LONG_STRUCT_NAME, "x": 1.0, "y": 2.0} + ) + assert_contains(code_long, f"typealias _t_s = {_LONG_STRUCT_NAME};") + assert_contains(code_long, "_t_s s;") + + # --- Short name → no typealias --- + short_src = f""" +struct {_SHORT_STRUCT_NAME} {{ + float x; + float y; +}}; +float sum({_SHORT_STRUCT_NAME} s) {{ return s.x + s.y; }} +""" + code_short, _ = generate_code_and_bindings( + device, "sum", short_src, {"_type": _SHORT_STRUCT_NAME, "x": 1.0, "y": 2.0} + ) + assert_not_contains(code_short, "typealias _t_s") + assert_contains(code_short, f"{_SHORT_STRUCT_NAME} s;") + + # --- Long wrapper name for _result --- + identity_src = f""" +struct {_LONG_STRUCT_NAME} {{ + float x; + float y; +}}; +{_LONG_STRUCT_NAME} identity({_LONG_STRUCT_NAME} s) {{ return s; }} +""" + code_id, _ = generate_code_and_bindings( + device, "identity", identity_src, {"_type": _LONG_STRUCT_NAME, "x": 1.0, "y": 2.0} + ) + result_type = f"RWValueRef<{_LONG_STRUCT_NAME}>" + assert len(result_type) > 60 + assert_contains(code_id, f"typealias _t__result = {result_type};") + + +# =========================================================================== +# Negative gates (22–24) — must remain passing +# =========================================================================== + + +# 22 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_wanghasharg_not_direct_bind(device_type: spy.DeviceType): + """WangHashArg is NOT direct-bind, standalone and as struct child. + + Merges: test_gate_wanghasharg_uses_wrapper, test_wanghasharg_binding_flag, + test_struct_with_wanghash_child_not_direct_bind. + """ + device = helpers.get_device(device_type) + + # --- Standalone WangHashArg --- + code_s, bindings_s = generate_code_and_bindings( + device, "rng", "uint3 rng(uint3 input) { return input; }", WangHashArg(3) + ) + assert_contains(code_s, "WangHashArg<") + assert_contains(code_s, "input") + assert bindings_s.args[0].direct_bind is False + assert bindings_s.args[0].call_dimensionality == 0 + + # --- As struct child --- + struct_src = """ +struct S { uint3 seed; float scale; }; +float apply(S s) { return float(s.seed.x) * s.scale; } +""" + func = helpers.create_function_from_module(device, "apply", struct_src) + cd = func.debug_build_call_data({"_type": "S", "seed": WangHashArg(3), "scale": 1.0}) + code_c, bindings_c = cd.code, cd.debug_only_bindings + + s = bindings_c.args[0] + assert s.direct_bind is False + assert s.children["scale"].direct_bind is True + assert_contains(code_c, "struct _t_s") + assert_not_contains(code_c, "typealias _t_s = S;") + + +# 23 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_vectorized_scalar_keeps_wrapper(device_type: spy.DeviceType): + """1D tensor → float: vectorized, keeps __slangpy_load.""" + device = helpers.get_device(device_type) + tensor = Tensor.from_numpy(device, np.array([1, 2, 3], dtype=np.float32)) + code, _ = generate_code_and_bindings( + device, "square", "float square(float x) { return x * x; }", tensor + ) + assert_contains(code, "__slangpy_load") + + +# 24 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_vectorized_dict_keeps_wrapper(device_type: spy.DeviceType): + """Dict with tensor children: vectorized, keeps __slangpy_load.""" + device = helpers.get_device(device_type) + src = """ +struct S { + float x; + float y; +}; +void apply(S s, float scale) {} +""" + tensor_x = Tensor.from_numpy(device, np.array([1, 2, 3], dtype=np.float32)) + tensor_y = Tensor.from_numpy(device, np.array([4, 5, 6], dtype=np.float32)) + code, _ = generate_code_and_bindings( + device, + "apply", + src, + {"_type": "S", "x": tensor_x, "y": tensor_y}, + 1.0, + ) + assert_contains(code, "__slangpy_load") + + +# =========================================================================== +# Autodiff (25) +# =========================================================================== + + +# 25 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_bwds_direct_bind(device_type: spy.DeviceType): + """Backwards-mode: primals direct-bind, differentiable markers present. + + Merges: test_gate_bwds_scalar_uses_valuetype, test_gate_bwds_trampoline_is_differentiable, + test_bwds_primal_binding_flags. + """ + device = helpers.get_device(device_type) + src = """ +[Differentiable] +float polynomial(float a, float b) { + return a * a + b + 1; +} +""" + code, bindings = generate_bwds_code_and_bindings(device, "polynomial", src, 5.0, 10.0, 26.0) + + # No ValueType wrapper + assert_not_contains(code, "ValueType") + # Differentiable markers + assert_contains(code, "[Differentiable]", "bwd_diff(_trampoline)") + # [Differentiable] appears before trampoline + diff_idx = code.index("[Differentiable]") + trampoline_idx = code.index("void _trampoline") + assert diff_idx < trampoline_idx + + # Primal args direct-bind + assert bindings.args[0].direct_bind is True # a + assert bindings.args[1].direct_bind is True # b + + +# =========================================================================== +# Functional GPU dispatch — novel scenarios only (26–34) +# =========================================================================== + + +# 26 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_dispatch_mixed_scalar_tensor(device_type: spy.DeviceType): + """Dispatch mixed scalar + tensor and verify GPU result.""" + device = helpers.get_device(device_type) + func = helpers.create_function_from_module( + device, "add", "float add(float a, float b) { return a + b; }" + ) + tensor = Tensor.from_numpy(device, np.array([10, 20, 30], dtype=np.float32)) + result = func(5.0, tensor) + expected = np.array([15, 25, 35], dtype=np.float32) + np.testing.assert_allclose(result.to_numpy().flatten(), expected, atol=1e-5) + + +# 27 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_dispatch_struct_mixed_fields(device_type: spy.DeviceType): + """Dispatch struct with mixed tensor+scalar fields and verify GPU result.""" + device = helpers.get_device(device_type) + src = """ +struct S { + float x; + float y; +}; +float weighted_sum(S s, float scale) { return (s.x + s.y) * scale; } +""" + func = helpers.create_function_from_module(device, "weighted_sum", src) + tensor_x = Tensor.from_numpy(device, np.array([1, 2, 3], dtype=np.float32)) + result = func({"_type": "S", "x": tensor_x, "y": 10.0}, 2.0) + expected = np.array([22, 24, 26], dtype=np.float32) + np.testing.assert_allclose(result.to_numpy().flatten(), expected, atol=1e-5) + + +# 28 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_dispatch_tensor_dim0(device_type: spy.DeviceType): + """Dispatch whole tensor at dim-0 and verify GPU result.""" + device = helpers.get_device(device_type) + src = """ +float tensor_read(Tensor t) { + return t[0]; +} +""" + func = helpers.create_function_from_module(device, "tensor_read", src) + tensor = Tensor.from_numpy(device, np.array([42, 99, 7], dtype=np.float32)) + result = func(tensor) + assert abs(result - 42.0) < 1e-5 + + +# 29 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_dispatch_2d_tensor_to_vector(device_type: spy.DeviceType): + """Dispatch 2D tensor → float3 and verify GPU result.""" + device = helpers.get_device(device_type) + func = helpers.create_function_from_module( + device, "scale", "float3 scale(float3 v, float s) { return v * s; }" + ) + data = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) + tensor = Tensor.from_numpy(device, data) + result = func(tensor, 2.0) + expected = data * 2.0 + np.testing.assert_allclose(result.to_numpy().reshape(expected.shape), expected, atol=1e-5) + + +# 30 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_dispatch_2d_tensor_to_array(device_type: spy.DeviceType): + """Dispatch 2D tensor → half[8] and verify GPU doubles each element.""" + device = helpers.get_device(device_type) + func = helpers.create_function_from_module( + device, + "tensor_test_channels<8>", + r""" +half[NumChannels] tensor_test_channels(half[NumChannels] data) +{ + [ForceUnroll] + for (int i = 0; i < NumChannels; ++i) + { + data[i] = 2.h * data[i]; + } + return data; +} +""", + ).return_type(Tensor) + data = np.ones((4, 8), dtype=np.float16) + tensor = Tensor.from_numpy(device, data) + result = func(tensor) + expected = data * 2.0 + np.testing.assert_allclose( + result.to_numpy().reshape(expected.shape).astype(np.float32), + expected.astype(np.float32), + atol=1e-2, + ) + + +# 31 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_dispatch_mixed_vectorized_dim0_tensor(device_type: spy.DeviceType): + """Dispatch vectorized float3 + dim-0 Tensor and verify GPU result.""" + device = helpers.get_device(device_type) + src = """ +float dot_lookup(float3 v, Tensor weights) { + return v.x * weights[0] + v.y * weights[1] + v.z * weights[2]; +} +""" + func = helpers.create_function_from_module(device, "dot_lookup", src) + vecs = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float32) + weights = np.array([10, 20, 30], dtype=np.float32) + result = func( + Tensor.from_numpy(device, vecs), + Tensor.from_numpy(device, weights), + ) + expected = np.array([10, 20, 30], dtype=np.float32) + np.testing.assert_allclose(result.to_numpy().flatten(), expected, atol=1e-5) + + +# 32 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_dispatch_nested_struct_with_tensor(device_type: spy.DeviceType): + """Dispatch nested struct with tensor leaf and verify GPU result.""" + device = helpers.get_device(device_type) + src = """ +struct Inner { + float x; + float y; +}; +struct Outer { + Inner inner; + float s; +}; +float compute(Outer o) { return (o.inner.x + o.inner.y) * o.s; } +""" + func = helpers.create_function_from_module(device, "compute", src) + tensor_x = Tensor.from_numpy(device, np.array([1, 2, 3], dtype=np.float32)) + result = func( + {"_type": "Outer", "inner": {"_type": "Inner", "x": tensor_x, "y": 10.0}, "s": 2.0} + ) + expected = np.array([22, 24, 26], dtype=np.float32) + np.testing.assert_allclose(result.to_numpy().flatten(), expected, atol=1e-5) + + +# 33 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_dispatch_struct_vectorized_2d_child(device_type: spy.DeviceType): + """Dispatch struct with 2D tensor→float3 child and verify GPU result.""" + device = helpers.get_device(device_type) + src = """ +struct S { + float3 v; + float s; +}; +float3 apply(S st) { return st.v * st.s; } +""" + func = helpers.create_function_from_module(device, "apply", src) + data = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) + tensor_v = Tensor.from_numpy(device, data) + result = func({"_type": "S", "v": tensor_v, "s": 2.0}) + expected = data * 2.0 + np.testing.assert_allclose(result.to_numpy().reshape(expected.shape), expected, atol=1e-5) + + +# 34 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_dispatch_struct_array_of_structs(device_type: spy.DeviceType): + """Dispatch struct with array-of-structs field and verify GPU result.""" + device = helpers.get_device(device_type) + src = """ +struct Inner { + int x; +}; +struct Outer { + Inner items[4]; +}; +int sum_inner(Outer outer) { + int s = 0; + for (int i = 0; i < 4; i++) { + s += outer.items[i].x; + } + return s; +} +""" + func = helpers.create_function_from_module(device, "sum_inner", src) + result = func( + { + "_type": "Outer", + "items": [ + {"_type": "Inner", "x": 10}, + {"_type": "Inner", "x": 20}, + {"_type": "Inner", "x": 30}, + {"_type": "Inner", "x": 40}, + ], + } + ) + assert result == 100 + + +if __name__ == "__main__": + pytest.main([__file__, "-vs"]) From 553795e1a8b32599171e484243221f001370d760 Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Thu, 12 Mar 2026 14:41:10 +0000 Subject: [PATCH 17/41] gate tests --- .../plan-simplifyKernelGen-phase2.prompt.md | 400 ++++++++++++++---- .../tests/slangpy_tests/test_kernel_gen.py | 74 ++++ 2 files changed, 397 insertions(+), 77 deletions(-) diff --git a/.github/prompts/plan-simplifyKernelGen-phase2.prompt.md b/.github/prompts/plan-simplifyKernelGen-phase2.prompt.md index 5fa19c3d0..2f46c6e3a 100644 --- a/.github/prompts/plan-simplifyKernelGen-phase2.prompt.md +++ b/.github/prompts/plan-simplifyKernelGen-phase2.prompt.md @@ -1,121 +1,367 @@ ## Phase 2: Eliminate CallData Struct -**Goal**: When ALL arguments are direct-eligible, bypass the `CallData` struct entirely and pass arguments as individual parameters on the entry point (or individual globals). +**Goal**: Move kernel uniforms out of the `CallData` struct into individual entry-point parameters. Eliminate the trampoline in forward (prim) mode. Fall back to `ParameterBlock` when total inline-uniform size exceeds a runtime per-device threshold. **Parent plan**: [plan-simplifyKernelGen.prompt.md](plan-simplifyKernelGen.prompt.md) --- -### Step 2.1: Determine eligibility +### Key Architectural Decisions -Add a check in [slangpy/core/calldata.py](slangpy/core/calldata.py) after all bindings are resolved: if every `BoundVariable` satisfies `is_direct_bind_eligible` (for leaves) or `is_direct_bind_recursive` (for composites) AND `call_data_len == 0` (no N-dimensional shape arrays needed), set a new flag `self.use_direct_args = True`. +These decisions correct several assumptions in the original plan: ---- +1. **Entry-point param placement is orthogonal to `direct_bind`.** Any type — wrapped or raw — can be an entry-point parameter (e.g., `uniform ValueType a` or `uniform int a` or `uniform Tensor t`). `direct_bind` governs whether `__slangpy_load`/`__slangpy_store` is needed inside the kernel; entry-point placement governs where the uniform lives in the shader layout. -### Step 2.2: New code generation path +2. **Trampoline elimination is independent of `direct_bind`.** The current trampoline body is: declare locals → load (direct assignment or `__slangpy_load`) → call function → store (`__slangpy_store`). All of that can appear directly in `compute_main`. The trampoline only exists because bwds mode needs a `[Differentiable]` wrapper for `bwd_diff()`. In prim mode, it is eliminated regardless of whether args use wrappers. -In [slangpy/core/callsignature.py](slangpy/core/callsignature.py), when `use_direct_args`: +3. **All-or-nothing fallback.** When total inline-uniform size exceeds the platform threshold, ALL args go back into `ParameterBlock` (the current path). No hybrid mixing of entry-point params and CallData. -- **Skip CallData struct generation** entirely. Note: `CodeGen.__init__` in [codegen.py](slangpy/bindings/codegen.py) unconditionally emits `struct CallData { ... }` — the constructor creates the `self.call_data` block and `finish()` calls `self.call_data.end_block()`. To eliminate CallData, either: - - Add a `skip_call_data` flag to `CodeGen.__init__` that conditionally initializes the block, and condition the `end_block()` in `finish()` on the same flag, OR - - Clear `self.call_data` contents before `finish()` when `use_direct_args` is true -- **Generate compute_main** with individual `uniform` parameters. The current compute_main signature has three semantic params: - ``` - void compute_main(int3 flat_call_thread_id: SV_DispatchThreadID, int3 flat_call_group_id: SV_GroupID, int flat_call_group_thread_id: SV_GroupIndex, uniform CallData call_data) - ``` - When `use_direct_args` and `call_data_len == 0`, the `SV_GroupID` and `SV_GroupIndex` params are unused (they feed `init_thread_local_call_shape_info` which reads `call_data._grid_stride`/`_grid_dim`/`_call_dim`). They can be dropped, simplifying to: - ``` - void compute_main(int3 flat_call_thread_id: SV_DispatchThreadID, uniform uint3 _thread_count, uniform int a, uniform int b, uniform RWStructuredBuffer _result) - ``` -- **Inline the function call** into compute_main (skip trampoline for prim mode): `_result[0] = add(a, b);` -- **Keep trampoline** for bwds mode (needed for `bwd_diff()`). The trampoline wraps the call with `[Differentiable]` and allows `bwd_diff(_trampoline)` from compute_main. In this case, generate a trampoline that takes individual params instead of a struct. Direct assignment `a = param_a;` is trivially differentiable in Slang for floating-point types. For non-differentiable types (int, etc.), autodiff is irrelevant. +4. **Shape arrays and `_thread_count` obey the same rules** as user args — they become entry-point params by default, and go into `CallData` on fallback. Phase 2 is NOT scoped only to `call_data_len == 0`. + +5. **Two code paths based on where data lives:** + - **Fast path** (entry-point params): In Slang, uniforms are entry-point parameters and can be used directly (in forward) or passed directly to the trampoline (in backward). + - **Fallback path** (`ParameterBlock`): In Slang, uniforms live in a `CallData` struct. They must be read into local variables before being used (in forward) or passed to the trampoline (in backward). This is the current behavior. + +6. **C++ dispatch changes are isolated to `NativeCallData::exec`.** Marshalls receive a `ShaderCursor` pointing to wherever their data lives — they don't care whether it's inside a `CallData` struct or an entry-point param. In the fast path, `m_runtime->write_shader_cursor_pre_dispatch()` receives the entry-point cursor directly. No marshall code changes needed. + +7. **`CallDataMode` is eliminated.** The `global_data` vs `entry_point` distinction is removed entirely. On the fast path, all backends use entry-point params uniformly. On the fallback path, all backends use `ParameterBlock` — CUDA supports `ParameterBlock` and in practice will never hit the fallback (CUDA's inline-uniform limit is ~4KB). This removes the `CallDataMode` enum, the CUDA-specific `is_entry_point` codegen branch in `callsignature.py`, and the corresponding C++ branch in `slangpy.cpp`. + +8. **`PackedArg` / param-block types are unchanged.** They stay as `ParameterBlock` at module scope, orthogonal to Phase 2. --- -### Step 2.3: Entry point parameters for all backends +### Current Kernel Structure (post-Phase 1) -Currently, CUDA (entry_point mode) already passes a `CallData` struct as a `uniform` entry point parameter. The simplification extends this: instead of a single struct, pass individual `uniform` parameters on the entry point — for ALL backends, not just CUDA. +For `int add(int a, int b)` with scalar args `(1, 2)`: -See [slangpy/tests/device/test_pipeline_utils.slang](slangpy/tests/device/test_pipeline_utils.slang) for examples of manually-written compute shaders that use entry point parameters on all backends: ```slang +import "module"; +import "slangpy"; + +typealias _t_a = int; // Phase 1: raw type (was ValueType) +typealias _t__result = RWValueRef; // writable _result still wrapped +static const int _m__result = 0; // mapping constant only for _result + +struct CallData { + _t_a a; + _t_a b; + _t__result _result; + uint3 _thread_count; +}; + +void _trampoline(Context __slangpy_context__, CallData __calldata__) { + int a; + a = __calldata__.a; // Phase 1: direct assignment + int b; + b = __calldata__.b; // Phase 1: direct assignment + int _result; + _result = add(a, b); + __calldata__._result.__slangpy_store(__slangpy_context__.map(_m__result), _result); +} + +[shader("compute")] [numthreads(32,1,1)] +void compute_main(int3 flat_call_thread_id: SV_DispatchThreadID, ..., uniform CallData call_data) { + if (any(flat_call_thread_id >= call_data._thread_count)) return; + Context __slangpy_context__ = {flat_call_thread_id}; + _trampoline(__slangpy_context__, call_data); +} +``` + +### Target Kernel (Phase 2 fast path, prim mode, all direct-bind) + +```slang +import "module"; + [shader("compute")] -[numthreads(16, 16, 1)] -void setcolor( - uint3 tid: SV_DispatchThreadID, - RWTexture2D render_texture, - uniform int2 pos, - uniform int2 size, - uniform float4 color -) +[numthreads(32, 1, 1)] +void compute_main(int3 tid: SV_DispatchThreadID, + uniform uint3 _thread_count, + uniform int a, + uniform int b, + uniform RWStructuredBuffer _result) +{ + if (any(tid >= _thread_count)) return; + _result[0] = add(a, b); +} ``` -Entry point parameters work on all backends (CUDA, Vulkan, D3D12). For `global_data` mode, the C++ side currently navigates `cursor["call_data"]` to write into a `ParameterBlock` global. With direct args, it would instead navigate `cursor.find_entry_point(0)` and write each parameter by index — the same mechanism CUDA already uses, but now applied universally. +### Target Kernel (Phase 2 fast path, prim mode, mixed direct/non-direct-bind) -The `CallData` struct can be omitted entirely when all args are direct-eligible. If some args still need the struct (e.g., shape arrays for `call_data_len > 0`, or non-direct-eligible types), emit a hybrid: direct-eligible args as individual entry point params, and the remaining data in a `CallData` struct that is also an entry point param. +When some args are not direct-bind (e.g., WangHashArg needs per-thread `thread_id` via `__slangpy_load`), the non-direct-bind args still use their wrapper types as entry-point params. Context is needed: -**Entry point size limits**: Some platforms impose limits on the total size of entry point parameter data (e.g., CUDA root constants are limited to ~4KB, D3D12 root signature has a 64 DWORD limit). To handle this: -- Define a per-backend threshold for maximum entry point parameter data size (queryable from device/backend info) -- During code generation, accumulate the uniform byte size of each direct-eligible argument. Resource types (`RWStructuredBuffer`, `Texture2D`, etc.) don't count toward the limit — they are bound as descriptors, not inline data -- If a single argument exceeds the threshold, force it back to `CallData` -- If the cumulative total exceeds the threshold, force remaining arguments (in declaration order) back to `CallData` -- The result may be a hybrid kernel: some args as entry point params, the rest in a `CallData` struct entry point param -- The C++ dispatch side must know which args are direct vs CallData-bound (store a per-argument flag or a bitmask on `NativeCallData`) +```slang +import "module"; +import "slangpy"; + +typealias _t_rng = WangHashArgType; // non-direct-bind wrapper type +static const int _m_rng = 0; + +[shader("compute")] +[numthreads(32, 1, 1)] +void compute_main(int3 flat_call_thread_id: SV_DispatchThreadID, + uniform uint3 _thread_count, + uniform _t_rng rng, + uniform int x, + uniform RWStructuredBuffer _result) +{ + if (any(flat_call_thread_id >= _thread_count)) return; + Context __slangpy_context__ = {flat_call_thread_id}; + int _rng_val; + rng.__slangpy_load(__slangpy_context__.map(_m_rng), _rng_val); + int _x_val; + _x_val = x; + int _result_val; + _result_val = func(_rng_val, _x_val); + _result[0] = _result_val; +} +``` + +### Target Kernel (Phase 2 fallback path, prim mode) + +When entry-point param size exceeds the platform limit, all args go into `ParameterBlock`. The trampoline is still eliminated in prim mode — the load/call/store is inlined into `compute_main`, reading from `call_data`: + +```slang +import "module"; +import "slangpy"; + +typealias _t_a = int; +typealias _t__result = RWValueRef; +static const int _m__result = 0; + +struct CallData { + _t_a a; + _t_a b; + _t__result _result; + uint3 _thread_count; +}; +ParameterBlock call_data; + +[shader("compute")] +[numthreads(32, 1, 1)] +void compute_main(int3 flat_call_thread_id: SV_DispatchThreadID, ...) { + if (any(flat_call_thread_id >= call_data._thread_count)) return; + Context __slangpy_context__ = {flat_call_thread_id}; + int a; + a = call_data.a; + int b; + b = call_data.b; + int _result; + _result = add(a, b); + call_data._result.__slangpy_store(__slangpy_context__.map(_m__result), _result); +} +``` + +--- + +### Step 2.0: Gating tests + +Add tests to [slangpy/tests/slangpy_tests/test_kernel_gen.py](slangpy/tests/slangpy_tests/test_kernel_gen.py) asserting current behavior. These document the baseline and will intentionally break as steps are implemented. + +| Test | Source | Args | Asserts (current) | Breaks when | +|------|--------|------|--------------------|-------------| +| `test_gate_p2_calldata_struct_present` | `int add(int a, int b)` | `(1, 2)` | `struct CallData` in code | Step 2.2 | +| `test_gate_p2_calldata_uniform_param` | same | same | `uniform CallData call_data` in `compute_main` | Step 2.2 | +| `test_gate_p2_thread_count_in_calldata` | same | same | `call_data._thread_count` | Step 2.2 | +| `test_gate_p2_trampoline_present_for_prim` | same | same | `void _trampoline(` present | Step 2.3 | +| `test_gate_p2_kernel_calls_trampoline` | same | same | `_trampoline(` in `compute_main` body | Step 2.3 | +| `test_gate_p2_sv_group_id_present` | same | same | `SV_GroupID` in `compute_main` signature | Step 2.2 | + +Negative gates (must stay passing after Phase 2): + +| Test | Asserts | +|------|---------| +| `test_gate_p2_wanghasharg_keeps_load` | Non-direct-bind arg still uses `__slangpy_load` | --- -### Step 2.4: C++ dispatch changes +### Step 2.1: Determine fast vs fallback path + +In [slangpy/core/calldata.py](slangpy/core/calldata.py), after `calculate_direct_binding(bindings)`: + +1. **Query a runtime per-device threshold** for max entry-point parameter inline-uniform size. This is a property of the device/backend — large for D3D12/CUDA (thousands of bytes), potentially as low as 128–256 bytes on Vulkan. +2. **Accumulate inline-uniform byte size** of each bound variable's `calldata_type_name`, plus `_thread_count` (12 bytes) and shape arrays (`call_data_len * 3 * sizeof(int)` for `_grid_stride`, `_grid_dim`, `_call_dim`). **Resource types** (`RWStructuredBuffer`, `Texture2D`, `TensorView`, etc.) don't count — they are bound as descriptors, not inline data. +3. **Decision**: If total size ≤ threshold → `self.use_direct_args = True` (fast path). Otherwise → `self.use_direct_args = False` (fallback path — current behavior). +4. **Store** `use_direct_args` on the `CallData` instance and propagate to C++ `NativeCallData`. + +`PackedArg` / param-block types are excluded from this accounting — they stay as `ParameterBlock` regardless. + +--- + +### Step 2.2: Code generation — entry-point params (fast path) + +In [slangpy/core/callsignature.py](slangpy/core/callsignature.py) `generate_code()`, when `use_direct_args == True`: + +**CodeGen changes** in [slangpy/bindings/codegen.py](slangpy/bindings/codegen.py): +- Add a `skip_call_data` flag to `CodeGen.__init__`. When `True`, don't emit `struct CallData` / `begin_block()` and gate `end_block()` in `finish()`. +- Add `self.entry_point_params: list[str] = []` to collect individual uniform param declarations. +- `finish()` ignores the `call_data` block and `use_param_block_for_call_data` when `skip_call_data` is set. + +**CallData struct elimination**: Set `cg.skip_call_data = True` when `use_direct_args`. No `struct CallData` emitted. -In [src/slangpy_ext/utils/slangpy.cpp](src/slangpy_ext/utils/slangpy.cpp): +**`gen_call_data_code` change** in [slangpy/bindings/boundvariable.py](slangpy/bindings/boundvariable.py): At `depth == 0`, when `use_direct_args`, append to `cg.entry_point_params` instead of `cg.call_data.declare(...)`. The `call_data_structs` block (type aliases, wrapper structs, mapping constants) still gets emitted at module scope. -- **Store `use_direct_args` flag** on `NativeCallData` (receive from Python `CallData`) -- **Both modes**: In `bind_call_data`, navigate via `cursor.find_entry_point(0)` and write each argument directly to its own entry point parameter by index. This is the same cursor API already used for CUDA entry_point mode — it just needs to write individual params instead of navigating into a single `CallData` struct field. -- **Thread count**: Write `_thread_count` as a separate entry point parameter instead of a struct field -- **Context construction**: The current kernel code constructs a `Context __slangpy_context__` from `call_data` fields (e.g., `flat_call_thread_id, CallShapeInfo::get_call_id().shape`). When `use_direct_args` and `call_data_len == 0`, the Context is simplified to just `{flat_call_thread_id}` and `CallShapeInfo` / `init_thread_local_call_shape_info` can be skipped. If Context is eliminated entirely (Phase 2 with inlined function calls), this becomes moot. -- **Skip shape array writing** (`_grid_stride`, `_grid_dim`, `_call_dim`) since `call_data_len == 0` -- **Cache parameter offsets**: Cache the entry point parameter indices at first dispatch (similar to existing `m_cached_call_data_offsets`) +**`_thread_count` and shape arrays**: Instead of `cg.call_data.append_statement("uint3 _thread_count")`, append to `cg.entry_point_params`. Same for `_grid_stride`, `_grid_dim`, `_call_dim` when `call_data_len > 0`. + +**Entry-point signature**: `compute_main` signature becomes: +```slang +void compute_main( + int3 flat_call_thread_id: SV_DispatchThreadID, + [int3 flat_call_group_id: SV_GroupID,] // only when call_data_len > 0 + [int flat_call_group_thread_id: SV_GroupIndex,] // only when call_data_len > 0 + uniform uint3 _thread_count, + [uniform int[N] _grid_stride, ...] // only when call_data_len > 0 + uniform _t_a a, + uniform _t_b b, + uniform _t__result _result +) +``` + +Drop `SV_GroupID` and `SV_GroupIndex` when `call_data_len == 0` — they feed `init_thread_local_call_shape_info` which isn't called when there are no shape arrays. + +**Bounds check**: Changes from `call_data._thread_count` to just `_thread_count`. + +**Shape info init**: Changes from `call_data._grid_stride` etc. to just `_grid_stride`, `_grid_dim`, `_call_dim`. + +**Fallback path** (`use_direct_args == False`): `struct CallData` is emitted with `ParameterBlock call_data` at module scope on ALL backends (including CUDA). The old `CallDataMode` distinction between `entry_point` (CUDA) and `global_data` (non-CUDA) is removed — `ParameterBlock` works on CUDA, and in practice CUDA will never hit the fallback due to its large (~4KB) inline-uniform limit. + +See [slangpy/tests/device/test_pipeline_utils.slang](slangpy/tests/device/test_pipeline_utils.slang) for examples of manually-written compute shaders that use entry point parameters on all backends (CUDA, Vulkan, D3D12). --- -### Step 2.5: Trampoline elimination for prim mode +### Step 2.3: Trampoline elimination for prim mode + +When `call_mode == prim` — on **both** fast and fallback paths: + +- Don't generate the `_trampoline` function. +- Inline the load/call/store sequence directly into `compute_main` after the bounds check and (if needed) Context construction. +- The load/call/store codegen reuses the same logic currently in [callsignature.py lines 378–449](slangpy/core/callsignature.py#L378-L449), but emitted into `cg.kernel` instead of `cg.trampoline` with adjusted `data_name`: + +| Path | `data_name` for non-param-block args | +|------|-------------------------------------| +| Fast | `x.variable_name` (entry-point param name directly) | +| Fallback | `call_data.{x.variable_name}` (global `ParameterBlock`, all backends) | +| Param blocks | `_param_{x.variable_name}` (unchanged) | -When `use_direct_args` and `call_mode == prim`: -- Don't generate a trampoline function -- Emit the function call directly in `compute_main` using the uniform parameter names -- For output variables, emit the store directly (e.g., `_result[0] = add(a, b);`) +**Context construction**: Needed only when any arg is non-direct-bind (i.e., calls `__slangpy_load`/`__slangpy_store`). When all args satisfy `direct_bind == True`, skip Context construction entirely — no `Context __slangpy_context__` declaration, no `import "slangpy"`. + +**Note**: The trampoline elimination does NOT depend on `direct_bind`. Even non-direct-bind args with `__slangpy_load` work inline in `compute_main` — the `__slangpy_load` call just needs the data reference and a `Context` value, both available in `compute_main`. + +--- + +### Step 2.4: Trampoline with individual params for bwds mode When `call_mode == bwds`: -- Still generate a trampoline (needed for `bwd_diff()`) -- Pass individual params to the trampoline instead of a struct + +- Still generate a `[Differentiable]` trampoline function. +- **Fast path**: Trampoline takes individual params instead of a struct. Use `_gen_trampoline_argument()` from [boundvariable.py](slangpy/bindings/boundvariable.py#L691) (currently dead code) to generate the signature — it already handles `in`/`out`/`inout` and `no_diff` annotations: + ```slang + [Differentiable] + void _trampoline(Context __slangpy_context__, no_diff in int a, no_diff in int b, ...) + ``` + `compute_main` calls `bwd_diff(_trampoline)(__slangpy_context__, a, b, _result)` passing entry-point param names directly. +- **Fallback path**: Trampoline reads from global `ParameterBlock call_data` as it does today (on all backends). `compute_main` calls `bwd_diff(_trampoline)(__slangpy_context__, call_data)`. +- Non-differentiable arguments (int, bool, etc.) get `no_diff` prefix automatically via `_gen_trampoline_argument()`. This may need to be added to additional integer or non-differentiable trampoline arguments to make the generated shader compile under Slang's autodiff rules. --- -### Step 2.6: Tests +### Step 2.5: C++ dispatch changes -**Gating tests** — assert CURRENT behavior so they break when Phase 2 is implemented: +In [src/slangpy_ext/utils/slangpy.cpp](src/slangpy_ext/utils/slangpy.cpp), store `m_use_direct_args` on `NativeCallData` (received from Python `CallData`). Also add to [slangpy.h](src/slangpy_ext/utils/slangpy.h). + +Modify `bind_call_data` lambda in `exec()`: + +**Fast path** (`m_use_direct_args == true`): +- All backends: Navigate via `cursor.find_entry_point(0)`. This is the entry-point cursor. +- Write `_thread_count` as an entry-point param: `entry_point_cursor["_thread_count"]`. +- Write shape arrays as entry-point params: `entry_point_cursor["_grid_stride"]`, etc. +- Pass `entry_point_cursor` as the `call_data_cursor` argument to `m_runtime->write_shader_cursor_pre_dispatch()`. Each `NativeBoundVariableRuntime` already navigates `cursor[m_variable_name]`, so it finds the entry-point param by name automatically. **No marshall code changes needed.** +- Cache entry-point param field indices on first call (analogous to existing `m_cached_call_data_offsets`). +- The `reserve_data` + raw-pointer optimization for `_thread_count` and shape arrays may not work for individual entry-point params at disjoint offsets. Use cursor-based writes for these metadata fields (they're small, performance impact minimal), or check if `reserve_data` still works across the entry-point shader object. + +**Fallback path** (`m_use_direct_args == false`): +- All backends: Navigate to global `call_data` field via `cursor.find_field("call_data")`, dereference (it's a `ParameterBlock`), write struct data. The old `CallDataMode` branch (CUDA using `find_entry_point(0)` for call_data) is removed. Remove `m_call_data_mode`, `CallDataMode` enum, and all associated branches from `slangpy.h`, `slangpy.cpp`, `calldata.py`, and `callsignature.py`. + +--- -| Test | Slang Source | Args | Asserts (current behavior) | Breaks when | -|------|-------------|------|---------------------------|-------------| -| `test_gate_calldata_struct_present` | `int add(int a, int b) { return a + b; }` | `(1, 2)` | `struct CallData` present in generated code | Step 2.1 | -| `test_gate_calldata_uniform_param` | same | same | `uniform CallData call_data` in `compute_main` signature (note: actual signature also includes `SV_GroupID` and `SV_GroupIndex` params) | Step 2.2 | -| `test_gate_thread_count_in_calldata` | same | same | `call_data._thread_count` in kernel body | Step 2.4 | -| `test_gate_context_from_calldata` | same | same | `Context __slangpy_context__` construction present in kernel body | Step 2.4 | -| `test_gate_trampoline_present_for_prim` | same | same | `void _trampoline(` present | Step 2.5 | -| `test_gate_trampoline_calls_function` | same | same | `_result = add(a, b)` inside trampoline | Step 2.5 | -| `test_gate_kernel_calls_trampoline` | same | same | `_trampoline(` inside `compute_main` body | Step 2.5 | +### Step 2.6: `_result` handling -**Negative gates** — should REMAIN passing after Phase 2: +Auto-created `_result` is a writable `ValueRef`, currently NOT direct-bind eligible (needs `RWValueRef` wrapper with buffer logic). Phase 2 handles this differently on the two paths: -| Test | Slang Source | Args | Asserts (must stay) | -|------|-------------|------|--------------------| -| `test_gate_wanghasharg_forces_calldata` | `int rng(WangHashArg rng, int x) { return x; }` | `(spy.WangHashArg(1), 1)` | `struct CallData` present (non-eligible arg forces fallback) | +**Fast path**: `_result` is emitted as `uniform RWValueRef _result` on the entry point. In prim mode, the inlined code stores via `_result.__slangpy_store(...)`. In the all-direct-bind case where Context is omitted, add a new code path: emit `uniform RWStructuredBuffer _result` with `_result[0] = value` for the store. This requires `ValueRefMarshall` to support writable direct-bind for the entry-point-param case specifically, using `RWStructuredBuffer` instead of `RWValueRef`. + +**Fallback path**: `_result` stays as `RWValueRef` inside `CallData`, same as current behavior. + +**Implementation note**: The `RWStructuredBuffer` approach for `_result` is only used when `use_direct_args == True` AND all other args are direct-bind (so Context can be omitted). When non-direct-bind args are present, Context exists and `_result` can continue to use `RWValueRef.__slangpy_store(context, value)`. + +--- + +### Step 2.7: Tests **Post-implementation tests** — should pass AFTER Phase 2 is complete: -- `test_phase2_no_calldata_struct`: verify `struct CallData` absent for all-eligible scalar call -- `test_phase2_uniform_params_on_entry`: verify individual `uniform int a`, `uniform int b` on `compute_main` -- `test_phase2_no_trampoline_prim`: verify no `_trampoline(` for prim-mode eligible calls -- `test_phase2_thread_count_as_uniform`: verify `uniform uint3 _thread_count` as entry point param -- `test_phase2_inline_function_call`: verify `_result[0] = add(a, b)` directly in kernel -- `test_phase2_bwds_keeps_trampoline`: verify bwds mode still has `_trampoline` and `bwd_diff` -- `test_phase2_mixed_args_hybrid`: mix direct-eligible + WangHashArg → hybrid kernel -- `test_phase2_functional_all_backends`: dispatch scalar add on each backend, verify result +| Test | Verifies | +|------|----------| +| `test_phase2_no_calldata_struct` | `struct CallData` absent for eligible call | +| `test_phase2_uniform_params_on_entry` | Individual `uniform` params on `compute_main` | +| `test_phase2_no_trampoline_prim` | No `void _trampoline(` for prim-mode calls | +| `test_phase2_inline_call` | Function call inlined directly in `compute_main` | +| `test_phase2_thread_count_as_uniform` | `uniform uint3 _thread_count` as entry-point param | +| `test_phase2_no_context_all_direct` | No `Context __slangpy_context__` when all args direct-bind | +| `test_phase2_context_kept_non_direct` | `Context` present when some args use `__slangpy_load` | +| `test_phase2_bwds_trampoline_individual` | Bwds trampoline has individual params with `no_diff` | +| `test_phase2_bwds_bwd_diff_call` | `bwd_diff(_trampoline)(ctx, a, b, ...)` in kernel | +| `test_phase2_no_sv_group_when_dim0` | No `SV_GroupID`/`SV_GroupIndex` when `call_data_len == 0` | +| `test_phase2_sv_group_when_vectorized` | `SV_GroupID`/`SV_GroupIndex` present when `call_data_len > 0` | +| `test_phase2_fallback_keeps_calldata` | Force fallback → `struct CallData` still emitted | +| `test_phase2_fallback_no_trampoline_prim` | Even fallback path eliminates trampoline in prim mode | +| `test_phase2_functional_scalar_add` | `add(1, 2) == 3` end-to-end dispatch | +| `test_phase2_functional_bwds` | Backward pass correct gradients | +| `test_phase2_functional_vectorized` | Vectorized call (shapes) with entry-point params | +| `test_phase2_functional_mixed_direct` | Mix of direct-bind + non-direct-bind args | + +--- + +### Implementation Order + +1. **Step 2.0** — Gating tests (baseline documentation) +2. **Step 2.3** — Trampoline elimination for prim mode (both paths). This is independent of entry-point param work and provides immediate value. +3. **Step 2.1** — Fast/fallback determination + size query +4. **Step 2.2 + 2.5** — Code gen + C++ dispatch for entry-point params (must land together — Slang layout and C++ cursor navigation must agree) +5. **Step 2.4** — Bwds trampoline with individual params (fast path) +6. **Step 2.6** — `_result` as `RWStructuredBuffer` for all-direct-bind case +7. **Step 2.7** — Post-implementation tests + functional tests + +Steps 2.3 (trampoline) and 2.2/2.5 (entry-point params) are independent axes and can be done in either order. Starting with 2.3 is recommended because it's simpler and touches fewer files. + +--- + +### Key Files + +| File | Changes | +|------|---------| +| [slangpy/core/calldata.py](slangpy/core/calldata.py) | `use_direct_args` flag, size threshold check, remove `CallDataMode` usage | +| [slangpy/core/callsignature.py](slangpy/core/callsignature.py) | `generate_code()` — inline load/call/store, entry-point params, Context gating, remove `is_entry_point` branch | +| [slangpy/bindings/codegen.py](slangpy/bindings/codegen.py) | `skip_call_data` flag, `entry_point_params` list | +| [slangpy/bindings/boundvariable.py](slangpy/bindings/boundvariable.py) | `gen_call_data_code` depth-0 entry-point path; `_gen_trampoline_argument()` usage | +| [src/slangpy_ext/utils/slangpy.cpp](src/slangpy_ext/utils/slangpy.cpp) | `bind_call_data` fast path via `find_entry_point(0)`, remove `CallDataMode` branches | +| [src/slangpy_ext/utils/slangpy.h](src/slangpy_ext/utils/slangpy.h) | `m_use_direct_args` on `NativeCallData`, remove `m_call_data_mode` | +| [src/sgl/utils/slangpy.h](src/sgl/utils/slangpy.h) | Remove `CallDataMode` enum definition | +| [slangpy/tests/slangpy_tests/test_kernel_gen.py](slangpy/tests/slangpy_tests/test_kernel_gen.py) | Gating + post-implementation tests | + +--- + +### Verification + +```bash +# Build first (required) +cmake --build --preset windows-msvc-debug + +# Run kernel gen tests +$env:PRINT_TEST_KERNEL_GEN="1"; pytest slangpy/tests/slangpy_tests/test_kernel_gen.py -v + +# Run full test suite +pytest slangpy/tests -v + +# Run pre-commit +pre-commit run --all-files +``` diff --git a/slangpy/tests/slangpy_tests/test_kernel_gen.py b/slangpy/tests/slangpy_tests/test_kernel_gen.py index f245b214e..1475b147e 100644 --- a/slangpy/tests/slangpy_tests/test_kernel_gen.py +++ b/slangpy/tests/slangpy_tests/test_kernel_gen.py @@ -1836,5 +1836,79 @@ def test_phase1_functional_long_struct_name(device_type: spy.DeviceType): assert abs(result - 10.0) < 1e-5 +# =========================================================================== +# Phase 2 gating tests — assert CURRENT behaviour, will break as Phase 2 +# steps are implemented. See plan-simplifyKernelGen-phase2.prompt.md +# =========================================================================== + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_p2_calldata_struct_present(device_type: spy.DeviceType): + """struct CallData is emitted for simple scalar call. Breaks at Step 2.2.""" + device = helpers.get_device(device_type) + code = generate_code(device, "add", "int add(int a, int b) { return a + b; }", 1, 2) + assert_contains(code, "struct CallData") + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_p2_calldata_uniform_param(device_type: spy.DeviceType): + """CallData is passed to kernel via uniform param (CUDA) or ParameterBlock (others). Breaks at Step 2.2.""" + device = helpers.get_device(device_type) + code = generate_code(device, "add", "int add(int a, int b) { return a + b; }", 1, 2) + # CUDA uses entry-point param; D3D12/Vulkan use ParameterBlock at module scope + has_uniform = "uniform CallData call_data" in code + has_param_block = "ParameterBlock call_data" in code + assert ( + has_uniform or has_param_block + ), "Expected 'uniform CallData call_data' or 'ParameterBlock call_data'" + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_p2_thread_count_in_calldata(device_type: spy.DeviceType): + """_thread_count accessed via call_data. prefix. Breaks at Step 2.2.""" + device = helpers.get_device(device_type) + code = generate_code(device, "add", "int add(int a, int b) { return a + b; }", 1, 2) + assert_contains(code, "call_data._thread_count") + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_p2_trampoline_present_for_prim(device_type: spy.DeviceType): + """Prim-mode kernel has a _trampoline function. Breaks at Step 2.3.""" + device = helpers.get_device(device_type) + code = generate_code(device, "add", "int add(int a, int b) { return a + b; }", 1, 2) + assert_contains(code, "void _trampoline(") + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_p2_kernel_calls_trampoline(device_type: spy.DeviceType): + """compute_main calls _trampoline(). Breaks at Step 2.3.""" + device = helpers.get_device(device_type) + code = generate_code(device, "add", "int add(int a, int b) { return a + b; }", 1, 2) + # Extract compute_main body and check it calls _trampoline + main_idx = code.index("void compute_main(") + main_body = code[main_idx:] + assert "_trampoline(__slangpy_context__" in main_body + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_p2_sv_group_id_present(device_type: spy.DeviceType): + """SV_GroupID present in compute_main signature even for dim-0. Breaks at Step 2.2.""" + device = helpers.get_device(device_type) + code = generate_code(device, "add", "int add(int a, int b) { return a + b; }", 1, 2) + assert_contains(code, "SV_GroupID") + + +# -- Phase 2 negative gate — must REMAIN passing after Phase 2 -- + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_gate_p2_wanghasharg_keeps_load(device_type: spy.DeviceType): + """Non-direct-bind WangHashArg still uses __slangpy_load after Phase 2.""" + device = helpers.get_device(device_type) + src = "uint3 rng(uint3 input) { return input; }" + code = generate_code(device, "rng", src, WangHashArg(3)) + assert_contains(code, "__slangpy_load") + + if __name__ == "__main__": pytest.main([__file__, "-vs"]) From 6d8739102b0132e21dcf70e012e351db8f3aad43 Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Thu, 12 Mar 2026 15:14:50 +0000 Subject: [PATCH 18/41] First version of reading uniform size --- .../plan-simplifyKernelGen-phase2.prompt.md | 10 +- slangpy/core/calldata.py | 11 ++ slangpy/core/callsignature.py | 33 +++++ .../tests/slangpy_tests/test_kernel_gen.py | 121 ++++++++++++++++++ src/sgl/device/device.cpp | 21 +++ src/sgl/device/device.h | 6 + src/slangpy_ext/device/device.cpp | 5 + src/slangpy_ext/py_doc.h | 2 + src/slangpy_ext/utils/slangpy.cpp | 7 + src/slangpy_ext/utils/slangpy.h | 7 + 10 files changed, 220 insertions(+), 3 deletions(-) diff --git a/.github/prompts/plan-simplifyKernelGen-phase2.prompt.md b/.github/prompts/plan-simplifyKernelGen-phase2.prompt.md index 2f46c6e3a..f2e280a38 100644 --- a/.github/prompts/plan-simplifyKernelGen-phase2.prompt.md +++ b/.github/prompts/plan-simplifyKernelGen-phase2.prompt.md @@ -153,14 +153,16 @@ void compute_main(int3 flat_call_thread_id: SV_DispatchThreadID, ...) { --- -### Step 2.0: Gating tests +### Step 2.0: Gating tests ✅ -Add tests to [slangpy/tests/slangpy_tests/test_kernel_gen.py](slangpy/tests/slangpy_tests/test_kernel_gen.py) asserting current behavior. These document the baseline and will intentionally break as steps are implemented. +**Status: DONE** + +Tests added to [slangpy/tests/slangpy_tests/test_kernel_gen.py](slangpy/tests/slangpy_tests/test_kernel_gen.py). All 21 parametrized cases (7 tests × 3 device types) pass. | Test | Source | Args | Asserts (current) | Breaks when | |------|--------|------|--------------------|-------------| | `test_gate_p2_calldata_struct_present` | `int add(int a, int b)` | `(1, 2)` | `struct CallData` in code | Step 2.2 | -| `test_gate_p2_calldata_uniform_param` | same | same | `uniform CallData call_data` in `compute_main` | Step 2.2 | +| `test_gate_p2_calldata_uniform_param` | same | same | `uniform CallData call_data` (CUDA) or `ParameterBlock call_data` (D3D12/Vulkan) | Step 2.2 | | `test_gate_p2_thread_count_in_calldata` | same | same | `call_data._thread_count` | Step 2.2 | | `test_gate_p2_trampoline_present_for_prim` | same | same | `void _trampoline(` present | Step 2.3 | | `test_gate_p2_kernel_calls_trampoline` | same | same | `_trampoline(` in `compute_main` body | Step 2.3 | @@ -172,6 +174,8 @@ Negative gates (must stay passing after Phase 2): |------|---------| | `test_gate_p2_wanghasharg_keeps_load` | Non-direct-bind arg still uses `__slangpy_load` | +**Note:** `test_gate_p2_calldata_uniform_param` checks for either `uniform CallData call_data` (CUDA entry-point param) or `ParameterBlock call_data` (D3D12/Vulkan module-scope), since the current `CallDataMode` distinction means different backends emit different patterns. + --- ### Step 2.1: Determine fast vs fallback path diff --git a/slangpy/core/calldata.py b/slangpy/core/calldata.py index 42d016748..3a61fe42f 100644 --- a/slangpy/core/calldata.py +++ b/slangpy/core/calldata.py @@ -269,6 +269,17 @@ def build(self, build_info: "FunctionBuildInfo", *args: Any, **kwargs: Any): # Calculate direct binding eligibility for all variables. calculate_direct_binding(bindings) + # Determine fast path (entry-point params) vs fallback (ParameterBlock). + # Sum inline-uniform byte size and compare against per-device threshold. + inline_size = calculate_inline_uniform_size(bindings, self.call_dimensionality) + threshold = build_info.module.device.info.limits.max_entry_point_uniform_size + self.use_direct_args = inline_size <= threshold + self.log_debug( + f" Inline uniform size: {inline_size} bytes, " + f"threshold: {threshold} bytes, " + f"use_direct_args: {self.use_direct_args}" + ) + # Generate code. codegen = CodeGen() generate_code(context, build_info, bindings, codegen) diff --git a/slangpy/core/callsignature.py b/slangpy/core/callsignature.py index 92093d573..d02f22e2b 100644 --- a/slangpy/core/callsignature.py +++ b/slangpy/core/callsignature.py @@ -162,6 +162,39 @@ def calculate_direct_binding(call: BoundCall): arg.calculate_direct_bind() +def calculate_inline_uniform_size(call: BoundCall, call_dimensionality: int) -> int: + """ + Calculate the total inline-uniform byte size for all depth-0 bound variables, + plus metadata fields (_thread_count, shape arrays). + + Resource types (StructuredBuffer, Texture2D, etc.) contribute 0 bytes to inline + uniform size since they are bound as descriptors. PackedArg / ParameterBlock + types are excluded from this accounting since they stay as ParameterBlock. + + :param call: The bound call containing all args/kwargs. + :param call_dimensionality: The call dimensionality (determines shape array count). + :return: Total inline-uniform size in bytes. + """ + total = 0 + + for node in call.values(): + # PackedArg types use ParameterBlock — excluded from inline accounting + if node.create_param_block: + continue + if node.vector_type is not None: + total += node.vector_type.uniform_layout.size + # If vector_type is None (shouldn't happen after binding), skip safely + + # _thread_count: uint3 = 12 bytes + total += 12 + + # Shape arrays: _grid_stride, _grid_dim, _call_dim — each is int[call_dimensionality] + if call_dimensionality > 0: + total += call_dimensionality * 4 * 3 # 3 arrays × N × sizeof(int) + + return total + + def calculate_call_dimensionality(signature: BoundCall) -> int: """ Calculate the dimensionality of the call diff --git a/slangpy/tests/slangpy_tests/test_kernel_gen.py b/slangpy/tests/slangpy_tests/test_kernel_gen.py index 1475b147e..dd613558a 100644 --- a/slangpy/tests/slangpy_tests/test_kernel_gen.py +++ b/slangpy/tests/slangpy_tests/test_kernel_gen.py @@ -1910,5 +1910,126 @@ def test_gate_p2_wanghasharg_keeps_load(device_type: spy.DeviceType): assert_contains(code, "__slangpy_load") +# =========================================================================== +# Step 2.1 tests — fast vs fallback path determination +# =========================================================================== + + +def build_call_data( + device: spy.Device, func_name: str, module_source: str, *args: Any, **kwargs: Any +) -> Any: + """Build CallData and return the full CallData object.""" + func = helpers.create_function_from_module(device, func_name, module_source) + return func.debug_build_call_data(*args, **kwargs) + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_step21_scalar_uses_direct_args(device_type: spy.DeviceType): + """Simple scalar call has small inline-uniform size → use_direct_args=True.""" + device = helpers.get_device(device_type) + cd = build_call_data(device, "add", "int add(int a, int b) { return a + b; }", 1, 2) + # Two ints (4+4) + RWValueRef for _result (descriptor, ~0 inline) + uint3 _thread_count (12) + # Should be well under any backend's threshold + assert cd.use_direct_args is True + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_step21_threshold_property_positive(device_type: spy.DeviceType): + """Device has a positive max_entry_point_uniform_size threshold.""" + device = helpers.get_device(device_type) + threshold = device.info.limits.max_entry_point_uniform_size + assert threshold > 0 + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_step21_vector_uses_direct_args(device_type: spy.DeviceType): + """float3 args are small enough for direct args.""" + device = helpers.get_device(device_type) + cd = build_call_data( + device, + "scale", + "float3 scale(float3 v, float s) { return v * s; }", + spy.math.float3(1, 2, 3), + 2.0, + ) + assert cd.use_direct_args is True + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_step21_struct_uses_direct_args(device_type: spy.DeviceType): + """All-scalar struct dict has small inline-uniform size.""" + device = helpers.get_device(device_type) + src = """ +struct S { float x; float y; }; +float sum(S s) { return s.x + s.y; } +""" + cd = build_call_data(device, "sum", src, {"_type": "S", "x": 1.0, "y": 2.0}) + assert cd.use_direct_args is True + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_step21_tensor_uses_direct_args(device_type: spy.DeviceType): + """Tensor args contribute descriptor-only (0 inline bytes) → direct args.""" + device = helpers.get_device(device_type) + tensor = Tensor.from_numpy(device, np.array([1.0, 2.0, 3.0], dtype=np.float32)) + cd = build_call_data( + device, + "sum_all", + "float sum_all(float x) { return x; }", + tensor, + ) + assert cd.use_direct_args is True + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_step21_many_float4x4_may_exceed_vulkan(device_type: spy.DeviceType): + """Many float4x4 params may exceed Vulkan's 128-byte threshold. + + 8 × float4x4 = 8 × 64 bytes = 512 bytes inline + 12 bytes _thread_count = 524 bytes. + This exceeds Vulkan (128) and D3D12 (256) but not CUDA (4096). + """ + device = helpers.get_device(device_type) + src = """ +float4x4 sum8(float4x4 a, float4x4 b, float4x4 c, float4x4 d, + float4x4 e, float4x4 f, float4x4 g, float4x4 h) { + return a + b + c + d + e + f + g + h; +} +""" + identity = spy.math.float4x4.identity() + cd = build_call_data( + device, + "sum8", + src, + identity, + identity, + identity, + identity, + identity, + identity, + identity, + identity, + ) + threshold = device.info.limits.max_entry_point_uniform_size + if threshold >= 524: + assert cd.use_direct_args is True + else: + assert cd.use_direct_args is False + + +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_step21_wanghasharg_uses_direct_args(device_type: spy.DeviceType): + """WangHashArg (non-direct-bind) still counts its inline-uniform size. + Its wrapper type has a small inline footprint, so use_direct_args should be True. + """ + device = helpers.get_device(device_type) + cd = build_call_data( + device, + "rng", + "uint3 rng(uint3 input) { return input; }", + WangHashArg(3), + ) + assert cd.use_direct_args is True + + if __name__ == "__main__": pytest.main([__file__, "-vs"]) diff --git a/src/sgl/device/device.cpp b/src/sgl/device/device.cpp index a9baa911a..4692d406e 100644 --- a/src/sgl/device/device.cpp +++ b/src/sgl/device/device.cpp @@ -274,6 +274,27 @@ Device::Device(const DeviceDesc& desc) ); m_info.limits.max_shader_visible_samplers = rhi_device_info.limits.maxShaderVisibleSamplers; + // Set conservative default for max entry-point uniform (push constant / root constant) size. + // The RHI doesn't expose this directly, so we use per-backend defaults. + switch (m_desc.type) { + case DeviceType::vulkan: + // Vulkan spec minimum maxPushConstantsSize is 128 bytes. + m_info.limits.max_entry_point_uniform_size = 128; + break; + case DeviceType::d3d12: + // D3D12 root signature allows 64 DWORDs (256 bytes) total for root constants, + // shared with root descriptors. Use a conservative 256. + m_info.limits.max_entry_point_uniform_size = 256; + break; + case DeviceType::cuda: + // CUDA kernel parameter block limit is typically 4KB. + m_info.limits.max_entry_point_uniform_size = 4096; + break; + default: + m_info.limits.max_entry_point_uniform_size = 128; + break; + } + // Get supported shader model. const std::vector> available_shader_models = { {ShaderModel::sm_6_7, "sm_6_7"}, diff --git a/src/sgl/device/device.h b/src/sgl/device/device.h index 1149cb235..2c0d77e2a 100644 --- a/src/sgl/device/device.h +++ b/src/sgl/device/device.h @@ -185,6 +185,12 @@ struct DeviceLimits { /// Maximum samplers visible in a shader stage. uint32_t max_shader_visible_samplers; + + /// Maximum size in bytes of inline-uniform data for entry-point parameters. + /// On Vulkan this corresponds to push constant size (minimum 128 bytes). + /// On D3D12 this corresponds to root constant space (~256 bytes). + /// On CUDA this corresponds to the kernel parameter block (~4096 bytes). + uint32_t max_entry_point_uniform_size; }; struct DeviceInfo { diff --git a/src/slangpy_ext/device/device.cpp b/src/slangpy_ext/device/device.cpp index 9f8a4dc65..dda7f4af3 100644 --- a/src/slangpy_ext/device/device.cpp +++ b/src/slangpy_ext/device/device.cpp @@ -306,6 +306,11 @@ SGL_PY_EXPORT(device_device) "max_shader_visible_samplers", &DeviceLimits::max_shader_visible_samplers, D(DeviceLimits, max_shader_visible_samplers) + ) + .def_ro( + "max_entry_point_uniform_size", + &DeviceLimits::max_entry_point_uniform_size, + D(DeviceLimits, max_entry_point_uniform_size) ); nb::class_(m, "DeviceInfo", D(DeviceInfo)) diff --git a/src/slangpy_ext/py_doc.h b/src/slangpy_ext/py_doc.h index 4028a5791..4c2e01dcc 100644 --- a/src/slangpy_ext/py_doc.h +++ b/src/slangpy_ext/py_doc.h @@ -2634,6 +2634,8 @@ static const char *__doc_sgl_DeviceLimits_max_compute_threads_per_group = R"doc( static const char *__doc_sgl_DeviceLimits_max_framebuffer_dimensions = R"doc(Maximum framebuffer dimensions.)doc"; +static const char *__doc_sgl_DeviceLimits_max_entry_point_uniform_size = R"doc(Maximum size in bytes of inline-uniform data for entry-point parameters.)doc"; + static const char *__doc_sgl_DeviceLimits_max_shader_visible_samplers = R"doc(Maximum samplers visible in a shader stage.)doc"; static const char *__doc_sgl_DeviceLimits_max_texture_dimension_1d = R"doc(Maximum dimension for 1D textures.)doc"; diff --git a/src/slangpy_ext/utils/slangpy.cpp b/src/slangpy_ext/utils/slangpy.cpp index 64714903c..0fc4893d3 100644 --- a/src/slangpy_ext/utils/slangpy.cpp +++ b/src/slangpy_ext/utils/slangpy.cpp @@ -1701,6 +1701,13 @@ SGL_PY_EXPORT(utils_slangpy) nb::arg(), D_NA(NativeCallData, has_thread_count) ) + .def_prop_rw( + "use_direct_args", + &NativeCallData::use_direct_args, + &NativeCallData::set_use_direct_args, + nb::arg(), + D_NA(NativeCallData, use_direct_args) + ) .def_prop_rw( "autograd_access_list", &NativeCallData::autograd_access_list, diff --git a/src/slangpy_ext/utils/slangpy.h b/src/slangpy_ext/utils/slangpy.h index a996c9da1..48dbff97e 100644 --- a/src/slangpy_ext/utils/slangpy.h +++ b/src/slangpy_ext/utils/slangpy.h @@ -788,6 +788,12 @@ class NativeCallData : Object { /// Set whether this call data expects a _thread_count kwarg. void set_has_thread_count(bool has_thread_count) { m_has_thread_count = has_thread_count; } + /// Get whether this call uses direct entry-point parameters (fast path). + bool use_direct_args() const { return m_use_direct_args; } + + /// Set whether this call uses direct entry-point parameters (fast path). + void set_use_direct_args(bool use_direct_args) { m_use_direct_args = use_direct_args; } + /// Get the autograd access list. /// This is a flat list of AutogradAccess values precomputed at build time. /// At dispatch time, find_torch_tensors steps through this list as it encounters tensors. @@ -919,6 +925,7 @@ class NativeCallData : Object { bool m_torch_autograd{false}; bool m_needs_unpack{true}; bool m_has_thread_count{false}; + bool m_use_direct_args{false}; std::vector m_autograd_access_list; ref m_bwds_call_data; mutable CallDataOffsets m_cached_call_data_offsets; From 86f45fcf256e91cad1a3aba71616f95ff840897b Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Thu, 12 Mar 2026 15:45:09 +0000 Subject: [PATCH 19/41] wip switching to entry point arguments --- .../plan-simplifyKernelGen-phase2.prompt.md | 38 ++- slangpy/bindings/boundvariable.py | 4 + slangpy/bindings/codegen.py | 16 +- slangpy/bindings/marshall.py | 7 +- slangpy/core/calldata.py | 16 +- slangpy/core/callsignature.py | 123 ++++++--- slangpy/core/dispatchdata.py | 10 +- slangpy/core/function.py | 1 - slangpy/core/packedarg.py | 5 +- .../tests/slangpy_tests/test_kernel_gen.py | 50 ++-- .../slangpy_tests/test_type_resolution.py | 1 - src/sgl/utils/slangpy.h | 10 - src/slangpy_ext/py_doc.h | 8 - src/slangpy_ext/utils/slangpy.cpp | 233 +++++++++++------- src/slangpy_ext/utils/slangpy.h | 11 +- 15 files changed, 310 insertions(+), 223 deletions(-) diff --git a/.github/prompts/plan-simplifyKernelGen-phase2.prompt.md b/.github/prompts/plan-simplifyKernelGen-phase2.prompt.md index f2e280a38..ba204861f 100644 --- a/.github/prompts/plan-simplifyKernelGen-phase2.prompt.md +++ b/.github/prompts/plan-simplifyKernelGen-phase2.prompt.md @@ -178,7 +178,9 @@ Negative gates (must stay passing after Phase 2): --- -### Step 2.1: Determine fast vs fallback path +### Step 2.1: Determine fast vs fallback path ✅ + +**Status: DONE** In [slangpy/core/calldata.py](slangpy/core/calldata.py), after `calculate_direct_binding(bindings)`: @@ -189,6 +191,25 @@ In [slangpy/core/calldata.py](slangpy/core/calldata.py), after `calculate_direct `PackedArg` / param-block types are excluded from this accounting — they stay as `ParameterBlock` regardless. +**Implementation details:** + +- `DeviceLimits.max_entry_point_uniform_size` added to C++ struct ([device.h](src/sgl/device/device.h)) with per-backend defaults: Vulkan=128, D3D12=256, CUDA=4096 bytes ([device.cpp](src/sgl/device/device.cpp)). +- `calculate_inline_uniform_size()` added to [callsignature.py](slangpy/core/callsignature.py) — sums `vector_type.uniform_layout.size` for each depth-0 bound variable (skipping `PackedArg`), plus 12 bytes for `_thread_count` and `call_dimensionality * 4 * 3` for shape arrays. +- `use_direct_args` property added to `NativeCallData` C++ class ([slangpy.h](src/slangpy_ext/utils/slangpy.h)) with Python binding. +- `CallData.__init__()` in [calldata.py](slangpy/core/calldata.py) sets `self.use_direct_args = inline_size <= threshold` after `calculate_direct_binding()`. + +**Tests** (7 tests × 3 device types = 21 parametrized cases, all pass): + +| Test | Asserts | +|------|---------| +| `test_step21_scalar_uses_direct_args` | Simple `int add(int,int)` with `(1,2)` → `use_direct_args=True` | +| `test_step21_threshold_property_positive` | `device.info.limits.max_entry_point_uniform_size > 0` | +| `test_step21_vector_uses_direct_args` | `float3` args → `use_direct_args=True` | +| `test_step21_struct_uses_direct_args` | All-scalar struct dict → `use_direct_args=True` | +| `test_step21_tensor_uses_direct_args` | Tensor (descriptor-only, 0 inline bytes) → `use_direct_args=True` | +| `test_step21_many_float4x4_may_exceed_vulkan` | 8×float4x4 (524 bytes) exceeds Vulkan/D3D12 thresholds, not CUDA | +| `test_step21_wanghasharg_uses_direct_args` | Non-direct-bind WangHashArg with small inline size → `use_direct_args=True` | + --- ### Step 2.2: Code generation — entry-point params (fast path) @@ -327,9 +348,9 @@ Auto-created `_result` is a writable `ValueRef`, currently NOT direct-bind eligi ### Implementation Order -1. **Step 2.0** — Gating tests (baseline documentation) -2. **Step 2.3** — Trampoline elimination for prim mode (both paths). This is independent of entry-point param work and provides immediate value. -3. **Step 2.1** — Fast/fallback determination + size query +1. **Step 2.0** ✅ — Gating tests (baseline documentation) +2. **Step 2.1** ✅ — Fast/fallback determination + size query +3. **Step 2.3** — Trampoline elimination for prim mode (both paths). This is independent of entry-point param work and provides immediate value. 4. **Step 2.2 + 2.5** — Code gen + C++ dispatch for entry-point params (must land together — Slang layout and C++ cursor navigation must agree) 5. **Step 2.4** — Bwds trampoline with individual params (fast path) 6. **Step 2.6** — `_result` as `RWStructuredBuffer` for all-direct-bind case @@ -347,10 +368,13 @@ Steps 2.3 (trampoline) and 2.2/2.5 (entry-point params) are independent axes and | [slangpy/core/callsignature.py](slangpy/core/callsignature.py) | `generate_code()` — inline load/call/store, entry-point params, Context gating, remove `is_entry_point` branch | | [slangpy/bindings/codegen.py](slangpy/bindings/codegen.py) | `skip_call_data` flag, `entry_point_params` list | | [slangpy/bindings/boundvariable.py](slangpy/bindings/boundvariable.py) | `gen_call_data_code` depth-0 entry-point path; `_gen_trampoline_argument()` usage | -| [src/slangpy_ext/utils/slangpy.cpp](src/slangpy_ext/utils/slangpy.cpp) | `bind_call_data` fast path via `find_entry_point(0)`, remove `CallDataMode` branches | -| [src/slangpy_ext/utils/slangpy.h](src/slangpy_ext/utils/slangpy.h) | `m_use_direct_args` on `NativeCallData`, remove `m_call_data_mode` | +| [src/slangpy_ext/utils/slangpy.cpp](src/slangpy_ext/utils/slangpy.cpp) | ✅ `use_direct_args` binding; `bind_call_data` fast path via `find_entry_point(0)`, remove `CallDataMode` branches | +| [src/slangpy_ext/utils/slangpy.h](src/slangpy_ext/utils/slangpy.h) | ✅ `m_use_direct_args` on `NativeCallData`; remove `m_call_data_mode` | +| [src/sgl/device/device.h](src/sgl/device/device.h) | ✅ `max_entry_point_uniform_size` on `DeviceLimits` | +| [src/sgl/device/device.cpp](src/sgl/device/device.cpp) | ✅ Per-backend defaults for `max_entry_point_uniform_size` | +| [src/slangpy_ext/device/device.cpp](src/slangpy_ext/device/device.cpp) | ✅ Python binding for `max_entry_point_uniform_size` | | [src/sgl/utils/slangpy.h](src/sgl/utils/slangpy.h) | Remove `CallDataMode` enum definition | -| [slangpy/tests/slangpy_tests/test_kernel_gen.py](slangpy/tests/slangpy_tests/test_kernel_gen.py) | Gating + post-implementation tests | +| [slangpy/tests/slangpy_tests/test_kernel_gen.py](slangpy/tests/slangpy_tests/test_kernel_gen.py) | ✅ Gating tests + Step 2.1 tests; post-implementation tests | --- diff --git a/slangpy/bindings/boundvariable.py b/slangpy/bindings/boundvariable.py index 9214ae607..04b88b602 100644 --- a/slangpy/bindings/boundvariable.py +++ b/slangpy/bindings/boundvariable.py @@ -685,6 +685,10 @@ def gen_call_data_code(self, cg: CodeGen, context: BindContext, depth: int = 0): ), f"calldata_type_name not set for '{self.variable_name}'" if self.create_param_block: cg.add_parameter_block(self.calldata_type_name, "_param_" + self.variable_name) + elif cg.skip_call_data: + cg.entry_point_params.append( + f"uniform {self.calldata_type_name} {self.variable_name}" + ) else: cg.call_data.declare(self.calldata_type_name, self.variable_name) diff --git a/slangpy/bindings/codegen.py b/slangpy/bindings/codegen.py index 3ec6cb4ce..ebf8f7abb 100644 --- a/slangpy/bindings/codegen.py +++ b/slangpy/bindings/codegen.py @@ -129,6 +129,12 @@ def __init__(self): #: Additional parameter blocks self.parameter_blocks: list[str] = [] + #: When True, skip emitting struct CallData (fast path: entry-point params). + self.skip_call_data: bool = False + + #: Individual uniform entry-point parameter declarations (fast path). + self.entry_point_params: list[str] = [] + # legacy self.input_load_store = CodeGenBlock(self) @@ -193,10 +199,11 @@ def finish( Generate the final code for the kernel. """ - self.call_data.end_block() + if not self.skip_call_data: + self.call_data.end_block() - if use_param_block_for_call_data: - self.call_data.append_statement("ParameterBlock call_data") + if use_param_block_for_call_data: + self.call_data.append_statement("ParameterBlock call_data") all_code: list[str] = [] if header: @@ -214,9 +221,10 @@ def finish( if call_data_structs: all_code = all_code + self.call_data_structs.code all_code.append("\n") - if call_data: + if call_data and not self.skip_call_data: all_code = all_code + self.call_data.code all_code.append("\n") + if call_data: all_code = all_code + self.parameter_blocks all_code.append("\n") if snippets: diff --git a/slangpy/bindings/marshall.py b/slangpy/bindings/marshall.py index 0aa719a3f..3b7a9fe7d 100644 --- a/slangpy/bindings/marshall.py +++ b/slangpy/bindings/marshall.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any -from slangpy.core.native import CallMode, CallDataMode, NativeMarshall +from slangpy.core.native import CallMode, NativeMarshall from slangpy.bindings.codegen import CodeGenBlock @@ -24,7 +24,6 @@ def __init__( call_mode: CallMode, device_module: "SlangModule", options: dict[str, Any], - call_data_mode: CallDataMode, ): super().__init__() @@ -37,8 +36,8 @@ def __init__( #: Call mode (prim/bwds/fwds). self.call_mode = call_mode - #: Call data mode (global_data/entry_point). - self.call_data_mode = call_data_mode + #: Whether to use direct entry-point params (fast path) vs ParameterBlock (fallback). + self.use_direct_args = False #: SGL module. self.device_module = device_module diff --git a/slangpy/core/calldata.py b/slangpy/core/calldata.py index 3a61fe42f..fa854df65 100644 --- a/slangpy/core/calldata.py +++ b/slangpy/core/calldata.py @@ -9,7 +9,6 @@ from slangpy.core.logging import bound_call_table, bound_exception_info, mismatch_info from slangpy.core.native import ( CallMode, - CallDataMode, NativeCallData, unpack_args, unpack_kwargs, @@ -144,15 +143,6 @@ def build(self, build_info: "FunctionBuildInfo", *args: Any, **kwargs: Any): self.layout = build_info.module.layout self.call_mode = build_info.call_mode - # Set call data mode based on device and pipeline type - if ( - build_info.module.device.info.type == DeviceType.cuda - and build_info.pipeline_type == PipelineType.compute - ): - self.call_data_mode = CallDataMode.entry_point - else: - self.call_data_mode = CallDataMode.global_data - # Unpack args (handles IThis wrappers) unpacked_args, args_had_unpack = unpack_args(*args) unpacked_kwargs, kwargs_had_unpack = unpack_kwargs(**kwargs) @@ -193,7 +183,6 @@ def build(self, build_info: "FunctionBuildInfo", *args: Any, **kwargs: Any): self.call_mode, build_info.module.device_module, build_info.options, - self.call_data_mode, ) # Build the unbound signature from inputs @@ -280,6 +269,9 @@ def build(self, build_info: "FunctionBuildInfo", *args: Any, **kwargs: Any): f"use_direct_args: {self.use_direct_args}" ) + # Propagate use_direct_args to context for code generation + context.use_direct_args = self.use_direct_args + # Generate code. codegen = CodeGen() generate_code(context, build_info, bindings, codegen) @@ -296,7 +288,7 @@ def build(self, build_info: "FunctionBuildInfo", *args: Any, **kwargs: Any): snippets=True, call_data_structs=True, constants=True, - use_param_block_for_call_data=context.call_data_mode == CallDataMode.global_data, + use_param_block_for_call_data=not self.use_direct_args, ) # Optionally write the shader to a file for debugging. diff --git a/slangpy/core/callsignature.py b/slangpy/core/callsignature.py index d02f22e2b..a8a5a4a71 100644 --- a/slangpy/core/callsignature.py +++ b/slangpy/core/callsignature.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from typing import TYPE_CHECKING, Any, Optional -from slangpy.core.native import AccessType, CallMode, CallDataMode +from slangpy.core.native import AccessType, CallMode from slangpy.core.function import PipelineType import slangpy.bindings.typeregistry as tr @@ -278,8 +278,8 @@ def generate_code( """ nodes: list[BoundVariable] = [] - # Check if we're using entry point call data mode - is_entry_point = context.call_data_mode == CallDataMode.entry_point + # Check if we're using direct entry-point params (fast path) + use_direct_args = context.use_direct_args # Generate the header cg.add_import("slangpy") @@ -368,22 +368,36 @@ def generate_code( cg.constants.dec_indent() cg.constants.append_statement("}") + # Set up code gen mode for direct args vs CallData struct + if use_direct_args: + cg.skip_call_data = True + # Generate call data inputs if vector call if call_data_len > 0: - # A group can be thought of as a "window" looking at a - # portion of the entire call shape. Grid here refers to the - # N dimensional call shape being broken up into some number of N - # dimensional "window"s / groups. - cg.call_data.append_statement(f"int[{call_data_len}] _grid_stride") - cg.call_data.append_statement(f"int[{call_data_len}] _grid_dim") - # We use the call shape dimensions to detect cases when the call shape - # and the call group shape are not aligned. When a thread's call id - # falls outside the call shape, we need it to return early. This is - # similar to the default linear case when the call shape size is not - # 32 thread aligned. - cg.call_data.append_statement(f"int[{call_data_len}] _call_dim") - - cg.call_data.append_statement(f"uint3 _thread_count") + if use_direct_args: + # Fast path: shape arrays as individual entry-point params + cg.entry_point_params.append(f"uniform int[{call_data_len}] _grid_stride") + cg.entry_point_params.append(f"uniform int[{call_data_len}] _grid_dim") + cg.entry_point_params.append(f"uniform int[{call_data_len}] _call_dim") + else: + # Fallback: shape arrays inside CallData struct + # A group can be thought of as a "window" looking at a + # portion of the entire call shape. Grid here refers to the + # N dimensional call shape being broken up into some number of N + # dimensional "window"s / groups. + cg.call_data.append_statement(f"int[{call_data_len}] _grid_stride") + cg.call_data.append_statement(f"int[{call_data_len}] _grid_dim") + # We use the call shape dimensions to detect cases when the call shape + # and the call group shape are not aligned. When a thread's call id + # falls outside the call shape, we need it to return early. This is + # similar to the default linear case when the call shape size is not + # 32 thread aligned. + cg.call_data.append_statement(f"int[{call_data_len}] _call_dim") + + if use_direct_args: + cg.entry_point_params.append("uniform uint3 _thread_count") + else: + cg.call_data.append_statement("uint3 _thread_count") # Generate call data definitions for all inputs to the kernel for node in signature.values(): @@ -397,12 +411,18 @@ def generate_code( if context.call_mode != CallMode.prim: cg.trampoline.append_line("[Differentiable]") - # For entry point mode, add CallData as a parameter to the trampoline function - if is_entry_point: - cg.trampoline.append_line( - f"void {trampoline_fn}(Context __slangpy_context__, CallData __calldata__)" - ) + if use_direct_args: + # Fast path: trampoline takes individual calldata-typed params. + # Use __in_ prefix for param names to avoid collision with local variable names. + trampoline_params = ["Context __slangpy_context__"] + for x in root_params: + if x.create_param_block: + continue # param blocks handled via _param_ at module scope + assert x.calldata_type_name is not None + trampoline_params.append(f"{x.calldata_type_name} __in_{x.variable_name}") + cg.trampoline.append_line(f"void {trampoline_fn}({', '.join(trampoline_params)})") else: + # Fallback: trampoline reads from global ParameterBlock call_data cg.trampoline.append_line(f"void {trampoline_fn}(Context __slangpy_context__)") cg.trampoline.begin_block() @@ -411,11 +431,9 @@ def generate_code( assert x.vector_type is not None cg.trampoline.declare(x.vector_type.full_name, x.variable_name) for x in root_params: - if is_entry_point: + if use_direct_args: data_name = ( - f"_param_{x.variable_name}" - if x.create_param_block - else f"__calldata__.{x.variable_name}" + f"_param_{x.variable_name}" if x.create_param_block else f"__in_{x.variable_name}" ) else: data_name = ( @@ -462,11 +480,11 @@ def generate_code( or x.access[0] == AccessType.readwrite or x.access[1] == AccessType.read ): - if is_entry_point: + if use_direct_args: data_name = ( f"_param_{x.variable_name}" if x.create_param_block - else f"__calldata__.{x.variable_name}" + else f"__in_{x.variable_name}" ) else: data_name = ( @@ -495,18 +513,26 @@ def generate_code( cg.kernel.append_line("[numthreads(32, 1, 1)]") # Note: While flat_call_thread_id is 3-dimensional, we consider it "flat" and 1-dimensional because of the # true call group shape of [x, 1, 1] and only use the first dimension for the call thread id. - if is_entry_point: - cg.kernel.append_line( - "void compute_main(int3 flat_call_thread_id: SV_DispatchThreadID, int3 flat_call_group_id: SV_GroupID, int flat_call_group_thread_id: SV_GroupIndex, uniform CallData call_data)" - ) + if use_direct_args: + # Fast path: build compute_main signature with individual entry-point params + sig_parts = ["int3 flat_call_thread_id: SV_DispatchThreadID"] + # Only include SV_GroupID/SV_GroupIndex when call_data_len > 0 + # (they feed init_thread_local_call_shape_info which isn't called otherwise) + if call_data_len > 0: + sig_parts.append("int3 flat_call_group_id: SV_GroupID") + sig_parts.append("int flat_call_group_thread_id: SV_GroupIndex") + sig_parts.extend(cg.entry_point_params) + cg.kernel.append_line(f"void compute_main({', '.join(sig_parts)})") else: + # Fallback: no uniform params (reads from global ParameterBlock) cg.kernel.append_line( "void compute_main(int3 flat_call_thread_id: SV_DispatchThreadID, int3 flat_call_group_id: SV_GroupID, int flat_call_group_thread_id: SV_GroupIndex)" ) elif build_info.pipeline_type == PipelineType.ray_tracing: cg.kernel.append_line('[shader("raygen")]') - if is_entry_point: - cg.kernel.append_line("void raygen_main(uniform CallData call_data)") + if use_direct_args: + sig_parts = list(cg.entry_point_params) + cg.kernel.append_line(f"void raygen_main({', '.join(sig_parts)})") else: cg.kernel.append_line("void raygen_main()") else: @@ -517,7 +543,13 @@ def generate_code( if build_info.pipeline_type == PipelineType.ray_tracing: cg.kernel.append_statement("int3 flat_call_thread_id = DispatchRaysIndex();") - cg.kernel.append_statement("if (any(flat_call_thread_id >= call_data._thread_count)) return") + # Bounds check — use _thread_count directly in fast path, call_data._thread_count in fallback + if use_direct_args: + cg.kernel.append_statement("if (any(flat_call_thread_id >= _thread_count)) return") + else: + cg.kernel.append_statement( + "if (any(flat_call_thread_id >= call_data._thread_count)) return" + ) # Loads / initializes call id context_args = "flat_call_thread_id" @@ -525,20 +557,22 @@ def generate_code( # Call init_thread_local_call_shape_info to initialize the call shape info. See # definition in callshape.slang. if call_data_len > 0: + # In fast path, shape arrays are direct entry-point params; in fallback, prefixed with call_data. + grid_prefix = "" if use_direct_args else "call_data." if build_info.pipeline_type == PipelineType.compute: cg.kernel.append_line( f""" if (!init_thread_local_call_shape_info(flat_call_group_thread_id, - flat_call_group_id, flat_call_thread_id, call_data._grid_stride, - call_data._grid_dim, call_data._call_dim)) + flat_call_group_id, flat_call_thread_id, {grid_prefix}_grid_stride, + {grid_prefix}_grid_dim, {grid_prefix}_call_dim)) return;""" ) elif build_info.pipeline_type == PipelineType.ray_tracing: cg.kernel.append_line( f""" if (!init_thread_local_call_shape_info(0, - uint3(0), flat_call_thread_id, call_data._grid_stride, - call_data._grid_dim, call_data._call_dim)) + uint3(0), flat_call_thread_id, {grid_prefix}_grid_stride, + {grid_prefix}_grid_dim, {grid_prefix}_call_dim)) return;""" ) context_args += ", CallShapeInfo::get_call_id().shape" @@ -550,9 +584,16 @@ def generate_code( if context.call_mode == CallMode.bwds: fn = f"bwd_diff({fn})" - if is_entry_point: - cg.kernel.append_statement(f"{fn}(__slangpy_context__, call_data)") + if use_direct_args: + # Fast path: pass individual entry-point param names to the trampoline + trampoline_args = ["__slangpy_context__"] + for x in root_params: + if x.create_param_block: + continue # param blocks are at module scope + trampoline_args.append(x.variable_name) + cg.kernel.append_statement(f"{fn}({', '.join(trampoline_args)})") else: + # Fallback: trampoline reads from global call_data cg.kernel.append_statement(f"{fn}(__slangpy_context__)") cg.kernel.end_block() diff --git a/slangpy/core/dispatchdata.py b/slangpy/core/dispatchdata.py index 5f8abe244..940ee3747 100644 --- a/slangpy/core/dispatchdata.py +++ b/slangpy/core/dispatchdata.py @@ -6,7 +6,7 @@ from slangpy.core.callsignature import generate_constants from slangpy.core.enums import IOType -from slangpy.core.native import CallMode, CallDataMode, pack_arg, unpack_arg +from slangpy.core.native import CallMode, pack_arg, unpack_arg from slangpy.core.calldata import _DUMP_SLANG_INTERMEDIATES, _DUMP_GENERATED_SHADERS from slangpy import ( @@ -45,19 +45,11 @@ def __init__(self, func: "FunctionNode", **kwargs: dict[str, Any]) -> None: # Bind # Setup context - # Determine call data mode based on device type - call_data_mode = ( - CallDataMode.entry_point - if build_info.module.device.info.type == DeviceType.cuda - else CallDataMode.global_data - ) - context = BindContext( func.module.layout, CallMode.prim, build_info.module.device_module, build_info.options, - call_data_mode, ) # Build the unbound signature from inputs and convert straight diff --git a/slangpy/core/function.py b/slangpy/core/function.py index 8fb4715c5..d263d785f 100644 --- a/slangpy/core/function.py +++ b/slangpy/core/function.py @@ -4,7 +4,6 @@ from slangpy.core.native import ( CallMode, - CallDataMode, SignatureBuilder, NativeCallRuntimeOptions, NativeFunctionNode, diff --git a/slangpy/core/packedarg.py b/slangpy/core/packedarg.py index 6cc0c51c8..d7a09a731 100644 --- a/slangpy/core/packedarg.py +++ b/slangpy/core/packedarg.py @@ -4,7 +4,6 @@ from slangpy.core.native import ( get_value_signature, CallMode, - CallDataMode, NativePackedArg, unpack_arg, ) @@ -28,9 +27,7 @@ def __init__(self, module: Module, python_object: Any): # Create a shader object from the python marshall and init native structure shader_object = python.build_shader_object( - BindContext( - module.layout, CallMode.prim, module.device_module, {}, CallDataMode.global_data - ), + BindContext(module.layout, CallMode.prim, module.device_module, {}), unpacked_obj, ) if shader_object is None: diff --git a/slangpy/tests/slangpy_tests/test_kernel_gen.py b/slangpy/tests/slangpy_tests/test_kernel_gen.py index dd613558a..41d58af98 100644 --- a/slangpy/tests/slangpy_tests/test_kernel_gen.py +++ b/slangpy/tests/slangpy_tests/test_kernel_gen.py @@ -47,14 +47,17 @@ def assert_not_contains(code: str, *patterns: str) -> None: def assert_trampoline_has(code: str, *stmts: str) -> None: - """Assert trampoline contains statements, insensitive to call_data vs __calldata__ prefix.""" + """Assert trampoline contains statements, insensitive to call_data vs __calldata__ vs __in_ prefix.""" for s in stmts: - # Replace __calldata__ with both options for matching + # Replace __calldata__ with all three options for matching if "__calldata__." in s: - alt = s.replace("__calldata__.", "call_data.") + alt_cd = s.replace("__calldata__.", "call_data.") + # For fast path: __calldata__.X → __in_X (entry-point param prefix) + # Extract variable name after __calldata__. and before any trailing char + alt_in = s.replace("__calldata__.", "__in_") assert ( - s in code or alt in code - ), f"Expected trampoline statement not found: {s} (or {alt})" + s in code or alt_cd in code or alt_in in code + ), f"Expected trampoline statement not found: {s} (or {alt_cd} or {alt_in})" else: assert s in code, f"Expected trampoline statement not found: {s}" @@ -1843,32 +1846,35 @@ def test_phase1_functional_long_struct_name(device_type: spy.DeviceType): @pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_p2_calldata_struct_present(device_type: spy.DeviceType): - """struct CallData is emitted for simple scalar call. Breaks at Step 2.2.""" +def test_gate_p2_calldata_struct_absent_fast_path(device_type: spy.DeviceType): + """Fast path (use_direct_args=True): no struct CallData emitted. Step 2.2 done.""" device = helpers.get_device(device_type) code = generate_code(device, "add", "int add(int a, int b) { return a + b; }", 1, 2) - assert_contains(code, "struct CallData") + assert_not_contains(code, "struct CallData") @pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_p2_calldata_uniform_param(device_type: spy.DeviceType): - """CallData is passed to kernel via uniform param (CUDA) or ParameterBlock (others). Breaks at Step 2.2.""" +def test_gate_p2_individual_uniform_params(device_type: spy.DeviceType): + """Fast path: individual uniform params instead of unified CallData. Step 2.2 done.""" device = helpers.get_device(device_type) code = generate_code(device, "add", "int add(int a, int b) { return a + b; }", 1, 2) - # CUDA uses entry-point param; D3D12/Vulkan use ParameterBlock at module scope - has_uniform = "uniform CallData call_data" in code - has_param_block = "ParameterBlock call_data" in code - assert ( - has_uniform or has_param_block - ), "Expected 'uniform CallData call_data' or 'ParameterBlock call_data'" + assert_contains(code, "uniform uint3 _thread_count") + assert_contains(code, "uniform int a") + assert_contains(code, "uniform int b") + assert_not_contains(code, "uniform CallData call_data") + assert_not_contains(code, "ParameterBlock call_data") @pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_p2_thread_count_in_calldata(device_type: spy.DeviceType): - """_thread_count accessed via call_data. prefix. Breaks at Step 2.2.""" +def test_gate_p2_thread_count_direct(device_type: spy.DeviceType): + """Fast path: _thread_count accessed directly, not via call_data prefix. Step 2.2 done.""" device = helpers.get_device(device_type) code = generate_code(device, "add", "int add(int a, int b) { return a + b; }", 1, 2) - assert_contains(code, "call_data._thread_count") + assert_not_contains(code, "call_data._thread_count") + # Extract compute_main body and check _thread_count used directly + main_idx = code.index("void compute_main(") + main_body = code[main_idx:] + assert ">= _thread_count)" in main_body @pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) @@ -1891,11 +1897,11 @@ def test_gate_p2_kernel_calls_trampoline(device_type: spy.DeviceType): @pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_p2_sv_group_id_present(device_type: spy.DeviceType): - """SV_GroupID present in compute_main signature even for dim-0. Breaks at Step 2.2.""" +def test_gate_p2_sv_group_id_absent_dim0(device_type: spy.DeviceType): + """Fast path dim-0: SV_GroupID not needed. Step 2.2 done.""" device = helpers.get_device(device_type) code = generate_code(device, "add", "int add(int a, int b) { return a + b; }", 1, 2) - assert_contains(code, "SV_GroupID") + assert_not_contains(code, "SV_GroupID") # -- Phase 2 negative gate — must REMAIN passing after Phase 2 -- diff --git a/slangpy/tests/slangpy_tests/test_type_resolution.py b/slangpy/tests/slangpy_tests/test_type_resolution.py index 7a4819ff4..47d239af9 100644 --- a/slangpy/tests/slangpy_tests/test_type_resolution.py +++ b/slangpy/tests/slangpy_tests/test_type_resolution.py @@ -34,7 +34,6 @@ def build_test_data(module: spy.Module, call_mode: spyn.CallMode, *args: Any, ** { "strict_broadcasting": False, }, - spyn.CallDataMode.global_data, ) return context, spy.bindings.BoundCall(context, *unpacked_args, **unpacked_kwargs) diff --git a/src/sgl/utils/slangpy.h b/src/sgl/utils/slangpy.h index 09d33ac04..0d2a8f44a 100644 --- a/src/sgl/utils/slangpy.h +++ b/src/sgl/utils/slangpy.h @@ -43,16 +43,6 @@ SGL_ENUM_INFO( ); SGL_ENUM_REGISTER(CallMode); -enum class CallDataMode { global_data, entry_point }; -SGL_ENUM_INFO( - CallDataMode, - { - {CallDataMode::global_data, "global_data"}, - {CallDataMode::entry_point, "entry_point"}, - } -); -SGL_ENUM_REGISTER(CallDataMode); - /// Access pattern for torch autograd tensor bindings. /// Precomputed at build time and stored in a flat list on NativeCallData, /// consumed in order during find_torch_tensors at dispatch time. diff --git a/src/slangpy_ext/py_doc.h b/src/slangpy_ext/py_doc.h index 4c2e01dcc..6abd6d2d5 100644 --- a/src/slangpy_ext/py_doc.h +++ b/src/slangpy_ext/py_doc.h @@ -10748,14 +10748,6 @@ static const char *__doc_sgl_slangpy_CallContext_m_call_shape = R"doc()doc"; static const char *__doc_sgl_slangpy_CallContext_m_device = R"doc()doc"; -static const char *__doc_sgl_slangpy_CallDataMode = R"doc()doc"; - -static const char *__doc_sgl_slangpy_CallDataMode_entry_point = R"doc()doc"; - -static const char *__doc_sgl_slangpy_CallDataMode_global_data = R"doc()doc"; - -static const char *__doc_sgl_slangpy_CallDataMode_info = R"doc()doc"; - static const char *__doc_sgl_slangpy_CallMode = R"doc()doc"; static const char *__doc_sgl_slangpy_CallMode_bwds = R"doc()doc"; diff --git a/src/slangpy_ext/utils/slangpy.cpp b/src/slangpy_ext/utils/slangpy.cpp index 0fc4893d3..98b7c8c03 100644 --- a/src/slangpy_ext/utils/slangpy.cpp +++ b/src/slangpy_ext/utils/slangpy.cpp @@ -809,108 +809,166 @@ nb::object NativeCallData::exec( auto bind_call_data = [&](ShaderCursor cursor) { - // On first call, cache all field indices and offsets to avoid repeated string lookups - if (!m_cached_call_data_offsets.is_valid) { - // Get the call data cursor using string lookup (first call only) - ShaderCursor call_data_cursor; - if (m_call_data_mode == CallDataMode::entry_point) { - ShaderCursor entry_point_cursor = cursor.find_entry_point(0); - call_data_cursor = entry_point_cursor.find_field("call_data"); - m_cached_call_data_offsets.call_data_field_index = entry_point_cursor.find_field_index("call_data"); - } else { - call_data_cursor = cursor.find_field("call_data"); - m_cached_call_data_offsets.call_data_field_index = cursor.find_field_index("call_data"); + if (m_use_direct_args) { + // ---- Fast path: individual entry-point params ---- + ShaderCursor ep = cursor.find_entry_point(0); + + // On first call, cache field offsets for metadata fields + if (!m_cached_call_data_offsets.is_valid) { + m_cached_call_data_offsets.thread_count = ep.find_field("_thread_count").offset(); + m_cached_call_data_offsets.field_offset = ep.offset(); + m_cached_call_data_offsets.field_size + = (uint32_t)ep.slang_type_layout()->getSize(SLANG_PARAMETER_CATEGORY_UNIFORM); + if (call_shape.size() > 0) { + m_cached_call_data_offsets.call_dim = ep.find_field("_call_dim").offset(); + m_cached_call_data_offsets.grid_stride = ep.find_field("_grid_stride").offset(); + m_cached_call_data_offsets.grid_dim = ep.find_field("_grid_dim").offset(); + m_cached_call_data_offsets.array_stride = (int)ep.find_field("_call_dim") + .slang_type_layout() + ->getElementStride(SLANG_PARAMETER_CATEGORY_UNIFORM); + } + m_cached_call_data_offsets.is_valid = true; } - // Cache whether call_data needs dereference - m_cached_call_data_offsets.call_data_is_reference = call_data_cursor.is_reference(); - if (m_cached_call_data_offsets.call_data_is_reference) - call_data_cursor = call_data_cursor.dereference(); + // Reserve memory block for all entry-point uniform fields + ShaderObject* shader_object = ep.shader_object(); + void* base_address = shader_object->reserve_data( + m_cached_call_data_offsets.field_offset, + m_cached_call_data_offsets.field_size + ); - // Cache all field offsets - m_cached_call_data_offsets.call_dim = call_data_cursor.find_field("_call_dim").offset(); - m_cached_call_data_offsets.grid_stride = call_data_cursor.find_field("_grid_stride").offset(); - m_cached_call_data_offsets.grid_dim = call_data_cursor.find_field("_grid_dim").offset(); - m_cached_call_data_offsets.thread_count = call_data_cursor.find_field("_thread_count").offset(); - m_cached_call_data_offsets.field_offset = call_data_cursor.offset(); - m_cached_call_data_offsets.field_size - = (uint32_t)call_data_cursor.slang_type_layout()->getSize(SLANG_PARAMETER_CATEGORY_UNIFORM); - if (m_cached_call_data_offsets.call_dim.is_valid()) { - m_cached_call_data_offsets.array_stride = (int)call_data_cursor.find_field("_call_dim") - .slang_type_layout() - ->getElementStride(SLANG_PARAMETER_CATEGORY_UNIFORM); + if (call_shape.size() > 0) { + // Write shape arrays using cached offsets + write_strided_array_helper( + base_address, + m_cached_call_data_offsets.call_dim.uniform_offset + - m_cached_call_data_offsets.field_offset.uniform_offset, + call_shape.data(), + call_shape.size(), + m_cached_call_data_offsets.array_stride + ); + + write_strided_array_helper( + base_address, + m_cached_call_data_offsets.grid_stride.uniform_offset + - m_cached_call_data_offsets.field_offset.uniform_offset, + call_grid_strides.data(), + call_grid_strides.size(), + m_cached_call_data_offsets.array_stride + ); + + write_strided_array_helper( + base_address, + m_cached_call_data_offsets.grid_dim.uniform_offset + - m_cached_call_data_offsets.field_offset.uniform_offset, + call_grid_shape.data(), + call_grid_shape.size(), + m_cached_call_data_offsets.array_stride + ); } - m_cached_call_data_offsets.is_valid = true; - } - // Fast path: use cached field index to find call_data cursor - ShaderCursor call_data_cursor; - if (m_call_data_mode == CallDataMode::entry_point) { - call_data_cursor - = cursor.find_entry_point(0).get_field_by_index(m_cached_call_data_offsets.call_data_field_index); + // Write thread count + uint3 thread_count_value(total_threads, 1, 1); + write_value_helper( + base_address, + m_cached_call_data_offsets.thread_count.uniform_offset + - m_cached_call_data_offsets.field_offset.uniform_offset, + thread_count_value + ); + + // Pass entry-point cursor as call_data_cursor — marshalls navigate ep[var_name] + m_runtime->write_shader_cursor_pre_dispatch(context, cursor, ep, unpacked_args, unpacked_kwargs, read_back); } else { - call_data_cursor = cursor.get_field_by_index(m_cached_call_data_offsets.call_data_field_index); - } + // ---- Fallback path: ParameterBlock at module scope (all backends) ---- + // On first call, cache all field indices and offsets + if (!m_cached_call_data_offsets.is_valid) { + ShaderCursor call_data_cursor = cursor.find_field("call_data"); + m_cached_call_data_offsets.call_data_field_index = cursor.find_field_index("call_data"); - // Dereference the cursor if needed (using cached result) - if (m_cached_call_data_offsets.call_data_is_reference) - call_data_cursor = call_data_cursor.dereference(); + // Cache whether call_data needs dereference + m_cached_call_data_offsets.call_data_is_reference = call_data_cursor.is_reference(); + if (m_cached_call_data_offsets.call_data_is_reference) + call_data_cursor = call_data_cursor.dereference(); + + // Cache all field offsets + m_cached_call_data_offsets.call_dim = call_data_cursor.find_field("_call_dim").offset(); + m_cached_call_data_offsets.grid_stride = call_data_cursor.find_field("_grid_stride").offset(); + m_cached_call_data_offsets.grid_dim = call_data_cursor.find_field("_grid_dim").offset(); + m_cached_call_data_offsets.thread_count = call_data_cursor.find_field("_thread_count").offset(); + m_cached_call_data_offsets.field_offset = call_data_cursor.offset(); + m_cached_call_data_offsets.field_size + = (uint32_t)call_data_cursor.slang_type_layout()->getSize(SLANG_PARAMETER_CATEGORY_UNIFORM); + if (m_cached_call_data_offsets.call_dim.is_valid()) { + m_cached_call_data_offsets.array_stride = (int)call_data_cursor.find_field("_call_dim") + .slang_type_layout() + ->getElementStride(SLANG_PARAMETER_CATEGORY_UNIFORM); + } + m_cached_call_data_offsets.is_valid = true; + } - // Reserve memory block for all call data fields - ShaderObject* shader_object = call_data_cursor.shader_object(); - void* base_address = shader_object->reserve_data( - m_cached_call_data_offsets.field_offset, - m_cached_call_data_offsets.field_size - ); + // Fast path: use cached field index to find call_data cursor + ShaderCursor call_data_cursor = cursor.get_field_by_index(m_cached_call_data_offsets.call_data_field_index); - if (call_shape.size() > 0) { - // Write arrays using cached offsets and direct memory access - write_strided_array_helper( - base_address, - m_cached_call_data_offsets.call_dim.uniform_offset - - m_cached_call_data_offsets.field_offset.uniform_offset, - call_shape.data(), - call_shape.size(), - m_cached_call_data_offsets.array_stride + // Dereference the cursor if needed (using cached result) + if (m_cached_call_data_offsets.call_data_is_reference) + call_data_cursor = call_data_cursor.dereference(); + + // Reserve memory block for all call data fields + ShaderObject* shader_object = call_data_cursor.shader_object(); + void* base_address = shader_object->reserve_data( + m_cached_call_data_offsets.field_offset, + m_cached_call_data_offsets.field_size ); - write_strided_array_helper( + if (call_shape.size() > 0) { + // Write arrays using cached offsets and direct memory access + write_strided_array_helper( + base_address, + m_cached_call_data_offsets.call_dim.uniform_offset + - m_cached_call_data_offsets.field_offset.uniform_offset, + call_shape.data(), + call_shape.size(), + m_cached_call_data_offsets.array_stride + ); + + write_strided_array_helper( + base_address, + m_cached_call_data_offsets.grid_stride.uniform_offset + - m_cached_call_data_offsets.field_offset.uniform_offset, + call_grid_strides.data(), + call_grid_strides.size(), + m_cached_call_data_offsets.array_stride + ); + + write_strided_array_helper( + base_address, + m_cached_call_data_offsets.grid_dim.uniform_offset + - m_cached_call_data_offsets.field_offset.uniform_offset, + call_grid_shape.data(), + call_grid_shape.size(), + m_cached_call_data_offsets.array_stride + ); + } + + // Write thread count + uint3 thread_count_value(total_threads, 1, 1); + write_value_helper( base_address, - m_cached_call_data_offsets.grid_stride.uniform_offset + m_cached_call_data_offsets.thread_count.uniform_offset - m_cached_call_data_offsets.field_offset.uniform_offset, - call_grid_strides.data(), - call_grid_strides.size(), - m_cached_call_data_offsets.array_stride + thread_count_value ); - write_strided_array_helper( - base_address, - m_cached_call_data_offsets.grid_dim.uniform_offset - - m_cached_call_data_offsets.field_offset.uniform_offset, - call_grid_shape.data(), - call_grid_shape.size(), - m_cached_call_data_offsets.array_stride + m_runtime->write_shader_cursor_pre_dispatch( + context, + cursor, + call_data_cursor, + unpacked_args, + unpacked_kwargs, + read_back ); } - // Write thread count - uint3 thread_count_value(total_threads, 1, 1); - write_value_helper( - base_address, - m_cached_call_data_offsets.thread_count.uniform_offset - - m_cached_call_data_offsets.field_offset.uniform_offset, - thread_count_value - ); - - m_runtime->write_shader_cursor_pre_dispatch( - context, - cursor, - call_data_cursor, - unpacked_args, - unpacked_kwargs, - read_back - ); - nb::list uniforms = opts->uniforms(); if (uniforms) { for (auto u : uniforms) { @@ -1268,7 +1326,6 @@ SGL_PY_EXPORT(utils_slangpy) nb::sgl_enum(slangpy, "AccessType"); nb::sgl_enum(slangpy, "CallMode"); - nb::sgl_enum(slangpy, "CallDataMode"); nb::sgl_enum(slangpy, "AutogradAccess"); slangpy.def( @@ -1629,12 +1686,6 @@ SGL_PY_EXPORT(utils_slangpy) &NativeCallData::set_call_mode, D_NA(NativeCallData, call_mode) ) - .def_prop_rw( - "call_data_mode", - &NativeCallData::call_data_mode, - &NativeCallData::set_call_data_mode, - D_NA(NativeCallData, call_data_mode) - ) .def_prop_ro("last_call_shape", &NativeCallData::last_call_shape, D_NA(NativeCallData, last_call_shape)) .def_prop_rw( "debug_name", diff --git a/src/slangpy_ext/utils/slangpy.h b/src/slangpy_ext/utils/slangpy.h index 48dbff97e..806de9374 100644 --- a/src/slangpy_ext/utils/slangpy.h +++ b/src/slangpy_ext/utils/slangpy.h @@ -743,12 +743,6 @@ class NativeCallData : Object { /// Set the call mode (primitive/forward/backward). void set_call_mode(CallMode call_mode) { m_call_mode = call_mode; } - /// Get the call data mode (global_data/entry_point). - CallDataMode call_data_mode() const { return m_call_data_mode; } - - /// Set the call data mode (global_data/entry_point). - void set_call_data_mode(CallDataMode call_data_mode) { m_call_data_mode = call_data_mode; } - /// Get the shape of the last call (useful for debugging). const Shape& last_call_shape() const { return m_last_call_shape; } @@ -901,10 +895,10 @@ class NativeCallData : Object { ShaderOffset grid_stride; ShaderOffset grid_dim; ShaderOffset thread_count; - ShaderOffset field_offset; // Base offset of the call_data structure + ShaderOffset field_offset; // Base offset of the call_data structure (or entry-point) uint32_t field_size = 0; // Total size of the call_data in uniform data int array_stride = 0; // Stride for array elements - // Cached information for navigating to call_data field + // Cached information for navigating to call_data field (fallback path) int32_t call_data_field_index = -1; // Field index for "call_data" lookup bool call_data_is_reference = false; // Whether call_data needs dereference bool is_valid = false; // Whether offsets have been initialized @@ -916,7 +910,6 @@ class NativeCallData : Object { int m_call_dimensionality{0}; ref m_runtime; CallMode m_call_mode{CallMode::prim}; - CallDataMode m_call_data_mode{CallDataMode::global_data}; Shape m_last_call_shape; std::string m_debug_name; ref m_logger; From c941898bd3b4f558a59efdaf05e9a8bbf29d1503 Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Thu, 12 Mar 2026 16:10:58 +0000 Subject: [PATCH 20/41] wip switching to entry point arguments --- .../plan-simplifyKernelGen-phase2.prompt.md | 69 ++++++++++++------- slangpy/core/calldata.py | 1 + slangpy/core/callsignature.py | 6 +- 3 files changed, 52 insertions(+), 24 deletions(-) diff --git a/.github/prompts/plan-simplifyKernelGen-phase2.prompt.md b/.github/prompts/plan-simplifyKernelGen-phase2.prompt.md index ba204861f..7411fdf9f 100644 --- a/.github/prompts/plan-simplifyKernelGen-phase2.prompt.md +++ b/.github/prompts/plan-simplifyKernelGen-phase2.prompt.md @@ -159,14 +159,14 @@ void compute_main(int3 flat_call_thread_id: SV_DispatchThreadID, ...) { Tests added to [slangpy/tests/slangpy_tests/test_kernel_gen.py](slangpy/tests/slangpy_tests/test_kernel_gen.py). All 21 parametrized cases (7 tests × 3 device types) pass. -| Test | Source | Args | Asserts (current) | Breaks when | -|------|--------|------|--------------------|-------------| -| `test_gate_p2_calldata_struct_present` | `int add(int a, int b)` | `(1, 2)` | `struct CallData` in code | Step 2.2 | -| `test_gate_p2_calldata_uniform_param` | same | same | `uniform CallData call_data` (CUDA) or `ParameterBlock call_data` (D3D12/Vulkan) | Step 2.2 | -| `test_gate_p2_thread_count_in_calldata` | same | same | `call_data._thread_count` | Step 2.2 | -| `test_gate_p2_trampoline_present_for_prim` | same | same | `void _trampoline(` present | Step 2.3 | -| `test_gate_p2_kernel_calls_trampoline` | same | same | `_trampoline(` in `compute_main` body | Step 2.3 | -| `test_gate_p2_sv_group_id_present` | same | same | `SV_GroupID` in `compute_main` signature | Step 2.2 | +| Test | Source | Args | Original assertion | Status | +|------|--------|------|--------------------|--------| +| `test_gate_p2_calldata_struct_present` | `int add(int a, int b)` | `(1, 2)` | `struct CallData` in code | ✅ Flipped — now asserts `struct CallData` ABSENT (Step 2.2 done) | +| `test_gate_p2_calldata_uniform_param` | same | same | `uniform CallData call_data` or `ParameterBlock` | ✅ Flipped — now asserts both ABSENT (Step 2.2 done) | +| `test_gate_p2_thread_count_in_calldata` | same | same | `call_data._thread_count` | ✅ Flipped — now asserts ABSENT (Step 2.2 done) | +| `test_gate_p2_trampoline_present_for_prim` | same | same | `void _trampoline(` present | Still asserts present (Step 2.3 pending) | +| `test_gate_p2_kernel_calls_trampoline` | same | same | `_trampoline(` in `compute_main` body | Still asserts present (Step 2.3 pending) | +| `test_gate_p2_sv_group_id_present` | same | same | `SV_GroupID` in `compute_main` signature | ✅ Flipped — now asserts ABSENT for dim-0 calls (Step 2.2 done) | Negative gates (must stay passing after Phase 2): @@ -174,7 +174,12 @@ Negative gates (must stay passing after Phase 2): |------|---------| | `test_gate_p2_wanghasharg_keeps_load` | Non-direct-bind arg still uses `__slangpy_load` | -**Note:** `test_gate_p2_calldata_uniform_param` checks for either `uniform CallData call_data` (CUDA entry-point param) or `ParameterBlock call_data` (D3D12/Vulkan module-scope), since the current `CallDataMode` distinction means different backends emit different patterns. +Bwds gates: + +| Test | Status | +|------|--------| +| `test_gate_scalar_uses_valuetype` | ✅ Passing — asserts fast-path trampoline with `__in_` prefix params | +| `test_gate_bwds_scalar_uses_valuetype` | ❌ Failing — bwds trampoline missing `no_diff` annotations (Step 2.4 pending) | --- @@ -212,7 +217,9 @@ In [slangpy/core/calldata.py](slangpy/core/calldata.py), after `calculate_direct --- -### Step 2.2: Code generation — entry-point params (fast path) +### Step 2.2: Code generation — entry-point params (fast path) ✅ + +**Status: DONE** In [slangpy/core/callsignature.py](slangpy/core/callsignature.py) `generate_code()`, when `use_direct_args == True`: @@ -255,6 +262,8 @@ See [slangpy/tests/device/test_pipeline_utils.slang](slangpy/tests/device/test_p ### Step 2.3: Trampoline elimination for prim mode +**Status: NOT STARTED** — Trampoline is still generated for prim mode on both paths. The load/call/store sequence needs to be inlined into `compute_main`. + When `call_mode == prim` — on **both** fast and fallback paths: - Don't generate the `_trampoline` function. @@ -275,6 +284,8 @@ When `call_mode == prim` — on **both** fast and fallback paths: ### Step 2.4: Trampoline with individual params for bwds mode +**Status: IN PROGRESS** — Fast-path trampoline takes individual params but is missing `no_diff` annotations. The bwds test (`test_gate_bwds_scalar_uses_valuetype`) fails because `bwd_diff(_trampoline)` receives plain `float` params but Slang expects `DifferentialPair` for differentiable params. Need to use `_gen_trampoline_argument()` from `boundvariable.py` to emit proper `no_diff in`/`inout` annotations. + When `call_mode == bwds`: - Still generate a `[Differentiable]` trampoline function. @@ -289,7 +300,9 @@ When `call_mode == bwds`: --- -### Step 2.5: C++ dispatch changes +### Step 2.5: C++ dispatch changes ✅ + +**Status: DONE** — `CallDataMode` enum fully removed. Fast path uses `find_entry_point(0)` on all backends. Fallback path uses global `ParameterBlock` on all backends. In [src/slangpy_ext/utils/slangpy.cpp](src/slangpy_ext/utils/slangpy.cpp), store `m_use_direct_args` on `NativeCallData` (received from Python `CallData`). Also add to [slangpy.h](src/slangpy_ext/utils/slangpy.h). @@ -310,6 +323,8 @@ Modify `bind_call_data` lambda in `exec()`: ### Step 2.6: `_result` handling +**Status: NOT STARTED** + Auto-created `_result` is a writable `ValueRef`, currently NOT direct-bind eligible (needs `RWValueRef` wrapper with buffer logic). Phase 2 handles this differently on the two paths: **Fast path**: `_result` is emitted as `uniform RWValueRef _result` on the entry point. In prim mode, the inlined code stores via `_result.__slangpy_store(...)`. In the all-direct-bind case where Context is omitted, add a new code path: emit `uniform RWStructuredBuffer _result` with `_result[0] = value` for the store. This requires `ValueRefMarshall` to support writable direct-bind for the entry-point-param case specifically, using `RWStructuredBuffer` instead of `RWValueRef`. @@ -322,6 +337,8 @@ Auto-created `_result` is a writable `ValueRef`, currently NOT direct-bind eligi ### Step 2.7: Tests +**Status: NOT STARTED** + **Post-implementation tests** — should pass AFTER Phase 2 is complete: | Test | Verifies | @@ -350,13 +367,13 @@ Auto-created `_result` is a writable `ValueRef`, currently NOT direct-bind eligi 1. **Step 2.0** ✅ — Gating tests (baseline documentation) 2. **Step 2.1** ✅ — Fast/fallback determination + size query -3. **Step 2.3** — Trampoline elimination for prim mode (both paths). This is independent of entry-point param work and provides immediate value. -4. **Step 2.2 + 2.5** — Code gen + C++ dispatch for entry-point params (must land together — Slang layout and C++ cursor navigation must agree) -5. **Step 2.4** — Bwds trampoline with individual params (fast path) +3. **Step 2.2 + 2.5** ✅ — Code gen + C++ dispatch for entry-point params + `CallDataMode` removal (landed together) +4. **Step 2.4** 🔧 — Bwds trampoline with individual params (fast path) — missing `no_diff` annotations +5. **Step 2.3** — Trampoline elimination for prim mode (both paths) 6. **Step 2.6** — `_result` as `RWStructuredBuffer` for all-direct-bind case 7. **Step 2.7** — Post-implementation tests + functional tests -Steps 2.3 (trampoline) and 2.2/2.5 (entry-point params) are independent axes and can be done in either order. Starting with 2.3 is recommended because it's simpler and touches fewer files. +**Note:** Implementation order deviated from original plan — Steps 2.2 + 2.5 were done before 2.3 (trampoline elimination), combined with `CallDataMode` removal. Step 2.4 is partially done (trampoline takes individual params but missing `no_diff` annotations for bwds mode). --- @@ -364,17 +381,23 @@ Steps 2.3 (trampoline) and 2.2/2.5 (entry-point params) are independent axes and | File | Changes | |------|---------| -| [slangpy/core/calldata.py](slangpy/core/calldata.py) | `use_direct_args` flag, size threshold check, remove `CallDataMode` usage | -| [slangpy/core/callsignature.py](slangpy/core/callsignature.py) | `generate_code()` — inline load/call/store, entry-point params, Context gating, remove `is_entry_point` branch | -| [slangpy/bindings/codegen.py](slangpy/bindings/codegen.py) | `skip_call_data` flag, `entry_point_params` list | -| [slangpy/bindings/boundvariable.py](slangpy/bindings/boundvariable.py) | `gen_call_data_code` depth-0 entry-point path; `_gen_trampoline_argument()` usage | -| [src/slangpy_ext/utils/slangpy.cpp](src/slangpy_ext/utils/slangpy.cpp) | ✅ `use_direct_args` binding; `bind_call_data` fast path via `find_entry_point(0)`, remove `CallDataMode` branches | -| [src/slangpy_ext/utils/slangpy.h](src/slangpy_ext/utils/slangpy.h) | ✅ `m_use_direct_args` on `NativeCallData`; remove `m_call_data_mode` | +| [slangpy/core/calldata.py](slangpy/core/calldata.py) | ✅ `use_direct_args` flag, size threshold check, `CallDataMode` removed | +| [slangpy/core/callsignature.py](slangpy/core/callsignature.py) | ✅ Entry-point params, fast/fallback code paths, `is_entry_point` branch removed. Trampoline still generated (Step 2.3 pending). Bwds missing `no_diff` (Step 2.4 pending). | +| [slangpy/bindings/codegen.py](slangpy/bindings/codegen.py) | ✅ `skip_call_data` flag, `entry_point_params` list | +| [slangpy/bindings/boundvariable.py](slangpy/bindings/boundvariable.py) | ✅ `gen_call_data_code` depth-0 entry-point path. `_gen_trampoline_argument()` not yet used (Step 2.4 pending). | +| [slangpy/bindings/marshall.py](slangpy/bindings/marshall.py) | ✅ `use_direct_args` field on `BindContext`, `CallDataMode` removed | +| [src/slangpy_ext/utils/slangpy.cpp](src/slangpy_ext/utils/slangpy.cpp) | ✅ `use_direct_args` binding; `bind_call_data` fast path via `find_entry_point(0)`, `CallDataMode` branches removed | +| [src/slangpy_ext/utils/slangpy.h](src/slangpy_ext/utils/slangpy.h) | ✅ `m_use_direct_args` on `NativeCallData`; `m_call_data_mode` removed | | [src/sgl/device/device.h](src/sgl/device/device.h) | ✅ `max_entry_point_uniform_size` on `DeviceLimits` | | [src/sgl/device/device.cpp](src/sgl/device/device.cpp) | ✅ Per-backend defaults for `max_entry_point_uniform_size` | | [src/slangpy_ext/device/device.cpp](src/slangpy_ext/device/device.cpp) | ✅ Python binding for `max_entry_point_uniform_size` | -| [src/sgl/utils/slangpy.h](src/sgl/utils/slangpy.h) | Remove `CallDataMode` enum definition | -| [slangpy/tests/slangpy_tests/test_kernel_gen.py](slangpy/tests/slangpy_tests/test_kernel_gen.py) | ✅ Gating tests + Step 2.1 tests; post-implementation tests | +| [src/sgl/utils/slangpy.h](src/sgl/utils/slangpy.h) | ✅ `CallDataMode` enum removed | +| [slangpy/core/dispatchdata.py](slangpy/core/dispatchdata.py) | ✅ `CallDataMode` removed | +| [slangpy/core/packedarg.py](slangpy/core/packedarg.py) | ✅ `CallDataMode` removed | +| [slangpy/core/function.py](slangpy/core/function.py) | ✅ `CallDataMode` removed from imports | +| [slangpy/slangpy/__init__.pyi](slangpy/slangpy/__init__.pyi) | ✅ `CallDataMode` class and `call_data_mode` property removed | +| [slangpy/tests/slangpy_tests/test_type_resolution.py](slangpy/tests/slangpy_tests/test_type_resolution.py) | ✅ `CallDataMode` removed from `BindContext` creation | +| [slangpy/tests/slangpy_tests/test_kernel_gen.py](slangpy/tests/slangpy_tests/test_kernel_gen.py) | ✅ Gating tests + Step 2.1 tests updated for new behavior; post-implementation tests (Step 2.7) pending | --- diff --git a/slangpy/core/calldata.py b/slangpy/core/calldata.py index fa854df65..dc220e44e 100644 --- a/slangpy/core/calldata.py +++ b/slangpy/core/calldata.py @@ -263,6 +263,7 @@ def build(self, build_info: "FunctionBuildInfo", *args: Any, **kwargs: Any): inline_size = calculate_inline_uniform_size(bindings, self.call_dimensionality) threshold = build_info.module.device.info.limits.max_entry_point_uniform_size self.use_direct_args = inline_size <= threshold + # self.use_direct_args = False self.log_debug( f" Inline uniform size: {inline_size} bytes, " f"threshold: {threshold} bytes, " diff --git a/slangpy/core/callsignature.py b/slangpy/core/callsignature.py index a8a5a4a71..9fdaf7140 100644 --- a/slangpy/core/callsignature.py +++ b/slangpy/core/callsignature.py @@ -414,12 +414,16 @@ def generate_code( if use_direct_args: # Fast path: trampoline takes individual calldata-typed params. # Use __in_ prefix for param names to avoid collision with local variable names. + # All params are no_diff — entry-point uniforms are never differentiable. + # Differentiation happens through local variable assignments inside the trampoline, + # matching the struct-based approach where CallData was implicitly non-differentiable. trampoline_params = ["Context __slangpy_context__"] for x in root_params: if x.create_param_block: continue # param blocks handled via _param_ at module scope assert x.calldata_type_name is not None - trampoline_params.append(f"{x.calldata_type_name} __in_{x.variable_name}") + arg_def = f"no_diff {x.calldata_type_name} __in_{x.variable_name}" + trampoline_params.append(arg_def) cg.trampoline.append_line(f"void {trampoline_fn}({', '.join(trampoline_params)})") else: # Fallback: trampoline reads from global ParameterBlock call_data From e589a3e7b5fe477aef36b8e52b66072319122752 Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Thu, 12 Mar 2026 17:05:46 +0000 Subject: [PATCH 21/41] working reduced entry points --- .../plan-simplifyKernelGen-phase2.prompt.md | 22 +- slangpy/core/calldata.py | 380 ++++++++++-------- slangpy/tests/slangpy_tests/test_code_gen.py | 13 +- .../tests/slangpy_tests/test_kernel_gen.py | 5 +- 4 files changed, 236 insertions(+), 184 deletions(-) diff --git a/.github/prompts/plan-simplifyKernelGen-phase2.prompt.md b/.github/prompts/plan-simplifyKernelGen-phase2.prompt.md index 7411fdf9f..4a2bfb8dc 100644 --- a/.github/prompts/plan-simplifyKernelGen-phase2.prompt.md +++ b/.github/prompts/plan-simplifyKernelGen-phase2.prompt.md @@ -179,7 +179,7 @@ Bwds gates: | Test | Status | |------|--------| | `test_gate_scalar_uses_valuetype` | ✅ Passing — asserts fast-path trampoline with `__in_` prefix params | -| `test_gate_bwds_scalar_uses_valuetype` | ❌ Failing — bwds trampoline missing `no_diff` annotations (Step 2.4 pending) | +| `test_gate_bwds_scalar_uses_valuetype` | ✅ Passing — bwds trampoline has `no_diff` on all params (Step 2.4 done) | --- @@ -282,21 +282,23 @@ When `call_mode == prim` — on **both** fast and fallback paths: --- -### Step 2.4: Trampoline with individual params for bwds mode +### Step 2.4: Trampoline with individual params for bwds mode ✅ -**Status: IN PROGRESS** — Fast-path trampoline takes individual params but is missing `no_diff` annotations. The bwds test (`test_gate_bwds_scalar_uses_valuetype`) fails because `bwd_diff(_trampoline)` receives plain `float` params but Slang expects `DifferentialPair` for differentiable params. Need to use `_gen_trampoline_argument()` from `boundvariable.py` to emit proper `no_diff in`/`inout` annotations. +**Status: DONE** — Fast-path trampoline takes individual params with `no_diff` on all params. All 3 device types pass. When `call_mode == bwds`: - Still generate a `[Differentiable]` trampoline function. -- **Fast path**: Trampoline takes individual params instead of a struct. Use `_gen_trampoline_argument()` from [boundvariable.py](slangpy/bindings/boundvariable.py#L691) (currently dead code) to generate the signature — it already handles `in`/`out`/`inout` and `no_diff` annotations: +- **Fast path**: Trampoline takes individual params instead of a struct. All params get `no_diff` — entry-point uniforms are never differentiable. Differentiation happens through local variable assignments inside the trampoline body, matching the struct-based approach where `CallData` was implicitly non-differentiable. No `in`/`out`/`inout` modifiers are added — `compute_main` passes its uniforms straight through: ```slang [Differentiable] - void _trampoline(Context __slangpy_context__, no_diff in int a, no_diff in int b, ...) + void _trampoline(Context __slangpy_context__, no_diff float __in_a, no_diff float __in_b, no_diff NoneType __in__result) ``` `compute_main` calls `bwd_diff(_trampoline)(__slangpy_context__, a, b, _result)` passing entry-point param names directly. - **Fallback path**: Trampoline reads from global `ParameterBlock call_data` as it does today (on all backends). `compute_main` calls `bwd_diff(_trampoline)(__slangpy_context__, call_data)`. -- Non-differentiable arguments (int, bool, etc.) get `no_diff` prefix automatically via `_gen_trampoline_argument()`. This may need to be added to additional integer or non-differentiable trampoline arguments to make the generated shader compile under Slang's autodiff rules. +- `_gen_trampoline_argument()` in `boundvariable.py` remains unused dead code — the inline generation in `callsignature.py` is simpler and avoids the `in`/`out`/`inout` modifiers that caused Slang autodiff errors. + +**Key insight**: Adding `in`/`out`/`inout` modifiers to trampoline params caused Slang autodiff issues (e.g., `out` params get reversed to `in` by `bwd_diff`, changing arity). The trampoline params are just pass-through uniforms — all data flow logic (loads, stores, differentiation) is handled internally via local variables. --- @@ -368,12 +370,12 @@ Auto-created `_result` is a writable `ValueRef`, currently NOT direct-bind eligi 1. **Step 2.0** ✅ — Gating tests (baseline documentation) 2. **Step 2.1** ✅ — Fast/fallback determination + size query 3. **Step 2.2 + 2.5** ✅ — Code gen + C++ dispatch for entry-point params + `CallDataMode` removal (landed together) -4. **Step 2.4** 🔧 — Bwds trampoline with individual params (fast path) — missing `no_diff` annotations +4. **Step 2.4** ✅ — Bwds trampoline with individual params (fast path) — `no_diff` on all params 5. **Step 2.3** — Trampoline elimination for prim mode (both paths) 6. **Step 2.6** — `_result` as `RWStructuredBuffer` for all-direct-bind case 7. **Step 2.7** — Post-implementation tests + functional tests -**Note:** Implementation order deviated from original plan — Steps 2.2 + 2.5 were done before 2.3 (trampoline elimination), combined with `CallDataMode` removal. Step 2.4 is partially done (trampoline takes individual params but missing `no_diff` annotations for bwds mode). +**Note:** Implementation order deviated from original plan — Steps 2.2 + 2.5 were done before 2.3 (trampoline elimination), combined with `CallDataMode` removal. Step 2.4 done — all trampoline params use `no_diff` without IO modifiers. --- @@ -382,9 +384,9 @@ Auto-created `_result` is a writable `ValueRef`, currently NOT direct-bind eligi | File | Changes | |------|---------| | [slangpy/core/calldata.py](slangpy/core/calldata.py) | ✅ `use_direct_args` flag, size threshold check, `CallDataMode` removed | -| [slangpy/core/callsignature.py](slangpy/core/callsignature.py) | ✅ Entry-point params, fast/fallback code paths, `is_entry_point` branch removed. Trampoline still generated (Step 2.3 pending). Bwds missing `no_diff` (Step 2.4 pending). | +| [slangpy/core/callsignature.py](slangpy/core/callsignature.py) | ✅ Entry-point params, fast/fallback code paths, `is_entry_point` branch removed. Trampoline still generated (Step 2.3 pending). Bwds `no_diff` on all trampoline params (Step 2.4 done). | | [slangpy/bindings/codegen.py](slangpy/bindings/codegen.py) | ✅ `skip_call_data` flag, `entry_point_params` list | -| [slangpy/bindings/boundvariable.py](slangpy/bindings/boundvariable.py) | ✅ `gen_call_data_code` depth-0 entry-point path. `_gen_trampoline_argument()` not yet used (Step 2.4 pending). | +| [slangpy/bindings/boundvariable.py](slangpy/bindings/boundvariable.py) | ✅ `gen_call_data_code` depth-0 entry-point path. `_gen_trampoline_argument()` unused — inline generation in `callsignature.py` used instead. | | [slangpy/bindings/marshall.py](slangpy/bindings/marshall.py) | ✅ `use_direct_args` field on `BindContext`, `CallDataMode` removed | | [src/slangpy_ext/utils/slangpy.cpp](src/slangpy_ext/utils/slangpy.cpp) | ✅ `use_direct_args` binding; `bind_call_data` fast path via `find_entry_point(0)`, `CallDataMode` branches removed | | [src/slangpy_ext/utils/slangpy.h](src/slangpy_ext/utils/slangpy.h) | ✅ `m_use_direct_args` on `NativeCallData`; `m_call_data_mode` removed | diff --git a/slangpy/core/calldata.py b/slangpy/core/calldata.py index dc220e44e..2ad6942ec 100644 --- a/slangpy/core/calldata.py +++ b/slangpy/core/calldata.py @@ -262,185 +262,42 @@ def build(self, build_info: "FunctionBuildInfo", *args: Any, **kwargs: Any): # Sum inline-uniform byte size and compare against per-device threshold. inline_size = calculate_inline_uniform_size(bindings, self.call_dimensionality) threshold = build_info.module.device.info.limits.max_entry_point_uniform_size - self.use_direct_args = inline_size <= threshold - # self.use_direct_args = False + use_direct_args = inline_size <= threshold self.log_debug( f" Inline uniform size: {inline_size} bytes, " f"threshold: {threshold} bytes, " - f"use_direct_args: {self.use_direct_args}" + f"use_direct_args: {use_direct_args}" ) - # Propagate use_direct_args to context for code generation - context.use_direct_args = self.use_direct_args - - # Generate code. - codegen = CodeGen() - generate_code(context, build_info, bindings, codegen) - for link in build_info.module.link: - codegen.add_import(link.name) - code = codegen.finish( - call_data=True, - input_load_store=True, - header=True, - kernel=True, - imports=True, - trampoline=True, - context=True, - snippets=True, - call_data_structs=True, - constants=True, - use_param_block_for_call_data=not self.use_direct_args, - ) - - # Optionally write the shader to a file for debugging. - sanitized = "" - if _DUMP_GENERATED_SHADERS or _DUMP_SLANG_INTERMEDIATES: - os.makedirs(".temp", exist_ok=True) - santized_module = re.sub(r"[<>, ./:\\]", "_", build_info.module.name) - sanitized = re.sub(r"[:<>, ./:\\]", "_", build_info.name) - santized_module = santized_module[:50] - sanitized = sanitized[:50] - fn = f".temp/{santized_module}_{sanitized}{'_backwards' if self.call_mode == CallMode.bwds else ''}" - # Some platforms have path length limits that are easily exceeded with nested generics - # Be a good citizen here and limit the length of what we generate - length_limit = 200 - if len(fn) > length_limit: - fn = fn[:length_limit] - fn += "-" + hashlib.sha256(code.encode()).hexdigest()[0:8] - fn = fn + ".slang" - - # with open(fn,"r") as f: - # code = f.read() - with open( - fn, - "w", - ) as f: - f.write("/*\n") - f.write(bound_call_table(bindings)) - f.write("\n*/\n") - f.write(code) - - # Optionally print the shader to the terminal for AI analysis. - if _PRINT_GENERATED_SHADERS: - print("=" * 80) - print(f"GENERATED SHADER: {build_info.module.name}::{build_info.name}") - if self.call_mode == CallMode.bwds: - print("MODE: Backwards") - else: - print("MODE: Forward") - print("=" * 80) - print("/* BINDINGS:") - print(bound_call_table(bindings)) - print("*/") - print(code) - print("=" * 80) - print(f"END SHADER: {build_info.module.name}::{build_info.name}") - print("=" * 80) - print() - - # Hash the code to get a unique identifier for the module. - # We add type conformances to the start of the code to ensure that the hash is unique - code_minus_header = ( - "[CallData]\n" + str(build_info.type_conformances) + code[len(codegen.header) :] - ) - hash = hashlib.sha256(code_minus_header.encode()).hexdigest() - - # Check if we've already built this module. - if hash in build_info.module.pipeline_cache: - # Get pipeline from cache if we have - self.pipeline = build_info.module.pipeline_cache[hash] - # Get shader table from cache if the pipeline is a raytracing pipeline - if build_info.pipeline_type == PipelineType.ray_tracing: - self.shader_table = build_info.module.shader_table_cache[hash] - self.device = build_info.module.device - self.log_debug(f" Found cached pipeline with hash {hash}") - - else: - # Build new module and link it with the one that contains the function being called. - self.log_debug(f" Building new pipeline with hash {hash}") - session = build_info.module.session - device = session.device - module = session.load_module_from_source(hash, code) - opts = SlangLinkOptions() - opts.dump_intermediates = _DUMP_SLANG_INTERMEDIATES - opts.dump_intermediates_prefix = sanitized - if build_info.pipeline_type == PipelineType.compute: - # Create compute pipeline - ep = module.entry_point(f"compute_main", type_conformances) - program = session.link_program( - [module, build_info.module.device_module] + build_info.module.link, - [ep], - opts, - ) - self.pipeline = device.create_compute_pipeline( - program, - defer_target_compilation=True, - label=f"{build_info.module.name}_{build_info.name}_compute_call", - ) - build_info.module.pipeline_cache[hash] = self.pipeline - elif build_info.pipeline_type == PipelineType.ray_tracing: - # Create ray tracing pipeline - eps = [module.entry_point(f"raygen_main", type_conformances)] - hit_group_names: list[str] = [] - for hit_group in build_info.ray_tracing_hit_groups: - hit_group_names.append(hit_group.hit_group_name) - if hit_group.closest_hit_entry_point != "": - eps.append( - build_info.module.device_module.entry_point( - hit_group.closest_hit_entry_point - ) - ) - if hit_group.any_hit_entry_point != "": - eps.append( - build_info.module.device_module.entry_point( - hit_group.any_hit_entry_point - ) - ) - if hit_group.intersection_entry_point != "": - eps.append( - build_info.module.device_module.entry_point( - hit_group.intersection_entry_point - ) - ) - for miss_entry_point in build_info.ray_tracing_miss_entry_points: - eps.append(build_info.module.device_module.entry_point(miss_entry_point)) - - program = session.link_program( - [module, build_info.module.device_module] + build_info.module.link, - eps, - opts, - ) - self.pipeline = device.create_ray_tracing_pipeline( - program, - hit_groups=build_info.ray_tracing_hit_groups, - max_recursion=build_info.ray_tracing_max_recursion, - max_ray_payload_size=build_info.ray_tracing_max_ray_payload_size, - max_attribute_size=build_info.ray_tracing_max_attribute_size, - flags=build_info.ray_tracing_flags, - defer_target_compilation=True, - label=f"{build_info.module.name}_{build_info.name}_rt_call", - ) - build_info.module.pipeline_cache[hash] = self.pipeline - self.shader_table = device.create_shader_table( - program, - ray_gen_entry_points=["raygen_main"], - miss_entry_points=build_info.ray_tracing_miss_entry_points, - hit_group_names=hit_group_names, - callable_entry_points=build_info.ray_tracing_callable_entry_points, - ) - build_info.module.shader_table_cache[hash] = self.shader_table - else: - raise RuntimeError("Unknown pipeline type") - self.device = device - self.log_debug(f" Build succesful") + # Try building the shader. If direct args compilation fails (the + # threshold is only an approximate heuristic), fall back to + # ParameterBlock. + try: + self._try_build_shader( + context, + build_info, + bindings, + type_conformances, + use_direct_args=use_direct_args, + ) + except Exception as e: + if not use_direct_args: + raise + self.log_debug( + " Direct args compilation failed, " "retrying with ParameterBlock" + ) + self._try_build_shader( + context, + build_info, + bindings, + type_conformances, + use_direct_args=False, + ) # Store the bindings and runtime for later use. self.debug_only_bindings = bindings self.runtime = BoundCallRuntime(bindings) - # Store the code as its useful for debugging - self.code = code - # If using autograd, build list of access modes for each tensor argument. if self.torch_autograd: self._build_autograd_access_list(unpacked_args, unpacked_kwargs) @@ -502,6 +359,191 @@ def build(self, build_info: "FunctionBuildInfo", *args: Any, **kwargs: Any): else: raise + def _try_build_shader( + self, + context: BindContext, + build_info: "FunctionBuildInfo", + bindings: BoundCall, + type_conformances: Any, + use_direct_args: bool, + ) -> None: + """ + Generate shader code and build the pipeline. + + Sets self.use_direct_args, self.pipeline, self.device, self.code, + and optionally self.shader_table. + + :param context: Binding context. + :param build_info: Function build information. + :param bindings: Bound call with resolved variables. + :param type_conformances: Type conformances for entry point. + :param use_direct_args: If True, use entry-point params; otherwise ParameterBlock. + """ + self.use_direct_args = use_direct_args + context.use_direct_args = use_direct_args + + # Generate code. + codegen = CodeGen() + generate_code(context, build_info, bindings, codegen) + for link in build_info.module.link: + codegen.add_import(link.name) + code = codegen.finish( + call_data=True, + input_load_store=True, + header=True, + kernel=True, + imports=True, + trampoline=True, + context=True, + snippets=True, + call_data_structs=True, + constants=True, + use_param_block_for_call_data=not use_direct_args, + ) + + # Optionally write the shader to a file for debugging. + sanitized = "" + if _DUMP_GENERATED_SHADERS or _DUMP_SLANG_INTERMEDIATES: + os.makedirs(".temp", exist_ok=True) + santized_module = re.sub(r"[<>, ./:\\]", "_", build_info.module.name) + sanitized = re.sub(r"[:<>, ./:\\]", "_", build_info.name) + santized_module = santized_module[:50] + sanitized = sanitized[:50] + fn = f".temp/{santized_module}_{sanitized}{'_backwards' if self.call_mode == CallMode.bwds else ''}" + # Some platforms have path length limits that are easily exceeded with nested generics + # Be a good citizen here and limit the length of what we generate + length_limit = 200 + if len(fn) > length_limit: + fn = fn[:length_limit] + fn += "-" + hashlib.sha256(code.encode()).hexdigest()[0:8] + fn = fn + ".slang" + + with open( + fn, + "w", + ) as f: + f.write("/*\n") + f.write(bound_call_table(bindings)) + f.write("\n*/\n") + f.write(code) + + # Optionally print the shader to the terminal for AI analysis. + if _PRINT_GENERATED_SHADERS: + print("=" * 80) + print(f"GENERATED SHADER: {build_info.module.name}::{build_info.name}") + if self.call_mode == CallMode.bwds: + print("MODE: Backwards") + else: + print("MODE: Forward") + print("=" * 80) + print("/* BINDINGS:") + print(bound_call_table(bindings)) + print("*/") + print(code) + print("=" * 80) + print(f"END SHADER: {build_info.module.name}::{build_info.name}") + print("=" * 80) + print() + + # Hash the code to get a unique identifier for the module. + # We add type conformances to the start of the code to ensure that the hash is unique + code_minus_header = ( + "[CallData]\n" + str(build_info.type_conformances) + code[len(codegen.header) :] + ) + hash = hashlib.sha256(code_minus_header.encode()).hexdigest() + + # Check if we've already built this module. + if hash in build_info.module.pipeline_cache: + # Get pipeline from cache if we have + self.pipeline = build_info.module.pipeline_cache[hash] + # Get shader table from cache if the pipeline is a raytracing pipeline + if build_info.pipeline_type == PipelineType.ray_tracing: + self.shader_table = build_info.module.shader_table_cache[hash] + self.device = build_info.module.device + self.log_debug(f" Found cached pipeline with hash {hash}") + + else: + # Build new module and link it with the one that contains the function being called. + self.log_debug(f" Building new pipeline with hash {hash}") + session = build_info.module.session + device = session.device + module = session.load_module_from_source(hash, code) + opts = SlangLinkOptions() + opts.dump_intermediates = _DUMP_SLANG_INTERMEDIATES + opts.dump_intermediates_prefix = sanitized + if build_info.pipeline_type == PipelineType.compute: + # Create compute pipeline + ep = module.entry_point(f"compute_main", type_conformances) + program = session.link_program( + [module, build_info.module.device_module] + build_info.module.link, + [ep], + opts, + ) + self.pipeline = device.create_compute_pipeline( + program, + defer_target_compilation=True, + label=f"{build_info.module.name}_{build_info.name}_compute_call", + ) + build_info.module.pipeline_cache[hash] = self.pipeline + elif build_info.pipeline_type == PipelineType.ray_tracing: + # Create ray tracing pipeline + eps = [module.entry_point(f"raygen_main", type_conformances)] + hit_group_names: list[str] = [] + for hit_group in build_info.ray_tracing_hit_groups: + hit_group_names.append(hit_group.hit_group_name) + if hit_group.closest_hit_entry_point != "": + eps.append( + build_info.module.device_module.entry_point( + hit_group.closest_hit_entry_point + ) + ) + if hit_group.any_hit_entry_point != "": + eps.append( + build_info.module.device_module.entry_point( + hit_group.any_hit_entry_point + ) + ) + if hit_group.intersection_entry_point != "": + eps.append( + build_info.module.device_module.entry_point( + hit_group.intersection_entry_point + ) + ) + for miss_entry_point in build_info.ray_tracing_miss_entry_points: + eps.append(build_info.module.device_module.entry_point(miss_entry_point)) + + program = session.link_program( + [module, build_info.module.device_module] + build_info.module.link, + eps, + opts, + ) + self.pipeline = device.create_ray_tracing_pipeline( + program, + hit_groups=build_info.ray_tracing_hit_groups, + max_recursion=build_info.ray_tracing_max_recursion, + max_ray_payload_size=build_info.ray_tracing_max_ray_payload_size, + max_attribute_size=build_info.ray_tracing_max_attribute_size, + flags=build_info.ray_tracing_flags, + defer_target_compilation=True, + label=f"{build_info.module.name}_{build_info.name}_rt_call", + ) + build_info.module.pipeline_cache[hash] = self.pipeline + self.shader_table = device.create_shader_table( + program, + ray_gen_entry_points=["raygen_main"], + miss_entry_points=build_info.ray_tracing_miss_entry_points, + hit_group_names=hit_group_names, + callable_entry_points=build_info.ray_tracing_callable_entry_points, + ) + build_info.module.shader_table_cache[hash] = self.shader_table + else: + raise RuntimeError("Unknown pipeline type") + self.device = device + self.log_debug(f" Build succesful") + + # Store the code as it's useful for debugging + self.code = code + def _build_autograd_access_list(self, args: list[Any], kwargs: dict[str, Any]) -> None: """ Walk args/kwargs in the same recursive order as find_torch_tensors, diff --git a/slangpy/tests/slangpy_tests/test_code_gen.py b/slangpy/tests/slangpy_tests/test_code_gen.py index 23ab05f06..7fae1644a 100644 --- a/slangpy/tests/slangpy_tests/test_code_gen.py +++ b/slangpy/tests/slangpy_tests/test_code_gen.py @@ -45,13 +45,15 @@ def assert_not_contains(code: str, *patterns: str) -> None: def assert_trampoline_has(code: str, *stmts: str) -> None: - """Assert trampoline contains statements (tolerates call_data vs __calldata__).""" + """Assert trampoline contains statements (tolerates call_data vs __calldata__ vs __in_).""" for s in stmts: if "__calldata__." in s: alt = s.replace("__calldata__.", "call_data.") + # Fast path: x = __in_x; instead of x = __calldata__.x; + alt2 = s.replace(" = __calldata__.", " = __in_") assert ( - s in code or alt in code - ), f"Expected trampoline statement not found: {s} (or {alt})" + s in code or alt in code or alt2 in code + ), f"Expected trampoline statement not found: {s} (or {alt} or {alt2})" else: assert s in code, f"Expected trampoline statement not found: {s}" @@ -720,7 +722,10 @@ def test_long_type_name_typealias(device_type: spy.DeviceType): device, "sum", long_src, {"_type": _LONG_STRUCT_NAME, "x": 1.0, "y": 2.0} ) assert_contains(code_long, f"typealias _t_s = {_LONG_STRUCT_NAME};") - assert_contains(code_long, "_t_s s;") + # Typealias used in entry-point param or CallData field + assert ( + "_t_s s;" in code_long or "uniform _t_s s" in code_long + ), "Expected typealias usage (_t_s s; or uniform _t_s s) not found" # --- Short name → no typealias --- short_src = f""" diff --git a/slangpy/tests/slangpy_tests/test_kernel_gen.py b/slangpy/tests/slangpy_tests/test_kernel_gen.py index 41d58af98..9388257b7 100644 --- a/slangpy/tests/slangpy_tests/test_kernel_gen.py +++ b/slangpy/tests/slangpy_tests/test_kernel_gen.py @@ -1773,7 +1773,10 @@ def test_gate_long_struct_name_gets_typealias(device_type: spy.DeviceType): ) # Long name → typealias _t_s emitted, CallData field declared as _t_s assert_contains(code, f"typealias _t_s = {_LONG_STRUCT_NAME};") - assert_contains(code, "_t_s s;") + # Typealias used in entry-point param or CallData field + assert ( + "_t_s s;" in code or "uniform _t_s s" in code + ), "Expected typealias usage (_t_s s; or uniform _t_s s) not found" @pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) From fadab3c13952b431a2a863c0d644bb2ffb30ba9d Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Fri, 13 Mar 2026 09:59:10 +0000 Subject: [PATCH 22/41] PR cleanup --- ...-simplifyKernelGenPhase2-cleanup.prompt.md | 546 ++++++++++++++++++ slangpy/benchmarks/test_benchmark_autograd.py | 6 +- slangpy/core/calldata.py | 34 +- slangpy/core/callsignature.py | 9 +- slangpy/tests/slangpy_tests/test_code_gen.py | 6 +- .../tests/slangpy_tests/test_kernel_gen.py | 3 +- src/slangpy_ext/utils/slangpy.cpp | 158 ++--- 7 files changed, 632 insertions(+), 130 deletions(-) create mode 100644 .github/prompts/plan-simplifyKernelGenPhase2-cleanup.prompt.md diff --git a/.github/prompts/plan-simplifyKernelGenPhase2-cleanup.prompt.md b/.github/prompts/plan-simplifyKernelGenPhase2-cleanup.prompt.md new file mode 100644 index 000000000..f0d5124cd --- /dev/null +++ b/.github/prompts/plan-simplifyKernelGenPhase2-cleanup.prompt.md @@ -0,0 +1,546 @@ +## Phase 2: Eliminate CallData Struct + +**Goal**: Move kernel uniforms out of the `CallData` struct into individual entry-point parameters. Eliminate the trampoline in forward (prim) mode. Fall back to `ParameterBlock` when total inline-uniform size exceeds a runtime per-device threshold. + +**Parent plan**: [plan-simplifyKernelGen.prompt.md](plan-simplifyKernelGen.prompt.md) + +--- + +### Key Architectural Decisions + +These decisions correct several assumptions in the original plan: + +1. **Entry-point param placement is orthogonal to `direct_bind`.** Any type — wrapped or raw — can be an entry-point parameter (e.g., `uniform ValueType a` or `uniform int a` or `uniform Tensor t`). `direct_bind` governs whether `__slangpy_load`/`__slangpy_store` is needed inside the kernel; entry-point placement governs where the uniform lives in the shader layout. + +2. **Trampoline elimination is independent of `direct_bind`.** The current trampoline body is: declare locals → load (direct assignment or `__slangpy_load`) → call function → store (`__slangpy_store`). All of that can appear directly in `compute_main`. The trampoline only exists because bwds mode needs a `[Differentiable]` wrapper for `bwd_diff()`. In prim mode, it is eliminated regardless of whether args use wrappers. + +3. **All-or-nothing fallback.** When total inline-uniform size exceeds the platform threshold, ALL args go back into `ParameterBlock` (the current path). No hybrid mixing of entry-point params and CallData. + +4. **Shape arrays and `_thread_count` obey the same rules** as user args — they become entry-point params by default, and go into `CallData` on fallback. Phase 2 is NOT scoped only to `call_data_len == 0`. + +5. **Two code paths based on where data lives:** + - **Fast path** (entry-point params): In Slang, uniforms are entry-point parameters and can be used directly (in forward) or passed directly to the trampoline (in backward). + - **Fallback path** (`ParameterBlock`): In Slang, uniforms live in a `CallData` struct. They must be read into local variables before being used (in forward) or passed to the trampoline (in backward). This is the current behavior. + +6. **C++ dispatch changes are isolated to `NativeCallData::exec`.** Marshalls receive a `ShaderCursor` pointing to wherever their data lives — they don't care whether it's inside a `CallData` struct or an entry-point param. In the fast path, `m_runtime->write_shader_cursor_pre_dispatch()` receives the entry-point cursor directly. No marshall code changes needed. + +7. **`CallDataMode` is eliminated.** The `global_data` vs `entry_point` distinction is removed entirely. On the fast path, all backends use entry-point params uniformly. On the fallback path, all backends use `ParameterBlock` — CUDA supports `ParameterBlock` and in practice will never hit the fallback (CUDA's inline-uniform limit is ~4KB). This removes the `CallDataMode` enum, the CUDA-specific `is_entry_point` codegen branch in `callsignature.py`, and the corresponding C++ branch in `slangpy.cpp`. + +8. **`PackedArg` / param-block types are unchanged.** They stay as `ParameterBlock` at module scope, orthogonal to Phase 2. + +--- + +### Current Kernel Structure (post-Phase 1) + +For `int add(int a, int b)` with scalar args `(1, 2)`: + +```slang +import "module"; +import "slangpy"; + +typealias _t_a = int; // Phase 1: raw type (was ValueType) +typealias _t__result = RWValueRef; // writable _result still wrapped +static const int _m__result = 0; // mapping constant only for _result + +struct CallData { + _t_a a; + _t_a b; + _t__result _result; + uint3 _thread_count; +}; + +void _trampoline(Context __slangpy_context__, CallData __calldata__) { + int a; + a = __calldata__.a; // Phase 1: direct assignment + int b; + b = __calldata__.b; // Phase 1: direct assignment + int _result; + _result = add(a, b); + __calldata__._result.__slangpy_store(__slangpy_context__.map(_m__result), _result); +} + +[shader("compute")] [numthreads(32,1,1)] +void compute_main(int3 flat_call_thread_id: SV_DispatchThreadID, ..., uniform CallData call_data) { + if (any(flat_call_thread_id >= call_data._thread_count)) return; + Context __slangpy_context__ = {flat_call_thread_id}; + _trampoline(__slangpy_context__, call_data); +} +``` + +### Target Kernel (Phase 2 fast path, prim mode, all direct-bind) + +```slang +import "module"; + +[shader("compute")] +[numthreads(32, 1, 1)] +void compute_main(int3 tid: SV_DispatchThreadID, + uniform uint3 _thread_count, + uniform int a, + uniform int b, + uniform RWStructuredBuffer _result) +{ + if (any(tid >= _thread_count)) return; + _result[0] = add(a, b); +} +``` + +### Target Kernel (Phase 2 fast path, prim mode, mixed direct/non-direct-bind) + +When some args are not direct-bind (e.g., WangHashArg needs per-thread `thread_id` via `__slangpy_load`), the non-direct-bind args still use their wrapper types as entry-point params. Context is needed: + +```slang +import "module"; +import "slangpy"; + +typealias _t_rng = WangHashArgType; // non-direct-bind wrapper type +static const int _m_rng = 0; + +[shader("compute")] +[numthreads(32, 1, 1)] +void compute_main(int3 flat_call_thread_id: SV_DispatchThreadID, + uniform uint3 _thread_count, + uniform _t_rng rng, + uniform int x, + uniform RWStructuredBuffer _result) +{ + if (any(flat_call_thread_id >= _thread_count)) return; + Context __slangpy_context__ = {flat_call_thread_id}; + int _rng_val; + rng.__slangpy_load(__slangpy_context__.map(_m_rng), _rng_val); + int _x_val; + _x_val = x; + int _result_val; + _result_val = func(_rng_val, _x_val); + _result[0] = _result_val; +} +``` + +### Target Kernel (Phase 2 fallback path, prim mode) + +When entry-point param size exceeds the platform limit, all args go into `ParameterBlock`. The trampoline is still eliminated in prim mode — the load/call/store is inlined into `compute_main`, reading from `call_data`: + +```slang +import "module"; +import "slangpy"; + +typealias _t_a = int; +typealias _t__result = RWValueRef; +static const int _m__result = 0; + +struct CallData { + _t_a a; + _t_a b; + _t__result _result; + uint3 _thread_count; +}; +ParameterBlock call_data; + +[shader("compute")] +[numthreads(32, 1, 1)] +void compute_main(int3 flat_call_thread_id: SV_DispatchThreadID, ...) { + if (any(flat_call_thread_id >= call_data._thread_count)) return; + Context __slangpy_context__ = {flat_call_thread_id}; + int a; + a = call_data.a; + int b; + b = call_data.b; + int _result; + _result = add(a, b); + call_data._result.__slangpy_store(__slangpy_context__.map(_m__result), _result); +} +``` + +--- + +### Step 2.0: Gating tests ✅ + +**Status: DONE** + +Tests added to [slangpy/tests/slangpy_tests/test_kernel_gen.py](slangpy/tests/slangpy_tests/test_kernel_gen.py). All 21 parametrized cases (7 tests × 3 device types) pass. + +| Test | Source | Args | Original assertion | Status | +|------|--------|------|--------------------|--------| +| `test_gate_p2_calldata_struct_present` | `int add(int a, int b)` | `(1, 2)` | `struct CallData` in code | ✅ Flipped — now asserts `struct CallData` ABSENT (Step 2.2 done) | +| `test_gate_p2_calldata_uniform_param` | same | same | `uniform CallData call_data` or `ParameterBlock` | ✅ Flipped — now asserts both ABSENT (Step 2.2 done) | +| `test_gate_p2_thread_count_in_calldata` | same | same | `call_data._thread_count` | ✅ Flipped — now asserts ABSENT (Step 2.2 done) | +| `test_gate_p2_trampoline_present_for_prim` | same | same | `void _trampoline(` present | Still asserts present (Step 2.3 pending) | +| `test_gate_p2_kernel_calls_trampoline` | same | same | `_trampoline(` in `compute_main` body | Still asserts present (Step 2.3 pending) | +| `test_gate_p2_sv_group_id_present` | same | same | `SV_GroupID` in `compute_main` signature | ✅ Flipped — now asserts ABSENT for dim-0 calls (Step 2.2 done) | + +Negative gates (must stay passing after Phase 2): + +| Test | Asserts | +|------|---------| +| `test_gate_p2_wanghasharg_keeps_load` | Non-direct-bind arg still uses `__slangpy_load` | + +Bwds gates: + +| Test | Status | +|------|--------| +| `test_gate_scalar_uses_valuetype` | ✅ Passing — asserts fast-path trampoline with `__in_` prefix params | +| `test_gate_bwds_scalar_uses_valuetype` | ✅ Passing — bwds trampoline has `no_diff` on all params (Step 2.4 done) | + +--- + +### Step 2.1: Determine fast vs fallback path ✅ + +**Status: DONE** + +In [slangpy/core/calldata.py](slangpy/core/calldata.py), after `calculate_direct_binding(bindings)`: + +1. **Query a runtime per-device threshold** for max entry-point parameter inline-uniform size. This is a property of the device/backend — large for D3D12/CUDA (thousands of bytes), potentially as low as 128–256 bytes on Vulkan. +2. **Accumulate inline-uniform byte size** of each bound variable's `calldata_type_name`, plus `_thread_count` (12 bytes) and shape arrays (`call_data_len * 3 * sizeof(int)` for `_grid_stride`, `_grid_dim`, `_call_dim`). **Resource types** (`RWStructuredBuffer`, `Texture2D`, `TensorView`, etc.) don't count — they are bound as descriptors, not inline data. +3. **Decision**: If total size ≤ threshold → `self.use_direct_args = True` (fast path). Otherwise → `self.use_direct_args = False` (fallback path — current behavior). +4. **Store** `use_direct_args` on the `CallData` instance and propagate to C++ `NativeCallData`. + +`PackedArg` / param-block types are excluded from this accounting — they stay as `ParameterBlock` regardless. + +**Implementation details:** + +- `DeviceLimits.max_entry_point_uniform_size` added to C++ struct ([device.h](src/sgl/device/device.h)) with per-backend defaults: Vulkan=128, D3D12=256, CUDA=4096 bytes ([device.cpp](src/sgl/device/device.cpp)). +- `calculate_inline_uniform_size()` added to [callsignature.py](slangpy/core/callsignature.py) — sums `vector_type.uniform_layout.size` for each depth-0 bound variable (skipping `PackedArg`), plus 12 bytes for `_thread_count` and `call_dimensionality * 4 * 3` for shape arrays. +- `use_direct_args` property added to `NativeCallData` C++ class ([slangpy.h](src/slangpy_ext/utils/slangpy.h)) with Python binding. +- `CallData.__init__()` in [calldata.py](slangpy/core/calldata.py) sets `self.use_direct_args = inline_size <= threshold` after `calculate_direct_binding()`. + +**Tests** (7 tests × 3 device types = 21 parametrized cases, all pass): + +| Test | Asserts | +|------|---------| +| `test_step21_scalar_uses_direct_args` | Simple `int add(int,int)` with `(1,2)` → `use_direct_args=True` | +| `test_step21_threshold_property_positive` | `device.info.limits.max_entry_point_uniform_size > 0` | +| `test_step21_vector_uses_direct_args` | `float3` args → `use_direct_args=True` | +| `test_step21_struct_uses_direct_args` | All-scalar struct dict → `use_direct_args=True` | +| `test_step21_tensor_uses_direct_args` | Tensor (descriptor-only, 0 inline bytes) → `use_direct_args=True` | +| `test_step21_many_float4x4_may_exceed_vulkan` | 8×float4x4 (524 bytes) exceeds Vulkan/D3D12 thresholds, not CUDA | +| `test_step21_wanghasharg_uses_direct_args` | Non-direct-bind WangHashArg with small inline size → `use_direct_args=True` | + +--- + +### Step 2.2: Code generation — entry-point params (fast path) ✅ + +**Status: DONE** + +In [slangpy/core/callsignature.py](slangpy/core/callsignature.py) `generate_code()`, when `use_direct_args == True`: + +**CodeGen changes** in [slangpy/bindings/codegen.py](slangpy/bindings/codegen.py): +- Add a `skip_call_data` flag to `CodeGen.__init__`. When `True`, don't emit `struct CallData` / `begin_block()` and gate `end_block()` in `finish()`. +- Add `self.entry_point_params: list[str] = []` to collect individual uniform param declarations. +- `finish()` ignores the `call_data` block and `use_param_block_for_call_data` when `skip_call_data` is set. + +**CallData struct elimination**: Set `cg.skip_call_data = True` when `use_direct_args`. No `struct CallData` emitted. + +**`gen_call_data_code` change** in [slangpy/bindings/boundvariable.py](slangpy/bindings/boundvariable.py): At `depth == 0`, when `use_direct_args`, append to `cg.entry_point_params` instead of `cg.call_data.declare(...)`. The `call_data_structs` block (type aliases, wrapper structs, mapping constants) still gets emitted at module scope. + +**`_thread_count` and shape arrays**: Instead of `cg.call_data.append_statement("uint3 _thread_count")`, append to `cg.entry_point_params`. Same for `_grid_stride`, `_grid_dim`, `_call_dim` when `call_data_len > 0`. + +**Entry-point signature**: `compute_main` signature becomes: +```slang +void compute_main( + int3 flat_call_thread_id: SV_DispatchThreadID, + [int3 flat_call_group_id: SV_GroupID,] // only when call_data_len > 0 + [int flat_call_group_thread_id: SV_GroupIndex,] // only when call_data_len > 0 + uniform uint3 _thread_count, + [uniform int[N] _grid_stride, ...] // only when call_data_len > 0 + uniform _t_a a, + uniform _t_b b, + uniform _t__result _result +) +``` + +Drop `SV_GroupID` and `SV_GroupIndex` when `call_data_len == 0` — they feed `init_thread_local_call_shape_info` which isn't called when there are no shape arrays. + +**Bounds check**: Changes from `call_data._thread_count` to just `_thread_count`. + +**Shape info init**: Changes from `call_data._grid_stride` etc. to just `_grid_stride`, `_grid_dim`, `_call_dim`. + +**Fallback path** (`use_direct_args == False`): `struct CallData` is emitted with `ParameterBlock call_data` at module scope on ALL backends (including CUDA). The old `CallDataMode` distinction between `entry_point` (CUDA) and `global_data` (non-CUDA) is removed — `ParameterBlock` works on CUDA, and in practice CUDA will never hit the fallback due to its large (~4KB) inline-uniform limit. + +See [slangpy/tests/device/test_pipeline_utils.slang](slangpy/tests/device/test_pipeline_utils.slang) for examples of manually-written compute shaders that use entry point parameters on all backends (CUDA, Vulkan, D3D12). + +--- + +### Step 2.3: Trampoline elimination for prim mode + +**Status: NOT STARTED** — Trampoline is still generated for prim mode on both paths. The load/call/store sequence needs to be inlined into `compute_main`. + +When `call_mode == prim` — on **both** fast and fallback paths: + +- Don't generate the `_trampoline` function. +- Inline the load/call/store sequence directly into `compute_main` after the bounds check and (if needed) Context construction. +- The load/call/store codegen reuses the same logic currently in [callsignature.py lines 378–449](slangpy/core/callsignature.py#L378-L449), but emitted into `cg.kernel` instead of `cg.trampoline` with adjusted `data_name`: + +| Path | `data_name` for non-param-block args | +|------|-------------------------------------| +| Fast | `x.variable_name` (entry-point param name directly) | +| Fallback | `call_data.{x.variable_name}` (global `ParameterBlock`, all backends) | +| Param blocks | `_param_{x.variable_name}` (unchanged) | + +**Context construction**: Needed only when any arg is non-direct-bind (i.e., calls `__slangpy_load`/`__slangpy_store`). When all args satisfy `direct_bind == True`, skip Context construction entirely — no `Context __slangpy_context__` declaration, no `import "slangpy"`. + +**Note**: The trampoline elimination does NOT depend on `direct_bind`. Even non-direct-bind args with `__slangpy_load` work inline in `compute_main` — the `__slangpy_load` call just needs the data reference and a `Context` value, both available in `compute_main`. + +--- + +### Step 2.4: Trampoline with individual params for bwds mode ✅ + +**Status: DONE** — Fast-path trampoline takes individual params with `no_diff` on all params. All 3 device types pass. + +When `call_mode == bwds`: + +- Still generate a `[Differentiable]` trampoline function. +- **Fast path**: Trampoline takes individual params instead of a struct. All params get `no_diff` — entry-point uniforms are never differentiable. Differentiation happens through local variable assignments inside the trampoline body, matching the struct-based approach where `CallData` was implicitly non-differentiable. No `in`/`out`/`inout` modifiers are added — `compute_main` passes its uniforms straight through: + ```slang + [Differentiable] + void _trampoline(Context __slangpy_context__, no_diff float __in_a, no_diff float __in_b, no_diff NoneType __in__result) + ``` + `compute_main` calls `bwd_diff(_trampoline)(__slangpy_context__, a, b, _result)` passing entry-point param names directly. +- **Fallback path**: Trampoline reads from global `ParameterBlock call_data` as it does today (on all backends). `compute_main` calls `bwd_diff(_trampoline)(__slangpy_context__, call_data)`. +- `_gen_trampoline_argument()` in `boundvariable.py` remains unused dead code — the inline generation in `callsignature.py` is simpler and avoids the `in`/`out`/`inout` modifiers that caused Slang autodiff errors. + +**Key insight**: Adding `in`/`out`/`inout` modifiers to trampoline params caused Slang autodiff issues (e.g., `out` params get reversed to `in` by `bwd_diff`, changing arity). The trampoline params are just pass-through uniforms — all data flow logic (loads, stores, differentiation) is handled internally via local variables. + +--- + +### Step 2.5: C++ dispatch changes ✅ + +**Status: DONE** — `CallDataMode` enum fully removed. Fast path uses `find_entry_point(0)` on all backends. Fallback path uses global `ParameterBlock` on all backends. + +In [src/slangpy_ext/utils/slangpy.cpp](src/slangpy_ext/utils/slangpy.cpp), store `m_use_direct_args` on `NativeCallData` (received from Python `CallData`). Also add to [slangpy.h](src/slangpy_ext/utils/slangpy.h). + +Modify `bind_call_data` lambda in `exec()`: + +**Fast path** (`m_use_direct_args == true`): +- All backends: Navigate via `cursor.find_entry_point(0)`. This is the entry-point cursor. +- Write `_thread_count` as an entry-point param: `entry_point_cursor["_thread_count"]`. +- Write shape arrays as entry-point params: `entry_point_cursor["_grid_stride"]`, etc. +- Pass `entry_point_cursor` as the `call_data_cursor` argument to `m_runtime->write_shader_cursor_pre_dispatch()`. Each `NativeBoundVariableRuntime` already navigates `cursor[m_variable_name]`, so it finds the entry-point param by name automatically. **No marshall code changes needed.** +- Cache entry-point param field indices on first call (analogous to existing `m_cached_call_data_offsets`). +- The `reserve_data` + raw-pointer optimization for `_thread_count` and shape arrays may not work for individual entry-point params at disjoint offsets. Use cursor-based writes for these metadata fields (they're small, performance impact minimal), or check if `reserve_data` still works across the entry-point shader object. + +**Fallback path** (`m_use_direct_args == false`): +- All backends: Navigate to global `call_data` field via `cursor.find_field("call_data")`, dereference (it's a `ParameterBlock`), write struct data. The old `CallDataMode` branch (CUDA using `find_entry_point(0)` for call_data) is removed. Remove `m_call_data_mode`, `CallDataMode` enum, and all associated branches from `slangpy.h`, `slangpy.cpp`, `calldata.py`, and `callsignature.py`. + +--- + +### Step 2.6: `_result` handling + +**Status: NOT STARTED** + +Auto-created `_result` is a writable `ValueRef`, currently NOT direct-bind eligible (needs `RWValueRef` wrapper with buffer logic). Phase 2 handles this differently on the two paths: + +**Fast path**: `_result` is emitted as `uniform RWValueRef _result` on the entry point. In prim mode, the inlined code stores via `_result.__slangpy_store(...)`. In the all-direct-bind case where Context is omitted, add a new code path: emit `uniform RWStructuredBuffer _result` with `_result[0] = value` for the store. This requires `ValueRefMarshall` to support writable direct-bind for the entry-point-param case specifically, using `RWStructuredBuffer` instead of `RWValueRef`. + +**Fallback path**: `_result` stays as `RWValueRef` inside `CallData`, same as current behavior. + +**Implementation note**: The `RWStructuredBuffer` approach for `_result` is only used when `use_direct_args == True` AND all other args are direct-bind (so Context can be omitted). When non-direct-bind args are present, Context exists and `_result` can continue to use `RWValueRef.__slangpy_store(context, value)`. + +--- + +### Step 2.7: Tests + +**Status: NOT STARTED** + +**Post-implementation tests** — should pass AFTER Phase 2 is complete: + +| Test | Verifies | +|------|----------| +| `test_phase2_no_calldata_struct` | `struct CallData` absent for eligible call | +| `test_phase2_uniform_params_on_entry` | Individual `uniform` params on `compute_main` | +| `test_phase2_no_trampoline_prim` | No `void _trampoline(` for prim-mode calls | +| `test_phase2_inline_call` | Function call inlined directly in `compute_main` | +| `test_phase2_thread_count_as_uniform` | `uniform uint3 _thread_count` as entry-point param | +| `test_phase2_no_context_all_direct` | No `Context __slangpy_context__` when all args direct-bind | +| `test_phase2_context_kept_non_direct` | `Context` present when some args use `__slangpy_load` | +| `test_phase2_bwds_trampoline_individual` | Bwds trampoline has individual params with `no_diff` | +| `test_phase2_bwds_bwd_diff_call` | `bwd_diff(_trampoline)(ctx, a, b, ...)` in kernel | +| `test_phase2_no_sv_group_when_dim0` | No `SV_GroupID`/`SV_GroupIndex` when `call_data_len == 0` | +| `test_phase2_sv_group_when_vectorized` | `SV_GroupID`/`SV_GroupIndex` present when `call_data_len > 0` | +| `test_phase2_fallback_keeps_calldata` | Force fallback → `struct CallData` still emitted | +| `test_phase2_fallback_no_trampoline_prim` | Even fallback path eliminates trampoline in prim mode | +| `test_phase2_functional_scalar_add` | `add(1, 2) == 3` end-to-end dispatch | +| `test_phase2_functional_bwds` | Backward pass correct gradients | +| `test_phase2_functional_vectorized` | Vectorized call (shapes) with entry-point params | +| `test_phase2_functional_mixed_direct` | Mix of direct-bind + non-direct-bind args | + +--- + +### Implementation Order + +1. **Step 2.0** ✅ — Gating tests (baseline documentation) +2. **Step 2.1** ✅ — Fast/fallback determination + size query +3. **Step 2.2 + 2.5** ✅ — Code gen + C++ dispatch for entry-point params + `CallDataMode` removal (landed together) +4. **Step 2.4** ✅ — Bwds trampoline with individual params (fast path) — `no_diff` on all params +5. **Step 2.3** — Trampoline elimination for prim mode (both paths) +6. **Step 2.6** — `_result` as `RWStructuredBuffer` for all-direct-bind case +7. **Step 2.7** — Post-implementation tests + functional tests + +**Note:** Implementation order deviated from original plan — Steps 2.2 + 2.5 were done before 2.3 (trampoline elimination), combined with `CallDataMode` removal. Step 2.4 done — all trampoline params use `no_diff` without IO modifiers. + +--- + +### Key Files + +| File | Changes | +|------|---------| +| [slangpy/core/calldata.py](slangpy/core/calldata.py) | ✅ `use_direct_args` flag, size threshold check, `CallDataMode` removed | +| [slangpy/core/callsignature.py](slangpy/core/callsignature.py) | ✅ Entry-point params, fast/fallback code paths, `is_entry_point` branch removed. Trampoline still generated (Step 2.3 pending). Bwds `no_diff` on all trampoline params (Step 2.4 done). | +| [slangpy/bindings/codegen.py](slangpy/bindings/codegen.py) | ✅ `skip_call_data` flag, `entry_point_params` list | +| [slangpy/bindings/boundvariable.py](slangpy/bindings/boundvariable.py) | ✅ `gen_call_data_code` depth-0 entry-point path. `_gen_trampoline_argument()` unused — inline generation in `callsignature.py` used instead. | +| [slangpy/bindings/marshall.py](slangpy/bindings/marshall.py) | ✅ `use_direct_args` field on `BindContext`, `CallDataMode` removed | +| [src/slangpy_ext/utils/slangpy.cpp](src/slangpy_ext/utils/slangpy.cpp) | ✅ `use_direct_args` binding; `bind_call_data` fast path via `find_entry_point(0)`, `CallDataMode` branches removed | +| [src/slangpy_ext/utils/slangpy.h](src/slangpy_ext/utils/slangpy.h) | ✅ `m_use_direct_args` on `NativeCallData`; `m_call_data_mode` removed | +| [src/sgl/device/device.h](src/sgl/device/device.h) | ✅ `max_entry_point_uniform_size` on `DeviceLimits` | +| [src/sgl/device/device.cpp](src/sgl/device/device.cpp) | ✅ Per-backend defaults for `max_entry_point_uniform_size` | +| [src/slangpy_ext/device/device.cpp](src/slangpy_ext/device/device.cpp) | ✅ Python binding for `max_entry_point_uniform_size` | +| [src/sgl/utils/slangpy.h](src/sgl/utils/slangpy.h) | ✅ `CallDataMode` enum removed | +| [slangpy/core/dispatchdata.py](slangpy/core/dispatchdata.py) | ✅ `CallDataMode` removed | +| [slangpy/core/packedarg.py](slangpy/core/packedarg.py) | ✅ `CallDataMode` removed | +| [slangpy/core/function.py](slangpy/core/function.py) | ✅ `CallDataMode` removed from imports | +| [slangpy/slangpy/__init__.pyi](slangpy/slangpy/__init__.pyi) | ✅ `CallDataMode` class and `call_data_mode` property removed | +| [slangpy/tests/slangpy_tests/test_type_resolution.py](slangpy/tests/slangpy_tests/test_type_resolution.py) | ✅ `CallDataMode` removed from `BindContext` creation | +| [slangpy/tests/slangpy_tests/test_kernel_gen.py](slangpy/tests/slangpy_tests/test_kernel_gen.py) | ✅ Gating tests + Step 2.1 tests updated for new behavior; post-implementation tests (Step 2.7) pending | + +--- + +### Verification + +```bash +# Build first (required) +cmake --build --preset windows-msvc-debug + +# Run kernel gen tests +$env:PRINT_TEST_KERNEL_GEN="1"; pytest slangpy/tests/slangpy_tests/test_kernel_gen.py -v + +# Run full test suite +pytest slangpy/tests -v + +# Run pre-commit +pre-commit run --all-files +``` + +--- + +### PR #862 Code Review — Proposed Improvements + +#### High Severity + +**1. Potential correctness bug — fast-path shape offset caching guarded by runtime data** + +In [slangpy.cpp](src/slangpy_ext/utils/slangpy.cpp) `bind_call_data`, the fast-path caching block guards shape offset caching with `call_shape.size() > 0`. If the *first* call to a multi-dimensional `NativeCallData` uses `has_thread_count=true` (which returns empty `call_shape`), shape offsets won't be cached. A subsequent normal call would find `is_valid == true` but shape offsets would be uninitialized, leading to writes at garbage offsets. The fallback path is more robust, using `call_dim.is_valid()` instead. + +**DO NOT FIX**: Reason: The '_thread_count' is written to the call signature, so by definition a given call data would never be used in both situations. + +**2. Benchmark changes are debugging artifacts** + +[test_benchmark_autograd.py](slangpy/benchmarks/test_benchmark_autograd.py) changes `ITERATIONS` 10→100, `WARMUPS` 10→1000, `RUN_SLANGTORCH_BENCHMARK` False→True. This will make CI benchmarks 10–100× slower. Revert to original values. + +**FIX**: Restore changes to original values. + +**3. Overly broad `except Exception` in calldata.py fallback** + +[calldata.py](slangpy/core/calldata.py): The fallback from fast path to `ParameterBlock` catches `except Exception`, which swallows `TypeError`, `KeyError`, `AttributeError`, etc. The caught exception `e` is never logged. + +**Fix**: Narrow to the specific compilation exception (RunTimeError) Log `str(e)` in the debug message. + +--- + +#### Medium Severity — Structural + +**4. `generate_code()` in callsignature.py is too long (~334 lines)** + +Extract into sub-functions: + +| Lines | Extract to | Purpose | +|-------|-----------|---------| +| ~L294–L339 | `_validate_and_compute_group_shape()` | Group shape validation & stride computation | +| ~L341–L388 | `_generate_link_time_constants()` | Link-time constants (group shape/stride arrays) | +| ~L390–L409 | `_generate_shape_params()` | Shape array & `_thread_count` param gen (fast/fallback) | +| ~L415–L517 | `_generate_trampoline()` | Trampoline function (signature, loads, call, stores) | +| ~L520–L565 | `_generate_entry_point_signature()` | Compute/ray-tracing entry-point signature | +| ~L567–L604 | `_generate_kernel_body()` | Kernel body (bounds check, shape init, dispatch) | + +Additionally, the duplicated `data_name` computation at ~L449 and ~L497 should be extracted: +```python +def _data_name(x: BoundVariable, use_direct_args: bool) -> str: + if x.create_param_block: + return f"_param_{x.variable_name}" + return f"__in_{x.variable_name}" if use_direct_args else f"call_data.{x.variable_name}" +``` + +**DO NOT FIX** Reason: This is a complex change and will be deferred to a later step. + +**5. `bind_call_data` in slangpy.cpp has ~70 lines of duplicated write logic** + +The `reserve_data` + `write_strided_array_helper` ×3 + `write_value_helper` + `write_shader_cursor_pre_dispatch` sequence is identical between fast and fallback paths. Extract a helper that takes a `ShaderCursor`: + +```cpp +auto write_uniforms = [&](ShaderCursor target) { + ShaderObject* so = target.shader_object(); + void* base = so->reserve_data(offsets.field_offset, offsets.field_size); + // ... write shape arrays, thread_count ... + m_runtime->write_shader_cursor_pre_dispatch(context, cursor, target, ...); +}; +``` + +Fast path → `write_uniforms(ep)`, fallback → `write_uniforms(call_data_cursor)`. + +**FIX**: Create helper as above. + +**6. `_try_build_shader` parameter pattern in calldata.py** + +Takes `use_direct_args` parameter then immediately sets `self.use_direct_args` and `context.use_direct_args`. The method never reads the flag except to store it. + +**Fix**: Have the caller set these directly before calling, remove the parameter. + +--- + +#### Low Severity + +**7. Unconditional `print(code)` in test_kernel_gen.py L107** — should be guarded by `PRINT_TEST_KERNEL_GEN` env var. + +**FIX**: Wrap in `if os.getenv("PRINT_TEST_KERNEL_GEN"):` and add `import os` at the top. + +**8. Test duplication** — ~30 tests near-identical between test_kernel_gen.py and test_code_gen.py. The merged tests in test_code_gen.py should replace the originals. + +**DO NOT FIX**: Reason: The kernel gen tests are temporary, designed for gating, and will be deleted once phases are complete. + +**9. Unused `nodes` variable** — [callsignature.py L278](slangpy/core/callsignature.py): `nodes: list[BoundVariable] = []` declared but never used. + +**FIX**: If not needed it can be deleted + +**10. Stale docstring** — [callsignature.py L275](slangpy/core/callsignature.py): Says "Generate a list of call data nodes" — doesn't match what the function does. + +**FIX**: Update docstring to reflect current behavior. + +**11. Missing return type annotations** — `generate_code()`, `generate_constants()`, `CallData.build()` all need `-> None`. + +**FIX**: Add `-> None` return type annotations to these methods. + +**12. `type_conformances: Any`** — [calldata.py](slangpy/core/calldata.py) should be `list[TypeConformance]`. + +**FIX**: Change type annotation to `list[TypeConformance]` and import `TypeConformance` from `slangpy.core.typeconformance`. + +**13. Bare `except:`** — [callsignature.py L59](slangpy/core/callsignature.py): `is_generic_vector` catches all exceptions including `SystemExit`. Use `except Exception:`. + +**FIX**: Change to `except Exception:` and log the exception for debugging. + +**14. Typo: `santized_module`** — [calldata.py](slangpy/core/calldata.py): Missing 'i'. Pre-existing. + +**DO NOT FIX**: Reason: Cosmetic typo in a variable name that's used in multiple places. Fixing would require renaming across the file, which is low value and risks introducing bugs. + +**15. D3D12 `max_entry_point_uniform_size = 256` may be optimistic** — root descriptors consume some of the 64-DWORD root signature budget. Comment should note shared budget; consider smaller default. + +**DO NOT FIX**: Reason: More complex logic is actually needed and can be addressed later. + +**16. Fallback path always includes `SV_GroupID`/`SV_GroupIndex`** — even when `call_data_len == 0`. Asymmetric with fast path. + +**DO NOT FIX**: Reason: Can be addressed later. + +**17. Hash salt `"[CallData]\n"`** — emitted even when CallData struct is absent. Cosmetic. + +**FIX**: Can be removed or left as-is, low impact. + +**18. `Tuple` import in test_code_gen.py** — should use lowercase `tuple[...]` for consistency. + +**FIX**: Change to `tuple[...]` and remove `from typing import Tuple`. diff --git a/slangpy/benchmarks/test_benchmark_autograd.py b/slangpy/benchmarks/test_benchmark_autograd.py index 94739cc8d..524794d8e 100644 --- a/slangpy/benchmarks/test_benchmark_autograd.py +++ b/slangpy/benchmarks/test_benchmark_autograd.py @@ -27,9 +27,9 @@ pass SLEEPS = True -ITERATIONS = 100 +ITERATIONS = 10 SUB_ITERATIONS = 20000 -WARMUPS = 1000 +WARMUPS = 10 # ITERATIONS = 1 # SUB_ITERATIONS = 1 @@ -49,7 +49,7 @@ # ============================================================================= RUN_PURE_TORCH_BENCHMARK = False -RUN_SLANGTORCH_BENCHMARK = True +RUN_SLANGTORCH_BENCHMARK = False RUN_SLANGPY_MANUAL_HOOK_BENCHMARK = True RUN_SLANGPY_AUTOMATIC_BENCHMARK = True AUTOGRAD_TENSOR_SIZE = 32 diff --git a/slangpy/core/calldata.py b/slangpy/core/calldata.py index 2ad6942ec..947d84e1d 100644 --- a/slangpy/core/calldata.py +++ b/slangpy/core/calldata.py @@ -19,6 +19,7 @@ SlangLinkOptions, NativeHandle, DeviceType, + TypeConformance, is_torch_bridge_using_fallback, ) from slangpy.bindings import ( @@ -112,7 +113,7 @@ def __init__( build_info = func.calc_build_info() self.build(build_info, *args, **kwargs) - def build(self, build_info: "FunctionBuildInfo", *args: Any, **kwargs: Any): + def build(self, build_info: "FunctionBuildInfo", *args: Any, **kwargs: Any) -> None: self.has_thread_count = "_thread_count" in kwargs try: @@ -269,29 +270,37 @@ def build(self, build_info: "FunctionBuildInfo", *args: Any, **kwargs: Any): f"use_direct_args: {use_direct_args}" ) + # Until https://github.com/shader-slang/slang-rhi/pull/676, Vk RTP can't use entry point args + if ( + build_info.pipeline_type == PipelineType.ray_tracing + and build_info.module.device.info.type == DeviceType.vulkan + ): + use_direct_args = False + # Try building the shader. If direct args compilation fails (the # threshold is only an approximate heuristic), fall back to # ParameterBlock. try: + self.use_direct_args = use_direct_args self._try_build_shader( context, build_info, bindings, type_conformances, - use_direct_args=use_direct_args, ) - except Exception as e: + except RuntimeError as e: if not use_direct_args: raise self.log_debug( - " Direct args compilation failed, " "retrying with ParameterBlock" + f" Direct args compilation failed ({e}), " + "retrying with ParameterBlock" ) + self.use_direct_args = False self._try_build_shader( context, build_info, bindings, type_conformances, - use_direct_args=False, ) # Store the bindings and runtime for later use. @@ -364,23 +373,20 @@ def _try_build_shader( context: BindContext, build_info: "FunctionBuildInfo", bindings: BoundCall, - type_conformances: Any, - use_direct_args: bool, + type_conformances: list["TypeConformance"], ) -> None: """ Generate shader code and build the pipeline. - Sets self.use_direct_args, self.pipeline, self.device, self.code, + Sets self.pipeline, self.device, self.code, and optionally self.shader_table. :param context: Binding context. :param build_info: Function build information. :param bindings: Bound call with resolved variables. :param type_conformances: Type conformances for entry point. - :param use_direct_args: If True, use entry-point params; otherwise ParameterBlock. """ - self.use_direct_args = use_direct_args - context.use_direct_args = use_direct_args + context.use_direct_args = self.use_direct_args # Generate code. codegen = CodeGen() @@ -398,7 +404,7 @@ def _try_build_shader( snippets=True, call_data_structs=True, constants=True, - use_param_block_for_call_data=not use_direct_args, + use_param_block_for_call_data=not context.use_direct_args, ) # Optionally write the shader to a file for debugging. @@ -447,9 +453,7 @@ def _try_build_shader( # Hash the code to get a unique identifier for the module. # We add type conformances to the start of the code to ensure that the hash is unique - code_minus_header = ( - "[CallData]\n" + str(build_info.type_conformances) + code[len(codegen.header) :] - ) + code_minus_header = str(build_info.type_conformances) + code[len(codegen.header) :] hash = hashlib.sha256(code_minus_header.encode()).hexdigest() # Check if we've already built this module. diff --git a/slangpy/core/callsignature.py b/slangpy/core/callsignature.py index 9fdaf7140..524fc650a 100644 --- a/slangpy/core/callsignature.py +++ b/slangpy/core/callsignature.py @@ -55,7 +55,7 @@ def is_generic_vector(type: TypeReflection) -> bool: try: if type.scalar_type != TypeReflection.Kind.none and type.col_count > 0: # @IgnoreException return False - except: + except Exception: return True return True @@ -247,7 +247,7 @@ def is_slangpy_vector(type: Any) -> bool: ) -def generate_constants(build_info: "FunctionBuildInfo", cg: CodeGen): +def generate_constants(build_info: "FunctionBuildInfo", cg: CodeGen) -> None: if build_info.constants is not None: for k, v in build_info.constants.items(): if isinstance(v, bool): @@ -272,11 +272,10 @@ def generate_code( build_info: "FunctionBuildInfo", signature: BoundCall, cg: CodeGen, -): +) -> None: """ - Generate a list of call data nodes that will be used to generate the call + Generate Slang kernel code for the given function call signature. """ - nodes: list[BoundVariable] = [] # Check if we're using direct entry-point params (fast path) use_direct_args = context.use_direct_args diff --git a/slangpy/tests/slangpy_tests/test_code_gen.py b/slangpy/tests/slangpy_tests/test_code_gen.py index 7fae1644a..9992e9d02 100644 --- a/slangpy/tests/slangpy_tests/test_code_gen.py +++ b/slangpy/tests/slangpy_tests/test_code_gen.py @@ -13,7 +13,7 @@ by other test files (``test_simple_function_call.py``, ``test_tensor.py``, etc.). """ -from typing import Any, Tuple +from typing import Any import numpy as np import os @@ -60,7 +60,7 @@ def assert_trampoline_has(code: str, *stmts: str) -> None: def generate_code_and_bindings( device: spy.Device, func_name: str, module_source: str, *args: Any, **kwargs: Any -) -> Tuple[str, Any]: +) -> tuple[str, Any]: """Generate code and return ``(code_str, bindings)`` from a single ``debug_build_call_data`` call.""" func = helpers.create_function_from_module(device, func_name, module_source) cd = func.debug_build_call_data(*args, **kwargs) @@ -71,7 +71,7 @@ def generate_code_and_bindings( def generate_bwds_code_and_bindings( device: spy.Device, func_name: str, module_source: str, *args: Any, **kwargs: Any -) -> Tuple[str, Any]: +) -> tuple[str, Any]: """Generate backwards-mode code and return ``(code_str, bindings)``.""" func = helpers.create_function_from_module(device, func_name, module_source) cd = func.bwds.debug_build_call_data(*args, **kwargs) diff --git a/slangpy/tests/slangpy_tests/test_kernel_gen.py b/slangpy/tests/slangpy_tests/test_kernel_gen.py index 9388257b7..0866c0422 100644 --- a/slangpy/tests/slangpy_tests/test_kernel_gen.py +++ b/slangpy/tests/slangpy_tests/test_kernel_gen.py @@ -105,7 +105,8 @@ def test_kernel_gen_basic(device_type: spy.DeviceType): """ device = helpers.get_device(device_type) code = generate_code(device, "add", src, 1, 2) - print(code) + if PRINT_TEST_KERNEL_GEN: + print(code) assert "add" in code diff --git a/src/slangpy_ext/utils/slangpy.cpp b/src/slangpy_ext/utils/slangpy.cpp index 98b7c8c03..f3e7e086a 100644 --- a/src/slangpy_ext/utils/slangpy.cpp +++ b/src/slangpy_ext/utils/slangpy.cpp @@ -807,6 +807,58 @@ nb::object NativeCallData::exec( ); } + auto write_uniforms = [&](ShaderCursor target, ShaderCursor root_cursor) + { + // Reserve memory block for all uniform fields + ShaderObject* shader_object = target.shader_object(); + void* base_address = shader_object->reserve_data( + m_cached_call_data_offsets.field_offset, + m_cached_call_data_offsets.field_size + ); + + if (call_shape.size() > 0) { + // Write shape arrays using cached offsets + write_strided_array_helper( + base_address, + m_cached_call_data_offsets.call_dim.uniform_offset + - m_cached_call_data_offsets.field_offset.uniform_offset, + call_shape.data(), + call_shape.size(), + m_cached_call_data_offsets.array_stride + ); + + write_strided_array_helper( + base_address, + m_cached_call_data_offsets.grid_stride.uniform_offset + - m_cached_call_data_offsets.field_offset.uniform_offset, + call_grid_strides.data(), + call_grid_strides.size(), + m_cached_call_data_offsets.array_stride + ); + + write_strided_array_helper( + base_address, + m_cached_call_data_offsets.grid_dim.uniform_offset + - m_cached_call_data_offsets.field_offset.uniform_offset, + call_grid_shape.data(), + call_grid_shape.size(), + m_cached_call_data_offsets.array_stride + ); + } + + // Write thread count + uint3 thread_count_value(total_threads, 1, 1); + write_value_helper( + base_address, + m_cached_call_data_offsets.thread_count.uniform_offset + - m_cached_call_data_offsets.field_offset.uniform_offset, + thread_count_value + ); + + m_runtime + ->write_shader_cursor_pre_dispatch(context, root_cursor, target, unpacked_args, unpacked_kwargs, read_back); + }; + auto bind_call_data = [&](ShaderCursor cursor) { if (m_use_direct_args) { @@ -830,54 +882,7 @@ nb::object NativeCallData::exec( m_cached_call_data_offsets.is_valid = true; } - // Reserve memory block for all entry-point uniform fields - ShaderObject* shader_object = ep.shader_object(); - void* base_address = shader_object->reserve_data( - m_cached_call_data_offsets.field_offset, - m_cached_call_data_offsets.field_size - ); - - if (call_shape.size() > 0) { - // Write shape arrays using cached offsets - write_strided_array_helper( - base_address, - m_cached_call_data_offsets.call_dim.uniform_offset - - m_cached_call_data_offsets.field_offset.uniform_offset, - call_shape.data(), - call_shape.size(), - m_cached_call_data_offsets.array_stride - ); - - write_strided_array_helper( - base_address, - m_cached_call_data_offsets.grid_stride.uniform_offset - - m_cached_call_data_offsets.field_offset.uniform_offset, - call_grid_strides.data(), - call_grid_strides.size(), - m_cached_call_data_offsets.array_stride - ); - - write_strided_array_helper( - base_address, - m_cached_call_data_offsets.grid_dim.uniform_offset - - m_cached_call_data_offsets.field_offset.uniform_offset, - call_grid_shape.data(), - call_grid_shape.size(), - m_cached_call_data_offsets.array_stride - ); - } - - // Write thread count - uint3 thread_count_value(total_threads, 1, 1); - write_value_helper( - base_address, - m_cached_call_data_offsets.thread_count.uniform_offset - - m_cached_call_data_offsets.field_offset.uniform_offset, - thread_count_value - ); - - // Pass entry-point cursor as call_data_cursor — marshalls navigate ep[var_name] - m_runtime->write_shader_cursor_pre_dispatch(context, cursor, ep, unpacked_args, unpacked_kwargs, read_back); + write_uniforms(ep, cursor); } else { // ---- Fallback path: ParameterBlock at module scope (all backends) ---- // On first call, cache all field indices and offsets @@ -906,67 +911,14 @@ nb::object NativeCallData::exec( m_cached_call_data_offsets.is_valid = true; } - // Fast path: use cached field index to find call_data cursor + // Use cached field index to find call_data cursor ShaderCursor call_data_cursor = cursor.get_field_by_index(m_cached_call_data_offsets.call_data_field_index); // Dereference the cursor if needed (using cached result) if (m_cached_call_data_offsets.call_data_is_reference) call_data_cursor = call_data_cursor.dereference(); - // Reserve memory block for all call data fields - ShaderObject* shader_object = call_data_cursor.shader_object(); - void* base_address = shader_object->reserve_data( - m_cached_call_data_offsets.field_offset, - m_cached_call_data_offsets.field_size - ); - - if (call_shape.size() > 0) { - // Write arrays using cached offsets and direct memory access - write_strided_array_helper( - base_address, - m_cached_call_data_offsets.call_dim.uniform_offset - - m_cached_call_data_offsets.field_offset.uniform_offset, - call_shape.data(), - call_shape.size(), - m_cached_call_data_offsets.array_stride - ); - - write_strided_array_helper( - base_address, - m_cached_call_data_offsets.grid_stride.uniform_offset - - m_cached_call_data_offsets.field_offset.uniform_offset, - call_grid_strides.data(), - call_grid_strides.size(), - m_cached_call_data_offsets.array_stride - ); - - write_strided_array_helper( - base_address, - m_cached_call_data_offsets.grid_dim.uniform_offset - - m_cached_call_data_offsets.field_offset.uniform_offset, - call_grid_shape.data(), - call_grid_shape.size(), - m_cached_call_data_offsets.array_stride - ); - } - - // Write thread count - uint3 thread_count_value(total_threads, 1, 1); - write_value_helper( - base_address, - m_cached_call_data_offsets.thread_count.uniform_offset - - m_cached_call_data_offsets.field_offset.uniform_offset, - thread_count_value - ); - - m_runtime->write_shader_cursor_pre_dispatch( - context, - cursor, - call_data_cursor, - unpacked_args, - unpacked_kwargs, - read_back - ); + write_uniforms(call_data_cursor, cursor); } nb::list uniforms = opts->uniforms(); From ef17ff5346361875aca22b4cdc12038f8720d6ad Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Fri, 13 Mar 2026 10:28:39 +0000 Subject: [PATCH 23/41] More PR fixes --- ...-simplifyKernelGenPhase2-cleanup.prompt.md | 91 ++++++++++++++++--- slangpy/bindings/boundvariable.py | 15 +-- src/sgl/device/device.cpp | 10 +- 3 files changed, 85 insertions(+), 31 deletions(-) diff --git a/.github/prompts/plan-simplifyKernelGenPhase2-cleanup.prompt.md b/.github/prompts/plan-simplifyKernelGenPhase2-cleanup.prompt.md index f0d5124cd..02d9d1ec6 100644 --- a/.github/prompts/plan-simplifyKernelGenPhase2-cleanup.prompt.md +++ b/.github/prompts/plan-simplifyKernelGenPhase2-cleanup.prompt.md @@ -435,13 +435,13 @@ In [slangpy.cpp](src/slangpy_ext/utils/slangpy.cpp) `bind_call_data`, the fast-p [test_benchmark_autograd.py](slangpy/benchmarks/test_benchmark_autograd.py) changes `ITERATIONS` 10→100, `WARMUPS` 10→1000, `RUN_SLANGTORCH_BENCHMARK` False→True. This will make CI benchmarks 10–100× slower. Revert to original values. -**FIX**: Restore changes to original values. +**FIXED**: Restored `ITERATIONS=10`, `WARMUPS=10`, `RUN_SLANGTORCH_BENCHMARK=False`. **3. Overly broad `except Exception` in calldata.py fallback** [calldata.py](slangpy/core/calldata.py): The fallback from fast path to `ParameterBlock` catches `except Exception`, which swallows `TypeError`, `KeyError`, `AttributeError`, etc. The caught exception `e` is never logged. -**Fix**: Narrow to the specific compilation exception (RunTimeError) Log `str(e)` in the debug message. +**FIXED**: Narrowed to `except RuntimeError as e` and included `str(e)` in the debug message. --- @@ -485,13 +485,13 @@ auto write_uniforms = [&](ShaderCursor target) { Fast path → `write_uniforms(ep)`, fallback → `write_uniforms(call_data_cursor)`. -**FIX**: Create helper as above. +**FIXED**: Extracted `write_uniforms` lambda taking `(ShaderCursor target, ShaderCursor root_cursor)`. Fast path calls `write_uniforms(ep, cursor)`, fallback calls `write_uniforms(call_data_cursor, cursor)`. **6. `_try_build_shader` parameter pattern in calldata.py** Takes `use_direct_args` parameter then immediately sets `self.use_direct_args` and `context.use_direct_args`. The method never reads the flag except to store it. -**Fix**: Have the caller set these directly before calling, remove the parameter. +**FIXED**: Caller sets `self.use_direct_args` before calling; `_try_build_shader` reads `self.use_direct_args` and sets `context.use_direct_args`. Parameter removed. --- @@ -499,7 +499,7 @@ Takes `use_direct_args` parameter then immediately sets `self.use_direct_args` a **7. Unconditional `print(code)` in test_kernel_gen.py L107** — should be guarded by `PRINT_TEST_KERNEL_GEN` env var. -**FIX**: Wrap in `if os.getenv("PRINT_TEST_KERNEL_GEN"):` and add `import os` at the top. +**FIXED**: Guarded with `if PRINT_TEST_KERNEL_GEN:` (existing module-level flag). **8. Test duplication** — ~30 tests near-identical between test_kernel_gen.py and test_code_gen.py. The merged tests in test_code_gen.py should replace the originals. @@ -507,23 +507,23 @@ Takes `use_direct_args` parameter then immediately sets `self.use_direct_args` a **9. Unused `nodes` variable** — [callsignature.py L278](slangpy/core/callsignature.py): `nodes: list[BoundVariable] = []` declared but never used. -**FIX**: If not needed it can be deleted +**FIXED**: Deleted unused variable. **10. Stale docstring** — [callsignature.py L275](slangpy/core/callsignature.py): Says "Generate a list of call data nodes" — doesn't match what the function does. -**FIX**: Update docstring to reflect current behavior. +**FIXED**: Updated to "Generate Slang kernel code for the given function call signature." **11. Missing return type annotations** — `generate_code()`, `generate_constants()`, `CallData.build()` all need `-> None`. -**FIX**: Add `-> None` return type annotations to these methods. +**FIXED**: Added `-> None` to `generate_code()`, `generate_constants()`, `CallData.build()`, and `_try_build_shader()`. **12. `type_conformances: Any`** — [calldata.py](slangpy/core/calldata.py) should be `list[TypeConformance]`. -**FIX**: Change type annotation to `list[TypeConformance]` and import `TypeConformance` from `slangpy.core.typeconformance`. +**FIXED**: Changed to `list["TypeConformance"]` and added `TypeConformance` to the `from slangpy import (...)` block. **13. Bare `except:`** — [callsignature.py L59](slangpy/core/callsignature.py): `is_generic_vector` catches all exceptions including `SystemExit`. Use `except Exception:`. -**FIX**: Change to `except Exception:` and log the exception for debugging. +**FIXED**: Changed to `except Exception:`. **14. Typo: `santized_module`** — [calldata.py](slangpy/core/calldata.py): Missing 'i'. Pre-existing. @@ -539,8 +539,75 @@ Takes `use_direct_args` parameter then immediately sets `self.use_direct_args` a **17. Hash salt `"[CallData]\n"`** — emitted even when CallData struct is absent. Cosmetic. -**FIX**: Can be removed or left as-is, low impact. +**FIXED**: Removed `"[CallData]\n"` prefix from hash salt. **18. `Tuple` import in test_code_gen.py** — should use lowercase `tuple[...]` for consistency. -**FIX**: Change to `tuple[...]` and remove `from typing import Tuple`. +**FIXED**: Changed to `tuple[...]` and removed `Tuple` from typing import. + +--- + +#### Additional Findings (subagent review, March 2026) + +**19. Latent correctness bug — `can_direct_bind_common()` missing write-access guard** + +[boundvariable.py](slangpy/bindings/boundvariable.py) `can_direct_bind_common()` does not check whether the binding has write access. This creates an inconsistency: + +- `ValueRefMarshall.can_direct_bind()` explicitly rejects writable bindings — correct +- `StructMarshall.can_direct_bind()` with children checks `access[0] == AccessType.read` — correct +- `StructMarshall.can_direct_bind()` without children falls through to `can_direct_bind_common()` — **missing access check** +- `ValueMarshall.can_direct_bind()` delegates entirely to `can_direct_bind_common()` — safe in practice (`ValueMarshall.is_writable = False`) but fragile + +If a writable dim-0 leaf binding gets `direct_bind=True`, `ValueMarshall.gen_trampoline_store()` returns `True` without emitting store code, silently dropping writes. + +**DO NOT FIX**: Reasion: This logic is subtle but correct, based on the desired behaviour. + +**20. Dead `_gen_trampoline_argument()` method** + +[boundvariable.py](slangpy/bindings/boundvariable.py) `_gen_trampoline_argument()` is never called anywhere in the codebase. The inline generation in [callsignature.py](slangpy/core/callsignature.py) replaced it. + +**FIXED**: Deleted the method. + +**21. Redundant `hasattr` guard in `calculate_direct_bind()`** + +[boundvariable.py](slangpy/bindings/boundvariable.py) `calculate_direct_bind()` uses `hasattr(self.python, "can_direct_bind")`, which is always `True` because `Marshall` base class defines `can_direct_bind()`. Simplify to `if self.python is not None:`. + +**DO NOT FIX**: Reason: For marshalls that inherit directly from NativeMarshall, this is not necessarily true. + +**22. Unnecessary `getattr` in `can_direct_bind_common()`** + +[boundvariable.py](slangpy/bindings/boundvariable.py) `can_direct_bind_common()` uses `getattr(binding, "create_param_block", False)`. `BoundVariable.__init__()` always sets `create_param_block`, so `binding.create_param_block` suffices. + +**FIXED**: Replaced `getattr(binding, "create_param_block", False)` with `binding.create_param_block`. + +**23. Wasteful `CodeGen.call_data` initialization when `skip_call_data=True`** + +[codegen.py](slangpy/bindings/codegen.py) `__init__` unconditionally calls `self.call_data.append_line("struct CallData")` and `begin_block()`, even when `skip_call_data=True`. The block is never serialized so there's no output impact, but it allocates a dangling block object. + +**DO NOT FIX**: Reason: Harmless — the block is never emitted. Restructuring `__init__` to conditionally skip initialization adds complexity for no functional benefit. + +**24. `entry_point_params` ownership pattern undocumented** + +[codegen.py](slangpy/bindings/codegen.py) collects `entry_point_params` via `boundvariable.py`, but [callsignature.py](slangpy/core/callsignature.py) reads and emits them. This cross-module ownership pattern is unconventional and lacks a comment explaining the flow. + +**DO NOT FIX**: Reason: `CodeGen` is already a shared state bag consumed by multiple modules. Adding a comment is fine but not blocking. + +**25. `direct_bind` and `use_direct_args` exposed as read-write in `.pyi` stubs** + +[__init__.pyi](slangpy/slangpy/__init__.pyi) exposes `direct_bind` on `NativeBoundVariableRuntime` and `use_direct_args` on `NativeCallData` with setters. Mutating these after first dispatch could invalidate cached cursor offsets in `NativeValueMarshall::ensure_cached`. + +**DO NOT FIX**: Reason: These are set during `CallData` construction before first dispatch. The cached `NativeCallData` is per-signature, so a new signature gets a fresh instance. Post-construction mutation would require going through `debug_build_call_data` which rebuilds everything. Not a practical concern. + +**26. No fallback-path codegen test in `test_code_gen.py`** + +[test_code_gen.py](slangpy/tests/slangpy_tests/test_code_gen.py) has no test that forces `use_direct_args=False` (e.g., by exceeding `max_entry_point_uniform_size`) and asserts the `ParameterBlock` codegen. The `test_step21_many_float4x4_may_exceed_vulkan` in `test_kernel_gen.py` checks the flag but not the generated code. + +**DO NOT FIX**: Reason: Step 2.7 will add comprehensive post-implementation tests including `test_phase2_fallback_keeps_calldata` and `test_phase2_fallback_no_trampoline_prim`. + +**27. No test for writable `inout` struct at dim-0** + +No test verifies the behavior of a writable (inout) dim-0 struct with all-scalar fields. This is the scenario where Fix 19 would prevent silent write loss. + +**Fix**: Add after Fix 19 is applied —test a writable dim-0 struct dict to confirm `direct_bind=False`. + +**Status: NOT FIXED** — blocked on Fix 19. diff --git a/slangpy/bindings/boundvariable.py b/slangpy/bindings/boundvariable.py index 04b88b602..baed05091 100644 --- a/slangpy/bindings/boundvariable.py +++ b/slangpy/bindings/boundvariable.py @@ -163,7 +163,7 @@ def can_direct_bind_common(binding: "BoundVariable") -> bool: return False if binding.children: return False - if getattr(binding, "create_param_block", False): + if binding.create_param_block: return False return True @@ -692,19 +692,6 @@ def gen_call_data_code(self, cg: CodeGen, context: BindContext, depth: int = 0): else: cg.call_data.declare(self.calldata_type_name, self.variable_name) - def _gen_trampoline_argument(self): - assert self.vector_type is not None - arg_def = f"{self.vector_type.full_name} {self.variable_name}" - if self.io_type == IOType.inout: - arg_def = f"inout {arg_def}" - elif self.io_type == IOType.out: - arg_def = f"out {arg_def}" - elif self.io_type == IOType.inn: - arg_def = f"in {arg_def}" - if self.no_diff or not self.differentiable: - arg_def = f"no_diff {arg_def}" - return arg_def - def __str__(self) -> str: return self._recurse_str(0) diff --git a/src/sgl/device/device.cpp b/src/sgl/device/device.cpp index 4692d406e..10f2de3e2 100644 --- a/src/sgl/device/device.cpp +++ b/src/sgl/device/device.cpp @@ -274,20 +274,20 @@ Device::Device(const DeviceDesc& desc) ); m_info.limits.max_shader_visible_samplers = rhi_device_info.limits.maxShaderVisibleSamplers; - // Set conservative default for max entry-point uniform (push constant / root constant) size. - // The RHI doesn't expose this directly, so we use per-backend defaults. + // TODO: These are known safe limits based on API spec, but could be increased based on + // platform (eg early Vk==128, CUDA 12.1+ supports 32k etc). Either this or the relevant + // information needs to be exposed by slang-rhi. switch (m_desc.type) { case DeviceType::vulkan: // Vulkan spec minimum maxPushConstantsSize is 128 bytes. m_info.limits.max_entry_point_uniform_size = 128; break; case DeviceType::d3d12: - // D3D12 root signature allows 64 DWORDs (256 bytes) total for root constants, - // shared with root descriptors. Use a conservative 256. + // D3D12 root signature allows 64 DWORDs m_info.limits.max_entry_point_uniform_size = 256; break; case DeviceType::cuda: - // CUDA kernel parameter block limit is typically 4KB. + // CUDA kernel parameter block limit pre 12.1 is 4KB. m_info.limits.max_entry_point_uniform_size = 4096; break; default: From 1d0c39f84cf08b6d5faef8476a52be94a5403036 Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Fri, 13 Mar 2026 10:37:14 +0000 Subject: [PATCH 24/41] Rename use_direct_args --- .../plan-simplifyKernelGen-phase2.prompt.md | 42 +++++++------- ...-simplifyKernelGenPhase2-cleanup.prompt.md | 56 +++++++++---------- slangpy/bindings/marshall.py | 2 +- slangpy/core/calldata.py | 18 +++--- slangpy/core/callsignature.py | 36 ++++++------ .../tests/slangpy_tests/test_kernel_gen.py | 30 +++++----- src/slangpy_ext/utils/slangpy.cpp | 10 ++-- src/slangpy_ext/utils/slangpy.h | 6 +- 8 files changed, 100 insertions(+), 100 deletions(-) diff --git a/.github/prompts/plan-simplifyKernelGen-phase2.prompt.md b/.github/prompts/plan-simplifyKernelGen-phase2.prompt.md index 4a2bfb8dc..2f6ca6f80 100644 --- a/.github/prompts/plan-simplifyKernelGen-phase2.prompt.md +++ b/.github/prompts/plan-simplifyKernelGen-phase2.prompt.md @@ -191,8 +191,8 @@ In [slangpy/core/calldata.py](slangpy/core/calldata.py), after `calculate_direct 1. **Query a runtime per-device threshold** for max entry-point parameter inline-uniform size. This is a property of the device/backend — large for D3D12/CUDA (thousands of bytes), potentially as low as 128–256 bytes on Vulkan. 2. **Accumulate inline-uniform byte size** of each bound variable's `calldata_type_name`, plus `_thread_count` (12 bytes) and shape arrays (`call_data_len * 3 * sizeof(int)` for `_grid_stride`, `_grid_dim`, `_call_dim`). **Resource types** (`RWStructuredBuffer`, `Texture2D`, `TensorView`, etc.) don't count — they are bound as descriptors, not inline data. -3. **Decision**: If total size ≤ threshold → `self.use_direct_args = True` (fast path). Otherwise → `self.use_direct_args = False` (fallback path — current behavior). -4. **Store** `use_direct_args` on the `CallData` instance and propagate to C++ `NativeCallData`. +3. **Decision**: If total size ≤ threshold → `self.use_entrypoint_args = True` (fast path). Otherwise → `self.use_entrypoint_args = False` (fallback path — current behavior). +4. **Store** `use_entrypoint_args` on the `CallData` instance and propagate to C++ `NativeCallData`. `PackedArg` / param-block types are excluded from this accounting — they stay as `ParameterBlock` regardless. @@ -200,20 +200,20 @@ In [slangpy/core/calldata.py](slangpy/core/calldata.py), after `calculate_direct - `DeviceLimits.max_entry_point_uniform_size` added to C++ struct ([device.h](src/sgl/device/device.h)) with per-backend defaults: Vulkan=128, D3D12=256, CUDA=4096 bytes ([device.cpp](src/sgl/device/device.cpp)). - `calculate_inline_uniform_size()` added to [callsignature.py](slangpy/core/callsignature.py) — sums `vector_type.uniform_layout.size` for each depth-0 bound variable (skipping `PackedArg`), plus 12 bytes for `_thread_count` and `call_dimensionality * 4 * 3` for shape arrays. -- `use_direct_args` property added to `NativeCallData` C++ class ([slangpy.h](src/slangpy_ext/utils/slangpy.h)) with Python binding. -- `CallData.__init__()` in [calldata.py](slangpy/core/calldata.py) sets `self.use_direct_args = inline_size <= threshold` after `calculate_direct_binding()`. +- `use_entrypoint_args` property added to `NativeCallData` C++ class ([slangpy.h](src/slangpy_ext/utils/slangpy.h)) with Python binding. +- `CallData.__init__()` in [calldata.py](slangpy/core/calldata.py) sets `self.use_entrypoint_args = inline_size <= threshold` after `calculate_direct_binding()`. **Tests** (7 tests × 3 device types = 21 parametrized cases, all pass): | Test | Asserts | |------|---------| -| `test_step21_scalar_uses_direct_args` | Simple `int add(int,int)` with `(1,2)` → `use_direct_args=True` | +| `test_step21_scalar_uses_entrypoint_args` | Simple `int add(int,int)` with `(1,2)` → `use_entrypoint_args=True` | | `test_step21_threshold_property_positive` | `device.info.limits.max_entry_point_uniform_size > 0` | -| `test_step21_vector_uses_direct_args` | `float3` args → `use_direct_args=True` | -| `test_step21_struct_uses_direct_args` | All-scalar struct dict → `use_direct_args=True` | -| `test_step21_tensor_uses_direct_args` | Tensor (descriptor-only, 0 inline bytes) → `use_direct_args=True` | +| `test_step21_vector_uses_entrypoint_args` | `float3` args → `use_entrypoint_args=True` | +| `test_step21_struct_uses_entrypoint_args` | All-scalar struct dict → `use_entrypoint_args=True` | +| `test_step21_tensor_uses_entrypoint_args` | Tensor (descriptor-only, 0 inline bytes) → `use_entrypoint_args=True` | | `test_step21_many_float4x4_may_exceed_vulkan` | 8×float4x4 (524 bytes) exceeds Vulkan/D3D12 thresholds, not CUDA | -| `test_step21_wanghasharg_uses_direct_args` | Non-direct-bind WangHashArg with small inline size → `use_direct_args=True` | +| `test_step21_wanghasharg_uses_entrypoint_args` | Non-direct-bind WangHashArg with small inline size → `use_entrypoint_args=True` | --- @@ -221,16 +221,16 @@ In [slangpy/core/calldata.py](slangpy/core/calldata.py), after `calculate_direct **Status: DONE** -In [slangpy/core/callsignature.py](slangpy/core/callsignature.py) `generate_code()`, when `use_direct_args == True`: +In [slangpy/core/callsignature.py](slangpy/core/callsignature.py) `generate_code()`, when `use_entrypoint_args == True`: **CodeGen changes** in [slangpy/bindings/codegen.py](slangpy/bindings/codegen.py): - Add a `skip_call_data` flag to `CodeGen.__init__`. When `True`, don't emit `struct CallData` / `begin_block()` and gate `end_block()` in `finish()`. - Add `self.entry_point_params: list[str] = []` to collect individual uniform param declarations. - `finish()` ignores the `call_data` block and `use_param_block_for_call_data` when `skip_call_data` is set. -**CallData struct elimination**: Set `cg.skip_call_data = True` when `use_direct_args`. No `struct CallData` emitted. +**CallData struct elimination**: Set `cg.skip_call_data = True` when `use_entrypoint_args`. No `struct CallData` emitted. -**`gen_call_data_code` change** in [slangpy/bindings/boundvariable.py](slangpy/bindings/boundvariable.py): At `depth == 0`, when `use_direct_args`, append to `cg.entry_point_params` instead of `cg.call_data.declare(...)`. The `call_data_structs` block (type aliases, wrapper structs, mapping constants) still gets emitted at module scope. +**`gen_call_data_code` change** in [slangpy/bindings/boundvariable.py](slangpy/bindings/boundvariable.py): At `depth == 0`, when `use_entrypoint_args`, append to `cg.entry_point_params` instead of `cg.call_data.declare(...)`. The `call_data_structs` block (type aliases, wrapper structs, mapping constants) still gets emitted at module scope. **`_thread_count` and shape arrays**: Instead of `cg.call_data.append_statement("uint3 _thread_count")`, append to `cg.entry_point_params`. Same for `_grid_stride`, `_grid_dim`, `_call_dim` when `call_data_len > 0`. @@ -254,7 +254,7 @@ Drop `SV_GroupID` and `SV_GroupIndex` when `call_data_len == 0` — they feed `i **Shape info init**: Changes from `call_data._grid_stride` etc. to just `_grid_stride`, `_grid_dim`, `_call_dim`. -**Fallback path** (`use_direct_args == False`): `struct CallData` is emitted with `ParameterBlock call_data` at module scope on ALL backends (including CUDA). The old `CallDataMode` distinction between `entry_point` (CUDA) and `global_data` (non-CUDA) is removed — `ParameterBlock` works on CUDA, and in practice CUDA will never hit the fallback due to its large (~4KB) inline-uniform limit. +**Fallback path** (`use_entrypoint_args == False`): `struct CallData` is emitted with `ParameterBlock call_data` at module scope on ALL backends (including CUDA). The old `CallDataMode` distinction between `entry_point` (CUDA) and `global_data` (non-CUDA) is removed — `ParameterBlock` works on CUDA, and in practice CUDA will never hit the fallback due to its large (~4KB) inline-uniform limit. See [slangpy/tests/device/test_pipeline_utils.slang](slangpy/tests/device/test_pipeline_utils.slang) for examples of manually-written compute shaders that use entry point parameters on all backends (CUDA, Vulkan, D3D12). @@ -306,11 +306,11 @@ When `call_mode == bwds`: **Status: DONE** — `CallDataMode` enum fully removed. Fast path uses `find_entry_point(0)` on all backends. Fallback path uses global `ParameterBlock` on all backends. -In [src/slangpy_ext/utils/slangpy.cpp](src/slangpy_ext/utils/slangpy.cpp), store `m_use_direct_args` on `NativeCallData` (received from Python `CallData`). Also add to [slangpy.h](src/slangpy_ext/utils/slangpy.h). +In [src/slangpy_ext/utils/slangpy.cpp](src/slangpy_ext/utils/slangpy.cpp), store `m_use_entrypoint_args` on `NativeCallData` (received from Python `CallData`). Also add to [slangpy.h](src/slangpy_ext/utils/slangpy.h). Modify `bind_call_data` lambda in `exec()`: -**Fast path** (`m_use_direct_args == true`): +**Fast path** (`m_use_entrypoint_args == true`): - All backends: Navigate via `cursor.find_entry_point(0)`. This is the entry-point cursor. - Write `_thread_count` as an entry-point param: `entry_point_cursor["_thread_count"]`. - Write shape arrays as entry-point params: `entry_point_cursor["_grid_stride"]`, etc. @@ -318,7 +318,7 @@ Modify `bind_call_data` lambda in `exec()`: - Cache entry-point param field indices on first call (analogous to existing `m_cached_call_data_offsets`). - The `reserve_data` + raw-pointer optimization for `_thread_count` and shape arrays may not work for individual entry-point params at disjoint offsets. Use cursor-based writes for these metadata fields (they're small, performance impact minimal), or check if `reserve_data` still works across the entry-point shader object. -**Fallback path** (`m_use_direct_args == false`): +**Fallback path** (`m_use_entrypoint_args == false`): - All backends: Navigate to global `call_data` field via `cursor.find_field("call_data")`, dereference (it's a `ParameterBlock`), write struct data. The old `CallDataMode` branch (CUDA using `find_entry_point(0)` for call_data) is removed. Remove `m_call_data_mode`, `CallDataMode` enum, and all associated branches from `slangpy.h`, `slangpy.cpp`, `calldata.py`, and `callsignature.py`. --- @@ -333,7 +333,7 @@ Auto-created `_result` is a writable `ValueRef`, currently NOT direct-bind eligi **Fallback path**: `_result` stays as `RWValueRef` inside `CallData`, same as current behavior. -**Implementation note**: The `RWStructuredBuffer` approach for `_result` is only used when `use_direct_args == True` AND all other args are direct-bind (so Context can be omitted). When non-direct-bind args are present, Context exists and `_result` can continue to use `RWValueRef.__slangpy_store(context, value)`. +**Implementation note**: The `RWStructuredBuffer` approach for `_result` is only used when `use_entrypoint_args == True` AND all other args are direct-bind (so Context can be omitted). When non-direct-bind args are present, Context exists and `_result` can continue to use `RWValueRef.__slangpy_store(context, value)`. --- @@ -383,13 +383,13 @@ Auto-created `_result` is a writable `ValueRef`, currently NOT direct-bind eligi | File | Changes | |------|---------| -| [slangpy/core/calldata.py](slangpy/core/calldata.py) | ✅ `use_direct_args` flag, size threshold check, `CallDataMode` removed | +| [slangpy/core/calldata.py](slangpy/core/calldata.py) | ✅ `use_entrypoint_args` flag, size threshold check, `CallDataMode` removed | | [slangpy/core/callsignature.py](slangpy/core/callsignature.py) | ✅ Entry-point params, fast/fallback code paths, `is_entry_point` branch removed. Trampoline still generated (Step 2.3 pending). Bwds `no_diff` on all trampoline params (Step 2.4 done). | | [slangpy/bindings/codegen.py](slangpy/bindings/codegen.py) | ✅ `skip_call_data` flag, `entry_point_params` list | | [slangpy/bindings/boundvariable.py](slangpy/bindings/boundvariable.py) | ✅ `gen_call_data_code` depth-0 entry-point path. `_gen_trampoline_argument()` unused — inline generation in `callsignature.py` used instead. | -| [slangpy/bindings/marshall.py](slangpy/bindings/marshall.py) | ✅ `use_direct_args` field on `BindContext`, `CallDataMode` removed | -| [src/slangpy_ext/utils/slangpy.cpp](src/slangpy_ext/utils/slangpy.cpp) | ✅ `use_direct_args` binding; `bind_call_data` fast path via `find_entry_point(0)`, `CallDataMode` branches removed | -| [src/slangpy_ext/utils/slangpy.h](src/slangpy_ext/utils/slangpy.h) | ✅ `m_use_direct_args` on `NativeCallData`; `m_call_data_mode` removed | +| [slangpy/bindings/marshall.py](slangpy/bindings/marshall.py) | ✅ `use_entrypoint_args` field on `BindContext`, `CallDataMode` removed | +| [src/slangpy_ext/utils/slangpy.cpp](src/slangpy_ext/utils/slangpy.cpp) | ✅ `use_entrypoint_args` binding; `bind_call_data` fast path via `find_entry_point(0)`, `CallDataMode` branches removed | +| [src/slangpy_ext/utils/slangpy.h](src/slangpy_ext/utils/slangpy.h) | ✅ `m_use_entrypoint_args` on `NativeCallData`; `m_call_data_mode` removed | | [src/sgl/device/device.h](src/sgl/device/device.h) | ✅ `max_entry_point_uniform_size` on `DeviceLimits` | | [src/sgl/device/device.cpp](src/sgl/device/device.cpp) | ✅ Per-backend defaults for `max_entry_point_uniform_size` | | [src/slangpy_ext/device/device.cpp](src/slangpy_ext/device/device.cpp) | ✅ Python binding for `max_entry_point_uniform_size` | diff --git a/.github/prompts/plan-simplifyKernelGenPhase2-cleanup.prompt.md b/.github/prompts/plan-simplifyKernelGenPhase2-cleanup.prompt.md index 02d9d1ec6..d5e20c82c 100644 --- a/.github/prompts/plan-simplifyKernelGenPhase2-cleanup.prompt.md +++ b/.github/prompts/plan-simplifyKernelGenPhase2-cleanup.prompt.md @@ -191,8 +191,8 @@ In [slangpy/core/calldata.py](slangpy/core/calldata.py), after `calculate_direct 1. **Query a runtime per-device threshold** for max entry-point parameter inline-uniform size. This is a property of the device/backend — large for D3D12/CUDA (thousands of bytes), potentially as low as 128–256 bytes on Vulkan. 2. **Accumulate inline-uniform byte size** of each bound variable's `calldata_type_name`, plus `_thread_count` (12 bytes) and shape arrays (`call_data_len * 3 * sizeof(int)` for `_grid_stride`, `_grid_dim`, `_call_dim`). **Resource types** (`RWStructuredBuffer`, `Texture2D`, `TensorView`, etc.) don't count — they are bound as descriptors, not inline data. -3. **Decision**: If total size ≤ threshold → `self.use_direct_args = True` (fast path). Otherwise → `self.use_direct_args = False` (fallback path — current behavior). -4. **Store** `use_direct_args` on the `CallData` instance and propagate to C++ `NativeCallData`. +3. **Decision**: If total size ≤ threshold → `self.use_entrypoint_args = True` (fast path). Otherwise → `self.use_entrypoint_args = False` (fallback path — current behavior). +4. **Store** `use_entrypoint_args` on the `CallData` instance and propagate to C++ `NativeCallData`. `PackedArg` / param-block types are excluded from this accounting — they stay as `ParameterBlock` regardless. @@ -200,20 +200,20 @@ In [slangpy/core/calldata.py](slangpy/core/calldata.py), after `calculate_direct - `DeviceLimits.max_entry_point_uniform_size` added to C++ struct ([device.h](src/sgl/device/device.h)) with per-backend defaults: Vulkan=128, D3D12=256, CUDA=4096 bytes ([device.cpp](src/sgl/device/device.cpp)). - `calculate_inline_uniform_size()` added to [callsignature.py](slangpy/core/callsignature.py) — sums `vector_type.uniform_layout.size` for each depth-0 bound variable (skipping `PackedArg`), plus 12 bytes for `_thread_count` and `call_dimensionality * 4 * 3` for shape arrays. -- `use_direct_args` property added to `NativeCallData` C++ class ([slangpy.h](src/slangpy_ext/utils/slangpy.h)) with Python binding. -- `CallData.__init__()` in [calldata.py](slangpy/core/calldata.py) sets `self.use_direct_args = inline_size <= threshold` after `calculate_direct_binding()`. +- `use_entrypoint_args` property added to `NativeCallData` C++ class ([slangpy.h](src/slangpy_ext/utils/slangpy.h)) with Python binding. +- `CallData.__init__()` in [calldata.py](slangpy/core/calldata.py) sets `self.use_entrypoint_args = inline_size <= threshold` after `calculate_direct_binding()`. **Tests** (7 tests × 3 device types = 21 parametrized cases, all pass): | Test | Asserts | |------|---------| -| `test_step21_scalar_uses_direct_args` | Simple `int add(int,int)` with `(1,2)` → `use_direct_args=True` | +| `test_step21_scalar_uses_entrypoint_args` | Simple `int add(int,int)` with `(1,2)` → `use_entrypoint_args=True` | | `test_step21_threshold_property_positive` | `device.info.limits.max_entry_point_uniform_size > 0` | -| `test_step21_vector_uses_direct_args` | `float3` args → `use_direct_args=True` | -| `test_step21_struct_uses_direct_args` | All-scalar struct dict → `use_direct_args=True` | -| `test_step21_tensor_uses_direct_args` | Tensor (descriptor-only, 0 inline bytes) → `use_direct_args=True` | +| `test_step21_vector_uses_entrypoint_args` | `float3` args → `use_entrypoint_args=True` | +| `test_step21_struct_uses_entrypoint_args` | All-scalar struct dict → `use_entrypoint_args=True` | +| `test_step21_tensor_uses_entrypoint_args` | Tensor (descriptor-only, 0 inline bytes) → `use_entrypoint_args=True` | | `test_step21_many_float4x4_may_exceed_vulkan` | 8×float4x4 (524 bytes) exceeds Vulkan/D3D12 thresholds, not CUDA | -| `test_step21_wanghasharg_uses_direct_args` | Non-direct-bind WangHashArg with small inline size → `use_direct_args=True` | +| `test_step21_wanghasharg_uses_entrypoint_args` | Non-direct-bind WangHashArg with small inline size → `use_entrypoint_args=True` | --- @@ -221,16 +221,16 @@ In [slangpy/core/calldata.py](slangpy/core/calldata.py), after `calculate_direct **Status: DONE** -In [slangpy/core/callsignature.py](slangpy/core/callsignature.py) `generate_code()`, when `use_direct_args == True`: +In [slangpy/core/callsignature.py](slangpy/core/callsignature.py) `generate_code()`, when `use_entrypoint_args == True`: **CodeGen changes** in [slangpy/bindings/codegen.py](slangpy/bindings/codegen.py): - Add a `skip_call_data` flag to `CodeGen.__init__`. When `True`, don't emit `struct CallData` / `begin_block()` and gate `end_block()` in `finish()`. - Add `self.entry_point_params: list[str] = []` to collect individual uniform param declarations. - `finish()` ignores the `call_data` block and `use_param_block_for_call_data` when `skip_call_data` is set. -**CallData struct elimination**: Set `cg.skip_call_data = True` when `use_direct_args`. No `struct CallData` emitted. +**CallData struct elimination**: Set `cg.skip_call_data = True` when `use_entrypoint_args`. No `struct CallData` emitted. -**`gen_call_data_code` change** in [slangpy/bindings/boundvariable.py](slangpy/bindings/boundvariable.py): At `depth == 0`, when `use_direct_args`, append to `cg.entry_point_params` instead of `cg.call_data.declare(...)`. The `call_data_structs` block (type aliases, wrapper structs, mapping constants) still gets emitted at module scope. +**`gen_call_data_code` change** in [slangpy/bindings/boundvariable.py](slangpy/bindings/boundvariable.py): At `depth == 0`, when `use_entrypoint_args`, append to `cg.entry_point_params` instead of `cg.call_data.declare(...)`. The `call_data_structs` block (type aliases, wrapper structs, mapping constants) still gets emitted at module scope. **`_thread_count` and shape arrays**: Instead of `cg.call_data.append_statement("uint3 _thread_count")`, append to `cg.entry_point_params`. Same for `_grid_stride`, `_grid_dim`, `_call_dim` when `call_data_len > 0`. @@ -254,7 +254,7 @@ Drop `SV_GroupID` and `SV_GroupIndex` when `call_data_len == 0` — they feed `i **Shape info init**: Changes from `call_data._grid_stride` etc. to just `_grid_stride`, `_grid_dim`, `_call_dim`. -**Fallback path** (`use_direct_args == False`): `struct CallData` is emitted with `ParameterBlock call_data` at module scope on ALL backends (including CUDA). The old `CallDataMode` distinction between `entry_point` (CUDA) and `global_data` (non-CUDA) is removed — `ParameterBlock` works on CUDA, and in practice CUDA will never hit the fallback due to its large (~4KB) inline-uniform limit. +**Fallback path** (`use_entrypoint_args == False`): `struct CallData` is emitted with `ParameterBlock call_data` at module scope on ALL backends (including CUDA). The old `CallDataMode` distinction between `entry_point` (CUDA) and `global_data` (non-CUDA) is removed — `ParameterBlock` works on CUDA, and in practice CUDA will never hit the fallback due to its large (~4KB) inline-uniform limit. See [slangpy/tests/device/test_pipeline_utils.slang](slangpy/tests/device/test_pipeline_utils.slang) for examples of manually-written compute shaders that use entry point parameters on all backends (CUDA, Vulkan, D3D12). @@ -306,11 +306,11 @@ When `call_mode == bwds`: **Status: DONE** — `CallDataMode` enum fully removed. Fast path uses `find_entry_point(0)` on all backends. Fallback path uses global `ParameterBlock` on all backends. -In [src/slangpy_ext/utils/slangpy.cpp](src/slangpy_ext/utils/slangpy.cpp), store `m_use_direct_args` on `NativeCallData` (received from Python `CallData`). Also add to [slangpy.h](src/slangpy_ext/utils/slangpy.h). +In [src/slangpy_ext/utils/slangpy.cpp](src/slangpy_ext/utils/slangpy.cpp), store `m_use_entrypoint_args` on `NativeCallData` (received from Python `CallData`). Also add to [slangpy.h](src/slangpy_ext/utils/slangpy.h). Modify `bind_call_data` lambda in `exec()`: -**Fast path** (`m_use_direct_args == true`): +**Fast path** (`m_use_entrypoint_args == true`): - All backends: Navigate via `cursor.find_entry_point(0)`. This is the entry-point cursor. - Write `_thread_count` as an entry-point param: `entry_point_cursor["_thread_count"]`. - Write shape arrays as entry-point params: `entry_point_cursor["_grid_stride"]`, etc. @@ -318,7 +318,7 @@ Modify `bind_call_data` lambda in `exec()`: - Cache entry-point param field indices on first call (analogous to existing `m_cached_call_data_offsets`). - The `reserve_data` + raw-pointer optimization for `_thread_count` and shape arrays may not work for individual entry-point params at disjoint offsets. Use cursor-based writes for these metadata fields (they're small, performance impact minimal), or check if `reserve_data` still works across the entry-point shader object. -**Fallback path** (`m_use_direct_args == false`): +**Fallback path** (`m_use_entrypoint_args == false`): - All backends: Navigate to global `call_data` field via `cursor.find_field("call_data")`, dereference (it's a `ParameterBlock`), write struct data. The old `CallDataMode` branch (CUDA using `find_entry_point(0)` for call_data) is removed. Remove `m_call_data_mode`, `CallDataMode` enum, and all associated branches from `slangpy.h`, `slangpy.cpp`, `calldata.py`, and `callsignature.py`. --- @@ -333,7 +333,7 @@ Auto-created `_result` is a writable `ValueRef`, currently NOT direct-bind eligi **Fallback path**: `_result` stays as `RWValueRef` inside `CallData`, same as current behavior. -**Implementation note**: The `RWStructuredBuffer` approach for `_result` is only used when `use_direct_args == True` AND all other args are direct-bind (so Context can be omitted). When non-direct-bind args are present, Context exists and `_result` can continue to use `RWValueRef.__slangpy_store(context, value)`. +**Implementation note**: The `RWStructuredBuffer` approach for `_result` is only used when `use_entrypoint_args == True` AND all other args are direct-bind (so Context can be omitted). When non-direct-bind args are present, Context exists and `_result` can continue to use `RWValueRef.__slangpy_store(context, value)`. --- @@ -383,13 +383,13 @@ Auto-created `_result` is a writable `ValueRef`, currently NOT direct-bind eligi | File | Changes | |------|---------| -| [slangpy/core/calldata.py](slangpy/core/calldata.py) | ✅ `use_direct_args` flag, size threshold check, `CallDataMode` removed | +| [slangpy/core/calldata.py](slangpy/core/calldata.py) | ✅ `use_entrypoint_args` flag, size threshold check, `CallDataMode` removed | | [slangpy/core/callsignature.py](slangpy/core/callsignature.py) | ✅ Entry-point params, fast/fallback code paths, `is_entry_point` branch removed. Trampoline still generated (Step 2.3 pending). Bwds `no_diff` on all trampoline params (Step 2.4 done). | | [slangpy/bindings/codegen.py](slangpy/bindings/codegen.py) | ✅ `skip_call_data` flag, `entry_point_params` list | | [slangpy/bindings/boundvariable.py](slangpy/bindings/boundvariable.py) | ✅ `gen_call_data_code` depth-0 entry-point path. `_gen_trampoline_argument()` unused — inline generation in `callsignature.py` used instead. | -| [slangpy/bindings/marshall.py](slangpy/bindings/marshall.py) | ✅ `use_direct_args` field on `BindContext`, `CallDataMode` removed | -| [src/slangpy_ext/utils/slangpy.cpp](src/slangpy_ext/utils/slangpy.cpp) | ✅ `use_direct_args` binding; `bind_call_data` fast path via `find_entry_point(0)`, `CallDataMode` branches removed | -| [src/slangpy_ext/utils/slangpy.h](src/slangpy_ext/utils/slangpy.h) | ✅ `m_use_direct_args` on `NativeCallData`; `m_call_data_mode` removed | +| [slangpy/bindings/marshall.py](slangpy/bindings/marshall.py) | ✅ `use_entrypoint_args` field on `BindContext`, `CallDataMode` removed | +| [src/slangpy_ext/utils/slangpy.cpp](src/slangpy_ext/utils/slangpy.cpp) | ✅ `use_entrypoint_args` binding; `bind_call_data` fast path via `find_entry_point(0)`, `CallDataMode` branches removed | +| [src/slangpy_ext/utils/slangpy.h](src/slangpy_ext/utils/slangpy.h) | ✅ `m_use_entrypoint_args` on `NativeCallData`; `m_call_data_mode` removed | | [src/sgl/device/device.h](src/sgl/device/device.h) | ✅ `max_entry_point_uniform_size` on `DeviceLimits` | | [src/sgl/device/device.cpp](src/sgl/device/device.cpp) | ✅ Per-backend defaults for `max_entry_point_uniform_size` | | [src/slangpy_ext/device/device.cpp](src/slangpy_ext/device/device.cpp) | ✅ Python binding for `max_entry_point_uniform_size` | @@ -462,10 +462,10 @@ Extract into sub-functions: Additionally, the duplicated `data_name` computation at ~L449 and ~L497 should be extracted: ```python -def _data_name(x: BoundVariable, use_direct_args: bool) -> str: +def _data_name(x: BoundVariable, use_entrypoint_args: bool) -> str: if x.create_param_block: return f"_param_{x.variable_name}" - return f"__in_{x.variable_name}" if use_direct_args else f"call_data.{x.variable_name}" + return f"__in_{x.variable_name}" if use_entrypoint_args else f"call_data.{x.variable_name}" ``` **DO NOT FIX** Reason: This is a complex change and will be deferred to a later step. @@ -489,9 +489,9 @@ Fast path → `write_uniforms(ep)`, fallback → `write_uniforms(call_data_curso **6. `_try_build_shader` parameter pattern in calldata.py** -Takes `use_direct_args` parameter then immediately sets `self.use_direct_args` and `context.use_direct_args`. The method never reads the flag except to store it. +Takes `use_entrypoint_args` parameter then immediately sets `self.use_entrypoint_args` and `context.use_entrypoint_args`. The method never reads the flag except to store it. -**FIXED**: Caller sets `self.use_direct_args` before calling; `_try_build_shader` reads `self.use_direct_args` and sets `context.use_direct_args`. Parameter removed. +**FIXED**: Caller sets `self.use_entrypoint_args` before calling; `_try_build_shader` reads `self.use_entrypoint_args` and sets `context.use_entrypoint_args`. Parameter removed. --- @@ -592,15 +592,15 @@ If a writable dim-0 leaf binding gets `direct_bind=True`, `ValueMarshall.gen_tra **DO NOT FIX**: Reason: `CodeGen` is already a shared state bag consumed by multiple modules. Adding a comment is fine but not blocking. -**25. `direct_bind` and `use_direct_args` exposed as read-write in `.pyi` stubs** +**25. `direct_bind` and `use_entrypoint_args` exposed as read-write in `.pyi` stubs** -[__init__.pyi](slangpy/slangpy/__init__.pyi) exposes `direct_bind` on `NativeBoundVariableRuntime` and `use_direct_args` on `NativeCallData` with setters. Mutating these after first dispatch could invalidate cached cursor offsets in `NativeValueMarshall::ensure_cached`. +[__init__.pyi](slangpy/slangpy/__init__.pyi) exposes `direct_bind` on `NativeBoundVariableRuntime` and `use_entrypoint_args` on `NativeCallData` with setters. Mutating these after first dispatch could invalidate cached cursor offsets in `NativeValueMarshall::ensure_cached`. **DO NOT FIX**: Reason: These are set during `CallData` construction before first dispatch. The cached `NativeCallData` is per-signature, so a new signature gets a fresh instance. Post-construction mutation would require going through `debug_build_call_data` which rebuilds everything. Not a practical concern. **26. No fallback-path codegen test in `test_code_gen.py`** -[test_code_gen.py](slangpy/tests/slangpy_tests/test_code_gen.py) has no test that forces `use_direct_args=False` (e.g., by exceeding `max_entry_point_uniform_size`) and asserts the `ParameterBlock` codegen. The `test_step21_many_float4x4_may_exceed_vulkan` in `test_kernel_gen.py` checks the flag but not the generated code. +[test_code_gen.py](slangpy/tests/slangpy_tests/test_code_gen.py) has no test that forces `use_entrypoint_args=False` (e.g., by exceeding `max_entry_point_uniform_size`) and asserts the `ParameterBlock` codegen. The `test_step21_many_float4x4_may_exceed_vulkan` in `test_kernel_gen.py` checks the flag but not the generated code. **DO NOT FIX**: Reason: Step 2.7 will add comprehensive post-implementation tests including `test_phase2_fallback_keeps_calldata` and `test_phase2_fallback_no_trampoline_prim`. diff --git a/slangpy/bindings/marshall.py b/slangpy/bindings/marshall.py index 3b7a9fe7d..509c33b66 100644 --- a/slangpy/bindings/marshall.py +++ b/slangpy/bindings/marshall.py @@ -37,7 +37,7 @@ def __init__( self.call_mode = call_mode #: Whether to use direct entry-point params (fast path) vs ParameterBlock (fallback). - self.use_direct_args = False + self.use_entrypoint_args = False #: SGL module. self.device_module = device_module diff --git a/slangpy/core/calldata.py b/slangpy/core/calldata.py index 947d84e1d..8c6dbea01 100644 --- a/slangpy/core/calldata.py +++ b/slangpy/core/calldata.py @@ -261,13 +261,13 @@ def build(self, build_info: "FunctionBuildInfo", *args: Any, **kwargs: Any) -> N # Determine fast path (entry-point params) vs fallback (ParameterBlock). # Sum inline-uniform byte size and compare against per-device threshold. - inline_size = calculate_inline_uniform_size(bindings, self.call_dimensionality) + inline_size = estimate_entrypoint_arguments_size(bindings, self.call_dimensionality) threshold = build_info.module.device.info.limits.max_entry_point_uniform_size - use_direct_args = inline_size <= threshold + use_entrypoint_args = inline_size <= threshold self.log_debug( f" Inline uniform size: {inline_size} bytes, " f"threshold: {threshold} bytes, " - f"use_direct_args: {use_direct_args}" + f"use_entrypoint_args: {use_entrypoint_args}" ) # Until https://github.com/shader-slang/slang-rhi/pull/676, Vk RTP can't use entry point args @@ -275,13 +275,13 @@ def build(self, build_info: "FunctionBuildInfo", *args: Any, **kwargs: Any) -> N build_info.pipeline_type == PipelineType.ray_tracing and build_info.module.device.info.type == DeviceType.vulkan ): - use_direct_args = False + use_entrypoint_args = False # Try building the shader. If direct args compilation fails (the # threshold is only an approximate heuristic), fall back to # ParameterBlock. try: - self.use_direct_args = use_direct_args + self.use_entrypoint_args = use_entrypoint_args self._try_build_shader( context, build_info, @@ -289,13 +289,13 @@ def build(self, build_info: "FunctionBuildInfo", *args: Any, **kwargs: Any) -> N type_conformances, ) except RuntimeError as e: - if not use_direct_args: + if not use_entrypoint_args: raise self.log_debug( f" Direct args compilation failed ({e}), " "retrying with ParameterBlock" ) - self.use_direct_args = False + self.use_entrypoint_args = False self._try_build_shader( context, build_info, @@ -386,7 +386,7 @@ def _try_build_shader( :param bindings: Bound call with resolved variables. :param type_conformances: Type conformances for entry point. """ - context.use_direct_args = self.use_direct_args + context.use_entrypoint_args = self.use_entrypoint_args # Generate code. codegen = CodeGen() @@ -404,7 +404,7 @@ def _try_build_shader( snippets=True, call_data_structs=True, constants=True, - use_param_block_for_call_data=not context.use_direct_args, + use_param_block_for_call_data=not context.use_entrypoint_args, ) # Optionally write the shader to a file for debugging. diff --git a/slangpy/core/callsignature.py b/slangpy/core/callsignature.py index 524fc650a..4241aa784 100644 --- a/slangpy/core/callsignature.py +++ b/slangpy/core/callsignature.py @@ -162,14 +162,14 @@ def calculate_direct_binding(call: BoundCall): arg.calculate_direct_bind() -def calculate_inline_uniform_size(call: BoundCall, call_dimensionality: int) -> int: +def estimate_entrypoint_arguments_size(call: BoundCall, call_dimensionality: int) -> int: """ - Calculate the total inline-uniform byte size for all depth-0 bound variables, - plus metadata fields (_thread_count, shape arrays). + Estimate the required entry point uniform byte size if the bound call where + to bind all depth-0 bound variables and plus metadata fields + (_thread_count, shape arrays) directly as entry point arguments. - Resource types (StructuredBuffer, Texture2D, etc.) contribute 0 bytes to inline - uniform size since they are bound as descriptors. PackedArg / ParameterBlock - types are excluded from this accounting since they stay as ParameterBlock. + Note: This is currently an estimate, as the actual calculation really needs to + take into account descriptors etc. :param call: The bound call containing all args/kwargs. :param call_dimensionality: The call dimensionality (determines shape array count). @@ -278,7 +278,7 @@ def generate_code( """ # Check if we're using direct entry-point params (fast path) - use_direct_args = context.use_direct_args + use_entrypoint_args = context.use_entrypoint_args # Generate the header cg.add_import("slangpy") @@ -368,12 +368,12 @@ def generate_code( cg.constants.append_statement("}") # Set up code gen mode for direct args vs CallData struct - if use_direct_args: + if use_entrypoint_args: cg.skip_call_data = True # Generate call data inputs if vector call if call_data_len > 0: - if use_direct_args: + if use_entrypoint_args: # Fast path: shape arrays as individual entry-point params cg.entry_point_params.append(f"uniform int[{call_data_len}] _grid_stride") cg.entry_point_params.append(f"uniform int[{call_data_len}] _grid_dim") @@ -393,7 +393,7 @@ def generate_code( # 32 thread aligned. cg.call_data.append_statement(f"int[{call_data_len}] _call_dim") - if use_direct_args: + if use_entrypoint_args: cg.entry_point_params.append("uniform uint3 _thread_count") else: cg.call_data.append_statement("uint3 _thread_count") @@ -410,7 +410,7 @@ def generate_code( if context.call_mode != CallMode.prim: cg.trampoline.append_line("[Differentiable]") - if use_direct_args: + if use_entrypoint_args: # Fast path: trampoline takes individual calldata-typed params. # Use __in_ prefix for param names to avoid collision with local variable names. # All params are no_diff — entry-point uniforms are never differentiable. @@ -434,7 +434,7 @@ def generate_code( assert x.vector_type is not None cg.trampoline.declare(x.vector_type.full_name, x.variable_name) for x in root_params: - if use_direct_args: + if use_entrypoint_args: data_name = ( f"_param_{x.variable_name}" if x.create_param_block else f"__in_{x.variable_name}" ) @@ -483,7 +483,7 @@ def generate_code( or x.access[0] == AccessType.readwrite or x.access[1] == AccessType.read ): - if use_direct_args: + if use_entrypoint_args: data_name = ( f"_param_{x.variable_name}" if x.create_param_block @@ -516,7 +516,7 @@ def generate_code( cg.kernel.append_line("[numthreads(32, 1, 1)]") # Note: While flat_call_thread_id is 3-dimensional, we consider it "flat" and 1-dimensional because of the # true call group shape of [x, 1, 1] and only use the first dimension for the call thread id. - if use_direct_args: + if use_entrypoint_args: # Fast path: build compute_main signature with individual entry-point params sig_parts = ["int3 flat_call_thread_id: SV_DispatchThreadID"] # Only include SV_GroupID/SV_GroupIndex when call_data_len > 0 @@ -533,7 +533,7 @@ def generate_code( ) elif build_info.pipeline_type == PipelineType.ray_tracing: cg.kernel.append_line('[shader("raygen")]') - if use_direct_args: + if use_entrypoint_args: sig_parts = list(cg.entry_point_params) cg.kernel.append_line(f"void raygen_main({', '.join(sig_parts)})") else: @@ -547,7 +547,7 @@ def generate_code( cg.kernel.append_statement("int3 flat_call_thread_id = DispatchRaysIndex();") # Bounds check — use _thread_count directly in fast path, call_data._thread_count in fallback - if use_direct_args: + if use_entrypoint_args: cg.kernel.append_statement("if (any(flat_call_thread_id >= _thread_count)) return") else: cg.kernel.append_statement( @@ -561,7 +561,7 @@ def generate_code( # definition in callshape.slang. if call_data_len > 0: # In fast path, shape arrays are direct entry-point params; in fallback, prefixed with call_data. - grid_prefix = "" if use_direct_args else "call_data." + grid_prefix = "" if use_entrypoint_args else "call_data." if build_info.pipeline_type == PipelineType.compute: cg.kernel.append_line( f""" @@ -587,7 +587,7 @@ def generate_code( if context.call_mode == CallMode.bwds: fn = f"bwd_diff({fn})" - if use_direct_args: + if use_entrypoint_args: # Fast path: pass individual entry-point param names to the trampoline trampoline_args = ["__slangpy_context__"] for x in root_params: diff --git a/slangpy/tests/slangpy_tests/test_kernel_gen.py b/slangpy/tests/slangpy_tests/test_kernel_gen.py index 0866c0422..250e4321b 100644 --- a/slangpy/tests/slangpy_tests/test_kernel_gen.py +++ b/slangpy/tests/slangpy_tests/test_kernel_gen.py @@ -1851,7 +1851,7 @@ def test_phase1_functional_long_struct_name(device_type: spy.DeviceType): @pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) def test_gate_p2_calldata_struct_absent_fast_path(device_type: spy.DeviceType): - """Fast path (use_direct_args=True): no struct CallData emitted. Step 2.2 done.""" + """Fast path (use_entrypoint_args=True): no struct CallData emitted. Step 2.2 done.""" device = helpers.get_device(device_type) code = generate_code(device, "add", "int add(int a, int b) { return a + b; }", 1, 2) assert_not_contains(code, "struct CallData") @@ -1934,13 +1934,13 @@ def build_call_data( @pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_step21_scalar_uses_direct_args(device_type: spy.DeviceType): - """Simple scalar call has small inline-uniform size → use_direct_args=True.""" +def test_step21_scalar_uses_entrypoint_args(device_type: spy.DeviceType): + """Simple scalar call has small inline-uniform size → use_entrypoint_args=True.""" device = helpers.get_device(device_type) cd = build_call_data(device, "add", "int add(int a, int b) { return a + b; }", 1, 2) # Two ints (4+4) + RWValueRef for _result (descriptor, ~0 inline) + uint3 _thread_count (12) # Should be well under any backend's threshold - assert cd.use_direct_args is True + assert cd.use_entrypoint_args is True @pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) @@ -1952,7 +1952,7 @@ def test_step21_threshold_property_positive(device_type: spy.DeviceType): @pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_step21_vector_uses_direct_args(device_type: spy.DeviceType): +def test_step21_vector_uses_entrypoint_args(device_type: spy.DeviceType): """float3 args are small enough for direct args.""" device = helpers.get_device(device_type) cd = build_call_data( @@ -1962,11 +1962,11 @@ def test_step21_vector_uses_direct_args(device_type: spy.DeviceType): spy.math.float3(1, 2, 3), 2.0, ) - assert cd.use_direct_args is True + assert cd.use_entrypoint_args is True @pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_step21_struct_uses_direct_args(device_type: spy.DeviceType): +def test_step21_struct_uses_entrypoint_args(device_type: spy.DeviceType): """All-scalar struct dict has small inline-uniform size.""" device = helpers.get_device(device_type) src = """ @@ -1974,11 +1974,11 @@ def test_step21_struct_uses_direct_args(device_type: spy.DeviceType): float sum(S s) { return s.x + s.y; } """ cd = build_call_data(device, "sum", src, {"_type": "S", "x": 1.0, "y": 2.0}) - assert cd.use_direct_args is True + assert cd.use_entrypoint_args is True @pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_step21_tensor_uses_direct_args(device_type: spy.DeviceType): +def test_step21_tensor_uses_entrypoint_args(device_type: spy.DeviceType): """Tensor args contribute descriptor-only (0 inline bytes) → direct args.""" device = helpers.get_device(device_type) tensor = Tensor.from_numpy(device, np.array([1.0, 2.0, 3.0], dtype=np.float32)) @@ -1988,7 +1988,7 @@ def test_step21_tensor_uses_direct_args(device_type: spy.DeviceType): "float sum_all(float x) { return x; }", tensor, ) - assert cd.use_direct_args is True + assert cd.use_entrypoint_args is True @pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) @@ -2021,15 +2021,15 @@ def test_step21_many_float4x4_may_exceed_vulkan(device_type: spy.DeviceType): ) threshold = device.info.limits.max_entry_point_uniform_size if threshold >= 524: - assert cd.use_direct_args is True + assert cd.use_entrypoint_args is True else: - assert cd.use_direct_args is False + assert cd.use_entrypoint_args is False @pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_step21_wanghasharg_uses_direct_args(device_type: spy.DeviceType): +def test_step21_wanghasharg_uses_entrypoint_args(device_type: spy.DeviceType): """WangHashArg (non-direct-bind) still counts its inline-uniform size. - Its wrapper type has a small inline footprint, so use_direct_args should be True. + Its wrapper type has a small inline footprint, so use_entrypoint_args should be True. """ device = helpers.get_device(device_type) cd = build_call_data( @@ -2038,7 +2038,7 @@ def test_step21_wanghasharg_uses_direct_args(device_type: spy.DeviceType): "uint3 rng(uint3 input) { return input; }", WangHashArg(3), ) - assert cd.use_direct_args is True + assert cd.use_entrypoint_args is True if __name__ == "__main__": diff --git a/src/slangpy_ext/utils/slangpy.cpp b/src/slangpy_ext/utils/slangpy.cpp index f3e7e086a..d172caf3e 100644 --- a/src/slangpy_ext/utils/slangpy.cpp +++ b/src/slangpy_ext/utils/slangpy.cpp @@ -861,7 +861,7 @@ nb::object NativeCallData::exec( auto bind_call_data = [&](ShaderCursor cursor) { - if (m_use_direct_args) { + if (m_use_entrypoint_args) { // ---- Fast path: individual entry-point params ---- ShaderCursor ep = cursor.find_entry_point(0); @@ -1705,11 +1705,11 @@ SGL_PY_EXPORT(utils_slangpy) D_NA(NativeCallData, has_thread_count) ) .def_prop_rw( - "use_direct_args", - &NativeCallData::use_direct_args, - &NativeCallData::set_use_direct_args, + "use_entrypoint_args", + &NativeCallData::use_entrypoint_args, + &NativeCallData::set_use_entrypoint_args, nb::arg(), - D_NA(NativeCallData, use_direct_args) + D_NA(NativeCallData, use_entrypoint_args) ) .def_prop_rw( "autograd_access_list", diff --git a/src/slangpy_ext/utils/slangpy.h b/src/slangpy_ext/utils/slangpy.h index 806de9374..c92b8e419 100644 --- a/src/slangpy_ext/utils/slangpy.h +++ b/src/slangpy_ext/utils/slangpy.h @@ -783,10 +783,10 @@ class NativeCallData : Object { void set_has_thread_count(bool has_thread_count) { m_has_thread_count = has_thread_count; } /// Get whether this call uses direct entry-point parameters (fast path). - bool use_direct_args() const { return m_use_direct_args; } + bool use_entrypoint_args() const { return m_use_entrypoint_args; } /// Set whether this call uses direct entry-point parameters (fast path). - void set_use_direct_args(bool use_direct_args) { m_use_direct_args = use_direct_args; } + void set_use_entrypoint_args(bool use_entrypoint_args) { m_use_entrypoint_args = use_entrypoint_args; } /// Get the autograd access list. /// This is a flat list of AutogradAccess values precomputed at build time. @@ -918,7 +918,7 @@ class NativeCallData : Object { bool m_torch_autograd{false}; bool m_needs_unpack{true}; bool m_has_thread_count{false}; - bool m_use_direct_args{false}; + bool m_use_entrypoint_args{false}; std::vector m_autograd_access_list; ref m_bwds_call_data; mutable CallDataOffsets m_cached_call_data_offsets; From 6c255b3418344891f49738797ef20810ac810ed5 Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Fri, 13 Mar 2026 10:56:29 +0000 Subject: [PATCH 25/41] More tests --- ...-simplifyKernelGenPhase2-cleanup.prompt.md | 46 ++-- slangpy/tests/slangpy_tests/test_code_gen.py | 226 ++++++++++++++++++ 2 files changed, 248 insertions(+), 24 deletions(-) diff --git a/.github/prompts/plan-simplifyKernelGenPhase2-cleanup.prompt.md b/.github/prompts/plan-simplifyKernelGenPhase2-cleanup.prompt.md index d5e20c82c..9aedfefcd 100644 --- a/.github/prompts/plan-simplifyKernelGenPhase2-cleanup.prompt.md +++ b/.github/prompts/plan-simplifyKernelGenPhase2-cleanup.prompt.md @@ -339,29 +339,26 @@ Auto-created `_result` is a writable `ValueRef`, currently NOT direct-bind eligi ### Step 2.7: Tests -**Status: NOT STARTED** +**Status: PARTIAL** — Tests for completed Phase 2 steps added to [test_code_gen.py](slangpy/tests/slangpy_tests/test_code_gen.py). Remaining tests for Step 2.3 (trampoline elimination) and Step 2.6 (`_result` as `RWStructuredBuffer`) will be added when those steps are implemented. + +**Tests added** (in [test_code_gen.py](slangpy/tests/slangpy_tests/test_code_gen.py), tests 35–38, 40): + +| Test | Verifies | Merges from test_kernel_gen.py | +|------|----------|-------------------------------| +| `test_entrypoint_params_scalar_dim0` (#35) | Fast path: no `struct CallData`, individual `uniform` params, `_thread_count` direct, `SV_GroupID` absent at dim-0, `use_entrypoint_args=True` | `test_gate_p2_calldata_struct_absent_fast_path`, `test_gate_p2_individual_uniform_params`, `test_gate_p2_thread_count_direct`, `test_gate_p2_sv_group_id_absent_dim0`, `test_step21_scalar_uses_entrypoint_args` | +| `test_entrypoint_params_vectorized` (#36) | Vectorized fast path: shape arrays as entry-point params, `SV_GroupID`/`SV_GroupIndex` present, no `struct CallData` | (new — covers vectorized entry-point param path) | +| `test_entrypoint_params_non_direct_bind` (#37) | Non-direct-bind arg (WangHashArg) on fast path: no `struct CallData`, wrapper type used, `__slangpy_load`/`Context` present | `test_gate_p2_wanghasharg_keeps_load`, `test_step21_wanghasharg_uses_entrypoint_args` | +| `test_bwds_entrypoint_no_diff_params` (#38) | Bwds fast path: trampoline params have `no_diff` and `__in_` prefix, `bwd_diff(_trampoline)` passes individual args, `[Differentiable]` before trampoline | (new — covers Step 2.4 bwds trampoline) | +| `test_fallback_calldata_large_params` (#40) | Fallback path: 8×float4x4 exceeds threshold → `ParameterBlock`, `call_data._thread_count`; CUDA stays fast path | `test_step21_many_float4x4_may_exceed_vulkan` (adds codegen assertions) | + +**Post-implementation tests** — to be added when remaining steps are complete: -**Post-implementation tests** — should pass AFTER Phase 2 is complete: - -| Test | Verifies | -|------|----------| -| `test_phase2_no_calldata_struct` | `struct CallData` absent for eligible call | -| `test_phase2_uniform_params_on_entry` | Individual `uniform` params on `compute_main` | -| `test_phase2_no_trampoline_prim` | No `void _trampoline(` for prim-mode calls | -| `test_phase2_inline_call` | Function call inlined directly in `compute_main` | -| `test_phase2_thread_count_as_uniform` | `uniform uint3 _thread_count` as entry-point param | -| `test_phase2_no_context_all_direct` | No `Context __slangpy_context__` when all args direct-bind | -| `test_phase2_context_kept_non_direct` | `Context` present when some args use `__slangpy_load` | -| `test_phase2_bwds_trampoline_individual` | Bwds trampoline has individual params with `no_diff` | -| `test_phase2_bwds_bwd_diff_call` | `bwd_diff(_trampoline)(ctx, a, b, ...)` in kernel | -| `test_phase2_no_sv_group_when_dim0` | No `SV_GroupID`/`SV_GroupIndex` when `call_data_len == 0` | -| `test_phase2_sv_group_when_vectorized` | `SV_GroupID`/`SV_GroupIndex` present when `call_data_len > 0` | -| `test_phase2_fallback_keeps_calldata` | Force fallback → `struct CallData` still emitted | -| `test_phase2_fallback_no_trampoline_prim` | Even fallback path eliminates trampoline in prim mode | -| `test_phase2_functional_scalar_add` | `add(1, 2) == 3` end-to-end dispatch | -| `test_phase2_functional_bwds` | Backward pass correct gradients | -| `test_phase2_functional_vectorized` | Vectorized call (shapes) with entry-point params | -| `test_phase2_functional_mixed_direct` | Mix of direct-bind + non-direct-bind args | +| Test | Verifies | Blocked on | +|------|----------|------------| +| `test_phase2_no_trampoline_prim` | No `void _trampoline(` for prim-mode calls | Step 2.3 | +| `test_phase2_inline_call` | Function call inlined directly in `compute_main` | Step 2.3 | +| `test_phase2_no_context_all_direct` | No `Context __slangpy_context__` when all args direct-bind | Step 2.3 | +| `test_phase2_fallback_no_trampoline_prim` | Even fallback path eliminates trampoline in prim mode | Step 2.3 | --- @@ -399,7 +396,8 @@ Auto-created `_result` is a writable `ValueRef`, currently NOT direct-bind eligi | [slangpy/core/function.py](slangpy/core/function.py) | ✅ `CallDataMode` removed from imports | | [slangpy/slangpy/__init__.pyi](slangpy/slangpy/__init__.pyi) | ✅ `CallDataMode` class and `call_data_mode` property removed | | [slangpy/tests/slangpy_tests/test_type_resolution.py](slangpy/tests/slangpy_tests/test_type_resolution.py) | ✅ `CallDataMode` removed from `BindContext` creation | -| [slangpy/tests/slangpy_tests/test_kernel_gen.py](slangpy/tests/slangpy_tests/test_kernel_gen.py) | ✅ Gating tests + Step 2.1 tests updated for new behavior; post-implementation tests (Step 2.7) pending | +| [slangpy/tests/slangpy_tests/test_kernel_gen.py](slangpy/tests/slangpy_tests/test_kernel_gen.py) | ✅ Gating tests + Step 2.1 tests updated for new behavior | +| [slangpy/tests/slangpy_tests/test_code_gen.py](slangpy/tests/slangpy_tests/test_code_gen.py) | ✅ Phase 2 tests 35–38, 40 added (Step 2.7 partial) | --- @@ -602,7 +600,7 @@ If a writable dim-0 leaf binding gets `direct_bind=True`, `ValueMarshall.gen_tra [test_code_gen.py](slangpy/tests/slangpy_tests/test_code_gen.py) has no test that forces `use_entrypoint_args=False` (e.g., by exceeding `max_entry_point_uniform_size`) and asserts the `ParameterBlock` codegen. The `test_step21_many_float4x4_may_exceed_vulkan` in `test_kernel_gen.py` checks the flag but not the generated code. -**DO NOT FIX**: Reason: Step 2.7 will add comprehensive post-implementation tests including `test_phase2_fallback_keeps_calldata` and `test_phase2_fallback_no_trampoline_prim`. +**FIXED**: Added `test_fallback_calldata_large_params` (#40) in `test_code_gen.py` — asserts `ParameterBlock` codegen on Vulkan/D3D12 and fast-path codegen on CUDA. **27. No test for writable `inout` struct at dim-0** diff --git a/slangpy/tests/slangpy_tests/test_code_gen.py b/slangpy/tests/slangpy_tests/test_code_gen.py index 9992e9d02..294c2519b 100644 --- a/slangpy/tests/slangpy_tests/test_code_gen.py +++ b/slangpy/tests/slangpy_tests/test_code_gen.py @@ -80,6 +80,28 @@ def generate_bwds_code_and_bindings( return cd.code, cd.debug_only_bindings +def build_call_data_full( + device: spy.Device, func_name: str, module_source: str, *args: Any, **kwargs: Any +) -> tuple[str, Any, Any]: + """Build CallData and return ``(code_str, bindings, call_data)``.""" + func = helpers.create_function_from_module(device, func_name, module_source) + cd = func.debug_build_call_data(*args, **kwargs) + if PRINT_CODE: + print(cd.code) + return cd.code, cd.debug_only_bindings, cd + + +def build_bwds_call_data_full( + device: spy.Device, func_name: str, module_source: str, *args: Any, **kwargs: Any +) -> tuple[str, Any, Any]: + """Build bwds CallData and return ``(code_str, bindings, call_data)``.""" + func = helpers.create_function_from_module(device, func_name, module_source) + cd = func.bwds.debug_build_call_data(*args, **kwargs) + if PRINT_CODE: + print(cd.code) + return cd.code, cd.debug_only_bindings, cd + + # =========================================================================== # Codegen + binding flag tests (1–21) # =========================================================================== @@ -1070,5 +1092,209 @@ def test_dispatch_struct_array_of_structs(device_type: spy.DeviceType): assert result == 100 +# =========================================================================== +# Phase 2 — entry-point params (35–38, 40) +# =========================================================================== + + +# 35 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_entrypoint_params_scalar_dim0(device_type: spy.DeviceType): + """Fast path: scalar dim-0 uses individual uniform entry-point params. + + Verifies: no struct CallData, no ParameterBlock, individual uniform params + for a/b/_thread_count, _thread_count used directly in bounds check, + SV_GroupID absent (dim-0 has no shape arrays), use_entrypoint_args=True. + + Merges: test_gate_p2_calldata_struct_absent_fast_path, + test_gate_p2_individual_uniform_params, test_gate_p2_thread_count_direct, + test_gate_p2_sv_group_id_absent_dim0, test_step21_scalar_uses_entrypoint_args. + """ + device = helpers.get_device(device_type) + code, bindings, cd = build_call_data_full( + device, "add", "int add(int a, int b) { return a + b; }", 1, 2 + ) + + # --- fast path flag --- + assert cd.use_entrypoint_args is True + + # --- no CallData struct or ParameterBlock --- + assert_not_contains(code, "struct CallData", "ParameterBlock", "uniform CallData") + + # --- individual uniform params on compute_main --- + assert_contains(code, "uniform uint3 _thread_count") + assert_contains(code, "uniform int a") + assert_contains(code, "uniform int b") + + # --- _thread_count used directly in bounds check --- + assert_not_contains(code, "call_data._thread_count") + main_idx = code.index("void compute_main(") + main_body = code[main_idx:] + assert ">= _thread_count)" in main_body + + # --- SV_GroupID absent for dim-0 (no shape arrays) --- + assert_not_contains(code, "SV_GroupID") + + +# 36 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_entrypoint_params_vectorized(device_type: spy.DeviceType): + """Fast path vectorized: shape arrays as entry-point params, SV_GroupID present. + + Verifies: use_entrypoint_args=True, shape arrays (_grid_stride, _grid_dim, + _call_dim) as uniform params, SV_GroupID/SV_GroupIndex present when + call_data_len > 0, no struct CallData. + """ + device = helpers.get_device(device_type) + tensor = Tensor.from_numpy(device, np.array([1, 2, 3], dtype=np.float32)) + code, bindings, cd = build_call_data_full( + device, "add", "float add(float a, float b) { return a + b; }", 1.0, tensor + ) + + # --- fast path --- + assert cd.use_entrypoint_args is True + + # --- no CallData --- + assert_not_contains(code, "struct CallData") + + # --- SV_GroupID/SV_GroupIndex present (call_data_len > 0) --- + assert_contains(code, "SV_GroupID", "SV_GroupIndex") + + # --- shape arrays as entry-point params --- + assert_contains( + code, "uniform int[1] _grid_stride", "uniform int[1] _grid_dim", "uniform int[1] _call_dim" + ) + + # --- shape arrays NOT prefixed with call_data. in kernel body --- + assert_not_contains( + code, "call_data._grid_stride", "call_data._grid_dim", "call_data._call_dim" + ) + + +# 37 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_entrypoint_params_non_direct_bind(device_type: spy.DeviceType): + """Fast path with non-direct-bind arg: no CallData, wrapper used, Context present. + + WangHashArg is NOT direct-bind but still goes as an entry-point param on the + fast path. __slangpy_load and Context are present because the wrapper needs them. + + Merges: test_gate_p2_wanghasharg_keeps_load, + test_step21_wanghasharg_uses_entrypoint_args. + """ + device = helpers.get_device(device_type) + code, bindings, cd = build_call_data_full( + device, "rng", "uint3 rng(uint3 input) { return input; }", WangHashArg(3) + ) + + # --- fast path despite non-direct-bind --- + assert cd.use_entrypoint_args is True + + # --- non-direct-bind binding --- + assert bindings.args[0].direct_bind is False + + # --- wrapper type used --- + assert_contains(code, "WangHashArg<") + + # --- __slangpy_load and Context present --- + assert_contains(code, "__slangpy_load", "Context") + + # --- no CallData struct --- + assert_not_contains(code, "struct CallData") + + +# 38 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_bwds_entrypoint_no_diff_params(device_type: spy.DeviceType): + """Bwds fast path: trampoline params have no_diff, bwd_diff call passes individuals. + + Verifies: use_entrypoint_args=True, trampoline params have 'no_diff' and + '__in_' prefix, bwd_diff(_trampoline) call passes individual arg names, + [Differentiable] before trampoline. + """ + device = helpers.get_device(device_type) + src = """ +[Differentiable] +float polynomial(float a, float b) { + return a * a + b + 1; +} +""" + code, bindings, cd = build_bwds_call_data_full(device, "polynomial", src, 5.0, 10.0, 26.0) + + # --- fast path --- + assert cd.use_entrypoint_args is True + + # --- trampoline params have no_diff and __in_ prefix --- + assert_contains(code, "no_diff") + assert_contains(code, "__in_a") + assert_contains(code, "__in_b") + + # --- [Differentiable] before trampoline --- + diff_idx = code.index("[Differentiable]") + trampoline_idx = code.index("void _trampoline") + assert diff_idx < trampoline_idx + + # --- bwd_diff call passes individual args (not just context) --- + main_idx = code.index("void compute_main(") + main_body = code[main_idx:] + assert "bwd_diff(_trampoline)(__slangpy_context__" in main_body + # Should have more than just the context arg + bwd_call_start = main_body.index("bwd_diff(_trampoline)(") + bwd_call_end = main_body.index(")", bwd_call_start + len("bwd_diff(_trampoline)(")) + bwd_call_args = main_body[bwd_call_start:bwd_call_end] + assert ", a," in bwd_call_args or ", a)" in bwd_call_args + + # --- no struct CallData --- + assert_not_contains(code, "struct CallData") + + +# 40 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_fallback_calldata_large_params(device_type: spy.DeviceType): + """Fallback path: many float4x4 params exceed threshold → ParameterBlock. + + 8 × float4x4 = 512 bytes + 12 bytes _thread_count = 524 bytes. + Exceeds Vulkan (128) and D3D12 (256); CUDA (4096) stays on fast path. + Asserts codegen patterns match the expected path. + + Merges: test_step21_many_float4x4_may_exceed_vulkan (adds codegen assertions). + """ + device = helpers.get_device(device_type) + src = """ +float4x4 sum8(float4x4 a, float4x4 b, float4x4 c, float4x4 d, + float4x4 e, float4x4 f, float4x4 g, float4x4 h) { + return a + b + c + d + e + f + g + h; +} +""" + identity = spy.math.float4x4.identity() + code, bindings, cd = build_call_data_full( + device, + "sum8", + src, + identity, + identity, + identity, + identity, + identity, + identity, + identity, + identity, + ) + + threshold = device.info.limits.max_entry_point_uniform_size + if threshold >= 524: + # CUDA: fast path — no CallData, individual uniform params + assert cd.use_entrypoint_args is True + assert_not_contains(code, "struct CallData") + assert_contains(code, "uniform uint3 _thread_count") + else: + # Vulkan/D3D12: fallback — struct CallData + ParameterBlock + assert cd.use_entrypoint_args is False + assert_contains(code, "struct CallData") + assert_contains(code, "ParameterBlock call_data") + assert_contains(code, "call_data._thread_count") + assert_not_contains(code, "uniform uint3 _thread_count") + + if __name__ == "__main__": pytest.main([__file__, "-vs"]) From 6fdf58b69ade71ce4f89c9225fb39eed3e340760 Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Fri, 13 Mar 2026 11:31:40 +0000 Subject: [PATCH 26/41] Don't use ep args on metal for now --- slangpy/core/calldata.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/slangpy/core/calldata.py b/slangpy/core/calldata.py index 8c6dbea01..367890c91 100644 --- a/slangpy/core/calldata.py +++ b/slangpy/core/calldata.py @@ -277,6 +277,10 @@ def build(self, build_info: "FunctionBuildInfo", *args: Any, **kwargs: Any) -> N ): use_entrypoint_args = False + # Disable for Metal until I can figure out how entry point args work properly + if build_info.module.device.info.type == DeviceType.metal: + use_entrypoint_args = False + # Try building the shader. If direct args compilation fails (the # threshold is only an approximate heuristic), fall back to # ParameterBlock. From 43a4688fa59efbf5eb5cc59b287e39dc9ed0390e Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Fri, 13 Mar 2026 11:31:52 +0000 Subject: [PATCH 27/41] Plan for code gen cleanup --- .../plan-extractCodegenToGenerator.prompt.md | 179 ++++++++++++++++++ 1 file changed, 179 insertions(+) create mode 100644 .github/prompts/plan-extractCodegenToGenerator.prompt.md diff --git a/.github/prompts/plan-extractCodegenToGenerator.prompt.md b/.github/prompts/plan-extractCodegenToGenerator.prompt.md new file mode 100644 index 000000000..22f4bc60a --- /dev/null +++ b/.github/prompts/plan-extractCodegenToGenerator.prompt.md @@ -0,0 +1,179 @@ +## Extract codegen into generator.py + +**Goal**: Extract the code-emission logic from [callsignature.py](slangpy/core/callsignature.py) (`generate_code`, `generate_constants`, `KernelGenException`, helpers) and `BoundVariable.gen_call_data_code` from [boundvariable.py](slangpy/bindings/boundvariable.py) into a new [generator.py](slangpy/core/generator.py) file. The new file decomposes the monolithic `generate_code` (332 lines) into clearly-named sub-functions with doc comments showing what Slang code each one emits. `callsignature.py` retains the binding-pipeline functions (`specialize`, `bind`, `calculate_*`, etc.). Each step is a pure move/rename with no behavioral changes, verifiable by the existing test suites. + +**Parent plan**: [plan-simplifyKernelGenPhase2-cleanup.prompt.md](plan-simplifyKernelGenPhase2-cleanup.prompt.md) + +--- + +### Step 1: Create `slangpy/core/generator.py` with `generate_constants` and `KernelGenException` + +Move these small, self-contained pieces first: + +- **Move** `KernelGenException` (lines 40–43) from [callsignature.py](slangpy/core/callsignature.py#L40-L43). +- **Move** `is_slangpy_vector` (lines 240–247) from [callsignature.py](slangpy/core/callsignature.py#L240-L247) — private helper, prefix with `_`. +- **Move** `generate_constants` (lines 250–268) from [callsignature.py](slangpy/core/callsignature.py#L250-L268). +- **In [callsignature.py](slangpy/core/callsignature.py)**: Add `from slangpy.core.generator import KernelGenException, generate_constants` and delete the moved code. Keep a re-export of `KernelGenException` so any external consumer of the wildcard import from [calldata.py](slangpy/core/calldata.py#L8) continues to work. +- **In [dispatchdata.py](slangpy/core/dispatchdata.py#L7)**: Change `from slangpy.core.callsignature import generate_constants` → `from slangpy.core.generator import generate_constants`. + +**Verify**: `pytest slangpy/tests -v` — all tests pass, no import errors. + +--- + +### Step 2: Extract `gen_call_data_code` as a free function + +Move `BoundVariable.gen_call_data_code` (lines 604–693 of [boundvariable.py](slangpy/bindings/boundvariable.py#L604-L693)) into `generator.py` as a free function, along with the related `gen_calldata_type_name` helper (lines 258–272 of [boundvariable.py](slangpy/bindings/boundvariable.py#L258-L272)). + +- **In `generator.py`**: Create two free functions: + - `gen_calldata_type_name(binding: BoundVariable, cgb: CodeGenBlock, type_name: str) -> None` — same logic, takes `binding` as first arg instead of `self`. + - `gen_call_data_code(binding: BoundVariable, cg: CodeGen, context: BindContext, depth: int = 0) -> None` — same logic, recursive calls use the free function. References to `self` become `binding`. Internal calls to `self.gen_calldata_type_name(...)` become `gen_calldata_type_name(binding, ...)`. Recursive calls on children become `gen_call_data_code(child, cg, context, depth + 1)`. +- **In [boundvariable.py](slangpy/bindings/boundvariable.py)**: Replace the method bodies with thin delegations: + ```python + def gen_calldata_type_name(self, cgb, type_name): + from slangpy.core.generator import gen_calldata_type_name + gen_calldata_type_name(self, cgb, type_name) + + def gen_call_data_code(self, cg, context, depth=0): + from slangpy.core.generator import gen_call_data_code + gen_call_data_code(self, cg, context, depth) + ``` + This preserves the existing call interface (`node.gen_call_data_code(cg, context)` in [callsignature.py line 406](slangpy/core/callsignature.py#L406)) and any marshall subclass code that calls `self.gen_calldata_type_name`. The `MAX_INLINE_TYPE_LEN` constant moves to `generator.py`. +- **Move** the import of `CodeGen` and `CodeGenBlock` into `generator.py` (already needed for Step 1). + +**Verify**: `pytest slangpy/tests -v` — all tests pass. + +--- + +### Step 3: Decompose `generate_code` into sub-functions and move to `generator.py` + +This is the main step. Move `generate_code` (lines 271–603 of [callsignature.py](slangpy/core/callsignature.py#L271-L603)) into `generator.py` and split it into clearly-named sub-functions. Each function has a docstring describing what Slang code it emits. + +The decomposition: + +| New function | Source lines | What it emits | +|---|---|---| +| `_validate_and_compute_group_shape(build_info, call_data_len) -> tuple[int, list[int], list[int]]` | [293–340](slangpy/core/callsignature.py#L293-L340) | Nothing — pure validation. Returns `(call_group_size, call_group_strides, call_group_shape_vector)`. | +| `_emit_link_time_constants(cg, build_info, call_data_len, call_group_size, call_group_strides, call_group_shape_vector)` | [342–371](slangpy/core/callsignature.py#L342-L371) | `export static const int call_data_len = ...`, group stride/shape arrays. Also calls `generate_constants()`. | +| `_emit_shape_and_metadata_params(cg, call_data_len, use_entrypoint_args)` | [373–403](slangpy/core/callsignature.py#L373-L403) | `_grid_stride`, `_grid_dim`, `_call_dim`, `_thread_count` — as entry-point params (fast) or `CallData` fields (fallback). | +| `_emit_call_data_definitions(cg, context, signature)` | [405–406](slangpy/core/callsignature.py#L405-L406) | Per-variable call data (wrapper structs, type aliases, mapping constants). Calls `gen_call_data_code` on each node. | +| `_data_name(x, use_entrypoint_args) -> str` | Duplicated at [449](slangpy/core/callsignature.py#L449) and [497](slangpy/core/callsignature.py#L497) | Helper: returns `__in_{name}`, `call_data.{name}`, or `_param_{name}`. | +| `_emit_trampoline(cg, context, build_info, signature, root_params, use_entrypoint_args)` | [408–500](slangpy/core/callsignature.py#L408-L500) | `[Differentiable] void _trampoline(...)` — param declarations, loads, function call, stores. | +| `_emit_entry_point_signature(cg, build_info, call_data_len, call_group_size, use_entrypoint_args)` | [503–541](slangpy/core/callsignature.py#L503-L541) | `[shader("compute")] [numthreads(...)] void compute_main(...)` or `[shader("raygen")] void raygen_main(...)`. | +| `_emit_kernel_body(cg, context, build_info, root_params, call_data_len, use_entrypoint_args)` | [543–603](slangpy/core/callsignature.py#L543-L603) | Bounds check, `init_thread_local_call_shape_info`, Context construction, trampoline call. | + +The top-level `generate_code` becomes a ~30-line orchestrator that calls these in order: + +```python +def generate_code(context, build_info, signature, cg): + use_entrypoint_args = context.use_entrypoint_args + cg.add_import("slangpy") + call_data_len = context.call_dimensionality + + call_group_size, strides, shape = _validate_and_compute_group_shape(build_info, call_data_len) + + cg.add_import(build_info.module.name) + if use_entrypoint_args: + cg.skip_call_data = True + + _emit_link_time_constants(cg, build_info, call_data_len, call_group_size, strides, shape) + _emit_shape_and_metadata_params(cg, call_data_len, use_entrypoint_args) + _emit_call_data_definitions(cg, context, signature) + + root_params = sorted(signature.values(), key=lambda x: x.param_index) + + _emit_trampoline(cg, context, build_info, root_params, use_entrypoint_args) + _emit_entry_point_signature(cg, build_info, call_data_len, call_group_size, use_entrypoint_args) + cg.kernel.begin_block() + _emit_kernel_body(cg, context, build_info, root_params, call_data_len, use_entrypoint_args) + cg.kernel.end_block() +``` + +- **In [callsignature.py](slangpy/core/callsignature.py)**: Delete `generate_code` and add `from slangpy.core.generator import generate_code` (or let the existing wildcard import consumer in [calldata.py](slangpy/core/calldata.py) point to `generator` instead). +- **Update [calldata.py](slangpy/core/calldata.py#L8)**: Change `from slangpy.core.callsignature import *` to explicit imports: binding-pipeline functions from `callsignature`, and `generate_code`, `KernelGenException` from `generator`. This eliminates the wildcard import, making dependencies explicit. + +**Verify**: `pytest slangpy/tests -v` — all tests pass. `$env:SLANGPY_PRINT_GENERATED_SHADERS="1"; pytest slangpy/tests/slangpy_tests/test_code_gen.py -v` — generated code unchanged. + +--- + +### Step 4: Clean up `callsignature.py` + +After Step 3, `callsignature.py` no longer has any codegen functions. Clean up: + +- Remove unused imports that were only needed by codegen (`CodeGen`, `PipelineType`, `AccessType`, `NoneMarshall`, `BoundVariableException` if no longer referenced). +- Remove re-exports of moved symbols once [calldata.py](slangpy/core/calldata.py) uses direct imports from `generator`. +- Add `from slangpy.core.generator import KernelGenException, ResolveException` re-exports **only if** external consumers import them from `callsignature` (check via grep). If only `calldata.py` uses them, the explicit import is sufficient. + +**Verify**: `pytest slangpy/tests -v`. `pre-commit run --all-files`. + +--- + +### Step 5: Add comments to `generator.py` sub-functions + +Enrich each sub-function's docstring with an example of the Slang code it generates, for both the fast path and fallback path. For example: + +```python +def _emit_shape_and_metadata_params( + cg: CodeGen, + call_data_len: int, + use_entrypoint_args: bool, +) -> None: + """Emit shape arrays and _thread_count. + + Fast path (entry-point params):: + + uniform int[2] _grid_stride + uniform int[2] _grid_dim + uniform int[2] _call_dim + uniform uint3 _thread_count + + Fallback (CallData struct fields):: + + int[2] _grid_stride; + int[2] _grid_dim; + int[2] _call_dim; + uint3 _thread_count; + """ +``` + +This is documentation-only, no functional changes. + +**Verify**: `pre-commit run --all-files` (formatting check). + +--- + +### Verification + +At each step: +```bash +cmake --build --preset windows-msvc-debug +pytest slangpy/tests -v +pre-commit run --all-files +``` + +After Step 3 specifically, also verify generated shader output is unchanged: +```powershell +$env:SLANGPY_PRINT_GENERATED_SHADERS="1"; pytest slangpy/tests/slangpy_tests/test_code_gen.py -v +``` + +Compare output before/after to confirm byte-identical generated Slang code. + +--- + +### Decisions + +- `gen_call_data_code` extracted as free function in `generator.py`; thin delegation stub kept on `BoundVariable` to preserve the method-call interface (`node.gen_call_data_code(cg, context)`) used in `generate_code` and potentially in external/user code. +- `generator.py` lives at `slangpy/core/generator.py` alongside `callsignature.py` and `calldata.py`. +- Wildcard import `from slangpy.core.callsignature import *` in `calldata.py` replaced with explicit imports to make dependencies clear. +- Sub-function names prefixed with `_` (private to the module); only `generate_code`, `generate_constants`, `gen_call_data_code`, `gen_calldata_type_name`, `KernelGenException` are public. + +--- + +### Key Files + +| File | Changes | +|------|---------| +| [slangpy/core/generator.py](slangpy/core/generator.py) | **NEW** — `generate_code`, `generate_constants`, `gen_call_data_code`, `gen_calldata_type_name`, `KernelGenException`, private helpers | +| [slangpy/core/callsignature.py](slangpy/core/callsignature.py) | Remove `generate_code`, `generate_constants`, `KernelGenException`, `is_slangpy_vector`; add re-exports from `generator` | +| [slangpy/bindings/boundvariable.py](slangpy/bindings/boundvariable.py) | `gen_call_data_code` and `gen_calldata_type_name` become thin delegation stubs; `MAX_INLINE_TYPE_LEN` moves out | +| [slangpy/core/calldata.py](slangpy/core/calldata.py) | Replace `from slangpy.core.callsignature import *` with explicit imports from `callsignature` and `generator` | +| [slangpy/core/dispatchdata.py](slangpy/core/dispatchdata.py) | Import `generate_constants` from `generator` instead of `callsignature` | From 828cf2c80c30366b0e9c4424236ca832dee1885d Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Fri, 13 Mar 2026 13:31:08 +0000 Subject: [PATCH 28/41] wip generator cleanup --- slangpy/core/callsignature.py | 34 +--------------------------- slangpy/core/dispatchdata.py | 2 +- slangpy/core/generator.py | 42 +++++++++++++++++++++++++++++++++++ 3 files changed, 44 insertions(+), 34 deletions(-) create mode 100644 slangpy/core/generator.py diff --git a/slangpy/core/callsignature.py b/slangpy/core/callsignature.py index 4241aa784..9e30265a5 100644 --- a/slangpy/core/callsignature.py +++ b/slangpy/core/callsignature.py @@ -38,10 +38,7 @@ def __init__(self, message: str): self.message = message -class KernelGenException(Exception): - def __init__(self, message: str): - super().__init__(message) - self.message = message +from slangpy.core.generator import KernelGenException, generate_constants # noqa: F401 # This detects if a type is a vector with its length defined by a generic @@ -238,35 +235,6 @@ def create_return_value_binding(context: BindContext, signature: BoundCall, retu node.python = python_type -def is_slangpy_vector(type: Any) -> bool: - return ( - hasattr(type, "element_type") - and hasattr(type, "shape") - and len(type.shape) == 1 - and type.shape[0] <= 4 - ) - - -def generate_constants(build_info: "FunctionBuildInfo", cg: CodeGen) -> None: - if build_info.constants is not None: - for k, v in build_info.constants.items(): - if isinstance(v, bool): - cg.constants.append_statement( - f"export static const bool {k} = {'true' if v else 'false'}" - ) - elif isinstance(v, (int, float)): - cg.constants.append_statement(f"export static const {type(v).__name__} {k} = {v}") - elif is_slangpy_vector(v): - # Cheeky logic to take, eg, {0,0,0} -> float3(0,0,0) - tn = type(v).__name__ - txt = f"{tn}({str(v)[1:-1]})" - cg.constants.append_statement(f"export static const {tn} {k} = {txt}") - else: - raise KernelGenException( - f"Constant value '{k}' must be an int, float or bool, not {type(v).__name__}" - ) - - def generate_code( context: BindContext, build_info: "FunctionBuildInfo", diff --git a/slangpy/core/dispatchdata.py b/slangpy/core/dispatchdata.py index 940ee3747..333323629 100644 --- a/slangpy/core/dispatchdata.py +++ b/slangpy/core/dispatchdata.py @@ -4,7 +4,7 @@ import re from typing import TYPE_CHECKING, Any, Optional -from slangpy.core.callsignature import generate_constants +from slangpy.core.generator import generate_constants from slangpy.core.enums import IOType from slangpy.core.native import CallMode, pack_arg, unpack_arg from slangpy.core.calldata import _DUMP_SLANG_INTERMEDIATES, _DUMP_GENERATED_SHADERS diff --git a/slangpy/core/generator.py b/slangpy/core/generator.py new file mode 100644 index 000000000..382cf549f --- /dev/null +++ b/slangpy/core/generator.py @@ -0,0 +1,42 @@ +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from typing import TYPE_CHECKING, Any + +from slangpy.bindings.codegen import CodeGen + +if TYPE_CHECKING: + from slangpy.core.function import FunctionBuildInfo + + +class KernelGenException(Exception): + def __init__(self, message: str): + super().__init__(message) + self.message = message + + +def _is_slangpy_vector(type: Any) -> bool: + return ( + hasattr(type, "element_type") + and hasattr(type, "shape") + and len(type.shape) == 1 + and type.shape[0] <= 4 + ) + + +def generate_constants(build_info: "FunctionBuildInfo", cg: CodeGen) -> None: + if build_info.constants is not None: + for k, v in build_info.constants.items(): + if isinstance(v, bool): + cg.constants.append_statement( + f"export static const bool {k} = {'true' if v else 'false'}" + ) + elif isinstance(v, (int, float)): + cg.constants.append_statement(f"export static const {type(v).__name__} {k} = {v}") + elif _is_slangpy_vector(v): + # Cheeky logic to take, eg, {0,0,0} -> float3(0,0,0) + tn = type(v).__name__ + txt = f"{tn}({str(v)[1:-1]})" + cg.constants.append_statement(f"export static const {tn} {k} = {txt}") + else: + raise KernelGenException( + f"Constant value '{k}' must be an int, float or bool, not {type(v).__name__}" + ) From a48603fd78b08666b51be9c394e40fef2b912f81 Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Fri, 13 Mar 2026 13:31:19 +0000 Subject: [PATCH 29/41] Disable metal tests --- slangpy/tests/slangpy_tests/test_code_gen.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/slangpy/tests/slangpy_tests/test_code_gen.py b/slangpy/tests/slangpy_tests/test_code_gen.py index 294c2519b..9ec2c36d7 100644 --- a/slangpy/tests/slangpy_tests/test_code_gen.py +++ b/slangpy/tests/slangpy_tests/test_code_gen.py @@ -1110,6 +1110,8 @@ def test_entrypoint_params_scalar_dim0(device_type: spy.DeviceType): test_gate_p2_individual_uniform_params, test_gate_p2_thread_count_direct, test_gate_p2_sv_group_id_absent_dim0, test_step21_scalar_uses_entrypoint_args. """ + if device_type == spy.DeviceType.metal: + pytest.skip("Metal doesn't support entry point params.") device = helpers.get_device(device_type) code, bindings, cd = build_call_data_full( device, "add", "int add(int a, int b) { return a + b; }", 1, 2 @@ -1145,6 +1147,8 @@ def test_entrypoint_params_vectorized(device_type: spy.DeviceType): _call_dim) as uniform params, SV_GroupID/SV_GroupIndex present when call_data_len > 0, no struct CallData. """ + if device_type == spy.DeviceType.metal: + pytest.skip("Metal doesn't support entry point params.") device = helpers.get_device(device_type) tensor = Tensor.from_numpy(device, np.array([1, 2, 3], dtype=np.float32)) code, bindings, cd = build_call_data_full( @@ -1182,6 +1186,8 @@ def test_entrypoint_params_non_direct_bind(device_type: spy.DeviceType): Merges: test_gate_p2_wanghasharg_keeps_load, test_step21_wanghasharg_uses_entrypoint_args. """ + if device_type == spy.DeviceType.metal: + pytest.skip("Metal doesn't support entry point params.") device = helpers.get_device(device_type) code, bindings, cd = build_call_data_full( device, "rng", "uint3 rng(uint3 input) { return input; }", WangHashArg(3) @@ -1212,6 +1218,8 @@ def test_bwds_entrypoint_no_diff_params(device_type: spy.DeviceType): '__in_' prefix, bwd_diff(_trampoline) call passes individual arg names, [Differentiable] before trampoline. """ + if device_type == spy.DeviceType.metal: + pytest.skip("Metal doesn't support entry point params.") device = helpers.get_device(device_type) src = """ [Differentiable] From 77b6bf65062d90d6cee18e0ca86ca261220bc833 Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Fri, 13 Mar 2026 14:24:05 +0000 Subject: [PATCH 30/41] Update verification commands in documentation to reflect correct test paths --- .../plan-extractCodegenToGenerator.prompt.md | 73 +++++++++++++------ 1 file changed, 49 insertions(+), 24 deletions(-) diff --git a/.github/prompts/plan-extractCodegenToGenerator.prompt.md b/.github/prompts/plan-extractCodegenToGenerator.prompt.md index 22f4bc60a..237d61604 100644 --- a/.github/prompts/plan-extractCodegenToGenerator.prompt.md +++ b/.github/prompts/plan-extractCodegenToGenerator.prompt.md @@ -16,7 +16,7 @@ Move these small, self-contained pieces first: - **In [callsignature.py](slangpy/core/callsignature.py)**: Add `from slangpy.core.generator import KernelGenException, generate_constants` and delete the moved code. Keep a re-export of `KernelGenException` so any external consumer of the wildcard import from [calldata.py](slangpy/core/calldata.py#L8) continues to work. - **In [dispatchdata.py](slangpy/core/dispatchdata.py#L7)**: Change `from slangpy.core.callsignature import generate_constants` → `from slangpy.core.generator import generate_constants`. -**Verify**: `pytest slangpy/tests -v` — all tests pass, no import errors. +**Verify**: `pytest slangpy/tests/slangpy_tests -v` — all tests pass, no import errors. --- @@ -40,28 +40,46 @@ Move `BoundVariable.gen_call_data_code` (lines 604–693 of [boundvariable.py](s This preserves the existing call interface (`node.gen_call_data_code(cg, context)` in [callsignature.py line 406](slangpy/core/callsignature.py#L406)) and any marshall subclass code that calls `self.gen_calldata_type_name`. The `MAX_INLINE_TYPE_LEN` constant moves to `generator.py`. - **Move** the import of `CodeGen` and `CodeGenBlock` into `generator.py` (already needed for Step 1). -**Verify**: `pytest slangpy/tests -v` — all tests pass. +**Verify**: `pytest slangpy/tests/slangpy_tests -v` — all tests pass. --- -### Step 3: Decompose `generate_code` into sub-functions and move to `generator.py` +### Step 3a: Extract pure-computation helpers in-place in `callsignature.py` -This is the main step. Move `generate_code` (lines 271–603 of [callsignature.py](slangpy/core/callsignature.py#L271-L603)) into `generator.py` and split it into clearly-named sub-functions. Each function has a docstring describing what Slang code it emits. +Extract the two helpers that do **no codegen** — pure calculation/validation only: -The decomposition: +- **Extract** `_validate_and_compute_group_shape(build_info, call_data_len) -> tuple[int, list[int], list[int]]` from lines [293–340](slangpy/core/callsignature.py#L293-L340). Returns `(call_group_size, call_group_strides, call_group_shape_vector)`. +- **Extract** `_data_name(x, use_entrypoint_args) -> str` — deduplicate the two inline occurrences at lines [449](slangpy/core/callsignature.py#L449) and [497](slangpy/core/callsignature.py#L497) into a single helper. Returns `__in_{name}`, `call_data.{name}`, or `_param_{name}`. -| New function | Source lines | What it emits | -|---|---|---| -| `_validate_and_compute_group_shape(build_info, call_data_len) -> tuple[int, list[int], list[int]]` | [293–340](slangpy/core/callsignature.py#L293-L340) | Nothing — pure validation. Returns `(call_group_size, call_group_strides, call_group_shape_vector)`. | -| `_emit_link_time_constants(cg, build_info, call_data_len, call_group_size, call_group_strides, call_group_shape_vector)` | [342–371](slangpy/core/callsignature.py#L342-L371) | `export static const int call_data_len = ...`, group stride/shape arrays. Also calls `generate_constants()`. | -| `_emit_shape_and_metadata_params(cg, call_data_len, use_entrypoint_args)` | [373–403](slangpy/core/callsignature.py#L373-L403) | `_grid_stride`, `_grid_dim`, `_call_dim`, `_thread_count` — as entry-point params (fast) or `CallData` fields (fallback). | -| `_emit_call_data_definitions(cg, context, signature)` | [405–406](slangpy/core/callsignature.py#L405-L406) | Per-variable call data (wrapper structs, type aliases, mapping constants). Calls `gen_call_data_code` on each node. | -| `_data_name(x, use_entrypoint_args) -> str` | Duplicated at [449](slangpy/core/callsignature.py#L449) and [497](slangpy/core/callsignature.py#L497) | Helper: returns `__in_{name}`, `call_data.{name}`, or `_param_{name}`. | -| `_emit_trampoline(cg, context, build_info, signature, root_params, use_entrypoint_args)` | [408–500](slangpy/core/callsignature.py#L408-L500) | `[Differentiable] void _trampoline(...)` — param declarations, loads, function call, stores. | -| `_emit_entry_point_signature(cg, build_info, call_data_len, call_group_size, use_entrypoint_args)` | [503–541](slangpy/core/callsignature.py#L503-L541) | `[shader("compute")] [numthreads(...)] void compute_main(...)` or `[shader("raygen")] void raygen_main(...)`. | -| `_emit_kernel_body(cg, context, build_info, root_params, call_data_len, use_entrypoint_args)` | [543–603](slangpy/core/callsignature.py#L543-L603) | Bounds check, `init_thread_local_call_shape_info`, Context construction, trampoline call. | +Leave both in `callsignature.py` as module-private functions. `generate_code` calls them. -The top-level `generate_code` becomes a ~30-line orchestrator that calls these in order: +**Verify**: `pytest slangpy/tests/slangpy_tests -v` — all tests pass. + +--- + +### Step 3b: Extract "setup" emission functions in-place in `callsignature.py` + +Extract the three functions that emit the top section of the generated kernel: + +- **Extract** `_emit_link_time_constants(cg, build_info, call_data_len, call_group_size, call_group_strides, call_group_shape_vector)` from lines [342–371](slangpy/core/callsignature.py#L342-L371). Emits `export static const int call_data_len = ...`, group stride/shape arrays; calls `generate_constants()`. +- **Extract** `_emit_shape_and_metadata_params(cg, call_data_len, use_entrypoint_args)` from lines [373–403](slangpy/core/callsignature.py#L373-L403). Emits `_grid_stride`, `_grid_dim`, `_call_dim`, `_thread_count` — as entry-point params (fast path) or `CallData` fields (fallback). +- **Extract** `_emit_call_data_definitions(cg, context, signature)` from lines [405–406](slangpy/core/callsignature.py#L405-L406). Emits per-variable call data (wrapper structs, type aliases, mapping constants) by calling `gen_call_data_code` on each node. + +Leave all three in `callsignature.py`. `generate_code` calls them. + +**Verify**: `pytest slangpy/tests/slangpy_tests -v` — all tests pass. Run `$env:SLANGPY_PRINT_GENERATED_SHADERS="1"; pytest slangpy/tests/slangpy_tests/test_code_gen.py -v` and capture output as the baseline for Step 3c and 3d. + +--- + +### Step 3c: Extract "body" emission functions in-place in `callsignature.py` + +Extract the remaining three functions that emit the entry point and kernel body: + +- **Extract** `_emit_trampoline(cg, context, build_info, root_params, use_entrypoint_args)` from lines [408–500](slangpy/core/callsignature.py#L408-L500). Emits `[Differentiable] void _trampoline(...)` — param declarations, loads, function call, stores. +- **Extract** `_emit_entry_point_signature(cg, build_info, call_data_len, call_group_size, use_entrypoint_args)` from lines [503–541](slangpy/core/callsignature.py#L503-L541). Emits `[shader("compute")] [numthreads(...)] void compute_main(...)` or `[shader("raygen")] void raygen_main(...)`. +- **Extract** `_emit_kernel_body(cg, context, build_info, root_params, call_data_len, use_entrypoint_args)` from lines [543–603](slangpy/core/callsignature.py#L543-L603). Emits bounds check, `init_thread_local_call_shape_info`, Context construction, trampoline call. + +At this point `generate_code` is reduced to the ~30-line orchestrator below. Still in `callsignature.py`. ```python def generate_code(context, build_info, signature, cg): @@ -88,10 +106,19 @@ def generate_code(context, build_info, signature, cg): cg.kernel.end_block() ``` -- **In [callsignature.py](slangpy/core/callsignature.py)**: Delete `generate_code` and add `from slangpy.core.generator import generate_code` (or let the existing wildcard import consumer in [calldata.py](slangpy/core/calldata.py) point to `generator` instead). -- **Update [calldata.py](slangpy/core/calldata.py#L8)**: Change `from slangpy.core.callsignature import *` to explicit imports: binding-pipeline functions from `callsignature`, and `generate_code`, `KernelGenException` from `generator`. This eliminates the wildcard import, making dependencies explicit. +**Verify**: `pytest slangpy/tests/slangpy_tests -v` — all tests pass. Re-run `$env:SLANGPY_PRINT_GENERATED_SHADERS="1"; pytest slangpy/tests/slangpy_tests/test_code_gen.py -v` and confirm output is byte-identical to the Step 3b baseline. + +--- -**Verify**: `pytest slangpy/tests -v` — all tests pass. `$env:SLANGPY_PRINT_GENERATED_SHADERS="1"; pytest slangpy/tests/slangpy_tests/test_code_gen.py -v` — generated code unchanged. +### Step 3d: Move all codegen symbols from `callsignature.py` to `generator.py` and fix imports + +Now that everything is neatly decomposed, do the pure mechanical move: + +- **Move** all seven `_emit_*`/`_validate_*`/`_data_name` private helpers and the `generate_code` orchestrator from `callsignature.py` into `generator.py`. +- **In [callsignature.py](slangpy/core/callsignature.py)**: Delete the moved code; add `from slangpy.core.generator import generate_code` re-export so any consumer that imports `generate_code` from `callsignature` continues to work. +- **Update [calldata.py](slangpy/core/calldata.py#L8)**: Replace `from slangpy.core.callsignature import *` with explicit imports — binding-pipeline functions from `callsignature`, and `generate_code`, `KernelGenException` from `generator`. This eliminates the wildcard import, making dependencies explicit. + +**Verify**: `pytest slangpy/tests/slangpy_tests -v` — all tests pass. Re-run `$env:SLANGPY_PRINT_GENERATED_SHADERS="1"; pytest slangpy/tests/slangpy_tests/test_code_gen.py -v` — output byte-identical to Step 3b baseline. --- @@ -103,7 +130,7 @@ After Step 3, `callsignature.py` no longer has any codegen functions. Clean up: - Remove re-exports of moved symbols once [calldata.py](slangpy/core/calldata.py) uses direct imports from `generator`. - Add `from slangpy.core.generator import KernelGenException, ResolveException` re-exports **only if** external consumers import them from `callsignature` (check via grep). If only `calldata.py` uses them, the explicit import is sufficient. -**Verify**: `pytest slangpy/tests -v`. `pre-commit run --all-files`. +**Verify**: `pytest slangpy/tests/slangpy_tests -v`. `pre-commit run --all-files`. --- @@ -146,17 +173,15 @@ This is documentation-only, no functional changes. At each step: ```bash cmake --build --preset windows-msvc-debug -pytest slangpy/tests -v +pytest slangpy/tests/slangpy_tests -v pre-commit run --all-files ``` -After Step 3 specifically, also verify generated shader output is unchanged: +After Step 3b specifically, capture generated shader output as a baseline; re-run after 3c and 3d to confirm byte-identical output: ```powershell $env:SLANGPY_PRINT_GENERATED_SHADERS="1"; pytest slangpy/tests/slangpy_tests/test_code_gen.py -v ``` -Compare output before/after to confirm byte-identical generated Slang code. - --- ### Decisions From dde44233a7433f4176e3ea37304426742efc65fe Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Fri, 13 Mar 2026 15:48:04 +0000 Subject: [PATCH 31/41] more extracting --- .../plan-extractCodegenToGenerator.prompt.md | 8 +- slangpy/bindings/boundvariable.py | 111 +--------------- slangpy/core/generator.py | 119 +++++++++++++++++- 3 files changed, 129 insertions(+), 109 deletions(-) diff --git a/.github/prompts/plan-extractCodegenToGenerator.prompt.md b/.github/prompts/plan-extractCodegenToGenerator.prompt.md index 22f4bc60a..84d21874a 100644 --- a/.github/prompts/plan-extractCodegenToGenerator.prompt.md +++ b/.github/prompts/plan-extractCodegenToGenerator.prompt.md @@ -18,6 +18,8 @@ Move these small, self-contained pieces first: **Verify**: `pytest slangpy/tests -v` — all tests pass, no import errors. +**DONE**: Created `slangpy/core/generator.py` with `KernelGenException`, `_is_slangpy_vector`, `generate_constants`. Replaced definitions in `callsignature.py` with re-exports. Updated `dispatchdata.py` import. 4999 passed, 5 pre-existing failures (raytrace d3d12, type conformance cache). + --- ### Step 2: Extract `gen_call_data_code` as a free function @@ -40,7 +42,9 @@ Move `BoundVariable.gen_call_data_code` (lines 604–693 of [boundvariable.py](s This preserves the existing call interface (`node.gen_call_data_code(cg, context)` in [callsignature.py line 406](slangpy/core/callsignature.py#L406)) and any marshall subclass code that calls `self.gen_calldata_type_name`. The `MAX_INLINE_TYPE_LEN` constant moves to `generator.py`. - **Move** the import of `CodeGen` and `CodeGenBlock` into `generator.py` (already needed for Step 1). -**Verify**: `pytest slangpy/tests -v` — all tests pass. +**Verify**: `pytest slangpy/tests/slangpy_tests -v` — all tests pass. + +**DONE**: Moved `gen_call_data_code` and `gen_calldata_type_name` to `generator.py` as free functions. `MAX_INLINE_TYPE_LEN` moved to `generator.py`, re-exported from `boundvariable.py`. Method bodies replaced with thin delegation stubs. 3294 passed, 285 kernel gen tests passed. --- @@ -103,7 +107,7 @@ After Step 3, `callsignature.py` no longer has any codegen functions. Clean up: - Remove re-exports of moved symbols once [calldata.py](slangpy/core/calldata.py) uses direct imports from `generator`. - Add `from slangpy.core.generator import KernelGenException, ResolveException` re-exports **only if** external consumers import them from `callsignature` (check via grep). If only `calldata.py` uses them, the explicit import is sufficient. -**Verify**: `pytest slangpy/tests -v`. `pre-commit run --all-files`. +**Verify**: `pytest slangpy/tests/slangpy_tests -v`. `pre-commit run --all-files`. --- diff --git a/slangpy/bindings/boundvariable.py b/slangpy/bindings/boundvariable.py index baed05091..2fa14851c 100644 --- a/slangpy/bindings/boundvariable.py +++ b/slangpy/bindings/boundvariable.py @@ -16,10 +16,7 @@ ) from slangpy.reflection.typeresolution import ResolvedParam -#: Type names longer than this threshold get a ``typealias _t_{name}`` alias -#: to keep the generated ``CallData`` struct readable. Shorter names are -#: inlined directly. -MAX_INLINE_TYPE_LEN = 60 +from slangpy.core.generator import MAX_INLINE_TYPE_LEN # noqa: F401 class BoundVariableException(Exception): @@ -281,21 +278,9 @@ def debug_name(self) -> str: return f"arg{self.python_pos_arg_index}" def gen_calldata_type_name(self, cgb: CodeGenBlock, type_name: str) -> None: - """Record the Slang type name for this variable's CallData field. + from slangpy.core.generator import gen_calldata_type_name - If the type name exceeds ``MAX_INLINE_TYPE_LEN``, a - ``typealias _t_{name}`` is emitted and the alias is stored. - Otherwise the raw type name is stored directly. - - :param cgb: The code-gen block to write the type alias to (if needed). - :param type_name: The resolved Slang type name. - """ - if len(type_name) > MAX_INLINE_TYPE_LEN: - alias = f"_t_{self.variable_name}" - cgb.type_alias(alias, type_name) - self.calldata_type_name = alias - else: - self.calldata_type_name = type_name + gen_calldata_type_name(self, cgb, type_name) def bind( self, @@ -602,95 +587,9 @@ def _calculate_differentiability(self, mode: CallMode): self.access = (AccessType.none, AccessType.none) def gen_call_data_code(self, cg: CodeGen, context: BindContext, depth: int = 0): - if self.children is not None: - cgb = cg.call_data_structs + from slangpy.core.generator import gen_call_data_code - if self.direct_bind: - # Direct-bind: use raw type name directly - assert self.vector_type is not None - self.gen_calldata_type_name(cgb, self.vector_type.full_name) - else: - struct_name = f"_t_{self.variable_name}" - cgb.begin_struct(struct_name) - - for field, variable in self.children.items(): - variable.gen_call_data_code(cg, context, depth + 1) - - for var in self.children.values(): - assert ( - var.calldata_type_name is not None - ), f"calldata_type_name not set for '{var.variable_name}'" - cgb.declare(var.calldata_type_name, var.variable_name) - - assert self.vector_type is not None - context_decl = f"ContextND<{self.call_dimensionality}> context" - value_decl = f"{self.vector_type.full_name} value" - prefix = "[Differentiable]" if self.access[1] != AccessType.none else "" - - if self.access[0] in (AccessType.read, AccessType.readwrite): - cgb.empty_line() - cgb.append_line( - f"{prefix} void __slangpy_load({context_decl}, out {value_decl})" - ) - cgb.begin_block() - for field, var in self.children.items(): - gen_load = getattr(var.python, "gen_trampoline_load", None) - if gen_load is not None and gen_load( - cgb, var, var.variable_name, f"value.{field}" - ): - continue - cgb.append_statement( - f"{var.variable_name}.__slangpy_load(context.map(_m_{var.variable_name}),value.{field})" - ) - cgb.end_block() - - if self.access[0] in (AccessType.write, AccessType.readwrite): - cgb.empty_line() - cgb.append_line( - f"{prefix} void __slangpy_store({context_decl}, in {value_decl})" - ) - cgb.begin_block() - for field, var in self.children.items(): - gen_store = getattr(var.python, "gen_trampoline_store", None) - if gen_store is not None and gen_store( - cgb, var, var.variable_name, f"value.{field}" - ): - continue - cgb.append_statement( - f"{var.variable_name}.__slangpy_store(context.map(_m_{var.variable_name}),value.{field})" - ) - cgb.end_block() - - cgb.end_struct() - self.calldata_type_name = struct_name - - else: - # Generate call data - self.python.gen_calldata(cg.call_data_structs, context, self) - - # Skip mapping constants for direct-bind variables (they bypass __slangpy_load/store) - if not self.direct_bind: - if len(self.vector_mapping) > 0: - cg.call_data_structs.append_statement( - f"static const int[] _m_{self.variable_name} = {{ {','.join([str(x) for x in self.vector_mapping.as_tuple()])} }}" - ) - else: - cg.call_data_structs.append_statement( - f"static const int _m_{self.variable_name} = 0" - ) - - if depth == 0: - assert ( - self.calldata_type_name is not None - ), f"calldata_type_name not set for '{self.variable_name}'" - if self.create_param_block: - cg.add_parameter_block(self.calldata_type_name, "_param_" + self.variable_name) - elif cg.skip_call_data: - cg.entry_point_params.append( - f"uniform {self.calldata_type_name} {self.variable_name}" - ) - else: - cg.call_data.declare(self.calldata_type_name, self.variable_name) + gen_call_data_code(self, cg, context, depth) def __str__(self) -> str: return self._recurse_str(0) diff --git a/slangpy/core/generator.py b/slangpy/core/generator.py index 382cf549f..6d75385b7 100644 --- a/slangpy/core/generator.py +++ b/slangpy/core/generator.py @@ -1,11 +1,19 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from typing import TYPE_CHECKING, Any -from slangpy.bindings.codegen import CodeGen +from slangpy.bindings.codegen import CodeGen, CodeGenBlock +from slangpy.core.native import AccessType if TYPE_CHECKING: + from slangpy.bindings.boundvariable import BoundVariable + from slangpy.bindings.marshall import BindContext from slangpy.core.function import FunctionBuildInfo +#: Type names longer than this threshold get a ``typealias _t_{name}`` alias +#: to keep the generated ``CallData`` struct readable. Shorter names are +#: inlined directly. +MAX_INLINE_TYPE_LEN = 60 + class KernelGenException(Exception): def __init__(self, message: str): @@ -40,3 +48,112 @@ def generate_constants(build_info: "FunctionBuildInfo", cg: CodeGen) -> None: raise KernelGenException( f"Constant value '{k}' must be an int, float or bool, not {type(v).__name__}" ) + + +def gen_calldata_type_name(binding: "BoundVariable", cgb: CodeGenBlock, type_name: str) -> None: + """Record the Slang type name for this variable's CallData field. + + If the type name exceeds ``MAX_INLINE_TYPE_LEN``, a + ``typealias _t_{name}`` is emitted and the alias is stored. + Otherwise the raw type name is stored directly. + + :param binding: The bound variable to update. + :param cgb: The code-gen block to write the type alias to (if needed). + :param type_name: The resolved Slang type name. + """ + if len(type_name) > MAX_INLINE_TYPE_LEN: + alias = f"_t_{binding.variable_name}" + cgb.type_alias(alias, type_name) + binding.calldata_type_name = alias + else: + binding.calldata_type_name = type_name + + +def gen_call_data_code( + binding: "BoundVariable", cg: CodeGen, context: "BindContext", depth: int = 0 +) -> None: + if binding.children is not None: + cgb = cg.call_data_structs + + if binding.direct_bind: + # Direct-bind: use raw type name directly + assert binding.vector_type is not None + gen_calldata_type_name(binding, cgb, binding.vector_type.full_name) + else: + struct_name = f"_t_{binding.variable_name}" + cgb.begin_struct(struct_name) + + for field, variable in binding.children.items(): + gen_call_data_code(variable, cg, context, depth + 1) + + for var in binding.children.values(): + assert ( + var.calldata_type_name is not None + ), f"calldata_type_name not set for '{var.variable_name}'" + cgb.declare(var.calldata_type_name, var.variable_name) + + assert binding.vector_type is not None + context_decl = f"ContextND<{binding.call_dimensionality}> context" + value_decl = f"{binding.vector_type.full_name} value" + prefix = "[Differentiable]" if binding.access[1] != AccessType.none else "" + + if binding.access[0] in (AccessType.read, AccessType.readwrite): + cgb.empty_line() + cgb.append_line(f"{prefix} void __slangpy_load({context_decl}, out {value_decl})") + cgb.begin_block() + for field, var in binding.children.items(): + gen_load = getattr(var.python, "gen_trampoline_load", None) + if gen_load is not None and gen_load( + cgb, var, var.variable_name, f"value.{field}" + ): + continue + cgb.append_statement( + f"{var.variable_name}.__slangpy_load(context.map(_m_{var.variable_name}),value.{field})" + ) + cgb.end_block() + + if binding.access[0] in (AccessType.write, AccessType.readwrite): + cgb.empty_line() + cgb.append_line(f"{prefix} void __slangpy_store({context_decl}, in {value_decl})") + cgb.begin_block() + for field, var in binding.children.items(): + gen_store = getattr(var.python, "gen_trampoline_store", None) + if gen_store is not None and gen_store( + cgb, var, var.variable_name, f"value.{field}" + ): + continue + cgb.append_statement( + f"{var.variable_name}.__slangpy_store(context.map(_m_{var.variable_name}),value.{field})" + ) + cgb.end_block() + + cgb.end_struct() + binding.calldata_type_name = struct_name + + else: + # Generate call data + binding.python.gen_calldata(cg.call_data_structs, context, binding) + + # Skip mapping constants for direct-bind variables (they bypass __slangpy_load/store) + if not binding.direct_bind: + if len(binding.vector_mapping) > 0: + cg.call_data_structs.append_statement( + f"static const int[] _m_{binding.variable_name} = {{ {','.join([str(x) for x in binding.vector_mapping.as_tuple()])} }}" + ) + else: + cg.call_data_structs.append_statement( + f"static const int _m_{binding.variable_name} = 0" + ) + + if depth == 0: + assert ( + binding.calldata_type_name is not None + ), f"calldata_type_name not set for '{binding.variable_name}'" + if binding.create_param_block: + cg.add_parameter_block(binding.calldata_type_name, "_param_" + binding.variable_name) + elif cg.skip_call_data: + cg.entry_point_params.append( + f"uniform {binding.calldata_type_name} {binding.variable_name}" + ) + else: + cg.call_data.declare(binding.calldata_type_name, binding.variable_name) From 5e24b646ca03967f0018cf0a1e4aea9851e5e6cc Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Fri, 13 Mar 2026 16:00:24 +0000 Subject: [PATCH 32/41] Refactor code generation for call data handling - Moved gen_calldata_type_name and gen_call_data_code functions to generator.py for better organization. - Simplified gen_calldata_type_name logic and integrated it into the call data generation process. - Enhanced gen_call_data_code to handle both direct-bind and structured variables more effectively. - Updated calls to these functions in BoundVariable class to streamline code generation. --- slangpy/bindings/boundvariable.py | 111 +----- slangpy/core/callsignature.py | 347 +------------------ slangpy/core/generator.py | 551 +++++++++++++++++++++++++++++- 3 files changed, 560 insertions(+), 449 deletions(-) diff --git a/slangpy/bindings/boundvariable.py b/slangpy/bindings/boundvariable.py index baed05091..c4fa028bb 100644 --- a/slangpy/bindings/boundvariable.py +++ b/slangpy/bindings/boundvariable.py @@ -281,21 +281,15 @@ def debug_name(self) -> str: return f"arg{self.python_pos_arg_index}" def gen_calldata_type_name(self, cgb: CodeGenBlock, type_name: str) -> None: - """Record the Slang type name for this variable's CallData field. + """Record the Slang type name for this variable's CallData field.""" + from slangpy.core.generator import gen_calldata_type_name - If the type name exceeds ``MAX_INLINE_TYPE_LEN``, a - ``typealias _t_{name}`` is emitted and the alias is stored. - Otherwise the raw type name is stored directly. + gen_calldata_type_name(self, cgb, type_name) - :param cgb: The code-gen block to write the type alias to (if needed). - :param type_name: The resolved Slang type name. - """ - if len(type_name) > MAX_INLINE_TYPE_LEN: - alias = f"_t_{self.variable_name}" - cgb.type_alias(alias, type_name) - self.calldata_type_name = alias - else: - self.calldata_type_name = type_name + def gen_call_data_code(self, cg: CodeGen, context: BindContext, depth: int = 0): + from slangpy.core.generator import gen_call_data_code + + gen_call_data_code(self, cg, context, depth) def bind( self, @@ -601,97 +595,6 @@ def _calculate_differentiability(self, mode: CallMode): # todo: fwds self.access = (AccessType.none, AccessType.none) - def gen_call_data_code(self, cg: CodeGen, context: BindContext, depth: int = 0): - if self.children is not None: - cgb = cg.call_data_structs - - if self.direct_bind: - # Direct-bind: use raw type name directly - assert self.vector_type is not None - self.gen_calldata_type_name(cgb, self.vector_type.full_name) - else: - struct_name = f"_t_{self.variable_name}" - cgb.begin_struct(struct_name) - - for field, variable in self.children.items(): - variable.gen_call_data_code(cg, context, depth + 1) - - for var in self.children.values(): - assert ( - var.calldata_type_name is not None - ), f"calldata_type_name not set for '{var.variable_name}'" - cgb.declare(var.calldata_type_name, var.variable_name) - - assert self.vector_type is not None - context_decl = f"ContextND<{self.call_dimensionality}> context" - value_decl = f"{self.vector_type.full_name} value" - prefix = "[Differentiable]" if self.access[1] != AccessType.none else "" - - if self.access[0] in (AccessType.read, AccessType.readwrite): - cgb.empty_line() - cgb.append_line( - f"{prefix} void __slangpy_load({context_decl}, out {value_decl})" - ) - cgb.begin_block() - for field, var in self.children.items(): - gen_load = getattr(var.python, "gen_trampoline_load", None) - if gen_load is not None and gen_load( - cgb, var, var.variable_name, f"value.{field}" - ): - continue - cgb.append_statement( - f"{var.variable_name}.__slangpy_load(context.map(_m_{var.variable_name}),value.{field})" - ) - cgb.end_block() - - if self.access[0] in (AccessType.write, AccessType.readwrite): - cgb.empty_line() - cgb.append_line( - f"{prefix} void __slangpy_store({context_decl}, in {value_decl})" - ) - cgb.begin_block() - for field, var in self.children.items(): - gen_store = getattr(var.python, "gen_trampoline_store", None) - if gen_store is not None and gen_store( - cgb, var, var.variable_name, f"value.{field}" - ): - continue - cgb.append_statement( - f"{var.variable_name}.__slangpy_store(context.map(_m_{var.variable_name}),value.{field})" - ) - cgb.end_block() - - cgb.end_struct() - self.calldata_type_name = struct_name - - else: - # Generate call data - self.python.gen_calldata(cg.call_data_structs, context, self) - - # Skip mapping constants for direct-bind variables (they bypass __slangpy_load/store) - if not self.direct_bind: - if len(self.vector_mapping) > 0: - cg.call_data_structs.append_statement( - f"static const int[] _m_{self.variable_name} = {{ {','.join([str(x) for x in self.vector_mapping.as_tuple()])} }}" - ) - else: - cg.call_data_structs.append_statement( - f"static const int _m_{self.variable_name} = 0" - ) - - if depth == 0: - assert ( - self.calldata_type_name is not None - ), f"calldata_type_name not set for '{self.variable_name}'" - if self.create_param_block: - cg.add_parameter_block(self.calldata_type_name, "_param_" + self.variable_name) - elif cg.skip_call_data: - cg.entry_point_params.append( - f"uniform {self.calldata_type_name} {self.variable_name}" - ) - else: - cg.call_data.declare(self.calldata_type_name, self.variable_name) - def __str__(self) -> str: return self._recurse_str(0) diff --git a/slangpy/core/callsignature.py b/slangpy/core/callsignature.py index 9e30265a5..d2d1929b3 100644 --- a/slangpy/core/callsignature.py +++ b/slangpy/core/callsignature.py @@ -1,18 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from typing import TYPE_CHECKING, Any, Optional -from slangpy.core.native import AccessType, CallMode -from slangpy.core.function import PipelineType +from slangpy.core.native import CallMode import slangpy.bindings.typeregistry as tr from slangpy import ModifierID, TypeReflection from slangpy.bindings.marshall import BindContext, ReturnContext -from slangpy.bindings.boundvariable import ( - BoundCall, - BoundVariable, - BoundVariableException, -) -from slangpy.bindings.codegen import CodeGen +from slangpy.bindings.boundvariable import BoundCall, BoundVariable from slangpy.builtin.value import NoneMarshall from slangpy.reflection.reflectiontypes import ( SlangFunction, @@ -38,7 +32,7 @@ def __init__(self, message: str): self.message = message -from slangpy.core.generator import KernelGenException, generate_constants # noqa: F401 +from slangpy.core.generator import KernelGenException, generate_constants, generate_code # noqa: F401 # This detects if a type is a vector with its length defined by a generic @@ -233,338 +227,3 @@ def create_return_value_binding(context: BindContext, signature: BoundCall, retu node.call_dimensionality = context.call_dimensionality node.python = python_type - - -def generate_code( - context: BindContext, - build_info: "FunctionBuildInfo", - signature: BoundCall, - cg: CodeGen, -) -> None: - """ - Generate Slang kernel code for the given function call signature. - """ - - # Check if we're using direct entry-point params (fast path) - use_entrypoint_args = context.use_entrypoint_args - - # Generate the header - cg.add_import("slangpy") - - call_data_len = context.call_dimensionality - - # Get the call group size so we can see about using it when generating the - # [numthreads(...)] attribute. We use 1 as the default size if a call - # group shape has not been set, as we can use that to make things "linear". - # Note that when size is 1, we will still launch a group of 32 threads, - # but each thread is conceptually in its own group of size 1. This then - # leads to a linearly calculated call_id based on threadID.x and the call - # shape's size and strides. - call_group_size = 1 - call_group_shape = build_info.call_group_shape - if call_group_shape is not None: - call_group_shape_vector = call_group_shape.as_list() - - # Validate call_group_shape dimensionality and values before using them - if len(call_group_shape_vector) > context.call_dimensionality: - raise KernelGenException( - f"call_group_shape dimensionality ({len(call_group_shape_vector)}) must be <= " - f"call_shape dimensionality ({context.call_dimensionality}). " - f"call_group_shape cannot have more dimensions than call_shape." - ) - elif len(call_group_shape_vector) < context.call_dimensionality: - # Call group shape size is less than the call shape size so we need to - # pad the call group shape with 1's to account for the missing dimensions. - # However, inserting at the front of the list will be inefficient, so - # log a debug message, giving users a chance to correct their calls. - - missing_dims = context.call_dimensionality - len(call_group_shape_vector) - - # Pad with 1's at the beginning - call_group_shape_vector = [1] * missing_dims + call_group_shape_vector - - # Validate that all call_group_shape values are >= 1 - for i, dim in enumerate(call_group_shape_vector): - if dim < 1: - raise KernelGenException( - f"call_group_shape[{i}] = {dim} is invalid. " - f"All call_group_shape elements must be >= 1." - ) - - # Calculate call group size as product of all dimensions - # Also grab the group strides here as that will allow us - # to use the group shape as constants to improve perf - call_group_strides = [] - for dim in call_group_shape_vector[::-1]: - call_group_strides.append(call_group_size) - call_group_size *= dim - call_group_strides.reverse() - - # Check if call_group_size exceeds hardware limits - if call_group_size > 1024: - raise KernelGenException( - f"call_group_size ({call_group_size}) exceeds the typical 1024 maximum " - f"enforced by most APIs. Consider reducing your call_group_shape dimensions." - ) - - cg.add_import(build_info.module.name) - - # Generate constants if specified - generate_constants(build_info, cg) - - # Generate additional link time constants definition code. These are declared in callshape.slang - # and used to generated call_ids that can be queried by user modules. - cg.constants.append_statement(f"export static const int call_data_len = {call_data_len}") - cg.constants.append_statement(f"export static const int call_group_size = {call_group_size}") - - # Also generate the call group shape and stride arrays as link time constants. Using constants - # should yield better performance than passing these in as uniforms. - cg.constants.append_line(f"export static const int[call_data_len] call_group_strides = {{") - cg.constants.inc_indent() - if call_group_size != 1: - for i in range(call_data_len): - cg.constants.append_line(f"{call_group_strides[i]},") - cg.constants.dec_indent() - cg.constants.append_statement("}") - - cg.constants.append_line(f"export static const int[call_data_len] call_group_shape_vector = {{") - cg.constants.inc_indent() - if call_group_size != 1: - for i in range(call_data_len): - cg.constants.append_line(f"{call_group_shape_vector[i]},") - cg.constants.dec_indent() - cg.constants.append_statement("}") - - # Set up code gen mode for direct args vs CallData struct - if use_entrypoint_args: - cg.skip_call_data = True - - # Generate call data inputs if vector call - if call_data_len > 0: - if use_entrypoint_args: - # Fast path: shape arrays as individual entry-point params - cg.entry_point_params.append(f"uniform int[{call_data_len}] _grid_stride") - cg.entry_point_params.append(f"uniform int[{call_data_len}] _grid_dim") - cg.entry_point_params.append(f"uniform int[{call_data_len}] _call_dim") - else: - # Fallback: shape arrays inside CallData struct - # A group can be thought of as a "window" looking at a - # portion of the entire call shape. Grid here refers to the - # N dimensional call shape being broken up into some number of N - # dimensional "window"s / groups. - cg.call_data.append_statement(f"int[{call_data_len}] _grid_stride") - cg.call_data.append_statement(f"int[{call_data_len}] _grid_dim") - # We use the call shape dimensions to detect cases when the call shape - # and the call group shape are not aligned. When a thread's call id - # falls outside the call shape, we need it to return early. This is - # similar to the default linear case when the call shape size is not - # 32 thread aligned. - cg.call_data.append_statement(f"int[{call_data_len}] _call_dim") - - if use_entrypoint_args: - cg.entry_point_params.append("uniform uint3 _thread_count") - else: - cg.call_data.append_statement("uint3 _thread_count") - - # Generate call data definitions for all inputs to the kernel - for node in signature.values(): - node.gen_call_data_code(cg, context) - - # Get sorted list of root parameters for trampoline function - root_params = sorted(signature.values(), key=lambda x: x.param_index) - - # Generate the trampoline function - trampoline_fn = "_trampoline" - if context.call_mode != CallMode.prim: - cg.trampoline.append_line("[Differentiable]") - - if use_entrypoint_args: - # Fast path: trampoline takes individual calldata-typed params. - # Use __in_ prefix for param names to avoid collision with local variable names. - # All params are no_diff — entry-point uniforms are never differentiable. - # Differentiation happens through local variable assignments inside the trampoline, - # matching the struct-based approach where CallData was implicitly non-differentiable. - trampoline_params = ["Context __slangpy_context__"] - for x in root_params: - if x.create_param_block: - continue # param blocks handled via _param_ at module scope - assert x.calldata_type_name is not None - arg_def = f"no_diff {x.calldata_type_name} __in_{x.variable_name}" - trampoline_params.append(arg_def) - cg.trampoline.append_line(f"void {trampoline_fn}({', '.join(trampoline_params)})") - else: - # Fallback: trampoline reads from global ParameterBlock call_data - cg.trampoline.append_line(f"void {trampoline_fn}(Context __slangpy_context__)") - cg.trampoline.begin_block() - - # Declare parameters and load inputs - for x in root_params: - assert x.vector_type is not None - cg.trampoline.declare(x.vector_type.full_name, x.variable_name) - for x in root_params: - if use_entrypoint_args: - data_name = ( - f"_param_{x.variable_name}" if x.create_param_block else f"__in_{x.variable_name}" - ) - else: - data_name = ( - f"_param_{x.variable_name}" - if x.create_param_block - else f"call_data.{x.variable_name}" - ) - gen_load = getattr(x.python, "gen_trampoline_load", None) - if gen_load is not None and gen_load(cg.trampoline, x, data_name, x.variable_name): - continue - if x.access[0] == AccessType.read or x.access[0] == AccessType.readwrite: - cg.trampoline.append_statement( - f"{data_name}.__slangpy_load(__slangpy_context__.map(_m_{x.variable_name}), {x.variable_name})" - ) - - cg.trampoline.append_indent() - if any(x.variable_name == "_result" for x in root_params): - cg.trampoline.append_code(f"_result = ") - - # Get function name, if it's the init function, use the result type - func_name = build_info.name - if func_name == "$init": - results = [x for x in root_params if x.variable_name == "_result"] - assert len(results) == 1 - assert results[0].vector_type is not None - func_name = results[0].vector_type.full_name - elif len(root_params) > 0 and root_params[0].variable_name == "_this": - func_name = f"_this.{func_name}" - - # Get the parameters that are not the result or this reference - normal_params = [ - x for x in root_params if x.variable_name != "_result" and x.variable_name != "_this" - ] - - # Internal call to the actual function - cg.trampoline.append_code( - f"{func_name}(" + ", ".join(x.variable_name for x in normal_params) + ");\n" - ) - - # For each writable trampoline parameter, potentially store it - for x in root_params: - if ( - x.access[0] == AccessType.write - or x.access[0] == AccessType.readwrite - or x.access[1] == AccessType.read - ): - if use_entrypoint_args: - data_name = ( - f"_param_{x.variable_name}" - if x.create_param_block - else f"__in_{x.variable_name}" - ) - else: - data_name = ( - f"_param_{x.variable_name}" - if x.create_param_block - else f"call_data.{x.variable_name}" - ) - gen_store = getattr(x.python, "gen_trampoline_store", None) - if gen_store is not None and gen_store(cg.trampoline, x, data_name, x.variable_name): - continue - if not x.python.is_writable: - raise BoundVariableException(f"Cannot read back value for non-writable type", x) - cg.trampoline.append_statement( - f"{data_name}.__slangpy_store(__slangpy_context__.map(_m_{x.variable_name}), {x.variable_name})" - ) - - cg.trampoline.end_block() - cg.trampoline.append_line("") - - # Generate the main function - if build_info.pipeline_type == PipelineType.compute: - cg.kernel.append_line('[shader("compute")]') - if call_group_size != 1: - cg.kernel.append_line(f"[numthreads({call_group_size}, 1, 1)]") - else: - cg.kernel.append_line("[numthreads(32, 1, 1)]") - # Note: While flat_call_thread_id is 3-dimensional, we consider it "flat" and 1-dimensional because of the - # true call group shape of [x, 1, 1] and only use the first dimension for the call thread id. - if use_entrypoint_args: - # Fast path: build compute_main signature with individual entry-point params - sig_parts = ["int3 flat_call_thread_id: SV_DispatchThreadID"] - # Only include SV_GroupID/SV_GroupIndex when call_data_len > 0 - # (they feed init_thread_local_call_shape_info which isn't called otherwise) - if call_data_len > 0: - sig_parts.append("int3 flat_call_group_id: SV_GroupID") - sig_parts.append("int flat_call_group_thread_id: SV_GroupIndex") - sig_parts.extend(cg.entry_point_params) - cg.kernel.append_line(f"void compute_main({', '.join(sig_parts)})") - else: - # Fallback: no uniform params (reads from global ParameterBlock) - cg.kernel.append_line( - "void compute_main(int3 flat_call_thread_id: SV_DispatchThreadID, int3 flat_call_group_id: SV_GroupID, int flat_call_group_thread_id: SV_GroupIndex)" - ) - elif build_info.pipeline_type == PipelineType.ray_tracing: - cg.kernel.append_line('[shader("raygen")]') - if use_entrypoint_args: - sig_parts = list(cg.entry_point_params) - cg.kernel.append_line(f"void raygen_main({', '.join(sig_parts)})") - else: - cg.kernel.append_line("void raygen_main()") - else: - raise RuntimeError(f"Unknown pipeline type: {build_info.pipeline_type}") - - cg.kernel.begin_block() - - if build_info.pipeline_type == PipelineType.ray_tracing: - cg.kernel.append_statement("int3 flat_call_thread_id = DispatchRaysIndex();") - - # Bounds check — use _thread_count directly in fast path, call_data._thread_count in fallback - if use_entrypoint_args: - cg.kernel.append_statement("if (any(flat_call_thread_id >= _thread_count)) return") - else: - cg.kernel.append_statement( - "if (any(flat_call_thread_id >= call_data._thread_count)) return" - ) - - # Loads / initializes call id - context_args = "flat_call_thread_id" - - # Call init_thread_local_call_shape_info to initialize the call shape info. See - # definition in callshape.slang. - if call_data_len > 0: - # In fast path, shape arrays are direct entry-point params; in fallback, prefixed with call_data. - grid_prefix = "" if use_entrypoint_args else "call_data." - if build_info.pipeline_type == PipelineType.compute: - cg.kernel.append_line( - f""" - if (!init_thread_local_call_shape_info(flat_call_group_thread_id, - flat_call_group_id, flat_call_thread_id, {grid_prefix}_grid_stride, - {grid_prefix}_grid_dim, {grid_prefix}_call_dim)) - return;""" - ) - elif build_info.pipeline_type == PipelineType.ray_tracing: - cg.kernel.append_line( - f""" - if (!init_thread_local_call_shape_info(0, - uint3(0), flat_call_thread_id, {grid_prefix}_grid_stride, - {grid_prefix}_grid_dim, {grid_prefix}_call_dim)) - return;""" - ) - context_args += ", CallShapeInfo::get_call_id().shape" - - cg.kernel.append_statement(f"Context __slangpy_context__ = {{{context_args}}}") - - # Call the trampoline function - fn = trampoline_fn - if context.call_mode == CallMode.bwds: - fn = f"bwd_diff({fn})" - - if use_entrypoint_args: - # Fast path: pass individual entry-point param names to the trampoline - trampoline_args = ["__slangpy_context__"] - for x in root_params: - if x.create_param_block: - continue # param blocks are at module scope - trampoline_args.append(x.variable_name) - cg.kernel.append_statement(f"{fn}({', '.join(trampoline_args)})") - else: - # Fallback: trampoline reads from global call_data - cg.kernel.append_statement(f"{fn}(__slangpy_context__)") - - cg.kernel.end_block() diff --git a/slangpy/core/generator.py b/slangpy/core/generator.py index 382cf549f..cd3468793 100644 --- a/slangpy/core/generator.py +++ b/slangpy/core/generator.py @@ -1,10 +1,18 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception from typing import TYPE_CHECKING, Any -from slangpy.bindings.codegen import CodeGen +from slangpy.bindings.codegen import CodeGen, CodeGenBlock +from slangpy.core.native import AccessType, CallMode if TYPE_CHECKING: from slangpy.core.function import FunctionBuildInfo + from slangpy.bindings.boundvariable import BoundVariable, BoundCall + from slangpy.bindings.marshall import BindContext + +#: Type names longer than this threshold get a ``typealias _t_{name}`` alias +#: to keep the generated ``CallData`` struct readable. Shorter names are +#: inlined directly. +MAX_INLINE_TYPE_LEN = 60 class KernelGenException(Exception): @@ -40,3 +48,544 @@ def generate_constants(build_info: "FunctionBuildInfo", cg: CodeGen) -> None: raise KernelGenException( f"Constant value '{k}' must be an int, float or bool, not {type(v).__name__}" ) + + +def gen_calldata_type_name( + binding: "BoundVariable", cgb: CodeGenBlock, type_name: str +) -> None: + """Record the Slang type name for this variable's CallData field. + + If the type name exceeds ``MAX_INLINE_TYPE_LEN``, a + ``typealias _t_{name}`` is emitted and the alias is stored. + Otherwise the raw type name is stored directly. + + :param binding: The bound variable to update. + :param cgb: The code-gen block to write the type alias to (if needed). + :param type_name: The resolved Slang type name. + """ + if len(type_name) > MAX_INLINE_TYPE_LEN: + alias = f"_t_{binding.variable_name}" + cgb.type_alias(alias, type_name) + binding.calldata_type_name = alias + else: + binding.calldata_type_name = type_name + + +def gen_call_data_code( + binding: "BoundVariable", cg: CodeGen, context: "BindContext", depth: int = 0 +) -> None: + """Emit Slang call-data struct and mapping constants for one bound variable. + + For struct/dict variables, emits a ``_t_{name}`` struct with ``__slangpy_load`` + and ``__slangpy_store`` methods. For leaf variables, delegates to the marshall's + ``gen_calldata``. At depth 0, appends the variable's type to ``call_data`` + (or ``entry_point_params`` for the fast path). + + :param binding: The bound variable to emit code for. + :param cg: The active CodeGen object. + :param context: The bind context for the current call. + :param depth: Recursion depth (0 = root, >0 = struct field). + """ + from slangpy.bindings.boundvariable import BoundVariableException + + if binding.children is not None: + cgb = cg.call_data_structs + + if binding.direct_bind: + # Direct-bind: use raw type name directly + assert binding.vector_type is not None + gen_calldata_type_name(binding, cgb, binding.vector_type.full_name) + else: + struct_name = f"_t_{binding.variable_name}" + cgb.begin_struct(struct_name) + + for field, variable in binding.children.items(): + gen_call_data_code(variable, cg, context, depth + 1) + + for var in binding.children.values(): + assert ( + var.calldata_type_name is not None + ), f"calldata_type_name not set for '{var.variable_name}'" + cgb.declare(var.calldata_type_name, var.variable_name) + + assert binding.vector_type is not None + context_decl = f"ContextND<{binding.call_dimensionality}> context" + value_decl = f"{binding.vector_type.full_name} value" + prefix = "[Differentiable]" if binding.access[1] != AccessType.none else "" + + if binding.access[0] in (AccessType.read, AccessType.readwrite): + cgb.empty_line() + cgb.append_line( + f"{prefix} void __slangpy_load({context_decl}, out {value_decl})" + ) + cgb.begin_block() + for field, var in binding.children.items(): + gen_load = getattr(var.python, "gen_trampoline_load", None) + if gen_load is not None and gen_load( + cgb, var, var.variable_name, f"value.{field}" + ): + continue + cgb.append_statement( + f"{var.variable_name}.__slangpy_load(context.map(_m_{var.variable_name}),value.{field})" + ) + cgb.end_block() + + if binding.access[0] in (AccessType.write, AccessType.readwrite): + cgb.empty_line() + cgb.append_line( + f"{prefix} void __slangpy_store({context_decl}, in {value_decl})" + ) + cgb.begin_block() + for field, var in binding.children.items(): + gen_store = getattr(var.python, "gen_trampoline_store", None) + if gen_store is not None and gen_store( + cgb, var, var.variable_name, f"value.{field}" + ): + continue + cgb.append_statement( + f"{var.variable_name}.__slangpy_store(context.map(_m_{var.variable_name}),value.{field})" + ) + cgb.end_block() + + cgb.end_struct() + binding.calldata_type_name = struct_name + + else: + # Generate call data + binding.python.gen_calldata(cg.call_data_structs, context, binding) + + # Skip mapping constants for direct-bind variables (they bypass __slangpy_load/store) + if not binding.direct_bind: + if len(binding.vector_mapping) > 0: + cg.call_data_structs.append_statement( + f"static const int[] _m_{binding.variable_name} = {{ {','.join([str(x) for x in binding.vector_mapping.as_tuple()])} }}" + ) + else: + cg.call_data_structs.append_statement( + f"static const int _m_{binding.variable_name} = 0" + ) + + if depth == 0: + assert ( + binding.calldata_type_name is not None + ), f"calldata_type_name not set for '{binding.variable_name}'" + if binding.create_param_block: + cg.add_parameter_block(binding.calldata_type_name, "_param_" + binding.variable_name) + elif cg.skip_call_data: + cg.entry_point_params.append( + f"uniform {binding.calldata_type_name} {binding.variable_name}" + ) + else: + cg.call_data.declare(binding.calldata_type_name, binding.variable_name) + + +# --------------------------------------------------------------------------- +# generate_code sub-functions +# --------------------------------------------------------------------------- + + +def _validate_and_compute_group_shape( + build_info: "FunctionBuildInfo", + call_data_len: int, +) -> tuple[int, list[int], list[int]]: + """Validate ``call_group_shape`` and compute the flat group size and strides. + + Returns ``(call_group_size, call_group_strides, call_group_shape_vector)``. + When no call_group_shape is set, returns ``(1, [], [])``. + """ + call_group_size = 1 + call_group_strides: list[int] = [] + call_group_shape_vector: list[int] = [] + + call_group_shape = build_info.call_group_shape + if call_group_shape is not None: + call_group_shape_vector = call_group_shape.as_list() + + if len(call_group_shape_vector) > call_data_len: + raise KernelGenException( + f"call_group_shape dimensionality ({len(call_group_shape_vector)}) must be <= " + f"call_shape dimensionality ({call_data_len}). " + f"call_group_shape cannot have more dimensions than call_shape." + ) + elif len(call_group_shape_vector) < call_data_len: + missing_dims = call_data_len - len(call_group_shape_vector) + call_group_shape_vector = [1] * missing_dims + call_group_shape_vector + + for i, dim in enumerate(call_group_shape_vector): + if dim < 1: + raise KernelGenException( + f"call_group_shape[{i}] = {dim} is invalid. " + f"All call_group_shape elements must be >= 1." + ) + + for dim in call_group_shape_vector[::-1]: + call_group_strides.append(call_group_size) + call_group_size *= dim + call_group_strides.reverse() + + if call_group_size > 1024: + raise KernelGenException( + f"call_group_size ({call_group_size}) exceeds the typical 1024 maximum " + f"enforced by most APIs. Consider reducing your call_group_shape dimensions." + ) + + return call_group_size, call_group_strides, call_group_shape_vector + + +def _emit_link_time_constants( + cg: CodeGen, + build_info: "FunctionBuildInfo", + call_data_len: int, + call_group_size: int, + call_group_strides: list[int], + call_group_shape_vector: list[int], +) -> None: + """Emit link-time constant declarations. + + Emits Slang code like:: + + export static const int call_data_len = 2; + export static const int call_group_size = 1; + export static const int[call_data_len] call_group_strides = {}; + export static const int[call_data_len] call_group_shape_vector = {}; + """ + generate_constants(build_info, cg) + cg.constants.append_statement(f"export static const int call_data_len = {call_data_len}") + cg.constants.append_statement(f"export static const int call_group_size = {call_group_size}") + + cg.constants.append_line(f"export static const int[call_data_len] call_group_strides = {{") + cg.constants.inc_indent() + if call_group_size != 1: + for i in range(call_data_len): + cg.constants.append_line(f"{call_group_strides[i]},") + cg.constants.dec_indent() + cg.constants.append_statement("}") + + cg.constants.append_line( + f"export static const int[call_data_len] call_group_shape_vector = {{" + ) + cg.constants.inc_indent() + if call_group_size != 1: + for i in range(call_data_len): + cg.constants.append_line(f"{call_group_shape_vector[i]},") + cg.constants.dec_indent() + cg.constants.append_statement("}") + + +def _emit_shape_and_metadata_params( + cg: CodeGen, + call_data_len: int, + use_entrypoint_args: bool, +) -> None: + """Emit shape arrays and ``_thread_count``. + + Fast path (entry-point params):: + + uniform int[N] _grid_stride + uniform int[N] _grid_dim + uniform int[N] _call_dim + uniform uint3 _thread_count + + Fallback (CallData struct fields):: + + int[N] _grid_stride; + int[N] _grid_dim; + int[N] _call_dim; + uint3 _thread_count; + """ + if call_data_len > 0: + if use_entrypoint_args: + cg.entry_point_params.append(f"uniform int[{call_data_len}] _grid_stride") + cg.entry_point_params.append(f"uniform int[{call_data_len}] _grid_dim") + cg.entry_point_params.append(f"uniform int[{call_data_len}] _call_dim") + else: + cg.call_data.append_statement(f"int[{call_data_len}] _grid_stride") + cg.call_data.append_statement(f"int[{call_data_len}] _grid_dim") + cg.call_data.append_statement(f"int[{call_data_len}] _call_dim") + + if use_entrypoint_args: + cg.entry_point_params.append("uniform uint3 _thread_count") + else: + cg.call_data.append_statement("uint3 _thread_count") + + +def _emit_call_data_definitions( + cg: CodeGen, + context: "BindContext", + signature: "BoundCall", +) -> None: + """Emit per-variable call-data structs and type aliases for all signature nodes.""" + for node in signature.values(): + node.gen_call_data_code(cg, context) + + +def _data_name(x: "BoundVariable", use_entrypoint_args: bool) -> str: + """Return the Slang name used to access a variable's call data in the trampoline. + + - ``_param_{name}`` for param-block variables (both paths). + - ``__in_{name}`` in the fast (entry-point-args) path. + - ``call_data.{name}`` in the fallback path. + """ + if x.create_param_block: + return f"_param_{x.variable_name}" + elif use_entrypoint_args: + return f"__in_{x.variable_name}" + else: + return f"call_data.{x.variable_name}" + + +def _emit_trampoline( + cg: CodeGen, + context: "BindContext", + build_info: "FunctionBuildInfo", + root_params: list["BoundVariable"], + use_entrypoint_args: bool, +) -> None: + """Emit the ``_trampoline`` helper function. + + Fast path signature:: + + [Differentiable] + void _trampoline(Context __slangpy_context__, + no_diff MyType __in_param0, ...) + + Fallback signature:: + + [Differentiable] + void _trampoline(Context __slangpy_context__) + """ + from slangpy.bindings.boundvariable import BoundVariableException + + if context.call_mode != CallMode.prim: + cg.trampoline.append_line("[Differentiable]") + + if use_entrypoint_args: + trampoline_params = ["Context __slangpy_context__"] + for x in root_params: + if x.create_param_block: + continue + assert x.calldata_type_name is not None + trampoline_params.append(f"no_diff {x.calldata_type_name} __in_{x.variable_name}") + cg.trampoline.append_line(f"void _trampoline({', '.join(trampoline_params)})") + else: + cg.trampoline.append_line("void _trampoline(Context __slangpy_context__)") + cg.trampoline.begin_block() + + # Declare local variables for each parameter + for x in root_params: + assert x.vector_type is not None + cg.trampoline.declare(x.vector_type.full_name, x.variable_name) + + # Load inputs from call data + for x in root_params: + data_name = _data_name(x, use_entrypoint_args) + gen_load = getattr(x.python, "gen_trampoline_load", None) + if gen_load is not None and gen_load(cg.trampoline, x, data_name, x.variable_name): + continue + if x.access[0] == AccessType.read or x.access[0] == AccessType.readwrite: + cg.trampoline.append_statement( + f"{data_name}.__slangpy_load(__slangpy_context__.map(_m_{x.variable_name}), {x.variable_name})" + ) + + # Emit function call + cg.trampoline.append_indent() + if any(x.variable_name == "_result" for x in root_params): + cg.trampoline.append_code("_result = ") + + func_name = build_info.name + if func_name == "$init": + results = [x for x in root_params if x.variable_name == "_result"] + assert len(results) == 1 + assert results[0].vector_type is not None + func_name = results[0].vector_type.full_name + elif len(root_params) > 0 and root_params[0].variable_name == "_this": + func_name = f"_this.{func_name}" + + normal_params = [ + x for x in root_params if x.variable_name != "_result" and x.variable_name != "_this" + ] + cg.trampoline.append_code( + f"{func_name}(" + ", ".join(x.variable_name for x in normal_params) + ");\n" + ) + + # Store outputs back to call data + for x in root_params: + if ( + x.access[0] == AccessType.write + or x.access[0] == AccessType.readwrite + or x.access[1] == AccessType.read + ): + data_name = _data_name(x, use_entrypoint_args) + gen_store = getattr(x.python, "gen_trampoline_store", None) + if gen_store is not None and gen_store(cg.trampoline, x, data_name, x.variable_name): + continue + if not x.python.is_writable: + raise BoundVariableException( + f"Cannot read back value for non-writable type", x + ) + cg.trampoline.append_statement( + f"{data_name}.__slangpy_store(__slangpy_context__.map(_m_{x.variable_name}), {x.variable_name})" + ) + + cg.trampoline.end_block() + cg.trampoline.append_line("") + + +def _emit_entry_point_signature( + cg: CodeGen, + build_info: "FunctionBuildInfo", + call_data_len: int, + call_group_size: int, + use_entrypoint_args: bool, +) -> None: + """Emit the ``[shader(...)]`` attribute line and entry-point function signature. + + Compute fast path:: + + [shader("compute")] + [numthreads(32, 1, 1)] + void compute_main(int3 flat_call_thread_id: SV_DispatchThreadID, + int3 flat_call_group_id: SV_GroupID, + int flat_call_group_thread_id: SV_GroupIndex, + uniform int[N] _grid_stride, ...) + + Ray-tracing fallback:: + + [shader("raygen")] + void raygen_main() + """ + from slangpy.core.function import PipelineType + + if build_info.pipeline_type == PipelineType.compute: + cg.kernel.append_line('[shader("compute")]') + if call_group_size != 1: + cg.kernel.append_line(f"[numthreads({call_group_size}, 1, 1)]") + else: + cg.kernel.append_line("[numthreads(32, 1, 1)]") + if use_entrypoint_args: + sig_parts = ["int3 flat_call_thread_id: SV_DispatchThreadID"] + if call_data_len > 0: + sig_parts.append("int3 flat_call_group_id: SV_GroupID") + sig_parts.append("int flat_call_group_thread_id: SV_GroupIndex") + sig_parts.extend(cg.entry_point_params) + cg.kernel.append_line(f"void compute_main({', '.join(sig_parts)})") + else: + cg.kernel.append_line( + "void compute_main(int3 flat_call_thread_id: SV_DispatchThreadID, int3 flat_call_group_id: SV_GroupID, int flat_call_group_thread_id: SV_GroupIndex)" + ) + elif build_info.pipeline_type == PipelineType.ray_tracing: + cg.kernel.append_line('[shader("raygen")]') + if use_entrypoint_args: + sig_parts = list(cg.entry_point_params) + cg.kernel.append_line(f"void raygen_main({', '.join(sig_parts)})") + else: + cg.kernel.append_line("void raygen_main()") + else: + raise RuntimeError(f"Unknown pipeline type: {build_info.pipeline_type}") + + +def _emit_kernel_body( + cg: CodeGen, + context: "BindContext", + build_info: "FunctionBuildInfo", + root_params: list["BoundVariable"], + call_data_len: int, + use_entrypoint_args: bool, +) -> None: + """Emit the body of the compute/raygen entry-point function. + + Emits the bounds check, ``init_thread_local_call_shape_info``, Context + construction, and the trampoline call:: + + if (any(flat_call_thread_id >= _thread_count)) return; + if (!init_thread_local_call_shape_info(...)) return; + Context __slangpy_context__ = {flat_call_thread_id, ...}; + _trampoline(__slangpy_context__, ...); + """ + from slangpy.core.function import PipelineType + + if build_info.pipeline_type == PipelineType.ray_tracing: + cg.kernel.append_statement("int3 flat_call_thread_id = DispatchRaysIndex();") + + if use_entrypoint_args: + cg.kernel.append_statement("if (any(flat_call_thread_id >= _thread_count)) return") + else: + cg.kernel.append_statement( + "if (any(flat_call_thread_id >= call_data._thread_count)) return" + ) + + context_args = "flat_call_thread_id" + + if call_data_len > 0: + grid_prefix = "" if use_entrypoint_args else "call_data." + if build_info.pipeline_type == PipelineType.compute: + cg.kernel.append_line( + f""" + if (!init_thread_local_call_shape_info(flat_call_group_thread_id, + flat_call_group_id, flat_call_thread_id, {grid_prefix}_grid_stride, + {grid_prefix}_grid_dim, {grid_prefix}_call_dim)) + return;""" + ) + elif build_info.pipeline_type == PipelineType.ray_tracing: + cg.kernel.append_line( + f""" + if (!init_thread_local_call_shape_info(0, + uint3(0), flat_call_thread_id, {grid_prefix}_grid_stride, + {grid_prefix}_grid_dim, {grid_prefix}_call_dim)) + return;""" + ) + context_args += ", CallShapeInfo::get_call_id().shape" + + cg.kernel.append_statement(f"Context __slangpy_context__ = {{{context_args}}}") + + fn = "_trampoline" + if context.call_mode == CallMode.bwds: + fn = f"bwd_diff({fn})" + + if use_entrypoint_args: + trampoline_args = ["__slangpy_context__"] + for x in root_params: + if x.create_param_block: + continue + trampoline_args.append(x.variable_name) + cg.kernel.append_statement(f"{fn}({', '.join(trampoline_args)})") + else: + cg.kernel.append_statement(f"{fn}(__slangpy_context__)") + + +def generate_code( + context: "BindContext", + build_info: "FunctionBuildInfo", + signature: "BoundCall", + cg: CodeGen, +) -> None: + """Generate Slang kernel code for the given function call signature. + + Orchestrates all sub-steps: constants, shape params, call-data structs, + trampoline, entry-point signature, and kernel body. + """ + use_entrypoint_args = context.use_entrypoint_args + cg.add_import("slangpy") + call_data_len = context.call_dimensionality + + call_group_size, call_group_strides, call_group_shape_vector = ( + _validate_and_compute_group_shape(build_info, call_data_len) + ) + + cg.add_import(build_info.module.name) + if use_entrypoint_args: + cg.skip_call_data = True + + _emit_link_time_constants( + cg, build_info, call_data_len, call_group_size, call_group_strides, call_group_shape_vector + ) + _emit_shape_and_metadata_params(cg, call_data_len, use_entrypoint_args) + _emit_call_data_definitions(cg, context, signature) + + root_params = sorted(signature.values(), key=lambda x: x.param_index) + + _emit_trampoline(cg, context, build_info, root_params, use_entrypoint_args) + _emit_entry_point_signature(cg, build_info, call_data_len, call_group_size, use_entrypoint_args) + cg.kernel.begin_block() + _emit_kernel_body(cg, context, build_info, root_params, call_data_len, use_entrypoint_args) + cg.kernel.end_block() From 49630c364c07258625de60b9a8f8597b7079e2c9 Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Fri, 13 Mar 2026 16:24:49 +0000 Subject: [PATCH 33/41] fix calldata --- slangpy/core/calldata.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/slangpy/core/calldata.py b/slangpy/core/calldata.py index 367890c91..2e343d40e 100644 --- a/slangpy/core/calldata.py +++ b/slangpy/core/calldata.py @@ -8,11 +8,13 @@ from slangpy.core.callsignature import * from slangpy.core.logging import bound_call_table, bound_exception_info, mismatch_info from slangpy.core.native import ( + AccessType, CallMode, NativeCallData, unpack_args, unpack_kwargs, ) +from slangpy.core.function import PipelineType from slangpy import ( SlangCompileError, From 457aaa2dfd8788459c583f04bf4e5c0e322a534f Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Mon, 16 Mar 2026 09:41:45 +0000 Subject: [PATCH 34/41] work on generator cleanup --- slangpy/core/generator.py | 231 ++++++++++++++++++++++++++++++-------- 1 file changed, 185 insertions(+), 46 deletions(-) diff --git a/slangpy/core/generator.py b/slangpy/core/generator.py index 412dfe2ba..c0f9935cc 100644 --- a/slangpy/core/generator.py +++ b/slangpy/core/generator.py @@ -9,7 +9,6 @@ from slangpy.bindings.marshall import BindContext from slangpy.core.function import FunctionBuildInfo from slangpy.bindings.boundvariable import BoundVariable, BoundCall - from slangpy.bindings.marshall import BindContext #: Type names longer than this threshold get a ``typealias _t_{name}`` alias #: to keep the generated ``CallData`` struct readable. Shorter names are @@ -32,7 +31,17 @@ def _is_slangpy_vector(type: Any) -> bool: ) -def generate_constants(build_info: "FunctionBuildInfo", cg: CodeGen) -> None: +def _emit_user_constants(build_info: "FunctionBuildInfo", cg: CodeGen) -> None: + """Emit user-provided ``build_info.constants`` as exported Slang constants, + by appending them to the ``CodeGen.constants`` block. + + Example emitted declarations:: + + export static const bool enable_flag = true; + export static const int iterations = 32; + export static const float threshold = 0.5; + export static const float3 tint = float3(1,0,0); + """ if build_info.constants is not None: for k, v in build_info.constants.items(): if isinstance(v, bool): @@ -52,6 +61,14 @@ def generate_constants(build_info: "FunctionBuildInfo", cg: CodeGen) -> None: ) +def generate_constants(build_info: "FunctionBuildInfo", cg: CodeGen) -> None: + """Compatibility wrapper for legacy imports. + + The preferred internal implementation name is ``_emit_user_constants``. + """ + _emit_user_constants(build_info, cg) + + def gen_calldata_type_name(binding: "BoundVariable", cgb: CodeGenBlock, type_name: str) -> None: """Record the Slang type name for this variable's CallData field. @@ -71,15 +88,144 @@ def gen_calldata_type_name(binding: "BoundVariable", cgb: CodeGenBlock, type_nam binding.calldata_type_name = type_name -def gen_call_data_code( +def _emit_field_load( + cgb: CodeGenBlock, + var: "BoundVariable", + field: str, +) -> None: + """Emit a single field's ``__slangpy_load`` call inside a composite struct. + Will either use a marshall-specific load that implements direct binding to a uniform value, + or emit a call to the field's ``__slangpy_load`` method. + + Example emitted code for field ``a``:: + a.__slangpy_load(context.map(_m_a), value.a); // use field load method + value.a = a; // direct-bind load (no __slangpy_load method) + """ + gen_load = getattr(var.python, "gen_trampoline_load", None) + if gen_load is not None and gen_load(cgb, var, var.variable_name, f"value.{field}"): + return + cgb.append_statement( + f"{var.variable_name}.__slangpy_load(context.map(_m_{var.variable_name}),value.{field})" + ) + + +def _emit_field_store( + cgb: CodeGenBlock, + var: "BoundVariable", + field: str, +) -> None: + """Emit a single field's ``__slangpy_store`` call inside a composite struct. + + Example emitted code for field ``a``:: + a.__slangpy_store(context.map(_m_a), value.a); + + :param cgb: The code-gen block to write the store call to. + :param var: The bound variable representing the field to store. + :param field: The name of the field being stored (used for generating the value reference and error messages). + """ + gen_store = getattr(var.python, "gen_trampoline_store", None) + if gen_store is not None and gen_store(cgb, var, var.variable_name, f"value.{field}"): + return + cgb.append_statement( + f"{var.variable_name}.__slangpy_store(context.map(_m_{var.variable_name}),value.{field})" + ) + + +def _emit_composite_load_func( + cgb: CodeGenBlock, + binding: "BoundVariable", +) -> None: + """Emit the ``__slangpy_load`` method for a composite call-data struct. This + may include calls to __slangpy_load, or delegate to a marshall-specific + load that implements direct binding to a uniform values; + + Example: for a struct with fields ``a`` and ``b``:: + void __slangpy_load(ContextND<2> context, out Foo value) { + // load via marshall + a.__slangpy_load(context.map(_m_a), value.a); + + // or direct-bind load + value.b = this.b; + } + + :param cgb: The code-gen block to write the load function to. + :param binding: The bound variable representing the composite struct. + """ + assert binding.children is not None + assert binding.vector_type is not None + context_decl = f"ContextND<{binding.call_dimensionality}> context" + value_decl = f"{binding.vector_type.full_name} value" + prefix = "[Differentiable]" if binding.access[1] != AccessType.none else "" + cgb.empty_line() + cgb.append_line(f"{prefix} void __slangpy_load({context_decl}, out {value_decl})") + cgb.begin_block() + for field, var in binding.children.items(): + _emit_field_load(cgb, var, field) + cgb.end_block() + + +def _emit_composite_store_func( + cgb: CodeGenBlock, + binding: "BoundVariable", +) -> None: + """Emit the ``__slangpy_store`` method for a composite call-data struct. + + Example: for a struct with fields ``a`` and ``b``:: + void __slangpy_store(ContextND<2> context, in Foo value) { + a.__slangpy_store(context.map(_m_a), value.a); + b.__slangpy_store(context.map(_m_b), value.b); + } + + :param cgb: The code-gen block to write the store function to. + :param binding: The bound variable representing the composite struct. + """ + assert binding.children is not None + assert binding.vector_type is not None + context_decl = f"ContextND<{binding.call_dimensionality}> context" + value_decl = f"{binding.vector_type.full_name} value" + prefix = "[Differentiable]" if binding.access[1] != AccessType.none else "" + cgb.empty_line() + cgb.append_line(f"{prefix} void __slangpy_store({context_decl}, in {value_decl})") + cgb.begin_block() + for field, var in binding.children.items(): + _emit_field_store(cgb, var, field) + cgb.end_block() + + +def _emit_call_data_code( binding: "BoundVariable", cg: CodeGen, context: "BindContext", depth: int = 0 ) -> None: - """Emit Slang call-data struct and mapping constants for one bound variable. + """Emit Slang call-data type declarations and mapping constants. For struct/dict variables, emits a ``_t_{name}`` struct with ``__slangpy_load`` and ``__slangpy_store`` methods. For leaf variables, delegates to the marshall's - ``gen_calldata``. At depth 0, appends the variable's type to ``call_data`` - (or ``entry_point_params`` for the fast path). + ``gen_calldata``. + + At depth 0, will append the variable declaration to either: + - a parameter block (if ``create_param_block`` is True - for pre-built shader objects) + - the entry point parameters (if using entry-point args) + - the CallData struct (fallback path) + + A composite type declaration ``foo``:: + + // The composite struct, with load/store methods that recursively call into + // child fields. + struct _t_foo { + ChildType a; + ChildType b; + void __slangpy_load(ContextND<2> context, out Foo value) { ... } + void __slangpy_store(ContextND<2> context, in Foo value) { ... } + } + + A leaf argument may declare a type alias if it can not be directly bound :: + + typealias _t_foo = int; + + May also generate mapping constants for vectorized variables: + + static const int[] _m_foo = {0,1,2}; + + The composite load/store functions will generate load/store code for each child field. :param binding: The bound variable to emit code for. :param cg: The active CodeGen object. @@ -99,49 +245,23 @@ def gen_call_data_code( struct_name = f"_t_{binding.variable_name}" cgb.begin_struct(struct_name) + # Generate call data for child fields (recursively) + # Note: These are added as separate structs, not inside the parent struct. for field, variable in binding.children.items(): - gen_call_data_code(variable, cg, context, depth + 1) + _emit_call_data_code(variable, cg, context, depth + 1) + # Member variables of composite struct for var in binding.children.values(): assert ( var.calldata_type_name is not None ), f"calldata_type_name not set for '{var.variable_name}'" cgb.declare(var.calldata_type_name, var.variable_name) - assert binding.vector_type is not None - context_decl = f"ContextND<{binding.call_dimensionality}> context" - value_decl = f"{binding.vector_type.full_name} value" - prefix = "[Differentiable]" if binding.access[1] != AccessType.none else "" - + # Load/store methods in struct body. if binding.access[0] in (AccessType.read, AccessType.readwrite): - cgb.empty_line() - cgb.append_line(f"{prefix} void __slangpy_load({context_decl}, out {value_decl})") - cgb.begin_block() - for field, var in binding.children.items(): - gen_load = getattr(var.python, "gen_trampoline_load", None) - if gen_load is not None and gen_load( - cgb, var, var.variable_name, f"value.{field}" - ): - continue - cgb.append_statement( - f"{var.variable_name}.__slangpy_load(context.map(_m_{var.variable_name}),value.{field})" - ) - cgb.end_block() - + _emit_composite_load_func(cgb, binding) if binding.access[0] in (AccessType.write, AccessType.readwrite): - cgb.empty_line() - cgb.append_line(f"{prefix} void __slangpy_store({context_decl}, in {value_decl})") - cgb.begin_block() - for field, var in binding.children.items(): - gen_store = getattr(var.python, "gen_trampoline_store", None) - if gen_store is not None and gen_store( - cgb, var, var.variable_name, f"value.{field}" - ): - continue - cgb.append_statement( - f"{var.variable_name}.__slangpy_store(context.map(_m_{var.variable_name}),value.{field})" - ) - cgb.end_block() + _emit_composite_store_func(cgb, binding) cgb.end_struct() binding.calldata_type_name = struct_name @@ -150,7 +270,7 @@ def gen_call_data_code( # Generate call data binding.python.gen_calldata(cg.call_data_structs, context, binding) - # Skip mapping constants for direct-bind variables (they bypass __slangpy_load/store) + # Mapping constants if not binding.direct_bind: if len(binding.vector_mapping) > 0: cg.call_data_structs.append_statement( @@ -161,6 +281,7 @@ def gen_call_data_code( f"static const int _m_{binding.variable_name} = 0" ) + # At depth 0, declare the variable in the appropriate place if depth == 0: assert ( binding.calldata_type_name is not None @@ -175,6 +296,16 @@ def gen_call_data_code( cg.call_data.declare(binding.calldata_type_name, binding.variable_name) +def gen_call_data_code( + binding: "BoundVariable", cg: CodeGen, context: "BindContext", depth: int = 0 +) -> None: + """Compatibility wrapper for method-style call sites. + + The preferred internal implementation name is ``_emit_call_data_code``. + """ + _emit_call_data_code(binding, cg, context, depth) + + # --------------------------------------------------------------------------- # generate_code sub-functions # --------------------------------------------------------------------------- @@ -236,16 +367,20 @@ def _emit_link_time_constants( call_group_strides: list[int], call_group_shape_vector: list[int], ) -> None: - """Emit link-time constant declarations. + """Emit link-time constant declarations, including user defined ones + and any of the required call group shape constants. Emits Slang code like:: + // User constants from build_info.constants (if present) + export static const int user_const = 7; + export static const int call_data_len = 2; export static const int call_group_size = 1; export static const int[call_data_len] call_group_strides = {}; export static const int[call_data_len] call_group_shape_vector = {}; """ - generate_constants(build_info, cg) + _emit_user_constants(build_info, cg) cg.constants.append_statement(f"export static const int call_data_len = {call_data_len}") cg.constants.append_statement(f"export static const int call_group_size = {call_group_size}") @@ -310,7 +445,7 @@ def _emit_call_data_definitions( ) -> None: """Emit per-variable call-data structs and type aliases for all signature nodes.""" for node in signature.values(): - node.gen_call_data_code(cg, context) + _emit_call_data_code(node, cg, context) def _data_name(x: "BoundVariable", use_entrypoint_args: bool) -> str: @@ -441,7 +576,7 @@ def _emit_entry_point_signature( int flat_call_group_thread_id: SV_GroupIndex, uniform int[N] _grid_stride, ...) - Ray-tracing fallback:: + Ray-tracing entry point:: [shader("raygen")] void raygen_main() @@ -496,9 +631,11 @@ def _emit_kernel_body( """ from slangpy.core.function import PipelineType + # For RTP, read thread ID using DispatchRaysIndex() instead of SV_DispatchThreadID if build_info.pipeline_type == PipelineType.ray_tracing: cg.kernel.append_statement("int3 flat_call_thread_id = DispatchRaysIndex();") + # Bail out if out of bounds. if use_entrypoint_args: cg.kernel.append_statement("if (any(flat_call_thread_id >= _thread_count)) return") else: @@ -506,8 +643,9 @@ def _emit_kernel_body( "if (any(flat_call_thread_id >= call_data._thread_count)) return" ) + # Call to init_thread_local_call_shape_info that unpacks the thread id into + # a coordinate in the call shape, and stores the call shape info in thread-local storage context_args = "flat_call_thread_id" - if call_data_len > 0: grid_prefix = "" if use_entrypoint_args else "call_data." if build_info.pipeline_type == PipelineType.compute: @@ -528,12 +666,13 @@ def _emit_kernel_body( ) context_args += ", CallShapeInfo::get_call_id().shape" + # Define the core context. cg.kernel.append_statement(f"Context __slangpy_context__ = {{{context_args}}}") + # Emit the trampoline call, passing the context and any entry-point args (if using them). fn = "_trampoline" if context.call_mode == CallMode.bwds: fn = f"bwd_diff({fn})" - if use_entrypoint_args: trampoline_args = ["__slangpy_context__"] for x in root_params: From 1117867121d4d915847f1e7593b1bdd72adb1d1a Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Mon, 16 Mar 2026 10:01:24 +0000 Subject: [PATCH 35/41] More generator cleanup --- slangpy/core/generator.py | 89 +++++++++++++++++++++++++++------------ 1 file changed, 62 insertions(+), 27 deletions(-) diff --git a/slangpy/core/generator.py b/slangpy/core/generator.py index c0f9935cc..20d3c3e56 100644 --- a/slangpy/core/generator.py +++ b/slangpy/core/generator.py @@ -463,6 +463,62 @@ def _data_name(x: "BoundVariable", use_entrypoint_args: bool) -> str: return f"call_data.{x.variable_name}" +def _emit_trampoline_loads( + cgb: CodeGenBlock, + root_params: list["BoundVariable"], + use_entrypoint_args: bool, +) -> None: + """Emit ``__slangpy_load`` calls for each readable trampoline parameter. + + For each parameter, either delegates to a marshall-specific + ``gen_trampoline_load`` or emits a standard load call:: + + __in_x.__slangpy_load(__slangpy_context__.map(_m_x), x); // slangpy load + x = __in_x; // direct-bind load (no __slangpy_load method) + """ + for x in root_params: + data_name = _data_name(x, use_entrypoint_args) + gen_load = getattr(x.python, "gen_trampoline_load", None) + if gen_load is not None and gen_load(cgb, x, data_name, x.variable_name): + continue + if x.access[0] == AccessType.read or x.access[0] == AccessType.readwrite: + cgb.append_statement( + f"{data_name}.__slangpy_load(__slangpy_context__.map(_m_{x.variable_name}), {x.variable_name})" + ) + + +def _emit_trampoline_stores( + cgb: CodeGenBlock, + root_params: list["BoundVariable"], + use_entrypoint_args: bool, +) -> None: + """Emit ``__slangpy_store`` calls for each writable trampoline parameter. + + For each parameter that is written or whose gradient is read, either + delegates to a marshall-specific ``gen_trampoline_store`` or emits a + standard store call:: + + __in_x.__slangpy_store(__slangpy_context__.map(_m_x), x); + """ + from slangpy.bindings.boundvariable import BoundVariableException + + for x in root_params: + if ( + x.access[0] == AccessType.write + or x.access[0] == AccessType.readwrite + or x.access[1] == AccessType.read + ): + data_name = _data_name(x, use_entrypoint_args) + gen_store = getattr(x.python, "gen_trampoline_store", None) + if gen_store is not None and gen_store(cgb, x, data_name, x.variable_name): + continue + if not x.python.is_writable: + raise BoundVariableException(f"Cannot read back value for non-writable type", x) + cgb.append_statement( + f"{data_name}.__slangpy_store(__slangpy_context__.map(_m_{x.variable_name}), {x.variable_name})" + ) + + def _emit_trampoline( cg: CodeGen, context: "BindContext", @@ -483,8 +539,6 @@ def _emit_trampoline( [Differentiable] void _trampoline(Context __slangpy_context__) """ - from slangpy.bindings.boundvariable import BoundVariableException - if context.call_mode != CallMode.prim: cg.trampoline.append_line("[Differentiable]") @@ -506,21 +560,15 @@ def _emit_trampoline( cg.trampoline.declare(x.vector_type.full_name, x.variable_name) # Load inputs from call data - for x in root_params: - data_name = _data_name(x, use_entrypoint_args) - gen_load = getattr(x.python, "gen_trampoline_load", None) - if gen_load is not None and gen_load(cg.trampoline, x, data_name, x.variable_name): - continue - if x.access[0] == AccessType.read or x.access[0] == AccessType.readwrite: - cg.trampoline.append_statement( - f"{data_name}.__slangpy_load(__slangpy_context__.map(_m_{x.variable_name}), {x.variable_name})" - ) + _emit_trampoline_loads(cg.trampoline, root_params, use_entrypoint_args) - # Emit function call + # Emit the 'result=' bit if function has a return value. cg.trampoline.append_indent() if any(x.variable_name == "_result" for x in root_params): cg.trampoline.append_code("_result = ") + # Generate the function call prefix, with some special casing for constructors + # and type method calls. func_name = build_info.name if func_name == "$init": results = [x for x in root_params if x.variable_name == "_result"] @@ -530,6 +578,7 @@ def _emit_trampoline( elif len(root_params) > 0 and root_params[0].variable_name == "_this": func_name = f"_this.{func_name}" + # Emit the function call itself, passing in parameters other than _result and _this. normal_params = [ x for x in root_params if x.variable_name != "_result" and x.variable_name != "_this" ] @@ -538,21 +587,7 @@ def _emit_trampoline( ) # Store outputs back to call data - for x in root_params: - if ( - x.access[0] == AccessType.write - or x.access[0] == AccessType.readwrite - or x.access[1] == AccessType.read - ): - data_name = _data_name(x, use_entrypoint_args) - gen_store = getattr(x.python, "gen_trampoline_store", None) - if gen_store is not None and gen_store(cg.trampoline, x, data_name, x.variable_name): - continue - if not x.python.is_writable: - raise BoundVariableException(f"Cannot read back value for non-writable type", x) - cg.trampoline.append_statement( - f"{data_name}.__slangpy_store(__slangpy_context__.map(_m_{x.variable_name}), {x.variable_name})" - ) + _emit_trampoline_stores(cg.trampoline, root_params, use_entrypoint_args) cg.trampoline.end_block() cg.trampoline.append_line("") From c71f4c24880e6cfd6bdf39f97ebe06a981440837 Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Mon, 16 Mar 2026 14:52:17 +0000 Subject: [PATCH 36/41] more generator cleanup --- slangpy/core/generator.py | 258 +++++++++++++++++++++++--------------- 1 file changed, 154 insertions(+), 104 deletions(-) diff --git a/slangpy/core/generator.py b/slangpy/core/generator.py index 20d3c3e56..c27e234a3 100644 --- a/slangpy/core/generator.py +++ b/slangpy/core/generator.py @@ -22,6 +22,58 @@ def __init__(self, message: str): self.message = message +# --------------------------------------------------------------------------- +# Access-tuple helpers +# --------------------------------------------------------------------------- +# BoundVariable.access is a (primal, derivative) tuple of AccessType. +# These predicates give readable names to the index lookups. + + +def _is_readable(b: "BoundVariable") -> bool: + """True when the primal value is read (read or readwrite).""" + return b.access[0] in (AccessType.read, AccessType.readwrite) + + +def _is_writable(b: "BoundVariable") -> bool: + """True when the primal value is written (write or readwrite).""" + return b.access[0] in (AccessType.write, AccessType.readwrite) + + +def _is_differentiable(b: "BoundVariable") -> bool: + """True when the derivative access is anything other than none.""" + return b.access[1] != AccessType.none + + +def _grad_is_readable(b: "BoundVariable") -> bool: + """True when the derivative/gradient is read back.""" + return b.access[1] == AccessType.read + + +# --------------------------------------------------------------------------- +# Shared trampoline dispatch helper +# --------------------------------------------------------------------------- + + +def _try_custom_gen( + var: "BoundVariable", + method: str, + cgb: CodeGenBlock, + data_name: str, + value_expr: str, +) -> bool: + """Try calling a marshall-specific ``gen_trampoline_load`` or ``gen_trampoline_store``. + + Returns True if the marshall handled code generation, False otherwise. + """ + fn = getattr(var.python, method, None) + return fn is not None and fn(cgb, var, data_name, value_expr) + + +# --------------------------------------------------------------------------- +# Misc helpers +# --------------------------------------------------------------------------- + + def _is_slangpy_vector(type: Any) -> bool: return ( hasattr(type, "element_type") @@ -31,6 +83,11 @@ def _is_slangpy_vector(type: Any) -> bool: ) +# --------------------------------------------------------------------------- +# Call-data code generation +# --------------------------------------------------------------------------- + + def _emit_user_constants(build_info: "FunctionBuildInfo", cg: CodeGen) -> None: """Emit user-provided ``build_info.constants`` as exported Slang constants, by appending them to the ``CodeGen.constants`` block. @@ -61,12 +118,8 @@ def _emit_user_constants(build_info: "FunctionBuildInfo", cg: CodeGen) -> None: ) -def generate_constants(build_info: "FunctionBuildInfo", cg: CodeGen) -> None: - """Compatibility wrapper for legacy imports. - - The preferred internal implementation name is ``_emit_user_constants``. - """ - _emit_user_constants(build_info, cg) +#: Compatibility alias for legacy imports. +generate_constants = _emit_user_constants def gen_calldata_type_name(binding: "BoundVariable", cgb: CodeGenBlock, type_name: str) -> None: @@ -101,8 +154,7 @@ def _emit_field_load( a.__slangpy_load(context.map(_m_a), value.a); // use field load method value.a = a; // direct-bind load (no __slangpy_load method) """ - gen_load = getattr(var.python, "gen_trampoline_load", None) - if gen_load is not None and gen_load(cgb, var, var.variable_name, f"value.{field}"): + if _try_custom_gen(var, "gen_trampoline_load", cgb, var.variable_name, f"value.{field}"): return cgb.append_statement( f"{var.variable_name}.__slangpy_load(context.map(_m_{var.variable_name}),value.{field})" @@ -123,8 +175,7 @@ def _emit_field_store( :param var: The bound variable representing the field to store. :param field: The name of the field being stored (used for generating the value reference and error messages). """ - gen_store = getattr(var.python, "gen_trampoline_store", None) - if gen_store is not None and gen_store(cgb, var, var.variable_name, f"value.{field}"): + if _try_custom_gen(var, "gen_trampoline_store", cgb, var.variable_name, f"value.{field}"): return cgb.append_statement( f"{var.variable_name}.__slangpy_store(context.map(_m_{var.variable_name}),value.{field})" @@ -155,7 +206,7 @@ def _emit_composite_load_func( assert binding.vector_type is not None context_decl = f"ContextND<{binding.call_dimensionality}> context" value_decl = f"{binding.vector_type.full_name} value" - prefix = "[Differentiable]" if binding.access[1] != AccessType.none else "" + prefix = "[Differentiable]" if _is_differentiable(binding) else "" cgb.empty_line() cgb.append_line(f"{prefix} void __slangpy_load({context_decl}, out {value_decl})") cgb.begin_block() @@ -183,7 +234,7 @@ def _emit_composite_store_func( assert binding.vector_type is not None context_decl = f"ContextND<{binding.call_dimensionality}> context" value_decl = f"{binding.vector_type.full_name} value" - prefix = "[Differentiable]" if binding.access[1] != AccessType.none else "" + prefix = "[Differentiable]" if _is_differentiable(binding) else "" cgb.empty_line() cgb.append_line(f"{prefix} void __slangpy_store({context_decl}, in {value_decl})") cgb.begin_block() @@ -192,24 +243,17 @@ def _emit_composite_store_func( cgb.end_block() -def _emit_call_data_code( - binding: "BoundVariable", cg: CodeGen, context: "BindContext", depth: int = 0 +def _emit_type_and_struct( + binding: "BoundVariable", cg: CodeGen, context: "BindContext", depth: int ) -> None: - """Emit Slang call-data type declarations and mapping constants. - - For struct/dict variables, emits a ``_t_{name}`` struct with ``__slangpy_load`` - and ``__slangpy_store`` methods. For leaf variables, delegates to the marshall's - ``gen_calldata``. + """Emit the type declaration for a single binding. - At depth 0, will append the variable declaration to either: - - a parameter block (if ``create_param_block`` is True - for pre-built shader objects) - - the entry point parameters (if using entry-point args) - - the CallData struct (fallback path) + For composite (struct/dict) variables, emits a ``_t_{name}`` struct with + ``__slangpy_load`` / ``__slangpy_store`` methods and recurses into children. + For leaf variables, delegates to the marshall's ``gen_calldata``. - A composite type declaration ``foo``:: + Example composite output:: - // The composite struct, with load/store methods that recursively call into - // child fields. struct _t_foo { ChildType a; ChildType b; @@ -217,93 +261,105 @@ def _emit_call_data_code( void __slangpy_store(ContextND<2> context, in Foo value) { ... } } - A leaf argument may declare a type alias if it can not be directly bound :: + Example leaf output:: typealias _t_foo = int; - - May also generate mapping constants for vectorized variables: - - static const int[] _m_foo = {0,1,2}; - - The composite load/store functions will generate load/store code for each child field. - - :param binding: The bound variable to emit code for. - :param cg: The active CodeGen object. - :param context: The bind context for the current call. - :param depth: Recursion depth (0 = root, >0 = struct field). """ - from slangpy.bindings.boundvariable import BoundVariableException - if binding.children is not None: cgb = cg.call_data_structs if binding.direct_bind: - # Direct-bind: use raw type name directly assert binding.vector_type is not None gen_calldata_type_name(binding, cgb, binding.vector_type.full_name) else: struct_name = f"_t_{binding.variable_name}" cgb.begin_struct(struct_name) - # Generate call data for child fields (recursively) - # Note: These are added as separate structs, not inside the parent struct. + # Recurse into children (emitted as separate structs, not nested). for field, variable in binding.children.items(): _emit_call_data_code(variable, cg, context, depth + 1) - # Member variables of composite struct + # Member variables for var in binding.children.values(): assert ( var.calldata_type_name is not None ), f"calldata_type_name not set for '{var.variable_name}'" cgb.declare(var.calldata_type_name, var.variable_name) - # Load/store methods in struct body. - if binding.access[0] in (AccessType.read, AccessType.readwrite): + # Load/store methods + if _is_readable(binding): _emit_composite_load_func(cgb, binding) - if binding.access[0] in (AccessType.write, AccessType.readwrite): + if _is_writable(binding): _emit_composite_store_func(cgb, binding) cgb.end_struct() binding.calldata_type_name = struct_name - else: - # Generate call data binding.python.gen_calldata(cg.call_data_structs, context, binding) - # Mapping constants - if not binding.direct_bind: - if len(binding.vector_mapping) > 0: - cg.call_data_structs.append_statement( - f"static const int[] _m_{binding.variable_name} = {{ {','.join([str(x) for x in binding.vector_mapping.as_tuple()])} }}" - ) - else: - cg.call_data_structs.append_statement( - f"static const int _m_{binding.variable_name} = 0" - ) - # At depth 0, declare the variable in the appropriate place - if depth == 0: - assert ( - binding.calldata_type_name is not None - ), f"calldata_type_name not set for '{binding.variable_name}'" - if binding.create_param_block: - cg.add_parameter_block(binding.calldata_type_name, "_param_" + binding.variable_name) - elif cg.skip_call_data: - cg.entry_point_params.append( - f"uniform {binding.calldata_type_name} {binding.variable_name}" - ) - else: - cg.call_data.declare(binding.calldata_type_name, binding.variable_name) +def _emit_mapping_constants(binding: "BoundVariable", cg: CodeGen) -> None: + """Emit the vectorization mapping constant for a single binding. + + Example output:: + + static const int[] _m_foo = {0,1,2}; // vectorized + static const int _m_foo = 0; // scalar + """ + if binding.direct_bind: + return + if len(binding.vector_mapping) > 0: + cg.call_data_structs.append_statement( + f"static const int[] _m_{binding.variable_name}" + f" = {{ {','.join([str(x) for x in binding.vector_mapping.as_tuple()])} }}" + ) + else: + cg.call_data_structs.append_statement(f"static const int _m_{binding.variable_name} = 0") + + +def _emit_root_declaration(binding: "BoundVariable", cg: CodeGen) -> None: + """At depth 0, declare the variable in the appropriate destination. + Chooses between a parameter block, an entry-point uniform, or a + CallData struct field depending on the binding configuration. + """ + assert ( + binding.calldata_type_name is not None + ), f"calldata_type_name not set for '{binding.variable_name}'" + if binding.create_param_block: + cg.add_parameter_block(binding.calldata_type_name, "_param_" + binding.variable_name) + elif cg.skip_call_data: + cg.entry_point_params.append( + f"uniform {binding.calldata_type_name} {binding.variable_name}" + ) + else: + cg.call_data.declare(binding.calldata_type_name, binding.variable_name) -def gen_call_data_code( + +def _emit_call_data_code( binding: "BoundVariable", cg: CodeGen, context: "BindContext", depth: int = 0 ) -> None: - """Compatibility wrapper for method-style call sites. + """Emit Slang call-data type declarations, mapping constants, and + root-level variable declarations for a single binding. - The preferred internal implementation name is ``_emit_call_data_code``. + Orchestrates three sub-steps: + 1. Type/struct declaration (``_emit_type_and_struct``) + 2. Mapping constants (``_emit_mapping_constants``) + 3. Root declaration at depth 0 (``_emit_root_declaration``) + + :param binding: The bound variable to emit code for. + :param cg: The active CodeGen object. + :param context: The bind context for the current call. + :param depth: Recursion depth (0 = root, >0 = struct field). """ - _emit_call_data_code(binding, cg, context, depth) + _emit_type_and_struct(binding, cg, context, depth) + _emit_mapping_constants(binding, cg) + if depth == 0: + _emit_root_declaration(binding, cg) + + +#: Compatibility alias for legacy imports. +gen_call_data_code = _emit_call_data_code # --------------------------------------------------------------------------- @@ -478,10 +534,9 @@ def _emit_trampoline_loads( """ for x in root_params: data_name = _data_name(x, use_entrypoint_args) - gen_load = getattr(x.python, "gen_trampoline_load", None) - if gen_load is not None and gen_load(cgb, x, data_name, x.variable_name): + if _try_custom_gen(x, "gen_trampoline_load", cgb, data_name, x.variable_name): continue - if x.access[0] == AccessType.read or x.access[0] == AccessType.readwrite: + if _is_readable(x): cgb.append_statement( f"{data_name}.__slangpy_load(__slangpy_context__.map(_m_{x.variable_name}), {x.variable_name})" ) @@ -503,14 +558,9 @@ def _emit_trampoline_stores( from slangpy.bindings.boundvariable import BoundVariableException for x in root_params: - if ( - x.access[0] == AccessType.write - or x.access[0] == AccessType.readwrite - or x.access[1] == AccessType.read - ): + if _is_writable(x) or _grad_is_readable(x): data_name = _data_name(x, use_entrypoint_args) - gen_store = getattr(x.python, "gen_trampoline_store", None) - if gen_store is not None and gen_store(cgb, x, data_name, x.variable_name): + if _try_custom_gen(x, "gen_trampoline_store", cgb, data_name, x.variable_name): continue if not x.python.is_writable: raise BoundVariableException(f"Cannot read back value for non-writable type", x) @@ -632,9 +682,12 @@ def _emit_entry_point_signature( sig_parts.extend(cg.entry_point_params) cg.kernel.append_line(f"void compute_main({', '.join(sig_parts)})") else: - cg.kernel.append_line( - "void compute_main(int3 flat_call_thread_id: SV_DispatchThreadID, int3 flat_call_group_id: SV_GroupID, int flat_call_group_thread_id: SV_GroupIndex)" - ) + sig_parts = [ + "int3 flat_call_thread_id: SV_DispatchThreadID", + "int3 flat_call_group_id: SV_GroupID", + "int flat_call_group_thread_id: SV_GroupIndex", + ] + cg.kernel.append_line(f"void compute_main({', '.join(sig_parts)})") elif build_info.pipeline_type == PipelineType.ray_tracing: cg.kernel.append_line('[shader("raygen")]') if use_entrypoint_args: @@ -682,23 +735,20 @@ def _emit_kernel_body( # a coordinate in the call shape, and stores the call shape info in thread-local storage context_args = "flat_call_thread_id" if call_data_len > 0: - grid_prefix = "" if use_entrypoint_args else "call_data." + gp = "" if use_entrypoint_args else "call_data." if build_info.pipeline_type == PipelineType.compute: - cg.kernel.append_line( - f""" - if (!init_thread_local_call_shape_info(flat_call_group_thread_id, - flat_call_group_id, flat_call_thread_id, {grid_prefix}_grid_stride, - {grid_prefix}_grid_dim, {grid_prefix}_call_dim)) - return;""" - ) + thread_arg = "flat_call_group_thread_id" + group_arg = "flat_call_group_id" elif build_info.pipeline_type == PipelineType.ray_tracing: - cg.kernel.append_line( - f""" - if (!init_thread_local_call_shape_info(0, - uint3(0), flat_call_thread_id, {grid_prefix}_grid_stride, - {grid_prefix}_grid_dim, {grid_prefix}_call_dim)) - return;""" - ) + thread_arg = "0" + group_arg = "uint3(0)" + else: + raise RuntimeError(f"Unknown pipeline type: {build_info.pipeline_type}") + cg.kernel.append_statement( + f"if (!init_thread_local_call_shape_info(" + f"{thread_arg}, {group_arg}, flat_call_thread_id, " + f"{gp}_grid_stride, {gp}_grid_dim, {gp}_call_dim)) return" + ) context_args += ", CallShapeInfo::get_call_id().shape" # Define the core context. From 4d224797c861b5ca4f401ce1436ddff5b94b554e Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Mon, 16 Mar 2026 15:00:42 +0000 Subject: [PATCH 37/41] Cleanup use param block --- slangpy/bindings/codegen.py | 5 ++++- slangpy/core/calldata.py | 1 - 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/slangpy/bindings/codegen.py b/slangpy/bindings/codegen.py index ebf8f7abb..faea4fec1 100644 --- a/slangpy/bindings/codegen.py +++ b/slangpy/bindings/codegen.py @@ -193,7 +193,7 @@ def finish( snippets: bool = False, call_data_structs: bool = False, constants: bool = False, - use_param_block_for_call_data: bool = False, + use_param_block_for_call_data: bool = True, ): """ Generate the final code for the kernel. @@ -202,6 +202,9 @@ def finish( if not self.skip_call_data: self.call_data.end_block() + # TODO: Remove 'use_param_block_for_call_data' + # This is only set to false for raw dispatch on cuda. Once it's retired, we will always bind + # call_data as a parameter block unless it is skipped. if use_param_block_for_call_data: self.call_data.append_statement("ParameterBlock call_data") diff --git a/slangpy/core/calldata.py b/slangpy/core/calldata.py index 2e343d40e..6a17e8828 100644 --- a/slangpy/core/calldata.py +++ b/slangpy/core/calldata.py @@ -410,7 +410,6 @@ def _try_build_shader( snippets=True, call_data_structs=True, constants=True, - use_param_block_for_call_data=not context.use_entrypoint_args, ) # Optionally write the shader to a file for debugging. From c5c13ce352c9eb02d01a233b387dd08cf47d037a Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Mon, 16 Mar 2026 15:08:43 +0000 Subject: [PATCH 38/41] Safety check --- slangpy/core/generator.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/slangpy/core/generator.py b/slangpy/core/generator.py index c27e234a3..c9676df53 100644 --- a/slangpy/core/generator.py +++ b/slangpy/core/generator.py @@ -296,6 +296,12 @@ def _emit_type_and_struct( binding.calldata_type_name = struct_name else: binding.python.gen_calldata(cg.call_data_structs, context, binding) + if binding.calldata_type_name is None: + raise KernelGenException( + f"Marshall '{type(binding.python).__name__}' did not set " + f"calldata_type_name for '{binding.variable_name}' in gen_calldata(). " + f"Ensure gen_calldata calls binding.gen_calldata_type_name()." + ) def _emit_mapping_constants(binding: "BoundVariable", cg: CodeGen) -> None: From 9f418e1602bc125bb735af31ce4293cd5c7ef5a0 Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Mon, 16 Mar 2026 16:54:49 +0000 Subject: [PATCH 39/41] wip removing trampoline --- .../plan-consolidateCodeGenTests.prompt.md | 269 ------------------ .../plan-simplifyKernelGen-phase2.prompt.md | 52 +++- slangpy/core/generator.py | 147 +++++++--- slangpy/tests/slangpy_tests/test_code_gen.py | 51 +++- .../tests/slangpy_tests/test_kernel_gen.py | 51 +++- 5 files changed, 221 insertions(+), 349 deletions(-) delete mode 100644 .github/prompts/plan-consolidateCodeGenTests.prompt.md diff --git a/.github/prompts/plan-consolidateCodeGenTests.prompt.md b/.github/prompts/plan-consolidateCodeGenTests.prompt.md deleted file mode 100644 index f8315a232..000000000 --- a/.github/prompts/plan-consolidateCodeGenTests.prompt.md +++ /dev/null @@ -1,269 +0,0 @@ -## Plan: Consolidate test_kernel_gen.py → test_code_gen.py - -**TL;DR**: Reduce 80 test functions to ~34 by: (1) merging codegen-pattern + binding-flag tests that generate the same kernel into single combined tests, (2) dropping functional GPU dispatch tests that duplicate coverage in existing test files, (3) consolidating tests that use identical Slang source strings, and (4) subsuming shallow struct nesting tests into deeper ones. - -Current: **80 tests × 3 device types = 240 parametrized cases** (~1841 lines) -Proposed: **~34 tests × 3 device types = ~102 parametrized cases** (~700-800 lines) - ---- - -### Consolidation Strategies - -#### Strategy A: Merge same-source codegen tests - -Five tests use `int add(int a, int b)` with args `(1, 2)` and generate the exact same kernel: -- `test_gate_scalar_uses_valuetype` -- `test_gate_valueref_write_uses_wrapper` -- `test_gate_mapping_constants_present` -- `test_gate_context_map_in_trampoline` -- `test_result_binding_not_direct_bind` - -→ **1 combined test** `test_scalar_direct_bind`: generates kernel once, asserts all codegen patterns (no `ValueType`, no `typealias`, direct assignment, `RWValueRef` for `_result`, no mapping constants for args, `_m__result` present, no `context.map` for args) **and** binding flags (`args[0].direct_bind=True`, `kwargs["_result"].direct_bind=False`). - -Also fold `test_gate_float_scalar_uses_valuetype` into this as a sub-assertion (generate a second kernel for `float mymul` and check same patterns), or drop entirely since `int` and `float` exercise the same `ValueMarshall` code path. - -#### Strategy B: Merge codegen + binding flag pairs - -Each of these pairs generates the same kernel twice — merge into one test using a new `generate_code_and_bindings` helper that returns `(code, bindings)`: - -| Merged test | From (codegen) | From (binding flags) | -|---|---|---| -| `test_vector_direct_bind` | `test_gate_vector_uses_vectorvaluetype` | (no binding-flag test exists; add flags check) | -| `test_struct_all_scalar_direct_bind` | `test_gate_struct_uses_slangpy_load` | `test_struct_all_scalars_binding_flag` | -| `test_mixed_scalar_tensor` | `test_gate_mixed_args_scalar_and_tensor` | `test_gate_mixed_args_direct_bind_flags` | -| `test_struct_mixed_fields` | `test_gate_struct_mixed_fields_codegen` + `test_mixed_children_direct_bind_codegen` | `test_gate_struct_mixed_fields_binding_flags` | -| `test_tensor_dim0_direct_bind` | `test_gate_tensor_dim0_codegen` | `test_gate_tensor_dim0_binding_flags` | -| `test_2d_tensor_to_vector` | `test_gate_2d_tensor_to_vector_codegen` | `test_gate_2d_tensor_to_vector_binding_flags` | -| `test_3d_tensor_to_vector` | `test_gate_3d_tensor_to_vector_codegen` | `test_gate_3d_tensor_to_vector_binding_flags` | -| `test_2d_tensor_to_scalar` | `test_gate_2d_tensor_to_scalar_codegen` | `test_gate_2d_tensor_to_scalar_binding_flags` | -| `test_2d_tensor_to_array` | `test_gate_2d_tensor_to_1d_array_codegen` | `test_gate_2d_tensor_to_1d_array_binding_flags` | -| `test_mixed_vectorized_dim0_tensor` | `test_gate_mixed_vectorized_and_dim0_tensor_codegen` | `test_gate_mixed_vectorized_and_dim0_tensor_binding_flags` | -| `test_struct_return_not_direct_bind` | `test_gate_struct_return_codegen` | `test_gate_struct_return_binding_flags` | -| `test_nested_struct_with_tensor_child` | `test_gate_nested_struct_with_tensor_child_codegen` | `test_gate_nested_struct_with_tensor_child_binding_flags` | - -#### Strategy C: Subsume shallow struct tests into deeper ones - -The 3-level-deep all-scalar struct test (`test_gate_deeply_nested_struct_codegen` + `test_gate_deeply_nested_struct_binding_flags`) covers the same pattern as the 2-level nested struct (`test_gate_nested_struct_codegen` + `test_gate_nested_struct_binding_flags`). Keep only **`test_deeply_nested_struct_direct_bind`** combining both. Drop tests for 2-level nesting. - -#### Strategy D: Consolidate all-scalar struct composite field variants - -Four test groups cover "struct with {vector/matrix/array/struct-array} fields, all dim-0 → direct-bind". These all exercise the same property (recursive `can_direct_bind` returning `True`): -- struct with vector field (`test_gate_struct_with_vector_fields_codegen` + binding) -- struct with matrix field (`test_gate_struct_with_matrix_field_codegen`) -- struct with array field (`test_gate_struct_with_array_field_codegen` + binding) -- struct with struct-array field (`test_gate_struct_with_struct_array_field_codegen`) - -→ Merge into **one parametrized test** `test_struct_composite_fields_direct_bind` with parameters for the variant (vector field, array field). Drop matrix and struct-array codegen-only tests — if vector+array pass, the mechanism works for all composite field types. - -#### Strategy E: Consolidate negative gates - -Merge `test_gate_wanghasharg_uses_wrapper`, `test_wanghasharg_binding_flag`, and `test_struct_with_wanghash_child_not_direct_bind` into **one test** `test_wanghasharg_not_direct_bind` that covers standalone + struct-child cases. - -Keep `test_gate_vectorized_scalar_keeps_wrapper` and `test_gate_vectorized_dict_keeps_struct_load` as separate small tests (they're already minimal). - -#### Strategy F: Consolidate long-name tests - -Merge `test_gate_long_struct_name_gets_typealias`, `test_gate_short_struct_name_inlined`, `test_gate_long_scalar_type_name_gets_typealias` into **one test** `test_long_type_name_typealias`. - -#### Strategy G: Drop functional dispatch tests with existing coverage - -| Dropped test | Covered by | -|---|---| -| `test_phase1_functional_scalar_add` | `test_simple_function_call.py::test_returnvalue` | -| `test_phase1_functional_float_mul` | Same mechanism as scalar_add, float type tested elsewhere | -| `test_phase1_functional_valueref_write` | `test_simple_function_call.py::test_scalar_outparam` | -| `test_phase1_functional_struct_return` | `test_return_types.py::test_return_struct_as_dict` | -| `test_phase1_functional_struct_sum` | Similar to struct_return, struct dispatch tested elsewhere | -| `test_phase1_functional_nested_struct` | Subsumed by deeply_nested + nested_with_tensor tests | -| `test_phase1_functional_struct_with_vector_fields` | Covered by composite field parametrized test pattern | -| `test_phase1_functional_struct_with_matrix_field` | Same mechanism as vector_fields test | -| `test_phase1_functional_struct_with_array_field` | Same mechanism, array dispatch tested in `test_simple_function_call.py` | -| `test_phase1_functional_deeply_nested_struct` | 3-level dispatch validates same mechanism as 2-level | -| `test_phase1_functional_vector_scale` | `test_vector_function_call.py` covers vector dispatch | -| `test_phase1_functional_3d_tensor_to_vector` | If 2D→vector works, 3D exercises same path with extra dim | -| `test_phase1_functional_2d_tensor_to_scalar` | Element-wise tensor dispatch covered by `test_tensor.py` | -| `test_phase1_functional_valueref_read_input` | Read ValueRef → scalar is tested indirectly; codegen test verifies the binding | -| `test_phase1_functional_long_struct_name` | Long name is a codegen-only concern; dispatch is identical to short-name struct | - ---- - -### Proposed Final Test List (~34 tests) - -**New helper:** -```python -def generate_code_and_bindings(device, func_name, module_source, *args, **kwargs): - """Generate code and return (code_str, bindings) from a single debug_build_call_data call.""" - func = helpers.create_function_from_module(device, func_name, module_source) - cd = func.debug_build_call_data(*args, **kwargs) - return cd.code, cd.debug_only_bindings -``` - -**Codegen + binding flag tests (21):** - -| # | Test name | Scenario | Merges from | -|---|---|---|---| -| 1 | `test_scalar_direct_bind` | int/float scalar dim-0; _result writable | 5 codegen tests + 1 binding test + float variant | -| 2 | `test_vector_direct_bind` | float3 dim-0 | codegen test + new binding assertions | -| 3 | `test_matrix_direct_bind` | float4x4 dim-0 | standalone | -| 4 | `test_array_direct_bind` | float[4] dim-0 | standalone | -| 5 | `test_valueref_read_direct_bind` | read-only ValueRef | standalone | -| 6 | `test_writable_valueref_not_direct_bind` | inout ValueRef (RWValueRef) | standalone | -| 7 | `test_struct_all_scalar_direct_bind` | S{float x, y} via dict | codegen + binding pair | -| 8 | `test_struct_composite_fields_direct_bind` | parametrized: struct with vector / array field | 4 codegen + 2 binding tests | -| 9 | `test_deeply_nested_struct_direct_bind` | 3-level Top{Mid{Bot}} | subsumes 2-level; codegen + binding pair | -| 10 | `test_struct_mixed_fields` | S{x(tensor), y(scalar)} | 2 codegen + 1 binding test | -| 11 | `test_nested_struct_with_tensor_child` | Outer{Inner{x(tensor),y},s} | codegen + binding pair | -| 12 | `test_struct_return_not_direct_bind` | function returning struct | codegen + binding pair | -| 13 | `test_struct_vectorized_2d_child` | S{float3 v (2D tensor), float s} | standalone | -| 14 | `test_mixed_scalar_and_tensor` | scalar + tensor args | codegen + binding pair | -| 15 | `test_tensor_dim0_direct_bind` | Tensor at dim-0 | codegen + binding pair | -| 16 | `test_2d_tensor_to_vector` | 2D(10,3) → float3 | codegen + binding pair | -| 17 | `test_3d_tensor_to_vector` | 3D(2,5,3) → float3 | codegen + binding pair | -| 18 | `test_2d_tensor_to_scalar` | 2D(4,5) → float | codegen + binding pair | -| 19 | `test_2d_tensor_to_array` | 2D(4,8) → half[8] | codegen + binding pair | -| 20 | `test_mixed_vectorized_dim0_tensor` | vectorized + dim-0 tensor | codegen + binding pair | -| 21 | `test_long_type_name_typealias` | long/short struct name, wrapper name | 3 tests merged | - -**Negative gates (3):** - -| # | Test name | Scenario | Merges from | -|---|---|---|---| -| 22 | `test_wanghasharg_not_direct_bind` | standalone + struct child | 3 tests merged | -| 23 | `test_vectorized_scalar_keeps_wrapper` | 1D tensor → float | standalone | -| 24 | `test_vectorized_dict_keeps_wrapper` | dict with tensor children | standalone | - -**Autodiff (1):** - -| # | Test name | Scenario | Merges from | -|---|---|---|---| -| 25 | `test_bwds_direct_bind` | codegen + binding flags for bwds polynomial | 3 tests merged | - -**Functional GPU dispatch — novel scenarios only (9):** - -| # | Test name | Scenario | Why novel | -|---|---|---|---| -| 26 | `test_dispatch_mixed_scalar_tensor` | scalar + 1D tensor | Not tested elsewhere | -| 27 | `test_dispatch_struct_mixed_fields` | struct{tensor+scalar} | Unique dispatch scenario | -| 28 | `test_dispatch_tensor_dim0` | Tensor at dim-0 | Specific dim-0 behavior | -| 29 | `test_dispatch_2d_tensor_to_vector` | 2D→float3 | Novel param mapping | -| 30 | `test_dispatch_2d_tensor_to_array` | 2D→half[8] generic | Unique test | -| 31 | `test_dispatch_mixed_vectorized_dim0_tensor` | vectorized + dim-0 tensor | Unique | -| 32 | `test_dispatch_nested_struct_with_tensor` | nested struct with tensor leaf | Unique | -| 33 | `test_dispatch_struct_vectorized_2d_child` | struct with 2D tensor→float3 child | Unique | -| 34 | `test_dispatch_struct_array_of_structs` | struct with `Inner items[4]` | Unique | - ---- - -### Old → New Mapping - -| Old test (test_kernel_gen.py) | New test (test_code_gen.py) | Action | -|---|---|---| -| `test_kernel_gen_basic` | — | **Dropped** (subset of `test_scalar_direct_bind`) | -| `test_gate_scalar_uses_valuetype` | `test_scalar_direct_bind` | **Merged** | -| `test_gate_float_scalar_uses_valuetype` | `test_scalar_direct_bind` | **Merged** (or dropped) | -| `test_gate_vector_uses_vectorvaluetype` | `test_vector_direct_bind` | **Merged** | -| `test_gate_matrix_uses_valuetype` | `test_matrix_direct_bind` | **Kept** (standalone) | -| `test_gate_array_dim0_uses_valuetype` | `test_array_direct_bind` | **Kept** (standalone) | -| `test_gate_valueref_read_uses_wrapper` | `test_valueref_read_direct_bind` | **Kept** (standalone) | -| `test_gate_valueref_write_uses_wrapper` | `test_scalar_direct_bind` | **Merged** | -| `test_gate_mapping_constants_present` | `test_scalar_direct_bind` | **Merged** | -| `test_gate_context_map_in_trampoline` | `test_scalar_direct_bind` | **Merged** | -| `test_gate_struct_uses_slangpy_load` | `test_struct_all_scalar_direct_bind` | **Merged** | -| `test_gate_bwds_scalar_uses_valuetype` | `test_bwds_direct_bind` | **Merged** | -| `test_gate_bwds_trampoline_is_differentiable` | `test_bwds_direct_bind` | **Merged** | -| `test_gate_wanghasharg_uses_wrapper` | `test_wanghasharg_not_direct_bind` | **Merged** | -| `test_gate_vectorized_scalar_keeps_wrapper` | `test_vectorized_scalar_keeps_wrapper` | **Kept** | -| `test_gate_vectorized_dict_keeps_struct_load` | `test_vectorized_dict_keeps_wrapper` | **Kept** | -| `test_phase1_functional_scalar_add` | — | **Dropped** (covered by `test_simple_function_call.py`) | -| `test_phase1_functional_float_mul` | — | **Dropped** | -| `test_phase1_functional_vector_scale` | — | **Dropped** (covered by `test_vector_function_call.py`) | -| `test_phase1_functional_struct_sum` | — | **Dropped** | -| `test_phase1_functional_valueref_write` | — | **Dropped** (covered by `test_simple_function_call.py`) | -| `test_gate_mixed_args_scalar_and_tensor` | `test_mixed_scalar_and_tensor` | **Merged** | -| `test_gate_mixed_args_direct_bind_flags` | `test_mixed_scalar_and_tensor` | **Merged** | -| `test_phase1_functional_mixed_scalar_tensor` | `test_dispatch_mixed_scalar_tensor` | **Kept** | -| `test_gate_struct_mixed_fields_codegen` | `test_struct_mixed_fields` | **Merged** | -| `test_gate_struct_mixed_fields_binding_flags` | `test_struct_mixed_fields` | **Merged** | -| `test_phase1_functional_struct_mixed_fields` | `test_dispatch_struct_mixed_fields` | **Kept** | -| `test_gate_tensor_dim0_codegen` | `test_tensor_dim0_direct_bind` | **Merged** | -| `test_gate_tensor_dim0_binding_flags` | `test_tensor_dim0_direct_bind` | **Merged** | -| `test_phase1_functional_tensor_dim0` | `test_dispatch_tensor_dim0` | **Kept** | -| `test_mixed_children_direct_bind_codegen` | `test_struct_mixed_fields` | **Merged** (overlap with struct_mixed_fields) | -| `test_writable_valueref_not_direct_bind` | `test_writable_valueref_not_direct_bind` | **Kept** | -| `test_result_binding_not_direct_bind` | `test_scalar_direct_bind` | **Merged** | -| `test_struct_all_scalars_binding_flag` | `test_struct_all_scalar_direct_bind` | **Merged** | -| `test_struct_with_wanghash_child_not_direct_bind` | `test_wanghasharg_not_direct_bind` | **Merged** | -| `test_wanghasharg_binding_flag` | `test_wanghasharg_not_direct_bind` | **Merged** | -| `test_bwds_primal_binding_flags` | `test_bwds_direct_bind` | **Merged** | -| `test_gate_2d_tensor_to_vector_codegen` | `test_2d_tensor_to_vector` | **Merged** | -| `test_gate_2d_tensor_to_vector_binding_flags` | `test_2d_tensor_to_vector` | **Merged** | -| `test_phase1_functional_2d_tensor_to_vector` | `test_dispatch_2d_tensor_to_vector` | **Kept** | -| `test_gate_3d_tensor_to_vector_codegen` | `test_3d_tensor_to_vector` | **Merged** | -| `test_gate_3d_tensor_to_vector_binding_flags` | `test_3d_tensor_to_vector` | **Merged** | -| `test_phase1_functional_3d_tensor_to_vector` | — | **Dropped** (2D→vector is sufficient) | -| `test_gate_2d_tensor_to_scalar_codegen` | `test_2d_tensor_to_scalar` | **Merged** | -| `test_gate_2d_tensor_to_scalar_binding_flags` | `test_2d_tensor_to_scalar` | **Merged** | -| `test_phase1_functional_2d_tensor_to_scalar` | — | **Dropped** (covered by `test_tensor.py`) | -| `test_gate_2d_tensor_to_1d_array_codegen` | `test_2d_tensor_to_array` | **Merged** | -| `test_gate_2d_tensor_to_1d_array_binding_flags` | `test_2d_tensor_to_array` | **Merged** | -| `test_phase1_functional_2d_tensor_to_1d_array` | `test_dispatch_2d_tensor_to_array` | **Kept** | -| `test_gate_mixed_vectorized_and_dim0_tensor_codegen` | `test_mixed_vectorized_dim0_tensor` | **Merged** | -| `test_gate_mixed_vectorized_and_dim0_tensor_binding_flags` | `test_mixed_vectorized_dim0_tensor` | **Merged** | -| `test_phase1_functional_mixed_vectorized_and_dim0_tensor` | `test_dispatch_mixed_vectorized_dim0_tensor` | **Kept** | -| `test_gate_nested_struct_codegen` | — | **Dropped** (subsumed by deeply_nested) | -| `test_gate_nested_struct_binding_flags` | — | **Dropped** (subsumed by deeply_nested) | -| `test_phase1_functional_nested_struct` | — | **Dropped** | -| `test_gate_struct_with_vector_fields_codegen` | `test_struct_composite_fields_direct_bind` | **Merged** (parametrized) | -| `test_gate_struct_with_vector_fields_binding_flags` | `test_struct_composite_fields_direct_bind` | **Merged** | -| `test_phase1_functional_struct_with_vector_fields` | — | **Dropped** | -| `test_gate_struct_with_matrix_field_codegen` | — | **Dropped** (covered by vector+array variants) | -| `test_phase1_functional_struct_with_matrix_field` | — | **Dropped** | -| `test_gate_struct_with_array_field_codegen` | `test_struct_composite_fields_direct_bind` | **Merged** (parametrized) | -| `test_gate_struct_with_array_field_binding_flags` | `test_struct_composite_fields_direct_bind` | **Merged** | -| `test_phase1_functional_struct_with_array_field` | — | **Dropped** | -| `test_gate_deeply_nested_struct_codegen` | `test_deeply_nested_struct_direct_bind` | **Merged** | -| `test_gate_deeply_nested_struct_binding_flags` | `test_deeply_nested_struct_direct_bind` | **Merged** | -| `test_phase1_functional_deeply_nested_struct` | — | **Dropped** | -| `test_gate_nested_struct_with_tensor_child_codegen` | `test_nested_struct_with_tensor_child` | **Merged** | -| `test_gate_nested_struct_with_tensor_child_binding_flags` | `test_nested_struct_with_tensor_child` | **Merged** | -| `test_phase1_functional_nested_struct_with_tensor` | `test_dispatch_nested_struct_with_tensor` | **Kept** | -| `test_gate_struct_with_struct_array_field_codegen` | — | **Dropped** (covered by array field variant) | -| `test_phase1_functional_struct_with_struct_array_field` | `test_dispatch_struct_array_of_structs` | **Kept** | -| `test_gate_struct_return_codegen` | `test_struct_return_not_direct_bind` | **Merged** | -| `test_gate_struct_return_binding_flags` | `test_struct_return_not_direct_bind` | **Merged** | -| `test_phase1_functional_struct_return` | — | **Dropped** (covered by `test_return_types.py`) | -| `test_gate_struct_with_vectorized_2d_tensor_child_codegen` | `test_struct_vectorized_2d_child` | **Kept** | -| `test_phase1_functional_struct_with_vectorized_2d_tensor` | `test_dispatch_struct_vectorized_2d_child` | **Kept** | -| `test_gate_long_struct_name_gets_typealias` | `test_long_type_name_typealias` | **Merged** | -| `test_gate_short_struct_name_inlined` | `test_long_type_name_typealias` | **Merged** | -| `test_gate_long_scalar_type_name_gets_typealias` | `test_long_type_name_typealias` | **Merged** | -| `test_phase1_functional_long_struct_name` | — | **Dropped** | -| `test_phase1_functional_valueref_read_input` | — | **Dropped** | - ---- - -### Verification - -```bash -# Build first (required) -cmake --build --preset windows-msvc-debug - -# Run new test file -pytest slangpy/tests/slangpy_tests/test_code_gen.py -v - -# Confirm full suite still passes (existing tests in other files cover dropped dispatch tests) -pytest slangpy/tests -v - -# Run pre-commit -pre-commit run --all-files -``` - -### Key Decisions - -- Combined codegen+binding tests: one `debug_build_call_data` call yields both `.code` and `.debug_only_bindings` — no redundant kernel generation -- Dropped `test_kernel_gen_basic`: its sole assertion (`"add" in code`) is a strict subset of `test_scalar_direct_bind` -- Dropped matrix/struct-array field variants: if vector field and array field pass, the `can_direct_bind` recursion works for all composite types -- Dropped 2-level nested struct: the 3-level test covers the same recursion with deeper nesting -- Dropped 15 functional dispatch tests that are covered by existing test files (`test_simple_function_call.py`, `test_return_types.py`, `test_vector_function_call.py`, `test_tensor.py`) -- Kept all negative gates — they deliberately test types NOT eligible for simplification and must remain passing as Phase 2 proceeds -- The old `test_kernel_gen.py` should be deleted once the new `test_code_gen.py` is verified diff --git a/.github/prompts/plan-simplifyKernelGen-phase2.prompt.md b/.github/prompts/plan-simplifyKernelGen-phase2.prompt.md index 2f6ca6f80..61f4f08ee 100644 --- a/.github/prompts/plan-simplifyKernelGen-phase2.prompt.md +++ b/.github/prompts/plan-simplifyKernelGen-phase2.prompt.md @@ -4,6 +4,8 @@ **Parent plan**: [plan-simplifyKernelGen.prompt.md](plan-simplifyKernelGen.prompt.md) +**Status**: Steps 2.0–2.2 and 2.4 complete. Step 2.3 (trampoline elimination for prim mode) not started. Code generation logic has been extracted from `callsignature.py` into [generator.py](slangpy/core/generator.py) (see [plan-extractCodegenToGenerator.prompt.md](plan-extractCodegenToGenerator.prompt.md)). + --- ### Key Architectural Decisions @@ -24,12 +26,26 @@ These decisions correct several assumptions in the original plan: 6. **C++ dispatch changes are isolated to `NativeCallData::exec`.** Marshalls receive a `ShaderCursor` pointing to wherever their data lives — they don't care whether it's inside a `CallData` struct or an entry-point param. In the fast path, `m_runtime->write_shader_cursor_pre_dispatch()` receives the entry-point cursor directly. No marshall code changes needed. -7. **`CallDataMode` is eliminated.** The `global_data` vs `entry_point` distinction is removed entirely. On the fast path, all backends use entry-point params uniformly. On the fallback path, all backends use `ParameterBlock` — CUDA supports `ParameterBlock` and in practice will never hit the fallback (CUDA's inline-uniform limit is ~4KB). This removes the `CallDataMode` enum, the CUDA-specific `is_entry_point` codegen branch in `callsignature.py`, and the corresponding C++ branch in `slangpy.cpp`. +7. **`CallDataMode` is eliminated.** The `global_data` vs `entry_point` distinction is removed entirely. On the fast path, all backends use entry-point params uniformly. On the fallback path, all backends use `ParameterBlock` — CUDA supports `ParameterBlock` and in practice will never hit the fallback (CUDA's inline-uniform limit is ~4KB). This removes the `CallDataMode` enum, the CUDA-specific `is_entry_point` codegen branch in `callsignature.py`/`generator.py`, and the corresponding C++ branch in `slangpy.cpp`. 8. **`PackedArg` / param-block types are unchanged.** They stay as `ParameterBlock` at module scope, orthogonal to Phase 2. --- +### Code Organization (post-extraction) + +All code generation logic now lives in [generator.py](slangpy/core/generator.py). [callsignature.py](slangpy/core/callsignature.py) retains binding-pipeline functions (`specialize`, `bind`, `calculate_*`, `estimate_entrypoint_arguments_size`, etc.) and re-exports `generate_code`, `generate_constants`, `KernelGenException` from `generator.py` for backward compatibility. + +| File | Role | +|------|------| +| [generator.py](slangpy/core/generator.py) | All code emission: `generate_code()`, `_emit_trampoline()`, `_emit_entry_point_signature()`, `_emit_kernel_body()`, `_emit_shape_and_metadata_params()`, `_emit_link_time_constants()`, `_emit_call_data_definitions()`, `_emit_trampoline_loads/stores()`, `_data_name()`, `_validate_and_compute_group_shape()`, `gen_call_data_code()`, `gen_calldata_type_name()`, `KernelGenException` | +| [callsignature.py](slangpy/core/callsignature.py) | Binding pipeline: `specialize()`, `bind()`, `apply_explicit_vectorization()`, `apply_implicit_vectorization()`, `finalize_mappings()`, `calculate_differentiability()`, `calculate_direct_binding()`, `estimate_entrypoint_arguments_size()`, `calculate_call_dimensionality()`, `create_return_value_binding()` | +| [calldata.py](slangpy/core/calldata.py) | `CallData` class orchestrating build pipeline; wildcard-imports from `callsignature.py` | +| [codegen.py](slangpy/bindings/codegen.py) | `CodeGen` class with `skip_call_data`, `entry_point_params` attributes | +| [boundvariable.py](slangpy/bindings/boundvariable.py) | `BoundVariable` methods delegate to `gen_call_data_code()` and `gen_calldata_type_name()` in `generator.py` | + +--- + ### Current Kernel Structure (post-Phase 1) For `int add(int a, int b)` with scalar args `(1, 2)`: @@ -199,7 +215,7 @@ In [slangpy/core/calldata.py](slangpy/core/calldata.py), after `calculate_direct **Implementation details:** - `DeviceLimits.max_entry_point_uniform_size` added to C++ struct ([device.h](src/sgl/device/device.h)) with per-backend defaults: Vulkan=128, D3D12=256, CUDA=4096 bytes ([device.cpp](src/sgl/device/device.cpp)). -- `calculate_inline_uniform_size()` added to [callsignature.py](slangpy/core/callsignature.py) — sums `vector_type.uniform_layout.size` for each depth-0 bound variable (skipping `PackedArg`), plus 12 bytes for `_thread_count` and `call_dimensionality * 4 * 3` for shape arrays. +- `estimate_entrypoint_arguments_size()` in [callsignature.py](slangpy/core/callsignature.py) — sums `vector_type.uniform_layout.size` for each depth-0 bound variable (skipping `PackedArg`), plus 12 bytes for `_thread_count` and `call_dimensionality * 4 * 3` for shape arrays. - `use_entrypoint_args` property added to `NativeCallData` C++ class ([slangpy.h](src/slangpy_ext/utils/slangpy.h)) with Python binding. - `CallData.__init__()` in [calldata.py](slangpy/core/calldata.py) sets `self.use_entrypoint_args = inline_size <= threshold` after `calculate_direct_binding()`. @@ -221,20 +237,20 @@ In [slangpy/core/calldata.py](slangpy/core/calldata.py), after `calculate_direct **Status: DONE** -In [slangpy/core/callsignature.py](slangpy/core/callsignature.py) `generate_code()`, when `use_entrypoint_args == True`: +In [generator.py](slangpy/core/generator.py) `generate_code()` (line 778), when `use_entrypoint_args == True`: **CodeGen changes** in [slangpy/bindings/codegen.py](slangpy/bindings/codegen.py): -- Add a `skip_call_data` flag to `CodeGen.__init__`. When `True`, don't emit `struct CallData` / `begin_block()` and gate `end_block()` in `finish()`. -- Add `self.entry_point_params: list[str] = []` to collect individual uniform param declarations. +- `self.skip_call_data: bool = False` — when `True`, don't emit `struct CallData` / `begin_block()` and gate `end_block()` in `finish()`. +- `self.entry_point_params: list[str] = []` — collects individual uniform param declarations. - `finish()` ignores the `call_data` block and `use_param_block_for_call_data` when `skip_call_data` is set. -**CallData struct elimination**: Set `cg.skip_call_data = True` when `use_entrypoint_args`. No `struct CallData` emitted. +**CallData struct elimination**: `generate_code()` sets `cg.skip_call_data = True` when `use_entrypoint_args`. No `struct CallData` emitted. -**`gen_call_data_code` change** in [slangpy/bindings/boundvariable.py](slangpy/bindings/boundvariable.py): At `depth == 0`, when `use_entrypoint_args`, append to `cg.entry_point_params` instead of `cg.call_data.declare(...)`. The `call_data_structs` block (type aliases, wrapper structs, mapping constants) still gets emitted at module scope. +**`_emit_call_data_code`** in [generator.py](slangpy/core/generator.py#L345): At `depth == 0`, when `use_entrypoint_args`, appends to `cg.entry_point_params` instead of `cg.call_data.declare(...)`. The `call_data_structs` block (type aliases, wrapper structs, mapping constants) still gets emitted at module scope. -**`_thread_count` and shape arrays**: Instead of `cg.call_data.append_statement("uint3 _thread_count")`, append to `cg.entry_point_params`. Same for `_grid_stride`, `_grid_dim`, `_call_dim` when `call_data_len > 0`. +**`_thread_count` and shape arrays**: `_emit_shape_and_metadata_params()` ([generator.py](slangpy/core/generator.py#L466)) appends to `cg.entry_point_params` instead of `cg.call_data`. Same for `_grid_stride`, `_grid_dim`, `_call_dim` when `call_data_len > 0`. -**Entry-point signature**: `compute_main` signature becomes: +**Entry-point signature**: `_emit_entry_point_signature()` ([generator.py](slangpy/core/generator.py#L652)) emits `compute_main` signature as: ```slang void compute_main( int3 flat_call_thread_id: SV_DispatchThreadID, @@ -268,7 +284,7 @@ When `call_mode == prim` — on **both** fast and fallback paths: - Don't generate the `_trampoline` function. - Inline the load/call/store sequence directly into `compute_main` after the bounds check and (if needed) Context construction. -- The load/call/store codegen reuses the same logic currently in [callsignature.py lines 378–449](slangpy/core/callsignature.py#L378-L449), but emitted into `cg.kernel` instead of `cg.trampoline` with adjusted `data_name`: +- The load/call/store codegen reuses the same logic currently in `_emit_trampoline_loads()` ([generator.py](slangpy/core/generator.py#L528)) and `_emit_trampoline_stores()` ([generator.py](slangpy/core/generator.py#L551)), but emitted into `cg.kernel` instead of `cg.trampoline` with adjusted `data_name` from `_data_name()` ([generator.py](slangpy/core/generator.py#L513)): | Path | `data_name` for non-param-block args | |------|-------------------------------------| @@ -278,6 +294,11 @@ When `call_mode == prim` — on **both** fast and fallback paths: **Context construction**: Needed only when any arg is non-direct-bind (i.e., calls `__slangpy_load`/`__slangpy_store`). When all args satisfy `direct_bind == True`, skip Context construction entirely — no `Context __slangpy_context__` declaration, no `import "slangpy"`. +**Key functions to modify in [generator.py](slangpy/core/generator.py)**: +- `_emit_trampoline()` (line 578): gate on `call_mode != prim` — only emit for bwds mode. +- `_emit_kernel_body()` (line 708): when prim mode, inline the load/call/store sequence directly instead of calling `_trampoline()`. +- `generate_code()` (line 778): skip `_emit_trampoline()` call when prim mode. + **Note**: The trampoline elimination does NOT depend on `direct_bind`. Even non-direct-bind args with `__slangpy_load` work inline in `compute_main` — the `__slangpy_load` call just needs the data reference and a `Context` value, both available in `compute_main`. --- @@ -288,7 +309,7 @@ When `call_mode == prim` — on **both** fast and fallback paths: When `call_mode == bwds`: -- Still generate a `[Differentiable]` trampoline function. +- Still generate a `[Differentiable]` trampoline function via `_emit_trampoline()` ([generator.py](slangpy/core/generator.py#L578)). - **Fast path**: Trampoline takes individual params instead of a struct. All params get `no_diff` — entry-point uniforms are never differentiable. Differentiation happens through local variable assignments inside the trampoline body, matching the struct-based approach where `CallData` was implicitly non-differentiable. No `in`/`out`/`inout` modifiers are added — `compute_main` passes its uniforms straight through: ```slang [Differentiable] @@ -296,7 +317,7 @@ When `call_mode == bwds`: ``` `compute_main` calls `bwd_diff(_trampoline)(__slangpy_context__, a, b, _result)` passing entry-point param names directly. - **Fallback path**: Trampoline reads from global `ParameterBlock call_data` as it does today (on all backends). `compute_main` calls `bwd_diff(_trampoline)(__slangpy_context__, call_data)`. -- `_gen_trampoline_argument()` in `boundvariable.py` remains unused dead code — the inline generation in `callsignature.py` is simpler and avoids the `in`/`out`/`inout` modifiers that caused Slang autodiff errors. +- `_gen_trampoline_argument()` in `boundvariable.py` remains unused dead code — the inline generation in `_emit_trampoline()` ([generator.py](slangpy/core/generator.py#L578)) is simpler and avoids the `in`/`out`/`inout` modifiers that caused Slang autodiff errors. **Key insight**: Adding `in`/`out`/`inout` modifiers to trampoline params caused Slang autodiff issues (e.g., `out` params get reversed to `in` by `bwd_diff`, changing arity). The trampoline params are just pass-through uniforms — all data flow logic (loads, stores, differentiation) is handled internally via local variables. @@ -383,10 +404,11 @@ Auto-created `_result` is a writable `ValueRef`, currently NOT direct-bind eligi | File | Changes | |------|---------| +| [slangpy/core/generator.py](slangpy/core/generator.py) | ✅ All code generation logic extracted here from `callsignature.py`. `generate_code()` orchestrator (line 778), `_emit_trampoline()` (line 578), `_emit_entry_point_signature()` (line 652), `_emit_kernel_body()` (line 708), `_emit_shape_and_metadata_params()` (line 466), `_emit_link_time_constants()` (line 424), `_emit_call_data_definitions()` (line 503), `_emit_trampoline_loads/stores()`, `_data_name()`, `_validate_and_compute_group_shape()`, `gen_call_data_code()`, `gen_calldata_type_name()`, `KernelGenException`. Entry-point params fast/fallback code paths. Bwds `no_diff` on all trampoline params (Step 2.4). Trampoline still generated for prim mode (Step 2.3 pending). | | [slangpy/core/calldata.py](slangpy/core/calldata.py) | ✅ `use_entrypoint_args` flag, size threshold check, `CallDataMode` removed | -| [slangpy/core/callsignature.py](slangpy/core/callsignature.py) | ✅ Entry-point params, fast/fallback code paths, `is_entry_point` branch removed. Trampoline still generated (Step 2.3 pending). Bwds `no_diff` on all trampoline params (Step 2.4 done). | +| [slangpy/core/callsignature.py](slangpy/core/callsignature.py) | ✅ Binding-pipeline functions only (`specialize`, `bind`, `calculate_*`, `estimate_entrypoint_arguments_size`). Re-exports `generate_code`, `generate_constants`, `KernelGenException` from `generator.py`. | | [slangpy/bindings/codegen.py](slangpy/bindings/codegen.py) | ✅ `skip_call_data` flag, `entry_point_params` list | -| [slangpy/bindings/boundvariable.py](slangpy/bindings/boundvariable.py) | ✅ `gen_call_data_code` depth-0 entry-point path. `_gen_trampoline_argument()` unused — inline generation in `callsignature.py` used instead. | +| [slangpy/bindings/boundvariable.py](slangpy/bindings/boundvariable.py) | ✅ `gen_call_data_code` and `gen_calldata_type_name` delegate to `generator.py`. `_gen_trampoline_argument()` unused dead code. | | [slangpy/bindings/marshall.py](slangpy/bindings/marshall.py) | ✅ `use_entrypoint_args` field on `BindContext`, `CallDataMode` removed | | [src/slangpy_ext/utils/slangpy.cpp](src/slangpy_ext/utils/slangpy.cpp) | ✅ `use_entrypoint_args` binding; `bind_call_data` fast path via `find_entry_point(0)`, `CallDataMode` branches removed | | [src/slangpy_ext/utils/slangpy.h](src/slangpy_ext/utils/slangpy.h) | ✅ `m_use_entrypoint_args` on `NativeCallData`; `m_call_data_mode` removed | @@ -394,7 +416,7 @@ Auto-created `_result` is a writable `ValueRef`, currently NOT direct-bind eligi | [src/sgl/device/device.cpp](src/sgl/device/device.cpp) | ✅ Per-backend defaults for `max_entry_point_uniform_size` | | [src/slangpy_ext/device/device.cpp](src/slangpy_ext/device/device.cpp) | ✅ Python binding for `max_entry_point_uniform_size` | | [src/sgl/utils/slangpy.h](src/sgl/utils/slangpy.h) | ✅ `CallDataMode` enum removed | -| [slangpy/core/dispatchdata.py](slangpy/core/dispatchdata.py) | ✅ `CallDataMode` removed | +| [slangpy/core/dispatchdata.py](slangpy/core/dispatchdata.py) | ✅ `CallDataMode` removed; imports `generate_constants` from `generator.py` | | [slangpy/core/packedarg.py](slangpy/core/packedarg.py) | ✅ `CallDataMode` removed | | [slangpy/core/function.py](slangpy/core/function.py) | ✅ `CallDataMode` removed from imports | | [slangpy/slangpy/__init__.pyi](slangpy/slangpy/__init__.pyi) | ✅ `CallDataMode` class and `call_data_mode` property removed | diff --git a/slangpy/core/generator.py b/slangpy/core/generator.py index c9676df53..99572af2e 100644 --- a/slangpy/core/generator.py +++ b/slangpy/core/generator.py @@ -514,17 +514,90 @@ def _data_name(x: "BoundVariable", use_entrypoint_args: bool) -> str: """Return the Slang name used to access a variable's call data in the trampoline. - ``_param_{name}`` for param-block variables (both paths). - - ``__in_{name}`` in the fast (entry-point-args) path. + - ``{name}`` in the fast (entry-point-args) path. - ``call_data.{name}`` in the fallback path. """ if x.create_param_block: return f"_param_{x.variable_name}" elif use_entrypoint_args: - return f"__in_{x.variable_name}" + return x.variable_name else: return f"call_data.{x.variable_name}" +def _tmp_name(x: "BoundVariable") -> str: + """Return the local temporary variable name used for loaded values.""" + return f"__tmp_{x.variable_name}" + + +def _emit_load_call_store_sequence( + cgb: CodeGenBlock, + build_info: "FunctionBuildInfo", + root_params: list["BoundVariable"], + use_entrypoint_args: bool, + context_name: str, +) -> None: + """Emit local declarations, load/call/store sequence into ``cgb``. + + This is shared by the bwds trampoline body and the prim inlined kernel body. + """ + from slangpy.bindings.boundvariable import BoundVariableException + + # Declare local temporaries for each parameter to avoid collisions with + # entry-point parameter names on the fast path. + for x in root_params: + assert x.vector_type is not None + cgb.declare(x.vector_type.full_name, _tmp_name(x)) + + # Load inputs from call data / entry-point params into temporaries. + for x in root_params: + data_name = _data_name(x, use_entrypoint_args) + value_name = _tmp_name(x) + if _try_custom_gen(x, "gen_trampoline_load", cgb, data_name, value_name): + continue + if _is_readable(x): + cgb.append_statement( + f"{data_name}.__slangpy_load({context_name}.map(_m_{x.variable_name}), {value_name})" + ) + + # Emit the 'result=' bit if function has a return value. + cgb.append_indent() + if any(x.variable_name == "_result" for x in root_params): + cgb.append_code( + f"{_tmp_name(next(x for x in root_params if x.variable_name == '_result'))} = " + ) + + # Generate the function call prefix, with some special casing for constructors + # and type method calls. + func_name = build_info.name + if func_name == "$init": + results = [x for x in root_params if x.variable_name == "_result"] + assert len(results) == 1 + assert results[0].vector_type is not None + func_name = results[0].vector_type.full_name + elif len(root_params) > 0 and root_params[0].variable_name == "_this": + func_name = f"{_tmp_name(root_params[0])}.{func_name}" + + # Emit the function call itself, passing in parameters other than _result and _this. + normal_params = [ + x for x in root_params if x.variable_name != "_result" and x.variable_name != "_this" + ] + cgb.append_code(f"{func_name}(" + ", ".join(_tmp_name(x) for x in normal_params) + ");\n") + + # Store outputs back to call data. + for x in root_params: + if _is_writable(x) or _grad_is_readable(x): + data_name = _data_name(x, use_entrypoint_args) + value_name = _tmp_name(x) + if _try_custom_gen(x, "gen_trampoline_store", cgb, data_name, value_name): + continue + if not x.python.is_writable: + raise BoundVariableException(f"Cannot read back value for non-writable type", x) + cgb.append_statement( + f"{data_name}.__slangpy_store({context_name}.map(_m_{x.variable_name}), {value_name})" + ) + + def _emit_trampoline_loads( cgb: CodeGenBlock, root_params: list["BoundVariable"], @@ -604,47 +677,20 @@ def _emit_trampoline( if x.create_param_block: continue assert x.calldata_type_name is not None - trampoline_params.append(f"no_diff {x.calldata_type_name} __in_{x.variable_name}") + trampoline_params.append(f"no_diff {x.calldata_type_name} {x.variable_name}") cg.trampoline.append_line(f"void _trampoline({', '.join(trampoline_params)})") else: cg.trampoline.append_line("void _trampoline(Context __slangpy_context__)") cg.trampoline.begin_block() - # Declare local variables for each parameter - for x in root_params: - assert x.vector_type is not None - cg.trampoline.declare(x.vector_type.full_name, x.variable_name) - - # Load inputs from call data - _emit_trampoline_loads(cg.trampoline, root_params, use_entrypoint_args) - - # Emit the 'result=' bit if function has a return value. - cg.trampoline.append_indent() - if any(x.variable_name == "_result" for x in root_params): - cg.trampoline.append_code("_result = ") - - # Generate the function call prefix, with some special casing for constructors - # and type method calls. - func_name = build_info.name - if func_name == "$init": - results = [x for x in root_params if x.variable_name == "_result"] - assert len(results) == 1 - assert results[0].vector_type is not None - func_name = results[0].vector_type.full_name - elif len(root_params) > 0 and root_params[0].variable_name == "_this": - func_name = f"_this.{func_name}" - - # Emit the function call itself, passing in parameters other than _result and _this. - normal_params = [ - x for x in root_params if x.variable_name != "_result" and x.variable_name != "_this" - ] - cg.trampoline.append_code( - f"{func_name}(" + ", ".join(x.variable_name for x in normal_params) + ");\n" + _emit_load_call_store_sequence( + cg.trampoline, + build_info, + root_params, + use_entrypoint_args, + "__slangpy_context__", ) - # Store outputs back to call data - _emit_trampoline_stores(cg.trampoline, root_params, use_entrypoint_args) - cg.trampoline.end_block() cg.trampoline.append_line("") @@ -757,13 +803,27 @@ def _emit_kernel_body( ) context_args += ", CallShapeInfo::get_call_id().shape" - # Define the core context. - cg.kernel.append_statement(f"Context __slangpy_context__ = {{{context_args}}}") + needs_context = context.call_mode == CallMode.bwds or any( + not x.direct_bind for x in root_params + ) - # Emit the trampoline call, passing the context and any entry-point args (if using them). - fn = "_trampoline" - if context.call_mode == CallMode.bwds: - fn = f"bwd_diff({fn})" + if needs_context: + # Define the core context. + cg.kernel.append_statement(f"Context __slangpy_context__ = {{{context_args}}}") + + if context.call_mode == CallMode.prim: + # Prim mode inlines load/call/store directly in compute_main. + _emit_load_call_store_sequence( + cg.kernel, + build_info, + root_params, + use_entrypoint_args, + "__slangpy_context__", + ) + return + + # Bwds mode still calls differentiable trampoline. + fn = "bwd_diff(_trampoline)" if use_entrypoint_args: trampoline_args = ["__slangpy_context__"] for x in root_params: @@ -806,7 +866,8 @@ def generate_code( root_params = sorted(signature.values(), key=lambda x: x.param_index) - _emit_trampoline(cg, context, build_info, root_params, use_entrypoint_args) + if context.call_mode != CallMode.prim: + _emit_trampoline(cg, context, build_info, root_params, use_entrypoint_args) _emit_entry_point_signature(cg, build_info, call_data_len, call_group_size, use_entrypoint_args) cg.kernel.begin_block() _emit_kernel_body(cg, context, build_info, root_params, call_data_len, use_entrypoint_args) diff --git a/slangpy/tests/slangpy_tests/test_code_gen.py b/slangpy/tests/slangpy_tests/test_code_gen.py index 9ec2c36d7..07257a709 100644 --- a/slangpy/tests/slangpy_tests/test_code_gen.py +++ b/slangpy/tests/slangpy_tests/test_code_gen.py @@ -35,6 +35,20 @@ def assert_contains(code: str, *patterns: str) -> None: """Assert all patterns appear in generated code.""" for p in patterns: + if p in code: + continue + + # Compatibility for Step 2.3: prim-mode local variable declarations now + # use __tmp_ prefixed names to avoid colliding with entry-point params. + # Example: "vector v;" -> "vector __tmp_v;" + if p.endswith(";") and "(" not in p and ")" not in p and "." not in p and " = " not in p: + decl = p[:-1].rstrip() + if " " in decl: + type_name, var_name = decl.rsplit(" ", 1) + alt_tmp_decl = f"{type_name} __tmp_{var_name};" + if alt_tmp_decl in code: + continue + assert p in code, f"Expected pattern not found: {p}" @@ -45,15 +59,28 @@ def assert_not_contains(code: str, *patterns: str) -> None: def assert_trampoline_has(code: str, *stmts: str) -> None: - """Assert trampoline contains statements (tolerates call_data vs __calldata__ vs __in_).""" + """Assert generated load statements across old/new kernel-generation variants. + + Accepts legacy trampoline statements and their modern equivalents: + - ``a = __calldata__.a;`` (legacy) + - ``a = call_data.a;`` (fallback) + - ``__tmp_a = a;`` (fast path with inline body) + - ``__tmp_a = call_data.a;`` (fallback with inline body) + """ for s in stmts: if "__calldata__." in s: - alt = s.replace("__calldata__.", "call_data.") - # Fast path: x = __in_x; instead of x = __calldata__.x; - alt2 = s.replace(" = __calldata__.", " = __in_") - assert ( - s in code or alt in code or alt2 in code - ), f"Expected trampoline statement not found: {s} (or {alt} or {alt2})" + alt_cd = s.replace("__calldata__.", "call_data.") + # For fast path: a = __calldata__.a; -> __tmp_a = a; + alt_tmp = s + if " = __calldata__." in s and s.endswith(";"): + lhs = s.split(" = __calldata__.", 1)[0].strip() + rhs = s.split(" = __calldata__.", 1)[1][:-1].strip() + alt_tmp = f"__tmp_{lhs} = {rhs};" + alt_tmp_cd = alt_tmp.replace(" = ", " = call_data.", 1) if alt_tmp != s else s + assert s in code or alt_cd in code or alt_tmp in code or alt_tmp_cd in code, ( + "Expected generated statement not found: " + f"{s} (or {alt_cd} or {alt_tmp} or {alt_tmp_cd})" + ) else: assert s in code, f"Expected trampoline statement not found: {s}" @@ -1232,10 +1259,14 @@ def test_bwds_entrypoint_no_diff_params(device_type: spy.DeviceType): # --- fast path --- assert cd.use_entrypoint_args is True - # --- trampoline params have no_diff and __in_ prefix --- + # --- trampoline params have no_diff; bare name (Step 2.3) or __in_ prefix (legacy) --- assert_contains(code, "no_diff") - assert_contains(code, "__in_a") - assert_contains(code, "__in_b") + assert ( + "no_diff float a" in code or "__in_a" in code + ), "Expected trampoline param for 'a' (no_diff float a or __in_a)" + assert ( + "no_diff float b" in code or "__in_b" in code + ), "Expected trampoline param for 'b' (no_diff float b or __in_b)" # --- [Differentiable] before trampoline --- diff_idx = code.index("[Differentiable]") diff --git a/slangpy/tests/slangpy_tests/test_kernel_gen.py b/slangpy/tests/slangpy_tests/test_kernel_gen.py index 250e4321b..61c4f55b1 100644 --- a/slangpy/tests/slangpy_tests/test_kernel_gen.py +++ b/slangpy/tests/slangpy_tests/test_kernel_gen.py @@ -37,6 +37,20 @@ def assert_contains(code: str, *patterns: str) -> None: """Assert all patterns appear in generated code.""" for p in patterns: + if p in code: + continue + + # Compatibility for Step 2.3: prim-mode local variable declarations now + # use __tmp_ prefixed names to avoid colliding with entry-point params. + # Example: "vector v;" -> "vector __tmp_v;" + if p.endswith(";") and "(" not in p and ")" not in p and "." not in p and " = " not in p: + decl = p[:-1].rstrip() + if " " in decl: + type_name, var_name = decl.rsplit(" ", 1) + alt_tmp_decl = f"{type_name} __tmp_{var_name};" + if alt_tmp_decl in code: + continue + assert p in code, f"Expected pattern not found: {p}" @@ -47,17 +61,29 @@ def assert_not_contains(code: str, *patterns: str) -> None: def assert_trampoline_has(code: str, *stmts: str) -> None: - """Assert trampoline contains statements, insensitive to call_data vs __calldata__ vs __in_ prefix.""" + """Assert generated load statements across old/new kernel-generation variants. + + Accepts legacy trampoline statements and their modern equivalents: + - ``a = __calldata__.a;`` (legacy) + - ``a = call_data.a;`` (fallback) + - ``__tmp_a = a;`` (fast path with inline body) + - ``__tmp_a = call_data.a;`` (fallback with inline body) + """ for s in stmts: # Replace __calldata__ with all three options for matching if "__calldata__." in s: alt_cd = s.replace("__calldata__.", "call_data.") - # For fast path: __calldata__.X → __in_X (entry-point param prefix) - # Extract variable name after __calldata__. and before any trailing char - alt_in = s.replace("__calldata__.", "__in_") - assert ( - s in code or alt_cd in code or alt_in in code - ), f"Expected trampoline statement not found: {s} (or {alt_cd} or {alt_in})" + # For fast path: a = __calldata__.a; -> __tmp_a = a; + alt_tmp = s + if " = __calldata__." in s and s.endswith(";"): + lhs = s.split(" = __calldata__.", 1)[0].strip() + rhs = s.split(" = __calldata__.", 1)[1][:-1].strip() + alt_tmp = f"__tmp_{lhs} = {rhs};" + alt_tmp_cd = alt_tmp.replace(" = ", " = call_data.", 1) if alt_tmp != s else s + assert s in code or alt_cd in code or alt_tmp in code or alt_tmp_cd in code, ( + "Expected generated statement not found: " + f"{s} (or {alt_cd} or {alt_tmp} or {alt_tmp_cd})" + ) else: assert s in code, f"Expected trampoline statement not found: {s}" @@ -1883,21 +1909,22 @@ def test_gate_p2_thread_count_direct(device_type: spy.DeviceType): @pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) def test_gate_p2_trampoline_present_for_prim(device_type: spy.DeviceType): - """Prim-mode kernel has a _trampoline function. Breaks at Step 2.3.""" + """Prim-mode kernel has no _trampoline function after Step 2.3.""" device = helpers.get_device(device_type) code = generate_code(device, "add", "int add(int a, int b) { return a + b; }", 1, 2) - assert_contains(code, "void _trampoline(") + assert_not_contains(code, "void _trampoline(") @pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) def test_gate_p2_kernel_calls_trampoline(device_type: spy.DeviceType): - """compute_main calls _trampoline(). Breaks at Step 2.3.""" + """Prim-mode compute_main inlines call sequence after Step 2.3.""" device = helpers.get_device(device_type) code = generate_code(device, "add", "int add(int a, int b) { return a + b; }", 1, 2) - # Extract compute_main body and check it calls _trampoline + # Extract compute_main body and check it no longer calls _trampoline. main_idx = code.index("void compute_main(") main_body = code[main_idx:] - assert "_trampoline(__slangpy_context__" in main_body + assert "_trampoline(" not in main_body + assert "add(__tmp_a, __tmp_b);" in main_body @pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) From 43cb48410ba7a9d9aac5180aa138ce46aecb71e6 Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Tue, 17 Mar 2026 08:50:14 +0000 Subject: [PATCH 40/41] no trampolines --- slangpy/core/generator.py | 86 ++++++++++++++++++++++++++------------- 1 file changed, 57 insertions(+), 29 deletions(-) diff --git a/slangpy/core/generator.py b/slangpy/core/generator.py index 99572af2e..b5669c528 100644 --- a/slangpy/core/generator.py +++ b/slangpy/core/generator.py @@ -11,8 +11,8 @@ from slangpy.bindings.boundvariable import BoundVariable, BoundCall #: Type names longer than this threshold get a ``typealias _t_{name}`` alias -#: to keep the generated ``CallData`` struct readable. Shorter names are -#: inlined directly. +#: to keep generated entry-point params and ``CallData`` fields readable. +#: Shorter names are inlined directly. MAX_INLINE_TYPE_LEN = 60 @@ -50,7 +50,7 @@ def _grad_is_readable(b: "BoundVariable") -> bool: # --------------------------------------------------------------------------- -# Shared trampoline dispatch helper +# Shared load/store dispatch helper # --------------------------------------------------------------------------- @@ -511,7 +511,9 @@ def _emit_call_data_definitions( def _data_name(x: "BoundVariable", use_entrypoint_args: bool) -> str: - """Return the Slang name used to access a variable's call data in the trampoline. + """Return the Slang name used to read/write a variable's data. + + Used by both the bwds trampoline body and the prim inlined kernel body. - ``_param_{name}`` for param-block variables (both paths). - ``{name}`` in the fast (entry-point-args) path. @@ -608,8 +610,12 @@ def _emit_trampoline_loads( For each parameter, either delegates to a marshall-specific ``gen_trampoline_load`` or emits a standard load call:: - __in_x.__slangpy_load(__slangpy_context__.map(_m_x), x); // slangpy load - x = __in_x; // direct-bind load (no __slangpy_load method) + data_name.__slangpy_load(__slangpy_context__.map(_m_x), x); // slangpy load + x = data_name; // direct-bind load (no __slangpy_load method) + + .. note:: Only used by the bwds trampoline. Prim mode uses + ``_emit_load_call_store_sequence`` which writes to ``__tmp_`` + local temporaries instead. """ for x in root_params: data_name = _data_name(x, use_entrypoint_args) @@ -632,7 +638,11 @@ def _emit_trampoline_stores( delegates to a marshall-specific ``gen_trampoline_store`` or emits a standard store call:: - __in_x.__slangpy_store(__slangpy_context__.map(_m_x), x); + data_name.__slangpy_store(__slangpy_context__.map(_m_x), x); + + .. note:: Only used by the bwds trampoline. Prim mode uses + ``_emit_load_call_store_sequence`` which writes to ``__tmp_`` + local temporaries instead. """ from slangpy.bindings.boundvariable import BoundVariableException @@ -655,13 +665,16 @@ def _emit_trampoline( root_params: list["BoundVariable"], use_entrypoint_args: bool, ) -> None: - """Emit the ``_trampoline`` helper function. + """Emit the ``_trampoline`` helper function (bwds mode only). + + In prim mode the trampoline is eliminated and the load/call/store + sequence is inlined directly into ``compute_main``. Fast path signature:: [Differentiable] void _trampoline(Context __slangpy_context__, - no_diff MyType __in_param0, ...) + no_diff MyType param0, ...) Fallback signature:: @@ -758,16 +771,19 @@ def _emit_kernel_body( root_params: list["BoundVariable"], call_data_len: int, use_entrypoint_args: bool, + need_trampoline: bool, ) -> None: """Emit the body of the compute/raygen entry-point function. - Emits the bounds check, ``init_thread_local_call_shape_info``, Context - construction, and the trampoline call:: + Emits the bounds check, ``init_thread_local_call_shape_info``, and Context + construction. Then either inlines the load/call/store sequence (prim mode) + or calls the differentiable trampoline (bwds mode):: if (any(flat_call_thread_id >= _thread_count)) return; if (!init_thread_local_call_shape_info(...)) return; Context __slangpy_context__ = {flat_call_thread_id, ...}; - _trampoline(__slangpy_context__, ...); + // prim: inline __tmp_a = a; ... result = func(...); ... + // bwds: bwd_diff(_trampoline)(__slangpy_context__, ...); """ from slangpy.core.function import PipelineType @@ -811,8 +827,23 @@ def _emit_kernel_body( # Define the core context. cg.kernel.append_statement(f"Context __slangpy_context__ = {{{context_args}}}") - if context.call_mode == CallMode.prim: - # Prim mode inlines load/call/store directly in compute_main. + if need_trampoline: + # Calling via trampoline (should only ever kick in for bwds in practice) + if context.call_mode == CallMode.bwds: + fn = "bwd_diff(_trampoline)" + else: + fn = "_trampoline" + if use_entrypoint_args: + trampoline_args = ["__slangpy_context__"] + for x in root_params: + if x.create_param_block: + continue + trampoline_args.append(x.variable_name) + cg.kernel.append_statement(f"{fn}({', '.join(trampoline_args)})") + else: + cg.kernel.append_statement(f"{fn}(__slangpy_context__)") + else: + # Inline load/call/store directly in compute_main. _emit_load_call_store_sequence( cg.kernel, build_info, @@ -820,19 +851,6 @@ def _emit_kernel_body( use_entrypoint_args, "__slangpy_context__", ) - return - - # Bwds mode still calls differentiable trampoline. - fn = "bwd_diff(_trampoline)" - if use_entrypoint_args: - trampoline_args = ["__slangpy_context__"] - for x in root_params: - if x.create_param_block: - continue - trampoline_args.append(x.variable_name) - cg.kernel.append_statement(f"{fn}({', '.join(trampoline_args)})") - else: - cg.kernel.append_statement(f"{fn}(__slangpy_context__)") def generate_code( @@ -866,9 +884,19 @@ def generate_code( root_params = sorted(signature.values(), key=lambda x: x.param_index) - if context.call_mode != CallMode.prim: + # Currently we assume a trampoline is always needed for bwds. Technically, this is only needed if + # there are none-direct-bind parameters (i.e. need calls to __slangpy_load/__slangpy_store that may + # internally accumulate gradients). However to make this work we'd also need to analyse the function + # arguments to calculate the correct bwds call signature, based on parameter differentiability. + need_trampoline = context.call_mode != CallMode.prim + + if need_trampoline: _emit_trampoline(cg, context, build_info, root_params, use_entrypoint_args) + _emit_entry_point_signature(cg, build_info, call_data_len, call_group_size, use_entrypoint_args) + cg.kernel.begin_block() - _emit_kernel_body(cg, context, build_info, root_params, call_data_len, use_entrypoint_args) + _emit_kernel_body( + cg, context, build_info, root_params, call_data_len, use_entrypoint_args, need_trampoline + ) cg.kernel.end_block() From 346ec6c4291ce4dd56c09803194ec2c9c526a5a3 Mon Sep 17 00:00:00 2001 From: Chris Cummings Date: Tue, 17 Mar 2026 10:30:26 +0000 Subject: [PATCH 41/41] Remove old excessive tests, clean up proper ones + add a few more --- slangpy/tests/slangpy_tests/test_code_gen.py | 292 ++- .../tests/slangpy_tests/test_kernel_gen.py | 2072 ----------------- 2 files changed, 221 insertions(+), 2143 deletions(-) delete mode 100644 slangpy/tests/slangpy_tests/test_kernel_gen.py diff --git a/slangpy/tests/slangpy_tests/test_code_gen.py b/slangpy/tests/slangpy_tests/test_code_gen.py index 07257a709..75036def7 100644 --- a/slangpy/tests/slangpy_tests/test_code_gen.py +++ b/slangpy/tests/slangpy_tests/test_code_gen.py @@ -35,20 +35,6 @@ def assert_contains(code: str, *patterns: str) -> None: """Assert all patterns appear in generated code.""" for p in patterns: - if p in code: - continue - - # Compatibility for Step 2.3: prim-mode local variable declarations now - # use __tmp_ prefixed names to avoid colliding with entry-point params. - # Example: "vector v;" -> "vector __tmp_v;" - if p.endswith(";") and "(" not in p and ")" not in p and "." not in p and " = " not in p: - decl = p[:-1].rstrip() - if " " in decl: - type_name, var_name = decl.rsplit(" ", 1) - alt_tmp_decl = f"{type_name} __tmp_{var_name};" - if alt_tmp_decl in code: - continue - assert p in code, f"Expected pattern not found: {p}" @@ -58,31 +44,17 @@ def assert_not_contains(code: str, *patterns: str) -> None: assert p not in code, f"Unexpected pattern found: {p}" -def assert_trampoline_has(code: str, *stmts: str) -> None: - """Assert generated load statements across old/new kernel-generation variants. +def assert_load_statement(code: str, *var_names: str) -> None: + """Assert that load statements exist for the given variables. - Accepts legacy trampoline statements and their modern equivalents: - - ``a = __calldata__.a;`` (legacy) - - ``a = call_data.a;`` (fallback) - - ``__tmp_a = a;`` (fast path with inline body) - - ``__tmp_a = call_data.a;`` (fallback with inline body) + Handles both code paths: + - Fast path (entry-point params): ``__tmp_x = x;`` + - Fallback path (ParameterBlock): ``__tmp_x = call_data.x;`` """ - for s in stmts: - if "__calldata__." in s: - alt_cd = s.replace("__calldata__.", "call_data.") - # For fast path: a = __calldata__.a; -> __tmp_a = a; - alt_tmp = s - if " = __calldata__." in s and s.endswith(";"): - lhs = s.split(" = __calldata__.", 1)[0].strip() - rhs = s.split(" = __calldata__.", 1)[1][:-1].strip() - alt_tmp = f"__tmp_{lhs} = {rhs};" - alt_tmp_cd = alt_tmp.replace(" = ", " = call_data.", 1) if alt_tmp != s else s - assert s in code or alt_cd in code or alt_tmp in code or alt_tmp_cd in code, ( - "Expected generated statement not found: " - f"{s} (or {alt_cd} or {alt_tmp} or {alt_tmp_cd})" - ) - else: - assert s in code, f"Expected trampoline statement not found: {s}" + for name in var_names: + fast = f"__tmp_{name} = {name};" + fallback = f"__tmp_{name} = call_data.{name};" + assert fast in code or fallback in code, f"Expected load for '{name}': {fast} or {fallback}" def generate_code_and_bindings( @@ -152,8 +124,8 @@ def test_scalar_direct_bind(device_type: spy.DeviceType): # Scalars use raw type directly, no wrapper assert_not_contains(code, "ValueType") assert_not_contains(code, "typealias _t_a", "typealias _t_b") - # Direct assignment in trampoline - assert_trampoline_has(code, "a = __calldata__.a;", "b = __calldata__.b;") + # Direct assignment — loaded into __tmp_ locals + assert_load_statement(code, "a", "b") # _result is auto-created writable RWValueRef assert_contains(code, "RWValueRef") assert_contains(code, "__slangpy_store") @@ -186,7 +158,7 @@ def test_vector_direct_bind(device_type: spy.DeviceType): assert_not_contains(code, "VectorValueType") assert_not_contains(code, "typealias _t_v") - assert_contains(code, "vector v;") + assert_contains(code, "vector __tmp_v;") assert bindings.args[0].direct_bind is True assert bindings.args[0].call_dimensionality == 0 @@ -207,7 +179,7 @@ def test_matrix_direct_bind(device_type: spy.DeviceType): assert_not_contains(code, "ValueType>") assert_not_contains(code, "typealias _t_m") - assert_contains(code, "matrix m;") + assert_contains(code, "matrix __tmp_m;") assert bindings.args[0].direct_bind is True assert bindings.args[0].call_dimensionality == 0 @@ -245,8 +217,8 @@ def test_valueref_read_direct_bind(device_type: spy.DeviceType): ) assert_not_contains(code, "typealias _t_v") - assert_contains(code, "float v;") - assert_trampoline_has(code, "v = __calldata__.v;") + assert_contains(code, "float __tmp_v;") + assert_load_statement(code, "v") assert_contains(code, "RWValueRef") assert bindings.args[0].direct_bind is True @@ -291,8 +263,8 @@ def test_struct_all_scalar_direct_bind(device_type: spy.DeviceType): # Direct-bind struct — raw type, no __slangpy_load assert_not_contains(code, "__slangpy_load") assert_not_contains(code, "typealias _t_s") - assert_contains(code, "S s;") - assert_trampoline_has(code, "s = __calldata__.s;") + assert_contains(code, "S __tmp_s;") + assert_load_statement(code, "s") s = bindings.args[0] assert s.direct_bind is True @@ -304,14 +276,11 @@ def test_struct_all_scalar_direct_bind(device_type: spy.DeviceType): @pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) @pytest.mark.parametrize( "variant", - ["vector_field", "array_field"], - ids=["vector_field", "array_field"], + ["vector_field", "array_field", "matrix_field"], + ids=["vector_field", "array_field", "matrix_field"], ) def test_struct_composite_fields_direct_bind(device_type: spy.DeviceType, variant: str): - """Struct with composite field (vector / array) all dim-0 → direct-bind. - - Merges: struct_with_vector_fields, struct_with_array_field codegen+binding tests. - """ + """Struct with composite field (vector / array / matrix) all dim-0 → direct-bind.""" device = helpers.get_device(device_type) if variant == "vector_field": @@ -325,7 +294,7 @@ def test_struct_composite_fields_direct_bind(device_type: spy.DeviceType, varian arg = {"_type": "S", "pos": spy.math.float3(1, 2, 3), "scale": 2.0} func_name = "apply" child_name = "pos" - else: + elif variant == "array_field": src = """ struct Foo { int vals[4]; @@ -341,14 +310,25 @@ def test_struct_composite_fields_direct_bind(device_type: spy.DeviceType, varian arg = {"_type": "Foo", "vals": [1, 2, 3, 4]} func_name = "sum_inner" child_name = "vals" + else: + src = """ +struct S { + float4x4 m; + float scale; +}; +float4x4 apply(S s) { return s.m * s.scale; } +""" + arg = {"_type": "S", "m": spy.math.float4x4.identity(), "scale": 2.0} + func_name = "apply" + child_name = "m" code, bindings = generate_code_and_bindings(device, func_name, src, arg) # Struct is direct-bind — raw type, no __slangpy_load assert_not_contains(code, "__slangpy_load") - param_name = "s" if variant == "vector_field" else "foo" + param_name = "foo" if variant == "array_field" else "s" assert_not_contains(code, f"typealias _t_{param_name}") - assert_trampoline_has(code, f"{param_name} = __calldata__.{param_name};") + assert_load_statement(code, param_name) s = bindings.args[0] assert s.direct_bind is True @@ -386,10 +366,10 @@ def test_deeply_nested_struct_direct_bind(device_type: spy.DeviceType): code, bindings = generate_code_and_bindings(device, "compute", src, arg) assert_not_contains(code, "typealias _t_t") - assert_contains(code, "Top t;") + assert_contains(code, "Top __tmp_t;") assert_not_contains(code, "__slangpy_load") assert_not_contains(code, "struct _t_t") - assert_trampoline_has(code, "t = __calldata__.t;") + assert_load_statement(code, "t") t = bindings.args[0] assert t.direct_bind is True @@ -436,7 +416,7 @@ def test_struct_mixed_fields(device_type: spy.DeviceType): assert_contains(code, "x.__slangpy_load(context.map(_m_x),value.x)") # Independent scalar arg 'scale' — direct-bind assert_not_contains(code, "typealias _t_scale") - assert_contains(code, "float scale;") + assert_contains(code, "float __tmp_scale;") # Binding flags s = bindings.args[0] @@ -566,7 +546,7 @@ def test_mixed_scalar_and_tensor(device_type: spy.DeviceType): # 'a' direct-bind assert_not_contains(code, "typealias _t_a") assert_not_contains(code, "ValueType") - assert_trampoline_has(code, "a = __calldata__.a;") + assert_load_statement(code, "a") # 'b' NOT direct-bind (vectorized tensor) assert_contains(code, "Tensor") assert_contains(code, "__slangpy_load") @@ -595,8 +575,8 @@ def test_tensor_dim0_direct_bind(device_type: spy.DeviceType): code, bindings = generate_code_and_bindings(device, "tensor_read", src, tensor) assert_not_contains(code, "typealias _t_t") - assert_contains(code, "Tensor t;") - assert_trampoline_has(code, "t = __calldata__.t;") + assert_contains(code, "Tensor __tmp_t;") + assert_load_statement(code, "t") assert_not_contains(code, "ValueType<") t = bindings.args[0] @@ -620,7 +600,7 @@ def test_2d_tensor_to_vector(device_type: spy.DeviceType): assert_contains(code, "__slangpy_load") assert_contains(code, "_m_v") assert_not_contains(code, "typealias _t_s") - assert_contains(code, "float s;") + assert_contains(code, "float __tmp_s;") v = bindings.args[0] assert v.call_dimensionality == 1 @@ -730,8 +710,8 @@ def test_mixed_vectorized_dim0_tensor(device_type: spy.DeviceType): assert_contains(code, "__slangpy_load") # weights: dim-0 direct-bind assert_not_contains(code, "typealias _t_weights") - assert_contains(code, "Tensor weights;") - assert_trampoline_has(code, "weights = __calldata__.weights;") + assert_contains(code, "Tensor __tmp_weights;") + assert_load_statement(code, "weights") v = bindings.args[0] assert v.call_dimensionality == 1 @@ -788,7 +768,7 @@ def test_long_type_name_typealias(device_type: spy.DeviceType): device, "sum", short_src, {"_type": _SHORT_STRUCT_NAME, "x": 1.0, "y": 2.0} ) assert_not_contains(code_short, "typealias _t_s") - assert_contains(code_short, f"{_SHORT_STRUCT_NAME} s;") + assert_contains(code_short, f"{_SHORT_STRUCT_NAME} __tmp_s;") # --- Long wrapper name for _result --- identity_src = f""" @@ -1259,14 +1239,10 @@ def test_bwds_entrypoint_no_diff_params(device_type: spy.DeviceType): # --- fast path --- assert cd.use_entrypoint_args is True - # --- trampoline params have no_diff; bare name (Step 2.3) or __in_ prefix (legacy) --- + # --- trampoline params have no_diff --- assert_contains(code, "no_diff") - assert ( - "no_diff float a" in code or "__in_a" in code - ), "Expected trampoline param for 'a' (no_diff float a or __in_a)" - assert ( - "no_diff float b" in code or "__in_b" in code - ), "Expected trampoline param for 'b' (no_diff float b or __in_b)" + assert_contains(code, "no_diff float a") + assert_contains(code, "no_diff float b") # --- [Differentiable] before trampoline --- diff_idx = code.index("[Differentiable]") @@ -1335,5 +1311,179 @@ def test_fallback_calldata_large_params(device_type: spy.DeviceType): assert_not_contains(code, "uniform uint3 _thread_count") +# =========================================================================== +# Prim-mode trampoline elimination (41–42) +# =========================================================================== + + +# 41 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_prim_no_trampoline(device_type: spy.DeviceType): + """Prim mode: no _trampoline function, call inlined in compute_main.""" + device = helpers.get_device(device_type) + code, _ = generate_code_and_bindings( + device, "add", "int add(int a, int b) { return a + b; }", 1, 2 + ) + # No trampoline function generated + assert_not_contains(code, "void _trampoline(") + # compute_main does NOT call _trampoline — it inlines the call + main_idx = code.index("void compute_main(") + main_body = code[main_idx:] + assert "_trampoline(" not in main_body + assert "add(__tmp_a, __tmp_b);" in main_body + + +# 42 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_struct_array_of_structs_codegen(device_type: spy.DeviceType): + """Struct with array-of-structs field: Outer{Inner items[4]} — all dim-0, direct-bind.""" + device = helpers.get_device(device_type) + src = """ +struct Inner { + int x; +}; +struct Outer { + Inner items[4]; +}; +int sum_inner(Outer outer) { + int s = 0; + for (int i = 0; i < 4; i++) { + s += outer.items[i].x; + } + return s; +} +""" + code, bindings = generate_code_and_bindings( + device, + "sum_inner", + src, + { + "_type": "Outer", + "items": [ + {"_type": "Inner", "x": 10}, + {"_type": "Inner", "x": 20}, + {"_type": "Inner", "x": 30}, + {"_type": "Inner", "x": 40}, + ], + }, + ) + assert_not_contains(code, "typealias _t_outer") + assert_contains(code, "Outer __tmp_outer;") + assert_not_contains(code, "__slangpy_load") + assert_load_statement(code, "outer") + + assert bindings.args[0].direct_bind is True + + +# =========================================================================== +# Additional use_entrypoint_args coverage (43–46) +# =========================================================================== + + +# 43 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_threshold_property_positive(device_type: spy.DeviceType): + """Device has a positive max_entry_point_uniform_size threshold.""" + device = helpers.get_device(device_type) + threshold = device.info.limits.max_entry_point_uniform_size + assert threshold > 0 + + +# 44 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_vector_uses_entrypoint_args(device_type: spy.DeviceType): + """float3 args are small enough for entry-point params.""" + device = helpers.get_device(device_type) + _, _, cd = build_call_data_full( + device, + "scale", + "float3 scale(float3 v, float s) { return v * s; }", + spy.math.float3(1, 2, 3), + 2.0, + ) + assert cd.use_entrypoint_args is True + + +# 45 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_struct_uses_entrypoint_args(device_type: spy.DeviceType): + """All-scalar struct dict has small inline-uniform size.""" + device = helpers.get_device(device_type) + src = """ +struct S { float x; float y; }; +float sum(S s) { return s.x + s.y; } +""" + _, _, cd = build_call_data_full(device, "sum", src, {"_type": "S", "x": 1.0, "y": 2.0}) + assert cd.use_entrypoint_args is True + + +# 46 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_tensor_uses_entrypoint_args(device_type: spy.DeviceType): + """Tensor args contribute descriptor-only (0 inline bytes) → entry-point params.""" + device = helpers.get_device(device_type) + tensor = Tensor.from_numpy(device, np.array([1.0, 2.0, 3.0], dtype=np.float32)) + _, _, cd = build_call_data_full( + device, + "sum_all", + "float sum_all(float x) { return x; }", + tensor, + ) + assert cd.use_entrypoint_args is True + + +# =========================================================================== +# Additional functional dispatch tests (47–49) +# =========================================================================== + + +# 47 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_dispatch_valueref_read(device_type: spy.DeviceType): + """Dispatch with a read-only ValueRef input — direct-bind pipeline end-to-end.""" + device = helpers.get_device(device_type) + func = helpers.create_function_from_module( + device, "double_it", "float double_it(float v) { return v * 2; }" + ) + result = func(ValueRef(7.0)) + assert abs(result - 14.0) < 1e-5 + + +# 48 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_dispatch_struct_return(device_type: spy.DeviceType): + """Dispatch struct return and verify result is dict with correct values.""" + device = helpers.get_device(device_type) + src = """ +struct S { + int x; + int y; +}; +S make_struct(int a, int b) { return { a, b }; } +""" + func = helpers.create_function_from_module(device, "make_struct", src) + result = func(4, 5) + assert isinstance(result, dict) + assert result["x"] == 4 + assert result["y"] == 5 + + +# 49 ------------------------------------------------------------------------ +@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) +def test_dispatch_long_struct_name(device_type: spy.DeviceType): + """End-to-end dispatch with a struct whose name exceeds 60 chars.""" + device = helpers.get_device(device_type) + src = f""" +struct {_LONG_STRUCT_NAME} {{ + float x; + float y; +}}; +float sum({_LONG_STRUCT_NAME} s) {{ return s.x + s.y; }} +""" + func = helpers.create_function_from_module(device, "sum", src) + result = func({"_type": _LONG_STRUCT_NAME, "x": 3.0, "y": 7.0}) + assert abs(result - 10.0) < 1e-5 + + if __name__ == "__main__": pytest.main([__file__, "-vs"]) diff --git a/slangpy/tests/slangpy_tests/test_kernel_gen.py b/slangpy/tests/slangpy_tests/test_kernel_gen.py deleted file mode 100644 index 61c4f55b1..000000000 --- a/slangpy/tests/slangpy_tests/test_kernel_gen.py +++ /dev/null @@ -1,2072 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -""" -Kernel generation test. - -These tests exercise different code paths for kernel generation, to exercise different kernel types, such as: -- passing arguments directly vs via call data -- passing read-only arguments that don't need storing directly rather than via marshalls -- handling the semantic 'dispatch thread id' etc and calling kernels directly - -Gating tests (test_gate_*) assert CURRENT generated kernel patterns and will -intentionally break as simplification steps from the kernel-gen simplification -plan are implemented. Negative gates (test_gate_*_keeps_*) must remain -passing after simplification — they cover types that are NOT direct-bind -eligible. -""" - -from typing import Any - -import numpy as np -import pytest -import os - -import slangpy as spy -from slangpy.testing import helpers -from slangpy.types import ValueRef, Tensor, diffPair -from slangpy.types.wanghasharg import WangHashArg - -PRINT_TEST_KERNEL_GEN = os.getenv("PRINT_TEST_KERNEL_GEN", "0") == "1" - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def assert_contains(code: str, *patterns: str) -> None: - """Assert all patterns appear in generated code.""" - for p in patterns: - if p in code: - continue - - # Compatibility for Step 2.3: prim-mode local variable declarations now - # use __tmp_ prefixed names to avoid colliding with entry-point params. - # Example: "vector v;" -> "vector __tmp_v;" - if p.endswith(";") and "(" not in p and ")" not in p and "." not in p and " = " not in p: - decl = p[:-1].rstrip() - if " " in decl: - type_name, var_name = decl.rsplit(" ", 1) - alt_tmp_decl = f"{type_name} __tmp_{var_name};" - if alt_tmp_decl in code: - continue - - assert p in code, f"Expected pattern not found: {p}" - - -def assert_not_contains(code: str, *patterns: str) -> None: - """Assert none of the patterns appear in generated code.""" - for p in patterns: - assert p not in code, f"Unexpected pattern found: {p}" - - -def assert_trampoline_has(code: str, *stmts: str) -> None: - """Assert generated load statements across old/new kernel-generation variants. - - Accepts legacy trampoline statements and their modern equivalents: - - ``a = __calldata__.a;`` (legacy) - - ``a = call_data.a;`` (fallback) - - ``__tmp_a = a;`` (fast path with inline body) - - ``__tmp_a = call_data.a;`` (fallback with inline body) - """ - for s in stmts: - # Replace __calldata__ with all three options for matching - if "__calldata__." in s: - alt_cd = s.replace("__calldata__.", "call_data.") - # For fast path: a = __calldata__.a; -> __tmp_a = a; - alt_tmp = s - if " = __calldata__." in s and s.endswith(";"): - lhs = s.split(" = __calldata__.", 1)[0].strip() - rhs = s.split(" = __calldata__.", 1)[1][:-1].strip() - alt_tmp = f"__tmp_{lhs} = {rhs};" - alt_tmp_cd = alt_tmp.replace(" = ", " = call_data.", 1) if alt_tmp != s else s - assert s in code or alt_cd in code or alt_tmp in code or alt_tmp_cd in code, ( - "Expected generated statement not found: " - f"{s} (or {alt_cd} or {alt_tmp} or {alt_tmp_cd})" - ) - else: - assert s in code, f"Expected trampoline statement not found: {s}" - - -def generate_code( - device: spy.Device, func_name: str, module_source: str, *args: Any, **kwargs: Any -) -> str: - """ - Generate code for the given function and arguments, and return the generated code as a string. - """ - func = helpers.create_function_from_module(device, func_name, module_source) - cd = func.debug_build_call_data(*args, **kwargs) - if PRINT_TEST_KERNEL_GEN: - print(cd.code) - return cd.code - - -def generate_bwds_code( - device: spy.Device, func_name: str, module_source: str, *args: Any, **kwargs: Any -) -> str: - """ - Generate backwards-mode code for the given function and arguments. - """ - func = helpers.create_function_from_module(device, func_name, module_source) - cd = func.bwds.debug_build_call_data(*args, **kwargs) - if PRINT_TEST_KERNEL_GEN: - print(cd.code) - return cd.code - - -# --------------------------------------------------------------------------- -# Basic test -# --------------------------------------------------------------------------- - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_kernel_gen_basic(device_type: spy.DeviceType): - """ - Test basic kernel generation with a simple function that adds two numbers. - """ - src = """ -int add(int a, int b) { - return a + b; -} -""" - device = helpers.get_device(device_type) - code = generate_code(device, "add", src, 1, 2) - if PRINT_TEST_KERNEL_GEN: - print(code) - assert "add" in code - - -# =========================================================================== -# Phase 1 tests — assert direct-bind behaviour after implementation -# =========================================================================== - -# -- Step 1.2: Scalar direct binding -- - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_scalar_uses_valuetype(device_type: spy.DeviceType): - device = helpers.get_device(device_type) - code = generate_code( - device, - "add", - "int add(int a, int b) { return a + b; }", - 1, - 2, - ) - # Scalars now use direct binding: type used directly in CallData, no ValueType wrapper - assert_not_contains(code, "ValueType") - assert_not_contains(code, "typealias _t_a", "typealias _t_b") - # Trampoline uses direct assignment, no __slangpy_load - assert_trampoline_has(code, "a = __calldata__.a;", "b = __calldata__.b;") - # _result is auto-created as writable RWValueRef (not direct-bind) - assert_contains(code, "RWValueRef") - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_float_scalar_uses_valuetype(device_type: spy.DeviceType): - device = helpers.get_device(device_type) - code = generate_code( - device, - "mymul", - "float mymul(float x, float y) { return x * y; }", - 1.0, - 2.0, - ) - assert_not_contains(code, "ValueType") - assert_not_contains(code, "typealias _t_x", "typealias _t_y") - - -# -- Step 1.3: Vector / Matrix / Array direct binding -- - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_vector_uses_vectorvaluetype(device_type: spy.DeviceType): - device = helpers.get_device(device_type) - code = generate_code( - device, - "scale", - "float3 scale(float3 v, float s) { return v * s; }", - spy.math.float3(1, 2, 3), - 1.0, - ) - assert_not_contains(code, "VectorValueType") - assert_not_contains(code, "typealias _t_v") - assert_contains(code, "vector v;") - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_matrix_uses_valuetype(device_type: spy.DeviceType): - device = helpers.get_device(device_type) - code = generate_code( - device, - "ident", - "float4x4 ident(float4x4 m) { return m; }", - spy.math.float4x4.identity(), - ) - assert_not_contains(code, "ValueType>") - assert_not_contains(code, "typealias _t_m") - assert_contains(code, "matrix m;") - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_array_dim0_uses_valuetype(device_type: spy.DeviceType): - device = helpers.get_device(device_type) - code = generate_code( - device, - "process", - "void process(float a[4]) { }", - [1.0, 2.0, 3.0, 4.0], - ) - assert_not_contains(code, "ValueType<") - assert_not_contains(code, "typealias _t_a") - - -# -- Step 1.5: ValueRef direct binding -- - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_valueref_read_uses_wrapper(device_type: spy.DeviceType): - device = helpers.get_device(device_type) - code = generate_code( - device, - "read_val", - "float read_val(float v) { return v; }", - ValueRef(1.0), - ) - # Read-only ValueRef uses raw type directly (direct-bind) - assert_not_contains(code, "typealias _t_v") - assert_contains(code, "float v;") - # Direct assignment in trampoline - assert_trampoline_has(code, "v = __calldata__.v;") - # _result (writable) still uses RWValueRef wrapper - assert_contains(code, "RWValueRef") - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_valueref_write_uses_wrapper(device_type: spy.DeviceType): - device = helpers.get_device(device_type) - code = generate_code( - device, - "add", - "int add(int a, int b) { return a + b; }", - 1, - 2, - ) - # Auto-created _result uses RWValueRef (writable, not direct-bind) - assert_contains(code, "RWValueRef") - # Trampoline uses __slangpy_store via wrapper - assert_contains(code, "__slangpy_store") - - -# -- Step 1.7: Mapping constants and context.map -- - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_mapping_constants_present(device_type: spy.DeviceType): - device = helpers.get_device(device_type) - code = generate_code( - device, - "add", - "int add(int a, int b) { return a + b; }", - 1, - 2, - ) - # Direct-bind variables no longer emit mapping constants - assert_not_contains( - code, - "static const int _m_a = 0", - "static const int _m_b = 0", - ) - # _result is NOT direct-bind (writable ValueRef), so it keeps mapping constant - assert_contains(code, "static const int _m__result = 0") - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_context_map_in_trampoline(device_type: spy.DeviceType): - device = helpers.get_device(device_type) - code = generate_code( - device, - "add", - "int add(int a, int b) { return a + b; }", - 1, - 2, - ) - # Direct-bind variables don't use context.map - assert_not_contains(code, "__slangpy_context__.map(_m_a)") - - -# -- Step 1.4: Struct / dict direct binding -- - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_struct_uses_slangpy_load(device_type: spy.DeviceType): - device = helpers.get_device(device_type) - src = """ -struct S { - float x; - float y; -}; -float sum(S s) { return s.x + s.y; } -""" - code = generate_code(device, "sum", src, {"_type": "S", "x": 1.0, "y": 2.0}) - # Direct-bind struct: uses raw type directly, no inline struct with __slangpy_load - assert_not_contains(code, "__slangpy_load") - assert_not_contains(code, "typealias _t_s") - assert_contains(code, "S s;") - # Direct assignment in trampoline - assert_trampoline_has(code, "s = __calldata__.s;") - - -# -- Step 1.8: Autodiff gating -- - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_bwds_scalar_uses_valuetype(device_type: spy.DeviceType): - device = helpers.get_device(device_type) - src = """ -[Differentiable] -float polynomial(float a, float b) { - return a * a + b + 1; -} -""" - code = generate_bwds_code(device, "polynomial", src, 5.0, 10.0, 26.0) - # bwds still uses direct bind for primals; check differentiable markers remain - assert_not_contains(code, "ValueType") - assert_contains(code, "[Differentiable]", "bwd_diff(_trampoline)") - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_bwds_trampoline_is_differentiable(device_type: spy.DeviceType): - device = helpers.get_device(device_type) - src = """ -[Differentiable] -float polynomial(float a, float b) { - return a * a + b + 1; -} -""" - code = generate_bwds_code(device, "polynomial", src, 5.0, 10.0, 26.0) - # [Differentiable] should appear before the trampoline function - diff_idx = code.index("[Differentiable]") - trampoline_idx = code.index("void _trampoline") - assert diff_idx < trampoline_idx - - -# =========================================================================== -# Phase 1 negative gates — must REMAIN passing after Phase 1 -# =========================================================================== - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_wanghasharg_uses_wrapper(device_type: spy.DeviceType): - device = helpers.get_device(device_type) - src = "uint3 rng(uint3 input) { return input; }" - code = generate_code(device, "rng", src, WangHashArg(3)) - assert_contains(code, "WangHashArg<") - # WangHashArg uses wrapper type — field declaration present in CallData - assert_contains(code, "input") - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_vectorized_scalar_keeps_wrapper(device_type: spy.DeviceType): - device = helpers.get_device(device_type) - src = "float square(float x) { return x * x; }" - tensor = Tensor.from_numpy( - helpers.get_device(device_type), np.array([1, 2, 3], dtype=np.float32) - ) - code = generate_code(device, "square", src, tensor) - # Vectorized (dim > 0) — tensor marshall used, __slangpy_load still present - assert_contains(code, "__slangpy_load") - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_vectorized_dict_keeps_struct_load(device_type: spy.DeviceType): - device = helpers.get_device(device_type) - src = """ -struct S { - float x; - float y; -}; -void apply(S s, float scale) {} -""" - tensor_x = Tensor.from_numpy( - helpers.get_device(device_type), np.array([1, 2, 3], dtype=np.float32) - ) - tensor_y = Tensor.from_numpy( - helpers.get_device(device_type), np.array([4, 5, 6], dtype=np.float32) - ) - code = generate_code(device, "apply", src, {"_type": "S", "x": tensor_x, "y": tensor_y}, 1.0) - # Children are vectorized (dim > 0) — should keep inline struct with __slangpy_load - assert_contains(code, "__slangpy_load") - - -# =========================================================================== -# Phase 1 functional dispatch tests — verify GPU results are correct -# =========================================================================== - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_phase1_functional_scalar_add(device_type: spy.DeviceType): - """Dispatch scalar add with direct binding and verify GPU result.""" - device = helpers.get_device(device_type) - func = helpers.create_function_from_module( - device, "add", "int add(int a, int b) { return a + b; }" - ) - result = func(3, 7) - assert result == 10 - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_phase1_functional_float_mul(device_type: spy.DeviceType): - """Dispatch float multiply with direct binding.""" - device = helpers.get_device(device_type) - func = helpers.create_function_from_module( - device, "mymul", "float mymul(float x, float y) { return x * y; }" - ) - result = func(3.0, 4.0) - assert abs(result - 12.0) < 1e-5 - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_phase1_functional_vector_scale(device_type: spy.DeviceType): - """Dispatch vector scale with direct binding.""" - device = helpers.get_device(device_type) - func = helpers.create_function_from_module( - device, "scale", "float3 scale(float3 v, float s) { return v * s; }" - ) - result = func(spy.math.float3(1, 2, 3), 2.0) - assert result.x == 2.0 - assert result.y == 4.0 - assert result.z == 6.0 - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_phase1_functional_struct_sum(device_type: spy.DeviceType): - """Dispatch struct sum via dict with direct binding.""" - device = helpers.get_device(device_type) - src = """ -struct S { - float x; - float y; -}; -float sum(S s) { return s.x + s.y; } -""" - func = helpers.create_function_from_module(device, "sum", src) - result = func({"_type": "S", "x": 3.0, "y": 7.0}) - assert abs(result - 10.0) < 1e-5 - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_phase1_functional_valueref_write(device_type: spy.DeviceType): - """Dispatch with explicit ValueRef output and read back.""" - device = helpers.get_device(device_type) - func = helpers.create_function_from_module( - device, "add", "int add(int a, int b) { return a + b; }" - ) - out = ValueRef(0) - func(5, 8, _result=out) - assert out.value == 13 - - -# =========================================================================== -# Mixed direct-bind tests — some args direct, some not -# =========================================================================== - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_mixed_args_scalar_and_tensor(device_type: spy.DeviceType): - """Scalar arg gets direct-bind; vectorized tensor arg does not.""" - device = helpers.get_device(device_type) - tensor = Tensor.from_numpy(device, np.array([1, 2, 3], dtype=np.float32)) - code = generate_code( - device, - "add", - "float add(float a, float b) { return a + b; }", - 1.0, - tensor, - ) - # 'a' is direct-bind (scalar dim-0): type used directly, direct trampoline load - assert_not_contains(code, "typealias _t_a") - assert_not_contains(code, "ValueType") - assert_trampoline_has(code, "a = __calldata__.a;") - # 'b' is NOT direct-bind (vectorized tensor dim-1): uses Tensor, - # __slangpy_load, and mapping constant - assert_contains(code, "Tensor") - assert_contains(code, "__slangpy_load") - assert_contains(code, "_m_b") - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_mixed_args_direct_bind_flags(device_type: spy.DeviceType): - """Verify direct_bind flags on bindings for mixed scalar + tensor call.""" - device = helpers.get_device(device_type) - tensor = Tensor.from_numpy(device, np.array([1, 2, 3], dtype=np.float32)) - func = helpers.create_function_from_module( - device, "add", "float add(float a, float b) { return a + b; }" - ) - cd = func.debug_build_call_data(1.0, tensor) - bindings = cd.debug_only_bindings - assert bindings.args[0].direct_bind is True, "scalar arg 'a' should be direct_bind" - assert bindings.args[0].call_dimensionality == 0 - assert bindings.args[1].direct_bind is False, "tensor arg 'b' should NOT be direct_bind" - assert bindings.args[1].call_dimensionality == 1 - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_phase1_functional_mixed_scalar_tensor(device_type: spy.DeviceType): - """Dispatch mixed scalar + tensor and verify GPU result.""" - device = helpers.get_device(device_type) - func = helpers.create_function_from_module( - device, "add", "float add(float a, float b) { return a + b; }" - ) - tensor = Tensor.from_numpy(device, np.array([10, 20, 30], dtype=np.float32)) - result = func(5.0, tensor) - expected = np.array([15, 25, 35], dtype=np.float32) - np.testing.assert_allclose(result.to_numpy().flatten(), expected, atol=1e-5) - - -# =========================================================================== -# Struct with mixed direct-bind fields -# =========================================================================== - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_struct_mixed_fields_codegen(device_type: spy.DeviceType): - """Struct with one tensor field and one scalar field. - - The struct is NOT direct-bind because child x is vectorized (dim-1). - Child y (scalar) keeps direct_bind=True — gen_call_data_code emits - direct assignment (value.y = y) instead of y.__slangpy_load(...). - """ - device = helpers.get_device(device_type) - src = """ -struct S { - float x; - float y; -}; -void apply(S s, float scale) {} -""" - tensor_x = Tensor.from_numpy(device, np.array([1, 2, 3], dtype=np.float32)) - code = generate_code(device, "apply", src, {"_type": "S", "x": tensor_x, "y": 1.0}, 2.0) - # Struct is NOT direct-bind: uses inline struct with __slangpy_load - assert_contains(code, "__slangpy_load") - assert_contains(code, "struct _t_s") - assert_not_contains(code, "typealias _t_s = S;") - # Child y is direct-bind: type used directly, direct assignment in __slangpy_load - assert_not_contains(code, "typealias _t_y") - assert_contains(code, "float y;") - assert_contains(code, "value.y = y;") - assert_not_contains(code, "ValueType") - # Child x should use tensor type - assert_contains(code, "Tensor") - # Scalar arg 'scale' is independent — should still be direct-bind - assert_not_contains(code, "typealias _t_scale") - assert_contains(code, "float scale;") - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_struct_mixed_fields_binding_flags(device_type: spy.DeviceType): - """Verify direct_bind flags on struct children when struct is NOT direct-bind.""" - device = helpers.get_device(device_type) - src = """ -struct S { - float x; - float y; -}; -void apply(S s, float scale) {} -""" - tensor_x = Tensor.from_numpy(device, np.array([1, 2, 3], dtype=np.float32)) - func = helpers.create_function_from_module(device, "apply", src) - cd = func.debug_build_call_data({"_type": "S", "x": tensor_x, "y": 1.0}, 2.0) - bindings = cd.debug_only_bindings - s_binding = bindings.args[0] - assert s_binding.direct_bind is False, "struct 's' should NOT be direct_bind" - # Child x is a tensor (dim-1), not direct-bind - assert s_binding.children["x"].direct_bind is False - # Child y is a scalar (dim-0), keeps its direct_bind status - assert s_binding.children["y"].direct_bind is True - # 'scale' is independent scalar — should be direct_bind - assert bindings.args[1].direct_bind is True - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_phase1_functional_struct_mixed_fields(device_type: spy.DeviceType): - """Dispatch struct with mixed tensor+scalar fields and verify GPU result.""" - device = helpers.get_device(device_type) - src = """ -struct S { - float x; - float y; -}; -float weighted_sum(S s, float scale) { return (s.x + s.y) * scale; } -""" - func = helpers.create_function_from_module(device, "weighted_sum", src) - tensor_x = Tensor.from_numpy(device, np.array([1, 2, 3], dtype=np.float32)) - result = func({"_type": "S", "x": tensor_x, "y": 10.0}, 2.0) - expected = np.array([22, 24, 26], dtype=np.float32) - np.testing.assert_allclose(result.to_numpy().flatten(), expected, atol=1e-5) - - -# =========================================================================== -# Tensor at dim-0 (whole tensor passed to Tensor parameter) -# =========================================================================== - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_tensor_dim0_codegen(device_type: spy.DeviceType): - """1D Tensor passed to Tensor param — dim-0, direct assignment.""" - device = helpers.get_device(device_type) - src = """ -float tensor_read(Tensor t) { - return t[0]; -} -""" - tensor = Tensor.from_numpy(device, np.array([42, 2, 3], dtype=np.float32)) - code = generate_code(device, "tensor_read", src, tensor) - # Type should use Tensor directly (no typealias) - assert_not_contains(code, "typealias _t_t") - assert_contains(code, "Tensor t;") - # Trampoline uses direct assignment (not __slangpy_load) - assert_trampoline_has(code, "t = __calldata__.t;") - # No wrapper type for the tensor - assert_not_contains(code, "ValueType<") - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_tensor_dim0_binding_flags(device_type: spy.DeviceType): - """Tensor at dim-0 has direct_bind=True (consistent with other dim-0 types).""" - device = helpers.get_device(device_type) - src = """ -float tensor_read(Tensor t) { - return t[0]; -} -""" - tensor = Tensor.from_numpy(device, np.array([42, 2, 3], dtype=np.float32)) - func = helpers.create_function_from_module(device, "tensor_read", src) - cd = func.debug_build_call_data(tensor) - bindings = cd.debug_only_bindings - t_binding = bindings.args[0] - assert t_binding.direct_bind is True - assert t_binding.call_dimensionality == 0 - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_phase1_functional_tensor_dim0(device_type: spy.DeviceType): - """Dispatch with whole tensor at dim-0 and verify GPU result.""" - device = helpers.get_device(device_type) - src = """ -float tensor_read(Tensor t) { - return t[0]; -} -""" - func = helpers.create_function_from_module(device, "tensor_read", src) - tensor = Tensor.from_numpy(device, np.array([42, 99, 7], dtype=np.float32)) - result = func(tensor) - assert abs(result - 42.0) < 1e-5 - - -# =========================================================================== -# Mixed direct-bind children in non-direct-bind struct — validates that -# gen_call_data_code correctly uses direct assignment for direct-bind -# children and __slangpy_load for non-direct-bind children. -# =========================================================================== - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_mixed_children_direct_bind_codegen(device_type: spy.DeviceType): - """Validate code gen for struct with mixed direct-bind / non-direct-bind children. - - Scalar child y gets direct assignment (value.y = y) inside __slangpy_load. - Tensor child x goes through __slangpy_load with context mapping. - Both patterns coexist in the same generated struct. - """ - device = helpers.get_device(device_type) - src = """ -struct S { - float x; - float y; -}; -float weighted_sum(S s, float scale) { return (s.x + s.y) * scale; } -""" - tensor_x = Tensor.from_numpy(device, np.array([1, 2, 3], dtype=np.float32)) - code = generate_code(device, "weighted_sum", src, {"_type": "S", "x": tensor_x, "y": 1.0}, 2.0) - # Child y uses raw type and direct assignment - assert_not_contains(code, "typealias _t_y") - assert_contains(code, "float y;") - assert_contains(code, "value.y = y;") - # No mapping constant for y (direct-bind skips it) - assert_not_contains(code, "_m_y") - # Child x uses tensor wrapper with __slangpy_load - assert_contains(code, "x.__slangpy_load(context.map(_m_x),value.x)") - # The struct itself is not direct-bind - assert_contains(code, "struct _t_s") - - -# =========================================================================== -# Review coverage — binding flag verification tests -# =========================================================================== - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_writable_valueref_not_direct_bind(device_type: spy.DeviceType): - """Writable ValueRef (inout) must not be direct-bind — needs buffer read/write.""" - device = helpers.get_device(device_type) - src = "void inc(inout int v) { v += 1; }" - func = helpers.create_function_from_module(device, "inc", src) - vr = ValueRef(5) - cd = func.debug_build_call_data(vr) - bindings = cd.debug_only_bindings - v_binding = bindings.args[0] - assert v_binding.direct_bind is False - assert v_binding.call_dimensionality == 0 - code = cd.code - assert_contains(code, "RWValueRef") - assert_not_contains(code, "typealias _t_v = int;") - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_result_binding_not_direct_bind(device_type: spy.DeviceType): - """Auto-created _result (writable ValueRef) must not be direct-bind.""" - device = helpers.get_device(device_type) - func = helpers.create_function_from_module( - device, "add", "int add(int a, int b) { return a + b; }" - ) - cd = func.debug_build_call_data(1, 2) - result_binding = cd.debug_only_bindings.kwargs["_result"] - assert result_binding.direct_bind is False - assert result_binding.call_dimensionality == 0 - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_struct_all_scalars_binding_flag(device_type: spy.DeviceType): - """All-scalar struct at dim-0 should have direct_bind=True (and so should children).""" - device = helpers.get_device(device_type) - src = """ -struct S { float x; float y; }; -float sum(S s) { return s.x + s.y; } -""" - func = helpers.create_function_from_module(device, "sum", src) - cd = func.debug_build_call_data({"_type": "S", "x": 1.0, "y": 2.0}) - bindings = cd.debug_only_bindings - s = bindings.args[0] - assert s.direct_bind is True - assert s.children["x"].direct_bind is True - assert s.children["y"].direct_bind is True - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_struct_with_wanghash_child_not_direct_bind(device_type: spy.DeviceType): - """Struct with a WangHashArg child must NOT be direct-bind.""" - device = helpers.get_device(device_type) - src = """ -struct S { uint3 seed; float scale; }; -float apply(S s) { return float(s.seed.x) * s.scale; } -""" - func = helpers.create_function_from_module(device, "apply", src) - cd = func.debug_build_call_data({"_type": "S", "seed": WangHashArg(3), "scale": 1.0}) - bindings = cd.debug_only_bindings - s = bindings.args[0] - assert s.direct_bind is False - # scale child should still be direct-bind individually - assert s.children["scale"].direct_bind is True - code = cd.code - assert_contains(code, "struct _t_s") - assert_not_contains(code, "typealias _t_s = S;") - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_wanghasharg_binding_flag(device_type: spy.DeviceType): - """WangHashArg (no can_direct_bind override) should have direct_bind=False.""" - device = helpers.get_device(device_type) - src = "uint3 rng(uint3 input) { return input; }" - func = helpers.create_function_from_module(device, "rng", src) - cd = func.debug_build_call_data(WangHashArg(3)) - bindings = cd.debug_only_bindings - assert bindings.args[0].direct_bind is False - assert bindings.args[0].call_dimensionality == 0 - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_phase1_functional_valueref_read_input(device_type: spy.DeviceType): - """Dispatch with a read-only ValueRef input — verifies direct-bind ValueRef pipeline end-to-end.""" - device = helpers.get_device(device_type) - func = helpers.create_function_from_module( - device, "double_it", "float double_it(float v) { return v * 2; }" - ) - result = func(ValueRef(7.0)) - assert abs(result - 14.0) < 1e-5 - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_bwds_primal_binding_flags(device_type: spy.DeviceType): - """In bwds mode, primal args (access[0]=read) should have direct_bind=True.""" - device = helpers.get_device(device_type) - src = """ -[Differentiable] -float polynomial(float a, float b) { return a * a + b + 1; } -""" - func = helpers.create_function_from_module(device, "polynomial", src) - cd = func.bwds.debug_build_call_data(5.0, 10.0, 26.0) - bindings = cd.debug_only_bindings - # Primal args in bwds mode → access[0]=read → direct_bind should be True - assert bindings.args[0].direct_bind is True # 'a' - assert bindings.args[1].direct_bind is True # 'b' - - -# =========================================================================== -# ND tensor → (N-1)D parameter vectorization — kernel source pattern tests -# =========================================================================== - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_2d_tensor_to_vector_codegen(device_type: spy.DeviceType): - """2D Tensor shape=(10,3) → float3 param: trailing dim consumed by vector, outer dim dispatched.""" - device = helpers.get_device(device_type) - tensor = Tensor.from_numpy(device, np.ones((10, 3), dtype=np.float32)) - code = generate_code( - device, - "scale", - "float3 scale(float3 v, float s) { return v * s; }", - tensor, - 2.0, - ) - # v is vectorized dim-1: tensor wrapping a vector type - assert_contains(code, "__slangpy_load") - assert_contains(code, "_m_v") - # s is scalar dim-0: direct-bind, type used directly - assert_not_contains(code, "typealias _t_s") - assert_contains(code, "float s;") - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_2d_tensor_to_vector_binding_flags(device_type: spy.DeviceType): - """2D Tensor shape=(10,3) → float3 param: check binding metadata.""" - device = helpers.get_device(device_type) - tensor = Tensor.from_numpy(device, np.ones((10, 3), dtype=np.float32)) - func = helpers.create_function_from_module( - device, - "scale", - "float3 scale(float3 v, float s) { return v * s; }", - ) - cd = func.debug_build_call_data(tensor, 2.0) - bindings = cd.debug_only_bindings - v_binding = bindings.args[0] - # Tensor vectorized over outer dim: call_dimensionality == 1 - assert v_binding.call_dimensionality == 1 - assert v_binding.direct_bind is False - assert v_binding.vector_type is not None - assert v_binding.vector_type.full_name == "vector" - # Scalar s: dim-0 direct-bind - s_binding = bindings.args[1] - assert s_binding.call_dimensionality == 0 - assert s_binding.direct_bind is True - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_phase1_functional_2d_tensor_to_vector(device_type: spy.DeviceType): - """Dispatch 2D tensor → float3 and verify GPU result.""" - device = helpers.get_device(device_type) - func = helpers.create_function_from_module( - device, - "scale", - "float3 scale(float3 v, float s) { return v * s; }", - ) - data = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) - tensor = Tensor.from_numpy(device, data) - result = func(tensor, 2.0) - expected = data * 2.0 - np.testing.assert_allclose(result.to_numpy().reshape(expected.shape), expected, atol=1e-5) - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_3d_tensor_to_vector_codegen(device_type: spy.DeviceType): - """3D Tensor shape=(2,5,3) → float3 param: two outer dims dispatched.""" - device = helpers.get_device(device_type) - tensor = Tensor.from_numpy(device, np.ones((2, 5, 3), dtype=np.float32)) - code = generate_code( - device, - "negate", - "float3 negate(float3 v) { return -v; }", - tensor, - ) - # v vectorized dim-2: uses __slangpy_load, mapping constant - assert_contains(code, "__slangpy_load") - assert_contains(code, "_m_v") - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_3d_tensor_to_vector_binding_flags(device_type: spy.DeviceType): - """3D Tensor shape=(2,5,3) → float3 param: call_dimensionality == 2.""" - device = helpers.get_device(device_type) - tensor = Tensor.from_numpy(device, np.ones((2, 5, 3), dtype=np.float32)) - func = helpers.create_function_from_module( - device, - "negate", - "float3 negate(float3 v) { return -v; }", - ) - cd = func.debug_build_call_data(tensor) - bindings = cd.debug_only_bindings - v = bindings.args[0] - assert v.call_dimensionality == 2 - assert v.direct_bind is False - assert v.vector_type.full_name == "vector" - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_phase1_functional_3d_tensor_to_vector(device_type: spy.DeviceType): - """Dispatch 3D tensor → float3 and verify GPU result.""" - device = helpers.get_device(device_type) - func = helpers.create_function_from_module( - device, - "negate", - "float3 negate(float3 v) { return -v; }", - ) - data = np.arange(30, dtype=np.float32).reshape(2, 5, 3) - tensor = Tensor.from_numpy(device, data) - result = func(tensor) - expected = -data - np.testing.assert_allclose(result.to_numpy().reshape(expected.shape), expected, atol=1e-5) - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_2d_tensor_to_scalar_codegen(device_type: spy.DeviceType): - """2D Tensor shape=(4,5) → float scalar: both dims dispatched (call_dim=2).""" - device = helpers.get_device(device_type) - tensor = Tensor.from_numpy(device, np.ones((4, 5), dtype=np.float32)) - code = generate_code( - device, - "square", - "float square(float x) { return x * x; }", - tensor, - ) - assert_contains(code, "__slangpy_load") - assert_contains(code, "_m_x") - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_2d_tensor_to_scalar_binding_flags(device_type: spy.DeviceType): - """2D Tensor shape=(4,5) → float scalar: call_dimensionality == 2.""" - device = helpers.get_device(device_type) - tensor = Tensor.from_numpy(device, np.ones((4, 5), dtype=np.float32)) - func = helpers.create_function_from_module( - device, - "square", - "float square(float x) { return x * x; }", - ) - cd = func.debug_build_call_data(tensor) - v = cd.debug_only_bindings.args[0] - assert v.call_dimensionality == 2 - assert v.direct_bind is False - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_phase1_functional_2d_tensor_to_scalar(device_type: spy.DeviceType): - """Dispatch 2D tensor elementwise to scalar and verify GPU result.""" - device = helpers.get_device(device_type) - func = helpers.create_function_from_module( - device, - "square", - "float square(float x) { return x * x; }", - ) - data = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) - tensor = Tensor.from_numpy(device, data) - result = func(tensor) - expected = data * data - np.testing.assert_allclose(result.to_numpy().reshape(expected.shape), expected, atol=1e-5) - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_2d_tensor_to_1d_array_codegen(device_type: spy.DeviceType): - """2D Tensor shape=(4,8) → half[8] param: trailing dim consumed by array, outer dim dispatched.""" - device = helpers.get_device(device_type) - tensor = Tensor.from_numpy(device, np.ones((4, 8), dtype=np.float16)) - code = generate_code( - device, - "tensor_test_channels<8>", - r""" -half[NumChannels] tensor_test_channels(half[NumChannels] data) -{ - [ForceUnroll] - for (int i = 0; i < NumChannels; ++i) - { - data[i] = 2.h * data[i]; - } - return data; -} -""", - tensor, - ) - # data is vectorized (trailing dim consumed by array): __slangpy_load present - assert_contains(code, "__slangpy_load") - assert_contains(code, "_m_data") - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_2d_tensor_to_1d_array_binding_flags(device_type: spy.DeviceType): - """2D Tensor shape=(4,8) → half[8] param: call_dimensionality == 1.""" - device = helpers.get_device(device_type) - tensor = Tensor.from_numpy(device, np.ones((4, 8), dtype=np.float16)) - func = helpers.create_function_from_module( - device, - "tensor_test_channels<8>", - r""" -half[NumChannels] tensor_test_channels(half[NumChannels] data) -{ - [ForceUnroll] - for (int i = 0; i < NumChannels; ++i) - { - data[i] = 2.h * data[i]; - } - return data; -} -""", - ) - cd = func.debug_build_call_data(tensor) - v = cd.debug_only_bindings.args[0] - assert v.call_dimensionality == 1 - assert v.direct_bind is False - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_phase1_functional_2d_tensor_to_1d_array(device_type: spy.DeviceType): - """Dispatch 2D tensor → half[8] and verify GPU doubles each element.""" - device = helpers.get_device(device_type) - func = helpers.create_function_from_module( - device, - "tensor_test_channels<8>", - r""" -half[NumChannels] tensor_test_channels(half[NumChannels] data) -{ - [ForceUnroll] - for (int i = 0; i < NumChannels; ++i) - { - data[i] = 2.h * data[i]; - } - return data; -} -""", - ).return_type(Tensor) - data = np.ones((4, 8), dtype=np.float16) - tensor = Tensor.from_numpy(device, data) - result = func(tensor) - expected = data * 2.0 - np.testing.assert_allclose( - result.to_numpy().reshape(expected.shape).astype(np.float32), - expected.astype(np.float32), - atol=1e-2, - ) - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_mixed_vectorized_and_dim0_tensor_codegen(device_type: spy.DeviceType): - """One tensor vectorized (2D→float3) and another at dim-0 (Tensor param).""" - device = helpers.get_device(device_type) - src = """ -float dot_lookup(float3 v, Tensor weights) { - return v.x * weights[0] + v.y * weights[1] + v.z * weights[2]; -} -""" - vec_tensor = Tensor.from_numpy(device, np.ones((5, 3), dtype=np.float32)) - weight_tensor = Tensor.from_numpy(device, np.array([1, 2, 3], dtype=np.float32)) - code = generate_code(device, "dot_lookup", src, vec_tensor, weight_tensor) - # v: vectorized dim-1 (2D→float3), uses __slangpy_load - assert_contains(code, "_m_v") - assert_contains(code, "__slangpy_load") - # weights: dim-0 direct-bind (Tensor param), type used directly + direct assignment - assert_not_contains(code, "typealias _t_weights") - assert_contains(code, "Tensor weights;") - assert_trampoline_has(code, "weights = __calldata__.weights;") - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_mixed_vectorized_and_dim0_tensor_binding_flags(device_type: spy.DeviceType): - """Binding flags: vectorized tensor has dim>0, dim-0 tensor has direct_bind.""" - device = helpers.get_device(device_type) - src = """ -float dot_lookup(float3 v, Tensor weights) { - return v.x * weights[0] + v.y * weights[1] + v.z * weights[2]; -} -""" - vec_tensor = Tensor.from_numpy(device, np.ones((5, 3), dtype=np.float32)) - weight_tensor = Tensor.from_numpy(device, np.array([1, 2, 3], dtype=np.float32)) - func = helpers.create_function_from_module(device, "dot_lookup", src) - cd = func.debug_build_call_data(vec_tensor, weight_tensor) - bindings = cd.debug_only_bindings - v = bindings.args[0] - assert v.call_dimensionality == 1 - assert v.direct_bind is False - w = bindings.args[1] - assert w.call_dimensionality == 0 - assert w.direct_bind is True - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_phase1_functional_mixed_vectorized_and_dim0_tensor(device_type: spy.DeviceType): - """Dispatch vectorized float3 + dim-0 Tensor and verify GPU result.""" - device = helpers.get_device(device_type) - src = """ -float dot_lookup(float3 v, Tensor weights) { - return v.x * weights[0] + v.y * weights[1] + v.z * weights[2]; -} -""" - func = helpers.create_function_from_module(device, "dot_lookup", src) - vecs = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.float32) - weights = np.array([10, 20, 30], dtype=np.float32) - result = func( - Tensor.from_numpy(device, vecs), - Tensor.from_numpy(device, weights), - ) - expected = np.array([10, 20, 30], dtype=np.float32) - np.testing.assert_allclose(result.to_numpy().flatten(), expected, atol=1e-5) - - -# =========================================================================== -# Composite struct codegen tests — nested structs, vector/matrix/array fields -# =========================================================================== - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_nested_struct_codegen(device_type: spy.DeviceType): - """Nested struct: Outer{Inner inner, float scale} — all-scalar, direct-bind.""" - device = helpers.get_device(device_type) - src = """ -struct Inner { - float x; - float y; -}; -struct Outer { - Inner inner; - float scale; -}; -float compute(Outer o) { return (o.inner.x + o.inner.y) * o.scale; } -""" - code = generate_code( - device, - "compute", - src, - {"_type": "Outer", "inner": {"_type": "Inner", "x": 1.0, "y": 2.0}, "scale": 3.0}, - ) - # All-scalar nested struct at dim-0: direct-bind → type used directly - assert_not_contains(code, "typealias _t_o") - assert_contains(code, "Outer o;") - assert_not_contains(code, "__slangpy_load") - assert_not_contains(code, "struct _t_o") - assert_trampoline_has(code, "o = __calldata__.o;") - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_nested_struct_binding_flags(device_type: spy.DeviceType): - """Nested struct: all-scalar → direct_bind=True at every level.""" - device = helpers.get_device(device_type) - src = """ -struct Inner { - float x; - float y; -}; -struct Outer { - Inner inner; - float scale; -}; -float compute(Outer o) { return (o.inner.x + o.inner.y) * o.scale; } -""" - func = helpers.create_function_from_module(device, "compute", src) - cd = func.debug_build_call_data( - {"_type": "Outer", "inner": {"_type": "Inner", "x": 1.0, "y": 2.0}, "scale": 3.0} - ) - bindings = cd.debug_only_bindings - o = bindings.args[0] - assert o.direct_bind is True - assert o.children["inner"].direct_bind is True - assert o.children["inner"].children["x"].direct_bind is True - assert o.children["inner"].children["y"].direct_bind is True - assert o.children["scale"].direct_bind is True - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_phase1_functional_nested_struct(device_type: spy.DeviceType): - """Dispatch nested struct and verify GPU result.""" - device = helpers.get_device(device_type) - src = """ -struct Inner { - float x; - float y; -}; -struct Outer { - Inner inner; - float scale; -}; -float compute(Outer o) { return (o.inner.x + o.inner.y) * o.scale; } -""" - func = helpers.create_function_from_module(device, "compute", src) - result = func({"_type": "Outer", "inner": {"_type": "Inner", "x": 3.0, "y": 7.0}, "scale": 2.0}) - assert abs(result - 20.0) < 1e-5 - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_struct_with_vector_fields_codegen(device_type: spy.DeviceType): - """Struct with vector fields: S{float3 pos, float scale} — all dim-0, direct-bind.""" - device = helpers.get_device(device_type) - src = """ -struct S { - float3 pos; - float scale; -}; -float3 apply(S s) { return s.pos * s.scale; } -""" - code = generate_code( - device, - "apply", - src, - {"_type": "S", "pos": spy.math.float3(1, 2, 3), "scale": 2.0}, - ) - # All-scalar struct with vector field at dim-0: direct-bind → type used directly - assert_not_contains(code, "typealias _t_s") - assert_contains(code, "S s;") - assert_not_contains(code, "__slangpy_load") - assert_trampoline_has(code, "s = __calldata__.s;") - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_struct_with_vector_fields_binding_flags(device_type: spy.DeviceType): - """Struct with vector field — all children direct-bind.""" - device = helpers.get_device(device_type) - src = """ -struct S { - float3 pos; - float scale; -}; -float3 apply(S s) { return s.pos * s.scale; } -""" - func = helpers.create_function_from_module(device, "apply", src) - cd = func.debug_build_call_data({"_type": "S", "pos": spy.math.float3(1, 2, 3), "scale": 2.0}) - bindings = cd.debug_only_bindings - s = bindings.args[0] - assert s.direct_bind is True - assert s.children["pos"].direct_bind is True - assert s.children["scale"].direct_bind is True - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_phase1_functional_struct_with_vector_fields(device_type: spy.DeviceType): - """Dispatch struct with vector field and verify GPU result.""" - device = helpers.get_device(device_type) - src = """ -struct S { - float3 pos; - float scale; -}; -float3 apply(S s) { return s.pos * s.scale; } -""" - func = helpers.create_function_from_module(device, "apply", src) - result = func({"_type": "S", "pos": spy.math.float3(1, 2, 3), "scale": 3.0}) - assert result.x == 3.0 - assert result.y == 6.0 - assert result.z == 9.0 - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_struct_with_matrix_field_codegen(device_type: spy.DeviceType): - """Struct with matrix field: S{float4x4 m, float scale} — all dim-0, direct-bind.""" - device = helpers.get_device(device_type) - src = """ -struct S { - float4x4 m; - float scale; -}; -float4x4 apply(S s) { return s.m * s.scale; } -""" - code = generate_code( - device, - "apply", - src, - {"_type": "S", "m": spy.math.float4x4.identity(), "scale": 2.0}, - ) - assert_not_contains(code, "typealias _t_s") - assert_contains(code, "S s;") - assert_not_contains(code, "__slangpy_load") - assert_trampoline_has(code, "s = __calldata__.s;") - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_phase1_functional_struct_with_matrix_field(device_type: spy.DeviceType): - """Dispatch struct with matrix field and verify GPU result.""" - device = helpers.get_device(device_type) - src = """ -struct S { - float4x4 m; - float scale; -}; -float4x4 apply(S s) { return s.m * s.scale; } -""" - func = helpers.create_function_from_module(device, "apply", src) - result = func({"_type": "S", "m": spy.math.float4x4.identity(), "scale": 2.0}) - # Identity * 2 → diagonal is 2 - assert abs(result[0][0] - 2.0) < 1e-5 - assert abs(result[1][1] - 2.0) < 1e-5 - assert abs(result[0][1] - 0.0) < 1e-5 - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_struct_with_array_field_codegen(device_type: spy.DeviceType): - """Struct with fixed-size array field: Foo{int vals[4]} — all dim-0, direct-bind.""" - device = helpers.get_device(device_type) - src = """ -struct Foo { - int vals[4]; -}; -int sum_inner(Foo foo) { - int s = 0; - for (int i = 0; i < 4; i++) { - s += foo.vals[i]; - } - return s; -} -""" - code = generate_code( - device, - "sum_inner", - src, - {"_type": "Foo", "vals": [1, 2, 3, 4]}, - ) - assert_not_contains(code, "typealias _t_foo") - assert_contains(code, "Foo foo;") - assert_not_contains(code, "__slangpy_load") - assert_trampoline_has(code, "foo = __calldata__.foo;") - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_struct_with_array_field_binding_flags(device_type: spy.DeviceType): - """Struct with array field: all direct_bind=True.""" - device = helpers.get_device(device_type) - src = """ -struct Foo { - int vals[4]; -}; -int sum_inner(Foo foo) { - int s = 0; - for (int i = 0; i < 4; i++) { - s += foo.vals[i]; - } - return s; -} -""" - func = helpers.create_function_from_module(device, "sum_inner", src) - cd = func.debug_build_call_data({"_type": "Foo", "vals": [1, 2, 3, 4]}) - bindings = cd.debug_only_bindings - foo = bindings.args[0] - assert foo.direct_bind is True - assert foo.children["vals"].direct_bind is True - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_phase1_functional_struct_with_array_field(device_type: spy.DeviceType): - """Dispatch struct with array field and verify GPU result.""" - device = helpers.get_device(device_type) - src = """ -struct Foo { - int vals[4]; -}; -int sum_inner(Foo foo) { - int s = 0; - for (int i = 0; i < 4; i++) { - s += foo.vals[i]; - } - return s; -} -""" - func = helpers.create_function_from_module(device, "sum_inner", src) - result = func({"_type": "Foo", "vals": [10, 20, 30, 40]}) - assert result == 100 - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_deeply_nested_struct_codegen(device_type: spy.DeviceType): - """3-level deep nesting: Top{Mid{Bot{float v}, int c}, float s} — all dim-0, direct-bind.""" - device = helpers.get_device(device_type) - src = """ -struct Bot { - float v; -}; -struct Mid { - Bot bot; - int c; -}; -struct Top { - Mid mid; - float s; -}; -float compute(Top t) { return t.mid.bot.v * float(t.mid.c) * t.s; } -""" - code = generate_code( - device, - "compute", - src, - { - "_type": "Top", - "mid": {"_type": "Mid", "bot": {"_type": "Bot", "v": 2.0}, "c": 3}, - "s": 4.0, - }, - ) - assert_not_contains(code, "typealias _t_t") - assert_contains(code, "Top t;") - assert_not_contains(code, "__slangpy_load") - assert_not_contains(code, "struct _t_t") - assert_trampoline_has(code, "t = __calldata__.t;") - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_deeply_nested_struct_binding_flags(device_type: spy.DeviceType): - """3-level deep: all direct_bind=True at every level.""" - device = helpers.get_device(device_type) - src = """ -struct Bot { - float v; -}; -struct Mid { - Bot bot; - int c; -}; -struct Top { - Mid mid; - float s; -}; -float compute(Top t) { return t.mid.bot.v * float(t.mid.c) * t.s; } -""" - func = helpers.create_function_from_module(device, "compute", src) - cd = func.debug_build_call_data( - { - "_type": "Top", - "mid": {"_type": "Mid", "bot": {"_type": "Bot", "v": 2.0}, "c": 3}, - "s": 4.0, - } - ) - bindings = cd.debug_only_bindings - t = bindings.args[0] - assert t.direct_bind is True - assert t.children["mid"].direct_bind is True - assert t.children["mid"].children["bot"].direct_bind is True - assert t.children["mid"].children["bot"].children["v"].direct_bind is True - assert t.children["mid"].children["c"].direct_bind is True - assert t.children["s"].direct_bind is True - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_phase1_functional_deeply_nested_struct(device_type: spy.DeviceType): - """Dispatch 3-level nested struct and verify GPU result.""" - device = helpers.get_device(device_type) - src = """ -struct Bot { - float v; -}; -struct Mid { - Bot bot; - int c; -}; -struct Top { - Mid mid; - float s; -}; -float compute(Top t) { return t.mid.bot.v * float(t.mid.c) * t.s; } -""" - func = helpers.create_function_from_module(device, "compute", src) - result = func( - { - "_type": "Top", - "mid": {"_type": "Mid", "bot": {"_type": "Bot", "v": 2.0}, "c": 3}, - "s": 4.0, - } - ) - assert abs(result - 24.0) < 1e-5 - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_nested_struct_with_tensor_child_codegen(device_type: spy.DeviceType): - """Nested struct where a leaf is a tensor: Outer{Inner{float x (tensor), float y (scalar)}, float s}. - - Outer and Inner are NOT direct-bind (Inner.x is vectorized). - Inner.y and s retain direct_bind=True inside the non-direct-bind parent. - """ - device = helpers.get_device(device_type) - src = """ -struct Inner { - float x; - float y; -}; -struct Outer { - Inner inner; - float s; -}; -float compute(Outer o) { return (o.inner.x + o.inner.y) * o.s; } -""" - tensor_x = Tensor.from_numpy(device, np.array([1, 2, 3], dtype=np.float32)) - code = generate_code( - device, - "compute", - src, - { - "_type": "Outer", - "inner": {"_type": "Inner", "x": tensor_x, "y": 10.0}, - "s": 2.0, - }, - ) - # Outer and Inner are NOT direct-bind: inline structs generated - assert_contains(code, "struct _t_o") - assert_contains(code, "__slangpy_load") - assert_not_contains(code, "typealias _t_o = Outer;") - # Scalar children retain direct-bind: types used directly - assert_not_contains(code, "typealias _t_y") - assert_contains(code, "float y;") - assert_not_contains(code, "typealias _t_s") - assert_contains(code, "float s;") - # Direct assignment for scalar children within __slangpy_load - assert_contains(code, "value.y = y;") - # Tensor child uses standard path - assert_contains(code, "_m_x") - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_nested_struct_with_tensor_child_binding_flags(device_type: spy.DeviceType): - """Nested struct with tensor: Outer not direct-bind, scalar children retain direct_bind.""" - device = helpers.get_device(device_type) - src = """ -struct Inner { - float x; - float y; -}; -struct Outer { - Inner inner; - float s; -}; -float compute(Outer o) { return (o.inner.x + o.inner.y) * o.s; } -""" - tensor_x = Tensor.from_numpy(device, np.array([1, 2, 3], dtype=np.float32)) - func = helpers.create_function_from_module(device, "compute", src) - cd = func.debug_build_call_data( - { - "_type": "Outer", - "inner": {"_type": "Inner", "x": tensor_x, "y": 10.0}, - "s": 2.0, - } - ) - bindings = cd.debug_only_bindings - o = bindings.args[0] - assert o.direct_bind is False - assert o.children["inner"].direct_bind is False # has non-direct child - assert o.children["inner"].children["x"].direct_bind is False # tensor dim>0 - assert o.children["inner"].children["y"].direct_bind is True # scalar dim-0 - assert o.children["s"].direct_bind is True # scalar dim-0 - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_phase1_functional_nested_struct_with_tensor(device_type: spy.DeviceType): - """Dispatch nested struct with tensor leaf and verify GPU result.""" - device = helpers.get_device(device_type) - src = """ -struct Inner { - float x; - float y; -}; -struct Outer { - Inner inner; - float s; -}; -float compute(Outer o) { return (o.inner.x + o.inner.y) * o.s; } -""" - func = helpers.create_function_from_module(device, "compute", src) - tensor_x = Tensor.from_numpy(device, np.array([1, 2, 3], dtype=np.float32)) - result = func( - { - "_type": "Outer", - "inner": {"_type": "Inner", "x": tensor_x, "y": 10.0}, - "s": 2.0, - } - ) - expected = np.array([22, 24, 26], dtype=np.float32) - np.testing.assert_allclose(result.to_numpy().flatten(), expected, atol=1e-5) - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_struct_with_struct_array_field_codegen(device_type: spy.DeviceType): - """Struct with array-of-structs field: Outer{Inner items[4]} — all dim-0, direct-bind.""" - device = helpers.get_device(device_type) - src = """ -struct Inner { - int x; -}; -struct Outer { - Inner items[4]; -}; -int sum_inner(Outer outer) { - int s = 0; - for (int i = 0; i < 4; i++) { - s += outer.items[i].x; - } - return s; -} -""" - code = generate_code( - device, - "sum_inner", - src, - { - "_type": "Outer", - "items": [ - {"_type": "Inner", "x": 10}, - {"_type": "Inner", "x": 20}, - {"_type": "Inner", "x": 30}, - {"_type": "Inner", "x": 40}, - ], - }, - ) - assert_not_contains(code, "typealias _t_outer") - assert_contains(code, "Outer outer;") - assert_not_contains(code, "__slangpy_load") - assert_trampoline_has(code, "outer = __calldata__.outer;") - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_phase1_functional_struct_with_struct_array_field(device_type: spy.DeviceType): - """Dispatch struct with array-of-structs field and verify GPU result.""" - device = helpers.get_device(device_type) - src = """ -struct Inner { - int x; -}; -struct Outer { - Inner items[4]; -}; -int sum_inner(Outer outer) { - int s = 0; - for (int i = 0; i < 4; i++) { - s += outer.items[i].x; - } - return s; -} -""" - func = helpers.create_function_from_module(device, "sum_inner", src) - result = func( - { - "_type": "Outer", - "items": [ - {"_type": "Inner", "x": 10}, - {"_type": "Inner", "x": 20}, - {"_type": "Inner", "x": 30}, - {"_type": "Inner", "x": 40}, - ], - } - ) - assert result == 100 - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_struct_return_codegen(device_type: spy.DeviceType): - """Function returning a struct: _result uses RWValueRef wrapper, not direct-bind.""" - device = helpers.get_device(device_type) - src = """ -struct S { - int x; - int y; -}; -S make_struct(int a, int b) { return { a, b }; } -""" - code = generate_code(device, "make_struct", src, 4, 5) - # Scalar inputs are direct-bind, types used directly - assert_not_contains(code, "typealias _t_a", "typealias _t_b") - # _result is writable → NOT direct-bind → uses wrapper - assert_contains(code, "__slangpy_store") - assert_contains(code, "_m__result") - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_struct_return_binding_flags(device_type: spy.DeviceType): - """Struct return: _result binding is NOT direct-bind (writable).""" - device = helpers.get_device(device_type) - src = """ -struct S { - int x; - int y; -}; -S make_struct(int a, int b) { return { a, b }; } -""" - func = helpers.create_function_from_module(device, "make_struct", src) - cd = func.debug_build_call_data(4, 5) - bindings = cd.debug_only_bindings - result = bindings.kwargs["_result"] - assert result.direct_bind is False - # Inputs are direct-bind - assert bindings.args[0].direct_bind is True - assert bindings.args[1].direct_bind is True - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_phase1_functional_struct_return(device_type: spy.DeviceType): - """Dispatch struct return and verify result is dict with correct values.""" - device = helpers.get_device(device_type) - src = """ -struct S { - int x; - int y; -}; -S make_struct(int a, int b) { return { a, b }; } -""" - func = helpers.create_function_from_module(device, "make_struct", src) - result = func(4, 5) - assert isinstance(result, dict) - assert result["x"] == 4 - assert result["y"] == 5 - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_struct_with_vectorized_2d_tensor_child_codegen(device_type: spy.DeviceType): - """Struct with 2D tensor child vectorized to float3: struct NOT direct-bind. - - S{float3 v (2D tensor→float3), float s (scalar)}. - The tensor's outer dim becomes dispatch, struct generates inline __slangpy_load. - """ - device = helpers.get_device(device_type) - src = """ -struct S { - float3 v; - float s; -}; -float3 apply(S st) { return st.v * st.s; } -""" - tensor_v = Tensor.from_numpy(device, np.ones((5, 3), dtype=np.float32)) - code = generate_code( - device, - "apply", - src, - {"_type": "S", "v": tensor_v, "s": 2.0}, - ) - # Struct NOT direct-bind (tensor child is vectorized) - assert_contains(code, "struct _t_st") - assert_contains(code, "__slangpy_load") - assert_not_contains(code, "typealias _t_st = S;") - # Scalar child s retains direct-bind — type used directly, no alias - assert_not_contains(code, "typealias _t_s") - assert_contains(code, "float s;") - assert_contains(code, "value.s = s;") - # Tensor child v uses standard path - assert_contains(code, "_m_v") - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_phase1_functional_struct_with_vectorized_2d_tensor(device_type: spy.DeviceType): - """Dispatch struct with 2D tensor→float3 child and verify GPU result.""" - device = helpers.get_device(device_type) - src = """ -struct S { - float3 v; - float s; -}; -float3 apply(S st) { return st.v * st.s; } -""" - func = helpers.create_function_from_module(device, "apply", src) - data = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) - tensor_v = Tensor.from_numpy(device, data) - result = func({"_type": "S", "v": tensor_v, "s": 2.0}) - expected = data * 2.0 - np.testing.assert_allclose(result.to_numpy().reshape(expected.shape), expected, atol=1e-5) - - -# =========================================================================== -# Long type name heuristic — typealias emitted for names > MAX_INLINE_TYPE_LEN -# =========================================================================== - -# Struct name that is deliberately longer than MAX_INLINE_TYPE_LEN (60 chars). -# 70 chars: -_LONG_STRUCT_NAME = "MyVeryLongStructNameThatExceedsSixtyCharactersForTesting12345" -assert len(_LONG_STRUCT_NAME) > 60 - -_SHORT_STRUCT_NAME = "S" -assert len(_SHORT_STRUCT_NAME) <= 60 - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_long_struct_name_gets_typealias(device_type: spy.DeviceType): - """A direct-bind struct with a name > 60 chars should emit a typealias.""" - device = helpers.get_device(device_type) - src = f""" -struct {_LONG_STRUCT_NAME} {{ - float x; - float y; -}}; -float sum({_LONG_STRUCT_NAME} s) {{ return s.x + s.y; }} -""" - code = generate_code( - device, - "sum", - src, - {"_type": _LONG_STRUCT_NAME, "x": 1.0, "y": 2.0}, - ) - # Long name → typealias _t_s emitted, CallData field declared as _t_s - assert_contains(code, f"typealias _t_s = {_LONG_STRUCT_NAME};") - # Typealias used in entry-point param or CallData field - assert ( - "_t_s s;" in code or "uniform _t_s s" in code - ), "Expected typealias usage (_t_s s; or uniform _t_s s) not found" - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_short_struct_name_inlined(device_type: spy.DeviceType): - """A direct-bind struct with a short name should NOT emit a typealias.""" - device = helpers.get_device(device_type) - src = f""" -struct {_SHORT_STRUCT_NAME} {{ - float x; - float y; -}}; -float sum({_SHORT_STRUCT_NAME} s) {{ return s.x + s.y; }} -""" - code = generate_code( - device, - "sum", - src, - {"_type": _SHORT_STRUCT_NAME, "x": 1.0, "y": 2.0}, - ) - # Short name → no typealias, raw type inlined - assert_not_contains(code, "typealias _t_s") - assert_contains(code, f"{_SHORT_STRUCT_NAME} s;") - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_long_scalar_type_name_gets_typealias(device_type: spy.DeviceType): - """A non-direct-bind arg whose wrapper type name exceeds 60 chars gets a typealias.""" - device = helpers.get_device(device_type) - src = f""" -struct {_LONG_STRUCT_NAME} {{ - float x; - float y; -}}; -{_LONG_STRUCT_NAME} identity({_LONG_STRUCT_NAME} s) {{ return s; }} -""" - # Pass as a ValueRef so _result is writable → uses wrapper, and the wrapper - # type name for _result will include the long struct name. - code = generate_code( - device, - "identity", - src, - {"_type": _LONG_STRUCT_NAME, "x": 1.0, "y": 2.0}, - ) - # The _result binding uses RWValueRef which exceeds 60 chars - result_type = f"RWValueRef<{_LONG_STRUCT_NAME}>" - assert len(result_type) > 60, f"Expected >60 chars, got {len(result_type)}" - assert_contains(code, f"typealias _t__result = {result_type};") - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_phase1_functional_long_struct_name(device_type: spy.DeviceType): - """End-to-end dispatch with a struct whose name exceeds 60 chars.""" - device = helpers.get_device(device_type) - src = f""" -struct {_LONG_STRUCT_NAME} {{ - float x; - float y; -}}; -float sum({_LONG_STRUCT_NAME} s) {{ return s.x + s.y; }} -""" - func = helpers.create_function_from_module(device, "sum", src) - result = func({"_type": _LONG_STRUCT_NAME, "x": 3.0, "y": 7.0}) - assert abs(result - 10.0) < 1e-5 - - -# =========================================================================== -# Phase 2 gating tests — assert CURRENT behaviour, will break as Phase 2 -# steps are implemented. See plan-simplifyKernelGen-phase2.prompt.md -# =========================================================================== - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_p2_calldata_struct_absent_fast_path(device_type: spy.DeviceType): - """Fast path (use_entrypoint_args=True): no struct CallData emitted. Step 2.2 done.""" - device = helpers.get_device(device_type) - code = generate_code(device, "add", "int add(int a, int b) { return a + b; }", 1, 2) - assert_not_contains(code, "struct CallData") - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_p2_individual_uniform_params(device_type: spy.DeviceType): - """Fast path: individual uniform params instead of unified CallData. Step 2.2 done.""" - device = helpers.get_device(device_type) - code = generate_code(device, "add", "int add(int a, int b) { return a + b; }", 1, 2) - assert_contains(code, "uniform uint3 _thread_count") - assert_contains(code, "uniform int a") - assert_contains(code, "uniform int b") - assert_not_contains(code, "uniform CallData call_data") - assert_not_contains(code, "ParameterBlock call_data") - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_p2_thread_count_direct(device_type: spy.DeviceType): - """Fast path: _thread_count accessed directly, not via call_data prefix. Step 2.2 done.""" - device = helpers.get_device(device_type) - code = generate_code(device, "add", "int add(int a, int b) { return a + b; }", 1, 2) - assert_not_contains(code, "call_data._thread_count") - # Extract compute_main body and check _thread_count used directly - main_idx = code.index("void compute_main(") - main_body = code[main_idx:] - assert ">= _thread_count)" in main_body - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_p2_trampoline_present_for_prim(device_type: spy.DeviceType): - """Prim-mode kernel has no _trampoline function after Step 2.3.""" - device = helpers.get_device(device_type) - code = generate_code(device, "add", "int add(int a, int b) { return a + b; }", 1, 2) - assert_not_contains(code, "void _trampoline(") - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_p2_kernel_calls_trampoline(device_type: spy.DeviceType): - """Prim-mode compute_main inlines call sequence after Step 2.3.""" - device = helpers.get_device(device_type) - code = generate_code(device, "add", "int add(int a, int b) { return a + b; }", 1, 2) - # Extract compute_main body and check it no longer calls _trampoline. - main_idx = code.index("void compute_main(") - main_body = code[main_idx:] - assert "_trampoline(" not in main_body - assert "add(__tmp_a, __tmp_b);" in main_body - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_p2_sv_group_id_absent_dim0(device_type: spy.DeviceType): - """Fast path dim-0: SV_GroupID not needed. Step 2.2 done.""" - device = helpers.get_device(device_type) - code = generate_code(device, "add", "int add(int a, int b) { return a + b; }", 1, 2) - assert_not_contains(code, "SV_GroupID") - - -# -- Phase 2 negative gate — must REMAIN passing after Phase 2 -- - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_gate_p2_wanghasharg_keeps_load(device_type: spy.DeviceType): - """Non-direct-bind WangHashArg still uses __slangpy_load after Phase 2.""" - device = helpers.get_device(device_type) - src = "uint3 rng(uint3 input) { return input; }" - code = generate_code(device, "rng", src, WangHashArg(3)) - assert_contains(code, "__slangpy_load") - - -# =========================================================================== -# Step 2.1 tests — fast vs fallback path determination -# =========================================================================== - - -def build_call_data( - device: spy.Device, func_name: str, module_source: str, *args: Any, **kwargs: Any -) -> Any: - """Build CallData and return the full CallData object.""" - func = helpers.create_function_from_module(device, func_name, module_source) - return func.debug_build_call_data(*args, **kwargs) - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_step21_scalar_uses_entrypoint_args(device_type: spy.DeviceType): - """Simple scalar call has small inline-uniform size → use_entrypoint_args=True.""" - device = helpers.get_device(device_type) - cd = build_call_data(device, "add", "int add(int a, int b) { return a + b; }", 1, 2) - # Two ints (4+4) + RWValueRef for _result (descriptor, ~0 inline) + uint3 _thread_count (12) - # Should be well under any backend's threshold - assert cd.use_entrypoint_args is True - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_step21_threshold_property_positive(device_type: spy.DeviceType): - """Device has a positive max_entry_point_uniform_size threshold.""" - device = helpers.get_device(device_type) - threshold = device.info.limits.max_entry_point_uniform_size - assert threshold > 0 - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_step21_vector_uses_entrypoint_args(device_type: spy.DeviceType): - """float3 args are small enough for direct args.""" - device = helpers.get_device(device_type) - cd = build_call_data( - device, - "scale", - "float3 scale(float3 v, float s) { return v * s; }", - spy.math.float3(1, 2, 3), - 2.0, - ) - assert cd.use_entrypoint_args is True - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_step21_struct_uses_entrypoint_args(device_type: spy.DeviceType): - """All-scalar struct dict has small inline-uniform size.""" - device = helpers.get_device(device_type) - src = """ -struct S { float x; float y; }; -float sum(S s) { return s.x + s.y; } -""" - cd = build_call_data(device, "sum", src, {"_type": "S", "x": 1.0, "y": 2.0}) - assert cd.use_entrypoint_args is True - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_step21_tensor_uses_entrypoint_args(device_type: spy.DeviceType): - """Tensor args contribute descriptor-only (0 inline bytes) → direct args.""" - device = helpers.get_device(device_type) - tensor = Tensor.from_numpy(device, np.array([1.0, 2.0, 3.0], dtype=np.float32)) - cd = build_call_data( - device, - "sum_all", - "float sum_all(float x) { return x; }", - tensor, - ) - assert cd.use_entrypoint_args is True - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_step21_many_float4x4_may_exceed_vulkan(device_type: spy.DeviceType): - """Many float4x4 params may exceed Vulkan's 128-byte threshold. - - 8 × float4x4 = 8 × 64 bytes = 512 bytes inline + 12 bytes _thread_count = 524 bytes. - This exceeds Vulkan (128) and D3D12 (256) but not CUDA (4096). - """ - device = helpers.get_device(device_type) - src = """ -float4x4 sum8(float4x4 a, float4x4 b, float4x4 c, float4x4 d, - float4x4 e, float4x4 f, float4x4 g, float4x4 h) { - return a + b + c + d + e + f + g + h; -} -""" - identity = spy.math.float4x4.identity() - cd = build_call_data( - device, - "sum8", - src, - identity, - identity, - identity, - identity, - identity, - identity, - identity, - identity, - ) - threshold = device.info.limits.max_entry_point_uniform_size - if threshold >= 524: - assert cd.use_entrypoint_args is True - else: - assert cd.use_entrypoint_args is False - - -@pytest.mark.parametrize("device_type", helpers.DEFAULT_DEVICE_TYPES) -def test_step21_wanghasharg_uses_entrypoint_args(device_type: spy.DeviceType): - """WangHashArg (non-direct-bind) still counts its inline-uniform size. - Its wrapper type has a small inline footprint, so use_entrypoint_args should be True. - """ - device = helpers.get_device(device_type) - cd = build_call_data( - device, - "rng", - "uint3 rng(uint3 input) { return input; }", - WangHashArg(3), - ) - assert cd.use_entrypoint_args is True - - -if __name__ == "__main__": - pytest.main([__file__, "-vs"])