Skip to content

Conversation

@SpenserCai
Copy link

@SpenserCai SpenserCai commented Dec 9, 2025

Summary

This PR adds bilinear interpolation support to Candle, implementing upsample_bilinear2d and upsample_bilinear2d_with_scale methods for the Tensor type.

Motivation

Many modern neural networks require bilinear interpolation for vision tasks, including:

  • Vision transformers (e.g., HunyuanOCR, CLIP)
  • Semantic segmentation models
  • Super-resolution networks
  • Feature pyramid networks

Currently, Candle only supports nearest-neighbor interpolation (upsample_nearest1d, upsample_nearest2d). This PR fills that gap with a PyTorch-compatible bilinear interpolation implementation.

Implementation

API

Two new public methods on Tensor:

// Specify exact output size
pub fn upsample_bilinear2d(
    &self,
    target_h: usize,
    target_w: usize,
    align_corners: bool,
) -> Result<Self>

// Specify scale factors
pub fn upsample_bilinear2d_with_scale(
    &self,
    scale_h: f64,
    scale_w: f64,
    align_corners: bool,
) -> Result<Self>

Features

  • ✅ CPU Backend - Full implementation with optimized performance

    • Pre-computed indices for better cache locality
    • Support for all data types: F32, F64, F16, BF16, U8, U32
    • Contiguous output allocation
    • Direct dtype operations to avoid unnecessary conversions
  • ✅ CUDA Backend - Complete GPU acceleration support

    • Thread-per-pixel parallel processing
    • Coalesced memory access patterns
    • Double precision intermediates for numerical accuracy
    • Support for F32, F64, F16, BF16, U8, U32
    • Proper handling of half-precision types (__half, __nv_bfloat16)
    • Architecture-specific optimizations (SM 5.3+ for FP16, SM 8.0+ for BF16)
  • ✅ Metal Backend - Apple Silicon optimized implementation

    • Compute shader with 256 threads per group
    • Float precision intermediates (sufficient for most cases)
    • Support for F32, F16, BF16 (when available), U8, U32
    • Unified memory architecture benefits
    • Optimized for Apple M-series chips
  • ✅ PyTorch Compatibility - Exact numerical equivalence

    • Matches torch.nn.functional.interpolate(mode='bilinear') behavior
    • Correct align_corners parameter semantics
    • Proper scale calculation following PyTorch's area_pixel_compute_scale logic
    • Distinction between size mode and scale_factor mode
    • Maximum error < 1e-5 for F32 (validated against PyTorch 2.10.0)
  • ✅ Comprehensive Testing

    • 12 unit tests covering various scenarios
    • Cross-backend validation (CPU/GPU/Metal)
    • PyTorch comparison tests with numerical validation
    • Edge case handling (identity, single pixel, non-square dimensions)

Files Modified

Core Implementation:

  • candle-core/src/op.rs - Added UpsampleBilinear2D operation
  • candle-core/src/tensor.rs - Added public API methods with documentation
  • candle-core/src/storage.rs - Added storage dispatch layer
  • candle-core/src/backend.rs - Added backend trait method

Backend Implementations:

  • candle-core/src/cpu_backend/mod.rs - CPU implementation with all dtype support
  • candle-kernels/src/conv.cu - CUDA kernel implementation
  • candle-metal-kernels/src/metal_src/conv.metal - Metal shader implementation
  • candle-metal-kernels/src/lib.rs - Metal helper functions
  • candle-metal-kernels/src/kernels/convolution.rs - Metal kernel registration

Testing:

  • candle-core/tests/interpolate_tests.rs - Comprehensive test suite (new file)

Test Result

Metal

cargo test --test interpolate_tests --features metal

running 24 tests
test bilinear_fractional_scale_cpu ... ok
test bilinear_asymmetric_scale_cpu ... ok
test bilinear_batch_cpu ... ok
test bilinear_align_corners_true_cpu ... ok
test bilinear_downscale_cpu ... ok
test bilinear_identity_cpu ... ok
test bilinear_align_corners_cpu ... ok
test bilinear_non_square_cpu ... ok
test bilinear_pytorch_compat_cpu ... ok
test bilinear_scale_factor_cpu ... ok
test bilinear_single_pixel_cpu ... ok
test bilinear_upscale_2x_cpu ... ok
test bilinear_scale_factor_metal ... ok
test bilinear_batch_metal ... ok
test bilinear_fractional_scale_metal ... ok
test bilinear_non_square_metal ... ok
test bilinear_asymmetric_scale_metal ... ok
test bilinear_downscale_metal ... ok
test bilinear_align_corners_true_metal ... ok
test bilinear_pytorch_compat_metal ... ok
test bilinear_upscale_2x_metal ... ok
test bilinear_single_pixel_metal ... ok
test bilinear_identity_metal ... ok
test bilinear_align_corners_metal ... ok

