Skip to content

Commit ef33b86

Browse files
authored
Update summarizer.py
1 parent 00a25dc commit ef33b86

File tree

1 file changed

+100
-64
lines changed
  • education-ai-suite/smart-classroom/components/llm/openvino

1 file changed

+100
-64
lines changed
Lines changed: 100 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,99 +1,135 @@
11
from components.llm.base_summarizer import BaseSummarizer
2-
import logging
2+
import logging, threading, gc
33
from transformers import AutoTokenizer, TextIteratorStreamer
44
from optimum.intel.openvino import OVModelForCausalLM
55
from utils import ensure_model
66
from utils.config_loader import config
77
from utils.locks import audio_pipeline_lock
8-
import threading
98

109
logger = logging.getLogger(__name__)
1110

1211

1312
class Summarizer(BaseSummarizer):
1413
def __init__(self, model_name, device, temperature=0.7, revision=None):
1514
self.model_name = model_name
16-
self.device = device.upper() # OpenVINO uses "GPU" or "CPU"
15+
self.device = device.upper()
1716
self.temperature = temperature
1817

19-
model_path = ensure_model.get_model_path()
20-
logger.info(f"Loading Model: model name={self.model_name}, model path={model_path}, device={self.device}")
21-
18+
self.model_path = ensure_model.get_model_path()
19+
20+
logger.info(
21+
f"Summarizer initialized (lazy load). "
22+
f"model={self.model_name}, path={self.model_path}, device={self.device}"
23+
)
24+
2225
self.tokenizer = AutoTokenizer.from_pretrained(
23-
model_path,
24-
trust_remote_code=True,
25-
fix_mistral_regex=True
26+
self.model_path,
27+
trust_remote_code=True,
28+
fix_mistral_regex=True,
2629
)
2730

2831
if self.tokenizer.pad_token is None:
2932
self.tokenizer.pad_token = self.tokenizer.eos_token
30-
31-
self.model = OVModelForCausalLM.from_pretrained(
32-
model_path,
33-
device=self.device,
34-
use_cache=True
33+
34+
def _load_model(self):
35+
logger.info("Loading OVModelForCausalLM instance...")
36+
return OVModelForCausalLM.from_pretrained(
37+
self.model_path,
38+
device=self.device,
39+
use_cache=True,
3540
)
3641

42+
def _destroy_model(self, model):
43+
try:
44+
del model
45+
gc.collect()
46+
logger.info("OV model instance destroyed")
47+
except Exception as e:
48+
logger.warning(f"Failed to destroy OV model cleanly: {e}")
49+
3750
def generate(self, prompt: str, stream: bool = True):
3851
max_new_tokens = config.models.summarizer.max_new_tokens
3952
inputs = self.tokenizer(prompt, return_tensors="pt")
4053

4154
if stream:
4255
class CountingTextIteratorStreamer(TextIteratorStreamer):
43-
def __init__(self, tokenizer, skip_special_tokens=True, skip_prompt=True):
44-
super().__init__(tokenizer, skip_special_tokens=skip_special_tokens, skip_prompt=skip_prompt)
45-
self.total_tokens = 0
56+
def __init__(self, tokenizer, skip_special_tokens=True, skip_prompt=True):
57+
super().__init__(
58+
tokenizer,
59+
skip_special_tokens=skip_special_tokens,
60+
skip_prompt=skip_prompt,
61+
)
62+
self.total_tokens = 0
4663

47-
def put(self, value):
48-
if value is not None:
49-
self.total_tokens += 1
50-
super().put(value)
64+
def put(self, value):
65+
if value is not None:
66+
self.total_tokens += 1
67+
super().put(value)
5168

5269
streamer = CountingTextIteratorStreamer(
53-
self.tokenizer,
54-
skip_special_tokens=True,
55-
skip_prompt=True
70+
self.tokenizer,
71+
skip_special_tokens=True,
72+
skip_prompt=True,
5673
)
57-
74+
5875
def run_generation():
59-
with audio_pipeline_lock:
60-
generation_kwargs = {
61-
"input_ids": inputs.input_ids,
62-
"max_new_tokens": max_new_tokens,
63-
64-
# 🔑 sampling safety
65-
"do_sample": True,
66-
"temperature": max(self.temperature, 0.1),
67-
"top_p": 0.9,
68-
"top_k": 50,
69-
70-
# tokens
71-
"pad_token_id": self.tokenizer.eos_token_id,
72-
"eos_token_id": self.tokenizer.eos_token_id,
73-
74-
# streaming
75-
"streamer": streamer,
76-
}
77-
self.model.generate(**generation_kwargs)
78-
79-
thread = threading.Thread(target=run_generation, daemon=True)
80-
thread.start()
81-
76+
model = None
77+
try:
78+
with audio_pipeline_lock:
79+
model = self._load_model()
80+
model.generate(
81+
input_ids=inputs.input_ids,
82+
max_new_tokens=max_new_tokens,
83+
84+
# sampling
85+
do_sample=True,
86+
temperature=max(self.temperature, 0.1),
87+
top_p=0.9,
88+
top_k=50,
89+
90+
# tokens
91+
pad_token_id=self.tokenizer.eos_token_id,
92+
eos_token_id=self.tokenizer.eos_token_id,
93+
94+
# streaming
95+
streamer=streamer,
96+
)
97+
98+
except Exception:
99+
logger.error(
100+
"Exception occurred in OV streaming generation",
101+
exc_info=True,
102+
)
103+
if hasattr(streamer, "_queue"):
104+
streamer._queue.put(
105+
"[ERROR]: Summary generation failed due to resource constraints."
106+
)
107+
108+
finally:
109+
if model is not None:
110+
self._destroy_model(model)
111+
streamer.end()
112+
113+
threading.Thread(target=run_generation, daemon=True).start()
82114
return streamer
115+
83116
else:
84-
with audio_pipeline_lock:
85-
generation_kwargs = {
86-
"input_ids": inputs.input_ids,
87-
"max_new_tokens": max_new_tokens,
88-
89-
# 🔑 sampling safety
90-
"do_sample": True,
91-
"temperature": max(self.temperature, 0.1),
92-
"top_p": 0.9,
93-
"top_k": 50,
94-
95-
# tokens
96-
"pad_token_id": self.tokenizer.eos_token_id,
97-
"eos_token_id": self.tokenizer.eos_token_id,
98-
}
99-
return self.model.generate(**generation_kwargs)
117+
model = None
118+
try:
119+
with audio_pipeline_lock:
120+
model = self._load_model()
121+
return model.generate(
122+
input_ids=inputs.input_ids,
123+
max_new_tokens=max_new_tokens,
124+
125+
do_sample=True,
126+
temperature=max(self.temperature, 0.1),
127+
top_p=0.9,
128+
top_k=50,
129+
130+
pad_token_id=self.tokenizer.eos_token_id,
131+
eos_token_id=self.tokenizer.eos_token_id,
132+
)
133+
finally:
134+
if model is not None:
135+
self._destroy_model(model)

0 commit comments

Comments
 (0)