Skip to content

Commit cd3cbce

Browse files
committed
Improve KG extractor reliability and CLI local model flow
1 parent ee66307 commit cd3cbce

7 files changed

Lines changed: 484 additions & 63 deletions

File tree

dwata-agents/src/kg_email_extractor/agent.rs

Lines changed: 353 additions & 36 deletions
Large diffs are not rendered by default.

dwata-agents/src/kg_email_extractor/document_labeler.rs

Lines changed: 105 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,62 @@
1-
use crate::kg_email_extractor::types::LabelDocumentParams;
1+
use crate::kg_email_extractor::types::{DocumentType, LabelDocumentParams};
22
use crate::storage::{AgentStorage, Message};
33
use nocodo_llm_sdk::client::LlmClient;
44
use nocodo_llm_sdk::types::{CompletionRequest, ContentBlock, Message as LlmMessage};
55
use nocodo_llm_sdk::Tool;
66
use std::sync::Arc;
77

8+
const RESPONSE_PREVIEW_CHARS: usize = 1200;
9+
10+
fn preview_text(value: &str, max_chars: usize) -> String {
11+
value.chars().take(max_chars).collect()
12+
}
13+
14+
fn fallback_label_from_template(template: &str) -> LabelDocumentParams {
15+
let text = template.to_ascii_lowercase();
16+
17+
let has_receipt = text.contains("receipt");
18+
let has_payment = text.contains("payment")
19+
|| text.contains("paid")
20+
|| text.contains("debited")
21+
|| text.contains("charged");
22+
let has_bill = text.contains("amount due")
23+
|| text.contains("due date")
24+
|| text.contains("pay by")
25+
|| text.contains("billing period")
26+
|| text.contains("invoice");
27+
let has_order = text.contains("order")
28+
|| text.contains("shipment")
29+
|| text.contains("tracking")
30+
|| text.contains("delivered");
31+
let has_event = text.contains("meeting")
32+
|| text.contains("appointment")
33+
|| text.contains("invite")
34+
|| text.contains("calendar")
35+
|| text.contains("event");
36+
37+
let doc_type = if has_receipt {
38+
DocumentType::Receipt
39+
} else if has_bill {
40+
DocumentType::Bill
41+
} else if has_payment {
42+
DocumentType::PaymentConfirmation
43+
} else if has_order {
44+
DocumentType::Unknown
45+
} else if has_event {
46+
DocumentType::Unknown
47+
} else {
48+
DocumentType::Unknown
49+
};
50+
51+
LabelDocumentParams {
52+
doc_type,
53+
has_bill,
54+
has_transaction: has_payment || has_receipt,
55+
has_event,
56+
has_order,
57+
}
58+
}
59+
860
pub struct TemplateDocumentLabelerAgent {
961
llm_client: Arc<dyn LlmClient>,
1062
storage: Arc<dyn AgentStorage>,
@@ -29,6 +81,12 @@ impl TemplateDocumentLabelerAgent {
2981

3082
pub async fn execute(&self, session_id: i64) -> anyhow::Result<LabelDocumentParams> {
3183
let system_prompt = super::document_labeler_prompt::build_system_prompt(&self.template);
84+
tracing::info!(
85+
model = %self.model,
86+
prompt_len = system_prompt.len(),
87+
template_len = self.template.len(),
88+
"Document labeler starting"
89+
);
3290

3391
let label_tool = Tool::from_type::<LabelDocumentParams>()
3492
.name("label_document")
@@ -82,26 +140,42 @@ impl TemplateDocumentLabelerAgent {
82140
};
83141

84142
let response = self.llm_client.complete(request).await?;
143+
let assistant_text = response
144+
.content
145+
.iter()
146+
.filter_map(|block| match block {
147+
ContentBlock::Text { text } => Some(text.clone()),
148+
_ => None,
149+
})
150+
.collect::<Vec<_>>()
151+
.join("\n");
152+
tracing::debug!(
153+
model = %self.model,
154+
iteration = iteration + 1,
155+
content_blocks = response.content.len(),
156+
tool_calls = response.tool_calls.as_ref().map(|t| t.len()).unwrap_or(0),
157+
assistant_text_preview = %preview_text(&assistant_text, RESPONSE_PREVIEW_CHARS),
158+
"Document labeler response received"
159+
);
85160

86161
self.storage
87162
.create_message(Message {
88163
id: None,
89164
session_id,
90165
role: "assistant".to_string(),
91-
content: response
92-
.content
93-
.iter()
94-
.filter_map(|block| match block {
95-
ContentBlock::Text { text } => Some(text.clone()),
96-
_ => None,
97-
})
98-
.collect::<Vec<_>>()
99-
.join("\n"),
166+
content: assistant_text.clone(),
100167
})
101168
.await?;
102169

103170
if let Some(tool_calls) = response.tool_calls {
104171
for tool_call in tool_calls {
172+
tracing::debug!(
173+
iteration = iteration + 1,
174+
tool_id = %tool_call.id(),
175+
tool_name = %tool_call.name(),
176+
raw_arguments = %tool_call.raw_arguments(),
177+
"Document labeler tool call"
178+
);
105179
if tool_call.name() == "label_document" {
106180
let params: LabelDocumentParams = tool_call.parse_arguments()?;
107181

@@ -120,6 +194,16 @@ impl TemplateDocumentLabelerAgent {
120194
return Ok(params);
121195
}
122196
}
197+
tracing::warn!(
198+
iteration = iteration + 1,
199+
"Document labeler returned tool calls but none matched label_document"
200+
);
201+
} else {
202+
tracing::warn!(
203+
iteration = iteration + 1,
204+
assistant_text_preview = %preview_text(&assistant_text, RESPONSE_PREVIEW_CHARS),
205+
"Document labeler returned no tool calls"
206+
);
123207
}
124208

125209
self.storage
@@ -133,8 +217,16 @@ impl TemplateDocumentLabelerAgent {
133217
.await?;
134218
}
135219

136-
Err(anyhow::anyhow!(
137-
"Document labeler did not call label_document after 2 iterations"
138-
))
220+
let fallback = fallback_label_from_template(&self.template);
221+
tracing::warn!(
222+
model = %self.model,
223+
fallback_doc_type = ?fallback.doc_type,
224+
fallback_has_bill = fallback.has_bill,
225+
fallback_has_transaction = fallback.has_transaction,
226+
fallback_has_event = fallback.has_event,
227+
fallback_has_order = fallback.has_order,
228+
"Document labeler did not call label_document after 2 iterations; using heuristic fallback"
229+
);
230+
Ok(fallback)
139231
}
140232
}

dwata-api/src/bin/extract_kg_entities.rs

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ use dwata_api::helpers::database::initialize_database;
1212
use dwata_api::search::entity_index::{
1313
open_or_create_index, reindex_all_entities, DbEntitySearchProvider,
1414
};
15-
use nocodo_llm_sdk::models::ollama::MINISTRAL_3_3B_ID;
16-
use nocodo_llm_sdk::ollama::OllamaClient;
15+
use nocodo_llm_sdk::llama_cpp::LlamaCppClient;
16+
use nocodo_llm_sdk::models::llama_cpp::QWEN_3_5_0_8B;
1717
use std::sync::Arc;
1818
use tracing_subscriber::EnvFilter;
1919

@@ -29,6 +29,10 @@ struct Args {
2929
/// Skip document labeler and run all four passes unconditionally
3030
#[arg(long, default_value_t = false)]
3131
all_passes: bool,
32+
33+
/// Base URL for llama.cpp OpenAI-compatible server
34+
#[arg(long, default_value = "http://localhost:8080")]
35+
llama_base_url: String,
3236
}
3337

3438
// ---------------------------------------------------------------------------
@@ -438,6 +442,7 @@ async fn main() -> Result<()> {
438442
.init();
439443

440444
let args = Args::parse();
445+
let selected_model = QWEN_3_5_0_8B;
441446

442447
let db = initialize_database().context("Failed to initialize database")?;
443448

@@ -455,7 +460,11 @@ async fn main() -> Result<()> {
455460
.await
456461
.context("Failed to reindex entities")?;
457462

458-
let llm_client = Arc::new(OllamaClient::new().context("Failed to initialize Ollama client")?);
463+
let llm_client = Arc::new(
464+
LlamaCppClient::new()
465+
.context("Failed to initialize llama.cpp client")?
466+
.with_base_url(args.llama_base_url.clone()),
467+
);
459468
let storage = Arc::new(InMemoryAgentStorage::new());
460469

461470
let email = emails_db::get_email(db.async_connection.clone(), args.email_id)
@@ -470,6 +479,8 @@ async fn main() -> Result<()> {
470479

471480
println!("Email ID: {}", email.id);
472481
println!("Subject: {}", simple.subject);
482+
println!("Model: {}", selected_model);
483+
println!("Provider: llama.cpp ({})", args.llama_base_url);
473484
println!();
474485

475486
// --- Step 1: Document labeling (skip if --all-passes) ---
@@ -493,7 +504,7 @@ async fn main() -> Result<()> {
493504
let labeler = TemplateDocumentLabelerAgent::new(
494505
llm_client.clone(),
495506
storage.clone(),
496-
MINISTRAL_3_3B_ID.to_string(),
507+
selected_model.to_string(),
497508
simple.body.clone(),
498509
);
499510

@@ -545,9 +556,10 @@ async fn main() -> Result<()> {
545556
llm_client,
546557
storage,
547558
persistence,
548-
MINISTRAL_3_3B_ID.to_string(),
559+
selected_model.to_string(),
549560
email_content,
550561
)
562+
.with_single_tool_submission(true)
551563
.with_search_provider(search_provider)
552564
.with_source_email_id(email.id)
553565
.with_sender(email.from_name.clone(), Some(email.from_address.clone()));

dwata-api/src/config.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use nocodo_llm_sdk::models::ollama::MINISTRAL_3_3B_ID;
1+
use nocodo_llm_sdk::models::ollama::QWEN_3_5_2B_ID;
22
use serde::{Deserialize, Serialize};
33
use std::fs;
44
use std::path::PathBuf;
@@ -115,7 +115,7 @@ impl Default for SelectedLlmConfig {
115115
fn default() -> Self {
116116
Self {
117117
provider: "ollama".to_string(),
118-
model: MINISTRAL_3_3B_ID.to_string(),
118+
model: QWEN_3_5_2B_ID.to_string(),
119119
}
120120
}
121121
}

dwata-api/src/handlers/kg_extraction.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use dwata_agents::{
1515
storage::{AgentStorage, InMemoryAgentStorage, Session},
1616
KgEmailExtractionAgent, TemplateDocumentLabelerAgent,
1717
};
18-
use nocodo_llm_sdk::models::ollama::MINISTRAL_3_3B_ID;
18+
use nocodo_llm_sdk::models::ollama::QWEN_3_5_2B_ID;
1919
use nocodo_llm_sdk::ollama::OllamaClient;
2020
use serde::Deserialize;
2121

@@ -474,7 +474,7 @@ async fn process_single_email(
474474
let labeler = TemplateDocumentLabelerAgent::new(
475475
llm_client.clone(),
476476
storage.clone(),
477-
MINISTRAL_3_3B_ID.to_string(),
477+
QWEN_3_5_2B_ID.to_string(),
478478
simple.body.clone(),
479479
);
480480

@@ -517,7 +517,7 @@ async fn process_single_email(
517517
llm_client,
518518
storage,
519519
persistence,
520-
MINISTRAL_3_3B_ID.to_string(),
520+
QWEN_3_5_2B_ID.to_string(),
521521
email_content,
522522
)
523523
.with_search_provider(search_provider)

dwata-api/src/handlers/ollama.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
use actix_web::{web, HttpResponse, Result};
2-
use nocodo_llm_sdk::models::ollama::MINISTRAL_3_3B_ID;
2+
use nocodo_llm_sdk::models::ollama::QWEN_3_5_2B_ID;
33
use nocodo_llm_sdk::ollama::types as sdk_types;
44
use nocodo_llm_sdk::ollama::OllamaClient;
55
use shared_types::{
66
ErrorResponse, OllamaModelDetails, OllamaModelInfo, OllamaModelsResponse,
77
OllamaPullModelRequest, OllamaPullModelResponse, OllamaStatusResponse,
88
};
99

10-
const ALLOWED_PULL_MODELS: &[&str] = &[MINISTRAL_3_3B_ID];
10+
const ALLOWED_PULL_MODELS: &[&str] = &[QWEN_3_5_2B_ID];
1111

1212
fn map_model_details(details: sdk_types::OllamaModelDetails) -> OllamaModelDetails {
1313
OllamaModelDetails {

project.example.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ auto_start = false
4747
# gemini_api_key = "your-gemini-key"
4848

4949
[selected_llm]
50-
# Controls the model used by template detection agents.
50+
# Controls the model used by KG extraction agents.
5151
# Providers: "ollama", "openai", "gemini"
5252
# OpenAI model allowed: "gpt-5-mini"
5353
# Gemini model allowed: "gemini-3-flash-preview"
5454
provider = "ollama"
55-
model = "ministral-3:3b"
55+
model = "qwen3.5:2b"

0 commit comments

Comments
 (0)