Skip to content

Commit ee37cf1

Browse files
leoguillaumecamilleANDleoguillaume
authored
feat(ui): refacto playground UI (#115)
* feat(playground): change ui + cache optimization * feat(playground): update pyproject.toml * feat(playground): update style for dsfr * feat(playground): update style for dsfr 2 * feat: cleaning and add reranker --------- Co-authored-by: camilleAND <camille.andre@modernisation.gouv.fr> Co-authored-by: leoguillaume <leo.guillaume@modernisation.gouv.fr>
1 parent 741b0cd commit ee37cf1

9 files changed

Lines changed: 419 additions & 296 deletions

File tree

compose.yml

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,5 @@
11
services:
22
fastapi:
3-
build:
4-
context: .
5-
dockerfile: ./app/Dockerfile
6-
73
image: ghcr.io/etalab-ia/albert-api/app:latest
84
command: uvicorn app.main:app --host 0.0.0.0 --port 8000
95
environment:
@@ -21,7 +17,7 @@ services:
2117

2218
streamlit:
2319
image: ghcr.io/etalab-ia/albert-api/ui:latest
24-
command: streamlit run /ui/chat.py --server.port=8501 --browser.gatherUsageStats false --theme.base light --server.maxUploadSize=20
20+
command: streamlit run ui/main.py --server.port=8501 --browser.gatherUsageStats false --theme.base=light --theme.primaryColor=#6a6af4 --server.maxUploadSize=20
2521
restart: always
2622
environment:
2723
- BASE_URL=http://fastapi:8000/v1

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ dependencies = [
1111

1212
[project.optional-dependencies]
1313
ui = [
14-
"streamlit==1.39.0",
14+
"streamlit==1.40.2",
1515
"streamlit-extras==0.5.0",
1616
]
1717
app = [

ui/chat.py

Lines changed: 0 additions & 119 deletions
This file was deleted.

ui/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
EMBEDDINGS_MODEL_TYPE = "text-embeddings-inference"
55
LANGUAGE_MODEL_TYPE = "text-generation"
66
AUDIO_MODEL_TYPE = "automatic-speech-recognition"
7+
RERANK_MODEL_TYPE = "text-classification"
78
INTERNET_COLLECTION_DISPLAY_ID = "internet"
89
PRIVATE_COLLECTION_TYPE = "private"
910
SUPPORTED_LANGUAGES = [

ui/main.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import streamlit as st
2+
3+
from config import BASE_URL
4+
5+
st.set_page_config(
6+
page_title="Albert playground",
7+
page_icon="https://www.systeme-de-design.gouv.fr/uploads/apple_touch_icon_8ffa1fa80c.png",
8+
layout="wide",
9+
initial_sidebar_state="expanded",
10+
menu_items={
11+
"Get Help": "mailto:etalab@modernisation.gouv.fr",
12+
"Report a bug": "https://github.com/etalab-ia/albert-api/issues",
13+
"About": "https://github.com/etalab-ia/albert-api",
14+
},
15+
)
16+
17+
st.logo(
18+
image="https://upload.wikimedia.org/wikipedia/fr/thumb/5/50/Bloc_Marianne.svg/1200px-Bloc_Marianne.svg.png",
19+
link=BASE_URL.replace("/v1", "/playground"),
20+
size="large",
21+
)
22+
23+
pg = st.navigation(
24+
pages=[
25+
st.Page(page="pages/chat.py", title="Chat", icon=":material/chat:"),
26+
st.Page(page="pages/documents.py", title="Documents", icon=":material/file_copy:"),
27+
st.Page(page="pages/transcription.py", title="Transcription", icon=":material/graphic_eq:"),
28+
]
29+
)
30+
pg.run()

ui/pages/chat.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
import logging
2+
import traceback
3+
4+
5+
import streamlit as st
6+
7+
from config import INTERNET_COLLECTION_DISPLAY_ID
8+
from utils import generate_stream, get_collections, get_models, header
9+
10+
API_KEY = header()
11+
12+
# Data
13+
try:
14+
language_models, embeddings_models, _, rerank_models = get_models(api_key=API_KEY)
15+
collections = get_collections(api_key=API_KEY)
16+
except Exception:
17+
st.error("Error to fetch user data.")
18+
logging.error(traceback.format_exc())
19+
st.stop()
20+
21+
# State
22+
23+
if "selected_model" not in st.session_state:
24+
st.session_state["selected_model"] = language_models[0]
25+
26+
if "selected_collections" not in st.session_state:
27+
st.session_state.selected_collections = []
28+
29+
if "messages" not in st.session_state:
30+
st.session_state["messages"] = []
31+
st.session_state["sources"] = []
32+
33+
# Sidebar
34+
with st.sidebar:
35+
new_chat = st.button(label="**:material/refresh: New chat**", key="new", use_container_width=True)
36+
if new_chat:
37+
st.session_state.pop("messages", None)
38+
st.session_state.pop("sources", None)
39+
st.rerun()
40+
params = {"sampling_params": dict(), "rag": dict()}
41+
42+
st.subheader(body="Chat parameters")
43+
st.session_state["selected_model"] = st.selectbox(
44+
label="Language model", options=language_models, index=language_models.index(st.session_state.selected_model)
45+
)
46+
47+
params["sampling_params"]["model"] = st.session_state["selected_model"]
48+
params["sampling_params"]["temperature"] = st.slider(label="Temperature", value=0.2, min_value=0.0, max_value=1.0, step=0.1)
49+
50+
if st.toggle(label="Max tokens", value=False):
51+
max_tokens = st.number_input(label="Max tokens", value=100, min_value=0, step=100)
52+
params["sampling_params"]["max_tokens"] = max_tokens
53+
54+
st.subheader(body="RAG parameters")
55+
params["rag"]["embeddings_model"] = st.selectbox(label="Embeddings model", options=embeddings_models)
56+
model_collections = [
57+
f"{collection["name"]} - {collection["id"]}" for collection in collections if collection["model"] == params["rag"]["embeddings_model"]
58+
] + [f"Internet - {INTERNET_COLLECTION_DISPLAY_ID}"]
59+
60+
if model_collections:
61+
62+
@st.dialog("Select collections")
63+
def add_collection(collections: list) -> None:
64+
selected_collections = st.session_state.selected_collections
65+
col1, col2 = st.columns(spec=2)
66+
67+
for collection in collections:
68+
collection_id = collection.split(" - ")[1]
69+
if st.checkbox(
70+
label=f"{collection.split(" - ")[0]} (*{collection_id[:8]}*)",
71+
value=False if collection_id not in st.session_state.selected_collections else True,
72+
):
73+
selected_collections.append(collection_id)
74+
elif collection_id in selected_collections:
75+
selected_collections.remove(collection_id)
76+
77+
with col1:
78+
if st.button(label="**Submit :material/check_circle:**", use_container_width=True):
79+
st.session_state.selected_collections = list(set(selected_collections))
80+
st.rerun()
81+
with col2:
82+
if st.button(label="**Clear :material/close:**", use_container_width=True):
83+
st.session_state.selected_collections = []
84+
st.rerun()
85+
86+
option_map = {0: f"{len(set(st.session_state.selected_collections))} selected"}
87+
pill = st.pills(
88+
label="Collections",
89+
options=option_map.keys(),
90+
format_func=lambda option: option_map[option],
91+
selection_mode="single",
92+
default=None,
93+
key="add_collections",
94+
)
95+
if pill == 0:
96+
add_collection(collections=model_collections)
97+
98+
params["rag"]["collections"] = st.session_state.selected_collections
99+
params["rag"]["k"] = st.number_input(label="Number of chunks to retrieve (k)", value=3)
100+
101+
if st.session_state.selected_collections:
102+
rag = st.toggle(label="Activated RAG", value=True, disabled=not bool(params["rag"]["collections"]))
103+
else:
104+
rag = st.toggle(label="Activated RAG", value=False, disabled=True, help="You need to select at least one collection to activate RAG.")
105+
106+
if st.session_state.selected_collections and rag:
107+
rerank = st.toggle(
108+
label="Add rerank",
109+
value=False,
110+
disabled=not bool(params["rag"]["collections"]),
111+
help="When activated, that retrieve the double number of chunks (k*2) and keep the best k chunks after reranking.",
112+
)
113+
if rerank:
114+
params["rag"]["rerank_model"] = st.selectbox(label="Rerank model", options=rerank_models)
115+
else:
116+
rerank = st.toggle(
117+
label="Add rerank", value=False, disabled=True, help="You need to select at least one collection to activate rerank and activate RAG."
118+
)
119+
120+
# Main
121+
with st.chat_message(name="assistant"):
122+
st.markdown(
123+
body="""Bonjour je suis Albert, et je peux vous aider si vous avez des questions administratives !
124+
125+
Je peux me connecter à vos bases de connaissances, pour ça sélectionnez les collections voulues dans le menu de gauche. Je peux également chercher sur les sites officiels de l'État, pour ça sélectionnez la collection "Internet" à gauche. Si vous ne souhaitez pas utiliser de collection, désactivez le RAG en décochant la fonction "Activated RAG".
126+
127+
Comment puis-je vous aider ?
128+
"""
129+
)
130+
131+
for i, message in enumerate(st.session_state.messages):
132+
with st.chat_message(message["role"], avatar=":material/face:" if message["role"] == "user" else None):
133+
st.markdown(message["content"])
134+
if st.session_state.sources[i]:
135+
st.pills(label="Sources", options=st.session_state.sources[i], label_visibility="hidden")
136+
137+
sources = []
138+
if prompt := st.chat_input(placeholder="Message to Albert"):
139+
# send message to the model
140+
user_message = {"role": "user", "content": prompt}
141+
st.session_state.messages.append(user_message)
142+
st.session_state.sources.append([])
143+
with st.chat_message(name="user", avatar=":material/face:"):
144+
st.markdown(body=prompt)
145+
146+
with st.chat_message(name="assistant"):
147+
try:
148+
stream, sources = generate_stream(
149+
messages=st.session_state.messages,
150+
params=params,
151+
api_key=API_KEY,
152+
rag=rag,
153+
rerank=rerank,
154+
)
155+
response = st.write_stream(stream=stream)
156+
except Exception:
157+
st.error(body="Error to generate response.")
158+
logging.error(traceback.format_exc())
159+
st.stop()
160+
161+
formatted_sources = []
162+
if sources:
163+
for source in sources:
164+
formatted_source = source[:15] + "..." if len(source) > 15 else source
165+
if source.lower().startswith("http"):
166+
formatted_sources.append(f":material/globe: [{formatted_source}]({source})")
167+
else:
168+
formatted_sources.append(f":material/import_contacts: {formatted_source}")
169+
st.pills(label="Sources", options=formatted_sources, label_visibility="hidden")
170+
171+
assistant_message = {"role": "assistant", "content": response}
172+
st.session_state.messages.append(assistant_message)
173+
st.session_state.sources.append(formatted_sources)
174+
175+
with st._bottom:
176+
st.caption(
177+
body='<p style="text-align: center;"><i>I can make mistakes, please always verify my sources and answers.</i></p>',
178+
unsafe_allow_html=True,
179+
)

0 commit comments

Comments
 (0)