Skip to content

Commit 2990a66

Browse files
authored
Merge pull request #16 from bigdata-ustc/d2v
[FEATURE] Upgrade SIF and enable end2end vectorization
2 parents 3cf06d6 + 97d6d24 commit 2990a66

34 files changed

+1693
-31
lines changed

CHANGE.txt

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
v0.0.3:
22
1. update formula ast: supporting more symbols and functions defined in katex
3-
2. add item to vector tools, including word2vec and doc2vec using gensim
3+
2. add tokens to vector tools, including word2vec and doc2vec using gensim
44
3. sci4sif support tokenization grouped by segments
55
4. add special tokens: \SIFTag and \SIFSep
6+
5. add item to vector tools
7+
6. add interface for getting pretrained models, where the supported model names can be accessed by `edunlp i2v` in the command console
68

79
v0.0.2:
810
1. fix potential ModuleNotFoundError

EduNLP/I2V/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# coding: utf-8
2+
# 2021/8/1 @ tongshiwei
3+
4+
from .i2v import I2V, get_pretrained_i2v
5+
from .i2v import D2V

EduNLP/I2V/i2v.py

+109
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
# coding: utf-8
2+
# 2021/8/1 @ tongshiwei
3+
4+
import json
5+
from EduNLP.constant import MODEL_DIR
6+
from ..Vector import T2V, get_pretrained_t2v as get_t2v_pretrained_model
7+
from ..Tokenizer import Tokenizer, get_tokenizer
8+
from EduNLP import logger
9+
10+
__all__ = ["I2V", "D2V", "get_pretrained_i2v"]
11+
12+
13+
class I2V(object):
14+
def __init__(self, tokenizer, t2v, *args, tokenizer_kwargs: dict = None, pretrained_t2v=False, **kwargs):
15+
"""
16+
17+
Parameters
18+
----------
19+
tokenizer: str
20+
the tokenizer name
21+
t2v: str
22+
the name of token2vector model
23+
args:
24+
the parameters passed to t2v
25+
tokenizer_kwargs: dict
26+
the parameters passed to tokenizer
27+
pretrained_t2v: bool
28+
kwargs:
29+
the parameters passed to t2v
30+
"""
31+
self.tokenizer: Tokenizer = get_tokenizer(tokenizer, **tokenizer_kwargs if tokenizer_kwargs is not None else {})
32+
if pretrained_t2v:
33+
logger.info("Use pretrained t2v model %s" % t2v)
34+
self.t2v = get_t2v_pretrained_model(t2v, kwargs.get("model_dir", MODEL_DIR))
35+
else:
36+
self.t2v = T2V(t2v, *args, **kwargs)
37+
self.params = {
38+
"tokenizer": tokenizer,
39+
"tokenizer_kwargs": tokenizer_kwargs,
40+
"t2v": t2v,
41+
"args": args,
42+
"kwargs": kwargs,
43+
"pretrained_t2v": pretrained_t2v
44+
}
45+
46+
def __call__(self, items, *args, **kwargs):
47+
return self.infer_vector(items, *args, **kwargs)
48+
49+
def tokenize(self, items, indexing=True, padding=False, *args, **kwargs) -> list:
50+
return self.tokenizer(items, *args, **kwargs)
51+
52+
def infer_vector(self, items, tokenize=True, indexing=False, padding=False, *args, **kwargs) -> tuple:
53+
raise NotImplementedError
54+
55+
def infer_item_vector(self, tokens, *args, **kwargs) -> ...:
56+
return self.infer_vector(tokens, *args, **kwargs)[0]
57+
58+
def infer_token_vector(self, tokens, *args, **kwargs) -> ...:
59+
return self.infer_vector(tokens, *args, **kwargs)[1]
60+
61+
def save(self, config_path, *args, **kwargs):
62+
with open(config_path, "w", encoding="utf-8") as wf:
63+
json.dump(self.params, wf, ensure_ascii=False, indent=2)
64+
65+
@classmethod
66+
def load(cls, config_path, *args, **kwargs):
67+
with open(config_path, encoding="utf-8") as f:
68+
params: dict = json.load(f)
69+
tokenizer = params.pop("tokenizer")
70+
t2v = params.pop("t2v")
71+
args = params.pop("args")
72+
kwargs = params.pop("kwargs")
73+
params.update(kwargs)
74+
return cls(tokenizer, t2v, *args, **params)
75+
76+
@classmethod
77+
def from_pretrained(cls, name, model_dir=MODEL_DIR, *args, **kwargs):
78+
raise NotImplementedError
79+
80+
@property
81+
def vector_size(self):
82+
return self.t2v.vector_size
83+
84+
85+
class D2V(I2V):
86+
def infer_vector(self, items, tokenize=True, indexing=False, padding=False, *args, **kwargs) -> tuple:
87+
tokens = self.tokenize(items, return_token=True) if tokenize is True else items
88+
return self.t2v(tokens, *args, **kwargs), None
89+
90+
@classmethod
91+
def from_pretrained(cls, name, model_dir=MODEL_DIR, *args, **kwargs):
92+
return cls("text", name, pretrained_t2v=True, model_dir=model_dir)
93+
94+
95+
MODELS = {
96+
"d2v_all_256": [D2V, "d2v_all_256"],
97+
"d2v_sci_256": [D2V, "d2v_sci_256"],
98+
"d2v_eng_256": [D2V, "d2v_eng_256"],
99+
"d2v_lit_256": [D2V, "d2v_lit_256"],
100+
}
101+
102+
103+
def get_pretrained_i2v(name, model_dir=MODEL_DIR):
104+
if name not in MODELS:
105+
raise KeyError(
106+
"Unknown model name %s, use one of the provided models: %s" % (name, ", ".join(MODELS.keys()))
107+
)
108+
_class, *params = MODELS[name]
109+
return _class.from_pretrained(*params, model_dir=model_dir)

