Skip to content

Commit 495531c

Browse files
Refactor for easier testing
1 parent d59465e commit 495531c

File tree

1 file changed

+29
-11
lines changed
  • src/country_workspace/contrib/name_parser

1 file changed

+29
-11
lines changed

src/country_workspace/contrib/name_parser/parser.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,32 +36,50 @@ def forward(self, input_: torch.Tensor) -> torch.Tensor:
3636
return self.softmax(out)
3737

3838

39-
def get_parser(country_code: str) -> Parser:
39+
Alphabet = tuple[str, ...]
40+
ModelArgs = tuple[int, ...]
41+
UNKNOWN_CHAR = "_"
42+
43+
44+
def read_config(country_code: str) -> tuple[Alphabet, int, ModelArgs]:
4045
with (BASE_PATH / f"data/name_parser/models/{country_code}.txt").open() as f:
4146
lines = tuple(line.rstrip("\n") for line in f.readlines())
4247

43-
unknown = "_"
44-
alphabet = tuple(lines[0])
45-
alphabet_len = len(alphabet)
46-
name_max_len = int(lines[1])
47-
rnn_args = map(int, lines[2].split())
48+
return (
49+
tuple(lines[0]),
50+
int(lines[1]),
51+
tuple(map(int, lines[2].split())),
52+
)
4853

49-
rnn = LSTM(*rnn_args, num_layers=2)
54+
55+
def load_model(country_code: str, *args: int) -> nn.Module:
56+
rnn = LSTM(*args, num_layers=2)
5057
rnn.load_state_dict(torch.load(BASE_PATH / f"data/name_parser/models/{country_code}.pt"))
5158
rnn.to(DEVICE)
5259
rnn.eval()
60+
return rnn
5361

54-
def letter_to_index(letter: str) -> int:
55-
return alphabet.index(letter) if letter in alphabet else alphabet.index(unknown)
5662

57-
oob = alphabet_len + 1
63+
def get_line_to_tensor_converter(alphabet: Alphabet, max_name_len: int) -> Callable[[str], torch.Tensor]:
64+
oob = len(alphabet) + 1
65+
66+
def letter_to_index(letter: str) -> int:
67+
return alphabet.index(letter) if letter in alphabet else alphabet.index(UNKNOWN_CHAR)
5868

5969
def line_to_tensor(line: str) -> torch.Tensor:
60-
tensor = torch.ones(name_max_len, dtype=torch.long) * oob
70+
tensor = torch.ones(max_name_len, dtype=torch.long) * oob
6171
for li, letter in enumerate(line):
6272
tensor[li] = letter_to_index(letter)
6373
return tensor
6474

75+
return line_to_tensor
76+
77+
78+
def get_parser(country_code: str) -> Parser:
79+
alphabet, max_name_len, rnn_args = read_config(country_code)
80+
rnn = load_model(country_code, *rnn_args)
81+
line_to_tensor = get_line_to_tensor_converter(alphabet, max_name_len)
82+
6583
def parser(name: str) -> list[str]:
6684
name_tokens = [line_to_tensor(i) for i in name.split()]
6785
out = [rnn(i.unsqueeze(0).to(DEVICE)) for i in name_tokens]

0 commit comments

Comments
 (0)