Skip to content

Commit 567d6c1

Browse files
committed
Remove autotune package dependency from tests and move fusion examples to standalone package with relative imports
1 parent f41a8fa commit 567d6c1

File tree

6 files changed

+78
-451
lines changed

6 files changed

+78
-451
lines changed

examples/plot_all.py

Lines changed: 0 additions & 13 deletions
This file was deleted.
Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22

33
import numpy as np
44

5-
from autotune.analysis.metrics import check_correctness
6-
from fusion.fusion_chain import FusionChain
7-
from fusion.operators import Operator
8-
from fusion.tensors import Tensor
5+
from .fusion_chain import FusionChain
6+
from .operators import Operator
7+
from .tensors import Tensor
98

109

1110
class Rowmax(Operator):
@@ -328,10 +327,10 @@ def test_flash_attention_fusion() -> None:
328327
)
329328

330329
result_standard = fusion.execute(fusion_axis="fusion", fusion_step_size=seq_len, input_tensors=input_tensors)
331-
check_correctness(golden, result_standard.data, atol, rtol, verbose=True)
330+
np.testing.assert_allclose(result_standard.data, golden, atol=atol, rtol=rtol)
332331

333332
result_fused = fusion.execute(fusion_axis="fusion", fusion_step_size=32, input_tensors=input_tensors, verbose=True)
334-
check_correctness(golden, result_fused.data, atol, rtol, verbose=True)
333+
np.testing.assert_allclose(result_fused.data, golden, atol=atol, rtol=rtol)
335334

336335

337336
if __name__ == "__main__":
Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22

33
import numpy as np
44

5-
from autotune.analysis.metrics import check_correctness
6-
from fusion.fusion_chain import FusionChain
7-
from fusion.operators import Operator
8-
from fusion.tensors import Tensor
5+
from .fusion_chain import FusionChain
6+
from .operators import Operator
7+
from .tensors import Tensor
98

109

1110
class SumSquares(Operator):
@@ -214,8 +213,8 @@ def test_rmsnorm_matmul_fusion() -> None:
214213
result_fused = fusion.execute(fusion_axis="hidden", fusion_step_size=256, input_tensors=input_tensors, verbose=True)
215214
result_standard = fusion.execute(fusion_axis="hidden", fusion_step_size=hidden_dim, input_tensors=input_tensors)
216215
golden = rmsnorm_matmul_golden(lhs, rhs, epsilon)
217-
check_correctness(golden, result_standard.data, atol, rtol, verbose=True)
218-
check_correctness(golden, result_fused.data, atol, rtol, verbose=True)
216+
np.testing.assert_allclose(result_standard.data, golden, atol=atol, rtol=rtol)
217+
np.testing.assert_allclose(result_fused.data, golden, atol=atol, rtol=rtol)
219218

220219

221220
if __name__ == "__main__":

test/golden/autotune.py

Lines changed: 0 additions & 169 deletions
This file was deleted.

0 commit comments

Comments
 (0)