EduNLP/ModelZoo/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# coding: utf-8
2+
# 2021/7/12 @ tongshiwei
3+
4+
from .utils import *

EduNLP/ModelZoo/rnn/__init__.py

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# coding: utf-8
2+
# 2021/7/12 @ tongshiwei
3+
4+
from .rnn import LM

EduNLP/ModelZoo/rnn/rnn.py

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# coding: utf-8
2+
# 2021/7/12 @ tongshiwei
3+
4+
import torch
5+
from torch import nn
6+
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
7+
8+
9+
class LM(nn.Module):
10+
"""
11+
Examples
12+
--------
13+
>>> import torch
14+
>>> seq_idx = torch.LongTensor([[1, 2, 3], [1, 2, 0], [3, 0, 0]])
15+
>>> seq_len = torch.LongTensor([3, 2, 1])
16+
>>> lm = LM("RNN", 4, 3, 2)
17+
>>> output, hn = lm(seq_idx, seq_len)
18+
>>> output.shape
19+
torch.Size([3, 3, 2])
20+
>>> hn.shape
21+
torch.Size([1, 3, 2])
22+
>>> lm = LM("RNN", 4, 3, 2, num_layers=2)
23+
>>> output, hn = lm(seq_idx, seq_len)
24+
>>> output.shape
25+
torch.Size([3, 3, 2])
26+
>>> hn.shape
27+
torch.Size([2, 3, 2])
28+
"""
29+
30+
def __init__(self, rnn_type: str, vocab_size: int, embedding_dim: int, hidden_size: int, num_layers=1,
31+
bidirectional=False, embedding=None, **kwargs):
32+
super(LM, self).__init__()
33+
rnn_type = rnn_type.upper()
34+
self.embedding = torch.nn.Embedding(vocab_size, embedding_dim) if embedding is None else embedding
35+
self.c = False
36+
if rnn_type == "RNN":
37+
self.rnn = torch.nn.RNN(
38+
embedding_dim, hidden_size, num_layers, bidirectional=bidirectional, **kwargs
39+
)
40+
elif rnn_type == "LSTM":
41+
self.rnn = torch.nn.LSTM(
42+
embedding_dim, hidden_size, num_layers, bidirectional=bidirectional, **kwargs
43+
)
44+
self.c = True
45+
elif rnn_type == "GRU":
46+
self.rnn = torch.nn.GRU(
47+
embedding_dim, hidden_size, num_layers, bidirectional=bidirectional, **kwargs
48+
)
49+
elif rnn_type == "ELMO":
50+
bidirectional = True
51+
self.rnn = torch.nn.LSTM(
52+
embedding_dim, hidden_size, num_layers, bidirectional=bidirectional, **kwargs
53+
)
54+
self.c = True
55+
else:
56+
raise TypeError("Unknown rnn_type %s" % rnn_type)
57+
58+
self.num_layers = num_layers
59+
self.bidirectional = bidirectional
60+
if bidirectional is True:
61+
self.num_layers *= 2
62+
self.hidden_size = hidden_size
63+
64+
def forward(self, seq_idx, seq_len):
65+
seq = self.embedding(seq_idx)
66+
pack = pack_padded_sequence(seq, seq_len, batch_first=True)
67+
h0 = torch.randn(self.num_layers, seq.shape[0], self.hidden_size)
68+
if self.c is True:
69+
c0 = torch.randn(self.num_layers, seq.shape[0], self.hidden_size)
70+
output, (hn, _) = self.rnn(pack, (h0, c0))
71+
else:
72+
output, hn = self.rnn(pack, h0)
73+
output, _ = pad_packed_sequence(output, batch_first=True)
74+
return output, hn

