Skip to content

Commit

Permalink
apply black formatting everywhere
Browse files Browse the repository at this point in the history
  • Loading branch information
maxjakob committed Feb 13, 2024
1 parent 49b9eb3 commit e71df0c
Show file tree
Hide file tree
Showing 72 changed files with 8,494 additions and 8,622 deletions.
18 changes: 9 additions & 9 deletions bin/mocks/elasticsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,30 @@ def patch_elasticsearch():

# remove the path entry that refers to this directory
for path in sys.path:
if not path.startswith('/'):
if not path.startswith("/"):
path = os.path.join(os.getcwd(), path)
if __file__ == os.path.join(path, 'elasticsearch.py'):
if __file__ == os.path.join(path, "elasticsearch.py"):
sys.path.remove(path)
break

# remove this module, and import the real one instead
del sys.modules['elasticsearch']
del sys.modules["elasticsearch"]
import elasticsearch

# restore the import path
sys.path = saved_path

# preserve the original Elasticsearch.__init__ method
# preserve the original Elasticsearch.__init__ method
orig_es_init = elasticsearch.Elasticsearch.__init__

# patched version of Elasticsearch.__init__ that connects to self-hosted
# regardless of connection arguments given
def patched_es_init(self, *args, **kwargs):
if 'cloud_id' in kwargs:
assert kwargs['cloud_id'] == 'foo'
if 'api_key' in kwargs:
assert kwargs['api_key'] == 'bar'
return orig_es_init(self, 'http://localhost:9200')
if "cloud_id" in kwargs:
assert kwargs["cloud_id"] == "foo"
if "api_key" in kwargs:
assert kwargs["api_key"] == "bar"
return orig_es_init(self, "http://localhost:9200")

# patch Elasticsearch.__init__
elasticsearch.Elasticsearch.__init__ = patched_es_init
Expand Down
32 changes: 20 additions & 12 deletions example-apps/chatbot-rag-app/api/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,31 +36,39 @@ def ask_question(question, session_id):
if len(chat_history.messages) > 0:
# create a condensed question
condense_question_prompt = render_template(
'condense_question_prompt.txt', question=question,
chat_history=chat_history.messages)
"condense_question_prompt.txt",
question=question,
chat_history=chat_history.messages,
)
condensed_question = get_llm().invoke(condense_question_prompt).content
else:
condensed_question = question

current_app.logger.debug('Condensed question: %s', condensed_question)
current_app.logger.debug('Question: %s', question)
current_app.logger.debug("Condensed question: %s", condensed_question)
current_app.logger.debug("Question: %s", question)

docs = store.as_retriever().invoke(condensed_question)
for doc in docs:
doc_source = {**doc.metadata, 'page_content': doc.page_content}
current_app.logger.debug('Retrieved document passage from: %s', doc.metadata['name'])
yield f'data: {SOURCE_TAG} {json.dumps(doc_source)}\n\n'
doc_source = {**doc.metadata, "page_content": doc.page_content}
current_app.logger.debug(
"Retrieved document passage from: %s", doc.metadata["name"]
)
yield f"data: {SOURCE_TAG} {json.dumps(doc_source)}\n\n"

qa_prompt = render_template('rag_prompt.txt', question=question, docs=docs,
chat_history=chat_history.messages)
qa_prompt = render_template(
"rag_prompt.txt",
question=question,
docs=docs,
chat_history=chat_history.messages,
)

answer = ''
answer = ""
for chunk in get_llm().stream(qa_prompt):
yield f'data: {chunk.content}\n\n'
yield f"data: {chunk.content}\n\n"
answer += chunk.content

yield f"data: {DONE_TAG}\n\n"
current_app.logger.debug('Answer: %s', answer)
current_app.logger.debug("Answer: %s", answer)

chat_history.add_user_message(question)
chat_history.add_ai_message(answer)
48 changes: 35 additions & 13 deletions example-apps/chatbot-rag-app/api/llm_integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,54 @@

LLM_TYPE = os.getenv("LLM_TYPE", "openai")


def init_openai_chat(temperature):
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
return ChatOpenAI(openai_api_key=OPENAI_API_KEY, streaming=True, temperature=temperature)
return ChatOpenAI(
openai_api_key=OPENAI_API_KEY, streaming=True, temperature=temperature
)


