Skip to content
Open
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
c71e591
test using slangpy in pytorch sequence
jhelferty-nv Feb 5, 2026
bc3e4af
loss
jhelferty-nv Feb 5, 2026
ca43731
Activations
jhelferty-nv Feb 5, 2026
132608a
interleaving tests
jhelferty-nv Feb 5, 2026
6edd1b6
Fix PyTorch autograd compatibility by preventing spurious tensor vers…
jhelferty-nv Feb 5, 2026
6531182
activations with min/max
jhelferty-nv Feb 5, 2026
85bbe15
Strided slice with different ops
jhelferty-nv Feb 5, 2026
cae7267
minor tidy
jhelferty-nv Feb 5, 2026
f9b7910
more test coverage for copyback
jhelferty-nv Feb 9, 2026
b590ae3
Fix tensor writeback logic for both scalar broadcast and tensor param…
jhelferty-nv Feb 9, 2026
49eb4b9
Fix output copy-back for scalar broadcast functions
jhelferty-nv Feb 9, 2026
d5bd496
Clean up test docstrings and comments
jhelferty-nv Feb 9, 2026
719233b
Merge branch 'main' into add-pytorch-tests
jhelferty-nv Feb 9, 2026
e71a92b
Add explicit gradient copy-back tests
jhelferty-nv Feb 10, 2026
2cd616c
Refactor copy-back decision to cache time
jhelferty-nv Feb 11, 2026
293d4f5
Rename CachedOffsets to CachedBindingInfo
jhelferty-nv Feb 11, 2026
fe5ef48
Update comments to reflect copy-back decision logic
jhelferty-nv Feb 11, 2026
777fd6a
Add clarifying comment for copy-back flags in CachedBindingInfo
jhelferty-nv Feb 11, 2026
f10146b
Merge branch 'main' into add-pytorch-tests
jhelferty-nv Feb 11, 2026
bca0314
Fix renamed variable after merge with main
jhelferty-nv Feb 11, 2026
d8596fd
Fix gradient copy-back for raw torch.Tensor inputs on Vulkan/D3D12
jhelferty-nv Feb 11, 2026
0d31d28
Remove unused needs_grad_copyback field from CachedBindingInfo
jhelferty-nv Feb 11, 2026
eeea630
Condense copy-back comments to reduce duplication
jhelferty-nv Feb 11, 2026
9687116
Fix interop backward crash and broadcast stride zeroing
jhelferty-nv Feb 12, 2026
95d7081
Fix CUDA interop memory leak: free mapped buffer before destroying ex…
jhelferty-nv Feb 12, 2026
4b469c6
Add end-to-end workflow tests for slang-torch parity validation
jhelferty-nv Feb 12, 2026
b22bf96
Clarify e2e test docstrings: these are workflow pattern tests, not sl…
jhelferty-nv Feb 12, 2026
87a4314
Rename test_e2e_workflows.py to test_torch_autograd_workflows.py
jhelferty-nv Feb 12, 2026
8f9060d
Merge branch 'main' into add-pytorch-tests
jhelferty-nv Feb 12, 2026
d9dea92
Merge branch 'main' into add-pytorch-tests
jhelferty-nv Feb 17, 2026
4f1213e
Use Slang uniform type name for interop copy-back decisions
jhelferty-nv Feb 17, 2026
6930a6d
Use async memset for zeroed interop buffers
jhelferty-nv Feb 17, 2026
8b744a5
Use correct CUDA stream for interop buffer memset
jhelferty-nv Feb 19, 2026
01dedcd
Merge branch 'main' into add-pytorch-tests
jhelferty-nv Feb 19, 2026
e7f7c16
Merge branch 'main' into add-pytorch-tests
jhelferty-nv Feb 23, 2026
04e7a3d
Merge branch 'main' into add-pytorch-tests
jhelferty-nv Feb 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,032 changes: 1,032 additions & 0 deletions slangpy/tests/slangpy_tests/test_pytorch_gradient_parity.py

