Skip to content

Commit 6e88b2b

Browse files
Add tests for LSTM class
1 parent 495531c commit 6e88b2b

File tree

1 file changed

+73
-0
lines changed

1 file changed

+73
-0
lines changed
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
from collections.abc import Callable
2+
from unittest.mock import Mock, MagicMock
3+
4+
import pytest
5+
import torch
6+
from pytest_mock import MockerFixture
7+
8+
from country_workspace.contrib.name_parser.parser import LSTM
9+
10+
11+
@pytest.fixture
12+
def embedding_mock(mocker: MockerFixture) -> torch.nn.Embedding:
13+
return mocker.patch("country_workspace.contrib.name_parser.parser.nn.Embedding")
14+
15+
16+
@pytest.fixture
17+
def lstm_mock(mocker: MockerFixture) -> torch.nn.LSTM:
18+
return mocker.patch("country_workspace.contrib.name_parser.parser.nn.LSTM")
19+
20+
21+
@pytest.fixture
22+
def linear_mock(mocker: MockerFixture) -> torch.nn.Linear:
23+
return mocker.patch("country_workspace.contrib.name_parser.parser.nn.Linear")
24+
25+
26+
@pytest.fixture
27+
def log_softmax_mock(mocker: MockerFixture) -> torch.nn.LogSoftmax:
28+
return mocker.patch("country_workspace.contrib.name_parser.parser.nn.LogSoftmax")
29+
30+
31+
@pytest.fixture
32+
def zeros_mock(mocker: MockerFixture) -> torch.Tensor:
33+
return mocker.patch("country_workspace.contrib.name_parser.parser.torch.zeros")
34+
35+
36+
def test_lstm_class_init(
37+
embedding_mock: torch.nn.Embedding,
38+
lstm_mock: torch.nn.LSTM,
39+
linear_mock: torch.nn.Linear,
40+
log_softmax_mock: torch.nn.LogSoftmax,
41+
) -> None:
42+
instance = Mock(spec=LSTM)
43+
input_size, hidden_size, output_size, num_layers = range(4)
44+
45+
LSTM.__init__(instance, input_size, hidden_size, output_size, num_layers)
46+
47+
assert instance.hidden_size == hidden_size
48+
assert instance.num_layers == num_layers
49+
assert instance.embedding == embedding_mock.return_value
50+
assert instance.lstm == lstm_mock.return_value
51+
assert instance.fc == linear_mock.return_value
52+
assert instance.softmax == log_softmax_mock.return_value
53+
54+
embedding_mock.assert_called_with(input_size, hidden_size)
55+
lstm_mock.assert_called_once_with(hidden_size, hidden_size, num_layers, batch_first=True)
56+
linear_mock.assert_called_once_with(hidden_size, output_size)
57+
log_softmax_mock.assert_called_once_with(dim=1)
58+
59+
60+
def test_lstm_class_forward(
61+
embedding_mock: torch.nn.Embedding,
62+
lstm_mock: torch.nn.LSTM,
63+
linear_mock: torch.nn.Linear,
64+
log_softmax_mock: torch.nn.LogSoftmax,
65+
zeros_mock: Callable,
66+
) -> None:
67+
instance = Mock(spec=LSTM)
68+
input_ = Mock(spec=torch.Tensor)
69+
input_size, hidden_size, output_size, num_layers = range(4)
70+
lstm_mock.return_value.return_value = MagicMock(), None
71+
LSTM.__init__(instance, input_size, hidden_size, output_size, num_layers)
72+
73+
assert LSTM.forward(instance, input_) == instance.softmax.return_value

0 commit comments

Comments
 (0)