def init_vertex_chat(temperature):
VERTEX_PROJECT_ID = os.getenv("VERTEX_PROJECT_ID")
VERTEX_REGION = os.getenv("VERTEX_REGION", "us-central1")
vertexai.init(project=VERTEX_PROJECT_ID, location=VERTEX_REGION)
return ChatVertexAI(streaming=True, temperature=temperature)


def init_azure_chat(temperature):
OPENAI_VERSION=os.getenv("OPENAI_VERSION", "2023-05-15")
BASE_URL=os.getenv("OPENAI_BASE_URL")
OPENAI_API_KEY=os.getenv("OPENAI_API_KEY")
OPENAI_ENGINE=os.getenv("OPENAI_ENGINE")
OPENAI_VERSION = os.getenv("OPENAI_VERSION", "2023-05-15")
BASE_URL = os.getenv("OPENAI_BASE_URL")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
OPENAI_ENGINE = os.getenv("OPENAI_ENGINE")
return AzureChatOpenAI(
deployment_name=OPENAI_ENGINE,
openai_api_base=BASE_URL,
openai_api_version=OPENAI_VERSION,
openai_api_key=OPENAI_API_KEY,
streaming=True,
temperature=temperature)
temperature=temperature,
)


def init_bedrock(temperature):
AWS_ACCESS_KEY=os.getenv("AWS_ACCESS_KEY")
AWS_SECRET_KEY=os.getenv("AWS_SECRET_KEY")
AWS_REGION=os.getenv("AWS_REGION")
AWS_MODEL_ID=os.getenv("AWS_MODEL_ID", "anthropic.claude-v2")
BEDROCK_CLIENT=boto3.client(service_name="bedrock-runtime", region_name=AWS_REGION, aws_access_key_id=AWS_ACCESS_KEY, aws_secret_access_key=AWS_SECRET_KEY)
AWS_ACCESS_KEY = os.getenv("AWS_ACCESS_KEY")
AWS_SECRET_KEY = os.getenv("AWS_SECRET_KEY")
AWS_REGION = os.getenv("AWS_REGION")
AWS_MODEL_ID = os.getenv("AWS_MODEL_ID", "anthropic.claude-v2")
BEDROCK_CLIENT = boto3.client(
service_name="bedrock-runtime",
region_name=AWS_REGION,
aws_access_key_id=AWS_ACCESS_KEY,
aws_secret_access_key=AWS_SECRET_KEY,
)
return BedrockChat(
client=BEDROCK_CLIENT,
model_id=AWS_MODEL_ID,
streaming=True,
model_kwargs={"temperature":temperature})
model_kwargs={"temperature": temperature},
)


MAP_LLM_TYPE_TO_CHAT_MODEL = {
"azure": init_azure_chat,
Expand All @@ -44,8 +61,13 @@ def init_bedrock(temperature):
"vertex": init_vertex_chat,
}


def get_llm(temperature=0):
if not LLM_TYPE in MAP_LLM_TYPE_TO_CHAT_MODEL:
raise Exception("LLM type not found. Please set LLM_TYPE to one of: " + ", ".join(MAP_LLM_TYPE_TO_CHAT_MODEL.keys()) + ".")
raise Exception(
"LLM type not found. Please set LLM_TYPE to one of: "
+ ", ".join(MAP_LLM_TYPE_TO_CHAT_MODEL.keys())
+ "."
)

return MAP_LLM_TYPE_TO_CHAT_MODEL[LLM_TYPE](temperature=temperature)
16 changes: 9 additions & 7 deletions example-apps/chatbot-rag-app/data/index_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,16 @@ def main():

print(f"Loading data from ${FILE}")

metadata_keys = ['name', 'summary', 'url', 'category', 'updated_at']
metadata_keys = ["name", "summary", "url", "category", "updated_at"]
workplace_docs = []
with open(FILE, 'rt') as f:
with open(FILE, "rt") as f:
for doc in json.loads(f.read()):
workplace_docs.append(Document(
page_content=doc['content'],
metadata={k: doc.get(k) for k in metadata_keys}
))
workplace_docs.append(
Document(
page_content=doc["content"],
metadata={k: doc.get(k) for k in metadata_keys},
)
)

