Skip to content

Commit 0cca5fc

Browse files
committed
feat: allow Tensor.store API to receive .var as value
This patch allows Tensor.store API (which is connected to `nki.language.store`) to accept a more generic `Core.Value` type. The motivation is tracing of interop/test/examples/matmul.py, specifically the `nki_matmul_basic_` function. After apply #111, tracing the Python function was raising the following error message: ``` error: line 44: nl.store(result[i_out_p, i_out_f], value=result_sbuf) ^-- expecting tensor access ``` It is because its `value` keyword argument was having the following expression: ``` KLR.Trace.Term.expr (KLR.Core.Expr.value (KLR.Core.Value.var "5")) (KLR.Trace.TermType.obj `object) ``` which could not be converted to Access through the FromNKI typeclass. The "5" temporary variable was emerging from the right hand side of the definition of `result_sbuf`: ``` result_sbuf = nl.copy(result_psum, dtype=result.dtype) ``` To convert the value of "5", it seems we need to get the generated trace and find assignment to "5" because: ``` def RValue : Term -> Trace Term ... | .expr e@(.call ..) ty => do let v := (<- genName).toString add_stmt (.assign v e) return .expr (.value $ .var v) ty ``` the `add_stmt` is just adding a Core statement to `State.body`. Skimming through `State.body` and finding this assignment to "5" didn't seem something we wanted to do inside Tensor.store, so instead I slightly chose a conservative approach and simply removed the shape checker. But any other reasonable option is still fine with me.
1 parent 6166f2d commit 0cca5fc

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

KLR/Trace/Tensor.lean

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,8 @@ nki load (src : Access) (dtype : Dtype := .float32) := do
5151
let dst <- declare "load" dtype shape .sbuf
5252
return .store (.simple dst) .load [.access src]
5353

54-
nki store (dst : Access) (value : Access) := do
55-
let s1 <- dst.shape
56-
let s2 <- value.shape
57-
if s1 != s2 then
58-
throw s!"incompatible shapes {s1} {s2}"
59-
return Term.store dst .save [.access value]
54+
nki store (dst : Access) (value : Core.Value) := do
55+
return Term.store dst .save [value]
6056

6157
nki tensor_scalar (data : Access)
6258
(op0 : AluOp)

0 commit comments

Comments
 (0)