Skip to content

Commit 31e4eee

Browse files
committed
Add loading state
1 parent 3b269d3 commit 31e4eee

File tree

2 files changed

+24
-21
lines changed

2 files changed

+24
-21
lines changed

main.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,10 @@ def execute_sql(query, conn, retries=2):
135135
and st.session_state["messages"][-1]["role"] != "assistant"
136136
):
137137
user_input_content = st.session_state["messages"][-1]["content"]
138-
# print(f"User input content is: {user_input_content}")
139138

140139
if isinstance(user_input_content, str):
140+
callback_handler.start_loading_message()
141+
141142
result = chain.invoke(
142143
{
143144
"question": user_input_content,

utils/snowchat_ui.py

+22-20
Original file line numberDiff line numberDiff line change
@@ -84,18 +84,38 @@ def message_func(text, is_user=False, is_df=False):
8484

8585
class StreamlitUICallbackHandler(BaseCallbackHandler):
8686
def __init__(self):
87-
# Buffer to accumulate tokens
8887
self.token_buffer = []
8988
self.placeholder = st.empty()
9089
self.has_streaming_ended = False
90+
self.has_streaming_started = False
91+
92+
def start_loading_message(self):
93+
loading_message_content = self._get_bot_message_container("Thinking...")
94+
self.placeholder.markdown(loading_message_content, unsafe_allow_html=True)
95+
96+
def on_llm_new_token(self, token, run_id, parent_run_id=None, **kwargs):
97+
if not self.has_streaming_started:
98+
self.has_streaming_started = True
99+
100+
self.token_buffer.append(token)
101+
complete_message = "".join(self.token_buffer)
102+
container_content = self._get_bot_message_container(complete_message)
103+
self.placeholder.markdown(container_content, unsafe_allow_html=True)
104+
105+
def on_llm_end(self, response, run_id, parent_run_id=None, **kwargs):
106+
self.token_buffer = []
107+
self.has_streaming_ended = True
108+
self.has_streaming_started = False
91109

92110
def _get_bot_message_container(self, text):
93111
"""Generate the bot's message container style for the given text."""
94112
avatar_url = "https://avataaars.io/?avatarStyle=Transparent&topType=WinterHat2&accessoriesType=Kurt&hatColor=Blue01&facialHairType=MoustacheMagnum&facialHairColor=Blonde&clotheType=Overall&clotheColor=Gray01&eyeType=WinkWacky&eyebrowType=SadConcernedNatural&mouthType=Sad&skinColor=Light"
95113
message_alignment = "flex-start"
96114
message_bg_color = "#71797E"
97115
avatar_class = "bot-avatar"
98-
formatted_text = format_message(text)
116+
formatted_text = format_message(
117+
text
118+
) # Ensure this handles "Thinking..." appropriately.
99119
container_content = f"""
100120
<div style="display: flex; align-items: center; margin-bottom: 10px; justify-content: {message_alignment};">
101121
<img src="{avatar_url}" class="{avatar_class}" alt="avatar" style="width: 50px; height: 50px;" />
@@ -105,17 +125,6 @@ def _get_bot_message_container(self, text):
105125
"""
106126
return container_content
107127

108-
def on_llm_new_token(self, token, run_id, parent_run_id=None, **kwargs):
109-
"""
110-
Handle the new token from the model. Accumulate tokens in a buffer and update the Streamlit UI.
111-
"""
112-
self.token_buffer.append(token)
113-
complete_message = "".join(self.token_buffer)
114-
115-
# Update the placeholder content with the complete message
116-
container_content = self._get_bot_message_container(complete_message)
117-
self.placeholder.markdown(container_content, unsafe_allow_html=True)
118-
119128
def display_dataframe(self, df):
120129
"""
121130
Display the dataframe in Streamlit UI within the chat container.
@@ -134,12 +143,5 @@ def display_dataframe(self, df):
134143
)
135144
st.write(df)
136145

137-
def on_llm_end(self, response, run_id, parent_run_id=None, **kwargs):
138-
"""
139-
Reset the buffer when the LLM finishes running.
140-
"""
141-
self.token_buffer = [] # Reset the buffer
142-
self.has_streaming_ended = True
143-
144146
def __call__(self, *args, **kwargs):
145147
pass

0 commit comments

Comments
 (0)