Skip to content

Commit 3533961

Browse files
authored
refactor(backend): Clean up Library & Store DB schema (#9774)
Distilled from #9541 to reduce the scope of that PR. - Part of #9307 - ❗ Blocks #9786 - ❗ Blocks #9541 ### Changes 🏗️ - Fix `LibraryAgent` schema (for #9786) - Fix relationships between `LibraryAgent`, `AgentGraph`, and `AgentPreset` - Impose uniqueness constraint on `LibraryAgent` - Rename things that are called `agent` that actually refer to a `graph`/`agentGraph` - Fix singular/plural forms in DB schema - Simplify reference names of closely related objects (e.g. `AgentGraph.AgentGraphExecutions` -> `AgentGraph.Executions`) - Eliminate use of `# type: ignore` in DB statements - Add `typed` and `typed_cast` utilities to `backend.util.type` ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - [x] CI static type checking (with all risky `# type: ignore` removed) - [x] Check that column references in views are updated
1 parent 70890de commit 3533961

File tree

29 files changed

+441
-295
lines changed

29 files changed

+441
-295
lines changed

autogpt_platform/backend/backend/data/credit.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from abc import ABC, abstractmethod
44
from collections import defaultdict
55
from datetime import datetime, timezone
6+
from typing import cast
67

78
import stripe
89
from autogpt_libs.utils.cache import thread_cached
@@ -18,6 +19,7 @@
1819
CreditRefundRequestCreateInput,
1920
CreditTransactionCreateInput,
2021
CreditTransactionWhereInput,
22+
IntFilter,
2123
)
2224
from tenacity import retry, stop_after_attempt, wait_exponential
2325

@@ -213,7 +215,7 @@ async def _get_credits(self, user_id: str) -> tuple[int, datetime]:
213215
"userId": user_id,
214216
"createdAt": {"lte": top_time},
215217
"isActive": True,
216-
"runningBalance": {"not": None}, # type: ignore
218+
"runningBalance": cast(IntFilter, {"not": None}),
217219
},
218220
order={"createdAt": "desc"},
219221
)

autogpt_platform/backend/backend/data/execution.py

