-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add example of adding streaming support to run llm with transformers (#…
…68) * Add example of adding streaming support to run llm with transformers * lint
- Loading branch information
Showing
3 changed files
with
73 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -61,9 +61,11 @@ def ui(self): | |
|
||
with blocks: | ||
gr.Markdown("# 🧙🏼 Earning Report Assistant") | ||
gr.Markdown(""" | ||
gr.Markdown( | ||
""" | ||
This is an earning report assistant built for investors can't make the earning call on time. This sample is using Apple 2023 Q2 report. Feel free to reach out to [email protected] for more advanced features. | ||
""") | ||
""" | ||
) | ||
with gr.Row(): | ||
chatbot = gr.Chatbot(label="Model") | ||
with gr.Row(): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import os | ||
from threading import Thread | ||
from queue import Queue | ||
|
||
from loguru import logger | ||
from leptonai.photon import Photon, StreamingResponse | ||
|
||
|
||
class HfStreamLLM(Photon): | ||
|
||
deployment_template = { | ||
"resource_shape": "gpu.a10.6xlarge", | ||
"env": { | ||
"MODEL_PATH": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", | ||
}, | ||
"secret": [ | ||
"HUGGING_FACE_HUB_TOKEN", | ||
], | ||
} | ||
|
||
requirement_dependency = [ | ||
"transformers", | ||
] | ||
|
||
handler_max_concurrency = 4 | ||
|
||
def init(self): | ||
from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
||
model_path = os.environ["MODEL_PATH"] | ||
|
||
self._tok = AutoTokenizer.from_pretrained(model_path) | ||
self._model = AutoModelForCausalLM.from_pretrained(model_path).to("cuda") | ||
|
||
self._generation_queue = Queue() | ||
|
||
for _ in range(self.handler_max_concurrency): | ||
Thread(target=self._generate, daemon=True).start() | ||
|
||
def _generate(self): | ||
while True: | ||
streamer, args, kwargs = self._generation_queue.get() | ||
try: | ||
self._model.generate(*args, **kwargs) | ||
except Exception as e: | ||
logger.error(f"Error in generation: {e}") | ||
streamer.text_queue.put(streamer.stop_signal) | ||
|
||
@Photon.handler | ||
def run(self, text: str, max_new_tokens: int = 100) -> StreamingResponse: | ||
from transformers import TextIteratorStreamer | ||
|
||
streamer = TextIteratorStreamer(self._tok, skip_prompt=True, timeout=60) | ||
inputs = self._tok(text, return_tensors="pt").to("cuda") | ||
self._generation_queue.put_nowait( | ||
( | ||
streamer, | ||
(), | ||
dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens), | ||
) | ||
) | ||
return streamer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters