1- from collections .abc import Callable
21from unittest .mock import Mock , MagicMock , call
32
43import pytest
54import torch
65from pytest_mock import MockerFixture
76
8- from country_workspace .contrib .name_parser .parser import LSTM
7+ from country_workspace .contrib .name_parser .parser import (
8+ LSTM ,
9+ read_config ,
10+ load_model ,
11+ get_line_to_tensor_converter ,
12+ get_parser ,
13+ BASE_PATH ,
14+ MODEL_PATH_TEMPLATE ,
15+ )
916
1017
1118@pytest .fixture
12- def embedding_mock (mocker : MockerFixture ) -> torch . nn . Embedding :
19+ def nn_embedding_mock (mocker : MockerFixture ) -> MagicMock :
1320 return mocker .patch ("country_workspace.contrib.name_parser.parser.nn.Embedding" )
1421
1522
1623@pytest .fixture
17- def lstm_mock (mocker : MockerFixture ) -> torch . nn . LSTM :
24+ def nn_lstm_mock (mocker : MockerFixture ) -> MagicMock :
1825 return mocker .patch ("country_workspace.contrib.name_parser.parser.nn.LSTM" )
1926
2027
2128@pytest .fixture
22- def linear_mock (mocker : MockerFixture ) -> torch . nn . Linear :
29+ def nn_linear_mock (mocker : MockerFixture ) -> MagicMock :
2330 return mocker .patch ("country_workspace.contrib.name_parser.parser.nn.Linear" )
2431
2532
2633@pytest .fixture
27- def log_softmax_mock (mocker : MockerFixture ) -> torch . nn . LogSoftmax :
34+ def nn_log_softmax_mock (mocker : MockerFixture ) -> MagicMock :
2835 return mocker .patch ("country_workspace.contrib.name_parser.parser.nn.LogSoftmax" )
2936
3037
3138@pytest .fixture
32- def zeros_mock (mocker : MockerFixture ) -> torch . Tensor :
39+ def torch_zeros_mock (mocker : MockerFixture ) -> MagicMock :
3340 return mocker .patch ("country_workspace.contrib.name_parser.parser.torch.zeros" )
3441
3542
3643def test_lstm_class_init (
37- embedding_mock : torch . nn . Embedding ,
38- lstm_mock : torch . nn . LSTM ,
39- linear_mock : torch . nn . Linear ,
40- log_softmax_mock : torch . nn . LogSoftmax ,
44+ nn_embedding_mock : MagicMock ,
45+ nn_lstm_mock : MagicMock ,
46+ nn_linear_mock : MagicMock ,
47+ nn_log_softmax_mock : MagicMock ,
4148) -> None :
4249 instance = Mock (spec = LSTM )
4350 input_size , hidden_size , output_size , num_layers = range (4 )
@@ -46,28 +53,28 @@ def test_lstm_class_init(
4653
4754 assert instance .hidden_size == hidden_size
4855 assert instance .num_layers == num_layers
49- assert instance .embedding == embedding_mock .return_value
50- assert instance .lstm == lstm_mock .return_value
51- assert instance .fc == linear_mock .return_value
52- assert instance .softmax == log_softmax_mock .return_value
56+ assert instance .embedding == nn_embedding_mock .return_value
57+ assert instance .lstm == nn_lstm_mock .return_value
58+ assert instance .fc == nn_linear_mock .return_value
59+ assert instance .softmax == nn_log_softmax_mock .return_value
5360
54- embedding_mock .assert_called_with (input_size , hidden_size )
55- lstm_mock .assert_called_once_with (hidden_size , hidden_size , num_layers , batch_first = True )
56- linear_mock .assert_called_once_with (hidden_size , output_size )
57- log_softmax_mock .assert_called_once_with (dim = 1 )
61+ nn_embedding_mock .assert_called_with (input_size , hidden_size )
62+ nn_lstm_mock .assert_called_once_with (hidden_size , hidden_size , num_layers , batch_first = True )
63+ nn_linear_mock .assert_called_once_with (hidden_size , output_size )
64+ nn_log_softmax_mock .assert_called_once_with (dim = 1 )
5865
5966
6067def test_lstm_class_forward (
61- embedding_mock : torch . nn . Embedding ,
62- lstm_mock : torch . nn . LSTM ,
63- linear_mock : torch . nn . Linear ,
64- log_softmax_mock : torch . nn . LogSoftmax ,
65- zeros_mock : Callable ,
68+ nn_embedding_mock : MagicMock ,
69+ nn_lstm_mock : MagicMock ,
70+ nn_linear_mock : MagicMock ,
71+ nn_log_softmax_mock : MagicMock ,
72+ torch_zeros_mock : MagicMock ,
6673) -> None :
6774 instance = Mock (spec = LSTM )
6875 input_ = Mock (spec = torch .Tensor )
6976 input_size , hidden_size , output_size , num_layers = range (4 )
70- lstm_mock .return_value .return_value = (lstm_out := MagicMock ()), None
77+ nn_lstm_mock .return_value .return_value = (lstm_out := MagicMock ()), None
7178 LSTM .__init__ (instance , input_size , hidden_size , output_size , num_layers )
7279
7380 assert LSTM .forward (instance , input_ ) == instance .softmax .return_value
@@ -81,7 +88,7 @@ def test_lstm_class_forward(
8188 c ,
8289 ]
8390 )
84- zeros_mock .assert_has_calls (
91+ torch_zeros_mock .assert_has_calls (
8592 [
8693 c0 := call (num_layers , instance .embedding .return_value .size .return_value , hidden_size ),
8794 c1 := call ().to (input_ .device ),
@@ -90,8 +97,66 @@ def test_lstm_class_forward(
9097 ]
9198 )
9299 instance .lstm .assert_called_once_with (
93- instance .embedding .return_value , (zt := zeros_mock .return_value .to .return_value , zt )
100+ instance .embedding .return_value , (zt := torch_zeros_mock .return_value .to .return_value , zt )
94101 )
95102 lstm_out .__getitem__ .assert_called_once_with ((slice (None ), - 1 , slice (None )))
96103 instance .fc .assert_called_once_with (lstm_out .__getitem__ .return_value )
97104 instance .softmax .assert_called_once_with (instance .fc .return_value )
105+
106+
107+ def test_read_config (mocker : MockerFixture ) -> None :
108+ config = (
109+ alphabet := ("a" , "b" , "c" ),
110+ max_name_len := 42 ,
111+ rnn_args := (1 , 2 , 3 ),
112+ )
113+ open_mock = mocker .patch ("country_workspace.contrib.name_parser.parser.Path.open" )
114+ open_mock .return_value .__enter__ .return_value .readlines .return_value = (
115+ "" .join (alphabet ) + "\n " ,
116+ str (max_name_len ) + "\n " ,
117+ " " .join (map (str , rnn_args )) + "\n " ,
118+ )
119+
120+ assert read_config ("CNT" ) == config
121+ open_mock .assert_called_once ()
122+
123+
124+ def test_load_model (mocker : MockerFixture ) -> None :
125+ device_mock = mocker .patch ("country_workspace.contrib.name_parser.parser.DEVICE" )
126+ lstm_mock = mocker .patch ("country_workspace.contrib.name_parser.parser.LSTM" )
127+ load_mock = mocker .patch ("country_workspace.contrib.name_parser.parser.torch.load" )
128+ rnn_args = (1 , 2 , 3 )
129+ country_code = "CNT"
130+
131+ assert load_model (country_code , * rnn_args ) == lstm_mock .return_value
132+
133+ lstm_mock .assert_called_once_with (* rnn_args , num_layers = 2 )
134+ load_mock .assert_called_once_with (BASE_PATH / MODEL_PATH_TEMPLATE .format (country_code = country_code ))
135+ lstm_mock .return_value .load_state_dict .assert_called_once_with (load_mock .return_value )
136+ lstm_mock .return_value .to .assert_called_once_with (device_mock )
137+ lstm_mock .return_value .eval .assert_called_once ()
138+
139+
140+ def test_get_line_to_tensor_converter (mocker : MockerFixture ) -> None :
141+ torch_ones_mock = mocker .patch ("country_workspace.contrib.name_parser.parser.torch.ones" )
142+ converter = get_line_to_tensor_converter (_alphabet := MagicMock (), _max_name_len := 42 )
143+ assert converter ("Name" ) == torch_ones_mock .return_value .__mul__ .return_value
144+
145+
146+ def test_get_parser (mocker : MockerFixture ) -> None :
147+ read_config_mock = mocker .patch ("country_workspace.contrib.name_parser.parser.read_config" )
148+ read_config_mock .return_value = (alphabet := Mock (), max_name_len := Mock (), rnn_args := (Mock (),))
149+ load_model_mock = mocker .patch ("country_workspace.contrib.name_parser.parser.load_model" )
150+ get_line_to_tensor_converter_mock = mocker .patch (
151+ "country_workspace.contrib.name_parser.parser.get_line_to_tensor_converter"
152+ )
153+ mocker .patch ("country_workspace.contrib.name_parser.parser.torch.exp" )
154+ mocker .patch ("country_workspace.contrib.name_parser.parser.torch.argmax" )
155+ name_types_mock = mocker .patch ("country_workspace.contrib.name_parser.parser.NAME_TYPES" )
156+
157+ parser = get_parser (country_code := "CNT" )
158+ assert parser ("Full Name" ) == [nt := name_types_mock .__getitem__ .return_value , nt ]
159+
160+ read_config_mock .assert_called_once_with (country_code )
161+ load_model_mock .assert_called_once_with (country_code , * rnn_args )
162+ get_line_to_tensor_converter_mock .assert_called_once_with (alphabet , max_name_len )
0 commit comments