Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions codex-rs/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions codex-rs/memory/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ edition = { workspace = true }
anyhow = "1"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
chrono = { version = "0.4", default-features = false, features = ["clock"] }

[features]
default = []
Expand Down
92 changes: 88 additions & 4 deletions codex-rs/memory/src/recall.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
use crate::store::MemoryStore;
use crate::types::MemoryItem;
use crate::types::Status;
use chrono::DateTime;
use chrono::Utc;
use std::collections::BTreeSet;

pub struct RecallContext {
pub repo_root: Option<std::path::PathBuf>,
Expand All @@ -13,9 +18,88 @@ pub struct RecallContext {
}

pub fn recall(
_store: &dyn crate::store::MemoryStore,
_prompt: &str,
_ctx: &RecallContext,
store: &dyn MemoryStore,
prompt: &str,
ctx: &RecallContext,
) -> anyhow::Result<Vec<MemoryItem>> {
todo!()
let now = DateTime::parse_from_rfc3339(&ctx.now_rfc3339)?.with_timezone(&Utc);
let tokens = tokenize(prompt);
let mut scored: Vec<(f32, usize, MemoryItem)> = store
.list(None, Some(Status::Active))?
.into_iter()
.map(|item| {
let mut score = overlap_score(&tokens, &tokenize(&item.content));
if let Some(f) = &ctx.current_file
&& item.relevance_hints.files.iter().any(|h| f.ends_with(h))
{
score += 0.4;
}
if let Some(c) = &ctx.crate_name
&& item.relevance_hints.crates.iter().any(|h| h == c)
{
score += 0.3;
}
if let Some(l) = &ctx.language
&& item
.relevance_hints
.languages
.iter()
.any(|h| h.eq_ignore_ascii_case(l))
{
score += 0.2;
}
if let Some(cmd) = &ctx.command
&& item.relevance_hints.commands.iter().any(|h| h == cmd)
{
score += 0.1;
}
let freq = 1.0 + item.counters.used_count as f32 * 0.1;
score *= freq;
if let Some(last) = &item.counters.last_used_at
&& let Ok(dt) = DateTime::parse_from_rfc3339(last)
{
let age_days = (now - dt.with_timezone(&Utc)).num_days();
let decay = 0.5f32.powf(age_days as f32 / 7.0);
score *= decay;
}
let token_len = item.content.split_whitespace().count();
(score, token_len, item)
})
.collect();
scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
let mut out = Vec::new();
let mut used_tokens = 0usize;
for (_, tokens, mut item) in scored {
if out.len() >= ctx.item_cap {
break;
}
if used_tokens + tokens > ctx.token_cap {
break;
}
used_tokens += tokens;
item.counters.used_count += 1;
item.counters.last_used_at = Some(ctx.now_rfc3339.clone());
store.update(&item)?;
out.push(item);
}
Ok(out)
}

fn tokenize(s: &str) -> BTreeSet<String> {
let mut set = BTreeSet::new();
for w in s.split(|c: char| !c.is_alphanumeric()) {
if w.is_empty() {
continue;
}
set.insert(w.to_ascii_lowercase());
}
set
}

fn overlap_score(a: &BTreeSet<String>, b: &BTreeSet<String>) -> f32 {
if a.is_empty() || b.is_empty() {
return 0.0;
}
let inter = a.intersection(b).count() as f32;
inter / a.len() as f32
}
140 changes: 138 additions & 2 deletions codex-rs/memory/tests/recall.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,140 @@
use codex_memory::recall::RecallContext;
use codex_memory::recall::recall;
use codex_memory::store::MemoryStore;
use codex_memory::types::Counters;
use codex_memory::types::Kind;
use codex_memory::types::MemoryItem;
use codex_memory::types::RelevanceHints;
use codex_memory::types::Scope;
use codex_memory::types::Status;
use std::collections::HashMap;
use std::sync::Mutex;

#[derive(Default)]
struct TestStore {
items: Mutex<HashMap<String, MemoryItem>>,
}

impl TestStore {
fn new(items: Vec<MemoryItem>) -> Self {
let map = items.into_iter().map(|i| (i.id.clone(), i)).collect();
Self {
items: Mutex::new(map),
}
}
}

impl MemoryStore for TestStore {
fn add(&self, item: MemoryItem) -> anyhow::Result<()> {
self.items.lock().unwrap().insert(item.id.clone(), item);
Ok(())
}

fn update(&self, item: &MemoryItem) -> anyhow::Result<()> {
self.items
.lock()
.unwrap()
.insert(item.id.clone(), item.clone());
Ok(())
}

fn delete(&self, _id: &str) -> anyhow::Result<()> {
Ok(())
}

fn get(&self, id: &str) -> anyhow::Result<Option<MemoryItem>> {
Ok(self.items.lock().unwrap().get(id).cloned())
}

fn list(
&self,
_scope: Option<Scope>,
status: Option<Status>,
) -> anyhow::Result<Vec<MemoryItem>> {
let items = self.items.lock().unwrap();
Ok(items
.values()
.filter(|i| match status.as_ref() {
Some(s) => i.status == *s,
None => true,
})
.cloned()
.collect())
}

fn archive(&self, _id: &str, _archived: bool) -> anyhow::Result<()> {
Ok(())
}

fn export(&self, _out: &mut dyn std::io::Write) -> anyhow::Result<()> {
Ok(())
}

fn import(&self, _input: &mut dyn std::io::Read) -> anyhow::Result<usize> {
Ok(0)
}

fn stats(&self) -> anyhow::Result<serde_json::Value> {
Ok(serde_json::json!({}))
}
}

fn item(id: &str, content: &str, lang: &str) -> MemoryItem {
MemoryItem {
id: id.to_string(),
created_at: "2024-01-01T00:00:00Z".into(),
updated_at: "2024-01-01T00:00:00Z".into(),
schema_version: 1,
source: "test".into(),
scope: Scope::Global,
status: Status::Active,
kind: Kind::Fact,
content: content.into(),
tags: vec![],
relevance_hints: RelevanceHints {
files: vec![],
crates: vec![],
languages: vec![lang.into()],
commands: vec![],
},
counters: Counters {
seen_count: 0,
used_count: 0,
last_used_at: None,
},
expiry: None,
}
}

#[test]
fn placeholder() {
// placeholder test
fn ranks_and_updates_counters() {
let a = item("1", "use cargo build for rust", "rust");
let b = item("2", "cargo test runs tests", "rust");
let c = item("3", "npm install packages", "javascript");
let store = TestStore::new(vec![a.clone(), b.clone(), c.clone()]);
let now = "2024-01-10T00:00:00Z".to_string();
let ctx = RecallContext {
repo_root: None,
dir: None,
current_file: None,
crate_name: None,
language: Some("rust".into()),
command: None,
now_rfc3339: now.clone(),
item_cap: 2,
token_cap: 50,
};
let out = recall(&store, "cargo build rust", &ctx).unwrap();
assert_eq!(out.len(), 2);
assert_eq!(out[0].id, "1");
assert_eq!(out[1].id, "2");
let a_upd = store.get("1").unwrap().unwrap();
assert_eq!(a_upd.counters.used_count, 1);
assert_eq!(a_upd.counters.last_used_at.as_deref(), Some(now.as_str()));
let b_upd = store.get("2").unwrap().unwrap();
assert_eq!(b_upd.counters.used_count, 1);
assert_eq!(b_upd.counters.last_used_at.as_deref(), Some(now.as_str()));
let c_upd = store.get("3").unwrap().unwrap();
assert_eq!(c_upd.counters.used_count, 0);
assert_eq!(c_upd.counters.last_used_at, None);
}
Loading