-
Notifications
You must be signed in to change notification settings - Fork 35
/
Copy pathlanguage_model.py
138 lines (107 loc) · 4.5 KB
/
language_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""Helper functions for language model use."""
import inspect
import math
from typing import Dict, List, Tuple
import numpy as np
from bcipy.core.symbols import alphabet
from bcipy.language.main import BciPyLanguageModel, ResponseType
# pylint: disable=unused-import
# flake8: noqa
"""Only imported models will be included in language_models_by_name"""
# flake8: noqa
from bcipy.language.model.causal import CausalLanguageModelAdapter
from bcipy.language.model.ngram import NGramLanguageModelAdapter
from bcipy.language.model.mixture import MixtureLanguageModelAdapter
from bcipy.language.model.oracle import OracleLanguageModel
from bcipy.language.model.uniform import UniformLanguageModel
def language_models_by_name() -> Dict[str, BciPyLanguageModel]:
"""Returns available language models indexed by name."""
return {lm.name(): lm for lm in BciPyLanguageModel.__subclasses__()}
def init_language_model(parameters: dict) -> BciPyLanguageModel:
"""
Init Language Model configured in the parameters. If no language model is
specified, a uniform language model is returned.
Parameters
----------
parameters : dict
configuration details and path locations
Returns
-------
instance of a BciPyLanguageModel
"""
language_models = language_models_by_name()
model = language_models[parameters.get("lang_model_type", "UNIFORM")]
# introspect the model arguments to determine what parameters to pass.
args = inspect.signature(model).parameters.keys()
# select the relevant parameters into a dict.
params = {key: parameters[key] for key in args & parameters.keys()}
return model(
response_type=ResponseType.SYMBOL,
symbol_set=alphabet(parameters),
**params)
def norm_domain(priors: List[Tuple[str, float]]) -> List[Tuple[str, float]]:
"""Convert a list of (symbol, likelihood) values from negative log
likelihood to the probability domain (between 0 and 1)
Parameters:
priors - list of (symbol, likelihood) values.
assumes that the units are in the negative log likelihood where
the lowest value is the most likely.
Returns:
list of values in the probability domain (between 0 and 1),
where the highest value is the most likely.
"""
return [(sym, math.exp(-prob)) for sym, prob in priors]
def with_min_prob(symbol_probs: List[Tuple[str, float]],
sym_prob: Tuple[str, float]) -> List[Tuple[str, float]]:
"""Returns a new list of symbol-probability pairs where the provided
symbol has a minimum probability given in the sym_prob.
If the provided symbol is already in the list with a greater probability,
the list of symbol_probs will be returned unmodified.
If the new probability is added or modified, existing values are adjusted
equally.
Parameters:
-----------
symbol_probs - list of symbol, probability pairs
sym_prob - (symbol, min_probability) defines the minimum probability
for the given symbol in the returned list.
Returns:
-------
list of (symbol, probability) pairs such that the sum of the
probabilities is approx. 1.0.
"""
new_sym, new_prob = sym_prob
# Split out symbols and probabilities into separate lists, excluding the
# symbol to be adjusted.
symbols = []
probs = []
for sym, prob in symbol_probs:
if sym != new_sym:
symbols.append(sym)
probs.append(prob)
elif prob >= new_prob:
# symbol prob in list is larger than minimum.
return symbol_probs
probabilities = np.array(probs)
# Add new symbol and its probability
all_probs = np.append(probabilities, new_prob / (1 - new_prob))
all_symbols = symbols + [new_sym]
normalized = all_probs / sum(all_probs)
return list(zip(all_symbols, normalized))
def histogram(letter_prior: List[Tuple[str, float]]) -> str:
"""Given a list of letter, prob tuples, generate a histogram that can be
output to the console.
Parameters:
-----------
letter_prior - list of letter, probability pairs
Returns:
--------
printable string which contains a histogram with the letter and probability as the label.
"""
margin = "\t"
star = '*'
lines = []
for letter, prob in sorted(letter_prior):
units = int(round(prob * 100))
lines.append(letter + ' (' + "%03.2f" % (prob) + ") :" + margin +
(units * star))
return '\n'.join(lines)