Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 203 additions & 46 deletions sotodlib/core/metadata/obsdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,13 @@
from .. import util


TABLE_DEFS = {
'obs': [
"`obs_id` varchar(256) primary key",
"`timestamp` float",
],
'tags': [
"`obs_id` varchar(256)",
"`tag` varchar(256)",
"CONSTRAINT one_tag UNIQUE (`obs_id`, `tag`)",
],
'_indices': {
'idx_obs': 'obs(obs_id)',
'idx_tags': 'tags(obs_id)',
},
}

DBROW_ALL = '_all'

class ObsDb(object):
"""Observation database.

The ObsDb helps to associate observations, indexed by an obs_id,
(or obs_id plus some wafer_info such as wafer_slot or bandpass)
with properties of the observation that might be useful for
selecting data or for identifying metadata.

Expand All @@ -43,11 +29,7 @@ class ObsDb(object):

"""

TABLE_TEMPLATE = [
"`obs_id` varchar(256)",
]

def __init__(self, map_file=None, init_db=True):
def __init__(self, map_file=None, init_db=True, wafer_info=None):
"""Instantiate an ObsDb.

Args:
Expand All @@ -59,6 +41,12 @@ def __init__(self, map_file=None, init_db=True):
sqlite3.Connection is opened on ':memory:'.
init_db (bool): If True, then any ObsDb tables that do not
already exist in the database will be created.
wafer_info (list of str): The additional primary keys for the obs table.
The default is None, which defaults to ['obs_id'] only.
An example of an alternative is ['wafer_slot', 'bandpass'] in which case
the ObsDb will be indexed by obs_id, wafer_slot, and bandpass.
This is only required when first initialializing a database;
otherwise the primary fields are determined from the loaded database.

Notes:
If map_file is provided, the database will be connected to
Expand All @@ -75,27 +63,107 @@ def __init__(self, map_file=None, init_db=True):

self.conn.row_factory = sqlite3.Row # access columns by name
if init_db:
pkeys = ["`obs_id`"]
if wafer_info:
pkeys.extend([f"`{k}`" for k in wafer_info])

self._table_defs = {'obs': [
"`timestamp` float",
*(f"{k} varchar(256)" for k in pkeys),
f"PRIMARY KEY ({', '.join(pkeys)})"
],
'tags': [
"`tag` varchar(256)",
*(f"{k} varchar(256)" for k in pkeys),
f"PRIMARY KEY ({', '.join(pkeys)}, `tag`)"
]}

# Define indices dynamically based on primary keys
pkeys_str = ', '.join([k.strip('`') for k in pkeys])
self._indices = {
'idx_obs': f'obs({pkeys_str})',
'idx_tags': f'tags({pkeys_str})',
}

c = self.conn.cursor()
c.execute("SELECT type, name FROM sqlite_master "
"WHERE type in ('table', 'index') and name not like 'sqlite_%';")
tables = [r[1] for r in c]
changes = False
for k, v in TABLE_DEFS.items():
if k[0] != '_' and k not in tables:
for k, v in self._table_defs.items():
if k not in tables:
q = ('create table if not exists `%s` (' % k +
','.join(v) + ')')
c.execute(q)
changes = True
for index, cols in TABLE_DEFS['_indices'].items():
for index, cols in self._indices.items():
if index not in tables:
c.execute(f'CREATE INDEX IF NOT EXISTS {index} on {cols}')
changes = True
if changes:
self.conn.commit()

self.primary_keys = self._get_primary_fields(wafer_info)

def _get_primary_fields(self, wafer_info=None):
"""Retrieve the primary keys of the specified table.
This is used whether to index by obs_id or
obs_id plus additional fields defined by wafer_info."""
query = "PRAGMA table_info('obs')"
c = self.conn.execute(query)
primary_keys = [row['name'] for row in c.fetchall() if row['pk'] > 0]
if wafer_info:
pkeys = ["obs_id"]
pkeys.extend([f"{k}" for k in wafer_info])
if sorted(pkeys) != sorted(primary_keys): # sorted allows for different order
raise ValueError(f"Primary keys do not match: {primary_keys} != {pkeys}"+
f" must use `wafer_info`=={primary_keys} or create a new dB with {pkeys}")
return primary_keys

def _convert_wafer_info(self, obs_id, wafer_info):
"""Helper function to allow flexibility in way obs_id and wafer_info are passed in."""
if isinstance(wafer_info, dict):
wafer_info = tuple([wafer_info[k] for k in self.primary_keys[1:]])
if isinstance(obs_id, tuple):
if len(obs_id) == len(self.primary_keys):
wafer_info = tuple([wi for wi in obs_id[1:]])
obs_id = obs_id[0]
else:
raise ValueError(f"obs_id tuple must be of length {len(self.primary_keys)}")
if isinstance(obs_id, dict):
if len(obs_id) == len(self.primary_keys):
wafer_info = tuple([obs_id[k] for k in self.primary_keys[1:]])
obs_id = obs_id['obs_id']
else:
raise ValueError(f"obs_id dict must be of length {len(self.primary_keys)}")
return obs_id, wafer_info

def _warn_primary_keys(self, wafer_info):
"""Warn the user if the primary keys are not specified
and we're defaulting to using _all."""
if len(self.primary_keys) == 1:
return []

