Skip to content

Commit 52dadd5

Browse files
committed
Add CLI tests
1 parent 4a9486b commit 52dadd5

File tree

4 files changed

+19983
-0
lines changed

4 files changed

+19983
-0
lines changed
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import os
2+
from unittest.mock import patch
3+
4+
import pytest
5+
6+
from cfpq_decomposer_cli.decompose_cflr_matrix import main as decompose_main
7+
from src.utils.useful_paths import POCR_FORMAT_DATA
8+
from test.cfpq_decomposer_cli.utils import parse_decomposer_cli_output, CorruptDecomposer
9+
from test.utils import find_graph_file, find_grammar_file
10+
11+
EXPECTED_SEDGES = 3_968_276
12+
DATA_PATH = os.path.join(POCR_FORMAT_DATA, 'leela')
13+
GRAPH_PATH = find_graph_file(DATA_PATH)
14+
GRAMMAR_PATH = find_grammar_file(DATA_PATH)
15+
16+
17+
@pytest.mark.CI
18+
@pytest.mark.parametrize("args", [[], ["--prototype"]], ids=["default", "prototype"])
19+
def test_cli_decompose_cflr_matrix_modes(args, capsys):
20+
decompose_main([GRAPH_PATH, GRAMMAR_PATH] + args)
21+
captured = capsys.readouterr()
22+
metrics = parse_decomposer_cli_output(captured.out)
23+
24+
assert metrics.s_edges == EXPECTED_SEDGES, f"#SEdges mismatch for args={args}"
25+
assert metrics.compression_factor > 10, f"Compression factor too low for args={args}"
26+
assert metrics.is_valid, f"Compression invalid for args={args}"
27+
28+
@pytest.mark.CI
29+
@patch('cfpq_decomposer_cli.decompose_cflr_matrix.HighPerformanceDecomposer', new=CorruptDecomposer)
30+
def test_cli_decompose_cflr_matrix_corrupted(monkeypatch, capsys):
31+
decompose_main([GRAPH_PATH, GRAMMAR_PATH])
32+
captured = capsys.readouterr()
33+
metrics = parse_decomposer_cli_output(captured.out)
34+
assert not metrics.is_valid, "Expected compression to be invalid when decomposer is corrupted"

test/cfpq_decomposer_cli/utils.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from dataclasses import dataclass
2+
from typing import Tuple
3+
4+
from graphblas.core.matrix import Matrix
5+
6+
from cfpq_decomposer.high_performance_decomposer import HighPerformanceDecomposer
7+
from cfpq_matrix.matrix_utils import stack
8+
9+
10+
@dataclass
11+
class DecomposerCliMetrics:
12+
s_edges: int
13+
compression_factor: float
14+
is_valid: bool
15+
16+
def parse_decomposer_cli_output(output: str) -> DecomposerCliMetrics:
17+
lines = output.strip().splitlines()
18+
sedges_line = next(l for l in lines if l.startswith('#SEdges'))
19+
s_edges = int(sedges_line.split()[1])
20+
cf_line = next(l for l in lines if l.startswith('Compression factor'))
21+
compression_factor = float(cf_line.split()[2])
22+
valid_line = next(l for l in lines if l.startswith('Is compression valid'))
23+
is_valid = valid_line.split()[3] == 'True'
24+
return DecomposerCliMetrics(
25+
s_edges=s_edges,
26+
compression_factor=compression_factor,
27+
is_valid=is_valid
28+
)
29+
30+
class CorruptDecomposer(HighPerformanceDecomposer):
31+
def decompose(self, matrix: Matrix) -> Tuple[Matrix, Matrix]:
32+
left, right = super().decompose(matrix)
33+
nrows, _ = left.shape
34+
_, ncols = right.shape
35+
36+
rows = list(range(nrows))
37+
cols = [0] * nrows
38+
extra_col = Matrix.from_coo(rows, cols, True, nrows=nrows, ncols=1)
39+
40+
rows2 = [0] * ncols
41+
cols2 = list(range(ncols))
42+
extra_row = Matrix.from_coo(rows2, cols2, True, nrows=1, ncols=ncols)
43+
44+
return stack([[left, extra_col]]), stack([[right], [extra_row]])

test/pocr_data/leela/aa.cnf

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
M DV d
2+
DV dbar V
3+
V A_r V
4+
V V A
5+
V FV_i f_i
6+
V M
7+
V
8+
FV_i fbar_i V
9+
A a M
10+
A a
11+
A
12+
A_r M abar
13+
A_r abar
14+
A_r
15+
16+
Count:
17+
V

0 commit comments

Comments
 (0)