Skip to content

Commit 8a0ee3a

Browse files
committed
Add example of adding streaming support to run llm with transformers
1 parent 2f3443e commit 8a0ee3a

File tree

1 file changed

+60
-0
lines changed

1 file changed

+60
-0
lines changed

advanced/hf-stream-llm/photon.py

+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
import os
2+
from threading import Thread
3+
from queue import Queue
4+
5+
from loguru import logger
6+
from leptonai.photon import Photon, StreamingResponse
7+
8+
9+
class HfStreamLLM(Photon):
10+
11+
deployment_template = {
12+
"resource_shape": "gpu.a10.6xlarge",
13+
"env": {
14+
"MODEL_PATH": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
15+
},
16+
"secret": [
17+
"HUGGING_FACE_HUB_TOKEN",
18+
],
19+
}
20+
21+
requirement_dependency = [
22+
"transformers",
23+
]
24+
25+
handler_max_concurrency = 4
26+
27+
def init(self):
28+
from transformers import AutoModelForCausalLM, AutoTokenizer
29+
30+
model_path = os.environ["MODEL_PATH"]
31+
32+
self._tok = AutoTokenizer.from_pretrained(model_path)
33+
self._model = AutoModelForCausalLM.from_pretrained(model_path).to("cuda")
34+
35+
self._generation_queue = Queue()
36+
37+
for _ in range(self.handler_max_concurrency):
38+
Thread(target=self._generate, daemon=True).start()
39+
40+
def _generate(self):
41+
while True:
42+
streamer, args, kwargs = self._generation_queue.get()
43+
try:
44+
self._model.generate(*args, **kwargs)
45+
except Exception as e:
46+
logger.error(f"Error in generation: {e}")
47+
streamer.text_queue.put(streamer.stop_signal)
48+
49+
@Photon.handler
50+
def run(self, text: str, max_new_tokens: int = 100) -> StreamingResponse:
51+
from transformers import TextIteratorStreamer
52+
53+
streamer = TextIteratorStreamer(self._tok, skip_prompt=True, timeout=60)
54+
inputs = self._tok(text, return_tensors="pt").to("cuda")
55+
self._generation_queue.put_nowait((
56+
streamer,
57+
(),
58+
dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens),
59+
))
60+
return streamer

0 commit comments

Comments
 (0)