Skip to content

Commit 0798b93

Browse files
Fix for Issue 16, to perform end-to-end testing (#23)
* Fix for Issue 16, to perform end-to-end testing * tests: somehow chat_model unit tests gone missing * refactor: move st command into main * refactor: set up relative path object for mhg spacy model * tests: make sure streamlit test is run * tests: move streamlit command from script to module * tests: increase timeout for streamlit app * refactor: download spacy modern model on the fly if not found * ci: try to checkout mhg german model repo in ci * doc: update readme for automated spacy model install * docs and bug: update mhg spacy model location to check for a path, remove obsolete streamlit exp rerun command * ci: use env variable for mhg repo * bug: correct str and path concatenation * add simple test for image_search * add simple test for input_output * add simple test for text_tagging * update test_chat_models * try to mock avoidance of the mhg model * add simple test for app.py --------- Co-authored-by: Inga Ulusoy <inga.ulusoy@uni-heidelberg.de>
1 parent 6f18bf0 commit 0798b93

11 files changed

Lines changed: 637 additions & 85 deletions

File tree

.github/workflows/ci.yml

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,13 @@ jobs:
2626
steps:
2727
- name: Checkout repository
2828
uses: actions/checkout@v4
29-
29+
with:
30+
path: .
31+
- name: Checkout the mhg spacy model repo
32+
uses: actions/checkout@v4
33+
with:
34+
repository: Middle-High-German-Conceptual-Database/Spacy-Model-for-Middle-High-German
35+
path: mhg
3036
- name: Set up Python ${{ matrix.python }}
3137
uses: actions/setup-python@v5
3238
with:
@@ -37,6 +43,8 @@ jobs:
3743
python -m pip install -r requirements-dev.txt
3844
cd parzivai
3945
python -m pytest -svv --cov=. --cov-branch --cov-report=xml
46+
env:
47+
SPACY_MHG_MODEL_PATH: ${{ github.workspace }}/mhg
4048
- name: Upload coverage reports to Codecov
4149
uses: codecov/codecov-action@v5
4250
with:

README.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,17 @@ parzivAI makes use from [spaCy](https://spacy.io/) under the hood. Download the
2727
```bash
2828
python -m spacy download de_core_news_sm
2929
```
30-
(*TODO: Download models on the fly if not found through the spacy cli*)
30+
Note that the model is downloaded on the fly if not found through the spacy cli.
3131

32-
For Middle High German, a specially trained model must be loaded, and its path needs to be integrated into the code. The model can be found [here](https://github.com/Middle-High-German-Conceptual-Database/Spacy-Model-for-Middle-High-German). Git clone the repository and place it in the same folder as the parzivAI repo:
32+
For Middle High German, a specially trained model must be loaded, and its path needs to be integrated into the code. The model can be found [here](https://github.com/Middle-High-German-Conceptual-Database/Spacy-Model-for-Middle-High-German). Git clone the repository and either set an environment variable with the model path as
3333
```
34-
you-folder/
34+
export SPACY_MHG_MODEL_PATH=/path/to/Spacy-Model-for-Middle-High-German-repo
35+
```
36+
or place it in the same folder as the parzivAI repo:
37+
```
38+
your-folder/
3539
36-
├── parzivai # parzivai
40+
├── parzivai # parzivai repo
3741
├── Spacy-Model-for-Middle-High-German # spaCy model
3842
```
3943
(*TODO: Make sure this is platform-agnostic and can also be done on-the-fly*)

parzivai/app.py

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111
from langchain.schema import Document
1212
from langchain_community.tools.tavily_search import TavilySearchResults
1313
from langchain_core.messages import HumanMessage, AIMessage
14-
15-
st.set_page_config(page_title="ParzivAI")
1614
from parzivai.input_output import get_vectorstore, load_embeddings_model
17-
from parzivai.image_search import display_images
15+
from parzivai.image_search import fetch_images_for_topic
1816
from parzivai.text_tagging import (
1917
check_attributes,
2018
POS_DESCRIPTIONS,
19+
load_modern_model,
20+
load_mhg_model,
2121
pos_tagging_mhg,
2222
pos_tagging_modern,
2323
)
@@ -30,10 +30,26 @@
3030
SIMPLE_INQUIRIES,
3131
)
3232

33+
# Page configuration (must be first Streamlit command)
34+
st.set_page_config(page_title="ParzivAI")
35+
3336
# avoid some torch incompatibility issues with newer Python versions
3437
# see https://github.com/SaiAkhil066/DeepSeek-RAG-Chatbot/issues/4
3538
torch.classes.__path__ = []
3639

40+
41+
# Add cache approach of getting models here, to make it easier for unit-tests
42+
@st.cache_resource
43+
def get_cached_retriever():
44+
embedding_model = load_embeddings_model()
45+
return get_vectorstore(embedding_model)
46+
47+
48+
@st.cache_resource
49+
def get_models():
50+
return load_modern_model(), load_mhg_model()
51+
52+
3753
# Set API keys
3854
load_dotenv() # TODO create a .env file in the root directory with TAVILY_API_KEY and delete initialization of TAVILY_API_KEY below
3955
if not os.getenv("TAVILY_API_KEY"):
@@ -47,7 +63,6 @@
4763
PKG = resources.files("parzivai")
4864
FILE_PATH = PKG / "data"
4965
AVATAR_IMAGE = str(FILE_PATH / "parzival.png")
50-
retriever = get_vectorstore()
5166
llm = instantiate_llm()
5267
EMOJI_MAP = {
5368
"Vectorstore": "📚",
@@ -79,7 +94,7 @@ def append_to_rendered_messages(role, content):
7994

8095

8196
@st.cache_data(ttl=3600)
82-
def retrieve(question) -> dict:
97+
def retrieve(question, retriever) -> dict:
8398
documents = retriever.invoke(question)
8499
return {"documents": documents, "question": question}
85100

@@ -150,8 +165,8 @@ def web_search(question):
150165
}
151166

152167

153-
def decide_route(question):
154-
documents = retrieve(question)["documents"]
168+
def decide_route(question, retriever):
169+
documents = retrieve(question, retriever)["documents"]
155170
print("Documents retrieved from Vectorstore:")
156171
for doc in documents:
157172
print(doc if isinstance(doc, str) else doc.page_content)
@@ -208,7 +223,7 @@ def save_chat_history_and_messages(role: str, message: str):
208223
save_chat_to_history(role, message)
209224

210225

211-
def process_user_input(user_input):
226+
def process_user_input(user_input, retriever):
212227
save_chat_history_and_messages("User", user_input)
213228
st.session_state.state["question"] = user_input
214229
st.session_state.state["messages"] = st.session_state.messages
@@ -222,7 +237,7 @@ def process_user_input(user_input):
222237
elif contains_any(user_input, SIMPLE_INQUIRIES["simple_inquiries"]):
223238
handle_direct_response(user_input)
224239
else:
225-
handle_routing_and_answer(user_input)
240+
handle_routing_and_answer(user_input, retriever)
226241

227242

228243
def is_translation_request(text: str) -> bool:
@@ -266,8 +281,8 @@ def handle_direct_response(user_input: str):
266281
save_chat_history_and_messages("Assistant", response.content)
267282

268283

269-
def handle_routing_and_answer(user_input: str):
270-
routing_info = decide_route(user_input)
284+
def handle_routing_and_answer(user_input: str, retriever):
285+
routing_info = decide_route(user_input, retriever)
271286
st.session_state.state.update(routing_info)
272287

273288
if routing_info["route_taken"] == "Vectorstore":
@@ -324,26 +339,25 @@ def build_final_response_message(route: str, result: dict) -> str:
324339
return message
325340

326341

327-
def show_pos_tagging_options(latest_response: str):
342+
def show_pos_tagging_options(latest_response: str, nlp_modern, nlp_mhg):
328343
st.markdown("### POS-Tagging Options")
329344
col1, col2 = st.columns(2)
330345
with col1:
331346
if st.button("POS-Tagging (Modernes Deutsch)"):
332-
doc = pos_tagging_modern(latest_response)
347+
doc = pos_tagging_modern(nlp_modern, latest_response)
333348
if doc:
334349
st.session_state.linguistic_analysis = ("Modernes Deutsch", doc)
335350
st.rerun()
336351
with col2:
337352
if st.button("POS-Tagging (Mittelhochdeutsch)"):
338-
doc = pos_tagging_mhg(latest_response)
353+
doc = pos_tagging_mhg(nlp_mhg, latest_response)
339354
if doc:
340355
st.session_state.linguistic_analysis = ("Mittelhochdeutsch", doc)
341-
st.experimental_update()
356+
st.rerun()
342357

343358

344359
def main():
345360
# Main function to run the Streamlit app
346-
# Page configuration (must be first Streamlit command)
347361
tab1, tab2, tab3, tab4 = st.tabs(
348362
[
349363
"ParzivAI Chatbot",
@@ -372,10 +386,12 @@ def main():
372386
st.sidebar.image(AVATAR_IMAGE, width=150)
373387
# function to initialize all session state variables
374388
initialize_session_state()
389+
retriever = get_cached_retriever()
390+
nlp_modern, nlp_mhg = get_models()
375391

376392
user_input = st.chat_input("Ask ParzivAI a question:")
377393
if user_input:
378-
process_user_input(user_input)
394+
process_user_input(user_input, retriever)
379395

380396
with st.sidebar.expander("Cached Data"):
381397
st.write("Embeddings:")
@@ -413,7 +429,7 @@ def main():
413429
None,
414430
)
415431
if assistant_response:
416-
show_pos_tagging_options(assistant_response)
432+
show_pos_tagging_options(assistant_response, nlp_modern, nlp_mhg)
417433

418434
# Feedback collection
419435
feedback = streamlit_feedback(
@@ -479,7 +495,16 @@ def main():
479495

480496
if "image_search_result" in st.session_state:
481497
st.write("Searching for images...")
482-
asyncio.run(display_images(st.session_state.image_search_result))
498+
image_data = asyncio.run(
499+
fetch_images_for_topic(st.session_state.image_search_result)
500+
)
501+
502+
for data in image_data:
503+
st.image(
504+
data["url"],
505+
caption=f"Bildthema: {data['name']}, Archivnummer: {data['archiveNumber']}, URL: {data['url']}",
506+
use_container_width=True,
507+
)
483508

484509
with tab3:
485510
st.header("Linguistische Analyse")

parzivai/image_search.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import streamlit as st
21
from urllib.parse import quote
32
import json
43
from playwright.async_api import async_playwright
@@ -16,11 +15,12 @@
1615
"image_search_url"
1716
) # it would be defined in the config file
1817
except FileNotFoundError:
19-
st.error(f"Configuration file not found at {CONFIG_PATH}. Please ensure it exists.")
20-
raise
18+
raise RuntimeError(
19+
f"Configuration file not found at {CONFIG_PATH}. Please ensure it exists."
20+
)
21+
2122
except json.JSONDecodeError as e:
22-
st.error(f"Error decoding configuration file: {e}")
23-
raise
23+
raise RuntimeError(f"Error decoding configuration file: {e}") from e
2424

2525

2626
def adjust_image_url(url: str) -> str:
@@ -80,12 +80,6 @@ async def fetch_images(topic: str):
8080
return image_data
8181

8282

83-
async def display_images(topic: str):
84-
"""Display fetched images in Streamlit."""
85-
image_data = await fetch_images(topic)
86-
for data in image_data:
87-
st.image(
88-
data["url"],
89-
caption=f"Bildthema: {data['name']}, Archivnummer: {data['archiveNumber']}, URL: {data['url']}",
90-
use_container_width=True,
91-
)
83+
async def fetch_images_for_topic(topic: str) -> list[dict]:
84+
"""Return image metadata for a given topic."""
85+
return await fetch_images(topic)

parzivai/input_output.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import json
3+
import warnings
34
from importlib import resources
4-
import streamlit as st
55
from langchain_huggingface import HuggingFaceEmbeddings
66
from langchain.text_splitter import RecursiveCharacterTextSplitter
77
from langchain_community.document_loaders import (
@@ -27,14 +27,11 @@ def load_config(file):
2727
with open(FILE_PATH / file, "r") as file:
2828
return json.load(file)
2929
except FileNotFoundError:
30-
st.error(f"Configuration file not found: {FILE_PATH / file}")
31-
return {}
30+
raise RuntimeError(f"Configuration file not found: {FILE_PATH / file}")
3231
except json.JSONDecodeError as e:
33-
st.error(f"Error decoding configuration file: {e}")
34-
return {}
32+
raise RuntimeError(f"Error decoding configuration file: {e}")
3533

3634

37-
@st.cache_resource
3835
def load_embeddings_model():
3936
model_name_hf = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
4037
model_kwargs_hf = {"device": "cpu"}
@@ -46,7 +43,7 @@ def load_embeddings_model():
4643
)
4744

4845

49-
def load_documents_and_create_vectorstore():
46+
def load_documents_and_create_vectorstore(embedding_model):
5047
"""Load documents from URLs and static files to create FAISS vector store."""
5148
# Load URLs
5249
urls_data = load_config(file="urls.json")
@@ -74,8 +71,9 @@ def load_documents_and_create_vectorstore():
7471
else:
7572
continue
7673
static_docs.extend(loader.load())
77-
except Exception as e:
78-
print(f"Error loading file {file_name}: {e}")
74+
except (IOError, ValueError) as e:
75+
warnings.warn(f"Problem loading '{file_name}': {e}", UserWarning)
76+
continue
7977

8078
# Combine and process documents
8179
all_docs = web_docs + static_docs
@@ -86,28 +84,26 @@ def load_documents_and_create_vectorstore():
8684
print("Documents loaded and split successfully.")
8785

8886
# Create and save FAISS vector store
89-
vectorstore = FAISS.from_documents(doc_splits, load_embeddings_model())
87+
vectorstore = FAISS.from_documents(doc_splits, embedding_model)
9088
vectorstore.save_local(persist_folder)
9189
print(f"FAISS index initialized and saved successfully in {persist_folder}.")
9290
return vectorstore
9391

9492

95-
def get_vectorstore():
93+
def get_vectorstore(embedding_model):
9694
vectorstore_exists = os.path.exists(index_path)
9795
if vectorstore_exists:
9896
try:
9997
vectorstore = FAISS.load_local(
10098
persist_folder,
101-
load_embeddings_model(),
99+
embedding_model,
102100
allow_dangerous_deserialization=True,
103101
)
104102
print(f"FAISS index loaded successfully from {persist_folder}.")
105103
except Exception as e:
106-
print(f"Error loading existing FAISS index: {e}")
107-
st.error(f"Error loading existing FAISS index: {e}")
108-
raise e
104+
raise RuntimeError(f"Error loading existing FAISS index: {e}") from e
109105
else:
110-
vectorstore = load_documents_and_create_vectorstore()
106+
vectorstore = load_documents_and_create_vectorstore(embedding_model)
111107

112108
retriever = vectorstore.as_retriever()
113109
return retriever

0 commit comments

Comments
 (0)