Skip to content

Commit ed44c0e

Browse files
Increase coverage
1 parent 74f5779 commit ed44c0e

File tree

2 files changed

+97
-29
lines changed

2 files changed

+97
-29
lines changed

src/country_workspace/contrib/name_parser/parser.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212

1313
BASE_PATH = Path(__file__).parent.parent.parent
1414

15+
CONFIG_PATH_TEMPLATE = "data/name_parser/models/{country_code}.txt"
16+
MODEL_PATH_TEMPLATE = "data/name_parser/models/{country_code}.pt"
17+
1518

1619
class LSTM(nn.Module):
1720
def __init__(self, input_size: int, hidden_size: int, output_size: int, num_layers: int = 1) -> None:
@@ -42,7 +45,7 @@ def forward(self, input_: torch.Tensor) -> torch.Tensor:
4245

4346

4447
def read_config(country_code: str) -> tuple[Alphabet, int, ModelArgs]:
45-
with (BASE_PATH / f"data/name_parser/models/{country_code}.txt").open() as f:
48+
with (BASE_PATH / CONFIG_PATH_TEMPLATE.format(country_code=country_code)).open() as f:
4649
lines = tuple(line.rstrip("\n") for line in f.readlines())
4750

4851
return (
@@ -54,7 +57,7 @@ def read_config(country_code: str) -> tuple[Alphabet, int, ModelArgs]:
5457

5558
def load_model(country_code: str, *args: int) -> nn.Module:
5659
rnn = LSTM(*args, num_layers=2)
57-
rnn.load_state_dict(torch.load(BASE_PATH / f"data/name_parser/models/{country_code}.pt"))
60+
rnn.load_state_dict(torch.load(BASE_PATH / MODEL_PATH_TEMPLATE.format(country_code=country_code)))
5861
rnn.to(DEVICE)
5962
rnn.eval()
6063
return rnn
Lines changed: 92 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,50 @@
1-
from collections.abc import Callable
21
from unittest.mock import Mock, MagicMock, call
32

43
import pytest
54
import torch
65
from pytest_mock import MockerFixture
76

8-
from country_workspace.contrib.name_parser.parser import LSTM
7+
from country_workspace.contrib.name_parser.parser import (
8+
LSTM,
9+
read_config,
10+
load_model,
11+
get_line_to_tensor_converter,
12+
get_parser,
13+
BASE_PATH,
14+
MODEL_PATH_TEMPLATE,
15+
)
916

1017

1118
@pytest.fixture
12-
def embedding_mock(mocker: MockerFixture) -> torch.nn.Embedding:
19+
def nn_embedding_mock(mocker: MockerFixture) -> MagicMock:
1320
return mocker.patch("country_workspace.contrib.name_parser.parser.nn.Embedding")
1421

1522

1623
@pytest.fixture
17-
def lstm_mock(mocker: MockerFixture) -> torch.nn.LSTM:
24+
def nn_lstm_mock(mocker: MockerFixture) -> MagicMock:
1825
return mocker.patch("country_workspace.contrib.name_parser.parser.nn.LSTM")
1926

2027

2128
@pytest.fixture
22-
def linear_mock(mocker: MockerFixture) -> torch.nn.Linear:
29+
def nn_linear_mock(mocker: MockerFixture) -> MagicMock:
2330
return mocker.patch("country_workspace.contrib.name_parser.parser.nn.Linear")
2431

2532

2633
@pytest.fixture
27-
def log_softmax_mock(mocker: MockerFixture) -> torch.nn.LogSoftmax:
34+
def nn_log_softmax_mock(mocker: MockerFixture) -> MagicMock:
2835
return mocker.patch("country_workspace.contrib.name_parser.parser.nn.LogSoftmax")
2936

3037

3138
@pytest.fixture
32-
def zeros_mock(mocker: MockerFixture) -> torch.Tensor:
39+
def torch_zeros_mock(mocker: MockerFixture) -> MagicMock:
3340
return mocker.patch("country_workspace.contrib.name_parser.parser.torch.zeros")
3441

3542

3643
def test_lstm_class_init(
37-
embedding_mock: torch.nn.Embedding,
38-
lstm_mock: torch.nn.LSTM,
39-
linear_mock: torch.nn.Linear,
40-
log_softmax_mock: torch.nn.LogSoftmax,
44+
nn_embedding_mock: MagicMock,
45+
nn_lstm_mock: MagicMock,
46+
nn_linear_mock: MagicMock,
47+
nn_log_softmax_mock: MagicMock,
4148
) -> None:
4249
instance = Mock(spec=LSTM)
4350
input_size, hidden_size, output_size, num_layers = range(4)
@@ -46,28 +53,28 @@ def test_lstm_class_init(
4653

4754
assert instance.hidden_size == hidden_size
4855
assert instance.num_layers == num_layers
49-
assert instance.embedding == embedding_mock.return_value
50-
assert instance.lstm == lstm_mock.return_value
51-
assert instance.fc == linear_mock.return_value
52-
assert instance.softmax == log_softmax_mock.return_value
56+
assert instance.embedding == nn_embedding_mock.return_value
57+
assert instance.lstm == nn_lstm_mock.return_value
58+
assert instance.fc == nn_linear_mock.return_value
59+
assert instance.softmax == nn_log_softmax_mock.return_value
5360

54-
embedding_mock.assert_called_with(input_size, hidden_size)
55-
lstm_mock.assert_called_once_with(hidden_size, hidden_size, num_layers, batch_first=True)
56-
linear_mock.assert_called_once_with(hidden_size, output_size)
57-
log_softmax_mock.assert_called_once_with(dim=1)
61+
nn_embedding_mock.assert_called_with(input_size, hidden_size)
62+
nn_lstm_mock.assert_called_once_with(hidden_size, hidden_size, num_layers, batch_first=True)
63+
nn_linear_mock.assert_called_once_with(hidden_size, output_size)
64+
nn_log_softmax_mock.assert_called_once_with(dim=1)
5865

5966

6067
def test_lstm_class_forward(
61-
embedding_mock: torch.nn.Embedding,
62-
lstm_mock: torch.nn.LSTM,
63-
linear_mock: torch.nn.Linear,
64-
log_softmax_mock: torch.nn.LogSoftmax,
65-
zeros_mock: Callable,
68+
nn_embedding_mock: MagicMock,
69+
nn_lstm_mock: MagicMock,
70+
nn_linear_mock: MagicMock,
71+
nn_log_softmax_mock: MagicMock,
72+
torch_zeros_mock: MagicMock,
6673
) -> None:
6774
instance = Mock(spec=LSTM)
6875
input_ = Mock(spec=torch.Tensor)
6976
input_size, hidden_size, output_size, num_layers = range(4)
70-
lstm_mock.return_value.return_value = (lstm_out := MagicMock()), None
77+
nn_lstm_mock.return_value.return_value = (lstm_out := MagicMock()), None
7178
LSTM.__init__(instance, input_size, hidden_size, output_size, num_layers)
7279

7380
assert LSTM.forward(instance, input_) == instance.softmax.return_value
@@ -81,7 +88,7 @@ def test_lstm_class_forward(
8188
c,
8289
]
8390
)
84-
zeros_mock.assert_has_calls(
91+
torch_zeros_mock.assert_has_calls(
8592
[
8693
c0 := call(num_layers, instance.embedding.return_value.size.return_value, hidden_size),
8794
c1 := call().to(input_.device),
@@ -90,8 +97,66 @@ def test_lstm_class_forward(
9097
]
9198
)
9299
instance.lstm.assert_called_once_with(
93-
instance.embedding.return_value, (zt := zeros_mock.return_value.to.return_value, zt)
100+
instance.embedding.return_value, (zt := torch_zeros_mock.return_value.to.return_value, zt)
94101
)
95102
lstm_out.__getitem__.assert_called_once_with((slice(None), -1, slice(None)))
96103
instance.fc.assert_called_once_with(lstm_out.__getitem__.return_value)
97104
instance.softmax.assert_called_once_with(instance.fc.return_value)
105+
106+
107+
def test_read_config(mocker: MockerFixture) -> None:
108+
config = (
109+
alphabet := ("a", "b", "c"),
110+
max_name_len := 42,
111+
rnn_args := (1, 2, 3),
112+
)
113+
open_mock = mocker.patch("country_workspace.contrib.name_parser.parser.Path.open")
114+
open_mock.return_value.__enter__.return_value.readlines.return_value = (
115+
"".join(alphabet) + "\n",
116+
str(max_name_len) + "\n",
117+
" ".join(map(str, rnn_args)) + "\n",
118+
)
119+
120+
assert read_config("CNT") == config
121+
open_mock.assert_called_once()
122+
123+
124+
def test_load_model(mocker: MockerFixture) -> None:
125+
device_mock = mocker.patch("country_workspace.contrib.name_parser.parser.DEVICE")
126+
lstm_mock = mocker.patch("country_workspace.contrib.name_parser.parser.LSTM")
127+
load_mock = mocker.patch("country_workspace.contrib.name_parser.parser.torch.load")
128+
rnn_args = (1, 2, 3)
129+
country_code = "CNT"
130+
131+
assert load_model(country_code, *rnn_args) == lstm_mock.return_value
132+
133+
lstm_mock.assert_called_once_with(*rnn_args, num_layers=2)
134+
load_mock.assert_called_once_with(BASE_PATH / MODEL_PATH_TEMPLATE.format(country_code=country_code))
135+
lstm_mock.return_value.load_state_dict.assert_called_once_with(load_mock.return_value)
136+
lstm_mock.return_value.to.assert_called_once_with(device_mock)
137+
lstm_mock.return_value.eval.assert_called_once()
138+
139+
140+
def test_get_line_to_tensor_converter(mocker: MockerFixture) -> None:
141+
torch_ones_mock = mocker.patch("country_workspace.contrib.name_parser.parser.torch.ones")
142+
converter = get_line_to_tensor_converter(_alphabet := MagicMock(), _max_name_len := 42)
143+
assert converter("Name") == torch_ones_mock.return_value.__mul__.return_value
144+
145+
146+
def test_get_parser(mocker: MockerFixture) -> None:
147+
read_config_mock = mocker.patch("country_workspace.contrib.name_parser.parser.read_config")
148+
read_config_mock.return_value = (alphabet := Mock(), max_name_len := Mock(), rnn_args := (Mock(),))
149+
load_model_mock = mocker.patch("country_workspace.contrib.name_parser.parser.load_model")
150+
get_line_to_tensor_converter_mock = mocker.patch(
151+
"country_workspace.contrib.name_parser.parser.get_line_to_tensor_converter"
152+
)
153+
mocker.patch("country_workspace.contrib.name_parser.parser.torch.exp")
154+
mocker.patch("country_workspace.contrib.name_parser.parser.torch.argmax")
155+
name_types_mock = mocker.patch("country_workspace.contrib.name_parser.parser.NAME_TYPES")
156+
157+
parser = get_parser(country_code := "CNT")
158+
assert parser("Full Name") == [nt := name_types_mock.__getitem__.return_value, nt]
159+
160+
read_config_mock.assert_called_once_with(country_code)
161+
load_model_mock.assert_called_once_with(country_code, *rnn_args)
162+
get_line_to_tensor_converter_mock.assert_called_once_with(alphabet, max_name_len)

0 commit comments

Comments
 (0)