Open
Description
import torch
from importlib import import_module
import pickle
class TextCNNPredictor:
def __init__(self, model_path, dataset='THUCNews', embedding='embedding_SougouNews.npz'):
self.model_name = 'TextCNN'
x = import_module('models.' + self.model_name)
self.config = x.Config(dataset, embedding)
with open(self.config.vocab_path, 'rb') as f:
self.vocab = pickle.load(f)
self.model = x.Model(self.config).to(self.config.device)
self.model.load_state_dict(torch.load(model_path, map_location=self.config.device))
self.model.eval()
self.class_list = self.config.class_list
def preprocess_text(self, text, pad_size=32):
tokenizer = lambda x: [y for y in x] # char-level分词
token = tokenizer(text)
seq_len = len(token)
if pad_size:
if len(token) < pad_size:
token.extend(['<PAD>'] * (pad_size - len(token)))
else:
token = token[:pad_size]
seq_len = pad_size
words_line = []
for word in token:
words_line.append(self.vocab.get(word, self.vocab.get('<UNK>')))
words_line = torch.LongTensor([words_line]).to(self.config.device)
return words_line
def predict(self, text):
input_tensor = self.preprocess_text(text, self.config.pad_size)
with torch.no_grad():
outputs = self.model((input_tensor, None, None, None)) # 保持与训练时相同的输入结构
pred = torch.max(outputs.data, 1)[1].cpu().numpy()[0]
return self.class_list[pred]
if __name__ == '__main__':
predictor = TextCNNPredictor('THUCNews/saved_dict/TextCNN.ckpt')
test_str='皇马3-1战胜巴塞罗那,夺得"西甲联赛冠军'
result = predictor.predict(test_str)
print(f"预测结果: {result}")
Metadata
Metadata
Assignees
Labels
No labels