Large diffs are not rendered by default.

601 changes: 601 additions & 0 deletions slangpy/tests/slangpy_tests/test_torch_autograd_workflows.py

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion slangpy/torchintegration/torchtensormarshall.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,10 @@ def __init__(
full_dims = dims + len(slang_dtype.shape)

# Determine writability and tensor type
writable = True # Torch tensors are always potentially writable
# Note: writable=True here signals that the tensor CAN be written to.
# Actual copy-back decisions are made in C++ (ensure_binding_info_cached)
# based on the Slang parameter's type and access mode.
writable = True
has_derivatives = d_in is not None or d_out is not None

# Get the slang tensor type
Expand Down
13 changes: 13 additions & 0 deletions src/sgl/device/cuda_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ void memset_device(void* dst, uint8_t value, size_t count)
SGL_CU_CHECK(cuMemsetD8(reinterpret_cast<CUdeviceptr>(dst), value, count));
}

void memset_device_async(void* dst, uint8_t value, size_t count, CUstream stream)
{
SGL_CU_CHECK(cuMemsetD8Async(reinterpret_cast<CUdeviceptr>(dst), value, count, stream));
}

CUexternalMemory import_external_memory(const Buffer* buffer)
{
SGL_CU_SCOPE(buffer->device());
Expand Down Expand Up @@ -317,6 +322,14 @@ ExternalMemory::ExternalMemory(const Buffer* buffer)

ExternalMemory::~ExternalMemory()
{
// The mapped device pointer returned by cuExternalMemoryGetMappedBuffer must be
// freed with cuMemFree before destroying the external memory, otherwise the CUDA
// driver keeps the underlying allocation alive and we leak ~64KB+ per buffer.
if (m_mapped_data) {
SGL_CU_SCOPE(m_resource->device());
SGL_CU_CHECK(cuMemFree(reinterpret_cast<CUdeviceptr>(m_mapped_data)));
m_mapped_data = nullptr;
}
destroy_external_memory(m_external_memory);
}

Expand Down
1 change: 1 addition & 0 deletions src/sgl/device/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ SGL_API void memcpy_host_to_device(void* dst, const void* src, size_t count);
SGL_API void memcpy_device_to_host(void* dst, const void* src, size_t count);

SGL_API void memset_device(void* dst, uint8_t value, size_t count);
SGL_API void memset_device_async(void* dst, uint8_t value, size_t count, CUstream stream = 0);

SGL_API CUexternalMemory import_external_memory(const Buffer* buffer);
SGL_API void destroy_external_memory(CUexternalMemory ext_mem);
Expand Down
77 changes: 39 additions & 38 deletions src/slangpy_ext/utils/slangpytensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,9 @@ NativeTensorMarshall::TensorFieldOffsets NativeTensorMarshall::extract_tensor_fi
return offsets;
}

NativeTensorMarshall::CachedOffsets NativeTensorMarshall::extract_offsets(ShaderCursor field)
NativeTensorMarshall::CachedBindingInfo NativeTensorMarshall::extract_binding_info(ShaderCursor field)
{
NativeTensorMarshall::CachedOffsets offsets;
NativeTensorMarshall::CachedBindingInfo offsets;

std::string_view type_name = field.slang_type_layout()->getName();
bool is_diff_tensor_view = type_name.find("DiffTensorView") != std::string_view::npos;
Expand Down Expand Up @@ -334,11 +334,11 @@ Shape NativeTensorMarshall::get_shape(nb::object data) const
return buffer->shape();
}

