Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ async def do_create(self, app, data):

**Datastore Access**:
- Via `self.middleware.call('datastore.query', 'table', filters, options)`
- CRUD: `datastore.create()`, `datastore.update()`, `datastore.delete()`
- CRUD: `datastore.insert()`, `datastore.update()`, `datastore.delete()`
- Supports filtering and complex queries

### Job System
Expand Down
143 changes: 117 additions & 26 deletions src/middlewared/middlewared/plugins/datastore/connection.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from concurrent.futures import ThreadPoolExecutor
import re
import shutil
import threading
import time
from os import getpid

from sqlalchemy import create_engine, text
from sqlalchemy import create_engine, event, text
from sqlalchemy.pool import NullPool

from middlewared.service import private, Service
from middlewared.service_exception import CallError

from middlewared.utils.db import FREENAS_DATABASE

thread_pool = ThreadPoolExecutor(1)
_tls = threading.local()


def regexp(expr, item):
Expand All @@ -20,60 +23,140 @@ def regexp(expr, item):
return reg.search(item) is not None


def _on_db_connect(dbapi_conn, _):
dbapi_conn.create_function('REGEXP', 2, regexp)
dbapi_conn.execute('PRAGMA foreign_keys=ON')


class DatastoreService(Service):

class Config:
private = True
thread_pool = thread_pool

engine = None
connection = None
write_lock = threading.Lock()
_generation = 0
_main_pid = None

def _get_conn(self):
"""Return the thread-local SQLite connection for the current generation.

Creates a fresh connection when either no connection exists for this thread
or the global generation counter has advanced past what the thread last saw.
The typical trigger for a generation bump is a new database being uploaded,
after which setup() disposes the old engine and increments the counter so
every thread transparently reconnects on its next access.
The stale connection is closed before the new one is opened.
"""
gen = self._generation
if getattr(_tls, 'db_generation', -1) != gen:
if (old := getattr(_tls, 'db_conn', None)) is not None:
try:
old.close()
except Exception:
pass
_tls.db_conn = self.engine.connect().execution_options(isolation_level='AUTOCOMMIT')
_tls.db_generation = gen

return _tls.db_conn

@private
def handle_constraint_violation(self, row, journal):
def handle_constraint_violation(self, conn, row, journal):
"""Log and remove a row that violates a foreign key constraint.

The offending DELETE statement is also appended to `journal` so that the
corrective actions taken at startup can be audited after the fact.
"""
self.logger.warning("Row %d in table %s violates foreign key constraint on table %s.",
row.rowid, row.table, row.parent)

self.logger.warning("Deleting row %d from table %s.", row.rowid, row.table)
op = f"DELETE FROM {row.table} WHERE rowid = {row.rowid}"
self.connection.execute(text(op))
conn.execute(text(op))
journal.write(f'{op}\n')

@private
def setup(self):
# In SQLAlchemy 2.0, we must close connections before disposing the engine
# to avoid "Cannot operate on a closed database" errors
if self.connection is not None:
self.connection.close()
"""Initialise (or re-initialise) the SQLite engine.

Disposes any existing engine, creates a fresh one, bumps the generation
counter so all threads obtain new connections on their next access, records
the current PID to guard against forked-child writes, repairs any
foreign-key violations, and runs VACUUM.
"""
if self.engine is not None:
self.engine.dispose()

self.engine = create_engine(f'sqlite:///{FREENAS_DATABASE}')
self.connection = self.engine.connect()
self.connection = self.connection.execution_options(isolation_level="AUTOCOMMIT")
self.connection.connection.create_function("REGEXP", 2, regexp)
self.connection.connection.execute("PRAGMA foreign_keys=ON")
# We're using a NullPool here because we're manually managing per-thread connections
# This is because in regular workflows we expect the database to get completely replaced
# (like during config uploads) and we need to track the "generation" of database to
# invalidate our per-thread connection.
self.engine = create_engine(
f'sqlite:///{FREENAS_DATABASE}',
connect_args={'check_same_thread': False},
poolclass=NullPool,
)

event.listen(self.engine, 'connect', _on_db_connect)

self._generation += 1
self._main_pid = getpid()
conn = self._get_conn()

if constraint_violations := self.connection.execute(text("PRAGMA foreign_key_check")).fetchall():
if constraint_violations := conn.execute(text('PRAGMA foreign_key_check')).fetchall():
ts = int(time.time())
shutil.copy(FREENAS_DATABASE, f'{FREENAS_DATABASE}_{ts}.bak')

with open(f'{FREENAS_DATABASE}_{ts}_journal.txt', 'w') as f:
for row in constraint_violations:
self.handle_constraint_violation(row, f)
self.handle_constraint_violation(conn, row, f)

self.connection.connection.execute("VACUUM")
conn.connection.execute('VACUUM')

def _check_main_pid(self):
"""Raise CallError if called from a process other than the one that ran setup().

Forked child processes must not write to the database because writes in the
child bypass the main process write_lock and the post_execute_write hook,
breaking the serialization guarantees that update hooks depend on.
"""
if getpid() != self._main_pid:
raise CallError('Datastore writes are not permitted from a forked child process')

