Skip to content

Commit 741379a

Browse files
committed
feat(memory): implement recall scoring
1 parent 0e37094 commit 741379a

4 files changed

Lines changed: 228 additions & 6 deletions

File tree

codex-rs/Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

codex-rs/memory/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ edition = { workspace = true }
77
anyhow = "1"
88
serde = { version = "1", features = ["derive"] }
99
serde_json = "1"
10+
chrono = { version = "0.4", default-features = false, features = ["clock"] }
1011

1112
[features]
1213
default = []

codex-rs/memory/src/recall.rs

Lines changed: 88 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1+
use crate::store::MemoryStore;
12
use crate::types::MemoryItem;
3+
use crate::types::Status;
4+
use chrono::DateTime;
5+
use chrono::Utc;
6+
use std::collections::BTreeSet;
27

38
pub struct RecallContext {
49
pub repo_root: Option<std::path::PathBuf>,
@@ -13,9 +18,88 @@ pub struct RecallContext {
1318
}
1419

1520
pub fn recall(
16-
_store: &dyn crate::store::MemoryStore,
17-
_prompt: &str,
18-
_ctx: &RecallContext,
21+
store: &dyn MemoryStore,
22+
prompt: &str,
23+
ctx: &RecallContext,
1924
) -> anyhow::Result<Vec<MemoryItem>> {
20-
todo!()
25+
let now = DateTime::parse_from_rfc3339(&ctx.now_rfc3339)?.with_timezone(&Utc);
26+
let tokens = tokenize(prompt);
27+
let mut scored: Vec<(f32, usize, MemoryItem)> = store
28+
.list(None, Some(Status::Active))?
29+
.into_iter()
30+
.map(|item| {
31+
let mut score = overlap_score(&tokens, &tokenize(&item.content));
32+
if let Some(f) = &ctx.current_file
33+
&& item.relevance_hints.files.iter().any(|h| f.ends_with(h))
34+
{
35+
score += 0.4;
36+
}
37+
if let Some(c) = &ctx.crate_name
38+
&& item.relevance_hints.crates.iter().any(|h| h == c)
39+
{
40+
score += 0.3;
41+
}
42+
if let Some(l) = &ctx.language
43+
&& item
44+
.relevance_hints
45+
.languages
46+
.iter()
47+
.any(|h| h.eq_ignore_ascii_case(l))
48+
{
49+
score += 0.2;
50+
}
51+
if let Some(cmd) = &ctx.command
52+
&& item.relevance_hints.commands.iter().any(|h| h == cmd)
53+
{
54+
score += 0.1;
55+
}
56+
let freq = 1.0 + item.counters.used_count as f32 * 0.1;
57+
score *= freq;
58+
if let Some(last) = &item.counters.last_used_at
59+
&& let Ok(dt) = DateTime::parse_from_rfc3339(last)
60+
{
61+
let age_days = (now - dt.with_timezone(&Utc)).num_days();
62+
let decay = 0.5f32.powf(age_days as f32 / 7.0);
63+
score *= decay;
64+
}
65+
let token_len = item.content.split_whitespace().count();
66+
(score, token_len, item)
67+
})
68+
.collect();
69+
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
70+
let mut out = Vec::new();
71+
let mut used_tokens = 0usize;
72+
for (_, tokens, mut item) in scored {
73+
if out.len() >= ctx.item_cap {
74+
break;
75+
}
76+
if used_tokens + tokens > ctx.token_cap {
77+
break;
78+
}
79+
used_tokens += tokens;
80+
item.counters.used_count += 1;
81+
item.counters.last_used_at = Some(ctx.now_rfc3339.clone());
82+
store.update(&item)?;
83+
out.push(item);
84+
}
85+
Ok(out)
86+
}
87+
88+
fn tokenize(s: &str) -> BTreeSet<String> {
89+
let mut set = BTreeSet::new();
90+
for w in s.split(|c: char| !c.is_alphanumeric()) {
91+
if w.is_empty() {
92+
continue;
93+
}
94+
set.insert(w.to_ascii_lowercase());
95+
}
96+
set
97+
}
98+
99+
fn overlap_score(a: &BTreeSet<String>, b: &BTreeSet<String>) -> f32 {
100+
if a.is_empty() || b.is_empty() {
101+
return 0.0;
102+
}
103+
let inter = a.intersection(b).count() as f32;
104+
inter / a.len() as f32
21105
}

codex-rs/memory/tests/recall.rs

Lines changed: 138 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,140 @@
1+
use codex_memory::recall::RecallContext;
2+
use codex_memory::recall::recall;
3+
use codex_memory::store::MemoryStore;
4+
use codex_memory::types::Counters;
5+
use codex_memory::types::Kind;
6+
use codex_memory::types::MemoryItem;
7+
use codex_memory::types::RelevanceHints;
8+
use codex_memory::types::Scope;
9+
use codex_memory::types::Status;
10+
use std::collections::HashMap;
11+
use std::sync::Mutex;
12+
13+
#[derive(Default)]
14+
struct TestStore {
15+
items: Mutex<HashMap<String, MemoryItem>>,
16+
}
17+
18+
impl TestStore {
19+
fn new(items: Vec<MemoryItem>) -> Self {
20+
let map = items.into_iter().map(|i| (i.id.clone(), i)).collect();
21+
Self {
22+
items: Mutex::new(map),
23+
}
24+
}
25+
}
26+
27+
impl MemoryStore for TestStore {
28+
fn add(&self, item: MemoryItem) -> anyhow::Result<()> {
29+
self.items.lock().unwrap().insert(item.id.clone(), item);
30+
Ok(())
31+
}
32+
33+
fn update(&self, item: &MemoryItem) -> anyhow::Result<()> {
34+
self.items
35+
.lock()
36+
.unwrap()
37+
.insert(item.id.clone(), item.clone());
38+
Ok(())
39+
}
40+
41+
fn delete(&self, _id: &str) -> anyhow::Result<()> {
42+
Ok(())
43+
}
44+
45+
fn get(&self, id: &str) -> anyhow::Result<Option<MemoryItem>> {
46+
Ok(self.items.lock().unwrap().get(id).cloned())
47+
}
48+
49+
fn list(
50+
&self,
51+
_scope: Option<Scope>,
52+
status: Option<Status>,
53+
) -> anyhow::Result<Vec<MemoryItem>> {
54+
let items = self.items.lock().unwrap();
55+
Ok(items
56+
.values()
57+
.filter(|i| match status.as_ref() {
58+
Some(s) => i.status == *s,
59+
None => true,
60+
})
61+
.cloned()
62+
.collect())
63+
}
64+
65+
fn archive(&self, _id: &str, _archived: bool) -> anyhow::Result<()> {
66+
Ok(())
67+
}
68+
69+
fn export(&self, _out: &mut dyn std::io::Write) -> anyhow::Result<()> {
70+
Ok(())
71+
}
72+
73+
fn import(&self, _input: &mut dyn std::io::Read) -> anyhow::Result<usize> {
74+
Ok(0)
75+
}
76+
77+
fn stats(&self) -> anyhow::Result<serde_json::Value> {
78+
Ok(serde_json::json!({}))
79+
}
80+
}
81+
82+
fn item(id: &str, content: &str, lang: &str) -> MemoryItem {
83+
MemoryItem {
84+
id: id.to_string(),
85+
created_at: "2024-01-01T00:00:00Z".into(),
86+
updated_at: "2024-01-01T00:00:00Z".into(),
87+
schema_version: 1,
88+
source: "test".into(),
89+
scope: Scope::Global,
90+
status: Status::Active,
91+
kind: Kind::Fact,
92+
content: content.into(),
93+
tags: vec![],
94+
relevance_hints: RelevanceHints {
95+
files: vec![],
96+
crates: vec![],
97+
languages: vec![lang.into()],
98+
commands: vec![],
99+
},
100+
counters: Counters {
101+
seen_count: 0,
102+
used_count: 0,
103+
last_used_at: None,
104+
},
105+
expiry: None,
106+
}
107+
}
108+
1109
#[test]
2-
fn placeholder() {
3-
// placeholder test
110+
fn ranks_and_updates_counters() {
111+
let a = item("1", "use cargo build for rust", "rust");
112+
let b = item("2", "cargo test runs tests", "rust");
113+
let c = item("3", "npm install packages", "javascript");
114+
let store = TestStore::new(vec![a.clone(), b.clone(), c.clone()]);
115+
let now = "2024-01-10T00:00:00Z".to_string();
116+
let ctx = RecallContext {
117+
repo_root: None,
118+
dir: None,
119+
current_file: None,
120+
crate_name: None,
121+
language: Some("rust".into()),
122+
command: None,
123+
now_rfc3339: now.clone(),
124+
item_cap: 2,
125+
token_cap: 50,
126+
};
127+
let out = recall(&store, "cargo build rust", &ctx).unwrap();
128+
assert_eq!(out.len(), 2);
129+
assert_eq!(out[0].id, "1");
130+
assert_eq!(out[1].id, "2");
131+
let a_upd = store.get("1").unwrap().unwrap();
132+
assert_eq!(a_upd.counters.used_count, 1);
133+
assert_eq!(a_upd.counters.last_used_at.as_deref(), Some(now.as_str()));
134+
let b_upd = store.get("2").unwrap().unwrap();
135+
assert_eq!(b_upd.counters.used_count, 1);
136+
assert_eq!(b_upd.counters.last_used_at.as_deref(), Some(now.as_str()));
137+
let c_upd = store.get("3").unwrap().unwrap();
138+
assert_eq!(c_upd.counters.used_count, 0);
139+
assert_eq!(c_upd.counters.last_used_at, None);
4140
}

0 commit comments

Comments
 (0)