test result: ok. 24 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.13s

Cuda

running 24 tests
test bilinear_align_corners_true_cpu ... ok
test bilinear_align_corners_cpu ... ok
test bilinear_asymmetric_scale_cpu ... ok
test bilinear_batch_cpu ... ok
test bilinear_downscale_cpu ... ok
test bilinear_fractional_scale_cpu ... ok
test bilinear_identity_cpu ... ok
test bilinear_non_square_cpu ... ok
test bilinear_pytorch_compat_cpu ... ok
test bilinear_scale_factor_cpu ... ok
test bilinear_single_pixel_cpu ... ok
test bilinear_upscale_2x_cpu ... ok
test bilinear_single_pixel_gpu ... ok
test bilinear_asymmetric_scale_gpu ... ok
test bilinear_pytorch_compat_gpu ... ok
test bilinear_batch_gpu ... ok
test bilinear_downscale_gpu ... ok
test bilinear_align_corners_true_gpu ... ok
test bilinear_upscale_2x_gpu ... ok
test bilinear_non_square_gpu ... ok
test bilinear_scale_factor_gpu ... ok
test bilinear_fractional_scale_gpu ... ok
test bilinear_identity_gpu ... ok
test bilinear_align_corners_gpu ... ok

test result: ok. 24 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.16s

Breaking Changes

✅ None - this is a new feature addition.

@SpenserCai
Copy link
Author

SpenserCai commented Dec 9, 2025

@ivarflakstad Request for review, thank you so much

Copy link
Member

@ivarflakstad ivarflakstad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! We definitely need this and it is on my todo list so thank you :)

I'll do a detailed review at some later point, but for now I'll only give feedback on the tests.
Because overall I don't think the they are quite strict enough wrt the functionality they are checking.

Maybe you could get some expected values from pytorch and assert that they match the results from candle?

For the tests that specifically only verify correct output dims I think they could be merged into a single test with several cases.

Does that sound reasonable to you?

@SpenserCai
Copy link
Author

SpenserCai commented Dec 12, 2025

Great! We definitely need this and it is on my todo list so thank you :)

I'll do a detailed review at some later point, but for now I'll only give feedback on the tests. Because overall I don't think the they are quite strict enough wrt the functionality they are checking.

Maybe you could get some expected values from pytorch and assert that they match the results from candle?

For the tests that specifically only verify correct output dims I think they could be merged into a single test with several cases.

Does that sound reasonable to you?

Thank you for your reply. I think it's reasonable. Before writing this PR, I manually implemented HunyuanOCR. This PR is decoupled from the HunyuanOCR I implemented. I conducted end-to-end testing on the implemented HunyuanOCR (however, I didn't support CUDA and Metal in my implementation (but used the CPU backend). These two supports (CUDA and Metal) are implemented in this PR). My original plan was to migrate the general algorithms used in HunyuanOCR to Candle, and then refactor the Rust implementation of HunyuanOCR based on the new Candle. The refactored Rust implementation of HunyuanOCR would be open sourced or PR'd to candle-transformer. This would be elegant.
I completely agree with you, regardless of the situation. To ensure rigor, independent testing is necessary.
And Thanks again.

image

@SpenserCai
Copy link
Author

Great! We definitely need this and it is on my todo list so thank you :)
I'll do a detailed review at some later point, but for now I'll only give feedback on the tests. Because overall I don't think the they are quite strict enough wrt the functionality they are checking.
Maybe you could get some expected values from pytorch and assert that they match the results from candle?
For the tests that specifically only verify correct output dims I think they could be merged into a single test with several cases.
Does that sound reasonable to you?

