Skip to content

Commit bb4b2d1

Browse files
committed
Added progress bar updates and streaming to the summarize tool
Signed-off-by: TikaaVo <tikavod6@gmail.com>
1 parent fbcb4b2 commit bb4b2d1

4 files changed

Lines changed: 84 additions & 12 deletions

File tree

WRITEUP.md

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Coding Challenge - Issue #241
2+
3+
Total time spent: 4 hours
4+
5+
## 1. Environment Setup
6+
7+
Setting up the environment took significantly longer (approximately 2 hours) than expected due to several issues. These are centered around the fact that I used an outdated source to base myself off when writing the docker compose file and thus was running on an older NextCloud version.
8+
9+
- `host.docker.internal` didn't work, as the Nextcloud container couldn't reach llm2 running on the host, giving the following error: `nc_py_api._exceptions.NextcloudException: [400] Bad Request <request: PUT /ocs/v1.php/apps/app_api/ex-app/status>`, which was difficult to guage the issue from. I eventually managed to fix it by switching to the Docker bridge IP `172.17.0.1` and added it to Nextcloud's `trusted_domains`.
10+
- llm2's task processing provider registration endpoint requires Nextcloud 30+, but my usage of `nextcloud:29` returned `ERROR - Failed to register llama-2-7b-chat.Q4_K_M - core:text2text:summary, Error: [501] <request: POST /ocs/v1.php/apps/app_api/api/v1/ai_provider/task_processing>`. Upgrading fixed this.
11+
- The version of nc_py_api installed by Poetry (`0.24.2`) was out of sync with AppAPI 3.2.3, so I upgraded to `0.30.1`.
12+
13+
After fixing this, llm2 successfully initialized, and upon running a summarization task, I saw the issue to be fixed, being that the progress bar was fixed at 0.00% until the task was completed.
14+
15+
## 2. Investigating the issue
16+
17+
Firstly, I wanted to locate the code that set the progress, so I looked through the source code of `nc_py_api` until I found `set_progress` in the `_TaskProcessingProviderAPI` class, which accepted the task_id and the progress as a float value from 0.00 to 100.00.
18+
19+
## 3. Pass needed information
20+
21+
I decided to start with the summarization task under `summary.py`. Firstly, in order to call set_progress, we need `SummarizeProcessor` to have access to `nc` and `task_id`, so they were passed as parameters into the constructor function, then `task_processors.py` and `main.py` were modified to support this.
22+
23+
## 4. Investigate how to estimate the response length
24+
25+
Initially, I was thinking about whether the context window `n_ctx` could be used to estimate the response length. However, then I found that inside `task_processors.py`, the model's `max_tokens` can be extracted from the model config, so I extracted that and passed it to `SummarizeProcessor` as another parameter, as this can be used as a more accurate estimation of response length.
26+
27+
## 5. Updating the Progress Bar
28+
29+
The `__call__` method in `SummarizeProcessor` used `invoke`, which doesn't provide progress. Therefore, I wrote a helped function `_invoke_progress`, which streams the generation and calls `set_progress`, using the max_tokens as the upper bound. For multiple splits, I assumed that each split is roughly equal, so if there are N splits, then split M (1 <= M <= N) would take the progress bar from `(100/N) * (M-1)` to `(100/N) * M`, so the `_invoke_progress` function accepts the current split index and the total number of splits.
30+
31+
## 6. Testing and Limitations
32+
33+
Upon testing, I noticed that the GUI was showing the progress as 0.00% to 1.00%. Therefore, if my code passed 25 to `set_progress`, the GUI would show 0.25%. This seems to be an issue on the side of either the `set_progress` function or the GUI, as there is some kind of division by 100 happening.
34+
35+
One limitation of using the max_tokens as an upper bound is that most responses don't hit this limit or even come close, so it's frequent to see the progress bar jump from a few percentage points to completed or to the boundary of the next split. This upper bound is safe, as responses cannot exceed it, but conservative.

lib/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def background_thread_task():
132132
task["id"], error_message="Requested model is not available"
133133
)
134134
continue
135-
task_processor = task_processor_loader()
135+
task_processor = task_processor_loader(nc, task["id"])
136136
log(nc, LogLvl.INFO, "Generating reply")
137137
time_start = perf_counter()
138138
log(nc, LogLvl.INFO, task.get("input"))

lib/summarize.py

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from langchain.schema.prompt_template import BasePromptTemplate
88
from langchain_core.runnables import Runnable
99
from langchain.text_splitter import RecursiveCharacterTextSplitter
10+
from nc_py_api import NextcloudApp
1011

1112