+36-35
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
AgentNodeExecutionInputOutput,
2424
)
2525
from prisma.types import (
26+
AgentGraphExecutionCreateInput,
2627
AgentGraphExecutionWhereInput,
2728
AgentNodeExecutionCreateInput,
2829
AgentNodeExecutionInputOutputCreateInput,
@@ -121,15 +122,15 @@ class GraphExecution(GraphExecutionMeta):
121122

122123
@staticmethod
123124
def from_db(_graph_exec: AgentGraphExecution):
124-
if _graph_exec.AgentNodeExecutions is None:
125+
if _graph_exec.NodeExecutions is None:
125126
raise ValueError("Node executions must be included in query")
126127

127128
graph_exec = GraphExecutionMeta.from_db(_graph_exec)
128129

129130
complete_node_executions = sorted(
130131
[
131132
NodeExecutionResult.from_db(ne, _graph_exec.userId)
132-
for ne in _graph_exec.AgentNodeExecutions
133+
for ne in _graph_exec.NodeExecutions
133134
if ne.executionStatus != ExecutionStatus.INCOMPLETE
134135
],
135136
key=lambda ne: (ne.queue_time is None, ne.queue_time or ne.add_time),
@@ -181,15 +182,15 @@ class GraphExecutionWithNodes(GraphExecution):
181182

182183
@staticmethod
183184
def from_db(_graph_exec: AgentGraphExecution):
184-
if _graph_exec.AgentNodeExecutions is None:
185+
if _graph_exec.NodeExecutions is None:
185186
raise ValueError("Node executions must be included in query")
186187

187188
graph_exec_with_io = GraphExecution.from_db(_graph_exec)
188189

189190
node_executions = sorted(
190191
[
191192
NodeExecutionResult.from_db(ne, _graph_exec.userId)
192-
for ne in _graph_exec.AgentNodeExecutions
193+
for ne in _graph_exec.NodeExecutions
193194
],
194195
key=lambda ne: (ne.queue_time is None, ne.queue_time or ne.add_time),
195196
)
@@ -220,21 +221,21 @@ class NodeExecutionResult(BaseModel):
220221
end_time: datetime | None
221222

222223
@staticmethod
223-
def from_db(execution: AgentNodeExecution, user_id: Optional[str] = None):
224-
if execution.executionData:
224+
def from_db(_node_exec: AgentNodeExecution, user_id: Optional[str] = None):
225+
if _node_exec.executionData:
225226
# Execution that has been queued for execution will persist its data.
226-
input_data = type_utils.convert(execution.executionData, dict[str, Any])
227+
input_data = type_utils.convert(_node_exec.executionData, dict[str, Any])
227228
else:
228229
# For incomplete execution, executionData will not be yet available.
229230
input_data: BlockInput = defaultdict()
230-
for data in execution.Input or []:
231+
for data in _node_exec.Input or []:
231232
input_data[data.name] = type_utils.convert(data.data, type[Any])
232233

233234
output_data: CompletedBlockOutput = defaultdict(list)
234-
for data in execution.Output or []:
235+
for data in _node_exec.Output or []:
235236
output_data[data.name].append(type_utils.convert(data.data, type[Any]))
236237

237-
graph_execution: AgentGraphExecution | None = execution.AgentGraphExecution
238+
graph_execution: AgentGraphExecution | None = _node_exec.GraphExecution
238239
if graph_execution:
239240
user_id = graph_execution.userId
240241
elif not user_id:
@@ -246,17 +247,17 @@ def from_db(execution: AgentNodeExecution, user_id: Optional[str] = None):
246247
user_id=user_id,
247248
graph_id=graph_execution.agentGraphId if graph_execution else "",
248249
graph_version=graph_execution.agentGraphVersion if graph_execution else 0,
249-
graph_exec_id=execution.agentGraphExecutionId,
250-
block_id=execution.AgentNode.agentBlockId if execution.AgentNode else "",
251-
node_exec_id=execution.id,
252-
node_id=execution.agentNodeId,
253-
status=execution.executionStatus,
250+
graph_exec_id=_node_exec.agentGraphExecutionId,
251+
block_id=_node_exec.Node.agentBlockId if _node_exec.Node else "",
252+
node_exec_id=_node_exec.id,
253+
node_id=_node_exec.agentNodeId,
254+
status=_node_exec.executionStatus,
254255
input_data=input_data,
255256
output_data=output_data,
256-
add_time=execution.addedTime,
257-
queue_time=execution.queuedTime,
258-
start_time=execution.startedTime,
259-
end_time=execution.endedTime,
257+
add_time=_node_exec.addedTime,
258+
queue_time=_node_exec.queuedTime,
259+
start_time=_node_exec.startedTime,
260+
end_time=_node_exec.endedTime,
260261
)
261262

262263

@@ -351,29 +352,29 @@ async def create_graph_execution(
351352
The id of the AgentGraphExecution and the list of ExecutionResult for each node.
352353
"""
353354
result = await AgentGraphExecution.prisma().create(
354-
data={
355-
"agentGraphId": graph_id,
356-
"agentGraphVersion": graph_version,
357-
"executionStatus": ExecutionStatus.QUEUED,
358-
"AgentNodeExecutions": {
359-
"create": [ # type: ignore
360-
{
361-
"agentNodeId": node_id,
362-
"executionStatus": ExecutionStatus.QUEUED,
363-
"queuedTime": datetime.now(tz=timezone.utc),
364-
"Input": {
355+
data=AgentGraphExecutionCreateInput(
356+
agentGraphId=graph_id,
357+
agentGraphVersion=graph_version,
358+
executionStatus=ExecutionStatus.QUEUED,
359+
NodeExecutions={
360+
"create": [
361+
AgentNodeExecutionCreateInput(
362+
agentNodeId=node_id,
363+
executionStatus=ExecutionStatus.QUEUED,
364+
queuedTime=datetime.now(tz=timezone.utc),
365+
Input={
365366
"create": [
366367
{"name": name, "data": Json(data)}
367368
for name, data in node_input.items()
368369
]
369370
},
370-
}
371+
)
371372
for node_id, node_input in nodes_input
372373
]
373374
},
374-
"userId": user_id,
375-
"agentPresetId": preset_id,
376-
},
375+
userId=user_id,
376+
agentPresetId=preset_id,
377+
),
377378
include=GRAPH_EXECUTION_INCLUDE_WITH_NODES,
378379
)
379380

@@ -600,7 +601,7 @@ async def get_node_execution_results(
600601
"agentGraphExecutionId": graph_exec_id,
601602
}
602603
if block_ids:
603-
where_clause["AgentNode"] = {"is": {"agentBlockId": {"in": block_ids}}}
604+
where_clause["Node"] = {"is": {"agentBlockId": {"in": block_ids}}}
604605
if statuses:
605606
where_clause["OR"] = [{"executionStatus": status} for status in statuses]
606607

autogpt_platform/backend/backend/data/graph.py

+36-23
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import uuid
33
from collections import defaultdict
4-
from typing import Any, Literal, Optional, Type
4+
from typing import Any, Literal, Optional, Type, cast
55

66
import prisma
77
from prisma import Json
@@ -10,7 +10,9 @@
1010
from prisma.types import (
1111
AgentGraphCreateInput,
1212
AgentGraphWhereInput,
13+
AgentGraphWhereInputRecursive1,
1314
AgentNodeCreateInput,
15+
AgentNodeIncludeFromAgentNodeRecursive1,
1416
AgentNodeLinkCreateInput,
1517
)
1618
from pydantic.fields import computed_field
@@ -465,13 +467,11 @@ def from_db(
465467
is_active=graph.isActive,
466468
name=graph.name or "",
467469
description=graph.description or "",
468-
nodes=[
469-
NodeModel.from_db(node, for_export) for node in graph.AgentNodes or []
470-
],
470+
nodes=[NodeModel.from_db(node, for_export) for node in graph.Nodes or []],
471471
links=list(
472472
{
473473
Link.from_db(link)
474-
for node in graph.AgentNodes or []
474+
for node in graph.Nodes or []
475475
for link in (node.Input or []) + (node.Output or [])
476476
}
477477
),
@@ -602,8 +602,8 @@ async def get_graph(
602602
and not (
603603
await StoreListingVersion.prisma().find_first(
604604
where={
605-
"agentId": graph_id,
606-
"agentVersion": version or graph.version,
605+
"agentGraphId": graph_id,
606+
"agentGraphVersion": version or graph.version,
607607
"isDeleted": False,
608608
"submissionStatus": SubmissionStatus.APPROVED,
609609
}
@@ -637,12 +637,16 @@ async def get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]:
637637
sub_graph_ids = [
638638
(graph_id, graph_version)
639639
for graph in search_graphs
640-
for node in graph.AgentNodes or []
640+
for node in graph.Nodes or []
641641
if (
642642
node.AgentBlock
643643
and node.AgentBlock.id == agent_block_id
644-
and (graph_id := dict(node.constantInput).get("graph_id"))
645-
and (graph_version := dict(node.constantInput).get("graph_version"))
644+
and (graph_id := cast(str, dict(node.constantInput).get("graph_id")))
645+
and (
646+
graph_version := cast(
647+
int, dict(node.constantInput).get("graph_version")
648+
)
649+
)
646650
)
647651
]
648652
if not sub_graph_ids:
@@ -651,13 +655,16 @@ async def get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]:
651655
graphs = await AgentGraph.prisma().find_many(
652656
where={
653657
"OR": [
654-
{
655-
"id": graph_id,
656-
"version": graph_version,
657-
"userId": graph.userId, # Ensure the sub-graph is owned by the same user
658-
}
658+
type_utils.typed(
659+
AgentGraphWhereInputRecursive1,
660+
{
661+
"id": graph_id,
662+
"version": graph_version,
663+
"userId": graph.userId, # Ensure the sub-graph is owned by the same user
664+
},
665+
)
659666
for graph_id, graph_version in sub_graph_ids
660-
] # type: ignore
667+
]
661668
},
662669
include=AGENT_GRAPH_INCLUDE,
663670
)
@@ -671,7 +678,13 @@ async def get_sub_graphs(graph: AgentGraph) -> list[AgentGraph]:
671678
async def get_connected_output_nodes(node_id: str) -> list[tuple[Link, Node]]:
672679
links = await AgentNodeLink.prisma().find_many(
673680
where={"agentNodeSourceId": node_id},
674-
include={"AgentNodeSink": {"include": AGENT_NODE_INCLUDE}}, # type: ignore
681+
include={
682+
"AgentNodeSink": {
683+
"include": cast(
684+
AgentNodeIncludeFromAgentNodeRecursive1, AGENT_NODE_INCLUDE
685+
)
686+
}
687+
},
675688
)
676689
return [
677690
(Link.from_db(link), NodeModel.from_db(link.AgentNodeSink))
@@ -829,12 +842,12 @@ async def fix_llm_provider_credentials():
829842
SELECT graph."userId" user_id,
830843
node.id node_id,
831844
node."constantInput" node_preset_input
832-
FROM platform."AgentNode" node
833-
LEFT JOIN platform."AgentGraph" graph
834-
ON node."agentGraphId" = graph.id
835-
WHERE node."constantInput"::jsonb->'credentials'->>'provider' = 'llm'
836-
ORDER BY graph."userId";
837-
"""
845+
FROM platform."AgentNode" node
846+
LEFT JOIN platform."AgentGraph" graph
847+
ON node."agentGraphId" = graph.id
848+
WHERE node."constantInput"::jsonb->'credentials'->>'provider' = 'llm'
849+
ORDER BY graph."userId";
850+
"""
838851
)
839852
logger.info(f"Fixing LLM credential inputs on {len(broken_nodes)} nodes")
840853
except Exception as e:

0 commit comments

Comments
 (0)