if len(wafer_info) != len(self.primary_keys) - 1:
raise ValueError(f"Wafer info must be of length {len(self.primary_keys) - 1}")
if wafer_info is None:
wafer_info = [None] * (len(self.primary_keys) - 1)
wafer_info = list(wafer_info)
if (None in wafer_info):
warn_str = 'WARNING: Primary key(s)'
for i, wb in enumerate(wafer_info):
if wb is None:
wafer_info[i] = DBROW_ALL
warn_str += f' wafer_info[{i}],'
warn_str += f"""
are not specified and ObsDb is indexed by {self.primary_keys}.
These keys will be set to _all.
"""
warnings.warn(warn_str, UserWarning)
return wafer_info

def __len__(self):
return self.conn.execute('select count(obs_id) from obs').fetchone()[0]

return self.conn.execute(f'SELECT COUNT({self.primary_keys[0]}) FROM obs').fetchone()[0]
def add_obs_columns(self, column_defs, ignore_duplicates=True, commit=True):
"""Add columns to the obs table.

Expand Down Expand Up @@ -157,38 +225,72 @@ def add_obs_columns(self, column_defs, ignore_duplicates=True, commit=True):
self.conn.commit()
return self

def update_obs(self, obs_id, data={}, tags=[], commit=True):
def update_obs(self, obs_id, data={}, tags=[],
wafer_info=None, commit=True):
"""Update an entry in the obs table.

Arguments:
obs_id (str): The id of the obs to update.
wafer_info (tuple of str): The wafer_info used as primary keys in addition to obs_id.
The default will be replaced with '_all' all primary keys other than obs_id
data (dict): map from column_name to value.
tags (list of str): tags to apply to this observation (if a
tag name is prefxed with '!', then the tag name will be
un-applied, i.e. cleared from this observation.

Returns:
self.
Example of ways to pass updates to obsdb when there are multiple primary keys.

1) obs_id as str and wafer_info as tuple::

obsdb.update_obs('obs_2345_xyz_110', wafer_info=('ws0', 'f090'), ...)

2) obs_id as str and wafer_info as dict::

obsdb.update_obs('obs_2345_xyz_110', wafer_info={'wafer_slot': 'ws0', 'bandpass': 'f090'}, ...)

3) obs_id as dict and wafer_info is None::

obsdb.update_obs({'obs_id': 'obs_2345_xyz_110', 'wafer_slot': 'ws0', 'bandpass': 'f090'}, ...)

4) obs_id as tuple and wafer_info is None::

obsdb.update_obs(('obs_2345_xyz_110', 'ws0', 'f090'), ...)

"""
obs_id, wafer_info = self._convert_wafer_info(obs_id, wafer_info)

obs_key = {'obs_id': obs_id}
if (len(self.primary_keys) > 1):
wafer_info = self._warn_primary_keys(wafer_info)
for i, k in enumerate(self.primary_keys[1:]):
obs_key[k] = wafer_info[i]

c = self.conn.cursor()
c.execute('INSERT OR IGNORE INTO obs (obs_id) VALUES (?)',
(obs_id,))
columns = ', '.join(obs_key.keys())
placeholders = ', '.join(['?'] * len(obs_key))
c.execute(f'INSERT OR IGNORE INTO obs ({columns}) VALUES ({placeholders})',
tuple(obs_key.values()))

if len(data.keys()):
settors = [f'{k}=?' for k in data.keys()]
c.execute('update obs set ' + ','.join(settors) + ' '
'where obs_id=?',
tuple(data.values()) + (obs_id, ))
where_str = ' AND '.join([f'{k}=?' for k in obs_key.keys()])
c.execute(f'UPDATE obs SET {", ".join(settors)} WHERE {where_str}',
tuple(data.values()) + tuple(obs_key.values()))

for t in tags:
if t[0] == '!':
# Kill this tag.
c.execute('DELETE FROM tags WHERE obs_id=? AND tag=?',
(obs_id, t[1:]))
# Kill this tag
where_str = ' AND '.join([f'{k}=?' for k in obs_key.keys()])
c.execute(f'DELETE FROM tags WHERE {where_str} AND tag=?',
tuple(obs_key.values()) + (t[1:],))
else:
c.execute('INSERT OR REPLACE INTO tags (obs_id, tag) '
'VALUES (?,?)', (obs_id, t))
if commit:
self.conn.commit()
# Add the tag for the specific primary key combination.
columns = ', '.join(list(obs_key.keys()) + ['tag'])
placeholders = ', '.join(['?'] * (len(obs_key) + 1))
c.execute(f'INSERT OR REPLACE INTO tags ({columns}) VALUES ({placeholders})',
tuple(obs_key.values()) + (t,))
if commit:
self.conn.commit()
return self

def copy(self, map_file=None, overwrite=False):
Expand Down Expand Up @@ -231,15 +333,20 @@ def from_file(cls, filename, fmt=None, force_new_db=True):
conn = common.sqlite_from_file(filename, fmt=fmt, force_new_db=force_new_db)
return cls(conn, init_db=False)

def get(self, obs_id=None, tags=None, add_prefix=''):
def get(self, obs_id=None, wafer_info=None, tags=None, add_prefix=''):
"""Returns the entry for obs_id, as an ordered dict.

