Skip to content

Commit dad8aae

Browse files
committed
add lower and upper bound functions, using binary search
1 parent 99ac006 commit dad8aae

File tree

6 files changed

+255
-3
lines changed

6 files changed

+255
-3
lines changed

pykokkos/core/translators/symbols_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __init__(self, members: PyKokkosMembers, pk_import: str, path: str):
5454
self.global_symbols.update(math_functions)
5555
self.global_symbols.update(allowed_types)
5656
self.global_symbols.update(view_dtypes)
57-
self.global_symbols.update(["self", "range", "math", "List", "abs", "inclusive_scan"])
57+
self.global_symbols.update(["self", "range", "math", "List", "abs", "inclusive_scan", "upper_bound"])
5858
self.global_symbols.add(pk_import)
5959

6060
self.global_symbols.update([field.declname for field in members.fields])

pykokkos/core/visitors/workunit_visitor.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,76 @@ def visit_Call(self, node: ast.Call) -> cppast.CallExpr:
363363

364364
return cppast.CallExpr(function, scan_args)
365365

366+
# Custom `upper_bound` implementation using binary search
367+
if name == "upper_bound":
368+
# Check if it's called via pk.upper_bound
369+
is_pk_call = (
370+
isinstance(node.func, ast.Attribute)
371+
and isinstance(node.func.value, ast.Name)
372+
and node.func.value.id == self.pk_import
373+
)
374+
375+
if not is_pk_call:
376+
return super().visit_Call(node)
377+
378+
# Expected signature: pk.upper_bound(view, size, value)
379+
if len(args) != 3:
380+
self.error(
381+
node,
382+
"pk.upper_bound() takes 3 arguments: view, size, value",
383+
)
384+
385+
view_expr = args[0]
386+
size_expr = args[1]
387+
value_expr = args[2]
388+
389+
# Generate binary search lambda inline
390+
from pykokkos.interface.algorithms.upper_bound import generate_upper_bound_binary_search
391+
392+
# Create lambda body with binary search
393+
lambda_body = generate_upper_bound_binary_search(view_expr, size_expr, value_expr)
394+
395+
# Create and invoke lambda
396+
lambda_expr = cppast.LambdaExpr("[&]", [], lambda_body)
397+
lambda_call = cppast.CallExpr(lambda_expr, [])
398+
399+
return lambda_call
400+
401+
# Custom `lower_bound` implementation using binary search
402+
if name == "lower_bound":
403+
# Check if it's called via pk.lower_bound
404+
is_pk_call = (
405+
isinstance(node.func, ast.Attribute)
406+
and isinstance(node.func.value, ast.Name)
407+
and node.func.value.id == self.pk_import
408+
)
409+
410+
if not is_pk_call:
411+
return super().visit_Call(node)
412+
413+
# Expected signature: pk.lower_bound(view, size, value)
414+
if len(args) != 3:
415+
self.error(
416+
node,
417+
"pk.lower_bound() takes 3 arguments: view, size, value",
418+
)
419+
420+
view_expr = args[0]
421+
size_expr = args[1]
422+
value_expr = args[2]
423+
424+
# Generate binary search lambda inline
425+
from pykokkos.interface.algorithms.lower_bound import generate_lower_bound_binary_search
426+
427+
# Create lambda body with binary search
428+
lambda_body = generate_lower_bound_binary_search(view_expr, size_expr, value_expr)
429+
430+
# Create and invoke lambda
431+
lambda_expr = cppast.LambdaExpr("[&]", [], lambda_body)
432+
lambda_call = cppast.CallExpr(lambda_expr, [])
433+
434+
return lambda_call
435+
366436
return super().visit_Call(node)
367437

368438
def is_nested_call(self, node: ast.FunctionDef) -> bool:

