1
1
# coding: utf-8
2
2
# 2021/8/1 @ tongshiwei
3
3
4
+ import torch
4
5
import json
5
6
import os .path
6
7
from typing import List , Tuple
@@ -59,12 +60,12 @@ class I2V(object):
59
60
"""
60
61
61
62
def __init__ (self , tokenizer , t2v , * args , tokenizer_kwargs : dict = None ,
62
- pretrained_t2v = False , model_dir = MODEL_DIR , ** kwargs ):
63
+ pretrained_t2v = False , model_dir = MODEL_DIR , device = 'cpu' , ** kwargs ):
63
64
if pretrained_t2v :
64
65
logger .info ("Use pretrained t2v model %s" % t2v )
65
- self .t2v = get_t2v_pretrained_model (t2v , model_dir )
66
+ self .t2v = get_t2v_pretrained_model (t2v , model_dir , device )
66
67
else :
67
- self .t2v = T2V (t2v , * args , ** kwargs )
68
+ self .t2v = T2V (t2v , device = device , * args , ** kwargs )
68
69
if tokenizer == 'bert' :
69
70
self .tokenizer = BertTokenizer .from_pretrained (
70
71
** tokenizer_kwargs if tokenizer_kwargs is not None else {})
@@ -82,31 +83,53 @@ def __init__(self, tokenizer, t2v, *args, tokenizer_kwargs: dict = None,
82
83
** tokenizer_kwargs if tokenizer_kwargs is not None else {})
83
84
self .params = {
84
85
"tokenizer" : tokenizer ,
85
- "tokenizer_kwargs" : tokenizer_kwargs ,
86
86
"t2v" : t2v ,
87
87
"args" : args ,
88
+ "tokenizer_kwargs" : tokenizer_kwargs ,
89
+ "pretrained_t2v" : pretrained_t2v ,
90
+ "model_dir" : model_dir ,
88
91
"kwargs" : kwargs ,
89
- "pretrained_t2v" : pretrained_t2v
90
92
}
93
+ self .device = torch .device (device )
91
94
92
95
def __call__ (self , items , * args , ** kwargs ):
93
96
"""transfer item to vector"""
94
97
return self .infer_vector (items , * args , ** kwargs )
95
98
96
99
def tokenize (self , items , * args , key = lambda x : x , ** kwargs ) -> list :
97
- # """tokenize item"""
100
+ """
101
+ tokenize item
102
+ Parameter
103
+ ----------
104
+ items: a list of questions
105
+ Return
106
+ ----------
107
+ tokens: list
108
+ """
98
109
return self .tokenizer (items , * args , key = key , ** kwargs )
99
110
100
111
def infer_vector (self , items , key = lambda x : x , ** kwargs ) -> tuple :
112
+ """
113
+ get question embedding
114
+ NotImplemented
115
+ """
101
116
raise NotImplementedError
102
117
103
118
def infer_item_vector (self , tokens , * args , ** kwargs ) -> ...:
119
+ """NotImplemented"""
104
120
return self .infer_vector (tokens , * args , ** kwargs )[0 ]
105
121
106
122
def infer_token_vector (self , tokens , * args , ** kwargs ) -> ...:
123
+ """NotImplemented"""
107
124
return self .infer_vector (tokens , * args , ** kwargs )[1 ]
108
125
109
126
def save (self , config_path ):
127
+ """
128
+ save model weights in config_path
129
+ Parameter:
130
+ ----------
131
+ config_path: str
132
+ """
110
133
with open (config_path , "w" , encoding = "utf-8" ) as wf :
111
134
json .dump (self .params , wf , ensure_ascii = False , indent = 2 )
112
135
@@ -123,6 +146,7 @@ def load(cls, config_path, *args, **kwargs):
123
146
124
147
@classmethod
125
148
def from_pretrained (cls , name , model_dir = MODEL_DIR , * args , ** kwargs ):
149
+ """NotImplemented"""
126
150
raise NotImplementedError
127
151
128
152
@property
@@ -327,13 +351,13 @@ def infer_vector(self, items: Tuple[List[str], List[dict], str, dict],
327
351
return self .t2v .infer_vector (inputs , * args , ** kwargs ), self .t2v .infer_tokens (inputs , * args , ** kwargs )
328
352
329
353
@classmethod
330
- def from_pretrained (cls , name , model_dir = MODEL_DIR , * args , ** kwargs ):
354
+ def from_pretrained (cls , name , model_dir = MODEL_DIR , device = 'cpu' , * args , ** kwargs ):
331
355
model_path = path_append (model_dir , get_pretrained_model_info (name )[0 ].split ('/' )[- 1 ], to_str = True )
332
356
for i in [".tar.gz" , ".tar.bz2" , ".tar.bz" , ".tar.tgz" , ".tar" , ".tgz" , ".zip" , ".rar" ]:
333
357
model_path = model_path .replace (i , "" )
334
358
logger .info ("model_path: %s" % model_path )
335
359
tokenizer_kwargs = {"tokenizer_config_dir" : model_path }
336
- return cls ("elmo" , name , pretrained_t2v = True , model_dir = model_dir ,
360
+ return cls ("elmo" , name , pretrained_t2v = True , model_dir = model_dir , device = device ,
337
361
tokenizer_kwargs = tokenizer_kwargs )
338
362
339
363
@@ -386,17 +410,19 @@ def infer_vector(self, items: Tuple[List[str], List[dict], str, dict],
386
410
--------
387
411
vector:list
388
412
"""
413
+ is_batch = isinstance (items , list )
414
+ items = items if is_batch else [items ]
389
415
inputs = self .tokenize (items , key = key , return_tensors = return_tensors )
390
416
return self .t2v .infer_vector (inputs , * args , ** kwargs ), self .t2v .infer_tokens (inputs , * args , ** kwargs )
391
417
392
418
@classmethod
393
- def from_pretrained (cls , name , model_dir = MODEL_DIR , * args , ** kwargs ):
419
+ def from_pretrained (cls , name , model_dir = MODEL_DIR , device = 'cpu' , * args , ** kwargs ):
394
420
model_path = path_append (model_dir , get_pretrained_model_info (name )[0 ].split ('/' )[- 1 ], to_str = True )
395
421
for i in [".tar.gz" , ".tar.bz2" , ".tar.bz" , ".tar.tgz" , ".tar" , ".tgz" , ".zip" , ".rar" ]:
396
422
model_path = model_path .replace (i , "" )
397
423
logger .info ("model_path: %s" % model_path )
398
424
tokenizer_kwargs = {"tokenizer_config_dir" : model_path }
399
- return cls ("bert" , name , pretrained_t2v = True , model_dir = model_dir ,
425
+ return cls ("bert" , name , pretrained_t2v = True , model_dir = model_dir , device = device ,
400
426
tokenizer_kwargs = tokenizer_kwargs )
401
427
402
428
@@ -452,7 +478,7 @@ def infer_vector(self, items: Tuple[List[str], List[dict], str, dict],
452
478
return i_vec , t_vec
453
479
454
480
@classmethod
455
- def from_pretrained (cls , name , model_dir = MODEL_DIR , ** kwargs ):
481
+ def from_pretrained (cls , name , model_dir = MODEL_DIR , device = 'cpu' , ** kwargs ):
456
482
model_path = path_append (model_dir , get_pretrained_model_info (name )[0 ].split ('/' )[- 1 ], to_str = True )
457
483
for i in [".tar.gz" , ".tar.bz2" , ".tar.bz" , ".tar.tgz" , ".tar" , ".tgz" , ".zip" , ".rar" ]:
458
484
model_path = model_path .replace (i , "" )
@@ -461,7 +487,7 @@ def from_pretrained(cls, name, model_dir=MODEL_DIR, **kwargs):
461
487
tokenizer_kwargs = {
462
488
"tokenizer_config_dir" : model_path ,
463
489
}
464
- return cls ("disenq" , name , pretrained_t2v = True , model_dir = model_dir ,
490
+ return cls ("disenq" , name , pretrained_t2v = True , model_dir = model_dir , device = device ,
465
491
tokenizer_kwargs = tokenizer_kwargs , ** kwargs )
466
492
467
493
@@ -495,18 +521,20 @@ def infer_vector(self, items: Tuple[List[str], List[dict], str, dict],
495
521
token embeddings
496
522
question embedding
497
523
"""
524
+ is_batch = isinstance (items , list )
525
+ items = items if is_batch else [items ]
498
526
encodes = self .tokenize (items , key = key , meta = meta , * args , ** kwargs )
499
527
return self .t2v .infer_vector (encodes ), self .t2v .infer_tokens (encodes )
500
528
501
529
@classmethod
502
- def from_pretrained (cls , name , model_dir = MODEL_DIR , * args , ** kwargs ):
530
+ def from_pretrained (cls , name , model_dir = MODEL_DIR , device = 'cpu' , * args , ** kwargs ):
503
531
model_path = path_append (model_dir , get_pretrained_model_info (name )[0 ].split ('/' )[- 1 ], to_str = True )
504
532
for i in [".tar.gz" , ".tar.bz2" , ".tar.bz" , ".tar.tgz" , ".tar" , ".tgz" , ".zip" , ".rar" ]:
505
533
model_path = model_path .replace (i , "" )
506
534
logger .info ("model_path: %s" % model_path )
507
535
tokenizer_kwargs = {
508
536
"tokenizer_config_dir" : model_path }
509
- return cls ("quesnet" , name , pretrained_t2v = True , model_dir = model_dir ,
537
+ return cls ("quesnet" , name , pretrained_t2v = True , model_dir = model_dir , device = device ,
510
538
tokenizer_kwargs = tokenizer_kwargs )
511
539
512
540
@@ -520,7 +548,7 @@ def from_pretrained(cls, name, model_dir=MODEL_DIR, *args, **kwargs):
520
548
}
521
549
522
550
523
- def get_pretrained_i2v (name , model_dir = MODEL_DIR ):
551
+ def get_pretrained_i2v (name , model_dir = MODEL_DIR , device = 'cpu' ):
524
552
"""
525
553
It is a good idea if you want to switch item to vector earily.
526
554
@@ -560,4 +588,4 @@ def get_pretrained_i2v(name, model_dir=MODEL_DIR):
560
588
)
561
589
_ , t2v = get_pretrained_model_info (name )
562
590
_class , * params = MODEL_MAP [t2v ], name
563
- return _class .from_pretrained (* params , model_dir = model_dir )
591
+ return _class .from_pretrained (* params , model_dir = model_dir , device = device )
0 commit comments