@private
def execute(self, query, *params):
"""Execute a raw SQL write statement under the write lock.

Accepts either positional parameters or a single list/tuple of parameters.
Raises CallError if called from a forked child process.
"""
self._check_main_pid()
if len(params) == 1 and isinstance(params[0], (list, tuple)):
params = tuple(params[0])

return self.connection.exec_driver_sql(query, params)
with self.write_lock:
self._get_conn().exec_driver_sql(query, params)

@private
def execute_write(self, stmt, options=None):
"""Compile and execute a SQLAlchemy DML statement under the write lock.

Handles bind-parameter expansion for IN clauses and type-processor coercion.
When `return_last_insert_rowid` is set in options the integer row ID of the
inserted row is returned; otherwise the raw DBAPI result is returned.
Raises CallError if called from a forked child process.

After each successful write, fires the `datastore.post_execute_write` hook
**while still holding the write_lock**. The hook is registered as inline so
it runs synchronously on this thread. On HA systems this is where SQL
replication to the backup controller happens; holding the lock during the
hook call is intentional — it prevents any other write from reaching the
local database before the same SQL has been forwarded to the remote, avoiding
replication race conditions. Consequently, anything invoked from within that
hook must not attempt to call execute() or execute_write(), as write_lock is
non-reentrant and doing so would deadlock. Reads via fetchall() are safe.
"""
self._check_main_pid()
options = options or {}
options.setdefault('ha_sync', True)
options.setdefault('return_last_insert_rowid', False)
Expand Down Expand Up @@ -102,18 +185,26 @@ def execute_write(self, stmt, options=None):
else:
binds.append(value)

result = self.connection.exec_driver_sql(sql, tuple(binds))

self.middleware.call_hook_inline("datastore.post_execute_write", sql, binds, options)
with self.write_lock:
conn = self._get_conn()
result = conn.exec_driver_sql(sql, tuple(binds))
if options['return_last_insert_rowid']:
result = conn.execute(text('SELECT last_insert_rowid() as rowid')).scalar_one()

if options['return_last_insert_rowid']:
return self.connection.execute(text("SELECT last_insert_rowid() as rowid")).scalar_one()
self.middleware.call_hook_inline('datastore.post_execute_write', sql, binds, options)

return result

@private
def fetchall(self, query, params=None):
cursor = self.connection.execute(text(query) if isinstance(query, str) else query, params or {})
"""Execute a query and return all rows as a list of mappings.

Accepts either a raw SQL string or a SQLAlchemy selectable. Read operations
do not acquire the write lock and are safe to call concurrently from multiple
threads.
"""
conn = self._get_conn()
cursor = conn.execute(text(query) if isinstance(query, str) else query, params or {})
try:
return list(cursor.mappings())
finally:
Expand Down
14 changes: 13 additions & 1 deletion src/middlewared/middlewared/plugins/datastore/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class DatastoreService(Service):
class Config:
private = True

events = defaultdict(list)
events: dict[str, list] = defaultdict(list)

async def register_event(self, options: dict) -> None:
options = asdict(DatastoreRegisterEventArgs(**options))
Expand All @@ -35,6 +35,12 @@ async def send_insert_events(self, datastore, row):
)

async def send_update_events(self, datastore, id_):
"""Fire a CHANGED event for each plugin registered against datastore.

Re-fetches the row via plugin.query to get the current state. If the row
is gone by the time the event fires (deleted concurrently with the update),
the event is silently skipped rather than raising an error.
"""
for options in self.events[datastore]:
fields = await self._fields(options, {options["prefix"] + options["id"]: id_}, False)
if not fields:
Expand Down Expand Up @@ -65,6 +71,12 @@ async def _fields(self, options, row, get=True):
)

async def _send_event(self, options, type_, **kwargs):
"""Dispatch a single datastore event, optionally transforming it first.

If process_event is configured, calls it with (type_, kwargs) and uses
the returned (type_, kwargs) pair instead. Returning None from process_event
suppresses the event entirely.
"""
if options["process_event"]:
processed = await self.middleware.call(options["process_event"], type_, kwargs)
if processed is None:
Expand Down
26 changes: 22 additions & 4 deletions src/middlewared/middlewared/plugins/datastore/filter.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,37 @@
import operator
from typing import Any, Iterable, Literal
from typing import Any, Iterable, Literal, cast

from sqlalchemy import Column, ForeignKey, Table, func
from sqlalchemy.sql.elements import ColumnElement

from truenas_api_client import ejson as json
from middlewared.utils.jsonpath import JSON_PATH_PREFIX, json_path_parse
from .schema import SchemaMixin


def in_(col: Column, value: Iterable):
"""Return a SQLAlchemy expression matching rows where col is in value.

NULL entries in value are handled separately via IS NULL because SQL
equality comparisons with NULL never match via the IN operator.
"""
has_nulls = None in value
value = [v for v in value if v is not None]
expr = col.in_(value)
expr: ColumnElement[bool] = col.in_(value)
if has_nulls:
expr = expr | (col == None) # noqa
return expr


