Skip to content

Commit 0aad5bf

Browse files
committed
refactor: module lowercase naming
1 parent cc62546 commit 0aad5bf

File tree

110 files changed

+11534
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

110 files changed

+11534
-0
lines changed

irbm/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""
2+
Intelligent resbibman,
3+
AI tools & GPU acceleration
4+
"""
5+
6+
7+
from resbibman.initLogger import setupLogger
8+
9+
def initLogger(level = "info"):
10+
return setupLogger(
11+
"iRBM",
12+
term_id="iserver",
13+
)
14+
15+
initLogger()

irbm/cmd/__init__.py

Whitespace-only changes.

irbm/cmd/summarize.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import asyncio
2+
import argparse
3+
from resbibman.core.pdfTools import PDFAnalyser
4+
from ..lmTools import summarize, structuredSummerize, featurize
5+
from ..lmInterface import streamOutput
6+
7+
def main():
8+
parser = argparse.ArgumentParser()
9+
parser.add_argument('pdf_path', type=str, help='path to pdf file')
10+
parser.add_argument('--structured', action='store_true', help='structured summarization')
11+
parser.add_argument('--model', type=str, default="gpt-3.5-turbo", help='model name')
12+
parser.add_argument("--max-length", type=int, default=-1, help="max length of the input text, the rest will be truncated")
13+
14+
args = parser.parse_args()
15+
with PDFAnalyser(args.pdf_path) as doc:
16+
pdf_text = doc.getText()
17+
18+
max_len = args.max_length
19+
if max_len == -1: max_len = len(pdf_text.split())
20+
if len(pdf_text.split()) > max_len:
21+
txt = " ".join(pdf_text.split()[:max_len])
22+
else: txt = pdf_text
23+
24+
# vec = asyncio.run(featurize(txt, verbose=True))
25+
# print(vec.shape)
26+
# exit()
27+
28+
if args.structured:
29+
res = asyncio.run(structuredSummerize(txt, print_func=print, model=args.model))
30+
vec = asyncio.run(featurize(res))
31+
print("Get vectorized result: ", vec.shape)
32+
else:
33+
streamOutput(summarize(txt))
34+
35+
36+
37+
if __name__ == "__main__":
38+
main()

irbm/globalConfig.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
"""
2+
Usage:
3+
import globalConfig as config
4+
...
5+
"""
6+
import openai
7+
from typing import Optional
8+
9+
openai_api_base: str = openai.api_base
10+
fastchat_api_base: str = "http://localhost:8000/v1"

irbm/lmInterface.py

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
2+
"""Language Model Interface"""
3+
from __future__ import annotations
4+
5+
from abc import ABC, abstractmethod
6+
from collections.abc import Iterator
7+
from typing import Any, TypedDict, Literal, Type
8+
import dataclasses, json
9+
import enum
10+
11+
# check python version
12+
import sys, os
13+
if sys.version_info < (3, 9):
14+
from typing import Iterator
15+
else:
16+
from collections.abc import Iterator
17+
18+
import openai
19+
from . import globalConfig as config
20+
import basaran.model
21+
22+
23+
ConvRole = Literal["user", "assistant"]
24+
ConvContent = str
25+
ConversationDictT = TypedDict("ConversationDictT", {
26+
"system": str,
27+
"conversations": list[tuple[ConvRole, ConvContent]]
28+
})
29+
@dataclasses.dataclass
30+
class Conversation:
31+
system: str
32+
conversations: list[tuple[ConvRole, ConvContent]]
33+
def add(self, role: ConvRole, content: str):
34+
self.conversations.append((role, content))
35+
def clear(self):
36+
self.conversations = []
37+
def __str__(self) -> str:
38+
template = "[system]\n> {}\n".format(self.system)
39+
return template + "\n".join(["[{}]\n> {}".format(c[0], c[1]) for c in self.conversations])
40+
def toDict(self) -> ConversationDictT:
41+
return {
42+
"system": self.system,
43+
"conversations": self.conversations
44+
}
45+
def setFromDict(self, dict: ConversationDictT):
46+
self.system = dict["system"]
47+
self.conversations = dict["conversations"]
48+
return self
49+
@property
50+
def openai_conversations(self):
51+
system = [{"role": "system", "content": self.system}]
52+
conv = [{"role": c[0], "content": c[1]} for c in self.conversations]
53+
return system + conv
54+
55+
def streamOutput(output_stream: Iterator[StreamData], print_callback: Any = lambda x, end=" ", flush=True: ...):
56+
"""
57+
Obtain the output from the stream, and maybe print it to stdout
58+
print_callback: a function that takes a string and print it to stdout, \
59+
should have the same interface as print (i.e. print_callback("hello", end=" ", flush=True))
60+
"""
61+
try:
62+
print_callback("", end="", flush=True)
63+
except TypeError:
64+
raise TypeError("print_func should have the same interface as print, i.e. contains end and flush")
65+
66+
pre = 0
67+
output_text = ""
68+
for outputs in output_stream:
69+
output_text = outputs["text"]
70+
output_text = output_text.strip().split(" ")
71+
now = len(output_text) - 1
72+
if now > pre:
73+
print_callback(" ".join(output_text[pre:now]), end=" ", flush=True)
74+
pre = now
75+
print_callback(" ".join(output_text[pre:]), flush=True)
76+
return " ".join(output_text)
77+
78+
79+
class ErrorCodes(enum.Enum):
80+
"""Error codes for the model output stream"""
81+
OK = 0
82+
83+
class StreamData(TypedDict):
84+
"""a class to represent the data returned by the model output stream"""
85+
text: str
86+
error_code: ErrorCodes
87+
88+
class ChatStreamIter(ABC):
89+
"""Abstract class for language model interface"""
90+
temperature = 0.8
91+
max_response_length = 1024
92+
conversations: Conversation
93+
94+
# whether to return the pieces of the output stream or return the concatenated whole output stream
95+
return_pieces: bool = False
96+
97+
@abstractmethod
98+
def call(self, message: str, temperature: float, max_len: int = 1024) -> Iterator[StreamData]:
99+
...
100+
def __call__(self, prompt) -> Iterator[StreamData]:
101+
return self.call(prompt, self.temperature, self.max_response_length)
102+
103+
class OpenAIChatStreamIter(ChatStreamIter):
104+
"""
105+
Connect to OpenAI API interface
106+
"""
107+
def __init__(self, model: str = "gpt-3.5-turbo") -> None:
108+
super().__init__()
109+
self.model = model
110+
self.conversations = Conversation(system="A conversation between a human and an AI assistant.", conversations=[])
111+
if "vicunna" in model:
112+
assert config.fastchat_api_base, "fastchat_api_base is not set"
113+
114+
def generateMessages(self, prompt: str):
115+
self.conversations.add(role = "user", content = prompt)
116+
return self.conversations.openai_conversations
117+
118+
@property
119+
def openai_base(self):
120+
if "vicuna" in self.model:
121+
return config.fastchat_api_base
122+
else:
123+
return config.openai_api_base
124+
125+
def call(self, prompt: str, temperature: float, max_len: int = 1024) -> Iterator[StreamData]:
126+
openai.api_base = self.openai_base # set the api base according to the model
127+
128+
res = openai.ChatCompletion.create(
129+
model=self.model, messages=self.generateMessages(prompt), temperature=temperature, stream=True
130+
)
131+
text = ""
132+
for chunk in res:
133+
piece: str = chunk["choices"][0]["delta"].get("content", "") # type: ignore
134+
text += piece
135+
data: StreamData = {
136+
"text": piece if self.return_pieces else text,
137+
"error_code": ErrorCodes.OK
138+
}
139+
yield data
140+
self.conversations.add(role = "assistant", content = text)
141+
142+
class HFChatStreamIter(ChatStreamIter):
143+
"""Offline models from huggingface"""
144+
def __init__(
145+
self,
146+
model: Literal["lmsys/vicuna-7b-v1.5-16k", "meta-llama/Llama-2-7b-chat", "stabilityai/StableBeluga-7B"],
147+
load_in_8bit: bool = True
148+
):
149+
self.model_name = model
150+
self.model = basaran.model.load_model(model, load_in_8bit=load_in_8bit)
151+
self.conversations = Conversation(system="A conversation between a human and an AI assistant.", conversations=[])
152+
153+
def getConv(self):
154+
if "Llama-2" in self.model_name:
155+
# Not sure if this is correct
156+
ret = f"[INST]<<SYS>>\n{self.conversations.system.strip()}\n<<SYS>>\n"
157+
for i, (role, content) in enumerate(self.conversations.conversations):
158+
if i == 0:
159+
assert role == "user"
160+
ret += f"{content}[/INST]"
161+
else:
162+
if role == "user":
163+
ret += f"[INST]{content}[/INST]"
164+
else:
165+
ret += f"{content}</s><s>"
166+
if self.conversations.conversations[-1][0] == "user":
167+
ret += "[INST]"
168+
return ret
169+
170+
elif "vicuna" in self.model_name:
171+
# Not sure if this is correct
172+
ret = f"{self.conversations.system.strip()}"
173+
for i, (role, content) in enumerate(self.conversations.conversations):
174+
if i == 0:
175+
assert role == "user"
176+
if role == "user":
177+
ret += f"USER: {content} "
178+
else:
179+
ret += f"ASSISTANT: {content}</s>"
180+
if self.conversations.conversations[-1][0] == "user":
181+
ret += "ASSISTANT: "
182+
else:
183+
ret += "USER: "
184+
return ret
185+
186+
elif "StableBeluga" in self.model_name:
187+
"""
188+
### System:
189+
This is a system prompt, please behave and help the user.
190+
191+
### User:
192+
Your prompt here
193+
194+
### Assistant:
195+
The output of Stable Beluga 7B
196+
"""
197+
ret = f"### System:\n{self.conversations.system.strip()}\n\n"
198+
for i, (role, content) in enumerate(self.conversations.conversations):
199+
if i == 0:
200+
assert role == "user"
201+
if role == "user":
202+
ret += f"### User:\n{content}\n\n"
203+
else:
204+
ret += f"### Assistant:\n{content}\n\n"
205+
if self.conversations.conversations[-1][0] == "user":
206+
ret += "### Assistant:\n"
207+
else:
208+
ret += "### User:\n"
209+
return ret
210+
211+
else:
212+
raise NotImplementedError("Unknown model: {}".format(self.model_name))
213+
214+
def call(self, prompt: str, temperature: float, max_len: int = 1024) -> Iterator[StreamData]:
215+
216+
self.conversations.add(role = "user", content = prompt)
217+
text = ""
218+
for choice in self.model(prompt=self.getConv(), max_tokens=max_len, temperature=temperature, return_full_text=False):
219+
piece = choice["text"]
220+
data: StreamData = {
221+
"text": piece,
222+
"error_code": ErrorCodes.OK
223+
}
224+
text += piece
225+
yield data
226+
self.conversations.add(role = "assistant", content = text)
227+
228+
229+
ChatStreamIterType = Literal[
230+
"gpt-3.5-turbo", "gpt-3.5-turbo-16k", "vicuna-13b", "gpt-4", "gpt-4-32k", "vicuna-33b-v1.3-gptq-4bit",
231+
"lmsys/vicuna-7b-v1.5-16k", "meta-llama/Llama-2-7b-chat", "stabilityai/StableBeluga-7B"
232+
]
233+
def getStreamIter(itype: ChatStreamIterType = "gpt-3.5-turbo") -> ChatStreamIter:
234+
if itype in ["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "vicuna-13b", "gpt-4", "gpt-4-32k", "vicuna-33b-v1.3-gptq-4bit"]:
235+
return OpenAIChatStreamIter(model=itype)
236+
237+
else:
238+
try:
239+
return HFChatStreamIter(model=itype) # type: ignore
240+
except:
241+
raise ValueError("Unknown interface type: {}".format(itype))

0 commit comments

Comments
 (0)