Skip to content

Allow Tensor.store API to receive .var as value #120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

aqjune-aws
Copy link
Contributor

This patch allows Tensor.store API (which is connected to nki.language.store) to accept a more generic Core.Value type.

The motivation is tracing of interop/test/examples/matmul.py, specifically the nki_matmul_basic_ function.

After apply #111, tracing the Python function was raising the following error message:

error:
line 44:
  nl.store(result[i_out_p, i_out_f], value=result_sbuf)
  ^-- expecting tensor access

It is because its value keyword argument was having the following expression:

KLR.Trace.Term.expr (KLR.Core.Expr.value (KLR.Core.Value.var "5")) (KLR.Trace.TermType.obj `object)

which could not be converted to Access through the FromNKI typeclass.

The "5" temporary variable was emerging from the right hand side of the definition of result_sbuf:

result_sbuf = nl.copy(result_psum, dtype=result.dtype)

To convert the value of "5", it seems we need to get the generated trace and find assignment to "5" because:

def RValue : Term -> Trace Term
...
  | .expr e@(.call ..) ty => do
       let v := (<- genName).toString
       add_stmt (.assign v e)
       return .expr (.value $ .var v) ty

the add_stmt is just adding a Core statement to State.body.

Skimming through State.body and finding this assignment to "5" didn't seem something we wanted to do inside Tensor.store, so instead I slightly chose a conservative approach and simply removed the shape checker.

But any other reasonable option is still fine with me.

This patch allows Tensor.store API (which is connected to `nki.language.store`) to accept a more
generic `Core.Value` type.

The motivation is tracing of interop/test/examples/matmul.py, specifically the `nki_matmul_basic_` function.

After apply leanprover#111, tracing the Python function was raising the following error message:

```
error:
line 44:
  nl.store(result[i_out_p, i_out_f], value=result_sbuf)
  ^-- expecting tensor access
```

It is because its `value` keyword argument was having the following expression:
```
KLR.Trace.Term.expr (KLR.Core.Expr.value (KLR.Core.Value.var "5")) (KLR.Trace.TermType.obj `object)
```
which could not be converted to Access through the FromNKI typeclass.

The "5" temporary variable was emerging from the right hand side of the definition of `result_sbuf`:

```
result_sbuf = nl.copy(result_psum, dtype=result.dtype)
```

To convert the value of "5", it seems we need to get the generated trace and find assignment to "5"
because:

```
def RValue : Term -> Trace Term
...
  | .expr e@(.call ..) ty => do
       let v := (<- genName).toString
       add_stmt (.assign v e)
       return .expr (.value $ .var v) ty
```

the `add_stmt` is just adding a Core statement to `State.body`.

Skimming through `State.body` and finding this assignment to "5" didn't seem something we wanted to do inside Tensor.store,
so instead I slightly chose a conservative approach and simply removed the shape checker.

But any other reasonable option is still fine with me.
@aqjune-aws
Copy link
Contributor Author

After this patch and #111, the nki_matmul_basic_ function can finally be fully traced! :)

@govereau
Copy link
Collaborator

govereau commented May 7, 2025

I think a better approach would be to simplify the variable by looking it up in the environment, as the resulting store (with the variable) doesn't correspond to anything the HW can do. At the KLR Core level, we should not have anything that does not have a corresponding ISA (or BIR) representation.

However, the problem here is more fundamental. The core issue is that matmul is not referentially transparent, and it doesn't make sense to transform it the way we are doing. This statement:

result_psum = nl.matmul(lhs_tile, rhs_tile, transpose_x=True)

is non-sensical because matmul takes the left-hand side as an input (and output) argument. This is why there is a lot of discussion about "you must write += with matmul, etc." The matmul functions need to be changed to something like:

nl.matmul(dst=result_psum, lhs_tile, rhs_tile, transpose_x=True, accum_mode=zero)

The above "statement form" would trace with no issues.

Of course, there are other operators, like tensor_tensor or tensor_scalar which would also have the issue you spotted here. However, the current thinking is that all ISA functions must be statements and not return any values. If we go this way, then we will remove store from KLR as it will not be needed. In fact, the whole add_stmt mechanism could be removed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants