Skip to content

Commit e71df0c

Browse files
committed
apply black formatting everywhere
1 parent 49b9eb3 commit e71df0c

File tree

72 files changed

+8494
-8622
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+8494
-8622
lines changed

bin/mocks/elasticsearch.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,30 +8,30 @@ def patch_elasticsearch():
88

99
# remove the path entry that refers to this directory
1010
for path in sys.path:
11-
if not path.startswith('/'):
11+
if not path.startswith("/"):
1212
path = os.path.join(os.getcwd(), path)
13-
if __file__ == os.path.join(path, 'elasticsearch.py'):
13+
if __file__ == os.path.join(path, "elasticsearch.py"):
1414
sys.path.remove(path)
1515
break
1616

1717
# remove this module, and import the real one instead
18-
del sys.modules['elasticsearch']
18+
del sys.modules["elasticsearch"]
1919
import elasticsearch
2020

2121
# restore the import path
2222
sys.path = saved_path
2323

24-
# preserve the original Elasticsearch.__init__ method
24+
# preserve the original Elasticsearch.__init__ method
2525
orig_es_init = elasticsearch.Elasticsearch.__init__
2626

2727
# patched version of Elasticsearch.__init__ that connects to self-hosted
2828
# regardless of connection arguments given
2929
def patched_es_init(self, *args, **kwargs):
30-
if 'cloud_id' in kwargs:
31-
assert kwargs['cloud_id'] == 'foo'
32-
if 'api_key' in kwargs:
33-
assert kwargs['api_key'] == 'bar'
34-
return orig_es_init(self, 'http://localhost:9200')
30+
if "cloud_id" in kwargs:
31+
assert kwargs["cloud_id"] == "foo"
32+
if "api_key" in kwargs:
33+
assert kwargs["api_key"] == "bar"
34+
return orig_es_init(self, "http://localhost:9200")
3535

3636
# patch Elasticsearch.__init__
3737
elasticsearch.Elasticsearch.__init__ = patched_es_init

example-apps/chatbot-rag-app/api/chat.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,31 +36,39 @@ def ask_question(question, session_id):
3636
if len(chat_history.messages) > 0:
3737
# create a condensed question
3838
condense_question_prompt = render_template(
39-
'condense_question_prompt.txt', question=question,
40-
chat_history=chat_history.messages)
39+
"condense_question_prompt.txt",
40+
question=question,
41+
chat_history=chat_history.messages,
42+
)
4143
condensed_question = get_llm().invoke(condense_question_prompt).content
4244
else:
4345
condensed_question = question
4446

45-
current_app.logger.debug('Condensed question: %s', condensed_question)
46-
current_app.logger.debug('Question: %s', question)
47+
current_app.logger.debug("Condensed question: %s", condensed_question)
48+
current_app.logger.debug("Question: %s", question)
4749

4850
docs = store.as_retriever().invoke(condensed_question)
4951
for doc in docs:
50-
doc_source = {**doc.metadata, 'page_content': doc.page_content}
51-
current_app.logger.debug('Retrieved document passage from: %s', doc.metadata['name'])
52-
yield f'data: {SOURCE_TAG} {json.dumps(doc_source)}\n\n'
52+
doc_source = {**doc.metadata, "page_content": doc.page_content}
53+
current_app.logger.debug(
54+
"Retrieved document passage from: %s", doc.metadata["name"]
55+
)
56+
yield f"data: {SOURCE_TAG} {json.dumps(doc_source)}\n\n"
5357

54-
qa_prompt = render_template('rag_prompt.txt', question=question, docs=docs,
55-
chat_history=chat_history.messages)
58+
qa_prompt = render_template(
59+
"rag_prompt.txt",
60+
question=question,
61+
docs=docs,
62+
chat_history=chat_history.messages,
63+
)
5664

57-
answer = ''
65+
answer = ""
5866
for chunk in get_llm().stream(qa_prompt):
59-
yield f'data: {chunk.content}\n\n'
67+
yield f"data: {chunk.content}\n\n"
6068
answer += chunk.content
6169

6270
yield f"data: {DONE_TAG}\n\n"
63-
current_app.logger.debug('Answer: %s', answer)
71+
current_app.logger.debug("Answer: %s", answer)
6472

6573
chat_history.add_user_message(question)
6674
chat_history.add_ai_message(answer)

example-apps/chatbot-rag-app/api/llm_integrations.py

Lines changed: 35 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,37 +5,54 @@
55

66
LLM_TYPE = os.getenv("LLM_TYPE", "openai")
77

8+
89
def init_openai_chat(temperature):
910
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
10-
return ChatOpenAI(openai_api_key=OPENAI_API_KEY, streaming=True, temperature=temperature)
11+
return ChatOpenAI(
12+
openai_api_key=OPENAI_API_KEY, streaming=True, temperature=temperature
13+
)
14+
15+
1116
def init_vertex_chat(temperature):
1217
VERTEX_PROJECT_ID = os.getenv("VERTEX_PROJECT_ID")
1318
VERTEX_REGION = os.getenv("VERTEX_REGION", "us-central1")
1419
vertexai.init(project=VERTEX_PROJECT_ID, location=VERTEX_REGION)
1520
return ChatVertexAI(streaming=True, temperature=temperature)
21+
22+
1623
def init_azure_chat(temperature):
17-
OPENAI_VERSION=os.getenv("OPENAI_VERSION", "2023-05-15")
18-
BASE_URL=os.getenv("OPENAI_BASE_URL")
19-
OPENAI_API_KEY=os.getenv("OPENAI_API_KEY")
20-
OPENAI_ENGINE=os.getenv("OPENAI_ENGINE")
24+
OPENAI_VERSION = os.getenv("OPENAI_VERSION", "2023-05-15")
25+
BASE_URL = os.getenv("OPENAI_BASE_URL")
26+
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
27+
OPENAI_ENGINE = os.getenv("OPENAI_ENGINE")
2128
return AzureChatOpenAI(
2229
deployment_name=OPENAI_ENGINE,
2330
openai_api_base=BASE_URL,
2431
openai_api_version=OPENAI_VERSION,
2532
openai_api_key=OPENAI_API_KEY,
2633
streaming=True,
27-
temperature=temperature)
34+
temperature=temperature,
35+
)
36+
37+
2838
def init_bedrock(temperature):
29-
AWS_ACCESS_KEY=os.getenv("AWS_ACCESS_KEY")
30-
AWS_SECRET_KEY=os.getenv("AWS_SECRET_KEY")
31-
AWS_REGION=os.getenv("AWS_REGION")
32-
AWS_MODEL_ID=os.getenv("AWS_MODEL_ID", "anthropic.claude-v2")
33-
BEDROCK_CLIENT=boto3.client(service_name="bedrock-runtime", region_name=AWS_REGION, aws_access_key_id=AWS_ACCESS_KEY, aws_secret_access_key=AWS_SECRET_KEY)
39+
AWS_ACCESS_KEY = os.getenv("AWS_ACCESS_KEY")
40+
AWS_SECRET_KEY = os.getenv("AWS_SECRET_KEY")
41+
AWS_REGION = os.getenv("AWS_REGION")
42+
AWS_MODEL_ID = os.getenv("AWS_MODEL_ID", "anthropic.claude-v2")
43+
BEDROCK_CLIENT = boto3.client(
44+
service_name="bedrock-runtime",
45+
region_name=AWS_REGION,
46+
aws_access_key_id=AWS_ACCESS_KEY,
47+
aws_secret_access_key=AWS_SECRET_KEY,
48+
)
3449
return BedrockChat(
3550
client=BEDROCK_CLIENT,
3651
model_id=AWS_MODEL_ID,
3752
streaming=True,
38-
model_kwargs={"temperature":temperature})
53+
model_kwargs={"temperature": temperature},
54+
)
55+
3956

4057
MAP_LLM_TYPE_TO_CHAT_MODEL = {
4158
"azure": init_azure_chat,
@@ -44,8 +61,13 @@ def init_bedrock(temperature):
4461
"vertex": init_vertex_chat,
4562
}
4663

64+
4765
def get_llm(temperature=0):
4866
if not LLM_TYPE in MAP_LLM_TYPE_TO_CHAT_MODEL:
49-
raise Exception("LLM type not found. Please set LLM_TYPE to one of: " + ", ".join(MAP_LLM_TYPE_TO_CHAT_MODEL.keys()) + ".")
67+
raise Exception(
68+
"LLM type not found. Please set LLM_TYPE to one of: "
69+
+ ", ".join(MAP_LLM_TYPE_TO_CHAT_MODEL.keys())
70+
+ "."
71+
)
5072

5173
return MAP_LLM_TYPE_TO_CHAT_MODEL[LLM_TYPE](temperature=temperature)

example-apps/chatbot-rag-app/data/index_data.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,16 @@ def main():
6161

6262
print(f"Loading data from ${FILE}")
6363

64-
metadata_keys = ['name', 'summary', 'url', 'category', 'updated_at']
64+
metadata_keys = ["name", "summary", "url", "category", "updated_at"]
6565
workplace_docs = []
66-
with open(FILE, 'rt') as f:
66+
with open(FILE, "rt") as f:
6767
for doc in json.loads(f.read()):
68-
workplace_docs.append(Document(
69-
page_content=doc['content'],
70-
metadata={k: doc.get(k) for k in metadata_keys}
71-
))
68+
workplace_docs.append(
69+
Document(
70+
page_content=doc["content"],
71+
metadata={k: doc.get(k) for k in metadata_keys},
72+
)
73+
)
7274