def nin(col: Column, value: Iterable):
"""Return a SQLAlchemy expression matching rows where col is not in value.

NULL entries in value are handled separately via IS NOT NULL because SQL
equality comparisons with NULL never match via the NOT IN operator.
"""
has_nulls = None in value
value = [v for v in value if v is not None]
expr = ~col.in_(value)
expr: ColumnElement[bool] = ~col.in_(value)
if has_nulls:
expr = expr & (col != None) # noqa
return expr
Expand All @@ -31,6 +42,7 @@ def nin(col: Column, value: Iterable):

class FilterMixin(SchemaMixin):
def _filters_contains_foreign_key(self, filters: FiltersList) -> bool:
"""Return True if any filter references a joined table via __ or . syntax."""
for f in filters:
if not isinstance(f, (list, tuple)):
raise ValueError('Filter must be a list or tuple: {0}'.format(f))
Expand All @@ -47,6 +59,12 @@ def _filters_to_queryset(
prefix: str | None,
aliases: dict[ForeignKey, Table]
) -> list:
"""Translate a middleware filter list into SQLAlchemy WHERE clause expressions.

Handles simple comparisons, JSONPath extraction ($.field), foreign-key
traversal via __ or . separators, and OR conjunctions. Returns a list of
SQLAlchemy column expressions suitable for passing to and_().
"""
opmap = {
'=': operator.eq,
'!=': operator.ne,
Expand Down Expand Up @@ -81,7 +99,7 @@ def _filters_to_queryset(
col = func.json_extract(col, json_target)
is_json_extract = True
elif matched := next((x for x in ['__', '.'] if x in name), False):
fk, name = name.split(matched, 1)
fk, name = name.split(cast(str, matched), 1)
col = self._get_col(aliases[list(self._get_col(table, fk, prefix).foreign_keys)[0]], name, '')
else:
col = self._get_col(table, name, prefix)
Expand Down
25 changes: 25 additions & 0 deletions src/middlewared/middlewared/plugins/datastore/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,11 @@ async def config(self, name: str, options: dict | None = None):
return await self.query(name, [], options)

def _get_queryset_joins(self, table):
"""Return a mapping of ForeignKey → aliased Table for all FK columns on table.

Recurses into joined tables so that chains of foreign keys are fully
resolved, enabling deep relationship loading in a single query.
"""
result = {}
for column in table.c:
if column.foreign_keys:
Expand All @@ -201,6 +206,11 @@ def _get_queryset_joins(self, table):
async def _queryset_serialize(
self, qs, table, aliases, relationships, extend, extend_context, field_prefix, select, extra_options, fk_attrs,
):
"""Serialize raw query rows to dicts and apply extend/select transformations.

Calls extend once per row (or once with context for extend_context) and
then applies per-FK extend passes for any foreign keys listed in fk_attrs.
"""
rows = []
for i, row in enumerate(qs):
rows.append(self._serialize(row, table, aliases, relationships[i], field_prefix, fk_attrs))
Expand Down Expand Up @@ -232,6 +242,7 @@ async def _queryset_serialize(
return result

def _serialize(self, obj, table, aliases, relationships, field_prefix, fk_attrs):
"""Convert one result row to a dict, stripping the field prefix and merging pre-fetched relationship data."""
data = self._serialize_row(obj, table, aliases)
data.update(relationships)

Expand All @@ -245,6 +256,7 @@ def _serialize(self, obj, table, aliases, relationships, field_prefix, fk_attrs)
return result

async def _extend(self, data, extend, extend_context, extend_context_value, select):
"""Call the extend method on a row dict and apply field selection if requested."""
if extend:
if extend_context:
data = await self.middleware.call(extend, data, extend_context_value)
Expand All @@ -257,9 +269,16 @@ async def _extend(self, data, extend, extend_context, extend_context_value, sele
return do_select([data], select)[0]

def _strip_prefix(self, k, field_prefix):
"""Remove field_prefix from k if k starts with it, otherwise return k unchanged."""
return k[len(field_prefix):] if field_prefix and k.startswith(field_prefix) else k

def _serialize_row(self, obj, table, aliases):
"""Produce a raw key→value dict from one query row.

Scalar columns are included directly. FK columns (ending in _id) are
replaced by a nested dict built by recursing into the joined alias table.
A NULL FK value or a NULL joined PK produces None for that key.
"""
data = {}

for column in table.c:
Expand All @@ -285,6 +304,12 @@ def _serialize_row(self, obj, table, aliases):
return data

async def _fetch_many_to_many(self, table, rows):
"""Fetch many-to-many related rows for every PK in rows.

For each relationship on the table, queries the junction table to collect
child IDs, then fetches the child rows in bulk and groups them back by
parent PK. Returns a list of relationship dicts aligned with rows.
"""
pk = self._get_pk(table)
pk_values = [row[pk.name] for row in rows]

Expand Down
Loading