void NativeTensorMarshall::ensure_offsets_cached(ShaderCursor cursor, NativeBoundVariableRuntime* binding) const
void NativeTensorMarshall::ensure_binding_info_cached(ShaderCursor cursor, NativeBoundVariableRuntime* binding) const
{
if (!m_cached_offsets.primal.is_valid) {
if (!m_cached_binding_info.primal.is_valid) {
ShaderCursor field = cursor[binding->variable_name()];
m_cached_offsets = extract_offsets(field);
m_cached_binding_info = extract_binding_info(field);
}
}

Expand All @@ -354,14 +354,14 @@ void NativeTensorMarshall::write_native_tensor(
const ref<NativeTensor>& grad_in = primal_tensor->grad_in();
const ref<NativeTensor>& grad_out = primal_tensor->grad_out();

if (!m_cached_offsets.has_grad_fields) {
if (!m_cached_binding_info.has_grad_fields) {
// Flat structure - write directly to primal offsets
write_native_tensor_fields(
context,
binding,
shader_object,
base_address,
m_cached_offsets.primal,
m_cached_binding_info.primal,
primal_tensor,
read_back
);
Expand All @@ -372,32 +372,32 @@ void NativeTensorMarshall::write_native_tensor(
binding,
shader_object,
base_address,
m_cached_offsets.primal,
m_cached_binding_info.primal,
primal_tensor,
read_back
);

if (m_d_in && m_cached_offsets.grad_in.is_valid) {
if (m_d_in && m_cached_binding_info.grad_in.is_valid) {
SGL_CHECK(grad_in, "Missing required input gradients");
write_native_tensor_fields(
context,
binding,
shader_object,
base_address,
m_cached_offsets.grad_in,
m_cached_binding_info.grad_in,
grad_in.get(),
read_back
);
}

if (m_d_out && m_cached_offsets.grad_out.is_valid) {
if (m_d_out && m_cached_binding_info.grad_out.is_valid) {
SGL_CHECK(grad_out, "Missing required output gradients");
write_native_tensor_fields(
context,
binding,
shader_object,
base_address,
m_cached_offsets.grad_out,
m_cached_binding_info.grad_out,
grad_out.get(),
read_back
);
Expand All @@ -421,35 +421,35 @@ void NativeTensorMarshall::write_shader_cursor_pre_dispatch(
) const
{
// Initialize cached offsets on first call
ensure_offsets_cached(cursor, binding);
ensure_binding_info_cached(cursor, binding);

#if 0
// Validate offsets on future calls
if (m_cached_offsets.primal.is_valid) {
CachedOffsets offsets = extract_offsets(cursor[binding->variable_name()]);
if (m_cached_binding_info.primal.is_valid) {
CachedBindingInfo offsets = extract_binding_info(cursor[binding->variable_name()]);
SGL_CHECK(
offsets.primal.data == m_cached_offsets.primal.data &&
offsets.primal.shape == m_cached_offsets.primal.shape &&
offsets.primal.strides == m_cached_offsets.primal.strides &&
offsets.primal.offset == m_cached_offsets.primal.offset,
offsets.primal.data == m_cached_binding_info.primal.data &&
offsets.primal.shape == m_cached_binding_info.primal.shape &&
offsets.primal.strides == m_cached_binding_info.primal.strides &&
offsets.primal.offset == m_cached_binding_info.primal.offset,
"Cached primal tensor offsets do not match current shader cursor offsets"
);
if (offsets.grad_in.is_valid) {
SGL_CHECK(
offsets.grad_in.data == m_cached_offsets.grad_in.data &&
offsets.grad_in.shape == m_cached_offsets.grad_in.shape &&
offsets.grad_in.strides == m_cached_offsets.grad_in.strides &&
offsets.grad_in.offset == m_cached_offsets.grad_in.offset,
offsets.grad_in.data == m_cached_binding_info.grad_in.data &&
offsets.grad_in.shape == m_cached_binding_info.grad_in.shape &&
offsets.grad_in.strides == m_cached_binding_info.grad_in.strides &&
offsets.grad_in.offset == m_cached_binding_info.grad_in.offset,
"Cached grad_in tensor offsets do not match current shader cursor offsets"
);
}
if (offsets.grad_out.is_valid) {

SGL_CHECK(
offsets.grad_out.data == m_cached_offsets.grad_out.data &&
offsets.grad_out.shape == m_cached_offsets.grad_out.shape &&
offsets.grad_out.strides == m_cached_offsets.grad_out.strides &&
offsets.grad_out.offset == m_cached_offsets.grad_out.offset,
offsets.grad_out.data == m_cached_binding_info.grad_out.data &&
offsets.grad_out.shape == m_cached_binding_info.grad_out.shape &&
offsets.grad_out.strides == m_cached_binding_info.grad_out.strides &&
offsets.grad_out.offset == m_cached_binding_info.grad_out.offset,
"Cached grad_out tensor offsets do not match current shader cursor offsets"
);
}
Expand All @@ -460,7 +460,8 @@ void NativeTensorMarshall::write_shader_cursor_pre_dispatch(
NativeTensor* primal;
if (nb::try_cast(value, primal)) {
ShaderObject* shader_object = cursor.shader_object();
void* base_address = shader_object->reserve_data(m_cached_offsets.field_offset, m_cached_offsets.field_size);
void* base_address
= shader_object->reserve_data(m_cached_binding_info.field_offset, m_cached_binding_info.field_size);

// Write the differentiated tensor structure
write_native_tensor(context, binding, shader_object, base_address, primal, read_back);
Expand Down Expand Up @@ -494,7 +495,7 @@ void NativeTensorMarshall::write_tensor_fields_from_buffer(
if (offsets.data.binding_range_index == offsets.shape.binding_range_index) {
write_value_helper(
base_address,
offsets.data.uniform_offset - m_cached_offsets.field_offset.uniform_offset,
offsets.data.uniform_offset - m_cached_binding_info.field_offset.uniform_offset,
buffer->device_address()
);
} else {
Expand All @@ -503,21 +504,21 @@ void NativeTensorMarshall::write_tensor_fields_from_buffer(

write_strided_array_helper(
base_address,
offsets.shape.uniform_offset - m_cached_offsets.field_offset.uniform_offset,
offsets.shape.uniform_offset - m_cached_binding_info.field_offset.uniform_offset,
shape,
offsets.array_stride
);

write_strided_array_helper(
base_address,
offsets.strides.uniform_offset - m_cached_offsets.field_offset.uniform_offset,
offsets.strides.uniform_offset - m_cached_binding_info.field_offset.uniform_offset,
strides,
offsets.array_stride
);

write_value_helper(
base_address,
offsets.offset.uniform_offset - m_cached_offsets.field_offset.uniform_offset,
offsets.offset.uniform_offset - m_cached_binding_info.field_offset.uniform_offset,
offset
);

Expand All @@ -527,7 +528,7 @@ void NativeTensorMarshall::write_tensor_fields_from_buffer(
if (offsets.element_byte_stride.is_valid()) {
write_value_helper(
base_address,
offsets.element_byte_stride.uniform_offset - m_cached_offsets.field_offset.uniform_offset,
offsets.element_byte_stride.uniform_offset - m_cached_binding_info.field_offset.uniform_offset,
static_cast<uint32_t>(buffer->desc().struct_size)
);
}
Expand All @@ -549,28 +550,28 @@ void NativeTensorMarshall::write_tensor_fields_from_pointer(
DeviceAddress address = reinterpret_cast<DeviceAddress>(data_ptr);
write_value_helper(
base_address,
offsets.data.uniform_offset - m_cached_offsets.field_offset.uniform_offset,
offsets.data.uniform_offset - m_cached_binding_info.field_offset.uniform_offset,
address
);

// Write shape and strides using the same mechanism as write_tensor_fields_from_buffer
write_strided_array_helper(
base_address,
offsets.shape.uniform_offset - m_cached_offsets.field_offset.uniform_offset,
offsets.shape.uniform_offset - m_cached_binding_info.field_offset.uniform_offset,
shape,
offsets.array_stride
);

write_strided_array_helper(
base_address,
offsets.strides.uniform_offset - m_cached_offsets.field_offset.uniform_offset,
offsets.strides.uniform_offset - m_cached_binding_info.field_offset.uniform_offset,
strides,
offsets.array_stride
);

write_value_helper(
base_address,
offsets.offset.uniform_offset - m_cached_offsets.field_offset.uniform_offset,
offsets.offset.uniform_offset - m_cached_binding_info.field_offset.uniform_offset,
offset
);

Expand Down Expand Up @@ -612,7 +613,7 @@ void NativeTensorMarshall::write_native_tensor_fields(
tvd.sizes[i] = static_cast<uint32_t>(shape[i]);
}
tvd.dimensionCount = static_cast<uint32_t>(ndim);
shader_object->set_data(m_cached_offsets.field_offset, &tvd, sizeof(TensorViewData));
shader_object->set_data(m_cached_binding_info.field_offset, &tvd, sizeof(TensorViewData));
return;
}

Expand Down
23 changes: 15 additions & 8 deletions src/slangpy_ext/utils/slangpytensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,24 +161,31 @@ class NativeTensorMarshall : public NativeMarshall {
bool is_tensorview = false;
};

/// Cached offsets for all tensor variants (primal, grad_in, grad_out)
/// Public so NativeTorchTensorMarshall can reuse them
struct CachedOffsets {
/// Cached binding info for all tensor variants (primal, grad_in, grad_out)
/// Contains shader offsets plus copy-back decision flags.
/// Public so NativeTorchTensorMarshall can reuse this structure.
struct CachedBindingInfo {
TensorFieldOffsets primal; // Offsets for primal tensor fields
TensorFieldOffsets grad_in; // Offsets for gradient input fields (if present)
TensorFieldOffsets grad_out; // Offsets for gradient output fields (if present)
bool has_grad_fields = false; // Whether tensor uses _primal wrapper (differentiated mode)
ShaderOffset field_offset; // Base offset of the entire field structure
uint32_t field_size = 0; // Total size of the field in uniform data

// Whether to copy interop buffers back to torch tensors after dispatch.
// Only used by NativeTorchTensorMarshall; computed in ensure_binding_info_cached()
// from the Slang uniform type name (Tensor/WTensor/RWTensor/DiffTensor/etc.).
bool needs_primal_copyback = false;
bool needs_grad_copyback = false;
};

/// Extract TensorFieldOffsets from a ShaderCursor pointing to a tensor structure
/// Public so NativeTorchTensorMarshall can reuse it
static TensorFieldOffsets extract_tensor_field_offsets(ShaderCursor tensor_cursor);

/// Extract all cached offsets (primal, grad_in, grad_out) from a field cursor
/// Extract all cached binding info (primal, grad_in, grad_out) from a field cursor
/// Public so NativeTorchTensorMarshall can reuse it
static CachedOffsets extract_offsets(ShaderCursor cursor);
static CachedBindingInfo extract_binding_info(ShaderCursor cursor);

private:
int m_dims;
Expand All @@ -187,11 +194,11 @@ class NativeTensorMarshall : public NativeMarshall {
ref<TypeLayoutReflection> m_element_layout;
ref<NativeTensorMarshall> m_d_in;
ref<NativeTensorMarshall> m_d_out;
mutable CachedOffsets m_cached_offsets;
mutable CachedBindingInfo m_cached_binding_info;

/// Initialize cached offsets if not already done
/// Initialize cached binding info if not already done
/// This method is called on the first dispatch to cache reflection data for subsequent calls
void ensure_offsets_cached(ShaderCursor cursor, NativeBoundVariableRuntime* binding) const;
void ensure_binding_info_cached(ShaderCursor cursor, NativeBoundVariableRuntime* binding) const;

//
// High-Level Write Methods
Expand Down
Loading
Loading