Skip to content

With torch backend jit_compile=True causes AssertionError #22282

@innat

Description

@innat

To reproduce the error, run this official guide with torch backend.

import os
os.environ["KERAS_BACKEND"] = "torch" 

model.compile(
    optimizer=keras.optimizers.RMSprop(),  # Optimizer
    # Loss function to minimize
    loss=keras.losses.SparseCategoricalCrossentropy(),
    # List of metrics to monitor
    metrics=[keras.metrics.SparseCategoricalAccuracy()],
    jit_compile=True
)

Logs

Fit model on training data
Epoch 1/2
W0225 07:56:20.566000 571 torch/_inductor/utils.py:1679] [1/0_1] Not enough SMs to use max_autotune_gemm mode
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
[/tmp/ipython-input-3880486385.py](https://localhost:8080/#) in <cell line: 0>()
      1 print("Fit model on training data")
----> 2 history = model.fit(
      3     x_train,
      4     y_train,
      5     batch_size=64,

25 frames
[/usr/local/lib/python3.12/dist-packages/torch/_dynamo/bytecode_transformation.py](https://localhost:8080/#) in compute_exception_table(instructions)
    988                 instructions, indexof[inst.exn_tab_entry.start]
    989             ).offset
--> 990             assert start is not None
    991             # point to the last 2 bytes of the end instruction
    992             end = (

AssertionError: 

from user code:
   File "/usr/local/lib/python3.12/dist-packages/keras/src/trainers/compile_utils.py", line 699, in call
    loss_fn(y_true, y_pred, sample_weight), dtype=self.dtype

Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions