Skip to content

根据run.py代码编写了一个调用TextCNN模型进行预测的代码,或许能帮上忙! #123

Open
@Gerogell

Description

@Gerogell
import torch
from importlib import import_module
import pickle

class TextCNNPredictor:
    def __init__(self, model_path, dataset='THUCNews', embedding='embedding_SougouNews.npz'):
        self.model_name = 'TextCNN'
        x = import_module('models.' + self.model_name)
        self.config = x.Config(dataset, embedding)

        with open(self.config.vocab_path, 'rb') as f:
            self.vocab = pickle.load(f)

        self.model = x.Model(self.config).to(self.config.device)
        self.model.load_state_dict(torch.load(model_path, map_location=self.config.device))
        self.model.eval()
        self.class_list = self.config.class_list

    def preprocess_text(self, text, pad_size=32):
        tokenizer = lambda x: [y for y in x]  # char-level分词

        token = tokenizer(text)
        seq_len = len(token)
        if pad_size:
            if len(token) < pad_size:
                token.extend(['<PAD>'] * (pad_size - len(token)))
            else:
                token = token[:pad_size]
                seq_len = pad_size

        words_line = []
        for word in token:
            words_line.append(self.vocab.get(word, self.vocab.get('<UNK>')))

        words_line = torch.LongTensor([words_line]).to(self.config.device)
        return words_line

    def predict(self, text):
        input_tensor = self.preprocess_text(text, self.config.pad_size)
        with torch.no_grad():
            outputs = self.model((input_tensor, None, None, None))  # 保持与训练时相同的输入结构
            pred = torch.max(outputs.data, 1)[1].cpu().numpy()[0]

        return self.class_list[pred]


if __name__ == '__main__':
    predictor = TextCNNPredictor('THUCNews/saved_dict/TextCNN.ckpt')

    test_str='皇马3-1战胜巴塞罗那,夺得"西甲联赛冠军'
    result = predictor.predict(test_str)
    print(f"预测结果: {result}")

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions