44from __future__ import annotations
55
66import math
7+ import sys
78
89import pytest
910import torch
@@ -400,7 +401,21 @@ def simulator1d(theta):
400401@pytest .mark .parametrize (
401402 "bidirectional" , [True , False ], ids = ["one-directional" , "bi-directional" ]
402403)
403- @pytest .mark .parametrize ("mode" , ["loop" , "scan" ], ids = ["loop" , "scan" ])
404+ @pytest .mark .parametrize (
405+ "mode" ,
406+ [
407+ "loop" ,
408+ pytest .param (
409+ "scan" ,
410+ marks = pytest .mark .xfail (
411+ condition = sys .version_info >= (3 , 13 ),
412+ reason = "torch.compiler is not yet supported on Python >= 3.13" ,
413+ strict = True ,
414+ ),
415+ ),
416+ ],
417+ ids = ["loop" , "scan" ],
418+ )
404419def test_lru_isolated (
405420 bidirectional : bool ,
406421 mode : str ,
@@ -434,7 +449,21 @@ def test_lru_isolated(
434449@pytest .mark .parametrize (
435450 "bidirectional" , [True , False ], ids = ["one-directional" , "bi-directional" ]
436451)
437- @pytest .mark .parametrize ("mode" , ["loop" , "scan" ], ids = ["loop" , "scan" ])
452+ @pytest .mark .parametrize (
453+ "mode" ,
454+ [
455+ "loop" ,
456+ pytest .param (
457+ "scan" ,
458+ marks = pytest .mark .xfail (
459+ condition = sys .version_info >= (3 , 13 ),
460+ reason = "torch.compiler is not yet supported on Python >= 3.13" ,
461+ strict = True ,
462+ ),
463+ ),
464+ ],
465+ ids = ["loop" , "scan" ],
466+ )
438467@pytest .mark .parametrize (
439468 "apply_input_normalization" ,
440469 [True , False ],
@@ -454,6 +483,7 @@ def test_lru_block_isolated(
454483 sequence_len : int = 50 ,
455484):
456485 """Run some random data through an LRUBlock."""
486+
457487 lru_block = LRUBlock (
458488 hidden_dim = hidden_dim ,
459489 state_dim = state_dim ,
@@ -477,7 +507,21 @@ def test_lru_block_isolated(
477507@pytest .mark .parametrize (
478508 "bidirectional" , [True , False ], ids = ["one-directional" , "bi-directional" ]
479509)
480- @pytest .mark .parametrize ("mode" , ["loop" , "scan" ], ids = ["loop" , "scan" ])
510+ @pytest .mark .parametrize (
511+ "mode" ,
512+ [
513+ "loop" ,
514+ pytest .param (
515+ "scan" ,
516+ marks = pytest .mark .xfail (
517+ condition = sys .version_info >= (3 , 13 ),
518+ reason = "torch.compiler is not yet supported on Python >= 3.13" ,
519+ strict = True ,
520+ ),
521+ ),
522+ ],
523+ ids = ["loop" , "scan" ],
524+ )
481525@pytest .mark .parametrize (
482526 "aggregate_fcn" , ["last_step" , "mean" ], ids = ["last-step" , "mean" ]
483527)
@@ -591,6 +635,11 @@ def _simulator(thetas: Tensor, num_time_steps=500, dt=0.002, eps=0.05) -> Tensor
591635 assert samples .shape == (10 , 2 )
592636
593637
638+ @pytest .mark .xfail (
639+ condition = sys .version_info >= (3 , 13 ),
640+ reason = "torch.compiler is not yet supported on Python >= 3.13" ,
641+ strict = True ,
642+ )
594643def test_scan (
595644 input_dim : int = 3 ,
596645 output_dim : int = 3 ,
0 commit comments