|
1 | 1 | from components.llm.base_summarizer import BaseSummarizer |
2 | | -import logging |
| 2 | +import logging, threading, gc |
3 | 3 | from transformers import AutoTokenizer, TextIteratorStreamer |
4 | 4 | from optimum.intel.openvino import OVModelForCausalLM |
5 | 5 | from utils import ensure_model |
6 | 6 | from utils.config_loader import config |
7 | 7 | from utils.locks import audio_pipeline_lock |
8 | | -import threading |
9 | 8 |
|
10 | 9 | logger = logging.getLogger(__name__) |
11 | 10 |
|
12 | 11 |
|
13 | 12 | class Summarizer(BaseSummarizer): |
14 | 13 | def __init__(self, model_name, device, temperature=0.7, revision=None): |
15 | 14 | self.model_name = model_name |
16 | | - self.device = device.upper() # OpenVINO uses "GPU" or "CPU" |
| 15 | + self.device = device.upper() |
17 | 16 | self.temperature = temperature |
18 | 17 |
|
19 | | - model_path = ensure_model.get_model_path() |
20 | | - logger.info(f"Loading Model: model name={self.model_name}, model path={model_path}, device={self.device}") |
21 | | - |
| 18 | + self.model_path = ensure_model.get_model_path() |
| 19 | + |
| 20 | + logger.info( |
| 21 | + f"Summarizer initialized (lazy load). " |
| 22 | + f"model={self.model_name}, path={self.model_path}, device={self.device}" |
| 23 | + ) |
| 24 | + |
22 | 25 | self.tokenizer = AutoTokenizer.from_pretrained( |
23 | | - model_path, |
24 | | - trust_remote_code=True, |
25 | | - fix_mistral_regex=True |
| 26 | + self.model_path, |
| 27 | + trust_remote_code=True, |
| 28 | + fix_mistral_regex=True, |
26 | 29 | ) |
27 | 30 |
|
28 | 31 | if self.tokenizer.pad_token is None: |
29 | 32 | self.tokenizer.pad_token = self.tokenizer.eos_token |
30 | | - |
31 | | - self.model = OVModelForCausalLM.from_pretrained( |
32 | | - model_path, |
33 | | - device=self.device, |
34 | | - use_cache=True |
| 33 | + |
| 34 | + def _load_model(self): |
| 35 | + logger.info("Loading OVModelForCausalLM instance...") |
| 36 | + return OVModelForCausalLM.from_pretrained( |
| 37 | + self.model_path, |
| 38 | + device=self.device, |
| 39 | + use_cache=True, |
35 | 40 | ) |
36 | 41 |
|
| 42 | + def _destroy_model(self, model): |
| 43 | + try: |
| 44 | + del model |
| 45 | + gc.collect() |
| 46 | + logger.info("OV model instance destroyed") |
| 47 | + except Exception as e: |
| 48 | + logger.warning(f"Failed to destroy OV model cleanly: {e}") |
| 49 | + |
37 | 50 | def generate(self, prompt: str, stream: bool = True): |
38 | 51 | max_new_tokens = config.models.summarizer.max_new_tokens |
39 | 52 | inputs = self.tokenizer(prompt, return_tensors="pt") |
40 | 53 |
|
41 | 54 | if stream: |
42 | 55 | class CountingTextIteratorStreamer(TextIteratorStreamer): |
43 | | - def __init__(self, tokenizer, skip_special_tokens=True, skip_prompt=True): |
44 | | - super().__init__(tokenizer, skip_special_tokens=skip_special_tokens, skip_prompt=skip_prompt) |
45 | | - self.total_tokens = 0 |
| 56 | + def __init__(self, tokenizer, skip_special_tokens=True, skip_prompt=True): |
| 57 | + super().__init__( |
| 58 | + tokenizer, |
| 59 | + skip_special_tokens=skip_special_tokens, |
| 60 | + skip_prompt=skip_prompt, |
| 61 | + ) |
| 62 | + self.total_tokens = 0 |
46 | 63 |
|
47 | | - def put(self, value): |
48 | | - if value is not None: |
49 | | - self.total_tokens += 1 |
50 | | - super().put(value) |
| 64 | + def put(self, value): |
| 65 | + if value is not None: |
| 66 | + self.total_tokens += 1 |
| 67 | + super().put(value) |
51 | 68 |
|
52 | 69 | streamer = CountingTextIteratorStreamer( |
53 | | - self.tokenizer, |
54 | | - skip_special_tokens=True, |
55 | | - skip_prompt=True |
| 70 | + self.tokenizer, |
| 71 | + skip_special_tokens=True, |
| 72 | + skip_prompt=True, |
56 | 73 | ) |
57 | | - |
| 74 | + |
58 | 75 | def run_generation(): |
59 | | - with audio_pipeline_lock: |
60 | | - generation_kwargs = { |
61 | | - "input_ids": inputs.input_ids, |
62 | | - "max_new_tokens": max_new_tokens, |
63 | | - |
64 | | - # 🔑 sampling safety |
65 | | - "do_sample": True, |
66 | | - "temperature": max(self.temperature, 0.1), |
67 | | - "top_p": 0.9, |
68 | | - "top_k": 50, |
69 | | - |
70 | | - # tokens |
71 | | - "pad_token_id": self.tokenizer.eos_token_id, |
72 | | - "eos_token_id": self.tokenizer.eos_token_id, |
73 | | - |
74 | | - # streaming |
75 | | - "streamer": streamer, |
76 | | - } |
77 | | - self.model.generate(**generation_kwargs) |
78 | | - |
79 | | - thread = threading.Thread(target=run_generation, daemon=True) |
80 | | - thread.start() |
81 | | - |
| 76 | + model = None |
| 77 | + try: |
| 78 | + with audio_pipeline_lock: |
| 79 | + model = self._load_model() |
| 80 | + model.generate( |
| 81 | + input_ids=inputs.input_ids, |
| 82 | + max_new_tokens=max_new_tokens, |
| 83 | + |
| 84 | + # sampling |
| 85 | + do_sample=True, |
| 86 | + temperature=max(self.temperature, 0.1), |
| 87 | + top_p=0.9, |
| 88 | + top_k=50, |
| 89 | + |
| 90 | + # tokens |
| 91 | + pad_token_id=self.tokenizer.eos_token_id, |
| 92 | + eos_token_id=self.tokenizer.eos_token_id, |
| 93 | + |
| 94 | + # streaming |
| 95 | + streamer=streamer, |
| 96 | + ) |
| 97 | + |
| 98 | + except Exception: |
| 99 | + logger.error( |
| 100 | + "Exception occurred in OV streaming generation", |
| 101 | + exc_info=True, |
| 102 | + ) |
| 103 | + if hasattr(streamer, "_queue"): |
| 104 | + streamer._queue.put( |
| 105 | + "[ERROR]: Summary generation failed due to resource constraints." |
| 106 | + ) |
| 107 | + |
| 108 | + finally: |
| 109 | + if model is not None: |
| 110 | + self._destroy_model(model) |
| 111 | + streamer.end() |
| 112 | + |
| 113 | + threading.Thread(target=run_generation, daemon=True).start() |
82 | 114 | return streamer |
| 115 | + |
83 | 116 | else: |
84 | | - with audio_pipeline_lock: |
85 | | - generation_kwargs = { |
86 | | - "input_ids": inputs.input_ids, |
87 | | - "max_new_tokens": max_new_tokens, |
88 | | - |
89 | | - # 🔑 sampling safety |
90 | | - "do_sample": True, |
91 | | - "temperature": max(self.temperature, 0.1), |
92 | | - "top_p": 0.9, |
93 | | - "top_k": 50, |
94 | | - |
95 | | - # tokens |
96 | | - "pad_token_id": self.tokenizer.eos_token_id, |
97 | | - "eos_token_id": self.tokenizer.eos_token_id, |
98 | | - } |
99 | | - return self.model.generate(**generation_kwargs) |
| 117 | + model = None |
| 118 | + try: |
| 119 | + with audio_pipeline_lock: |
| 120 | + model = self._load_model() |
| 121 | + return model.generate( |
| 122 | + input_ids=inputs.input_ids, |
| 123 | + max_new_tokens=max_new_tokens, |
| 124 | + |
| 125 | + do_sample=True, |
| 126 | + temperature=max(self.temperature, 0.1), |
| 127 | + top_p=0.9, |
| 128 | + top_k=50, |
| 129 | + |
| 130 | + pad_token_id=self.tokenizer.eos_token_id, |
| 131 | + eos_token_id=self.tokenizer.eos_token_id, |
| 132 | + ) |
| 133 | + finally: |
| 134 | + if model is not None: |
| 135 | + self._destroy_model(model) |
0 commit comments