If obs_id is None, returns all entries, as a ResultSet.
However, this usage is deprecated in favor of self.query().

Args:
obs_id (str): The observation id to get info for.
tags (bool): Whether or not to load and return the tags.
wafer_info (tuple of str): The wafer_info used as primary keys in addition to obs_id.
The default will be replaced with '_all' all primary keys other than obs_id
tags (bool): If True, include the tags associated with this
observation in the output. The tags will be stored in a
field called 'tags', which will be a list of strings.
If False or None, the tags will not be included in the output.
add_prefix (str): A string that will be prepended to each
field name. This is for the lazy metadata system, because
obsdb selectors are prefixed with 'obs:'.
Expand All @@ -250,17 +357,23 @@ def get(self, obs_id=None, tags=None, add_prefix=''):
requested, they will be stored in 'tags' as a list of strings.

"""
obs_id, wafer_info = self._convert_wafer_info(obs_id, wafer_info)
if obs_id is None:
return self.query('1', add_prefix=add_prefix)
results = self.query(f"obs_id='{obs_id}'", add_prefix=add_prefix)

wafer_info = self._warn_primary_keys(wafer_info)
query_text = " AND ".join([f"{key} == '{val}'" for key, val in zip(self.primary_keys, [obs_id] + wafer_info)])

results = self.query(query_text, add_prefix=add_prefix)
if len(results) == 0:
return None
if len(results) > 1:
raise ValueError('Too many rows...') # or integrity error...
output = results[0]
if tags:
# "distinct" should not be needed given uniqueness constraint.
c = self.conn.execute('select distinct tag from tags where obs_id=?', (obs_id,))
where_str = ' AND '.join([f"{k}='{v}'" for k, v in zip(self.primary_keys, [obs_id] + list(wafer_info))])
c = self.conn.execute(f'SELECT DISTINCT tag FROM tags WHERE {where_str}')
output['tags'] = [r[0] for r in c]
return output

Expand Down Expand Up @@ -335,6 +448,50 @@ def query(self, query_text='1', tags=None, sort=['obs_id'], add_prefix=''):
if add_prefix is not None:
results.keys = [add_prefix + k for k in results.keys]
return results

def query_linked_dbs(self, secondary_dbs, query_text, add_prefix='',
wafer_info=None):
"""
Query two ObsDb objects and link their results based on obs_id. Primary ObsDb
can be either keyed by obs_id or obs_id and wafer_info (such as wafer_slot and bandpass).
For every row returned from the primary database, the linked secondary databases
are queried for rows with the same obs_id (and a specific wafer_info subset if wafer_info is passed).
The results are returned as a list of tuples the first element of the tuple is the primary
database result and the rest are the linked secondary database results.

Args:
secondary_dbs (list of ObsDb): A list of secondary database to query for linked rows.
If a single ObsDb is passed, it will be converted to a list of length 1.
query_text (str): The query text for the primary database.
add_prefix (str): A string to prepend to field names in the result.
wafer_info (tuple of str): The wafer_info to restrict what's returned from the secondary
database. The default value is None, which means all wafer_info will be returned.

Returns:
results (list of ResultSet): A list containing tuples of resultsets from the primary and secondary databases.
"""
# Ensure secondary_dbs is a list
if not isinstance(secondary_dbs, list):
secondary_dbs = [secondary_dbs]

# Query the primary database
primary_results = self.query(query_text, add_prefix=add_prefix)
if len(primary_results) == 0:
return None

results = []
for pr in primary_results:
_res = (pr, )
for secondary_db in secondary_dbs:
if wafer_info:
_wafer_info = secondary_db._warn_primary_keys(wafer_info)
query_str = ' and '.join([f"{k}=='{v}'" for k, v in zip(secondary_db.primary_keys, [pr['obs_id']] + list(_wafer_info))])
else:
query_str = f"obs_id=='{pr['obs_id']}'"
secondary_result = secondary_db.query(query_str, add_prefix=add_prefix)
_res += ([sr for sr in secondary_result],)
results.append(_res)
return results

def info(self):
"""Return a dict summarizing the structure and contents of the obsdb;
Expand Down
Loading