1213
class SummarizeProcessor:
@@ -35,8 +36,12 @@ class SummarizeProcessor:
3536
"""
3637
)
3738

38-
def __init__(self, runnable: Runnable, n_ctx: int = 8000):
39+
def __init__(self, runnable: Runnable, nc: NextcloudApp, task_id: int, n_ctx: int = 8000, max_tokens: int = 512):
3940
self.runnable = runnable
41+
self.nc = nc
42+
self.task_id = task_id
43+
self.n_ctx = n_ctx
44+
self.max_tokens = max_tokens if max_tokens > 0 else 512
4045
self.text_splitter = RecursiveCharacterTextSplitter(
4146
separators=['\n\n|\\.|\\?|\\!'],
4247
is_separator_regex=True,
@@ -46,6 +51,29 @@ def __init__(self, runnable: Runnable, n_ctx: int = 8000):
4651
length_function=len,
4752
)
4853

54+
def _invoke_progress(self, messages, max_tokens: int, idx: int, total_splits: int) -> str:
55+
# Stream the response and update progress
56+
57+
start_pct = (idx / total_splits) * 100.0
58+
end_pct = ((idx + 1) / total_splits) * 100.0
59+
60+
tokens_generated = 0
61+
full_response = ""
62+
total_range = end_pct - start_pct
63+
64+
for chunk in self.runnable.stream(messages):
65+
token = chunk.content if hasattr(chunk, 'content') else str(chunk)
66+
full_response += token
67+
tokens_generated += 1
68+
69+
fraction = min(1.0, tokens_generated / max_tokens)
70+
progress = start_pct + fraction * total_range
71+
self.nc.providers.task_processing.set_progress(self.task_id, progress)
72+
73+
# Ensure the end percentage is set after completion
74+
self.nc.providers.task_processing.set_progress(self.task_id, end_pct)
75+
return full_response
76+
4977
def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]:
5078
# Split text if needed
5179
splits = self.text_splitter.split_text(inputs['input'])
@@ -55,23 +83,29 @@ def __call__(self, inputs: dict[str, Any]) -> dict[str, Any]:
5583
SystemMessage(content=self.system_prompt),
5684
HumanMessage(content=self.user_prompt.format(input=splits[0]))
5785
]
58-
output = self.runnable.invoke(messages)
59-
return {'output': output.content}
86+
87+
output = self._invoke_progress(messages, self.max_tokens, 0, 1)
88+
return {'output': output}
6089

6190
# Process each split
91+
total_splits = len(splits)
6292
summaries = []
63-
for split in splits:
93+
94+
for idx, split in enumerate(splits):
95+
96+
6497
messages = [
6598
SystemMessage(content=self.system_prompt),
6699
HumanMessage(content=self.user_prompt.format(input=split))
67100
]
68-
output = self.runnable.invoke(messages)
69-
summaries.append(output.content)
70101

71-
# Merge summaries
72-
messages = [
102+
split_output = self._invoke_progress(messages, self.max_tokens, idx, total_splits)
103+
summaries.append(split_output)
104+
105+
merge_messages = [
73106
SystemMessage(content=self.system_prompt),
74107
HumanMessage(content=self.merge_prompt.format(input="\n\n".join(summaries)))
75108
]
76-
final_output = self.runnable.invoke(messages)
109+
final_output = self.runnable.invoke(merge_messages)
110+
self.nc.providers.task_processing.set_progress(self.task_id, 100.0)
77111
return {'output': final_output.content}

lib/task_processors.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,12 @@ def generate_task_processors(task_processors = {}):
128128

129129
def generate_task_processors_for_model(file_name, task_processors):
130130
model_name = file_name.split('.gguf')[0]
131-
n_ctx = get_model_config(file_name)["loader_config"]["n_ctx"]
131+
model_config = get_model_config(file_name)
132+
n_ctx = model_config["loader_config"]["n_ctx"]
133+
max_tokens = model_config["loader_config"].get("max_tokens")
134+
132135

133-
task_processors[model_name + ":core:text2text:summary"] = lambda: SummarizeProcessor(generate_chat_chain(file_name), n_ctx)
136+
task_processors[model_name + ":core:text2text:summary"] = lambda nc, task_id: SummarizeProcessor(generate_chat_chain(file_name), nc, task_id, n_ctx, max_tokens)
134137
task_processors[model_name + ":core:text2text:headline"] = lambda: HeadlineProcessor(generate_chat_chain(file_name))
135138
task_processors[model_name + ":core:text2text:topics"] = lambda: TopicsProcessor(generate_chat_chain(file_name))
136139
task_processors[model_name + ":core:text2text:simplification"] = lambda: SimplifyProcessor(generate_chat_chain(file_name))

0 commit comments

Comments
 (0)