Skip to content

Commit 890a20e

Browse files
authored
Merge pull request #303 from arc53/feature/hf-docs-models
Support for hf models optimised for docsgpt
2 parents e6f48c9 + 909f0af commit 890a20e

File tree

3 files changed

+19
-13
lines changed

3 files changed

+19
-13
lines changed

application/app.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from celery.result import AsyncResult
1515
from flask import Flask, request, render_template, send_from_directory, jsonify, Response
1616
from langchain import FAISS
17-
from langchain import VectorDBQA, HuggingFaceHub, Cohere, OpenAI
17+
from langchain import VectorDBQA, Cohere, OpenAI
1818
from langchain.chains import LLMChain, ConversationalRetrievalChain
1919
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT
2020
from langchain.chains.question_answering import load_qa_chain
@@ -25,7 +25,6 @@
2525
CohereEmbeddings,
2626
HuggingFaceInstructEmbeddings,
2727
)
28-
from langchain.llms import GPT4All
2928
from langchain.prompts import PromptTemplate
3029
from langchain.prompts.chat import (
3130
ChatPromptTemplate,
@@ -50,11 +49,20 @@
5049
else:
5150
gpt_model = 'gpt-3.5-turbo'
5251

53-
if settings.LLM_NAME == "manifest":
54-
from manifest import Manifest
55-
from langchain.llms.manifest import ManifestWrapper
5652

57-
manifest = Manifest(client_name="huggingface", client_connection="http://127.0.0.1:5000")
53+
if settings.SELF_HOSTED_MODEL:
54+
from langchain.llms import HuggingFacePipeline
55+
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
56+
57+
model_id = settings.LLM_NAME # hf model id (Arc53/docsgpt-7b-falcon, Arc53/docsgpt-14b)
58+
tokenizer = AutoTokenizer.from_pretrained(model_id)
59+
model = AutoModelForCausalLM.from_pretrained(model_id)
60+
pipe = pipeline(
61+
"text-generation", model=model,
62+
tokenizer=tokenizer, max_new_tokens=2000,
63+
device_map="auto", eos_token_id=tokenizer.eos_token_id
64+
)
65+
hf = HuggingFacePipeline(pipeline=pipe)
5866

5967
# Redirect PosixPath to WindowsPath on Windows
6068

@@ -346,14 +354,10 @@ def api_answer():
346354
p_chat_combine = ChatPromptTemplate.from_messages(messages_combine)
347355
elif settings.LLM_NAME == "openai":
348356
llm = OpenAI(openai_api_key=api_key, temperature=0)
349-
elif settings.LLM_NAME == "manifest":
350-
llm = ManifestWrapper(client=manifest, llm_kwargs={"temperature": 0.001, "max_tokens": 2048})
351-
elif settings.LLM_NAME == "huggingface":
352-
llm = HuggingFaceHub(repo_id="bigscience/bloom", huggingfacehub_api_token=api_key)
357+
elif settings.SELF_HOSTED_MODEL:
358+
llm = hf
353359
elif settings.LLM_NAME == "cohere":
354360
llm = Cohere(model="command-xlarge-nightly", cohere_api_key=api_key)
355-
elif settings.LLM_NAME == "gpt4all":
356-
llm = GPT4All(model=settings.MODEL_PATH)
357361
else:
358362
raise ValueError("unknown LLM model")
359363

@@ -369,7 +373,7 @@ def api_answer():
369373
# result = chain({"question": question, "chat_history": chat_history})
370374
# generate async with async generate method
371375
result = run_async_chain(chain, question, chat_history)
372-
elif settings.LLM_NAME == "gpt4all":
376+
elif settings.SELF_HOSTED_MODEL:
373377
question_generator = LLMChain(llm=llm, prompt=CONDENSE_QUESTION_PROMPT)
374378
doc_chain = load_qa_chain(llm, chain_type="map_reduce", combine_prompt=p_chat_combine)
375379
chain = ConversationalRetrievalChain(

application/core/settings.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ class Settings(BaseSettings):
1111
MONGO_URI: str = "mongodb://localhost:27017/docsgpt"
1212
MODEL_PATH: str = "./models/gpt4all-model.bin"
1313
TOKENS_MAX_HISTORY: int = 150
14+
SELF_HOSTED_MODEL: bool = False
1415

1516
API_URL: str = "http://localhost:7091" # backend url for celery worker
1617

docker-compose.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ services:
1919
- CELERY_BROKER_URL=redis://redis:6379/0
2020
- CELERY_RESULT_BACKEND=redis://redis:6379/1
2121
- MONGO_URI=mongodb://mongo:27017/docsgpt
22+
- SELF_HOSTED_MODEL=$SELF_HOSTED_MODEL
2223
ports:
2324
- "7091:7091"
2425
volumes:

0 commit comments

Comments
 (0)