Skip to content

8bit optimizers #1624

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions deeppavlov/core/common/requirements_registry.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
],
"entity_linker": [
"{DEEPPAVLOV_PATH}/requirements/hdt.txt",
"{DEEPPAVLOV_PATH}/requirements/rapidfuzz.txt"
"{DEEPPAVLOV_PATH}/requirements/rapidfuzz.txt",
"{DEEPPAVLOV_PATH}/requirements/en_core_web_sm.txt",
"{DEEPPAVLOV_PATH}/requirements/ru_core_news_sm.txt"
],
"fasttext": [
"{DEEPPAVLOV_PATH}/requirements/fasttext.txt"
Expand Down Expand Up @@ -58,6 +60,7 @@
"{DEEPPAVLOV_PATH}/requirements/transformers.txt"
],
"ru_adj_to_noun": [
"{DEEPPAVLOV_PATH}/requirements/ru_core_news_sm.txt",
"{DEEPPAVLOV_PATH}/requirements/udapi.txt"
],
"russian_words_vocab": [
Expand Down Expand Up @@ -147,7 +150,9 @@
"{DEEPPAVLOV_PATH}/requirements/transformers.txt"
],
"tree_to_sparql": [
"{DEEPPAVLOV_PATH}/requirements/udapi.txt"
"{DEEPPAVLOV_PATH}/requirements/udapi.txt",
"{DEEPPAVLOV_PATH}/requirements/en_core_web_sm.txt",
"{DEEPPAVLOV_PATH}/requirements/ru_core_news_sm.txt"
],
"typos_custom_reader": [
"{DEEPPAVLOV_PATH}/requirements/lxml.txt"
Expand Down
28 changes: 26 additions & 2 deletions deeppavlov/core/models/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,31 @@ def __init__(self, device: str = "gpu",
self.model.eval()
log.debug(f"Model was successfully initialized! Model summary:\n {self.model}")

def get_optimizer(self):
"""
Initialize optimizer from bnb. Resort to pytorch if already initialized
"""
try:
# Import BNB opt
import bitsandbytes as bnb
if 'AdamW' in self.optimizer_name:
log.info('No weight decay supported in bitsandbytes yet')
opt_name = self.optimizer_name.replace('AdamW','Adam')
else:
opt_name = self.optimizer_name
#if self.optimizer_name[-4:] != '8bit': # backwards compatibility
# opt_name = opt_name + '8bit'
log.info(f'Using bitsandbytes optimizer {opt_name}')
optimizer = getattr(bnb.optim, opt_name)(
self.model.parameters(), **self.optimizer_parameters)
except Exception as e:
print(e)
breakpoint()
log.info('Not imported 8bit optimizer - resorting to torch optimizer')
optimizer = getattr(torch.optim, self.optimizer_name)(
self.model.parameters(), **self.optimizer_parameters)
return optimizer

def init_from_opt(self, model_func: str) -> None:
"""Initialize from scratch `self.model` with the architecture built in `model_func` method of this class
along with `self.optimizer` as `self.optimizer_name` from `torch.optim` and parameters
Expand All @@ -115,8 +140,7 @@ def init_from_opt(self, model_func: str) -> None:
"""
if callable(model_func):
self.model = model_func(**self.opt).to(self.device)
self.optimizer = getattr(torch.optim, self.optimizer_name)(
self.model.parameters(), **self.optimizer_parameters)
self.optimizer = self.get_optimizer()
if self.lr_scheduler_name:
self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_name)(
self.optimizer, **self.lr_scheduler_parameters)
Expand Down
10 changes: 9 additions & 1 deletion deeppavlov/models/classifiers/torch_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
# limitations under the License.

from typing import List, Union, Optional
from logging import getLogger

import torch
import torch.nn as nn

log = getLogger(__name__)


class ShallowAndWideCnn(nn.Module):
def __init__(self, n_classes: int, embedding_size: int, kernel_sizes_cnn: List[int],
Expand All @@ -27,7 +30,12 @@ def __init__(self, n_classes: int, embedding_size: int, kernel_sizes_cnn: List[i
self.kernel_sizes_cnn = kernel_sizes_cnn

if not embedded_tokens and vocab_size:
self.embedding = nn.Embedding(vocab_size, embedding_size)
try:
import bitsandbytes as bnb
self.embedding = bnb.nn.StableEmbedding(vocab_size, embedding_size)
except:
log.info('Not imported 8bit optimizer - resorting to torch optimizer')
self.embedding = nn.Embedding(vocab_size, embedding_size)
if isinstance(filters_cnn, int):
filters_cnn = len(kernel_sizes_cnn) * [filters_cnn]

Expand Down
39 changes: 15 additions & 24 deletions deeppavlov/models/entity_extraction/entity_linking.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,19 @@

import re
import sqlite3
from collections import defaultdict
from logging import getLogger
from typing import List, Dict, Tuple, Union, Any
from collections import defaultdict

import pymorphy2
import spacy
from hdt import HDTDocument
from nltk.corpus import stopwords
from rapidfuzz import fuzz

from deeppavlov.core.commands.utils import expand_path
from deeppavlov.core.common.registry import register
from deeppavlov.core.models.component import Component
from deeppavlov.core.models.serializable import Serializable
from deeppavlov.core.commands.utils import expand_path

log = getLogger(__name__)

Expand Down Expand Up @@ -75,7 +75,6 @@ def __init__(
**kwargs:
"""
super().__init__(save_path=None, load_path=load_path)
self.morph = pymorphy2.MorphAnalyzer()
self.lemmatize = lemmatize
self.entities_database_filename = entities_database_filename
self.num_entities_for_bert_ranking = num_entities_for_bert_ranking
Expand All @@ -86,8 +85,10 @@ def __init__(
self.lang = f"@{lang}"
if self.lang == "@en":
self.stopwords = set(stopwords.words("english"))
self.nlp = spacy.load("en_core_web_sm")
elif self.lang == "@ru":
self.stopwords = set(stopwords.words("russian"))
self.nlp = spacy.load("ru_core_news_sm")
self.use_descriptions = use_descriptions
self.use_connections = use_connections
self.max_paragraph_len = max_paragraph_len
Expand Down Expand Up @@ -198,7 +199,7 @@ def link_entities(
):
cand_ent_scores = []
if len(entity_substr) > 1:
entity_substr_split_lemm = [self.morph.parse(tok)[0].normal_form for tok in entity_substr_split]
entity_substr_split_lemm = [self.nlp(tok)[0].lemma_ for tok in entity_substr_split]
cand_ent_init = self.find_exact_match(entity_substr, tag)
if not cand_ent_init or entity_substr_split != entity_substr_split_lemm:
cand_ent_init = self.find_fuzzy_match(entity_substr_split, tag)
Expand Down Expand Up @@ -297,28 +298,23 @@ def find_exact_match(self, entity_substr, tag):
entity_substr_split = entity_substr_split[1:]
entities_and_ids = self.find_title(entity_substr)
cand_ent_init = self.process_cand_ent(cand_ent_init, entities_and_ids, entity_substr_split, tag)
if self.lang == "@ru":
entity_substr_split_lemm = [self.morph.parse(tok)[0].normal_form for tok in entity_substr_split]
entity_substr_lemm = " ".join(entity_substr_split_lemm)
if entity_substr_lemm != entity_substr:
entities_and_ids = self.find_title(entity_substr_lemm)
if entities_and_ids:
cand_ent_init = self.process_cand_ent(
cand_ent_init, entities_and_ids, entity_substr_split_lemm, tag
)

entity_substr_split_lemm = [self.nlp(tok)[0].lemma_ for tok in entity_substr_split]
entity_substr_lemm = " ".join(entity_substr_split_lemm)
if entity_substr_lemm != entity_substr:
entities_and_ids = self.find_title(entity_substr_lemm)
if entities_and_ids:
cand_ent_init = self.process_cand_ent(cand_ent_init, entities_and_ids, entity_substr_split_lemm, tag)
return cand_ent_init

def find_fuzzy_match(self, entity_substr_split, tag):
if self.lang == "@ru":
entity_substr_split_lemm = [self.morph.parse(tok)[0].normal_form for tok in entity_substr_split]
else:
entity_substr_split_lemm = entity_substr_split
entity_substr_split_lemm = [self.nlp(tok)[0].lemma_ for tok in entity_substr_split]
cand_ent_init = defaultdict(set)
for word in entity_substr_split:
part_entities_and_ids = self.find_title(word)
cand_ent_init = self.process_cand_ent(cand_ent_init, part_entities_and_ids, entity_substr_split, tag)
if self.lang == "@ru":
word_lemm = self.morph.parse(word)[0].normal_form
word_lemm = self.nlp(word)[0].lemma_
if word != word_lemm:
part_entities_and_ids = self.find_title(word_lemm)
cand_ent_init = self.process_cand_ent(
Expand All @@ -329,11 +325,6 @@ def find_fuzzy_match(self, entity_substr_split, tag):
)
return cand_ent_init

def morph_parse(self, word):
morph_parse_tok = self.morph.parse(word)[0]
normal_form = morph_parse_tok.normal_form
return normal_form

def calc_substr_score(self, cand_entity_title, entity_substr_split):
label_tokens = cand_entity_title.split()
cnt = 0.0
Expand Down
20 changes: 10 additions & 10 deletions deeppavlov/models/kbqa/tree_to_sparql.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from typing import Any, List, Tuple, Dict, Union

import numpy as np
import pymorphy2
import spacy
from navec import Navec
from scipy.sparse import csr_matrix
from slovnet import Syntax
Expand Down Expand Up @@ -66,11 +66,10 @@ def __init__(self, freq_dict_filename: str, candidate_nouns: int = 10, **kwargs)
self.adj_set = set([word for word, freq in pos_freq_dict["a"]])
self.nouns = [noun[0] for noun in self.nouns_with_freq]
self.matrix = self.make_sparse_matrix(self.nouns).transpose()
self.morph = pymorphy2.MorphAnalyzer()
self.nlp = spacy.load("ru_core_news_sm")

def search(self, word: str):
word = self.morph.parse(word)[0]
word = word.normal_form
word = self.nlp(word)[0].lemma_
if word in self.adj_set:
q_matrix = self.make_sparse_matrix([word])
scores = q_matrix * self.matrix
Expand Down Expand Up @@ -190,6 +189,7 @@ def __init__(self, sparql_queries_filename: str, lang: str = "rus", adj_to_noun:
self.begin_tokens = {"начинать", "начать"}
self.end_tokens = {"завершить", "завершать", "закончить"}
self.ranking_tokens = {"самый"}
self.nlp = spacy.load("ru_core_news_sm")
elif self.lang == "eng":
self.q_pronouns = {"what", "who", "how", "when", "where", "which"}
self.how_many = "how many"
Expand All @@ -199,12 +199,12 @@ def __init__(self, sparql_queries_filename: str, lang: str = "rus", adj_to_noun:
self.begin_tokens = set()
self.end_tokens = set()
self.ranking_tokens = set()
self.nlp = spacy.load("en_core_web_sm")
else:
raise ValueError(f"unsupported language {lang}")
self.sparql_queries_filename = expand_path(sparql_queries_filename)
self.template_queries = read_json(self.sparql_queries_filename)
self.adj_to_noun = adj_to_noun
self.morph = pymorphy2.MorphAnalyzer()

def __call__(self, syntax_tree_batch: List[str],
positions_batch: List[List[List[int]]]) -> Tuple[
Expand Down Expand Up @@ -274,7 +274,7 @@ def __call__(self, syntax_tree_batch: List[str],
self.root_entity = True

temporal_order = self.find_first_last(new_root)
new_root_nf = self.morph.parse(new_root.form)[0].normal_form
new_root_nf = self.nlp(new_root.form)[0].lemma_
if new_root_nf in self.begin_tokens or new_root_nf in self.end_tokens:
temporal_order = new_root_nf
ranking_tokens = self.find_ranking_tokens(new_root)
Expand All @@ -288,7 +288,7 @@ def __call__(self, syntax_tree_batch: List[str],
question = []
for node in tree.descendants:
if node.ord in ranking_tokens or node.form.lower() in self.q_pronouns:
question.append(self.morph.parse(node.form)[0].normal_form)
question.append(self.nlp(node.form)[0].lemma_)
else:
question.append(node.form)
question = ' '.join(question)
Expand Down Expand Up @@ -496,9 +496,9 @@ def find_first_last(self, node: Node) -> str:
for node in nodes:
node_desc = defaultdict(set)
for elem in node.children:
parsed_elem = self.morph.parse(elem.form.lower())[0].inflect({"masc", "sing", "nomn"})
parsed_elem = self.nlp(elem.form.lower())[0].lemma_
if parsed_elem is not None:
node_desc[elem.deprel].add(parsed_elem.word)
node_desc[elem.deprel].add(parsed_elem)
else:
node_desc[elem.deprel].add(elem.form)
if "amod" in node_desc.keys() and "nmod" in node_desc.keys() and \
Expand All @@ -511,7 +511,7 @@ def find_first_last(self, node: Node) -> str:
def find_ranking_tokens(self, node: Node) -> list:
ranking_tokens = []
for elem in node.descendants:
if self.morph.parse(elem.form)[0].normal_form in self.ranking_tokens:
if self.nlp(elem.form)[0].lemma_ in self.ranking_tokens:
ranking_tokens.append(elem.ord)
ranking_tokens.append(elem.parent.ord)
return ranking_tokens
Expand Down
4 changes: 1 addition & 3 deletions deeppavlov/models/kbqa/type_define.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import pickle
from typing import List

import pymorphy2
import spacy
from nltk.corpus import stopwords

Expand Down Expand Up @@ -43,7 +42,6 @@ def __init__(self, lang: str, types_filename: str, types_sets_filename: str,
self.types_filename = str(expand_path(types_filename))
self.types_sets_filename = str(expand_path(types_sets_filename))
self.num_types_to_return = num_types_to_return
self.morph = pymorphy2.MorphAnalyzer()
if self.lang == "@en":
self.stopwords = set(stopwords.words("english"))
self.nlp = spacy.load("en_core_web_sm")
Expand Down Expand Up @@ -102,7 +100,7 @@ def __call__(self, questions_batch: List[str], entity_substr_batch: List[List[st
types_substr_tokens = types_substr.split()
types_substr_tokens = [tok for tok in types_substr_tokens if tok not in self.stopwords]
if self.lang == "@ru":
types_substr_tokens = [self.morph.parse(tok)[0].normal_form for tok in types_substr_tokens]
types_substr_tokens = [self.nlp(tok)[0].lemma_ for tok in types_substr_tokens]
types_substr_tokens = set(types_substr_tokens)
types_scores = []
for entity in self.types_dict:
Expand Down
3 changes: 1 addition & 2 deletions deeppavlov/models/torch_bert/torch_bert_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,7 @@ def load(self, fname=None):

self.model.to(self.device)

self.optimizer = getattr(torch.optim, self.optimizer_name)(
self.model.parameters(), **self.optimizer_parameters)
self.optimizer = self.get_optimizer()
if self.lr_scheduler_name is not None:
self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_name)(
self.optimizer, **self.lr_scheduler_parameters)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,7 @@ def load(self, fname=None):

self.model.to(self.device)

self.optimizer = getattr(torch.optim, self.optimizer_name)(
self.model.parameters(), **self.optimizer_parameters)
self.optimizer = self.get_optimizer()
if self.lr_scheduler_name is not None:
self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_name)(
self.optimizer, **self.lr_scheduler_parameters)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,7 @@ def load(self, fname = None):

self.model.to(self.device)

self.optimizer = getattr(torch.optim, self.optimizer_name)(
self.model.parameters(), **self.optimizer_parameters)
self.optimizer = self.get_optimizer()
if self.lr_scheduler_name is not None:
self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_name)(
self.optimizer, **self.lr_scheduler_parameters)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,7 @@ def load(self, fname=None):
if self.use_crf:
self.crf = CRF(self.n_classes).to(self.device)

self.optimizer = getattr(torch.optim, self.optimizer_name)(
self.model.parameters(), **self.optimizer_parameters)
self.optimizer = self.get_optimizer()
if self.lr_scheduler_name is not None:
self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_name)(
self.optimizer, **self.lr_scheduler_parameters)
Expand Down
3 changes: 1 addition & 2 deletions deeppavlov/models/torch_bert/torch_transformers_squad.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,7 @@ def load(self, fname=None):
self.model = torch.nn.DataParallel(self.model)

self.model.to(self.device)
self.optimizer = getattr(torch.optim, self.optimizer_name)(
self.model.parameters(), **self.optimizer_parameters)
self.optimizer = self.get_optimizer()
if self.lr_scheduler_name is not None:
self.lr_scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_name)(
self.optimizer, **self.lr_scheduler_parameters)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@ scikit-learn>=0.24,<1.1.0
scipy<1.9.0
tqdm>=4.42.0,<4.65.0
uvicorn>=0.13.0,<0.19.0
bitsandbytes-cuda113==0.25.0
2 changes: 1 addition & 1 deletion tests/test_quick_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@
("kbqa/kbqa_cq_ru.json", "kbqa", ('IP',)):
[
("Кто такой Оксимирон?", ("российский рэп-исполнитель",)),
("Чем питаются коалы?", ("Лист",)),
("Кто написал «Евгений Онегин»?", ("Александр Сергеевич Пушкин",)),
("абв", ("Not Found",))
]
},
Expand Down