@@ -36,32 +36,50 @@ def forward(self, input_: torch.Tensor) -> torch.Tensor:
3636 return self .softmax (out )
3737
3838
39- def get_parser (country_code : str ) -> Parser :
39+ Alphabet = tuple [str , ...]
40+ ModelArgs = tuple [int , ...]
41+ UNKNOWN_CHAR = "_"
42+
43+
44+ def read_config (country_code : str ) -> tuple [Alphabet , int , ModelArgs ]:
4045 with (BASE_PATH / f"data/name_parser/models/{ country_code } .txt" ).open () as f :
4146 lines = tuple (line .rstrip ("\n " ) for line in f .readlines ())
4247
43- unknown = "_"
44- alphabet = tuple (lines [0 ])
45- alphabet_len = len ( alphabet )
46- name_max_len = int ( lines [1 ])
47- rnn_args = map ( int , lines [ 2 ]. split () )
48+ return (
49+ tuple (lines [0 ]),
50+ int ( lines [ 1 ]),
51+ tuple ( map ( int , lines [2 ]. split ())),
52+ )
4853
49- rnn = LSTM (* rnn_args , num_layers = 2 )
54+
55+ def load_model (country_code : str , * args : int ) -> nn .Module :
56+ rnn = LSTM (* args , num_layers = 2 )
5057 rnn .load_state_dict (torch .load (BASE_PATH / f"data/name_parser/models/{ country_code } .pt" ))
5158 rnn .to (DEVICE )
5259 rnn .eval ()
60+ return rnn
5361
54- def letter_to_index (letter : str ) -> int :
55- return alphabet .index (letter ) if letter in alphabet else alphabet .index (unknown )
5662
57- oob = alphabet_len + 1
63+ def get_line_to_tensor_converter (alphabet : Alphabet , max_name_len : int ) -> Callable [[str ], torch .Tensor ]:
64+ oob = len (alphabet ) + 1
65+
66+ def letter_to_index (letter : str ) -> int :
67+ return alphabet .index (letter ) if letter in alphabet else alphabet .index (UNKNOWN_CHAR )
5868
5969 def line_to_tensor (line : str ) -> torch .Tensor :
60- tensor = torch .ones (name_max_len , dtype = torch .long ) * oob
70+ tensor = torch .ones (max_name_len , dtype = torch .long ) * oob
6171 for li , letter in enumerate (line ):
6272 tensor [li ] = letter_to_index (letter )
6373 return tensor
6474
75+ return line_to_tensor
76+
77+
78+ def get_parser (country_code : str ) -> Parser :
79+ alphabet , max_name_len , rnn_args = read_config (country_code )
80+ rnn = load_model (country_code , * rnn_args )
81+ line_to_tensor = get_line_to_tensor_converter (alphabet , max_name_len )
82+
6583 def parser (name : str ) -> list [str ]:
6684 name_tokens = [line_to_tensor (i ) for i in name .split ()]
6785 out = [rnn (i .unsqueeze (0 ).to (DEVICE )) for i in name_tokens ]
0 commit comments