print(f"Loaded {len(workplace_docs)} documents")

Expand All @@ -92,7 +94,7 @@ def main():
index_name=INDEX,
strategy=ElasticsearchStore.SparseVectorRetrievalStrategy(model_id=ELSER_MODEL),
bulk_kwargs={
'request_timeout': 60,
"request_timeout": 60,
},
)

Expand Down
38 changes: 19 additions & 19 deletions example-apps/internal-knowledge-search/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@


def get_identities_index(search_app_name):
search_app = elasticsearch_client.search_application.get(
name=search_app_name)
identities_indices = elasticsearch_client.indices.get(
index=".search-acl-filter*")
search_app = elasticsearch_client.search_application.get(name=search_app_name)
identities_indices = elasticsearch_client.indices.get(index=".search-acl-filter*")
secured_index = [
app_index
for app_index in search_app["indices"]
Expand All @@ -36,19 +34,22 @@ def api_index():
@app.route("/api/default_settings", methods=["GET"])
def default_settings():
return {
"elasticsearch_endpoint": os.getenv("ELASTICSEARCH_URL") or "http://localhost:9200"
"elasticsearch_endpoint": os.getenv("ELASTICSEARCH_URL")
or "http://localhost:9200"
}


@app.route("/api/search_proxy/<path:text>", methods=["POST"])
def search(text):
response = requests.request(
method="POST",
url=os.getenv("ELASTICSEARCH_URL") + '/' + text,
url=os.getenv("ELASTICSEARCH_URL") + "/" + text,
data=request.get_data(),
allow_redirects=False,
headers={"Authorization": request.headers.get(
"Authorization"), "Content-Type": "application/json"}
headers={
"Authorization": request.headers.get("Authorization"),
"Content-Type": "application/json",
},
)

return response.content
Expand All @@ -59,8 +60,7 @@ def personas():
try:
search_app_name = request.args.get("app_name")
identities_index = get_identities_index(search_app_name)
response = elasticsearch_client.search(
index=identities_index, size=1000)
response = elasticsearch_client.search(index=identities_index, size=1000)
hits = response["hits"]["hits"]
personas = [x["_id"] for x in hits]
personas.append("admin")
Expand All @@ -77,9 +77,8 @@ def personas():
def indices():
try:
search_app_name = request.args.get("app_name")
search_app = elasticsearch_client.search_application.get(
name=search_app_name)
return search_app['indices']
search_app = elasticsearch_client.search_application.get(name=search_app_name)
return search_app["indices"]

except Exception as e:
current_app.logger.warn(
Expand Down Expand Up @@ -118,8 +117,7 @@ def api_key():
if persona == "admin":
role_descriptor = default_role_descriptor
else:
identity = elasticsearch_client.get(
index=identities_index, id=persona)
identity = elasticsearch_client.get(index=identities_index, id=persona)
permissions = identity["_source"]["query"]["template"]["params"][
"access_control"
]
Expand Down Expand Up @@ -161,12 +159,14 @@ def api_key():
}
}
api_key = elasticsearch_client.security.create_api_key(
name=search_app_name+"-internal-knowledge-search-example-"+persona, expiration="1h", role_descriptors=role_descriptor)
return {"api_key": api_key['encoded']}
name=search_app_name + "-internal-knowledge-search-example-" + persona,
expiration="1h",
role_descriptors=role_descriptor,
)
return {"api_key": api_key["encoded"]}

except Exception as e:
current_app.logger.warn(
"Encountered error %s while fetching api key", e)
current_app.logger.warn("Encountered error %s while fetching api key", e)
raise e


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@
)
elif ELASTIC_USERNAME and ELASTIC_PASSWORD:
elasticsearch_client = Elasticsearch(
basic_auth=(ELASTIC_USERNAME, ELASTIC_PASSWORD),
cloud_id=ELASTIC_CLOUD_ID
basic_auth=(ELASTIC_USERNAME, ELASTIC_PASSWORD), cloud_id=ELASTIC_CLOUD_ID
)
else:
raise ValueError(
Expand Down
Loading

0 comments on commit e71df0c

Please sign in to comment.