Thank you for your reply. I think it's reasonable. Before writing this PR, I manually implemented HunyuanOCR. This PR is decoupled from the HunyuanOCR I implemented. I conducted end-to-end testing on the implemented HunyuanOCR (however, I didn't support CUDA and Metal in my implementation (but used the CPU backend). These two supports (CUDA and Metal) are implemented in this PR). My original plan was to migrate the general algorithms used in HunyuanOCR to Candle, and then refactor the Rust implementation of HunyuanOCR based on the new Candle. The refactored Rust implementation of HunyuanOCR would be open sourced or PR'd to candle-transformer. This would be elegant. I completely agree with you, regardless of the situation. To ensure rigor, independent testing is necessary. And Thanks again.

image

@ivarflakstad I have expanded the test cases again by adding more PyTorch numerical comparison verification, and tested them on both the metal and CUDA backends. The results are as follows.

File: candle/candle-core/tests/bilinear_tests.rs

Metal

     Running tests/bilinear_tests.rs (/Users/spensercai/Dev/candle_dev/candle/target/debug/deps/bilinear_tests-0a37c0f4e1ce3b04)

running 24 tests
test bilinear_output_dimensions_cpu ... ok
test bilinear_pytorch_multi_channel_cpu ... ok
test bilinear_align_corners_difference_cpu ... ok
test bilinear_pytorch_2x_upscale_cpu ... ok
test bilinear_pytorch_align_corners_true_cpu ... ok
test bilinear_pytorch_downscale_cpu ... ok
test bilinear_identity_cpu ... ok
test bilinear_pytorch_non_square_exact_cpu ... ok
test bilinear_pytorch_scale_factor_cpu ... ok
test bilinear_pytorch_tiny_1x1_to_3x3_cpu ... ok
test bilinear_pytorch_tiny_1x2_to_3x6_cpu ... ok
test bilinear_pytorch_large_64x64_to_128x128_cpu ... ok
test bilinear_identity_metal ... ok
test bilinear_output_dimensions_metal ... ok
test bilinear_pytorch_non_square_exact_metal ... ok
test bilinear_pytorch_downscale_metal ... ok
test bilinear_pytorch_large_64x64_to_128x128_metal ... ok
test bilinear_pytorch_tiny_1x1_to_3x3_metal ... ok
test bilinear_pytorch_align_corners_true_metal ... ok
test bilinear_pytorch_tiny_1x2_to_3x6_metal ... ok
test bilinear_pytorch_2x_upscale_metal ... ok
test bilinear_pytorch_multi_channel_metal ... ok
test bilinear_align_corners_difference_metal ... ok
test bilinear_pytorch_scale_factor_metal ... ok

test result: ok. 24 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.06s

CUDA

    Finished `test` profile [unoptimized + debuginfo] target(s) in 0.45s
     Running tests/bilinear_tests.rs (target/debug/deps/bilinear_tests-72e19a9c1fdd7dc5)

running 24 tests
test bilinear_align_corners_difference_cpu ... ok
test bilinear_identity_cpu ... ok
test bilinear_pytorch_2x_upscale_cpu ... ok
test bilinear_output_dimensions_cpu ... ok
test bilinear_pytorch_align_corners_true_cpu ... ok
test bilinear_pytorch_downscale_cpu ... ok
test bilinear_pytorch_multi_channel_cpu ... ok
test bilinear_pytorch_non_square_exact_cpu ... ok
test bilinear_pytorch_scale_factor_cpu ... ok
test bilinear_pytorch_tiny_1x1_to_3x3_cpu ... ok
test bilinear_pytorch_tiny_1x2_to_3x6_cpu ... ok
test bilinear_pytorch_large_64x64_to_128x128_cpu ... ok
test bilinear_pytorch_multi_channel_gpu ... ok
test bilinear_pytorch_large_64x64_to_128x128_gpu ... ok
test bilinear_output_dimensions_gpu ... ok
test bilinear_pytorch_tiny_1x2_to_3x6_gpu ... ok
test bilinear_pytorch_downscale_gpu ... ok
test bilinear_pytorch_align_corners_true_gpu ... ok
test bilinear_align_corners_difference_gpu ... ok
test bilinear_identity_gpu ... ok
test bilinear_pytorch_2x_upscale_gpu ... ok
test bilinear_pytorch_scale_factor_gpu ... ok
test bilinear_pytorch_non_square_exact_gpu ... ok
test bilinear_pytorch_tiny_1x1_to_3x3_gpu ... ok

test result: ok. 24 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.41s

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants