Skip to content

Commit 9c84c5e

Browse files
committed
Distributed shard servers (end-to-end)
- wolfdb.server: ThreadingHTTPServer wrapping a Memory; JSON remember/recall/forget/snapshot/compact/gc/inspect; per-server lock; optional bearer-token auth; wolf-server CLI - wolfdb.cluster.DistributedMemory: stdlib client, subject-routed writes, parallel fan-out recall merged top-k -> recall escapes the GIL across processes - test_cluster.py: 5 real-HTTP e2e tests (routing, supersession, fan-out, forget, auth) - benchmarks/distributed.py spawns real subprocess servers; README distributed section - 74 tests, ruff+mypy clean
1 parent 144577b commit 9c84c5e

7 files changed

Lines changed: 364 additions & 6 deletions

File tree

README.md

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -156,13 +156,32 @@ mem.remember("Alice works at Acme", subject="Alice", predicate="works_at", objec
156156
mem.recall("where does Alice work?", k=5) # fans out to all shards, merges top-k
157157
```
158158

159+
`ShardedMemory` parallelizes writes in-process; for parallel **recall** (escaping the GIL),
160+
run each shard as its own process with the bundled server and route with `DistributedMemory`:
161+
162+
```bash
163+
wolf-server ./shard0 --port 8100 # one process per shard (set WOLF_TOKEN to require auth)
164+
wolf-server ./shard1 --port 8101
165+
```
166+
167+
```python
168+
from wolfdb import DistributedMemory
169+
170+
mem = DistributedMemory(["http://127.0.0.1:8100", "http://127.0.0.1:8101"])
171+
mem.remember("Bob likes tea", subject="Bob", predicate="likes", object="tea")
172+
mem.recall("tea", k=5) # fan-out over the network; shards score in parallel
173+
```
174+
175+
> The server binds to localhost and is unauthenticated unless you set a token
176+
> (`WOLF_TOKEN` env or `serve(token=...)`). Put it behind TLS + auth before exposing it.
177+
159178
## Repository layout
160179

161180
```
162181
docs/research/ competitive teardown + domain mastery
163182
docs/problem-statement.md the precise gap WolfDB fills
164183
docs/design/ data model, storage, write path, retrieval, API
165-
src/wolfdb/ engine, storage (local + S3), index, embedders, scoring, cli
184+
src/wolfdb/ engine, storage (local + S3), index, embedders, scoring, shard, server, cluster, cli
166185
tests/ unit + property (hypothesis) + concurrency + S3 tests
167186
benchmarks/ micro-benchmarks
168187
examples/ runnable quickstart
@@ -171,10 +190,10 @@ examples/ runnable quickstart
171190
## Roadmap
172191

173192
Reference implementation is Python. Shipped: object-storage backends, vectorized recall, an
174-
inverted index so hybrid/keyword recall scales sublinearly, size-tiered auto-compaction, and
175-
horizontal sharding. Next: a Rust production engine, a local disk cache in front of object
176-
storage, distributed (cross-process) shard servers, and HNSW-on-object-storage for
177-
billion-scale recall.
193+
inverted index so hybrid/keyword recall scales sublinearly, size-tiered auto-compaction,
194+
horizontal sharding, and distributed shard servers (one process per shard, fan-out recall).
195+
Next: a Rust production engine, a local disk cache in front of object storage, and
196+
HNSW-on-object-storage for billion-scale recall.
178197

179198
## License
180199

benchmarks/distributed.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""Distributed benchmark: spawn K real shard-server processes and measure fan-out.
2+
3+
Unlike in-process ShardedMemory (GIL-bound), each shard here is its own OS process,
4+
so recall fan-out runs truly in parallel. Run: python benchmarks/distributed.py [N] [K]
5+
"""
6+
import os
7+
import statistics
8+
import subprocess
9+
import sys
10+
import tempfile
11+
import time
12+
import urllib.error
13+
import urllib.request
14+
from concurrent.futures import ThreadPoolExecutor
15+
16+
from wolfdb import DistributedMemory
17+
18+
19+
def _wait_ready(url: str, timeout: float = 15.0) -> None:
20+
deadline = time.time() + timeout
21+
while time.time() < deadline:
22+
try:
23+
urllib.request.urlopen(url + "/inspect", timeout=1.0).read()
24+
return
25+
except (urllib.error.URLError, ConnectionError):
26+
time.sleep(0.1)
27+
raise RuntimeError(f"server {url} did not start")
28+
29+
30+
def main(n: int, k: int) -> None:
31+
root = tempfile.mkdtemp()
32+
procs, urls = [], []
33+
for i in range(k):
34+
port = 8100 + i
35+
procs.append(subprocess.Popen(
36+
[sys.executable, "-m", "wolfdb.server", os.path.join(root, f"s{i}"),
37+
"--port", str(port)], stdout=subprocess.DEVNULL))
38+
urls.append(f"http://127.0.0.1:{port}")
39+
try:
40+
for u in urls:
41+
_wait_ready(u)
42+
client = DistributedMemory(urls)
43+
44+
t0 = time.perf_counter()
45+
with ThreadPoolExecutor(max_workers=k * 4) as ex:
46+
list(ex.map(lambda i: client.remember(
47+
f"user {i} prefers product {i} in region {i % 13}",
48+
subject=f"user{i}", predicate="pref", object=f"product{i}"), range(n)))
49+
wdt = time.perf_counter() - t0
50+
print(f"N={n} across {k} server processes")
51+
print(f"write: {n / wdt:,.0f} facts/s ({wdt:.1f}s)")
52+
print(f"believed: {client.inspect()['facts_believed']}")
53+
54+
for mode in ("vector", "hybrid"):
55+
lat = []
56+
for i in range(200):
57+
s = time.perf_counter()
58+
client.recall(f"product {i * 7}", k=10, mode=mode)
59+
lat.append((time.perf_counter() - s) * 1000)
60+
lat.sort()
61+
print(f"recall {mode:7}: p50={statistics.median(lat):.2f}ms "
62+
f"p95={lat[int(len(lat) * 0.95)]:.2f}ms")
63+
finally:
64+
for p in procs:
65+
p.terminate()
66+
67+
68+
if __name__ == "__main__":
69+
args = [a for a in sys.argv[1:] if not a.startswith("--")]
70+
main(int(args[0]) if args else 4000, int(args[1]) if len(args) > 1 else 4)

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ dev = [
4444
[project.scripts]
4545
wolf = "wolfdb.cli:main"
4646
wolfdb = "wolfdb.cli:main"
47+
wolf-server = "wolfdb.server:main"
4748

4849
[tool.setuptools.packages.find]
4950
where = ["src"]

src/wolfdb/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""WolfDB: open-source temporal memory database for AI agents."""
2+
from .cluster import DistributedMemory
23
from .embedders import OpenAIEmbedder, SentenceTransformerEmbedder
34
from .engine import Hit, Memory
45
from .errors import Config, ConflictError, ValidationError, WolfError
@@ -15,5 +16,5 @@
1516
"Embedder", "Extractor", "HashingEmbedder", "IdentityExtractor",
1617
"OpenAIEmbedder", "SentenceTransformerEmbedder",
1718
"StorageBackend", "LocalBackend", "S3Backend", "Log", "Conflict",
18-
"VectorIndex", "cosine_scores", "ShardedMemory",
19+
"VectorIndex", "cosine_scores", "ShardedMemory", "DistributedMemory",
1920
]

src/wolfdb/cluster.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
"""Distributed client: fan out across WolfDB shard servers (separate processes).
2+
3+
Recall fans out over the network, so each shard's CPU work runs in its own process
4+
— escaping the GIL that limits in-process ``ShardedMemory``. Writes route by subject
5+
(keeping functional-edge supersession within a shard). Uses only the standard library.
6+
"""
7+
from __future__ import annotations
8+
9+
import hashlib
10+
import json
11+
import urllib.request
12+
from concurrent.futures import ThreadPoolExecutor
13+
14+
from .engine import Hit
15+
from .errors import ValidationError
16+
from .models import Fact
17+
18+
19+
def _route(key: str, n: int) -> int:
20+
return int.from_bytes(hashlib.blake2b(key.encode("utf-8"), digest_size=8).digest(), "big") % n
21+
22+
23+
class DistributedMemory:
24+
def __init__(self, urls: list[str], *, token: str | None = None, timeout: float = 30.0):
25+
if not urls:
26+
raise ValidationError("need at least one shard url")
27+
self.urls = [u.rstrip("/") for u in urls]
28+
self._token = token
29+
self._timeout = timeout
30+
self._pool = ThreadPoolExecutor(max_workers=len(self.urls))
31+
32+
# ---- transport ------------------------------------------------------
33+
def _headers(self) -> dict:
34+
h = {"Content-Type": "application/json"}
35+
if self._token:
36+
h["Authorization"] = f"Bearer {self._token}"
37+
return h
38+
39+
def _call(self, url: str, path: str, payload: dict | None) -> dict:
40+
data = None if payload is None else json.dumps(payload).encode("utf-8")
41+
req = urllib.request.Request(url + path, data=data, headers=self._headers(),
42+
method="POST" if data is not None else "GET")
43+
with urllib.request.urlopen(req, timeout=self._timeout) as r:
44+
return json.loads(r.read())
45+
46+
def _fanout(self, path: str, payload: dict | None) -> list[dict]:
47+
return list(self._pool.map(lambda u: self._call(u, path, payload), self.urls))
48+
49+
# ---- API ------------------------------------------------------------
50+
def remember(self, text: str = "", *, subject: str | None = None,
51+
partition_key: str | None = None, **kw) -> list[str]:
52+
i = _route(str(partition_key or subject or text or ""), len(self.urls))
53+
return self._call(self.urls[i], "/remember", {"text": text, "subject": subject, **kw})["ids"]
54+
55+
def recall(self, query: str, *, k: int = 10, **kw) -> list[Hit]:
56+
payload = {"query": query, "k": k, **kw}
57+
hits: list[Hit] = []
58+
for part in self._fanout("/recall", payload):
59+
hits.extend(Hit(fact=Fact.from_dict(h["fact"]), score=h["score"],
60+
components=h["components"]) for h in part["hits"])
61+
hits.sort(key=lambda h: h.score, reverse=True)
62+
return hits[:k]
63+
64+
def forget(self, fact_id: str) -> None:
65+
self._fanout("/forget", {"fact_id": fact_id}) # broadcast; no-op where absent
66+
67+
def snapshot(self, **kw) -> list[Fact]:
68+
return [Fact.from_dict(f) for part in self._fanout("/snapshot", kw) for f in part["facts"]]
69+
70+
def compact(self, **kw) -> int:
71+
return sum(part["n"] for part in self._fanout("/compact", kw))
72+
73+
def gc(self) -> int:
74+
return sum(part["n"] for part in self._fanout("/gc", {}))
75+
76+
def inspect(self) -> dict:
77+
infos = self._fanout("/inspect", None)
78+
agg: dict = {"shards": len(self.urls)}
79+
for key in ("facts_total", "facts_believed", "entities", "edges_believed",
80+
"events", "segments"):
81+
agg[key] = sum(i.get(key, 0) for i in infos)
82+
return agg
83+
84+
def close(self) -> None:
85+
self._pool.shutdown(wait=False)

