Skip to content

Functions to assist with patching one ObsDb sqlite file to match another #954

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 155 additions & 0 deletions sotodlib/core/metadata/obsdb.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import sqlite3
import os
import numpy as np

from .resultset import ResultSet
from . import common
from .. import util


TABLE_DEFS = {
Expand Down Expand Up @@ -350,3 +352,156 @@ def _short_list(items, max_len=40):
'fields': fields,
'tags': tags,
}


def diff_obsdbs(obsdb_left, obsdb_right, return_detail=False):
"""Examine all records in two obsdbs and construct a list of changes
that could made to obsdb_left in order to make it match
obsdb_right.

Returns a dict with following entries:

- ``different`` (bool): whether the two databases carry different
information.
- ``patchable`` (bool): whether the function was able to construct
patching instructions.
- ``unpatchable_reason`` (str): if not patchable, a string
explaining why.
- ``detail`` (various): if not patchable, and return_detail, then
this will contain detail about the offending data (e.g. obs rows
in the two dbs that contain discrepant data).
- ``patch_data`` (dict): if patchable, the data needed to patch
obsdb_left. The fields are:

- ``remove_obs`` (list of obs_id): entries to remove from obs
table.
- ``remove_tags`` (list of tuple): entries to remove from tags
table.
- ``new_obs`` (list of dict): rows of new data for obs table --
each dict can be passed directly to obsdb.update_obs.
- ``new_tags`` (list of tuple): rows of data for tags table
(obs_id, tag).


Notes:

In the present implementation, only changes involving adding rows
to obsdb_left (either whole obs rows or tag rows) will yield a
patchable result. Cases where some data has changed, or obs or
tags have been deleted, will simply return as unpatchable. This
is probably pretty easy to extend, should the need arise.

"""
if isinstance(obsdb_left, str):
obsdb_left = ObsDb.from_file(obsdb_left, force_new_db=False)
if isinstance(obsdb_right, str):
obsdb_right = ObsDb.from_file(obsdb_right, force_new_db=False)

def failure_declaration(reason, detail=None):
if not return_detail:
detail = None
return {'different': True,
'patchable': False,
'unpatchable_reason': reason,
'detail': detail}

full = [db.query() for db in [obsdb_left, obsdb_right]]
if full[0].keys != full[1].keys:
return failure_declaration(
'obsdb_left and obsdb_right have different column names.',
detail=[full[0].keys, full[1].keys])

# Convert to arrays.
obs_ids = [set(f['obs_id']) for f in full]

# Insist right is superset of left.
left_not_right = sorted(list(obs_ids[0].difference(obs_ids[1])))
if len(left_not_right):
return failure_declaration(
f'obsdb_left contains {len(left_not_right)} '
'obs not found in obsdb_right.',
detail=left_not_right)

# Any obs in common?
unmatched_right = np.ones(len(full[1]), bool)
common = sorted(list(obs_ids[0].intersection(obs_ids[1])))
if len(common):
common, i0, i1 = util.get_coindices(*(f['obs_id'] for f in full))
diffs = []
for i, (_i0, _i1) in enumerate(zip(i0, i1)):
if full[0][_i0] != full[1][_i1]:
diffs.append((full[0][_i0], full[1][_i1]))
if len(diffs):
return failure_declaration(
f'obsdb_left and obsdb_right have {len(diffs)} obs '
'in common, with different data.',
detail=diffs)
unmatched_right[i1] = False

# Ok finally
pd = {
'remove_obs': [],
'remove_tags': [],
'new_obs': [],
'new_tags': [],
}
for idx in unmatched_right.nonzero()[0]:
pd['new_obs'].append(full[1][idx])

# Tag check.
tags_tuples = [
list(map(tuple, db.conn.execute(
'select distinct obs_id, tag from tags '
'order by obs_id, tag').fetchall()))
for db in [obsdb_left, obsdb_right]]

# Collapse tags to single strings and eliminate duplicates.
DELIM = ':::/:::'
common, i0, i1 = util.get_coindices(*[[t[0] + DELIM + t[1] for t in tt]
for tt in tags_tuples])
if len(i0) != len(tags_tuples[0]):
return failure_declaration(
f'obsdb_left contains {len(tags_tuples[0]) - len(i0)} '
'tags not found in obsdb_right',
detail=list(set(tags_tuples[0]).difference(tags_tuples[1])))
unmatched_right = np.ones(len(tags_tuples[1]), bool)
unmatched_right[i1] = False
for idx in unmatched_right.nonzero()[0]:
pd['new_tags'].append(tags_tuples[1][idx])

different = any([(len(v) != 0) for v in pd.values()])

return {
'different': different,
'patchable': True,
'patch_data': pd,
}


def patch_obsdb(patch_data, target_db):
"""Update an ObsDb with a batch of changes.

Args:
target_db (ObsDb): the database where changes should be made.
patch_data (dict): patch information, as returned by
diff_obsdbs.

"""
assert len(patch_data['remove_obs']) == 0
assert len(patch_data['remove_tags']) == 0

for obs_entry in patch_data['new_obs']:
target_db.update_obs(obs_entry['obs_id'], obs_entry,
commit=False)

# Group new tags by obs.
tags_obsed = {}
for k, v in patch_data['new_tags']:
if k not in tags_obsed:
tags_obsed[k] = [v]
else:
tags_obsed[k].append(v)
for obs, tags in tags_obsed.items():
target_db.update_obs(obs, {}, tags=tags, commit=False)

target_db.conn.commit()
18 changes: 16 additions & 2 deletions tests/test_obsdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,21 @@
from ._helpers import mpi_multi


def get_example():
def get_example(stuff_missing=False):
# Create a new Db and add two columns.
obsdb = metadata.ObsDb()
obsdb.add_obs_columns(['timestamp float', 'hwp_speed float', 'drift string'])

# Add 10 rows.
for i in range(10):
if stuff_missing and i in [2, 4]:
continue
tags = []
if i == 6:
tags.append('cryo_problem')
if i > 7:
tags.append('planet')
else:
elif not stuff_missing:
tags.append('cmb_survey')
obsdb.update_obs(f'myobs{i}', {'timestamp': 1900000000. + i * 100,
'hwp_speed': 2.0,
Expand Down Expand Up @@ -83,6 +85,18 @@ def test_info(self):
db0 = get_example()
db0.info()

def test_diff_patch(self):
"""Use diff/patch to update one obsdb to match another."""
db0 = get_example(stuff_missing=True)
db1 = get_example()
diff = metadata.obsdb.diff_obsdbs(db0, db1)
assert diff['different']
assert diff['patchable']

metadata.obsdb.patch_obsdb(diff['patch_data'], db0)
diff = metadata.obsdb.diff_obsdbs(db0, db1)
assert not diff['different']


if __name__ == '__main__':
unittest.main()