|
| 1 | +import logging |
| 2 | +import traceback |
| 3 | + |
| 4 | + |
| 5 | +import streamlit as st |
| 6 | + |
| 7 | +from config import INTERNET_COLLECTION_DISPLAY_ID |
| 8 | +from utils import generate_stream, get_collections, get_models, header |
| 9 | + |
| 10 | +API_KEY = header() |
| 11 | + |
| 12 | +# Data |
| 13 | +try: |
| 14 | + language_models, embeddings_models, _, rerank_models = get_models(api_key=API_KEY) |
| 15 | + collections = get_collections(api_key=API_KEY) |
| 16 | +except Exception: |
| 17 | + st.error("Error to fetch user data.") |
| 18 | + logging.error(traceback.format_exc()) |
| 19 | + st.stop() |
| 20 | + |
| 21 | +# State |
| 22 | + |
| 23 | +if "selected_model" not in st.session_state: |
| 24 | + st.session_state["selected_model"] = language_models[0] |
| 25 | + |
| 26 | +if "selected_collections" not in st.session_state: |
| 27 | + st.session_state.selected_collections = [] |
| 28 | + |
| 29 | +if "messages" not in st.session_state: |
| 30 | + st.session_state["messages"] = [] |
| 31 | + st.session_state["sources"] = [] |
| 32 | + |
| 33 | +# Sidebar |
| 34 | +with st.sidebar: |
| 35 | + new_chat = st.button(label="**:material/refresh: New chat**", key="new", use_container_width=True) |
| 36 | + if new_chat: |
| 37 | + st.session_state.pop("messages", None) |
| 38 | + st.session_state.pop("sources", None) |
| 39 | + st.rerun() |
| 40 | + params = {"sampling_params": dict(), "rag": dict()} |
| 41 | + |
| 42 | + st.subheader(body="Chat parameters") |
| 43 | + st.session_state["selected_model"] = st.selectbox( |
| 44 | + label="Language model", options=language_models, index=language_models.index(st.session_state.selected_model) |
| 45 | + ) |
| 46 | + |
| 47 | + params["sampling_params"]["model"] = st.session_state["selected_model"] |
| 48 | + params["sampling_params"]["temperature"] = st.slider(label="Temperature", value=0.2, min_value=0.0, max_value=1.0, step=0.1) |
| 49 | + |
| 50 | + if st.toggle(label="Max tokens", value=False): |
| 51 | + max_tokens = st.number_input(label="Max tokens", value=100, min_value=0, step=100) |
| 52 | + params["sampling_params"]["max_tokens"] = max_tokens |
| 53 | + |
| 54 | + st.subheader(body="RAG parameters") |
| 55 | + params["rag"]["embeddings_model"] = st.selectbox(label="Embeddings model", options=embeddings_models) |
| 56 | + model_collections = [ |
| 57 | + f"{collection["name"]} - {collection["id"]}" for collection in collections if collection["model"] == params["rag"]["embeddings_model"] |
| 58 | + ] + [f"Internet - {INTERNET_COLLECTION_DISPLAY_ID}"] |
| 59 | + |
| 60 | + if model_collections: |
| 61 | + |
| 62 | + @st.dialog("Select collections") |
| 63 | + def add_collection(collections: list) -> None: |
| 64 | + selected_collections = st.session_state.selected_collections |
| 65 | + col1, col2 = st.columns(spec=2) |
| 66 | + |
| 67 | + for collection in collections: |
| 68 | + collection_id = collection.split(" - ")[1] |
| 69 | + if st.checkbox( |
| 70 | + label=f"{collection.split(" - ")[0]} (*{collection_id[:8]}*)", |
| 71 | + value=False if collection_id not in st.session_state.selected_collections else True, |
| 72 | + ): |
| 73 | + selected_collections.append(collection_id) |
| 74 | + elif collection_id in selected_collections: |
| 75 | + selected_collections.remove(collection_id) |
| 76 | + |
| 77 | + with col1: |
| 78 | + if st.button(label="**Submit :material/check_circle:**", use_container_width=True): |
| 79 | + st.session_state.selected_collections = list(set(selected_collections)) |
| 80 | + st.rerun() |
| 81 | + with col2: |
| 82 | + if st.button(label="**Clear :material/close:**", use_container_width=True): |
| 83 | + st.session_state.selected_collections = [] |
| 84 | + st.rerun() |
| 85 | + |
| 86 | + option_map = {0: f"{len(set(st.session_state.selected_collections))} selected"} |
| 87 | + pill = st.pills( |
| 88 | + label="Collections", |
| 89 | + options=option_map.keys(), |
| 90 | + format_func=lambda option: option_map[option], |
| 91 | + selection_mode="single", |
| 92 | + default=None, |
| 93 | + key="add_collections", |
| 94 | + ) |
| 95 | + if pill == 0: |
| 96 | + add_collection(collections=model_collections) |
| 97 | + |
| 98 | + params["rag"]["collections"] = st.session_state.selected_collections |
| 99 | + params["rag"]["k"] = st.number_input(label="Number of chunks to retrieve (k)", value=3) |
| 100 | + |
| 101 | + if st.session_state.selected_collections: |
| 102 | + rag = st.toggle(label="Activated RAG", value=True, disabled=not bool(params["rag"]["collections"])) |
| 103 | + else: |
| 104 | + rag = st.toggle(label="Activated RAG", value=False, disabled=True, help="You need to select at least one collection to activate RAG.") |
| 105 | + |
| 106 | + if st.session_state.selected_collections and rag: |
| 107 | + rerank = st.toggle( |
| 108 | + label="Add rerank", |
| 109 | + value=False, |
| 110 | + disabled=not bool(params["rag"]["collections"]), |
| 111 | + help="When activated, that retrieve the double number of chunks (k*2) and keep the best k chunks after reranking.", |
| 112 | + ) |
| 113 | + if rerank: |
| 114 | + params["rag"]["rerank_model"] = st.selectbox(label="Rerank model", options=rerank_models) |
| 115 | + else: |
| 116 | + rerank = st.toggle( |
| 117 | + label="Add rerank", value=False, disabled=True, help="You need to select at least one collection to activate rerank and activate RAG." |
| 118 | + ) |
| 119 | + |
| 120 | +# Main |
| 121 | +with st.chat_message(name="assistant"): |
| 122 | + st.markdown( |
| 123 | + body="""Bonjour je suis Albert, et je peux vous aider si vous avez des questions administratives ! |
| 124 | +
|
| 125 | +Je peux me connecter à vos bases de connaissances, pour ça sélectionnez les collections voulues dans le menu de gauche. Je peux également chercher sur les sites officiels de l'État, pour ça sélectionnez la collection "Internet" à gauche. Si vous ne souhaitez pas utiliser de collection, désactivez le RAG en décochant la fonction "Activated RAG". |
| 126 | + |
| 127 | +Comment puis-je vous aider ? |
| 128 | +""" |
| 129 | + ) |
| 130 | + |
| 131 | +for i, message in enumerate(st.session_state.messages): |
| 132 | + with st.chat_message(message["role"], avatar=":material/face:" if message["role"] == "user" else None): |
| 133 | + st.markdown(message["content"]) |
| 134 | + if st.session_state.sources[i]: |
| 135 | + st.pills(label="Sources", options=st.session_state.sources[i], label_visibility="hidden") |
| 136 | + |
| 137 | +sources = [] |
| 138 | +if prompt := st.chat_input(placeholder="Message to Albert"): |
| 139 | + # send message to the model |
| 140 | + user_message = {"role": "user", "content": prompt} |
| 141 | + st.session_state.messages.append(user_message) |
| 142 | + st.session_state.sources.append([]) |
| 143 | + with st.chat_message(name="user", avatar=":material/face:"): |
| 144 | + st.markdown(body=prompt) |
| 145 | + |
| 146 | + with st.chat_message(name="assistant"): |
| 147 | + try: |
| 148 | + stream, sources = generate_stream( |
| 149 | + messages=st.session_state.messages, |
| 150 | + params=params, |
| 151 | + api_key=API_KEY, |
| 152 | + rag=rag, |
| 153 | + rerank=rerank, |
| 154 | + ) |
| 155 | + response = st.write_stream(stream=stream) |
| 156 | + except Exception: |
| 157 | + st.error(body="Error to generate response.") |
| 158 | + logging.error(traceback.format_exc()) |
| 159 | + st.stop() |
| 160 | + |
| 161 | + formatted_sources = [] |
| 162 | + if sources: |
| 163 | + for source in sources: |
| 164 | + formatted_source = source[:15] + "..." if len(source) > 15 else source |
| 165 | + if source.lower().startswith("http"): |
| 166 | + formatted_sources.append(f":material/globe: [{formatted_source}]({source})") |
| 167 | + else: |
| 168 | + formatted_sources.append(f":material/import_contacts: {formatted_source}") |
| 169 | + st.pills(label="Sources", options=formatted_sources, label_visibility="hidden") |
| 170 | + |
| 171 | + assistant_message = {"role": "assistant", "content": response} |
| 172 | + st.session_state.messages.append(assistant_message) |
| 173 | + st.session_state.sources.append(formatted_sources) |
| 174 | + |
| 175 | +with st._bottom: |
| 176 | + st.caption( |
| 177 | + body='<p style="text-align: center;"><i>I can make mistakes, please always verify my sources and answers.</i></p>', |
| 178 | + unsafe_allow_html=True, |
| 179 | + ) |
0 commit comments