- 
                Notifications
    
You must be signed in to change notification settings  - Fork 328
 
feat: adds planning time validation of udf function signature #5470
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Adds planning-time type validation for rowwise Python UDFs by extracting type hints from function signatures and validating them against actual expression types in Rust.
Key changes:
- Extracts input parameter types from function signatures using 
inspect.signature()andget_type_hints() - Passes type information through the Python/Rust boundary to the planning phase
 - Validates expected vs actual types in 
PyScalarFn::to_field(), producing clear error messages like "Expects input to 'func_name' to be Float64, but received Int64" - Treats 
DataType::Pythonas a wildcard (used when no type hint provided or type isAny) - Applies to both 
@daft.funcand@daft.clsdecorated functions 
Implementation:
- Python side collects dtypes per parameter in 
_get_input_dtypes()and tracks them in__call__() - Rust side validates during logical plan construction in conditional compilation block (only when 
pythonfeature enabled) - Proto schema extended to serialize/deserialize input types
 
Test coverage: Comprehensive tests covering positional args, kwargs, keyword-only args, defaults, and both sync/async variants
Confidence Score: 4/5
- PR is safe to merge with one minor edge case consideration around mixed Expression/non-Expression positional arguments
 - Solid implementation with comprehensive tests. The Rust validation logic is correct, proto serialization is proper, and test coverage is thorough. One potential edge case exists with mixed positional args (Expression + literal), but this appears to be an existing design constraint rather than a new bug introduced by this PR
 - daft/udf/udf_v2.py - review the iterator-based dtype matching for positional args (lines 235-240)
 
Important Files Changed
File Analysis
| Filename | Score | Overview | 
|---|---|---|
| daft/udf/udf_v2.py | 3/5 | Adds _get_input_dtypes method and tracks input types, but iterator-based positional arg matching could misalign types if mixing Expression and non-Expression positional args | 
| src/daft-dsl/src/python_udf/mod.rs | 5/5 | Implements type validation at planning time, correctly validates expected vs actual types with proper Python type handling | 
| tests/udf/test_row_wise_udf.py | 5/5 | Comprehensive tests for sync and async rowwise UDF type validation with various argument patterns | 
| tests/udf/test_cls.py | 5/5 | Comprehensive tests for class-based UDF type validation with various argument patterns | 
Sequence Diagram
sequenceDiagram
    participant User
    participant Func as Func.__call__()
    participant Python as Python (udf_v2.py)
    participant Rust as Rust (python_udf/mod.rs)
    participant Schema
    User->>Func: @daft.func decorated function(expr1, expr2)
    
    Note over Func: Initialization Phase
    Func->>Python: _get_input_dtypes(fn)
    Python->>Python: inspect.signature(fn)
    Python->>Python: get_type_hints(fn)
    Python-->>Func: input_dtypes dict {param: DataType}
    
    Note over Func: Call Phase
    User->>Func: func(df["col1"], df["col2"])
    Func->>Func: Extract Expression args
    Func->>Func: Build input_dtypes list from dict
    Func->>Rust: row_wise_udf(name, cls, method, ..., input_dtypes)
    
    Note over Rust: Planning Phase (to_field)
    Rust->>Rust: PyScalarFn::to_field(schema)
    Rust->>Schema: Get actual types from args
    Schema-->>Rust: actual_inputs: Vec<DataType>
    Rust->>Rust: Validate len(expected) == len(actual)
    loop For each (expected, actual) pair
        alt expected != Python type
            Rust->>Rust: Validate expected == actual
            alt Types mismatch
                Rust-->>User: TypeError: Expects input to 'func' to be X, but received Y
            end
        end
    end
    Rust-->>Func: Field with validated type
    Func-->>User: Expression (validated)
    10 files reviewed, 1 comment
          Codecov Report❌ Patch coverage is  
 Additional details and impacted files@@            Coverage Diff             @@
##             main    #5470      +/-   ##
==========================================
- Coverage   71.66%   71.66%   -0.01%     
==========================================
  Files         998      998              
  Lines      127368   127423      +55     
==========================================
+ Hits        91279    91315      +36     
- Misses      36089    36108      +19     
 🚀 New features to boost your workflow:
  | 
    
Changes Made
UDF's are our most expensive operation to run, so we want to make sure that we're not wasting resources running functions if we can know in advance that the types are wrong and it would likely fail at runtime.
This PR adds planning time validation on rowwise python udf function signatures. This is consistent with our built in expressions. If the dtype is wrong, we error during planning, preventing potentially costly runtime errors or confusing results
example:
previously this would work with any datatype, which could produce confusing results, as the
s: strwas not actually enforced.This is opt in as we only check if there are function signatures. If a user does not provide typing, then we don't do anything
as such, this would continue to work for all datatypes.
Related Issues
Closes #5462
Checklist
docs/mkdocs.ymlnavigation