7375
print(f"Loaded {len(workplace_docs)} documents")
7476

@@ -92,7 +94,7 @@ def main():
9294
index_name=INDEX,
9395
strategy=ElasticsearchStore.SparseVectorRetrievalStrategy(model_id=ELSER_MODEL),
9496
bulk_kwargs={
95-
'request_timeout': 60,
97+
"request_timeout": 60,
9698
},
9799
)
98100

example-apps/internal-knowledge-search/api/app.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,8 @@
1010

1111

1212
def get_identities_index(search_app_name):
13-
search_app = elasticsearch_client.search_application.get(
14-
name=search_app_name)
15-
identities_indices = elasticsearch_client.indices.get(
16-
index=".search-acl-filter*")
13+
search_app = elasticsearch_client.search_application.get(name=search_app_name)
14+
identities_indices = elasticsearch_client.indices.get(index=".search-acl-filter*")
1715
secured_index = [
1816
app_index
1917
for app_index in search_app["indices"]
@@ -36,19 +34,22 @@ def api_index():
3634
@app.route("/api/default_settings", methods=["GET"])
3735
def default_settings():
3836
return {
39-
"elasticsearch_endpoint": os.getenv("ELASTICSEARCH_URL") or "http://localhost:9200"
37+
"elasticsearch_endpoint": os.getenv("ELASTICSEARCH_URL")
38+
or "http://localhost:9200"
4039
}
4140

4241

4342
@app.route("/api/search_proxy/<path:text>", methods=["POST"])
4443
def search(text):
4544
response = requests.request(
4645
method="POST",
47-
url=os.getenv("ELASTICSEARCH_URL") + '/' + text,
46+
url=os.getenv("ELASTICSEARCH_URL") + "/" + text,
4847
data=request.get_data(),
4948
allow_redirects=False,
50-
headers={"Authorization": request.headers.get(
51-
"Authorization"), "Content-Type": "application/json"}
49+
headers={
50+
"Authorization": request.headers.get("Authorization"),
51+
"Content-Type": "application/json",
52+
},
5253
)
5354

5455
return response.content
@@ -59,8 +60,7 @@ def personas():
5960
try:
6061
search_app_name = request.args.get("app_name")
6162
identities_index = get_identities_index(search_app_name)
62-
response = elasticsearch_client.search(
63-
index=identities_index, size=1000)
63+
response = elasticsearch_client.search(index=identities_index, size=1000)
6464
hits = response["hits"]["hits"]
6565
personas = [x["_id"] for x in hits]
6666
personas.append("admin")
@@ -77,9 +77,8 @@ def personas():
7777
def indices():
7878
try:
7979
search_app_name = request.args.get("app_name")
80-
search_app = elasticsearch_client.search_application.get(
81-
name=search_app_name)
82-
return search_app['indices']
80+
search_app = elasticsearch_client.search_application.get(name=search_app_name)
81+
return search_app["indices"]
8382

8483
except Exception as e:
8584
current_app.logger.warn(
@@ -118,8 +117,7 @@ def api_key():
118117
if persona == "admin":
119118
role_descriptor = default_role_descriptor
120119
else:
121-
identity = elasticsearch_client.get(
122-
index=identities_index, id=persona)
120+
identity = elasticsearch_client.get(index=identities_index, id=persona)
123121
permissions = identity["_source"]["query"]["template"]["params"][
124122
"access_control"
125123
]
@@ -161,12 +159,14 @@ def api_key():
161159
}
162160
}
163161
api_key = elasticsearch_client.security.create_api_key(
164-
name=search_app_name+"-internal-knowledge-search-example-"+persona, expiration="1h", role_descriptors=role_descriptor)
165-
return {"api_key": api_key['encoded']}
162+
name=search_app_name + "-internal-knowledge-search-example-" + persona,
163+
expiration="1h",
164+
role_descriptors=role_descriptor,
165+
)
166+
return {"api_key": api_key["encoded"]}
166167

167168
except Exception as e:
168-
current_app.logger.warn(
169-
"Encountered error %s while fetching api key", e)
169+
current_app.logger.warn("Encountered error %s while fetching api key", e)
170170
raise e
171171

172172

example-apps/internal-knowledge-search/api/elasticsearch_client.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,7 @@
3131
)
3232
elif ELASTIC_USERNAME and ELASTIC_PASSWORD:
3333
elasticsearch_client = Elasticsearch(
34-
basic_auth=(ELASTIC_USERNAME, ELASTIC_PASSWORD),
35-
cloud_id=ELASTIC_CLOUD_ID
34+
basic_auth=(ELASTIC_USERNAME, ELASTIC_PASSWORD), cloud_id=ELASTIC_CLOUD_ID
3635
)
3736
else:
3837
raise ValueError(

0 commit comments

Comments
 (0)