Skip to content

Commit 458cf7a

Browse files
authored
Add additional context (#83)
* Add `Session.add_system_instructions` method * Rename `Session.add_system_instructions` to `Session.add_additional_context` * Add basic add_additional_context unit tests * Allow add_additional_context to be called multiple times * Address PR comments * Address PR comments * Clearer properties * Address PR comments
1 parent e581f46 commit 458cf7a

File tree

3 files changed

+89
-25
lines changed

3 files changed

+89
-25
lines changed

databao/agents/lighthouse/agent.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,24 +17,29 @@ def __init__(self) -> None:
1717
"""Initialize agent with lazy graph compilation."""
1818
super().__init__()
1919
self._cached_graph: ExecuteSubmit | None = None
20+
self._prompt_template = read_prompt_template(Path("system_prompt.jinja"))
2021

2122
def render_system_prompt(self, data_connection: Any, session: Session) -> str:
2223
"""Render system prompt with database schema."""
23-
prompt_template = read_prompt_template(Path("system_prompt.jinja"))
2424
db_schema = describe_duckdb_schema(data_connection)
25-
db_contexts, df_contexts = session.context
25+
2626
context = ""
27-
for db_name, db_context in db_contexts.items():
27+
for db_name, db_context in session.db_context.items():
2828
context += f"## Context for DB {db_name}\n\n{db_context}\n\n"
29-
for df_name, df_context in df_contexts.items():
29+
for df_name, df_context in session.df_context.items():
3030
context += f"## Context for DF {df_name} (fully qualified name 'temp.main.{df_name}')\n\n{df_context}\n\n"
31+
for idx, additional_ctx in enumerate(session.additional_context, start=1):
32+
additional_context = additional_ctx.strip()
33+
context += f"## General information {idx}\n\n{additional_context}\n\n"
34+
context = context.strip()
3135

32-
prompt = prompt_template.render(
36+
prompt = self._prompt_template.render(
3337
date=get_today_date_str(),
3438
db_schema=db_schema,
3539
context=context,
3640
)
37-
return prompt
41+
42+
return prompt.strip()
3843

3944
def _create_graph(self, data_connection: Any, llm_config: LLMConfig) -> Any:
4045
"""Create and compile the Lighthouse agent graph."""

databao/core/session.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,9 @@ def __init__(
4040
self.__dbs: dict[str, Any] = {}
4141
self.__dfs: dict[str, DataFrame] = {}
4242

43-
self.__db_contexts: dict[str, str] = {}
44-
self.__df_contexts: dict[str, str] = {}
43+
self.__db_context: dict[str, str] = {}
44+
self.__df_context: dict[str, str] = {}
45+
self.__additional_context: list[str] = []
4546

4647
# Create a DuckDB connection for the session
4748
self.__duckdb_connection = duckdb.connect(":memory:")
@@ -93,7 +94,7 @@ def add_db(self, connection: Any, *, name: str | None = None, context: str | Pat
9394
self.__dbs[conn_name] = connection
9495

9596
if (context_text := self._parse_context_arg(context)) is not None:
96-
self.__db_contexts[conn_name] = context_text
97+
self.__db_context[conn_name] = context_text
9798

9899
def add_df(self, df: DataFrame, *, name: str | None = None, context: str | Path | None = None) -> None:
99100
"""Register a DataFrame in this session and in the session's DuckDB.
@@ -114,7 +115,21 @@ def add_df(self, df: DataFrame, *, name: str | None = None, context: str | Path
114115
self.__dbs["duckdb"] = self.__duckdb_connection
115116

116117
if (context_text := self._parse_context_arg(context)) is not None:
117-
self.__df_contexts[df_name] = context_text
118+
self.__df_context[df_name] = context_text
119+
120+
def add_context(self, context: str | Path) -> None:
121+
"""Add additional context to help models understand your data.
122+
123+
Use this method to add general information that might not be associated with a specific data source.
124+
If the information is specific to a data source, use the `context` argument of `add_db` and `add_df`.
125+
126+
Args:
127+
context: The string or the path to a file containing the additional context.
128+
"""
129+
text = self._parse_context_arg(context)
130+
if text is None:
131+
raise ValueError("Invalid context provided.")
132+
self.__additional_context.append(text)
118133

119134
def thread(self, *, stream_ask: bool | None = None, stream_plot: bool | None = None) -> Pipe:
120135
"""Start a new thread in this session."""
@@ -158,6 +173,16 @@ def cache(self) -> "Cache":
158173
return self.__cache
159174

160175
@property
161-
def context(self) -> tuple[dict[str, str], dict[str, str]]:
162-
"""Per-source natural-language context for DBs and DFs: (db_contexts, df_contexts)."""
163-
return self.__db_contexts, self.__df_contexts
176+
def db_context(self) -> dict[str, str]:
177+
"""Per-source natural-language context for DBs."""
178+
return self.__db_context
179+
180+
@property
181+
def df_context(self) -> dict[str, str]:
182+
"""Per-source natural-language context for DFs."""
183+
return self.__df_context
184+
185+
@property
186+
def additional_context(self) -> list[str]:
187+
"""General additional context not specific to any one data source."""
188+
return self.__additional_context

tests/test_session.py

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,8 @@ def test_add_db_with_temp_file_context(temp_context_file: Path) -> None:
4141
session = databao.open_session("ctx_path_session_db_file")
4242
session.add_db(conn, context=temp_context_file)
4343

44-
db_contexts, _df_contexts = session.context
45-
assert "db1" in db_contexts
46-
assert db_contexts["db1"] == temp_context_file.read_text()
44+
assert "db1" in session.db_context
45+
assert session.db_context["db1"] == temp_context_file.read_text()
4746

4847

4948
def test_add_df_with_temp_file_context(temp_context_file: Path) -> None:
@@ -52,9 +51,8 @@ def test_add_df_with_temp_file_context(temp_context_file: Path) -> None:
5251
session = databao.open_session("ctx_path_session_df_file")
5352
session.add_df(df, context=temp_context_file)
5453

55-
_db_contexts, df_contexts = session.context
56-
assert "df1" in df_contexts
57-
assert df_contexts["df1"] == temp_context_file.read_text()
54+
assert "df1" in session.df_context
55+
assert session.df_context["df1"] == temp_context_file.read_text()
5856

5957

6058
def test_add_db_with_string_context() -> None:
@@ -64,9 +62,8 @@ def test_add_db_with_string_context() -> None:
6462
context_string = "This is a string context for the database."
6563
session.add_db(conn, context=context_string)
6664

67-
db_contexts, _df_contexts = session.context
68-
assert "db1" in db_contexts
69-
assert db_contexts["db1"] == context_string
65+
assert "db1" in session.db_context
66+
assert session.db_context["db1"] == context_string
7067

7168

7269
def test_add_df_with_string_context() -> None:
@@ -76,6 +73,43 @@ def test_add_df_with_string_context() -> None:
7673
context_string = "This is a string context for the DataFrame."
7774
session.add_df(df, context=context_string)
7875

79-
_db_contexts, df_contexts = session.context
80-
assert "df1" in df_contexts
81-
assert df_contexts["df1"] == context_string
76+
assert "df1" in session.df_context
77+
assert session.df_context["df1"] == context_string
78+
79+
80+
def test_add_additional_context_with_nonexistent_path_raises() -> None:
81+
"""add_additional_context should raise if given a non-existent Path."""
82+
session = databao.open_session("additional_ctx_missing_path")
83+
with pytest.raises(FileNotFoundError):
84+
session.add_context(Path("no_such_context_file_123.md"))
85+
86+
87+
def test_add_additional_context_with_temp_file(temp_context_file: Path) -> None:
88+
"""Ensure additional context can be loaded from a temporary file path."""
89+
session = databao.open_session("additional_ctx_from_file")
90+
session.add_context(temp_context_file)
91+
assert session.additional_context == [temp_context_file.read_text()]
92+
93+
94+
def test_add_additional_context_with_string() -> None:
95+
"""Ensure additional context can be provided directly as a string."""
96+
session = databao.open_session("additional_ctx_from_string")
97+
text = "Global instructions for the session go here."
98+
session.add_context(text)
99+
assert session.additional_context == [text]
100+
101+
102+
def test_add_additional_context_multiple_calls_mixed_sources(temp_context_file: Path) -> None:
103+
"""Calling add_additional_context multiple times should append in order."""
104+
session = databao.open_session("additional_ctx_multiple")
105+
first = "First global instruction."
106+
second = temp_context_file.read_text()
107+
third = "Third bit of context."
108+
109+
session.add_context(first)
110+
session.add_context(temp_context_file)
111+
session.add_context(third)
112+
113+
assert first in session.additional_context
114+
assert second in session.additional_context
115+
assert third in session.additional_context

0 commit comments

Comments
 (0)