EduNLP/ModelZoo/utils/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# coding: utf-8
2+
# 2021/7/12 @ tongshiwei
3+
4+
from .padder import PadSequence, pad_sequence
5+
from .device import set_device

EduNLP/ModelZoo/utils/device.py

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# coding: utf-8
2+
# 2021/8/2 @ tongshiwei
3+
import logging
4+
import torch
5+
from torch.nn import DataParallel
6+
7+
8+
def set_device(_net, ctx, *args, **kwargs): # pragma: no cover
9+
"""code from longling v1.3.26"""
10+
if ctx == "cpu":
11+
if not isinstance(_net, DataParallel):
12+
_net = DataParallel(_net)
13+
return _net.cpu()
14+
elif any(map(lambda x: x in ctx, ["cuda", "gpu"])):
15+
if not torch.cuda.is_available():
16+
try:
17+
torch.ones((1,), device=torch.device("cuda: 0"))
18+
except AssertionError as e:
19+
raise TypeError("no cuda detected, noly cpu is supported, the detailed error msg:%s" % str(e))
20+
if torch.cuda.device_count() >= 1:
21+
if ":" in ctx:
22+
ctx_name, device_ids = ctx.split(":")
23+
assert ctx_name in ["cuda", "gpu"], "the equipment should be 'cpu', 'cuda' or 'gpu', now is %s" % ctx
24+
device_ids = [int(i) for i in device_ids.strip().split(",")]
25+
try:
26+
if not isinstance(_net, DataParallel):
27+
return DataParallel(_net, device_ids).cuda
28+
return _net.cuda(device_ids)
29+
except AssertionError as e:
30+
logging.error(device_ids)
31+
raise e
32+
elif ctx in ["cuda", "gpu"]:
33+
if not isinstance(_net, DataParallel):
34+
_net = DataParallel(_net)
35+
return _net.cuda()
36+
else:
37+
raise TypeError("the equipment should be 'cpu', 'cuda' or 'gpu', now is %s" % ctx)
38+
else:
39+
logging.error(torch.cuda.device_count())
40+
raise TypeError("0 gpu can be used, use cpu")
41+
else:
42+
if not isinstance(_net, DataParallel):
43+
return DataParallel(_net, device_ids=ctx).cuda()
44+
return _net.cuda(ctx)

EduNLP/ModelZoo/utils/padder.py

+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# coding: utf-8
2+
# 2021/7/12 @ tongshiwei
3+
4+
__all__ = ["PadSequence", "pad_sequence"]
5+
6+
7+
class PadSequence(object):
8+
"""Pad the sequence.
9+
10+
Pad the sequence to the given `length` by inserting `pad_val`. If `clip` is set,
11+
sequence that has length larger than `length` will be clipped.
12+
13+
Parameters
14+
----------
15+
length : int
16+
The maximum length to pad/clip the sequence
17+
pad_val : number
18+
The pad value. Default 0
19+
clip : bool
20+
"""
21+
22+
def __init__(self, length, pad_val=0, clip=True):
23+
self._length = length
24+
self._pad_val = pad_val
25+
self._clip = clip
26+
27+
def __call__(self, sample: list):
28+
"""
29+
30+
Parameters
31+
----------
32+
sample : list of number
33+
34+
Returns
35+
-------
36+
ret : list of number
37+
"""
38+
sample_length = len(sample)
39+
if sample_length >= self._length:
40+
if self._clip and sample_length > self._length:
41+
return sample[:self._length]
42+
else:
43+
return sample
44+
else:
45+
return sample + [
46+
self._pad_val for _ in range(self._length - sample_length)
47+
]
48+
49+
50+
def pad_sequence(sequence: list, max_length=None, pad_val=0, clip=True):
51+
"""
52+
53+
Parameters
54+
----------
55+
sequence
56+
max_length
57+
pad_val
58+
clip
59+
60+
Returns
61+
-------
62+
63+
Examples
64+
--------
65+
>>> seq = [[4, 3, 3], [2], [3, 3, 2]]
66+
>>> pad_sequence(seq)
67+
[[4, 3, 3], [2, 0, 0], [3, 3, 2]]
68+
>>> pad_sequence(seq, pad_val=1)
69+
[[4, 3, 3], [2, 1, 1], [3, 3, 2]]
70+
>>> pad_sequence(seq, max_length=2)
71+
[[4, 3], [2, 0], [3, 3]]
72+
>>> pad_sequence(seq, max_length=2, clip=False)
73+
[[4, 3, 3], [2, 0], [3, 3, 2]]
74+
"""
75+
padder = PadSequence(max([len(seq) for seq in sequence]) if max_length is None else max_length, pad_val, clip)
76+
return [padder(seq) for seq in sequence]

0 commit comments

Comments
 (0)