src/wolfdb/server.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
"""HTTP shard server: exposes one Memory over JSON so shards run as separate
2+
processes (recall fan-out then escapes the GIL).
3+
4+
Security: binds to 127.0.0.1 by default and has NO authentication unless a token
5+
is configured (``serve(token=...)`` or the ``WOLF_TOKEN`` env var). Do not expose
6+
to an untrusted network without a token and TLS termination in front.
7+
"""
8+
from __future__ import annotations
9+
10+
import json
11+
import os
12+
import threading
13+
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
14+
15+
from .engine import Memory
16+
17+
18+
def _handler(memory: Memory, token: str | None):
19+
lock = threading.Lock() # Memory is not internally thread-safe; serialize per server
20+
21+
class Handler(BaseHTTPRequestHandler):
22+
def _auth_ok(self) -> bool:
23+
return not token or self.headers.get("Authorization") == f"Bearer {token}"
24+
25+
def _send(self, code: int, obj) -> None:
26+
body = json.dumps(obj).encode("utf-8")
27+
self.send_response(code)
28+
self.send_header("Content-Type", "application/json")
29+
self.send_header("Content-Length", str(len(body)))
30+
self.end_headers()
31+
self.wfile.write(body)
32+
33+
def _body(self) -> dict:
34+
n = int(self.headers.get("Content-Length", 0))
35+
return json.loads(self.rfile.read(n) or b"{}")
36+
37+
def log_message(self, *_a): # quiet
38+
pass
39+
40+
def do_GET(self):
41+
if not self._auth_ok():
42+
return self._send(401, {"error": "unauthorized"})
43+
if self.path == "/inspect":
44+
with lock:
45+
return self._send(200, memory.inspect())
46+
self._send(404, {"error": "not found"})
47+
48+
def do_POST(self):
49+
if not self._auth_ok():
50+
return self._send(401, {"error": "unauthorized"})
51+
try:
52+
p = self._body()
53+
with lock:
54+
self._dispatch(p)
55+
except Exception as exc: # noqa: BLE001
56+
self._send(400, {"error": str(exc)})
57+
58+
def _dispatch(self, p: dict) -> None:
59+
if self.path == "/remember":
60+
self._send(200, {"ids": memory.remember(p.pop("text", ""), **p)})
61+
elif self.path == "/recall":
62+
hits = memory.recall(p.pop("query", ""), **p)
63+
self._send(200, {"hits": [{"fact": h.fact.to_dict(), "score": h.score,
64+
"components": h.components} for h in hits]})
65+
elif self.path == "/forget":
66+
memory.forget(p["fact_id"])
67+
self._send(200, {"ok": True})
68+
elif self.path == "/snapshot":
69+
self._send(200, {"facts": [f.to_dict() for f in memory.snapshot(**p)]})
70+
elif self.path == "/compact":
71+
self._send(200, {"n": memory.compact(**p)})
72+
elif self.path == "/gc":
73+
self._send(200, {"n": memory.gc()})
74+
else:
75+
self._send(404, {"error": "not found"})
76+
77+
return Handler
78+
79+
80+
def serve(memory: Memory, host: str = "127.0.0.1", port: int = 8080,
81+
token: str | None = None) -> ThreadingHTTPServer:
82+
"""Build a ThreadingHTTPServer for `memory`. Call serve_forever() to run it."""
83+
return ThreadingHTTPServer((host, port), _handler(memory, token))
84+
85+
86+
def main(argv=None) -> int:
87+
import argparse
88+
ap = argparse.ArgumentParser(prog="wolf-server", description="WolfDB shard server")
89+
ap.add_argument("path", help="path to this shard's .wolf store")
90+
ap.add_argument("--host", default="127.0.0.1")
91+
ap.add_argument("--port", type=int, default=8080)
92+
ap.add_argument("--namespace", default="wolf")
93+
a = ap.parse_args(argv)
94+
token = os.environ.get("WOLF_TOKEN")
95+
httpd = serve(Memory.open(a.path, namespace=a.namespace), a.host, a.port, token)
96+
note = "" if token else " [NO AUTH — bind to localhost only]"
97+
print(f"WolfDB shard serving {a.path} at http://{a.host}:{a.port}{note}", flush=True)
98+
try:
99+
httpd.serve_forever()
100+
except KeyboardInterrupt:
101+
httpd.shutdown()
102+
return 0
103+
104+
105+
if __name__ == "__main__":
106+
raise SystemExit(main())

0 commit comments

Comments
 (0)