diff --git a/bcipy/core/tests/resources/mock_session/parameters.json b/bcipy/core/tests/resources/mock_session/parameters.json index 9331fe71..99b99a23 100644 --- a/bcipy/core/tests/resources/mock_session/parameters.json +++ b/bcipy/core/tests/resources/mock_session/parameters.json @@ -680,7 +680,7 @@ "recommended": [ "UNIFORM", "CAUSAL", - "KENLM", + "NGRAM", "MIXTURE", "ORACLE" ], diff --git a/bcipy/exceptions.py b/bcipy/exceptions.py index e58bb161..23c525a0 100644 --- a/bcipy/exceptions.py +++ b/bcipy/exceptions.py @@ -1,4 +1,3 @@ - class BciPyCoreException(Exception): """BciPy Core Exception. @@ -75,30 +74,29 @@ class InvalidFieldException(FieldException): ... -class UnsupportedResponseType(BciPyCoreException): - """Unsupported ResponseType +class TaskConfigurationException(BciPyCoreException): + """Task Configuration Exception. - Thrown when attempting to set the response type of a language model to an - unsupported value.""" + Thrown when attempting to run a task with invalid configurations""" ... -class TaskConfigurationException(BciPyCoreException): - """Task Configuration Exception. +class KenLMInstallationException(BciPyCoreException): + """KenLM Installation Exception. - Thrown when attempting to run a task with invalid configurations""" + Thrown when attempting to import kenlm without installing the module""" ... -class InvalidLanguageModelException(BciPyCoreException): - """Invalid Language Model Exception. +class InvalidSymbolSetException(BciPyCoreException): + """Invalid Symbol Set Exception. - Thrown when attempting to load a language model from an invalid path""" + Thrown when querying a language model for predictions without setting the symbol set.""" ... -class KenLMInstallationException(BciPyCoreException): - """KenLM Installation Exception. +class LanguageModelNameInUseException(BciPyCoreException): + """Language Model Name In Use Exception. - Thrown when attempting to import kenlm without installing the module""" + Thrown when attempting to register a language model type with a duplicate name.""" ... diff --git a/bcipy/helpers/copy_phrase_wrapper.py b/bcipy/helpers/copy_phrase_wrapper.py index cfa40aa8..c41e4daf 100644 --- a/bcipy/helpers/copy_phrase_wrapper.py +++ b/bcipy/helpers/copy_phrase_wrapper.py @@ -10,7 +10,7 @@ from bcipy.core.symbols import BACKSPACE_CHAR from bcipy.exceptions import BciPyCoreException from bcipy.helpers.language_model import histogram, with_min_prob -from bcipy.language.main import LanguageModel +from bcipy.language.main import CharacterLanguageModel, LanguageModel from bcipy.task.control.criteria import (CriteriaEvaluator, MaxIterationsCriteria, MinIterationsCriteria, @@ -180,6 +180,9 @@ def letter_info(self, triggers: List[Tuple[str, float]], def initialize_series(self) -> Tuple[bool, InquirySchedule]: """If a decision is made initializes the next series.""" assert self.lmodel, "Language model must be initialized." + if not isinstance(self.lmodel, CharacterLanguageModel): + raise BciPyCoreException( + "Only character language models are currently supported.") try: # First, reset the history for this new series @@ -190,7 +193,7 @@ def initialize_series(self) -> Tuple[bool, InquirySchedule]: log.info(f"Querying language model: '{update}'") # update the lmodel and get back the priors - lm_letter_prior = self.lmodel.predict(list(update)) + lm_letter_prior = self.lmodel.predict_character(list(update)) if BACKSPACE_CHAR in self.alp: # Apply configured backspace probability. diff --git a/bcipy/helpers/language_model.py b/bcipy/helpers/language_model.py index 2c0bcd4d..81005817 100644 --- a/bcipy/helpers/language_model.py +++ b/bcipy/helpers/language_model.py @@ -1,29 +1,45 @@ """Helper functions for language model use.""" import inspect import math -from typing import Dict, List, Tuple +from typing import Callable, Dict, List, Tuple import numpy as np from bcipy.core.symbols import alphabet -from bcipy.language.main import LanguageModel, ResponseType +from bcipy.exceptions import LanguageModelNameInUseException +from bcipy.language.main import LanguageModel # pylint: disable=unused-import # flake8: noqa """Only imported models will be included in language_models_by_name""" # flake8: noqa -from bcipy.exceptions import InvalidLanguageModelException -from bcipy.language.model.causal import CausalLanguageModel -from bcipy.language.model.kenlm import KenLMLanguageModel -from bcipy.language.model.mixture import MixtureLanguageModel +from bcipy.language.model.causal import CausalLanguageModelAdapter +from bcipy.language.model.mixture import MixtureLanguageModelAdapter +from bcipy.language.model.ngram import NGramLanguageModelAdapter from bcipy.language.model.oracle import OracleLanguageModel from bcipy.language.model.uniform import UniformLanguageModel +VALID_LANGUAGE_MODELS: Dict[str, Callable[[], LanguageModel]] = { + "CAUSAL": CausalLanguageModelAdapter, + "NGRAM": NGramLanguageModelAdapter, + "MIXTURE": MixtureLanguageModelAdapter, + "ORACLE": OracleLanguageModel, + "UNIFORM": UniformLanguageModel +} -def language_models_by_name() -> Dict[str, LanguageModel]: + +def language_models_by_name() -> Dict[str, Callable[[], LanguageModel]]: """Returns available language models indexed by name.""" - return {lm.name(): lm for lm in LanguageModel.__subclasses__()} + return VALID_LANGUAGE_MODELS + + +def register_language_model(name: str, lm_type: Callable[[], LanguageModel]) -> None: + if name in VALID_LANGUAGE_MODELS: + raise LanguageModelNameInUseException( + f"{name} is already registered as {VALID_LANGUAGE_MODELS[name]}.") + else: + VALID_LANGUAGE_MODELS[name] = lm_type def init_language_model(parameters: dict) -> LanguageModel: @@ -50,10 +66,11 @@ def init_language_model(parameters: dict) -> LanguageModel: # select the relevant parameters into a dict. params = {key: parameters[key] for key in args & parameters.keys()} - return model( - response_type=ResponseType.SYMBOL, - symbol_set=alphabet(parameters), - **params) + lm = model(**params) + + lm.set_symbol_set(alphabet(parameters)) + + return lm def norm_domain(priors: List[Tuple[str, float]]) -> List[Tuple[str, float]]: diff --git a/bcipy/helpers/tests/test_copy_phrase_wrapper.py b/bcipy/helpers/tests/test_copy_phrase_wrapper.py index ee08a6a5..7eeccc89 100644 --- a/bcipy/helpers/tests/test_copy_phrase_wrapper.py +++ b/bcipy/helpers/tests/test_copy_phrase_wrapper.py @@ -1,7 +1,7 @@ import unittest from bcipy.helpers.copy_phrase_wrapper import CopyPhraseWrapper -from bcipy.core.symbols import alphabet +from bcipy.core.symbols import DEFAULT_SYMBOL_SET from bcipy.language.model.uniform import UniformLanguageModel from bcipy.task.data import EvidenceType @@ -10,12 +10,11 @@ class TestCopyPhraseWrapper(unittest.TestCase): """Test CopyPhraseWrapper""" def test_valid_letters(self): - alp = alphabet() cp = CopyPhraseWrapper( min_num_inq=1, max_num_inq=50, lmodel=None, - alp=alp, + alp=DEFAULT_SYMBOL_SET, task_list=[("HELLO_WORLD", "HE")], is_txt_stim=True, evidence_names=[EvidenceType.LM, EvidenceType.ERP], @@ -104,13 +103,15 @@ def test_valid_letters(self): ["nontarget", "nontarget"]) def test_init_series(self): - alp = alphabet() + + lmodel = UniformLanguageModel() + lmodel.set_symbol_set(DEFAULT_SYMBOL_SET) copy_phrase_task = CopyPhraseWrapper( min_num_inq=1, max_num_inq=50, - lmodel=UniformLanguageModel(symbol_set=alp), - alp=alp, + lmodel=lmodel, + alp=DEFAULT_SYMBOL_SET, task_list=[("HELLO_WORLD", "HE")], is_txt_stim=True, evidence_names=[EvidenceType.LM, EvidenceType.ERP], diff --git a/bcipy/language/README.md b/bcipy/language/README.md index 1438a745..c54ca8e3 100644 --- a/bcipy/language/README.md +++ b/bcipy/language/README.md @@ -1,6 +1,6 @@ # Language -BciPy Language module provides an interface for word and character level predictions. +BciPy Language module provides an interface for word and character level predictions. This module primarily relies upon the TextSlinger package (textslinger on PyPI) for its probability calculations. More information on this package can be found on our [GitHub repo](https://github.com/kdv123/textpredict) The core methods of any `LanguageModel` include: @@ -8,8 +8,6 @@ The core methods of any `LanguageModel` include: > `load` - load a pre-trained model given a path (currently BciPy does not support training language models!) -> `update` - update internal state of your model. - You may of course define other methods, however all integrated BciPy experiments using your model will require those to be defined! The language module has the following structure: @@ -30,20 +28,20 @@ The language module has the following structure: The UniformLanguageModel provides equal probabilities for all symbols in the symbol set. This model is useful for evaluating other aspects of the system, such as EEG signal quality, without any influence from a language model. -## KenLM Model -The KenLMLanguageModel utilizes a pretrained n-gram language model to generate probabilities for all symbols in the symbol set. N-gram models use frequencies of different character sequences to generate their predictions. Models trained on AAC-like data can be found [here](https://imagineville.org/software/lm/dec19_char/). For faster load times, it is recommended to use the binary models located at the bottom of the page. The default parameters file utilizes `lm_dec19_char_large_12gram.kenlm`. If you have issues accessing, please reach out to us on GitHub or via email at `cambi_support@googlegroups.com`. +## NGram Model +The NGramLanguageModelAdapter utilizes a pretrained n-gram language model to generate probabilities for all symbols in the symbol set. N-gram models use frequencies of different character sequences to generate their predictions. Models trained on AAC-like data can be found [here](https://imagineville.org/software/lm/dec19_char/). For faster load times, it is recommended to use the binary models located at the bottom of the page. The default parameters file utilizes `lm_dec19_char_large_12gram.kenlm`. If you have issues accessing, please reach out to us on GitHub or via email at `cambi_support@googlegroups.com`. For models that import the kenlm module, this must be manually installed using `pip install kenlm==0.1 --global-option="max_order=12"`. ## Causal Model -The CausalLanguageModel class can use any causal language model from Huggingface, though it has only been tested with gpt2, facebook/opt, and distilgpt2 families of models. Causal language models predict the next token in a sequence of tokens. For the many of these models, byte-pair encoding (BPE) is used for tokenization. The main idea of BPE is to create a fixed-size vocabulary that contains common English subword units. Then a less common word would be broken down into several subword units in the vocabulary. For example, the tokenization of character sequence `peanut_butter_and_jel` would be: +The CausalLanguageModelAdapter class can use any causal language model from Huggingface, though it has only been tested with gpt2, facebook/opt, and distilgpt2 families of models (including the domain-adapted figmtu/opt-350m-aac). Causal language models predict the next token in a sequence of tokens. For the many of these models, byte-pair encoding (BPE) is used for tokenization. The main idea of BPE is to create a fixed-size vocabulary that contains common English subword units. Then a less common word would be broken down into several subword units in the vocabulary. For example, the tokenization of character sequence `peanut_butter_and_jel` would be: > *['pe', 'anut', '_butter', '_and', '_j', 'el']* -Therefore, in order to generate a predictive distribution on the next character, we need to examine all the possibilities that could complete the final subword tokens in the input sequences. We must remove at least one token from the end of the context to allow the model the option of extending it, as opposed to only adding a new token. Removing more tokens allows the model more flexibility and may lead to better predictions, but at the cost of a higher prediction time. In this model we remove all of the subword tokens in the current (partially-typed) word to allow it the most flexibility. We then ask the model to estimate the likelihood of the next token and evaluate each token that matches our context. For efficiency, we only track a certain number of hypotheses at a time, known as the beam width, and each hypothesis until it surpasses the context. We can then store the likelihood for each final prediction in a list based on the character that directly follows the context. Once we have no more hypotheses to extend, we can sum the likelihoods stored for each character in our symbol set and normalize so they sum to 1, giving us our final distribution. +Therefore, in order to generate a predictive distribution on the next character, we need to examine all the possibilities that could complete the final subword tokens in the input sequences. We must remove at least one token from the end of the context to allow the model the option of extending it, as opposed to only adding a new token. Removing more tokens allows the model more flexibility and may lead to better predictions, but at the cost of a higher prediction time. In this model we remove all of the subword tokens in the current (partially-typed) word to allow it the most flexibility. We then ask the model to estimate the likelihood of the next token and evaluate each token that matches our context. For efficiency, we only track a certain number of hypotheses at a time, known as the beam width, and each hypothesis until it surpasses the context. We can then store the likelihood for each final prediction in a list based on the character that directly follows the context. Once we have no more hypotheses to extend, we can sum the likelihoods stored for each character in our symbol set and normalize so they sum to 1, giving us our final distribution. More details on this process can be found in our paper, [Adapting Large Language Models for Character-based Augmentative and Alternative Communication](https://arxiv.org/abs/2501.10582). ## Mixture Model -The MixtureLanguageModel class allows for the combination of two or more supported models. The selected models are mixed according to the provided weights, which can be tuned using the Bcipy/scripts/python/mixture_tuning.py script. It is not recommended to use more than one "heavy-weight" model with long prediction times (the CausalLanguageModel) since this model will query each component model and parallelization is not currently supported. +The MixtureLanguageModelAdapter class allows for the combination of two or more supported models. The selected models are mixed according to the provided weights, which can be tuned using the Bcipy/scripts/python/mixture_tuning.py script. It is not recommended to use more than one "heavy-weight" model with long prediction times (the CausalLanguageModel) since this model will query each component model and parallelization is not currently supported. # Contact Information diff --git a/bcipy/language/__init__.py b/bcipy/language/__init__.py index 73c549b3..c6f35313 100644 --- a/bcipy/language/__init__.py +++ b/bcipy/language/__init__.py @@ -1,6 +1,7 @@ -from .main import LanguageModel, ResponseType +from .main import LanguageModel, CharacterLanguageModel, WordLanguageModel __all__ = [ "LanguageModel", - "ResponseType", + "CharacterLanguageModel", + "WordLanguageModel" ] diff --git a/bcipy/language/demo/demo_causal.py b/bcipy/language/demo/demo_causal.py index 486b228e..1d6fb342 100644 --- a/bcipy/language/demo/demo_causal.py +++ b/bcipy/language/demo/demo_causal.py @@ -1,26 +1,35 @@ -from bcipy.language.model.causal import CausalLanguageModel -from bcipy.core.symbols import alphabet -from bcipy.language.main import ResponseType +from bcipy.language.model.causal import CausalLanguageModelAdapter +from bcipy.core.symbols import DEFAULT_SYMBOL_SET if __name__ == "__main__": - symbol_set = alphabet() - response_type = ResponseType.SYMBOL - lm = CausalLanguageModel(response_type, symbol_set, lang_model_name="gpt2") + lm = CausalLanguageModelAdapter(lang_model_name="figmtu/opt-350m-aac") + lm.set_symbol_set(DEFAULT_SYMBOL_SET) - next_char_pred = lm.state_update(list("does_it_make_sen")) - print(next_char_pred) + print("Target sentence: does_it_make_sense\n") + + next_char_pred = lm.predict_character(list("does_it_make_sen")) + print("Context: does_it_make_sen") + print(f"Predictions: {next_char_pred}") correct_char_rank = [c[0] for c in next_char_pred].index("S") + 1 - print(correct_char_rank) - next_char_pred = lm.state_update(list("does_it_make_sens")) - print(next_char_pred) + print(f"Correct character rank: {correct_char_rank}\n") + + next_char_pred = lm.predict_character(list("does_it_make_sens")) + print("Context: does_it_make_sens") + print(f"Predictions: {next_char_pred}") correct_char_rank = [c[0] for c in next_char_pred].index("E") + 1 - print(correct_char_rank) - next_char_pred = lm.state_update(list("does_it_make_sense")) - print(next_char_pred) + print(f"Correct character rank: {correct_char_rank}\n") + + next_char_pred = lm.predict_character(list("does_it_make_sense")) + print("Context: does_it_make_sense") + print(f"Predictions: {next_char_pred}") correct_char_rank = [c[0] for c in next_char_pred].index("_") + 1 - print(correct_char_rank) - next_char_pred = lm.state_update(list("i_like_zebra")) - print(next_char_pred) + print(f"Correct character rank: {correct_char_rank}\n") + + print("Target sentence: i_like_zebras\n") + + next_char_pred = lm.predict_character(list("i_like_zebra")) + print("Context: i_like_zebra") + print(f"Predictions: {next_char_pred}") correct_char_rank = [c[0] for c in next_char_pred].index("S") + 1 - print(correct_char_rank) + print(f"Correct character rank: {correct_char_rank}\n") diff --git a/bcipy/language/demo/demo_mixture.py b/bcipy/language/demo/demo_mixture.py index fee828e4..aeec3743 100644 --- a/bcipy/language/demo/demo_mixture.py +++ b/bcipy/language/demo/demo_mixture.py @@ -1,22 +1,36 @@ -from bcipy.language.model.mixture import MixtureLanguageModel -from bcipy.core.symbols import alphabet -from bcipy.language.main import ResponseType +from bcipy.language.model.mixture import MixtureLanguageModelAdapter +from bcipy.core.symbols import DEFAULT_SYMBOL_SET if __name__ == "__main__": - symbol_set = alphabet() - response_type = ResponseType.SYMBOL - lm = MixtureLanguageModel(response_type, symbol_set) + # Load the default mixture model from lm_params.json + lm = MixtureLanguageModelAdapter() + lm.set_symbol_set(DEFAULT_SYMBOL_SET) - next_char_pred = lm.state_update(list("does_it_make_sen")) - print(next_char_pred) + print("Target sentence: does_it_make_sense\n") + + next_char_pred = lm.predict_character(list("does_it_make_sen")) + print("Context: does_it_make_sen") + print(f"Predictions: {next_char_pred}") correct_char_rank = [c[0] for c in next_char_pred].index("S") + 1 - print(correct_char_rank) - next_char_pred = lm.state_update(list("does_it_make_sens")) - print(next_char_pred) + print(f"Correct character rank: {correct_char_rank}\n") + + next_char_pred = lm.predict_character(list("does_it_make_sens")) + print("Context: does_it_make_sens") + print(f"Predictions: {next_char_pred}") correct_char_rank = [c[0] for c in next_char_pred].index("E") + 1 - print(correct_char_rank) - next_char_pred = lm.state_update(list("does_it_make_sense")) - print(next_char_pred) + print(f"Correct character rank: {correct_char_rank}\n") + + next_char_pred = lm.predict_character(list("does_it_make_sense")) + print("Context: does_it_make_sense") + print(f"Predictions: {next_char_pred}") correct_char_rank = [c[0] for c in next_char_pred].index("_") + 1 - print(correct_char_rank) + print(f"Correct character rank: {correct_char_rank}\n") + + print("Target sentence: i_like_zebras\n") + + next_char_pred = lm.predict_character(list("i_like_zebra")) + print("Context: i_like_zebra") + print(f"Predictions: {next_char_pred}") + correct_char_rank = [c[0] for c in next_char_pred].index("S") + 1 + print(f"Correct character rank: {correct_char_rank}\n") diff --git a/bcipy/language/demo/demo_kenlm.py b/bcipy/language/demo/demo_ngram.py similarity index 72% rename from bcipy/language/demo/demo_kenlm.py rename to bcipy/language/demo/demo_ngram.py index cf42c503..50e335a1 100644 --- a/bcipy/language/demo/demo_kenlm.py +++ b/bcipy/language/demo/demo_ngram.py @@ -1,8 +1,7 @@ # Basic sanity test of using KenLM to predict a sentence using a 12-gram character model. -from bcipy.language.model.kenlm import KenLMLanguageModel -from bcipy.core.symbols import alphabet -from bcipy.language.main import ResponseType +from bcipy.language.model.ngram import NGramLanguageModelAdapter +from bcipy.core.symbols import DEFAULT_SYMBOL_SET from bcipy.config import LM_PATH from bcipy.exceptions import KenLMInstallationException @@ -16,6 +15,8 @@ if __name__ == "__main__": lm_path = f"{LM_PATH}/lm_dec19_char_12gram_1e-5_kenlm_probing.bin" + # Using KenLM directly + # Load a really pruned n-gram language model model = kenlm.LanguageModel(lm_path) @@ -80,27 +81,38 @@ prev = token print(f"sum logprob = {accum:.4f}") - symbol_set = alphabet() - response_type = ResponseType.SYMBOL - lm = KenLMLanguageModel(response_type, symbol_set, lm_path) + # Using the adapter and textslinger toolkit + lm = NGramLanguageModelAdapter(lm_path) + lm.set_symbol_set(DEFAULT_SYMBOL_SET) + + print("Target sentence: i_like_zebras\n") - next_char_pred = lm.state_update(list("i_like_z")) - print(next_char_pred) + next_char_pred = lm.predict_character(list("i_like_z")) + print("Context: i_like_z") + print(f"Predictions: {next_char_pred}") correct_char_rank = [c[0] for c in next_char_pred].index("E") + 1 - print(correct_char_rank) - next_char_pred = lm.state_update(list("i_lik")) - print(next_char_pred) + print(f"Correct character rank: {correct_char_rank}\n") + + next_char_pred = lm.predict_character(list("i_lik")) + print("Context: i_lik") + print(f"Predictions: {next_char_pred}") correct_char_rank = [c[0] for c in next_char_pred].index("E") + 1 - print(correct_char_rank) - next_char_pred = lm.state_update(list("i_like_zebras")) - print(next_char_pred) + print(f"Correct character rank: {correct_char_rank}\n") + + next_char_pred = lm.predict_character(list("i_like_zebras")) + print("Context: i_like_zebras") + print(f"Predictions: {next_char_pred}") correct_char_rank = [c[0] for c in next_char_pred].index("_") + 1 - print(correct_char_rank) - next_char_pred = lm.state_update(list("")) - print(next_char_pred) + print(f"Correct character rank: {correct_char_rank}\n") + + next_char_pred = lm.predict_character(list("")) + print("Context: ") + print(f"Predictions: {next_char_pred}") correct_char_rank = [c[0] for c in next_char_pred].index("I") + 1 - print(correct_char_rank) - next_char_pred = lm.state_update(list("i_like_zebra")) - print(next_char_pred) + print(f"Correct character rank: {correct_char_rank}\n") + + next_char_pred = lm.predict_character(list("i_like_zebra")) + print("Context: i_like_zebra") + print(f"Predictions: {next_char_pred}") correct_char_rank = [c[0] for c in next_char_pred].index("S") + 1 - print(correct_char_rank) + print(f"Correct character rank: {correct_char_rank}\n") diff --git a/bcipy/language/main.py b/bcipy/language/main.py index ed3bfdac..37136a03 100644 --- a/bcipy/language/main.py +++ b/bcipy/language/main.py @@ -1,87 +1,50 @@ """Defines the language model base class.""" -from abc import ABC, abstractmethod -from enum import Enum -from typing import List, Optional, Tuple -import json +from abc import abstractmethod +from typing import List, Optional, Protocol, Tuple, Union, runtime_checkable -from bcipy.exceptions import UnsupportedResponseType -from bcipy.core.symbols import DEFAULT_SYMBOL_SET -from bcipy.config import DEFAULT_LM_PARAMETERS_PATH - -class ResponseType(Enum): - """Language model response type options.""" - SYMBOL = 'Symbol' - WORD = 'Word' - - def __str__(self): - return self.value - - -class LanguageModel(ABC): - """Parent class for Language Models.""" - - _response_type: ResponseType = None - symbol_set: List[str] = None - - def __init__(self, - response_type: Optional[ResponseType] = None, - symbol_set: Optional[List[str]] = None): - self.response_type = response_type or ResponseType.SYMBOL - self.symbol_set = symbol_set or DEFAULT_SYMBOL_SET - with open(DEFAULT_LM_PARAMETERS_PATH, 'r') as params_file: - self.parameters = json.load(params_file) - - @classmethod - def name(cls) -> str: - """Model name used for configuration""" - suffix = 'LanguageModel' - if cls.__name__.endswith(suffix): - return cls.__name__[0:-len(suffix)].upper() - return cls.__name__.upper() +class UsesSymbols(Protocol): + """A protocol for classes in which symbols can be set.""" + symbol_set: Optional[List[str]] = None @abstractmethod - def supported_response_types(self) -> List[ResponseType]: - """Returns a list of response types supported by this language model.""" + def set_symbol_set(self, symbol_set: List[str]) -> None: + """Updates the symbol set of the model. Must be called prior to prediction""" - @property - def response_type(self) -> ResponseType: - """Returns the current response type""" - return self._response_type - @response_type.setter - def response_type(self, value: ResponseType): - """Attempts to set the response type to the given value""" - if value not in self.supported_response_types(): - raise UnsupportedResponseType( - f"{value} responses are not supported by this model") - self._response_type = value +@runtime_checkable +class CharacterLanguageModel(UsesSymbols, Protocol): + """Protocol for BciPy Language models that predict characters.""" @abstractmethod - def predict(self, evidence: List[str]) -> List[Tuple]: + def predict_character(self, evidence: Union[str, + List[str]]) -> List[Tuple]: """ - Using the provided data, compute log likelihoods over the entire symbol set. + Using the provided data, compute the probability distribution over the entire symbol set. Args: - evidence - ['A', 'B'] + evidence - ['H', 'E'] Response: - probability - dependant on response type, a list of words or symbols with probability + probability - a list of symbols with probability """ - ... - @abstractmethod - def update(self) -> None: - """Update the model state""" - ... + +@runtime_checkable +class WordLanguageModel(UsesSymbols, Protocol): + """Protocol for BciPy Language models that predict words.""" @abstractmethod - def load(self) -> None: - """Restore model state from the provided checkpoint""" - ... + def predict_word(self, evidence: Union[str, List[str]], + num_predictions: int) -> List[Tuple]: + """ + Using the provided data, compute log likelihoods of word completions + in the symbol set. + Args: + evidence - ['H', 'E'] + + Response: + a list of words with associated log likelihoods + """ - def reset(self) -> None: - """Reset language model state""" - ... - def state_update(self, evidence: List[str]) -> List[Tuple]: - """Update state by predicting and updating""" +LanguageModel = Union[CharacterLanguageModel, WordLanguageModel] diff --git a/bcipy/language/model/adapter.py b/bcipy/language/model/adapter.py new file mode 100644 index 00000000..9fb6ff43 --- /dev/null +++ b/bcipy/language/model/adapter.py @@ -0,0 +1,67 @@ +"""Defines the language model adapter base class.""" +import json +from abc import ABC, abstractmethod +from typing import List, Optional, Tuple, Union + +from bcipy.config import DEFAULT_LM_PARAMETERS_PATH +from bcipy.core.symbols import BACKSPACE_CHAR, SPACE_CHAR +from bcipy.exceptions import InvalidSymbolSetException + + +class LanguageModelAdapter(ABC): + """Abstract base class for textslinger language model adapters.""" + + symbol_set: Optional[List[str]] = None + model = None + + def predict_character(self, evidence: Union[str, List[str]]) -> List[Tuple]: + """ + Using the provided data, compute the probability distribution over the entire symbol set. + Args: + evidence - ['H', 'E'] + + Response: + probability - a list of symbols with probability + """ + + if self.symbol_set is None: + raise InvalidSymbolSetException("symbol set must be set prior to requesting predictions.") + + assert self.model is not None, "language model does not exist!" + + context = "".join(evidence) + converted_context = context.replace(SPACE_CHAR, ' ') + + # TODO: If toolkit dependency is updated to >=1.0.0, this will need to change to predict_character() + next_char_pred = dict(self.model.predict(list(converted_context))) + + # Replace space with special space + if ' ' in next_char_pred: + next_char_pred[SPACE_CHAR] = next_char_pred[' '] + del next_char_pred[' '] + + # Add backspace, but return prob 0 from the lm + next_char_pred[BACKSPACE_CHAR] = 0.0 + + return list(sorted(next_char_pred.items(), + key=lambda item: item[1], reverse=True)) + + def _load_parameters(self) -> None: + with open(DEFAULT_LM_PARAMETERS_PATH, 'r', + encoding='utf8') as params_file: + self.parameters = json.load(params_file) + + @abstractmethod + def _load_model(self) -> None: + """Load the model itself using stored parameters""" + + def set_symbol_set(self, symbol_set: List[str]) -> None: + """Update the symbol set and call for the model to be loaded""" + + self.symbol_set = symbol_set + + # LM doesn't care about backspace, needs literal space + self.model_symbol_set = [' ' if ch is SPACE_CHAR else ch for ch in self.symbol_set] + self.model_symbol_set.remove(BACKSPACE_CHAR) + + self._load_model() diff --git a/bcipy/language/model/causal.py b/bcipy/language/model/causal.py index 4a2a475f..ca493d24 100644 --- a/bcipy/language/model/causal.py +++ b/bcipy/language/model/causal.py @@ -1,42 +1,28 @@ -import torch -from typing import Optional, List, Tuple -from transformers import AutoModelForCausalLM, AutoTokenizer -import itertools -import heapq +from typing import Optional -from bcipy.core.symbols import BACKSPACE_CHAR, SPACE_CHAR -from bcipy.language.main import LanguageModel, ResponseType -from bcipy.exceptions import InvalidLanguageModelException +from textslinger.causal import CausalLanguageModel -from scipy.special import logsumexp -from scipy.special import softmax -import time -from collections import defaultdict -from typing import Final from bcipy.config import LM_PATH +from bcipy.language.model.adapter import LanguageModelAdapter -class CausalLanguageModel(LanguageModel): - """Character language model based on a pre-trained causal model, GPT-2 by default.""" +class CausalLanguageModelAdapter(LanguageModelAdapter): + """Character language model based on a pre-trained causal model.""" def __init__(self, - response_type: ResponseType, - symbol_set: List[str], lang_model_name: Optional[str] = None, lm_path: Optional[str] = None, lm_device: str = "cpu", lm_left_context: str = "", - beam_width: int = None, + beam_width: Optional[int] = None, fp16: bool = True, mixed_case_context: bool = True, case_simple: bool = True, - max_completed: int = None, + max_completed: Optional[int] = None, ): """ - Initialize instance variables and load the language model with given path + Initialize instance variables and load model parameters Args: - response_type - SYMBOL only - symbol_set - list of symbol strings lang_model_name - name of the Hugging Face casual language model to load lm_path - load fine-tuned model from specified directory lm_device - device to use for making predictions (cpu, mps, or cuda) @@ -47,33 +33,14 @@ def __init__(self, case_simple - simple fixing of left context case max_completed - stop search once we reach this many completed hypotheses, None=don't prune """ - super().__init__(response_type=response_type, symbol_set=symbol_set) + + self._load_parameters() causal_params = self.parameters['causal'] - self.model = None - self.tokenizer = None - self.vocab_size = 0 - self.valid_vocab = [] - self.vocab = defaultdict(list) - # Since subword token ids are integers, use a list instead of a - # dictionary - self.index_to_word = [] - self.index_to_word_lower = [] - self.symbol_set_lower = None - self.device = lm_device - self.left_context = lm_left_context self.beam_width = beam_width or int(causal_params['beam_width']['value']) - self.fp16 = fp16 - self.mixed_case_context = mixed_case_context - self.case_simple = case_simple self.max_completed = max_completed or int(causal_params['max_completed']['value']) - if not self.max_completed and not self.beam_width: - print("WARNING: using causal language model without any pruning, this can be slow!") - else: - print(f"Causal language model, beam_width {self.beam_width}, max_completed {self.max_completed}") - # We optionally load the model from a local directory, but if this is not # specified, we load a Hugging Face model @@ -82,443 +49,23 @@ def __init__(self, local_model_path = lm_path or causal_params['model_path']['value'] self.model_dir = f"{LM_PATH}/{local_model_path}" if local_model_path != "" else self.model_name - # Parameters for the search - - # Simple heuristic to correct case in the LM context - self.simple_upper_words = {"i": "I", - "i'll": "I'll", - "i've": "I've", - "i'd": "I'd", - "i'm": "I'm"} - - # Track how much time spent in different parts of the predict function - self.predict_total_ns = 0 - self.predict_inference_ns = 0 - - # Are we a model that automatically inserts a start token that we need - # to get rid of - self.drop_first_token = False - - self.load() - - def supported_response_types(self) -> List[ResponseType]: - return [ResponseType.SYMBOL] - - def _build_vocab(self) -> None: - """ - Build a vocabulary table mapping token index to word strings - """ - - # Loop over all the subword tokens in the LLM - for i in range(self.vocab_size): - # Create a map from the subword token integer ID to the mixed and - # lowercase string versions - word = self.tokenizer.decode([i]) - word_lower = word.lower() - self.index_to_word += word, - self.index_to_word_lower += word_lower, - - # Check if all the characters in the subword token are in our valid - # symbol set - valid = True - for ch in word_lower: - # The space char is only valid once we convert spaces to the - # space char - if ch == SPACE_CHAR: - valid = False - break - if ch == ' ': - continue - elif ch not in self.symbol_set_lower: - valid = False - break - - # If the subword token symbols are all valid, then add it to the - # list of valid token IDs - if valid: - self.valid_vocab += i, - # Add this token ID to all lists for its valid text prefixes - for j in range(len(word)): - key = word_lower[0:j + 1].replace(' ', SPACE_CHAR) - self.vocab[key] += i, - - # When done, self.vocab can be used to map to possible following subword tokens given some text, e.g.: - # self.vocab["cyclo"] = [47495, 49484] - # self.index_to_word[self.vocab["cyclo"][0]] = cyclop - # self.index_to_word[self.vocab["cyclo"][1]] = cyclopedia - - (self.model_name.startswith("facebook/opt") - or self.model_name.startswith("figmtu/opt") - or "Llama-3.1" in self.model_name) - - # Get the index we use for the start or end pseudo-word - if self.left_context == "": - if "gpt2" in self.model_name: - self.left_context = "<|endoftext|>" - elif "Llama-3.1" in self.model_name: - self.left_context = "<|begin_of_text|>" - # Seems to have both sentence start and end tokens: - # https://docs.mistral.ai/guides/tokenization/ - elif "Mistral" in self.model_name: - self.left_context = "" - else: - self.left_context = "" - - # OPT, Llama and Mistral all insert start token - self.drop_first_token = (self.model_name.startswith("facebook/opt") or - self.model_name.startswith("figmtu/opt") or - "Llama-3.1" in self.model_name or - "Mistral" in self.model_name) - - # Get token id(s) for the left context we condition all sentences on - self.left_context_tokens = self._encode(self.left_context) - print(f"Causal: left_context = '{self.left_context}', left_context_tokens = {self.left_context_tokens}") - - def _encode(self, text: str) -> List[int]: - tokens = self.tokenizer.encode(text) - # Both OPT and Llama automatically insert a start token which we want - # to control ourselves - if len(tokens) > 1 and self.drop_first_token: - tokens = tokens[1:] - - return tokens - - def _sequence_string(self, sequence: List[int]) -> str: - """ - Convert a sequence of subword token IDs into a string with each token in ()'s - :param sequence: List of subword token IDs - :return: String - """ - return "".join([f"({self.index_to_word[x]})" for x in sequence]) - - def get_all_tokens_text(self): - """ - Return an array with the text of all subword tokens. - The array is in order by the integer index into the vocabulary. - This is mostly just for exploring the tokens in different LLMs. - :return: Array of subword token text strings. - """ - result = [] - for i in range(self.vocab_size): - result.append(self.tokenizer.decode([i])) - return result - - def predict(self, evidence: List[str]) -> List[Tuple]: - """ - Given an evidence of typed string, predict the probability distribution of - the next symbol - Args: - evidence - a list of characters (typed by the user) - Response: - A list of symbols with probability - """ - - assert self.model is not None, "language model does not exist!" - start_ns = time.time_ns() - - converted_context = "".join(evidence) - converted_context_lower = converted_context.lower() - context = converted_context.replace(SPACE_CHAR, ' ') - - # If using the simple case feature, we need to go through the actual - # left context and capitalize the first letter in the sentence as - # well as any word in our list of words that should be capitalized. - if self.case_simple and len(context) > 0: - cased_context = "" - words = context.split() - for i, word in enumerate(words): - if i == 0 and word[0] >= 'a' and word[0] <= 'z': - word = word[0].upper() + word[1:] - if i > 0: - if word in self.simple_upper_words: - word = self.simple_upper_words[word] - cased_context += " " - cased_context += word - # Handle ending space in the context - if context[-1] == ' ': - cased_context += " " - context = cased_context - - context_lower = context.lower() - - # Index in the hypothesis string that is the next character after our - # context - target_pos = len(context_lower) - - # For stats purposes track length of the prefix we are extending from space to match - # prefix_len = target_pos - - # Look for the last space in the context, or -1 if no begin_text in - # context yet - pos = context_lower.rfind(" ") - tokens = [] - tokens.extend(self.left_context_tokens) - if pos >= 0: - # Optionally, we condition on upper and lower case left context - if self.mixed_case_context: - truncated_context = context[0:pos] - else: - truncated_context = context_lower[0:pos] - tokens.extend(self._encode(truncated_context)) - # prefix_len -= pos - - # print(f"DEBUG, {context_lower} pos {pos}, prefix_len {prefix_len}") - - # Constant indexes for use with the hypotheses tuples - LOGP: Final[int] = 0 - SEQ: Final[int] = 1 - LEN: Final[int] = 2 - - # Our starting hypothesis that we'll be extending. - # Format is (log likelihood, token id sequence, text length). - # Note: we only include tokens after any in left context. - start_length = 0 - for x in tokens[len(self.left_context_tokens):]: - start_length += len(self.index_to_word_lower[x]) - current_hypos = [(0.0, tokens, start_length)] - - # We use a priority queue to track the top hypotheses during the beam search. - # For a beam of 8, empirical testing showed this was about the same amount - # of time as a simpler list that used a linear search to replace when - # full. - heapq.heapify(current_hypos) - - # Create a hash mapping each valid following character to a list of log - # probabilities - char_to_log_probs = defaultdict(list) - - # Add new extended hypotheses to this heap - next_hypos = [] - - # Tracks count of completed hypotheses - completed = 0 - - # Used to signal to while loop to stop the search - done = False - - # Start a beam search forward from the backed off token sequence. - # Each iteration of this while loop extends hypotheses by all valid tokens. - # We only keep at most self.beam_width hypotheses in the valid heap. - # Stop extending search once we reach our max completed target. - while len(current_hypos) > 0 and not done: - # We'll explore hypothesis in order from most probable to least. - # This has little impact on how long it takes since this is only sorting a small number of things. - # But it is important with max_completed pruning since we want to - # bias for completing high probability things. - current_hypos.sort(reverse=True) - - # Work on the hypotheses from the last round of extension. - # Create the torch tensor for the inference with a row for each - # hypothesis. - tokens_tensor = torch.tensor([x[SEQ] for x in current_hypos]).reshape( - len(current_hypos), -1).to(self.device) - - before_inference_ns = time.time_ns() - # Ask the LLM to predict tokens that come after our current set of - # hypotheses - with torch.no_grad(): - # Compute the probabilities from the logits - log_probs = torch.log_softmax(self.model( - tokens_tensor).logits[:, -1, :], dim=1) - - # Create a big 2D tensor where each row is that hypothesis' current likelihood. - # First create a list of just the hypotheses' likelihoods. - # Then reshape to be a column vector. - # Then duplicate the column based on the number of subword - # tokens in the LLM. - add_tensor = torch.tensor([x[LOGP] for x in current_hypos]).reshape( - (log_probs.size()[0], 1)).repeat(1, log_probs.size()[1]).to(self.device) - - # Add the current likelihoods with each subtoken's probability. - # Move it back to the CPU and convert to numpy since this makes - # it a lot faster to access for some reason. - new_log_probs = torch.add( - log_probs, add_tensor).detach().cpu().numpy() - self.predict_inference_ns += time.time_ns() - before_inference_ns - - for current_index, current in enumerate(current_hypos): - vocab = [] - extra_vocab = [] - # Extending this hypothesis must match the remaining text - remaining_context = converted_context_lower[current[LEN]:] - if len(remaining_context) == 0: - # There is no remaining context thus all subword tokens that are valid under our symbol set - # should be considered when computing the probability of - # the next character. - vocab = self.valid_vocab - else: - if remaining_context in self.vocab: - # We have a list of subword tokens that match the remaining text. - # They could be the same length as the remaining text - # or longer and have the remaining text as a prefix. - vocab = self.vocab[remaining_context] - - # We may need to use a subword token that doesn't completely consume the remaining text. - # Find these by tokenizing all possible lengths of text - # starting from the current position. - for i in range(1, len(remaining_context)): - tokenization = self._encode( - context_lower[current[LEN]:current[LEN] + i]) - # Ignore tokenizations involving multiple tokens since - # they involve an ID we would have already added. - if len(tokenization) == 1: - extra_vocab += tokenization[0], - - # The below code takes the most time, results from pprofile on 5 phrases on an 2080 GPU: - # 299| 22484582| 89.5763| 3.9839e-06| 14.24%| for token_id in itertools.chain(vocab, extra_vocab): - # 300| 0| 0| 0| 0.00%| # For a hypothesis to finish it must extend beyond the existing typed context - # 301| 22483271| 93.7939| 4.17172e-06| 14.91%| subword_len = len(self.index_to_word_lower[token_id]) - # 302| 22483271| 92.8608| 4.13022e-06| 14.76%| if (current[LEN] + subword_len) > len(context): - # 303| 0| 0| 0| 0.00%| # Add this likelihood to the list for the character at the prediction position. - # 304| 0| 0| 0| 0.00%| # Tracking the list and doing logsumpexp later was faster than doing it for each add. - # 305| 22480431| 106.353| 4.73094e-06| 16.90%| char_to_log_probs[self.index_to_word_lower[token_id][target_pos - current[LEN]]] += new_log_probs[current_index][token_id], - # 306| 22480431| 92.689| 4.1231e-06| 14.73%| completed += 1 - # 307| 2840| 0.0124488| 4.38338e-06| 0.00%| elif not self.beam_width or len(next_hypos) < - # - # Tuning notes: - # - With a beam of 8 and max completed of 32,000, getting around 5x speedup on written dev set. - # - This results in a PPL increase of 0.0025 versus old results using only beam of >= 8. - # - Pruning based on log probability difference and based on minimum number of hypotheses per symbol in alphabet did worse. - # - Code for these other pruning methods was removed. - # Possible ways to make it faster: - # - Stop part way through the below for loop over (vocab, extra_vocab). But this seems weird since the token IDs are in - # no particular order, we'd be just stopping early on the last hypothesis being explored by the enclosing loop. - # - Sort the rows in the log prob results on the GPU. Use these to limit which token IDs we explore in the below - # for loop. Is it possible to do this without introducing too - # much extra work to limit to the high probability ones? - - # Create a list of token indexes that are a prefix of the target text. - # We go over all the integer IDs in the vocab and extra_vocab - # lists. - for token_id in itertools.chain(vocab, extra_vocab): - # For a hypothesis to finish it must extend beyond the - # existing typed context - subword_len = len(self.index_to_word_lower[token_id]) - if (current[LEN] + subword_len) > len(context): - # Add this likelihood to the list for the character at the prediction position. - # Tracking the list and doing logsumpexp later was - # faster than doing it for each add. - char_to_log_probs[self.index_to_word_lower[token_id][target_pos - - current[LEN]]] += new_log_probs[current_index][token_id], - completed += 1 - elif not self.beam_width or len(next_hypos) < self.beam_width: - # If we are under the beam limit then just add it - heapq.heappush(next_hypos, - (new_log_probs[current_index][token_id], - current[SEQ] + [token_id], - current[LEN] + subword_len)) - elif new_log_probs[current_index][token_id] > next_hypos[0][LOGP]: - # Or replace the worst hypotheses with the new one - heapq.heappushpop(next_hypos, - (new_log_probs[current_index][token_id], - current[SEQ] + [token_id], - current[LEN] + subword_len)) - - # Break out of the for loop over hypotheses and while loop if - # we reach our max completed goal - if self.max_completed and completed >= self.max_completed: - done = True - break - - # Swap in the extended set as the new current working set - current_hypos = next_hypos - next_hypos = [] - - # Parallel array to symbol_set for storing the marginals - char_probs = [] - for ch in self.symbol_set_lower: - # Convert space to the underscore used in BciPy - if ch == SPACE_CHAR: - target_ch = ' ' - else: - target_ch = ch - - # Handle cases when symbols are never seen - if target_ch in char_to_log_probs: - char_probs += logsumexp(char_to_log_probs[target_ch]), - else: - char_probs += float("-inf"), - - # Normalize to a distribution that sums to 1 - char_probs = softmax(char_probs) - - next_char_pred = {} - for i, ch in enumerate(self.symbol_set_lower): - if ch is SPACE_CHAR: - next_char_pred[ch] = char_probs[i] - else: - next_char_pred[ch.upper()] = char_probs[i] - next_char_pred[BACKSPACE_CHAR] = 0.0 - - end_ns = time.time_ns() - self.predict_total_ns += end_ns - start_ns - - return list(sorted(next_char_pred.items(), - key=lambda item: item[1], reverse=True)) - - def dump_predict_times(self) -> None: - """Print some stats about the prediction timing""" - if self.predict_total_ns > 0: - print(f"Predict %: " - f"inference {self.predict_inference_ns / self.predict_total_ns * 100.0:.3f}") - - def update(self) -> None: - """Update the model state""" - ... - - def load(self) -> None: - """ - Load the language model and tokenizer, initialize class variables - """ - try: - self.tokenizer = AutoTokenizer.from_pretrained( - self.model_name, use_fast=False) - except BaseException: - raise InvalidLanguageModelException( - f"{self.model_name} is not a valid model identifier on HuggingFace.") - self.vocab_size = self.tokenizer.vocab_size - try: - self.model = AutoModelForCausalLM.from_pretrained(self.model_dir) - if self.fp16 and self.device == "cuda": - self.model = self.model.half() - except BaseException: - raise InvalidLanguageModelException( - f"{self.model_dir} is not a valid local folder or model identifier on HuggingFace.") - - self.model.eval() - - self.model.to(self.device) - - self.symbol_set_lower = [] - - for ch in self.symbol_set: - if ch is SPACE_CHAR: - self.symbol_set_lower.append(SPACE_CHAR) - elif ch is BACKSPACE_CHAR: - continue - else: - self.symbol_set_lower.append(ch.lower()) - - self._build_vocab() - - def get_num_parameters(self) -> int: - """ - Find out how many parameters the loaded model has - Args: - Response: - Integer number of parameters in the transformer model - """ - return sum(p.numel() for p in self.model.parameters()) - - def state_update(self, evidence: List[str]) -> List[Tuple]: - """ - Wrapper method that takes in evidence text, and output probability distribution - of next character - Args: - evidence - a list of characters (typed by the user) - Response: - A list of symbol with probability - """ - next_char_pred = self.predict(evidence) + self.lm_device = lm_device + self.lm_left_context = lm_left_context + self.fp16 = fp16 + self.mixed_case_context = mixed_case_context + self.case_simple = case_simple - return next_char_pred + def _load_model(self) -> None: + """Load the model itself using stored parameters""" + + self.model = CausalLanguageModel( + symbol_set=self.model_symbol_set, + lang_model_name=self.model_name, + lm_path=self.model_dir, + lm_device=self.lm_device, + lm_left_context=self.lm_left_context, + beam_width=self.beam_width, + fp16=self.fp16, + mixed_case_context=self.mixed_case_context, + case_simple=self.case_simple, + max_completed=self.max_completed) diff --git a/bcipy/language/model/kenlm.py b/bcipy/language/model/kenlm.py deleted file mode 100644 index 05b9bbf1..00000000 --- a/bcipy/language/model/kenlm.py +++ /dev/null @@ -1,137 +0,0 @@ -from collections import Counter -from typing import Optional, List, Tuple -from bcipy.core.symbols import BACKSPACE_CHAR, SPACE_CHAR -from bcipy.language.main import LanguageModel, ResponseType -from bcipy.exceptions import InvalidLanguageModelException, KenLMInstallationException -from bcipy.config import LM_PATH -try: - import kenlm -except BaseException: - raise KenLMInstallationException( - "Please install the requisite kenlm package:\n'pip install kenlm==0.1 --global-option=\"--max_order=12\"") -import numpy as np - - -class KenLMLanguageModel(LanguageModel): - """Character n-gram language model using the KenLM library for querying""" - - def __init__(self, response_type: ResponseType, symbol_set: List[str], lm_path: Optional[str] = None): - - super().__init__(response_type=response_type, symbol_set=symbol_set) - self.model = None - kenlm_params = self.parameters['kenlm'] - kenlm_model = kenlm_params['model_file']['value'] - self.lm_path = lm_path or f"{LM_PATH}/{kenlm_model}" - - self.load() - - def supported_response_types(self) -> List[ResponseType]: - return [ResponseType.SYMBOL] - - def predict(self, evidence: List[str]) -> List[Tuple]: - """ - Given an evidence of typed string, predict the probability distribution of - the next symbol - Args: - evidence - a list of characters (typed by the user) - Response: - A list of symbols with probability - """ - - # Do not modify the original parameter, could affect mixture model - context = evidence.copy() - - if len(context) > 11: - context = context[-11:] - - evidence_str = ''.join(context).lower() - - for i, ch in enumerate(context): - if ch == SPACE_CHAR: - context[i] = "" - - self.model.BeginSentenceWrite(self.state) - - # Update the state one token at a time based on evidence, alternate states - for i, token in enumerate(context): - if i % 2 == 0: - self.model.BaseScore(self.state, token.lower(), self.state2) - else: - self.model.BaseScore(self.state2, token.lower(), self.state) - - next_char_pred = None - - # Generate the probability distribution based on the final state - if len(context) % 2 == 0: - next_char_pred = self.prob_dist(self.state) - else: - next_char_pred = self.prob_dist(self.state2) - - return next_char_pred - - def update(self) -> None: - """Update the model state""" - ... - - def load(self) -> None: - """ - Load the language model, initialize state variables - Args: - path: language model file path - """ - - try: - self.model = kenlm.LanguageModel(self.lm_path) - except BaseException: - raise InvalidLanguageModelException( - f"A valid model path must be provided for the KenLMLanguageModel.\nPath{self.lm_path} is not valid.") - - self.state = kenlm.State() - self.state2 = kenlm.State() - - def state_update(self, evidence: List[str]) -> List[Tuple]: - """ - Wrapper method that takes in evidence text and outputs probability distribution - of next character - Args: - evidence - a list of characters (typed by the user) - Response: - A list of symbols with probabilities - """ - next_char_pred = self.predict(evidence) - - return next_char_pred - - def prob_dist(self, state: kenlm.State) -> List[Tuple]: - """ - Take in a state and generate the probability distribution of next character - Args: - state - the kenlm state updated with the evidence - Response: - A list of symbols with probability - """ - next_char_pred = Counter() - - temp_state = kenlm.State() - - for char in self.symbol_set: - # Backspace probability under the LM is 0 - if char == BACKSPACE_CHAR: - next - - score = 0.0 - - # Replace the space character with KenLM's token - if char == SPACE_CHAR: - score = self.model.BaseScore(state, '', temp_state) - else: - score = self.model.BaseScore(state, char.lower(), temp_state) - - # BaseScore returns log probs, convert by putting 10 to its power - next_char_pred[char] = pow(10, score) - - sum = np.sum(list(next_char_pred.values())) - for char in self.symbol_set: - next_char_pred[char] /= sum - - return list(sorted(next_char_pred.items(), key=lambda item: item[1], reverse=True)) diff --git a/bcipy/language/model/mixture.py b/bcipy/language/model/mixture.py index 36e8cdd4..f337ca67 100644 --- a/bcipy/language/model/mixture.py +++ b/bcipy/language/model/mixture.py @@ -1,61 +1,25 @@ -from collections import Counter -from typing import Optional, Dict, List, Tuple -from math import isclose +from typing import Dict, List, Optional -from bcipy.language.main import LanguageModel, ResponseType +from textslinger.mixture import MixtureLanguageModel -from bcipy.exceptions import InvalidLanguageModelException +from bcipy.config import LM_PATH +from bcipy.language.model.adapter import LanguageModelAdapter -# pylint: disable=unused-import -# flake8: noqa -"""All supported models must be imported""" -from bcipy.language.model.causal import CausalLanguageModel -from bcipy.language.model.kenlm import KenLMLanguageModel - -class MixtureLanguageModel(LanguageModel): +class MixtureLanguageModelAdapter(LanguageModelAdapter): """ Character language model that mixes any combination of other models """ - supported_lm_types = ["CAUSAL", "KENLM"] - - @staticmethod - def language_models_by_name() -> Dict[str, LanguageModel]: - """Returns available language models indexed by name.""" - return {lm.name(): lm for lm in LanguageModel.__subclasses__()} - - @staticmethod - def validate_parameters(types: List[str], weights: List[float], params: List[Dict[str, str]]): - if params is not None: - if (types is None) or (len(types) != len(params)): - raise InvalidLanguageModelException("Length of parameters does not match length of types") - - if weights is not None: - if (types is None) or (len(types) != len(weights)): - raise InvalidLanguageModelException("Length of weights does not match length of types") - if not isclose(sum(weights), 1.0, abs_tol=1e-05): - raise InvalidLanguageModelException("Weights do not sum to 1") - - if types is not None: - if weights is None: - raise InvalidLanguageModelException("Model weights not provided") - if params is None: - raise InvalidLanguageModelException("Model parameters not provided") - if not all(x in MixtureLanguageModel.supported_lm_types for x in types): - raise InvalidLanguageModelException(f"Supported model types: {MixtureLanguageModel.supported_lm_types}") + supported_lm_types = MixtureLanguageModel.supported_lm_types def __init__(self, - response_type: ResponseType, - symbol_set: List[str], lm_types: Optional[List[str]] = None, lm_weights: Optional[List[float]] = None, lm_params: Optional[List[Dict[str, str]]] = None): """ - Initialize instance variables and load the language model with given path + Initialize instance variables and load parameters Args: - response_type - SYMBOL only - symbol_set - list of symbol strings lm_types - list of types of models to mix lm_weights - list of weights to use when mixing the models lm_params - list of dictionaries to pass as parameters for each model's instantiation @@ -63,90 +27,20 @@ def __init__(self, MixtureLanguageModel.validate_parameters(lm_types, lm_weights, lm_params) - super().__init__(response_type=response_type, symbol_set=symbol_set) - self.models = list() - self.response_type = response_type - self.symbol_set = symbol_set + self._load_parameters() mixture_params = self.parameters['mixture'] self.lm_types = lm_types or mixture_params['model_types']['value'] self.lm_weights = lm_weights or mixture_params['model_weights']['value'] self.lm_params = lm_params or mixture_params['model_params']['value'] - MixtureLanguageModel.validate_parameters(self.lm_types, self.lm_weights, self.lm_params) - - self.load() - - def supported_response_types(self) -> List[ResponseType]: - return [ResponseType.SYMBOL] - - @staticmethod - def interpolate_language_models(lms: List[Dict[str, float]], coeffs: List[float]) -> List[Tuple]: - """ - interpolate two or more language models - Args: - lms - output from the language models (a list of dicts with char as keys and prob as values) - coeffs - list of rescale coefficients, lms[0] will be scaled by coeffs[0] and so on - Response: - a list of (char, prob) tuples representing an interpolated language model - """ - combined_lm = Counter() - - for i, lm in enumerate(lms): - for char in lm: - combined_lm[char] += lm[char] * coeffs[i] - - return list(sorted(combined_lm.items(), key=lambda item: item[1], reverse=True)) - - def predict(self, evidence: List[str]) -> List[Tuple]: - """ - Given an evidence of typed string, predict the probability distribution of - the next symbol - Args: - evidence - a list of characters (typed by the user) - Response: - A list of symbols with probability - """ - - pred_list = list() - - # Generate predictions from each component language model - pred_list = [dict(model.predict(evidence)) for model in self.models] - - # Mix the component models - next_char_pred = MixtureLanguageModel.interpolate_language_models(pred_list, self.lm_weights) - - return next_char_pred - - def update(self) -> None: - """Update the model state""" - ... - - def load(self) -> None: - """ - Load the language models to be mixed - """ - - language_models = MixtureLanguageModel.language_models_by_name() - for lm_type, params in zip(self.lm_types, self.lm_params): - model = language_models[lm_type] - lm = None - try: - lm = model(self.response_type, self.symbol_set, **params) - except InvalidLanguageModelException as e: - raise InvalidLanguageModelException(f"Error in creation of model type {lm_type}: {e.message}") - - self.models.append(lm) + for type, params in zip(self.lm_types, self.lm_params): + if type == "NGRAM": + params["lm_path"] = f"{LM_PATH}/{params['lm_path']}" - def state_update(self, evidence: List[str]) -> List[Tuple]: - """ - Wrapper method that takes in evidence text, and output probability distribution - of next character - Args: - evidence - a list of characters (typed by the user) - Response: - A list of symbol with probability - """ - next_char_pred = self.predict(evidence) + MixtureLanguageModel.validate_parameters(self.lm_types, self.lm_weights, self.lm_params) - return next_char_pred + def _load_model(self) -> None: + """Load the model itself using stored parameters""" + self.model = MixtureLanguageModel(self.model_symbol_set, self.lm_types, + self.lm_weights, self.lm_params) diff --git a/bcipy/language/model/ngram.py b/bcipy/language/model/ngram.py new file mode 100644 index 00000000..226860a4 --- /dev/null +++ b/bcipy/language/model/ngram.py @@ -0,0 +1,29 @@ +from typing import Optional + +from textslinger.ngram import NGramLanguageModel + +from bcipy.config import LM_PATH +from bcipy.language.model.adapter import LanguageModelAdapter + + +class NGramLanguageModelAdapter(LanguageModelAdapter): + """Character n-gram language model using the KenLM library for querying""" + + def __init__(self, + lm_path: Optional[str] = None): + """ + Initialize instance variables and load parameters + Args: + lm_path - location of local ngram model - loaded from parameters if None + """ + + self._load_parameters() + + ngram_params = self.parameters['ngram'] + ngram_model = ngram_params['model_file']['value'] + self.lm_path = lm_path or f"{LM_PATH}/{ngram_model}" + + def _load_model(self) -> None: + """Load the model itself using stored parameters""" + self.model = NGramLanguageModel(symbol_set=self.model_symbol_set, + lm_path=self.lm_path) diff --git a/bcipy/language/model/oracle.py b/bcipy/language/model/oracle.py index 73896fb5..c1f1157e 100644 --- a/bcipy/language/model/oracle.py +++ b/bcipy/language/model/oracle.py @@ -5,17 +5,18 @@ import numpy as np from bcipy.config import SESSION_LOG_FILENAME -from bcipy.core.symbols import BACKSPACE_CHAR -from bcipy.language.main import LanguageModel, ResponseType +from bcipy.core.symbols import BACKSPACE_CHAR, DEFAULT_SYMBOL_SET +from bcipy.exceptions import InvalidSymbolSetException +from bcipy.language.main import CharacterLanguageModel from bcipy.language.model.uniform import equally_probable logger = logging.getLogger(SESSION_LOG_FILENAME) TARGET_BUMP_MIN = 0.0 -TARGET_BUMP_MAX = 0.95 +TARGET_BUMP_MAX = 1.0 -class OracleLanguageModel(LanguageModel): +class OracleLanguageModel(CharacterLanguageModel): """Language model which knows the target phrase the user is attempting to spell. @@ -28,25 +29,28 @@ class OracleLanguageModel(LanguageModel): Parameters ---------- - response_type - SYMBOL only - symbol_set - optional specify the symbol set, otherwise uses DEFAULT_SYMBOL_SET task_text - the phrase the user is attempting to spell (ex. 'HELLO_WORLD') target_bump - the amount by which the probability of the target letter is increased. """ def __init__(self, - response_type: Optional[ResponseType] = None, - symbol_set: Optional[List[str]] = None, - task_text: str = None, + task_text: Optional[str] = None, target_bump: float = 0.1): - super().__init__(response_type=response_type, symbol_set=symbol_set) + self.task_text = task_text self.target_bump = target_bump + + self.symbol_set = DEFAULT_SYMBOL_SET + logger.debug( f"Initialized OracleLanguageModel(task_text='{task_text}', target_bump={target_bump})" ) + def set_symbol_set(self, symbol_set: List[str]) -> None: + """Updates the symbol set of the model. Must be called prior to prediction""" + self.symbol_set = symbol_set + @property def task_text(self): """Get the task_text property""" @@ -70,10 +74,7 @@ def target_bump(self, value: float): assert TARGET_BUMP_MIN <= value <= TARGET_BUMP_MAX, msg self._target_bump = value - def supported_response_types(self) -> List[ResponseType]: - return [ResponseType.SYMBOL] - - def predict(self, evidence: Union[str, List[str]]) -> List[Tuple]: + def predict_character(self, evidence: Union[str, List[str]]) -> List[Tuple]: """ Using the provided data, compute probabilities over the entire symbol. set. @@ -86,25 +87,32 @@ def predict(self, evidence: Union[str, List[str]]) -> List[Tuple]: ------- list of (symbol, probability) tuples """ + + if not self.symbol_set: + raise InvalidSymbolSetException( + "symbol set must be set prior to requesting predictions.") + spelled_text = ''.join(evidence) - probs = equally_probable(self.symbol_set) - symbol_probs = list(zip(self.symbol_set, probs)) target = self._next_target(spelled_text) + symbol_probs = {} + if target: - sym = (target, probs[0] + self.target_bump) - updated_symbol_probs = with_min_prob(symbol_probs, sym) + # non-target prob = x = (1-b)/n where n is len(symbol_set) and b is target_bump + # target prob = x + b + non_target_prob = (1 - self.target_bump) / len(self.symbol_set) + for ch in self.symbol_set: + if ch == target: + symbol_probs[ch] = non_target_prob + self.target_bump + else: + symbol_probs[ch] = non_target_prob else: - updated_symbol_probs = symbol_probs - - return sorted(updated_symbol_probs, - key=lambda pair: self.symbol_set.index(pair[0])) - - def update(self) -> None: - """Update the model state""" + symbol_probs = dict( + zip(self.symbol_set, equally_probable(self.symbol_set))) - def load(self) -> None: - """Restore model state from the provided checkpoint""" + return sorted(symbol_probs.items(), + key=lambda item: item[1], + reverse=True) def _next_target(self, spelled_text: str) -> Optional[str]: """Computes the next target letter based on the currently spelled_text. diff --git a/bcipy/language/model/uniform.py b/bcipy/language/model/uniform.py index 013735ba..e2dc52e0 100644 --- a/bcipy/language/model/uniform.py +++ b/bcipy/language/model/uniform.py @@ -1,30 +1,34 @@ """Uniform language model""" -from typing import Dict, List, Tuple, Union, Optional +from typing import Dict, List, Optional, Tuple, Union import numpy as np -from bcipy.language.main import LanguageModel, ResponseType +from bcipy.core.symbols import BACKSPACE_CHAR, DEFAULT_SYMBOL_SET +from bcipy.exceptions import InvalidSymbolSetException +from bcipy.language.main import CharacterLanguageModel -class UniformLanguageModel(LanguageModel): +class UniformLanguageModel(CharacterLanguageModel): """Language model in which probabilities for symbols are uniformly distributed. Parameters ---------- - response_type - SYMBOL only - symbol_set - optional specify the symbol set, otherwise uses DEFAULT_SYMBOL_SET + None """ - def __init__(self, - response_type: Optional[ResponseType] = None, - symbol_set: Optional[List[str]] = None): - super().__init__(response_type=response_type, symbol_set=symbol_set) + def __init__(self): + self.set_symbol_set(DEFAULT_SYMBOL_SET) - def supported_response_types(self) -> List[ResponseType]: - return [ResponseType.SYMBOL] + def set_symbol_set(self, symbol_set: List[str]) -> None: + """Updates the symbol set of the model. Must be called prior to prediction""" + self.symbol_set = symbol_set - def predict(self, evidence: Union[str, List[str]]) -> List[Tuple]: + self.model_symbol_set = [ch for ch in symbol_set] + if BACKSPACE_CHAR in symbol_set: + self.model_symbol_set.remove(BACKSPACE_CHAR) + + def predict_character(self, evidence: Union[str, List[str]]) -> List[Tuple]: """ Using the provided data, compute probabilities over the entire symbol. set. @@ -37,18 +41,19 @@ def predict(self, evidence: Union[str, List[str]]) -> List[Tuple]: ------- list of (symbol, probability) tuples """ - probs = equally_probable(self.symbol_set) - return list(zip(self.symbol_set, probs)) - def update(self) -> None: - """Update the model state""" + if not self.symbol_set: + raise InvalidSymbolSetException( + "symbol set must be set prior to requesting predictions.") - def load(self) -> None: - """Restore model state from the provided checkpoint""" + probs = equally_probable(self.model_symbol_set) + return list(zip(self.model_symbol_set, probs)) + [(BACKSPACE_CHAR, 0.0) + ] -def equally_probable(alphabet: List[str], - specified: Optional[Dict[str, float]] = None) -> List[float]: +def equally_probable( + alphabet: List[str], + specified: Optional[Dict[str, float]] = None) -> List[float]: """Returns a list of probabilities which correspond to the provided alphabet. Unless overridden by the specified values, all items will have the same probability. All probabilities sum to 1.0. diff --git a/bcipy/language/tests/test_causal.py b/bcipy/language/tests/test_causal.py index da685cb8..e1f13467 100644 --- a/bcipy/language/tests/test_causal.py +++ b/bcipy/language/tests/test_causal.py @@ -4,114 +4,110 @@ import unittest from operator import itemgetter -from bcipy.exceptions import UnsupportedResponseType, InvalidLanguageModelException -from bcipy.core.symbols import alphabet, BACKSPACE_CHAR, SPACE_CHAR -from bcipy.language.model.causal import CausalLanguageModel -from bcipy.language.main import ResponseType +from bcipy.exceptions import InvalidSymbolSetException +from bcipy.core.symbols import DEFAULT_SYMBOL_SET, BACKSPACE_CHAR, SPACE_CHAR +from bcipy.language.model.causal import CausalLanguageModelAdapter +from bcipy.language.main import CharacterLanguageModel + +from textslinger.exceptions import InvalidLanguageModelException @pytest.mark.slow -class TestCausalLanguageModel(unittest.TestCase): +class TestCausalLanguageModelAdapter(unittest.TestCase): """Tests for language model""" @classmethod def setUpClass(cls): - cls.gpt2_model = CausalLanguageModel(response_type=ResponseType.SYMBOL, - symbol_set=alphabet(), lang_model_name="gpt2") - cls.opt_model = CausalLanguageModel(response_type=ResponseType.SYMBOL, - symbol_set=alphabet(), lang_model_name="facebook/opt-125m") + cls.gpt2_model = CausalLanguageModelAdapter(lang_model_name="gpt2") + cls.gpt2_model.set_symbol_set(DEFAULT_SYMBOL_SET) + + cls.opt_model = CausalLanguageModelAdapter(lang_model_name="facebook/opt-125m") + cls.opt_model.set_symbol_set(DEFAULT_SYMBOL_SET) @pytest.mark.slow def test_default_load(self): """Test loading model with parameters from json This test requires a valid lm_params.json file and all requisite models""" - lm = CausalLanguageModel(response_type=ResponseType.SYMBOL, symbol_set=alphabet()) + lm = CausalLanguageModelAdapter() + lm.set_symbol_set(DEFAULT_SYMBOL_SET) def test_gpt2_init(self): """Test default parameters for GPT-2 model""" - self.assertEqual(self.gpt2_model.response_type, ResponseType.SYMBOL) - self.assertEqual(self.gpt2_model.symbol_set, alphabet()) - self.assertTrue( - ResponseType.SYMBOL in self.gpt2_model.supported_response_types()) - self.assertEqual(self.gpt2_model.left_context, "<|endoftext|>") - self.assertEqual(self.gpt2_model.device, "cpu") + self.assertEqual(self.gpt2_model.symbol_set, DEFAULT_SYMBOL_SET) + self.assertTrue(isinstance(self.gpt2_model, CharacterLanguageModel)) + self.assertEqual(self.gpt2_model.model.left_context, "<|endoftext|>") + self.assertEqual(self.gpt2_model.model.device, "cpu") def test_opt_init(self): """Test default parameters for Facebook OPT model""" - self.assertEqual(self.opt_model.response_type, ResponseType.SYMBOL) - self.assertEqual(self.opt_model.symbol_set, alphabet()) - self.assertTrue( - ResponseType.SYMBOL in self.opt_model.supported_response_types()) - self.assertEqual(self.opt_model.left_context, "") - self.assertEqual(self.opt_model.device, "cpu") - - def test_name(self): - """Test model name.""" - self.assertEqual("CAUSAL", CausalLanguageModel.name()) + self.assertEqual(self.opt_model.symbol_set, DEFAULT_SYMBOL_SET) + self.assertTrue(isinstance(self.opt_model, CharacterLanguageModel)) + self.assertEqual(self.opt_model.model.left_context, "") + self.assertEqual(self.opt_model.model.device, "cpu") - def test_unsupported_response_type(self): + def test_invalid_symbol_set(self): """Unsupported responses should raise an exception""" - with self.assertRaises(UnsupportedResponseType): - CausalLanguageModel(response_type=ResponseType.WORD, - symbol_set=alphabet(), lang_model_name="gpt2") + with self.assertRaises(InvalidSymbolSetException): + lm = CausalLanguageModelAdapter(lang_model_name="gpt2") + lm.predict_character("this_should_fail") def test_invalid_model_name(self): """Test that the proper exception is thrown if given an invalid lang_model_name""" with self.assertRaises(InvalidLanguageModelException): - CausalLanguageModel(response_type=ResponseType.SYMBOL, symbol_set=alphabet(), - lang_model_name="phonymodel") + lm = CausalLanguageModelAdapter(lang_model_name="phonymodel") + lm.set_symbol_set(DEFAULT_SYMBOL_SET) def test_invalid_model_path(self): """Test that the proper exception is thrown if given an invalid lm_path""" with self.assertRaises(InvalidLanguageModelException): - CausalLanguageModel(response_type=ResponseType.SYMBOL, symbol_set=alphabet(), - lang_model_name="gpt2", lm_path="./phonypath/") + lm = CausalLanguageModelAdapter(lang_model_name="gpt2", lm_path="./phonypath/") + lm.set_symbol_set(DEFAULT_SYMBOL_SET) def test_non_mutable_evidence(self): """Test that the model does not change the evidence variable passed in. This could impact the mixture model if failed""" evidence = list("Test_test") evidence2 = list("Test_test") - self.gpt2_model.predict(evidence) + self.gpt2_model.predict_character(evidence) self.assertEqual(evidence, evidence2) - self.opt_model.predict(evidence) + self.opt_model.predict_character(evidence) self.assertEqual(evidence, evidence2) def test_gpt2_identical(self): """Ensure predictions are the same for subsequent queries with the same evidence.""" - query1 = self.gpt2_model.predict(list("evidenc")) - query2 = self.gpt2_model.predict(list("evidenc")) + query1 = self.gpt2_model.predict_character(list("evidenc")) + query2 = self.gpt2_model.predict_character(list("evidenc")) for ((sym1, prob1), (sym2, prob2)) in zip(query1, query2): self.assertAlmostEqual(prob1, prob2, places=5) self.assertEqual(sym1, sym2) def test_opt_identical(self): """Ensure predictions are the same for subsequent queries with the same evidence.""" - query1 = self.opt_model.predict(list("evidenc")) - query2 = self.opt_model.predict(list("evidenc")) + query1 = self.opt_model.predict_character(list("evidenc")) + query2 = self.opt_model.predict_character(list("evidenc")) for ((sym1, prob1), (sym2, prob2)) in zip(query1, query2): self.assertAlmostEqual(prob1, prob2, places=5) self.assertEqual(sym1, sym2) def test_gpt2_upper_lower_case(self): """Ensure predictions are the same for upper or lower case evidence.""" - lc = self.gpt2_model.predict(list("EVIDENC")) - uc = self.gpt2_model.predict(list("evidenc")) + lc = self.gpt2_model.predict_character(list("EVIDENC")) + uc = self.gpt2_model.predict_character(list("evidenc")) for ((l_sym, l_prob), (u_sym, u_prob)) in zip(lc, uc): self.assertAlmostEqual(l_prob, u_prob, places=5) self.assertEqual(l_sym, u_sym) def test_opt_upper_lower_case(self): """Ensure predictions are the same for upper or lower case evidence.""" - lc = self.opt_model.predict(list("EVIDENC")) - uc = self.opt_model.predict(list("evidenc")) + lc = self.opt_model.predict_character(list("EVIDENC")) + uc = self.opt_model.predict_character(list("evidenc")) for ((l_sym, l_prob), (u_sym, u_prob)) in zip(lc, uc): self.assertAlmostEqual(l_prob, u_prob, places=5) self.assertEqual(l_sym, u_sym) def test_gpt2_predict_start_of_word(self): """Test the gpt2 predict method with no prior evidence.""" - symbol_probs = self.gpt2_model.predict(evidence=[]) + symbol_probs = self.gpt2_model.predict_character(evidence=[]) probs = [prob for sym, prob in symbol_probs] self.assertTrue( @@ -126,7 +122,7 @@ def test_gpt2_predict_start_of_word(self): def test_opt_predict_start_of_word(self): """Test the Facebook opt predict method with no prior evidence.""" - symbol_probs = self.opt_model.predict(evidence=[]) + symbol_probs = self.opt_model.predict_character(evidence=[]) probs = [prob for sym, prob in symbol_probs] self.assertTrue( @@ -141,7 +137,7 @@ def test_opt_predict_start_of_word(self): def test_gpt2_predict_middle_of_word(self): """Test the predict method in the middle of a word with gpt2 model.""" - symbol_probs = self.gpt2_model.predict(evidence=list("TH")) + symbol_probs = self.gpt2_model.predict_character(evidence=list("TH")) probs = [prob for sym, prob in symbol_probs] self.assertTrue( @@ -159,7 +155,7 @@ def test_gpt2_predict_middle_of_word(self): def test_opt_predict_middle_of_word(self): """Test the predict method in the middle of a word with Facebook opt model.""" - symbol_probs = self.opt_model.predict(evidence=list("TH")) + symbol_probs = self.opt_model.predict_character(evidence=list("TH")) probs = [prob for sym, prob in symbol_probs] self.assertTrue( @@ -177,7 +173,7 @@ def test_opt_predict_middle_of_word(self): def test_gpt2_phrase(self): """Test that a phrase can be used for input with gpt2 model""" - symbol_probs = self.gpt2_model.predict(list("does_it_make_sen")) + symbol_probs = self.gpt2_model.predict_character(list("does_it_make_sen")) most_likely_sym, _prob = sorted(symbol_probs, key=itemgetter(1), reverse=True)[0] @@ -185,7 +181,7 @@ def test_gpt2_phrase(self): def test_opt_phrase(self): """Test that a phrase can be used for input with Facebook opt model""" - symbol_probs = self.opt_model.predict(list("does_it_make_sen")) + symbol_probs = self.opt_model.predict_character(list("does_it_make_sen")) most_likely_sym, _prob = sorted(symbol_probs, key=itemgetter(1), reverse=True)[0] @@ -193,30 +189,30 @@ def test_opt_phrase(self): def test_gpt2_multiple_spaces(self): """Test that the probability of space after a space is smaller than before the space""" - symbol_probs_before = self.gpt2_model.predict(list("the")) - symbol_probs_after = self.gpt2_model.predict(list("the_")) + symbol_probs_before = self.gpt2_model.predict_character(list("the")) + symbol_probs_after = self.gpt2_model.predict_character(list("the_")) space_prob_before = (dict(symbol_probs_before))[SPACE_CHAR] space_prob_after = (dict(symbol_probs_after))[SPACE_CHAR] self.assertTrue(space_prob_before > space_prob_after) def test_opt_multiple_spaces(self): """Test that the probability of space after a space is smaller than before the space""" - symbol_probs_before = self.opt_model.predict(list("the")) - symbol_probs_after = self.opt_model.predict(list("the_")) + symbol_probs_before = self.opt_model.predict_character(list("the")) + symbol_probs_after = self.opt_model.predict_character(list("the_")) space_prob_before = (dict(symbol_probs_before))[SPACE_CHAR] space_prob_after = (dict(symbol_probs_after))[SPACE_CHAR] self.assertTrue(space_prob_before > space_prob_after) def test_gpt2_nonzero_prob(self): """Test that all letters in the alphabet have nonzero probability except for backspace""" - symbol_probs = self.gpt2_model.predict(list("does_it_make_sens")) + symbol_probs = self.gpt2_model.predict_character(list("does_it_make_sens")) prob_values = [item[1] for item in symbol_probs if item[0] != BACKSPACE_CHAR] for value in prob_values: self.assertTrue(value > 0) def test_opt_nonzero_prob(self): """Test that all letters in the alphabet have nonzero probability except for backspace""" - symbol_probs = self.opt_model.predict(list("does_it_make_sens")) + symbol_probs = self.opt_model.predict_character(list("does_it_make_sens")) prob_values = [item[1] for item in symbol_probs if item[0] != BACKSPACE_CHAR] for value in prob_values: self.assertTrue(value > 0) diff --git a/bcipy/language/tests/test_mixture.py b/bcipy/language/tests/test_mixture.py index aefe5b3a..107cfc2c 100644 --- a/bcipy/language/tests/test_mixture.py +++ b/bcipy/language/tests/test_mixture.py @@ -1,108 +1,110 @@ """Tests for MIXTURE Language Model""" -import pytest -import unittest import os +import unittest from operator import itemgetter -from bcipy.exceptions import UnsupportedResponseType, InvalidLanguageModelException -from bcipy.core.symbols import alphabet, BACKSPACE_CHAR, SPACE_CHAR -from bcipy.language.model.mixture import MixtureLanguageModel -from bcipy.language.main import ResponseType +import pytest +from textslinger.exceptions import InvalidLanguageModelException + +from bcipy.core.symbols import BACKSPACE_CHAR, DEFAULT_SYMBOL_SET, SPACE_CHAR +from bcipy.exceptions import InvalidSymbolSetException +from bcipy.language.main import CharacterLanguageModel +from bcipy.language.model.mixture import MixtureLanguageModelAdapter @pytest.mark.slow -class TestMixtureLanguageModel(unittest.TestCase): +class TestMixtureLanguageModelAdapter(unittest.TestCase): """Tests for language model""" @classmethod def setUpClass(cls): dirname = os.path.dirname(__file__) or '.' - cls.kenlm_path = f"{dirname}/resources/lm_dec19_char_tiny_12gram.kenlm" + cls.kenlm_path = "lm_dec19_char_tiny_12gram.kenlm" + print(cls.kenlm_path) cls.lm_params = [{"lm_path": cls.kenlm_path}, {"lang_model_name": "gpt2"}] - cls.lmodel = MixtureLanguageModel(response_type=ResponseType.SYMBOL, symbol_set=alphabet(), - lm_types=["KENLM", "CAUSAL"], lm_weights=[0.5, 0.5], - lm_params=cls.lm_params) + cls.lmodel = MixtureLanguageModelAdapter(lm_types=["NGRAM", "CAUSAL"], lm_weights=[0.5, 0.5], + lm_params=cls.lm_params) + cls.lmodel.set_symbol_set(DEFAULT_SYMBOL_SET) @pytest.mark.slow def test_default_load(self): """Test loading model with parameters from json This test requires a valid lm_params.json file and all referenced models""" - lm = MixtureLanguageModel(response_type=ResponseType.SYMBOL, symbol_set=alphabet()) + lm = MixtureLanguageModelAdapter() + lm.set_symbol_set(DEFAULT_SYMBOL_SET) def test_init(self): """Test default parameters""" - self.assertEqual(self.lmodel.response_type, ResponseType.SYMBOL) - self.assertEqual(self.lmodel.symbol_set, alphabet()) - self.assertTrue( - ResponseType.SYMBOL in self.lmodel.supported_response_types()) + self.assertEqual(self.lmodel.symbol_set, DEFAULT_SYMBOL_SET) + self.assertTrue(isinstance(self.lmodel, CharacterLanguageModel)) - def test_name(self): - """Test model name.""" - self.assertEqual("MIXTURE", MixtureLanguageModel.name()) - - def test_unsupported_response_type(self): - """Unsupported responses should raise an exception""" - with self.assertRaises(UnsupportedResponseType): - MixtureLanguageModel(response_type=ResponseType.WORD, - symbol_set=alphabet()) + def test_invalid_symbol_set(self): + """Should raise an exception if predict is called without setting symbol set""" + with self.assertRaises(InvalidSymbolSetException): + lm = MixtureLanguageModelAdapter() + lm.predict_character("this_should_fail") def test_invalid_model_type(self): """Test that the proper exception is thrown if given an invalid lm_type""" with self.assertRaises(InvalidLanguageModelException): - MixtureLanguageModel(response_type=ResponseType.SYMBOL, symbol_set=alphabet(), - lm_types=["PHONY", "CAUSAL"], lm_weights=[0.5, 0.5], - lm_params=[{}, {"lang_model_name": "gpt2"}]) + lm = MixtureLanguageModelAdapter(lm_types=["PHONY", "CAUSAL"], lm_weights=[0.5, 0.5], + lm_params=[{}, {"lang_model_name": "gpt2"}]) + lm.set_symbol_set(DEFAULT_SYMBOL_SET) + with self.assertRaises(InvalidLanguageModelException): - MixtureLanguageModel(response_type=ResponseType.SYMBOL, symbol_set=alphabet(), - lm_types=["CAUSAL", "PHONY"], lm_weights=[0.5, 0.5], - lm_params=[{"lang_model_name": "gpt2"}, {}]) + lm = MixtureLanguageModelAdapter(lm_types=["CAUSAL", "PHONY"], lm_weights=[0.5, 0.5], + lm_params=[{"lang_model_name": "gpt2"}, {}]) + lm.set_symbol_set(DEFAULT_SYMBOL_SET) def test_invalid_model_weights(self): """Test that the proper exception is thrown if given an improper number of lm_weights""" with self.assertRaises(InvalidLanguageModelException, msg="Exception not thrown when too few weights given"): - MixtureLanguageModel(response_type=ResponseType.SYMBOL, symbol_set=alphabet(), - lm_types=["KENLM", "CAUSAL"], lm_weights=[0.5], - lm_params=self.lm_params) + lm = MixtureLanguageModelAdapter(lm_types=["NGRAM", "CAUSAL"], lm_weights=[0.5], + lm_params=[{"lm_path": self.kenlm_path}, {"lang_model_name": "gpt2"}]) + lm.set_symbol_set(DEFAULT_SYMBOL_SET) + with self.assertRaises(InvalidLanguageModelException, msg="Exception not thrown when no weights given"): - MixtureLanguageModel(response_type=ResponseType.SYMBOL, symbol_set=alphabet(), - lm_types=["KENLM", "CAUSAL"], lm_weights=None, - lm_params=self.lm_params) + lm = MixtureLanguageModelAdapter(lm_types=["NGRAM", "CAUSAL"], lm_weights=None, + lm_params=[{"lm_path": self.kenlm_path}, {"lang_model_name": "gpt2"}]) + lm.set_symbol_set(DEFAULT_SYMBOL_SET) + with self.assertRaises(InvalidLanguageModelException, msg="Exception not thrown when too many weights given"): - MixtureLanguageModel(response_type=ResponseType.SYMBOL, symbol_set=alphabet(), - lm_types=["KENLM", "CAUSAL"], lm_weights=[0.2, 0.3, 0.5], - lm_params=self.lm_params) + lm = MixtureLanguageModelAdapter(lm_types=["NGRAM", "CAUSAL"], lm_weights=[0.2, 0.3, 0.5], + lm_params=[{"lm_path": self.kenlm_path}, {"lang_model_name": "gpt2"}]) + lm.set_symbol_set(DEFAULT_SYMBOL_SET) + with self.assertRaises(InvalidLanguageModelException, msg="Exception not thrown when weights given do not \ sum to 1"): - MixtureLanguageModel(response_type=ResponseType.SYMBOL, symbol_set=alphabet(), - lm_types=["KENLM", "CAUSAL"], lm_weights=[0.5, 0.8], - lm_params=self.lm_params) + lm = MixtureLanguageModelAdapter(lm_types=["NGRAM", "CAUSAL"], lm_weights=[0.5, 0.8], + lm_params=[{"lm_path": self.kenlm_path}, {"lang_model_name": "gpt2"}]) + lm.set_symbol_set(DEFAULT_SYMBOL_SET) def test_non_mutable_evidence(self): """Test that the model does not change the evidence variable passed in.""" evidence = list("Test_test") evidence2 = list("Test_test") - self.lmodel.predict(evidence) + self.lmodel.predict_character(evidence) self.assertEqual(evidence, evidence2) def test_identical(self): """Ensure predictions are the same for subsequent queries with the same evidence.""" - query1 = self.lmodel.predict(list("evidenc")) - query2 = self.lmodel.predict(list("evidenc")) + query1 = self.lmodel.predict_character(list("evidenc")) + query2 = self.lmodel.predict_character(list("evidenc")) for ((sym1, prob1), (sym2, prob2)) in zip(query1, query2): self.assertAlmostEqual(prob1, prob2, places=5) self.assertEqual(sym1, sym2) def test_upper_lower_case(self): """Ensure predictions are the same for upper or lower case evidence.""" - lc = self.lmodel.predict(list("EVIDENC")) - uc = self.lmodel.predict(list("evidenc")) + lc = self.lmodel.predict_character(list("EVIDENC")) + uc = self.lmodel.predict_character(list("evidenc")) for ((l_sym, l_prob), (u_sym, u_prob)) in zip(lc, uc): self.assertAlmostEqual(l_prob, u_prob, places=5) self.assertEqual(l_sym, u_sym) def test_predict_start_of_word(self): """Test the predict method with no prior evidence.""" - symbol_probs = self.lmodel.predict(evidence=[]) + symbol_probs = self.lmodel.predict_character(evidence=[]) probs = [prob for sym, prob in symbol_probs] self.assertTrue( @@ -117,7 +119,7 @@ def test_predict_start_of_word(self): def test_predict_middle_of_word(self): """Test the predict method in the middle of a word.""" - symbol_probs = self.lmodel.predict(evidence=list("TH")) + symbol_probs = self.lmodel.predict_character(evidence=list("TH")) probs = [prob for sym, prob in symbol_probs] self.assertTrue( @@ -135,7 +137,7 @@ def test_predict_middle_of_word(self): def test_phrase(self): """Test that a phrase can be used for input""" - symbol_probs = self.lmodel.predict(list("does_it_make_sen")) + symbol_probs = self.lmodel.predict_character(list("does_it_make_sen")) most_likely_sym, _prob = sorted(symbol_probs, key=itemgetter(1), reverse=True)[0] @@ -143,15 +145,15 @@ def test_phrase(self): def test_multiple_spaces(self): """Test that the probability of space after a space is smaller than before the space""" - symbol_probs_before = self.lmodel.predict(list("the")) - symbol_probs_after = self.lmodel.predict(list("the_n")) + symbol_probs_before = self.lmodel.predict_character(list("the")) + symbol_probs_after = self.lmodel.predict_character(list("the_n")) space_prob_before = (dict(symbol_probs_before))[SPACE_CHAR] space_prob_after = (dict(symbol_probs_after))[SPACE_CHAR] self.assertTrue(space_prob_before > space_prob_after) def test_nonzero_prob(self): """Test that all letters in the alphabet have nonzero probability except for backspace""" - symbol_probs = self.lmodel.predict(list("does_it_make_sens")) + symbol_probs = self.lmodel.predict_character(list("does_it_make_sens")) prob_values = [item[1] for item in symbol_probs if item[0] != BACKSPACE_CHAR] for value in prob_values: self.assertTrue(value > 0) diff --git a/bcipy/language/tests/test_kenlm.py b/bcipy/language/tests/test_ngram.py similarity index 64% rename from bcipy/language/tests/test_kenlm.py rename to bcipy/language/tests/test_ngram.py index 6a460d38..b55ac605 100644 --- a/bcipy/language/tests/test_kenlm.py +++ b/bcipy/language/tests/test_ngram.py @@ -1,82 +1,80 @@ -"""Tests for KENLM Language Model""" +"""Tests for NGRAM Language Model""" -import pytest -import unittest import os +import unittest from operator import itemgetter -from bcipy.exceptions import UnsupportedResponseType, InvalidLanguageModelException -from bcipy.core.symbols import alphabet, BACKSPACE_CHAR, SPACE_CHAR -from bcipy.language.model.kenlm import KenLMLanguageModel -from bcipy.language.main import ResponseType +import pytest +from textslinger.exceptions import InvalidLanguageModelException + +from bcipy.core.symbols import BACKSPACE_CHAR, DEFAULT_SYMBOL_SET, SPACE_CHAR +from bcipy.exceptions import InvalidSymbolSetException +from bcipy.language.main import CharacterLanguageModel +from bcipy.language.model.ngram import NGramLanguageModelAdapter @pytest.mark.slow -class TestKenLMLanguageModel(unittest.TestCase): +class TestNGramLanguageModelAdapter(unittest.TestCase): """Tests for language model""" + @classmethod def setUpClass(cls): dirname = os.path.dirname(__file__) or '.' cls.lm_path = f"{dirname}/resources/lm_dec19_char_tiny_12gram.kenlm" - cls.lmodel = KenLMLanguageModel(response_type=ResponseType.SYMBOL, - symbol_set=alphabet(), lm_path=cls.lm_path) + cls.lmodel = NGramLanguageModelAdapter(lm_path=cls.lm_path) + cls.lmodel.set_symbol_set(DEFAULT_SYMBOL_SET) @pytest.mark.slow def test_default_load(self): """Test loading model with parameters from json This test requires a valid lm_params.json file and all requisite models""" - lm = KenLMLanguageModel(response_type=ResponseType.SYMBOL, symbol_set=alphabet()) + lm = NGramLanguageModelAdapter() + lm.set_symbol_set(DEFAULT_SYMBOL_SET) def test_init(self): """Test default parameters""" - self.assertEqual(self.lmodel.response_type, ResponseType.SYMBOL) - self.assertEqual(self.lmodel.symbol_set, alphabet()) - self.assertTrue( - ResponseType.SYMBOL in self.lmodel.supported_response_types()) - - def test_name(self): - """Test model name.""" - self.assertEqual("KENLM", KenLMLanguageModel.name()) + self.assertEqual(self.lmodel.symbol_set, DEFAULT_SYMBOL_SET) + self.assertTrue(isinstance(self.lmodel, CharacterLanguageModel)) - def test_unsupported_response_type(self): - """Unsupported responses should raise an exception""" - with self.assertRaises(UnsupportedResponseType): - KenLMLanguageModel(response_type=ResponseType.WORD, - symbol_set=alphabet(), lm_path=self.lm_path) + def test_invalid_symbol_set(self): + """Should raise an exception if predict is called without settting symbol set""" + with self.assertRaises(InvalidSymbolSetException): + lm = NGramLanguageModelAdapter(lm_path=self.lm_path) + lm.predict_character("this_should_fail") def test_invalid_model_path(self): """Test that the proper exception is thrown if given an invalid lm_path""" with self.assertRaises(InvalidLanguageModelException): - KenLMLanguageModel(response_type=ResponseType.SYMBOL, symbol_set=alphabet(), - lm_path="phonymodel.txt") + lm = NGramLanguageModelAdapter(lm_path="phonymodel.txt") + lm.set_symbol_set(DEFAULT_SYMBOL_SET) def test_non_mutable_evidence(self): """Test that the model does not change the evidence variable passed in. This could impact the mixture model if failed""" evidence = list("Test_test") evidence2 = list("Test_test") - self.lmodel.predict(evidence) + self.lmodel.predict_character(evidence) self.assertEqual(evidence, evidence2) def test_identical(self): """Ensure predictions are the same for subsequent queries with the same evidence.""" - query1 = self.lmodel.predict(list("evidenc")) - query2 = self.lmodel.predict(list("evidenc")) + query1 = self.lmodel.predict_character(list("evidenc")) + query2 = self.lmodel.predict_character(list("evidenc")) for ((sym1, prob1), (sym2, prob2)) in zip(query1, query2): self.assertAlmostEqual(prob1, prob2, places=5) self.assertEqual(sym1, sym2) def test_upper_lower_case(self): """Ensure predictions are the same for upper or lower case evidence.""" - lc = self.lmodel.predict(list("EVIDENC")) - uc = self.lmodel.predict(list("evidenc")) + lc = self.lmodel.predict_character(list("EVIDENC")) + uc = self.lmodel.predict_character(list("evidenc")) for ((l_sym, l_prob), (u_sym, u_prob)) in zip(lc, uc): self.assertAlmostEqual(l_prob, u_prob, places=5) self.assertEqual(l_sym, u_sym) def test_predict_start_of_word(self): """Test the predict method with no prior evidence.""" - symbol_probs = self.lmodel.predict(evidence=[]) + symbol_probs = self.lmodel.predict_character(evidence=[]) probs = [prob for sym, prob in symbol_probs] self.assertTrue( @@ -91,7 +89,7 @@ def test_predict_start_of_word(self): def test_predict_middle_of_word(self): """Test the predict method in the middle of a word.""" - symbol_probs = self.lmodel.predict(evidence=list("TH")) + symbol_probs = self.lmodel.predict_character(evidence=list("TH")) probs = [prob for sym, prob in symbol_probs] self.assertTrue( @@ -109,7 +107,7 @@ def test_predict_middle_of_word(self): def test_phrase(self): """Test that a phrase can be used for input""" - symbol_probs = self.lmodel.predict(list("does_it_make_sen")) + symbol_probs = self.lmodel.predict_character(list("does_it_make_sen")) most_likely_sym, _prob = sorted(symbol_probs, key=itemgetter(1), reverse=True)[0] @@ -117,16 +115,18 @@ def test_phrase(self): def test_multiple_spaces(self): """Test that the probability of space after a space is smaller than before the space""" - symbol_probs_before = self.lmodel.predict(list("the")) - symbol_probs_after = self.lmodel.predict(list("the_")) + symbol_probs_before = self.lmodel.predict_character(list("the")) + symbol_probs_after = self.lmodel.predict_character(list("the_")) space_prob_before = (dict(symbol_probs_before))[SPACE_CHAR] space_prob_after = (dict(symbol_probs_after))[SPACE_CHAR] self.assertTrue(space_prob_before > space_prob_after) def test_nonzero_prob(self): """Test that all letters in the alphabet have nonzero probability except for backspace""" - symbol_probs = self.lmodel.predict(list("does_it_make_sens")) - prob_values = [item[1] for item in symbol_probs if item[0] != BACKSPACE_CHAR] + symbol_probs = self.lmodel.predict_character(list("does_it_make_sens")) + prob_values = [ + item[1] for item in symbol_probs if item[0] != BACKSPACE_CHAR + ] for value in prob_values: self.assertTrue(value > 0) diff --git a/bcipy/language/tests/test_oracle.py b/bcipy/language/tests/test_oracle.py index 69a2f7a2..c5cb5489 100644 --- a/bcipy/language/tests/test_oracle.py +++ b/bcipy/language/tests/test_oracle.py @@ -2,8 +2,10 @@ import unittest -from bcipy.language.model.oracle import (BACKSPACE_CHAR, OracleLanguageModel, - ResponseType) +from bcipy.core.symbols import BACKSPACE_CHAR, DEFAULT_SYMBOL_SET +from bcipy.exceptions import InvalidSymbolSetException +from bcipy.language.main import CharacterLanguageModel +from bcipy.language.model.oracle import OracleLanguageModel class TestOracleLanguageModel(unittest.TestCase): @@ -17,15 +19,25 @@ def test_init(self): def test_init_with_text(self): """Test with task_text provided""" lmodel = OracleLanguageModel(task_text="HELLO_WORLD") - self.assertEqual(lmodel.response_type, ResponseType.SYMBOL) + lmodel.set_symbol_set(DEFAULT_SYMBOL_SET) self.assertEqual( len(lmodel.symbol_set), 28, "Should be the alphabet plus the backspace and space chars") + self.assertTrue(isinstance(lmodel, CharacterLanguageModel)) + + def test_invalid_symbol_set(self): + """Should raise an exception if predict is called before settting the symbol set""" + lm = OracleLanguageModel(task_text="HELLO_WORLD") + lm.set_symbol_set([]) + with self.assertRaises(InvalidSymbolSetException): + lm.predict_character("this_should_fail") + def test_predict(self): """Test the predict method""" lm = OracleLanguageModel(task_text="HELLO_WORLD") - symbol_probs = lm.predict(evidence=[]) + lm.set_symbol_set(DEFAULT_SYMBOL_SET) + symbol_probs = lm.predict_character(evidence=[]) probs = [prob for sym, prob in symbol_probs] self.assertEqual(len(set(probs)), 2, @@ -38,12 +50,13 @@ def test_predict(self): "Target should have a higher value") self.assertAlmostEqual(lm.target_bump, probs_dict['H'] - probs_dict['A'], - places=1) + places=4) def test_predict_with_spelled_text(self): """Test predictions with previously spelled symbols""" lm = OracleLanguageModel(task_text="HELLO_WORLD") - symbol_probs = lm.predict(evidence=list("HELLO_")) + lm.set_symbol_set(DEFAULT_SYMBOL_SET) + symbol_probs = lm.predict_character(evidence=list("HELLO_")) probs = [prob for sym, prob in symbol_probs] self.assertEqual(len(set(probs)), 2, @@ -55,7 +68,8 @@ def test_predict_with_spelled_text(self): def test_predict_with_incorrectly_spelled_text(self): """Test predictions with incorrectly spelled prior.""" lm = OracleLanguageModel(task_text="HELLO_WORLD") - symbol_probs = lm.predict(evidence=list("HELP")) + lm.set_symbol_set(DEFAULT_SYMBOL_SET) + symbol_probs = lm.predict_character(evidence=list("HELP")) probs = [prob for sym, prob in symbol_probs] self.assertEqual(len(set(probs)), 2) @@ -66,29 +80,32 @@ def test_predict_with_incorrectly_spelled_text(self): def test_target_bump_parameter(self): """Test setting the target_bump parameter.""" lm = OracleLanguageModel(task_text="HELLO_WORLD", target_bump=0.2) - symbol_probs = lm.predict(evidence=[]) + lm.set_symbol_set(DEFAULT_SYMBOL_SET) + symbol_probs = lm.predict_character(evidence=[]) probs_dict = dict(symbol_probs) self.assertTrue(probs_dict['H'] > probs_dict['A'], "Target should have a higher value") self.assertAlmostEqual(0.2, probs_dict['H'] - probs_dict['A'], - places=1) + places=4) def test_setting_task_text_to_none(self): """Test that task_text is required""" lmodel = OracleLanguageModel(task_text="HELLO_WORLD") + lmodel.set_symbol_set(DEFAULT_SYMBOL_SET) with self.assertRaises(AssertionError): lmodel.task_text = None def test_updating_task_text(self): """Test updating the task_text property.""" lm = OracleLanguageModel(task_text="HELLO_WORLD", target_bump=0.2) - probs = dict(lm.predict(evidence=list("HELLO_"))) + lm.set_symbol_set(DEFAULT_SYMBOL_SET) + probs = dict(lm.predict_character(evidence=list("HELLO_"))) self.assertTrue(probs['W'] > probs['T'], "Target should have a higher value") lm.task_text = "HELLO_THERE" - probs = dict(lm.predict(evidence=list("HELLO_"))) + probs = dict(lm.predict_character(evidence=list("HELLO_"))) self.assertTrue(probs['T'] > probs['W'], "Target should have a higher value") @@ -101,6 +118,7 @@ def test_target_bump_bounds(self): OracleLanguageModel(task_text="HI", target_bump=1.1) lm = OracleLanguageModel(task_text="HI", target_bump=0.0) + lm.set_symbol_set(DEFAULT_SYMBOL_SET) with self.assertRaises(AssertionError): lm.target_bump = -1.0 @@ -110,22 +128,23 @@ def test_target_bump_bounds(self): def test_evidence_exceeds_task(self): """Test probs when evidence exceeds task_text.""" lm = OracleLanguageModel(task_text="HELLO") + lm.set_symbol_set(DEFAULT_SYMBOL_SET) - probs = dict(lm.predict(evidence="HELL")) + probs = dict(lm.predict_character(evidence="HELL")) self.assertEqual(2, len(set(probs.values()))) self.assertEqual(max(probs.values()), probs['O']) - probs = dict(lm.predict(evidence="HELLO")) + probs = dict(lm.predict_character(evidence="HELLO")) self.assertEqual(1, len(set(probs.values()))) - probs = dict(lm.predict(evidence="HELLP")) + probs = dict(lm.predict_character(evidence="HELLP")) self.assertEqual(2, len(set(probs.values()))) self.assertEqual(max(probs.values()), probs[BACKSPACE_CHAR]) - probs = dict(lm.predict(evidence="HELLO_")) + probs = dict(lm.predict_character(evidence="HELLO_")) self.assertEqual(1, len(set(probs.values()))) - probs = dict(lm.predict(evidence="HELPED")) + probs = dict(lm.predict_character(evidence="HELPED")) self.assertEqual(2, len(set(probs.values()))) self.assertEqual(max(probs.values()), probs[BACKSPACE_CHAR]) diff --git a/bcipy/language/tests/test_uniform.py b/bcipy/language/tests/test_uniform.py index 8e8a4034..7b4b6dca 100644 --- a/bcipy/language/tests/test_uniform.py +++ b/bcipy/language/tests/test_uniform.py @@ -2,25 +2,42 @@ import unittest -from bcipy.language.model.uniform import (ResponseType, UniformLanguageModel, - equally_probable) +from bcipy.core.symbols import BACKSPACE_CHAR, DEFAULT_SYMBOL_SET +from bcipy.exceptions import InvalidSymbolSetException +from bcipy.language.main import CharacterLanguageModel +from bcipy.language.model.uniform import UniformLanguageModel, equally_probable class TestUniformLanguageModel(unittest.TestCase): """Tests for language model""" + @classmethod + def setUpClass(cls): + cls.lm = UniformLanguageModel() + cls.lm.set_symbol_set(DEFAULT_SYMBOL_SET) + def test_init(self): """Test default parameters""" lmodel = UniformLanguageModel() - self.assertEqual(lmodel.response_type, ResponseType.SYMBOL) + lmodel.set_symbol_set(DEFAULT_SYMBOL_SET) self.assertEqual( len(lmodel.symbol_set), 28, "Should be the alphabet plus the backspace and space chars") + self.assertTrue(isinstance(lmodel, CharacterLanguageModel)) + + def test_invalid_symbol_set(self): + """Should raise an exception if predict is called before setting a symbol set""" + lm = UniformLanguageModel() + lm.set_symbol_set([]) + with self.assertRaises(InvalidSymbolSetException): + lm.predict_character("this_should_fail") def test_predict(self): """Test the predict method""" - symbol_probs = UniformLanguageModel().predict(evidence=[]) - probs = [prob for sym, prob in symbol_probs] + symbol_probs = self.lm.predict_character(evidence=[]) + + # Backspace can be 0 + probs = [prob for sym, prob in symbol_probs if sym != BACKSPACE_CHAR] self.assertEqual(len(set(probs)), 1, "All values should be the same") self.assertTrue(0 < probs[0] < 1) diff --git a/bcipy/parameters/lm_params.json b/bcipy/parameters/lm_params.json index 02f93472..ddd81c5a 100644 --- a/bcipy/parameters/lm_params.json +++ b/bcipy/parameters/lm_params.json @@ -1,5 +1,5 @@ { - "kenlm": { + "ngram": { "model_file": { "description": "Name of the pretrained model file", "value": "lm_dec19_char_large_12gram.kenlm", @@ -33,7 +33,7 @@ "description": "Defines the types of models to be used by the mixture model.", "value": [ "CAUSAL", - "KENLM" + "NGRAM" ], "type": "List[str]" }, @@ -48,8 +48,8 @@ "model_params": { "description": "Defines the extra parameters of models to be used by the mixture model.", "value": [ - {}, - {} + {"lang_model_name": "figmtu/opt-350m-aac"}, + {"lm_path": "lm_dec19_char_large_12gram.kenlm"} ], "type": "List[Dict[str, str]]" } diff --git a/bcipy/simulator/task/copy_phrase.py b/bcipy/simulator/task/copy_phrase.py index 164a8d13..c9129e1d 100644 --- a/bcipy/simulator/task/copy_phrase.py +++ b/bcipy/simulator/task/copy_phrase.py @@ -4,6 +4,7 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Tuple +from bcipy.acquisition.multimodal import ClientManager from bcipy.core.parameters import Parameters from bcipy.core.stimuli import InquirySchedule from bcipy.display.main import Display @@ -11,10 +12,8 @@ from bcipy.language.main import LanguageModel from bcipy.signal.model.base_model import SignalModel from bcipy.simulator.data.sampler import Sampler -from bcipy.simulator.task.null_display import NullDisplay from bcipy.simulator.task.null_daq import NullDAQ -from bcipy.acquisition.multimodal import ClientManager - +from bcipy.simulator.task.null_display import NullDisplay from bcipy.simulator.util.state import SimState from bcipy.task import TaskMode from bcipy.task.control.evidence import EvidenceEvaluator diff --git a/pyproject.toml b/pyproject.toml index 16cfb2fc..13d82afa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,9 +24,9 @@ classifiers = [ ] dependencies = [ + "textslinger~=0.1.1", "attrdict3==2.0.2", "EDFlib-Python==1.0.8", - "transformers==4.36.0", "torch==2.2.0", "construct==2.8.14", "mne==1.6.1", @@ -56,7 +56,6 @@ dependencies = [ "rich==13.9.4", "reportlab==4.2.0", "tables==3.7.0", - "kenlm==0.1", "pyWinhook==1.6.2;python_version=='3.9' and platform_system=='Windows'", "WxPython==4.2.1;platform_system!='Linux'", ] @@ -180,7 +179,6 @@ exclude = [ "scripts", "acquisition", "display.paradigm", - "language", "signal.model", "task.control", "core.parameters",