pykokkos/interface/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from .accumulator import Acc
2-
from .algorithms.inclusive_scan import inclusive_scan
2+
from .algorithms import inclusive_scan, lower_bound, upper_bound
33
from .atomic.atomic_fetch_op import (
44
atomic_fetch_add, atomic_fetch_and, atomic_fetch_div,
55
atomic_fetch_lshift, atomic_fetch_max, atomic_fetch_min,
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
from .inclusive_scan import inclusive_scan
2+
from .lower_bound import lower_bound
3+
from .upper_bound import upper_bound
24

3-
__all__ = ["inclusive_scan"]
5+
__all__ = ["inclusive_scan", "lower_bound", "upper_bound"]
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from pykokkos.interface.views import ViewType
2+
from pykokkos.core import cppast
3+
4+
def lower_bound(view: ViewType, size: int, value) -> int:
5+
"""
6+
Perform a lower bound search on a view
7+
8+
Returns the index of the first element not less than (i.e. greater or equal to) value,
9+
similar to std::lower_bound or thrust::lower_bound.
10+
11+
:param view: the view to search (must be sorted)
12+
:param size: the number of elements to search
13+
:param value: the value to search for
14+
:returns: the index of the first element >= value
15+
"""
16+
pass
17+
18+
19+
def generate_lower_bound_binary_search(
20+
view_expr: cppast.Expr, size_expr: cppast.Expr, value_expr: cppast.Expr
21+
) -> cppast.CompoundStmt:
22+
"""
23+
Generate binary search implementation for lower_bound.
24+
Returns a CompoundStmt that implements:
25+
26+
int left = 0;
27+
int right = size;
28+
int mid;
29+
while (left < right) {
30+
mid = left + (right - left) / 2;
31+
if (view[mid] < value) {
32+
left = mid + 1;
33+
} else {
34+
right = mid;
35+
}
36+
}
37+
return left;
38+
"""
39+
40+
# Variable declarations
41+
int_type = cppast.PrimitiveType("int32_t")
42+
43+
# int left = 0;
44+
left_var = cppast.DeclRefExpr("left")
45+
left_init = cppast.IntegerLiteral(0)
46+
left_decl = cppast.VarDecl(int_type, left_var, left_init)
47+
left_stmt = cppast.DeclStmt(left_decl)
48+
49+
# int right = size;
50+
right_var = cppast.DeclRefExpr("right")
51+
right_decl = cppast.VarDecl(int_type, right_var, size_expr)
52+
right_stmt = cppast.DeclStmt(right_decl)
53+
54+
# int mid;
55+
mid_var = cppast.DeclRefExpr("mid")
56+
mid_decl = cppast.VarDecl(int_type, mid_var, None)
57+
mid_stmt = cppast.DeclStmt(mid_decl)
58+
59+
# while (left < right)
60+
while_cond = cppast.BinaryOperator(left_var, right_var, cppast.BinaryOperatorKind.LT)
61+
62+
# mid = left + (right - left) / 2;
63+
right_minus_left = cppast.BinaryOperator(right_var, left_var, cppast.BinaryOperatorKind.Sub)
64+
div_expr = cppast.BinaryOperator(right_minus_left, cppast.IntegerLiteral(2), cppast.BinaryOperatorKind.Div)
65+
mid_calc = cppast.BinaryOperator(left_var, div_expr, cppast.BinaryOperatorKind.Add)
66+
mid_assign = cppast.AssignOperator([mid_var], mid_calc, cppast.BinaryOperatorKind.Assign)
67+
68+
# if (view[mid] < value)
69+
view_access = cppast.CallExpr(view_expr, [mid_var])
70+
if_cond = cppast.BinaryOperator(view_access, value_expr, cppast.BinaryOperatorKind.LT)
71+
72+
# left = mid + 1;
73+
mid_plus_one = cppast.BinaryOperator(mid_var, cppast.IntegerLiteral(1), cppast.BinaryOperatorKind.Add)
74+
left_assign = cppast.AssignOperator([left_var], mid_plus_one, cppast.BinaryOperatorKind.Assign)
75+
76+
# right = mid;
77+
right_assign = cppast.AssignOperator([right_var], mid_var, cppast.BinaryOperatorKind.Assign)
78+
79+
# if-else statement
80+
if_stmt = cppast.IfStmt(if_cond, left_assign, right_assign)
81+
82+
# while body
83+
while_body = cppast.CompoundStmt([mid_assign, if_stmt])
84+
while_stmt = cppast.WhileStmt(while_cond, while_body)
85+
86+
# return left;
87+
return_stmt = cppast.ReturnStmt(left_var)
88+
89+
# Complete function body
90+
return cppast.CompoundStmt([left_stmt, right_stmt, mid_stmt, while_stmt, return_stmt])
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
from pykokkos.interface.views import ViewType
2+
from pykokkos.core import cppast
3+
4+
def upper_bound(view: ViewType, size: int, value) -> int:
5+
"""
6+
Perform an upper bound search on a view
7+
8+
Returns the index of the first element greater than value,
9+
similar to std::upper_bound or thrust::upper_bound.
10+
11+
:param view: the view to search (must be sorted)
12+
:param size: the number of elements to search
13+
:param value: the value to search for
14+
:returns: the index of the first element greater than value
15+
"""
16+
pass
17+
18+
19+
def generate_upper_bound_binary_search(
20+
view_expr: cppast.Expr, size_expr: cppast.Expr, value_expr: cppast.Expr
21+
) -> cppast.CompoundStmt:
22+
"""
23+
Generate binary search implementation for upper_bound.
24+
Returns a CompoundStmt that implements:
25+
26+
int left = -1;
27+
int right = size;
28+
int mid;
29+
while (left + 1 < right) {
30+
mid = left + ((right - left) >> 1);
31+
if (view[mid] > value) {
32+
right = mid;
33+
} else {
34+
left = mid;
35+
}
36+
}
37+
return right;
38+
"""
39+
40+
# Variable declarations
41+
int_type = cppast.PrimitiveType("int32_t")
42+
43+
# int left = 0;
44+
left_var = cppast.DeclRefExpr("left")
45+
left_init = cppast.IntegerLiteral(0)
46+
left_decl = cppast.VarDecl(int_type, left_var, left_init)
47+
left_stmt = cppast.DeclStmt(left_decl)
48+
49+
# int right = size;
50+
right_var = cppast.DeclRefExpr("right")
51+
right_decl = cppast.VarDecl(int_type, right_var, size_expr)
52+
right_stmt = cppast.DeclStmt(right_decl)
53+
54+
# int mid;
55+
mid_var = cppast.DeclRefExpr("mid")
56+
mid_decl = cppast.VarDecl(int_type, mid_var, None)
57+
mid_stmt = cppast.DeclStmt(mid_decl)
58+
59+
# while (left < right)
60+
while_cond = cppast.BinaryOperator(left_var, right_var, cppast.BinaryOperatorKind.LT)
61+
62+
# mid = left + (right - left) / 2;
63+
right_minus_left = cppast.BinaryOperator(right_var, left_var, cppast.BinaryOperatorKind.Sub)
64+
div_expr = cppast.BinaryOperator(right_minus_left, cppast.IntegerLiteral(2), cppast.BinaryOperatorKind.Div)
65+
mid_calc = cppast.BinaryOperator(left_var, div_expr, cppast.BinaryOperatorKind.Add)
66+
mid_assign = cppast.AssignOperator([mid_var], mid_calc, cppast.BinaryOperatorKind.Assign)
67+
68+
# if (view[mid] > value)
69+
view_access = cppast.CallExpr(view_expr, [mid_var])
70+
if_cond = cppast.BinaryOperator(view_access, value_expr, cppast.BinaryOperatorKind.GT)
71+
72+
# right = mid;
73+
right_assign = cppast.AssignOperator([right_var], mid_var, cppast.BinaryOperatorKind.Assign)
74+
75+
# left = mid + 1;
76+
mid_plus_one = cppast.BinaryOperator(mid_var, cppast.IntegerLiteral(1), cppast.BinaryOperatorKind.Add)
77+
left_assign = cppast.AssignOperator([left_var], mid_plus_one, cppast.BinaryOperatorKind.Assign)
78+
79+
# if-else statement
80+
if_stmt = cppast.IfStmt(if_cond, right_assign, left_assign)
81+
82+
# while body
83+
while_body = cppast.CompoundStmt([mid_assign, if_stmt])
84+
while_stmt = cppast.WhileStmt(while_cond, while_body)
85+
86+
# return left;
87+
return_stmt = cppast.ReturnStmt(left_var)
88+
89+
# Complete function body
90+
return cppast.CompoundStmt([left_stmt, right_stmt, mid_stmt, while_stmt, return_stmt])

0 commit comments

Comments
 (0)