Skip to content

Commit 1d52025

Browse files
Return a constant value from len() where possible
1 parent 35bd906 commit 1d52025

4 files changed

Lines changed: 49 additions & 32 deletions

File tree

warp/builtins.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,14 @@ def extract_tuple(arg, as_constant=False):
7676
return out
7777

7878

79+
def static_len_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
80+
if arg_types is None:
81+
return int
82+
83+
length = warp.types.type_length(arg_types["a"])
84+
return Var(None, type=int, constant=length)
85+
86+
7987
# ---------------------------------
8088
# Scalar Math
8189

@@ -7720,7 +7728,7 @@ def static(expr):
77207728
add_builtin(
77217729
"len",
77227730
input_types={"a": vector(length=Any, dtype=Scalar)},
7723-
value_type=int,
7731+
value_func=static_len_value_func,
77247732
doc="Return the number of elements in a vector.",
77257733
group="Utility",
77267734
export=False,
@@ -7729,7 +7737,7 @@ def static(expr):
77297737
add_builtin(
77307738
"len",
77317739
input_types={"a": quaternion(dtype=Scalar)},
7732-
value_type=int,
7740+
value_func=static_len_value_func,
77337741
doc="Return the number of elements in a quaternion.",
77347742
group="Utility",
77357743
export=False,
@@ -7738,7 +7746,7 @@ def static(expr):
77387746
add_builtin(
77397747
"len",
77407748
input_types={"a": matrix(shape=(Any, Any), dtype=Scalar)},
7741-
value_type=int,
7749+
value_func=static_len_value_func,
77427750
doc="Return the number of rows in a matrix.",
77437751
group="Utility",
77447752
export=False,
@@ -7747,7 +7755,7 @@ def static(expr):
77477755
add_builtin(
77487756
"len",
77497757
input_types={"a": transformation(dtype=Float)},
7750-
value_type=int,
7758+
value_func=static_len_value_func,
77517759
doc="Return the number of elements in a transformation.",
77527760
group="Utility",
77537761
export=False,
@@ -7765,7 +7773,7 @@ def static(expr):
77657773
add_builtin(
77667774
"len",
77677775
input_types={"a": tile(dtype=Any, shape=Tuple[int, ...])},
7768-
value_type=int,
7776+
value_func=static_len_value_func,
77697777
doc="Return the number of rows in a tile.",
77707778
group="Utility",
77717779
export=False,
@@ -7835,10 +7843,11 @@ def tuple_extract_dispatch_func(input_types: Mapping[str, type], return_type: An
78357843
missing_grad=True,
78367844
)
78377845

7846+
78387847
add_builtin(
78397848
"len",
78407849
input_types={"a": Tuple},
7841-
value_type=int,
7850+
value_func=static_len_value_func,
78427851
doc="Return the number of elements in a tuple.",
78437852
group="Utility",
78447853
export=False,

warp/codegen.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1423,6 +1423,13 @@ def add_call(adj, func, args, kwargs, type_args, min_outputs=None):
14231423
bound_arg_values,
14241424
)
14251425

1426+
# Handle the special case where a Var instance is returned from the `value_func`
1427+
# callback, in which case we replace the call with a reference to that variable.
1428+
if isinstance(return_type, Var):
1429+
return adj.register_var(return_type)
1430+
elif isinstance(return_type, Sequence) and all(isinstance(x, Var) for x in return_type):
1431+
return tuple(adj.register_var(x) for x in return_type)
1432+
14261433
if get_origin(return_type) is tuple:
14271434
types = get_args(return_type)
14281435
return_type = warp.types.tuple_t(types=types, values=(None,) * len(types))
@@ -2334,7 +2341,6 @@ def emit_Call(adj, node):
23342341
args = tuple(adj.resolve_arg(x) for x in node.args)
23352342
kwargs = {x.arg: adj.resolve_arg(x.value) for x in node.keywords}
23362343

2337-
# add the call and build the callee adjoint if needed (func.adj)
23382344
out = adj.add_call(func, args, kwargs, type_args, min_outputs=min_outputs)
23392345

23402346
if warp.config.verify_autograd_array_access:
@@ -3049,29 +3055,10 @@ def evaluate_static_expression(adj, node) -> tuple[Any, str]:
30493055

30503056
# Replace all constant `len()` expressions with their value.
30513057
if "len" in static_code:
3052-
3053-
def eval_len(obj):
3054-
if type_is_vector(obj):
3055-
return obj._length_
3056-
elif type_is_quaternion(obj):
3057-
return obj._length_
3058-
elif type_is_matrix(obj):
3059-
return obj._shape_[0]
3060-
elif type_is_transformation(obj):
3061-
return obj._length_
3062-
elif is_tuple(obj):
3063-
return len(obj.types)
3064-
elif is_tile(obj):
3065-
return obj.shape[0]
3066-
elif get_origin(obj) is tuple:
3067-
return len(get_args(obj))
3068-
3069-
return len(obj)
3070-
30713058
len_expr_ctx = vars_dict.copy()
30723059
constant_types = {k: v.type for k, v in adj.symbols.items() if isinstance(v, Var) and v.type is not None}
30733060
len_expr_ctx.update(constant_types)
3074-
len_expr_ctx.update({"len": eval_len})
3061+
len_expr_ctx.update({"len": warp.types.type_length})
30753062

30763063
# We want to replace the expression code in-place,
30773064
# so reparse it to get the correct column info.

warp/tests/test_tuple.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,14 +136,17 @@ def loop_user_func(values: Tuple[int, int, int]):
136136
for i in range(wp.static(len(values))):
137137
out += values[i]
138138

139+
for i in range(len(values)):
140+
out += values[i] * 2
141+
139142
return out
140143

141144

142145
@wp.kernel
143146
def test_loop():
144147
t = (1, 2, 3)
145148
res = loop_user_func(t)
146-
wp.expect_eq(res, 6)
149+
wp.expect_eq(res, 18)
147150

148151

149152
@wp.func
@@ -152,26 +155,29 @@ def loop_variadic_any_user_func(values: Any):
152155
for i in range(wp.static(len(values))):
153156
out += values[i]
154157

158+
for i in range(len(values)):
159+
out += values[i] * 2
160+
155161
return out
156162

157163

158164
@wp.kernel
159165
def test_loop_variadic_any():
160166
t1 = (1,)
161167
res = loop_variadic_any_user_func(t1)
162-
wp.expect_eq(res, 1)
168+
wp.expect_eq(res, 3)
163169

164170
t2 = (2, 3)
165171
res = loop_variadic_any_user_func(t2)
166-
wp.expect_eq(res, 5)
172+
wp.expect_eq(res, 15)
167173

168174
t3 = (3, 4, 5)
169175
res = loop_variadic_any_user_func(t3)
170-
wp.expect_eq(res, 12)
176+
wp.expect_eq(res, 36)
171177

172178
t4 = (4, 5, 6, 7)
173179
res = loop_variadic_any_user_func(t4)
174-
wp.expect_eq(res, 22)
180+
wp.expect_eq(res, 66)
175181

176182

177183
@wp.func

warp/types.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1388,6 +1388,21 @@ def type_ctype(dtype):
13881388
return dtype._type_
13891389

13901390

1391+
def type_length(obj):
1392+
if is_tile(obj):
1393+
return obj.shape[0]
1394+
elif is_tuple(obj):
1395+
return len(obj.types)
1396+
elif get_origin(obj) is tuple:
1397+
return len(get_args(obj))
1398+
elif hasattr(obj, "_shape_"):
1399+
return obj._shape_[0]
1400+
elif hasattr(obj, "_length_"):
1401+
return obj._length_
1402+
1403+
return len(obj)
1404+
1405+
13911406
def type_size(dtype):
13921407
if dtype == float or dtype == int or isinstance(dtype, warp.codegen.Struct):
13931408
return 1

0 commit comments

Comments
 (0)