diff --git a/sotodlib/core/metadata/obsdb.py b/sotodlib/core/metadata/obsdb.py index 96c9d6174..c9151161e 100644 --- a/sotodlib/core/metadata/obsdb.py +++ b/sotodlib/core/metadata/obsdb.py @@ -8,27 +8,13 @@ from .. import util -TABLE_DEFS = { - 'obs': [ - "`obs_id` varchar(256) primary key", - "`timestamp` float", - ], - 'tags': [ - "`obs_id` varchar(256)", - "`tag` varchar(256)", - "CONSTRAINT one_tag UNIQUE (`obs_id`, `tag`)", - ], - '_indices': { - 'idx_obs': 'obs(obs_id)', - 'idx_tags': 'tags(obs_id)', - }, -} - +DBROW_ALL = '_all' class ObsDb(object): """Observation database. The ObsDb helps to associate observations, indexed by an obs_id, + (or obs_id plus some wafer_info such as wafer_slot or bandpass) with properties of the observation that might be useful for selecting data or for identifying metadata. @@ -43,11 +29,7 @@ class ObsDb(object): """ - TABLE_TEMPLATE = [ - "`obs_id` varchar(256)", - ] - - def __init__(self, map_file=None, init_db=True): + def __init__(self, map_file=None, init_db=True, wafer_info=None): """Instantiate an ObsDb. Args: @@ -59,6 +41,12 @@ def __init__(self, map_file=None, init_db=True): sqlite3.Connection is opened on ':memory:'. init_db (bool): If True, then any ObsDb tables that do not already exist in the database will be created. + wafer_info (list of str): The additional primary keys for the obs table. + The default is None, which defaults to ['obs_id'] only. + An example of an alternative is ['wafer_slot', 'bandpass'] in which case + the ObsDb will be indexed by obs_id, wafer_slot, and bandpass. + This is only required when first initialializing a database; + otherwise the primary fields are determined from the loaded database. Notes: If map_file is provided, the database will be connected to @@ -75,27 +63,107 @@ def __init__(self, map_file=None, init_db=True): self.conn.row_factory = sqlite3.Row # access columns by name if init_db: + pkeys = ["`obs_id`"] + if wafer_info: + pkeys.extend([f"`{k}`" for k in wafer_info]) + + self._table_defs = {'obs': [ + "`timestamp` float", + *(f"{k} varchar(256)" for k in pkeys), + f"PRIMARY KEY ({', '.join(pkeys)})" + ], + 'tags': [ + "`tag` varchar(256)", + *(f"{k} varchar(256)" for k in pkeys), + f"PRIMARY KEY ({', '.join(pkeys)}, `tag`)" + ]} + + # Define indices dynamically based on primary keys + pkeys_str = ', '.join([k.strip('`') for k in pkeys]) + self._indices = { + 'idx_obs': f'obs({pkeys_str})', + 'idx_tags': f'tags({pkeys_str})', + } + c = self.conn.cursor() c.execute("SELECT type, name FROM sqlite_master " "WHERE type in ('table', 'index') and name not like 'sqlite_%';") tables = [r[1] for r in c] changes = False - for k, v in TABLE_DEFS.items(): - if k[0] != '_' and k not in tables: + for k, v in self._table_defs.items(): + if k not in tables: q = ('create table if not exists `%s` (' % k + ','.join(v) + ')') c.execute(q) changes = True - for index, cols in TABLE_DEFS['_indices'].items(): + for index, cols in self._indices.items(): if index not in tables: c.execute(f'CREATE INDEX IF NOT EXISTS {index} on {cols}') changes = True if changes: self.conn.commit() - + self.primary_keys = self._get_primary_fields(wafer_info) + + def _get_primary_fields(self, wafer_info=None): + """Retrieve the primary keys of the specified table. + This is used whether to index by obs_id or + obs_id plus additional fields defined by wafer_info.""" + query = "PRAGMA table_info('obs')" + c = self.conn.execute(query) + primary_keys = [row['name'] for row in c.fetchall() if row['pk'] > 0] + if wafer_info: + pkeys = ["obs_id"] + pkeys.extend([f"{k}" for k in wafer_info]) + if sorted(pkeys) != sorted(primary_keys): # sorted allows for different order + raise ValueError(f"Primary keys do not match: {primary_keys} != {pkeys}"+ + f" must use `wafer_info`=={primary_keys} or create a new dB with {pkeys}") + return primary_keys + + def _convert_wafer_info(self, obs_id, wafer_info): + """Helper function to allow flexibility in way obs_id and wafer_info are passed in.""" + if isinstance(wafer_info, dict): + wafer_info = tuple([wafer_info[k] for k in self.primary_keys[1:]]) + if isinstance(obs_id, tuple): + if len(obs_id) == len(self.primary_keys): + wafer_info = tuple([wi for wi in obs_id[1:]]) + obs_id = obs_id[0] + else: + raise ValueError(f"obs_id tuple must be of length {len(self.primary_keys)}") + if isinstance(obs_id, dict): + if len(obs_id) == len(self.primary_keys): + wafer_info = tuple([obs_id[k] for k in self.primary_keys[1:]]) + obs_id = obs_id['obs_id'] + else: + raise ValueError(f"obs_id dict must be of length {len(self.primary_keys)}") + return obs_id, wafer_info + + def _warn_primary_keys(self, wafer_info): + """Warn the user if the primary keys are not specified + and we're defaulting to using _all.""" + if len(self.primary_keys) == 1: + return [] + + if len(wafer_info) != len(self.primary_keys) - 1: + raise ValueError(f"Wafer info must be of length {len(self.primary_keys) - 1}") + if wafer_info is None: + wafer_info = [None] * (len(self.primary_keys) - 1) + wafer_info = list(wafer_info) + if (None in wafer_info): + warn_str = 'WARNING: Primary key(s)' + for i, wb in enumerate(wafer_info): + if wb is None: + wafer_info[i] = DBROW_ALL + warn_str += f' wafer_info[{i}],' + warn_str += f""" + are not specified and ObsDb is indexed by {self.primary_keys}. + These keys will be set to _all. + """ + warnings.warn(warn_str, UserWarning) + return wafer_info + def __len__(self): - return self.conn.execute('select count(obs_id) from obs').fetchone()[0] - + return self.conn.execute(f'SELECT COUNT({self.primary_keys[0]}) FROM obs').fetchone()[0] + def add_obs_columns(self, column_defs, ignore_duplicates=True, commit=True): """Add columns to the obs table. @@ -157,38 +225,72 @@ def add_obs_columns(self, column_defs, ignore_duplicates=True, commit=True): self.conn.commit() return self - def update_obs(self, obs_id, data={}, tags=[], commit=True): + def update_obs(self, obs_id, data={}, tags=[], + wafer_info=None, commit=True): """Update an entry in the obs table. Arguments: obs_id (str): The id of the obs to update. + wafer_info (tuple of str): The wafer_info used as primary keys in addition to obs_id. + The default will be replaced with '_all' all primary keys other than obs_id data (dict): map from column_name to value. tags (list of str): tags to apply to this observation (if a tag name is prefxed with '!', then the tag name will be un-applied, i.e. cleared from this observation. - Returns: - self. + Example of ways to pass updates to obsdb when there are multiple primary keys. + + 1) obs_id as str and wafer_info as tuple:: + + obsdb.update_obs('obs_2345_xyz_110', wafer_info=('ws0', 'f090'), ...) + 2) obs_id as str and wafer_info as dict:: + + obsdb.update_obs('obs_2345_xyz_110', wafer_info={'wafer_slot': 'ws0', 'bandpass': 'f090'}, ...) + + 3) obs_id as dict and wafer_info is None:: + + obsdb.update_obs({'obs_id': 'obs_2345_xyz_110', 'wafer_slot': 'ws0', 'bandpass': 'f090'}, ...) + + 4) obs_id as tuple and wafer_info is None:: + + obsdb.update_obs(('obs_2345_xyz_110', 'ws0', 'f090'), ...) + """ + obs_id, wafer_info = self._convert_wafer_info(obs_id, wafer_info) + + obs_key = {'obs_id': obs_id} + if (len(self.primary_keys) > 1): + wafer_info = self._warn_primary_keys(wafer_info) + for i, k in enumerate(self.primary_keys[1:]): + obs_key[k] = wafer_info[i] + c = self.conn.cursor() - c.execute('INSERT OR IGNORE INTO obs (obs_id) VALUES (?)', - (obs_id,)) + columns = ', '.join(obs_key.keys()) + placeholders = ', '.join(['?'] * len(obs_key)) + c.execute(f'INSERT OR IGNORE INTO obs ({columns}) VALUES ({placeholders})', + tuple(obs_key.values())) + if len(data.keys()): settors = [f'{k}=?' for k in data.keys()] - c.execute('update obs set ' + ','.join(settors) + ' ' - 'where obs_id=?', - tuple(data.values()) + (obs_id, )) + where_str = ' AND '.join([f'{k}=?' for k in obs_key.keys()]) + c.execute(f'UPDATE obs SET {", ".join(settors)} WHERE {where_str}', + tuple(data.values()) + tuple(obs_key.values())) + for t in tags: if t[0] == '!': - # Kill this tag. - c.execute('DELETE FROM tags WHERE obs_id=? AND tag=?', - (obs_id, t[1:])) + # Kill this tag + where_str = ' AND '.join([f'{k}=?' for k in obs_key.keys()]) + c.execute(f'DELETE FROM tags WHERE {where_str} AND tag=?', + tuple(obs_key.values()) + (t[1:],)) else: - c.execute('INSERT OR REPLACE INTO tags (obs_id, tag) ' - 'VALUES (?,?)', (obs_id, t)) - if commit: - self.conn.commit() + # Add the tag for the specific primary key combination. + columns = ', '.join(list(obs_key.keys()) + ['tag']) + placeholders = ', '.join(['?'] * (len(obs_key) + 1)) + c.execute(f'INSERT OR REPLACE INTO tags ({columns}) VALUES ({placeholders})', + tuple(obs_key.values()) + (t,)) + if commit: + self.conn.commit() return self def copy(self, map_file=None, overwrite=False): @@ -231,7 +333,7 @@ def from_file(cls, filename, fmt=None, force_new_db=True): conn = common.sqlite_from_file(filename, fmt=fmt, force_new_db=force_new_db) return cls(conn, init_db=False) - def get(self, obs_id=None, tags=None, add_prefix=''): + def get(self, obs_id=None, wafer_info=None, tags=None, add_prefix=''): """Returns the entry for obs_id, as an ordered dict. If obs_id is None, returns all entries, as a ResultSet. @@ -239,7 +341,12 @@ def get(self, obs_id=None, tags=None, add_prefix=''): Args: obs_id (str): The observation id to get info for. - tags (bool): Whether or not to load and return the tags. + wafer_info (tuple of str): The wafer_info used as primary keys in addition to obs_id. + The default will be replaced with '_all' all primary keys other than obs_id + tags (bool): If True, include the tags associated with this + observation in the output. The tags will be stored in a + field called 'tags', which will be a list of strings. + If False or None, the tags will not be included in the output. add_prefix (str): A string that will be prepended to each field name. This is for the lazy metadata system, because obsdb selectors are prefixed with 'obs:'. @@ -250,9 +357,14 @@ def get(self, obs_id=None, tags=None, add_prefix=''): requested, they will be stored in 'tags' as a list of strings. """ + obs_id, wafer_info = self._convert_wafer_info(obs_id, wafer_info) if obs_id is None: return self.query('1', add_prefix=add_prefix) - results = self.query(f"obs_id='{obs_id}'", add_prefix=add_prefix) + + wafer_info = self._warn_primary_keys(wafer_info) + query_text = " AND ".join([f"{key} == '{val}'" for key, val in zip(self.primary_keys, [obs_id] + wafer_info)]) + + results = self.query(query_text, add_prefix=add_prefix) if len(results) == 0: return None if len(results) > 1: @@ -260,7 +372,8 @@ def get(self, obs_id=None, tags=None, add_prefix=''): output = results[0] if tags: # "distinct" should not be needed given uniqueness constraint. - c = self.conn.execute('select distinct tag from tags where obs_id=?', (obs_id,)) + where_str = ' AND '.join([f"{k}='{v}'" for k, v in zip(self.primary_keys, [obs_id] + list(wafer_info))]) + c = self.conn.execute(f'SELECT DISTINCT tag FROM tags WHERE {where_str}') output['tags'] = [r[0] for r in c] return output @@ -335,6 +448,50 @@ def query(self, query_text='1', tags=None, sort=['obs_id'], add_prefix=''): if add_prefix is not None: results.keys = [add_prefix + k for k in results.keys] return results + + def query_linked_dbs(self, secondary_dbs, query_text, add_prefix='', + wafer_info=None): + """ + Query two ObsDb objects and link their results based on obs_id. Primary ObsDb + can be either keyed by obs_id or obs_id and wafer_info (such as wafer_slot and bandpass). + For every row returned from the primary database, the linked secondary databases + are queried for rows with the same obs_id (and a specific wafer_info subset if wafer_info is passed). + The results are returned as a list of tuples the first element of the tuple is the primary + database result and the rest are the linked secondary database results. + + Args: + secondary_dbs (list of ObsDb): A list of secondary database to query for linked rows. + If a single ObsDb is passed, it will be converted to a list of length 1. + query_text (str): The query text for the primary database. + add_prefix (str): A string to prepend to field names in the result. + wafer_info (tuple of str): The wafer_info to restrict what's returned from the secondary + database. The default value is None, which means all wafer_info will be returned. + + Returns: + results (list of ResultSet): A list containing tuples of resultsets from the primary and secondary databases. + """ + # Ensure secondary_dbs is a list + if not isinstance(secondary_dbs, list): + secondary_dbs = [secondary_dbs] + + # Query the primary database + primary_results = self.query(query_text, add_prefix=add_prefix) + if len(primary_results) == 0: + return None + + results = [] + for pr in primary_results: + _res = (pr, ) + for secondary_db in secondary_dbs: + if wafer_info: + _wafer_info = secondary_db._warn_primary_keys(wafer_info) + query_str = ' and '.join([f"{k}=='{v}'" for k, v in zip(secondary_db.primary_keys, [pr['obs_id']] + list(_wafer_info))]) + else: + query_str = f"obs_id=='{pr['obs_id']}'" + secondary_result = secondary_db.query(query_str, add_prefix=add_prefix) + _res += ([sr for sr in secondary_result],) + results.append(_res) + return results def info(self): """Return a dict summarizing the structure and contents of the obsdb; diff --git a/tests/test_obsdb.py b/tests/test_obsdb.py index 5f7dbf0a2..3ef82d3ba 100644 --- a/tests/test_obsdb.py +++ b/tests/test_obsdb.py @@ -29,6 +29,28 @@ def get_example(stuff_missing=False): tags=tags) return obsdb +def get_owb_example(): + # Create a new Db keyed by obsid, wafer_slot, bandpass, and add two columns. + obsdb = metadata.ObsDb(wafer_info=('wafer_slot', 'bandpass')) + obsdb.add_obs_columns(['timestamp float', 'data1 int', 'data2 int']) + + wafer_slots = ['ws0', 'ws1', 'ws2', 'ws3', 'ws4', 'ws5', 'ws6'] + band_passes = ['f090', 'f150'] + data1 = 0 + data2 = 10000 + for i in range(3): + if i == 2: + tags = ['max_was_here'] + else: + tags = [] + for w in range(7): + for b in range(2): + obsdb.update_obs(f'myobs{i}', wafer_info=(wafer_slots[w], band_passes[b]), + data = {'timestamp': 1900000000. + i * 100, + 'data1':data1, 'data2':data2}, tags=tags) + data1+=1 + data2+=1 + return obsdb @unittest.skipIf(mpi_multi(), "Running with multiple MPI processes") class TestObsDb(unittest.TestCase): @@ -44,6 +66,15 @@ def test_smoke(self): db.query('timestamp > 0') db.query('timestamp > 0', tags=['cryo_problem=1']) + def test_smoke_owb(self): + """Basic functionality.""" + db = get_owb_example() + all_obs = db.query() + q1 = db.get(all_obs[0]['obs_id'], wafer_info=('ws3', 'f090')) + self.assertEqual(q1['data1'], 6) + rs = db.query('timestamp > 0', tags=['max_was_here=1']) + self.assertEqual(len(rs), 14) + def test_query(self): db = get_example() r0 = db.query("drift == 'rising'") @@ -55,6 +86,16 @@ def test_query(self): with self.assertWarns(UserWarning): r1 = db.query('drift == "setting"') + def test_owb_query(self): + """Test querying with wafer_info keys.""" + db = get_owb_example() + # Query for a specific wafer_slot and bandpass + results = db.query("timestamp > 0 and wafer_slot =='ws3' and bandpass == 'f090'") + self.assertGreater(len(results), 0) + for row in results: + self.assertEqual(row['wafer_slot'], 'ws3') + self.assertEqual(row['bandpass'], 'f090') + def test_tags(self): db = get_example() r0 = db.query(tags=['planet=1', 'cryo_problem', 'not_a_tag']) @@ -65,6 +106,78 @@ def test_tags(self): self.assertTrue(k in r0.keys) self.assertTrue(k in r1.keys) + def test_owb_tags(self): + """Test tags with wafer_info keys.""" + db = get_owb_example() + db.update_obs('myobs0', wafer_info=('ws3', 'f090'), tags=['test_tag']) + result = db.get('myobs0', wafer_info=('ws3', 'f090'), tags=True) + self.assertIn('test_tag', result['tags']) + result_other = db.get('myobs0', wafer_info=('ws4', 'f090'), tags=True) + self.assertNotIn('test_tag', result_other['tags']) + + def test_owb_tag_deletion(self): + """Test deleting tags for specific wafer_info keys.""" + db = get_owb_example() + db.update_obs('myobs0', wafer_info=('ws3', 'f090'), tags=['delete_me']) + db.update_obs('myobs0', wafer_info=('ws3', 'f090'), tags=['!delete_me']) + result = db.get('myobs0', wafer_info=('ws3', 'f090'), tags=True) + self.assertNotIn('delete_me', result.get('tags', [])) + db.update_obs('myobs0', wafer_info=('ws4', 'f090'), tags=['delete_me']) + result_other = db.get('myobs0', wafer_info=('ws4', 'f090'), tags=True) + self.assertIn('delete_me', result_other['tags']) + + def test_linked_query(self): + obs_db = get_example() + owb_db = get_owb_example() + query_res = obs_db.query_linked_dbs(owb_db, 'obs_id == "myobs2"') + self.assertIsInstance(query_res, list) + self.assertEqual(len(query_res), 1) + self.assertIsInstance(query_res[0], tuple) + self.assertEqual(len(query_res[0]), 2) + self.assertEqual(query_res[0][0]['drift'], "setting") + self.assertEqual(len(query_res[0][1]), 14) + + def test_owb_primary_linked_query(self): + """Test linked queries with wafer_info keys.""" + obs_db = get_example() + owb_db = get_owb_example() + query_res = obs_db.query_linked_dbs(owb_db, 'obs_id == "myobs0"', wafer_info=("ws3", "f090")) + self.assertIsInstance(query_res, list) + self.assertGreater(len(query_res), 0) + for primary, linked in query_res: + self.assertEqual(primary['obs_id'], 'myobs0') + for row in linked: + self.assertEqual(row['wafer_slot'], 'ws3') + self.assertEqual(row['bandpass'], 'f090') + + def test_owb_update(self): + """ + Test updating data for specific wafer_info keys. + Tests all 4 possible ways to update data. + """ + db = get_owb_example() + # Method 1 + db.update_obs('myobs0', wafer_info=('ws3', 'f090'), data={'data1': 999}) + result = db.get('myobs0', wafer_info=('ws3', 'f090')) + self.assertEqual(result['data1'], 999) + result_other = db.get('myobs0', wafer_info=('ws4', 'f090')) + self.assertNotEqual(result_other['data1'], 999) + # Method 2 + db.update_obs('myobs0', wafer_info={'wafer_slot':'ws3', 'bandpass':'f150'}, + data={'data1': 998}) + result_m2 = db.get('myobs0', wafer_info={'wafer_slot':'ws3', 'bandpass':'f150'}) + self.assertEqual(result_m2['data1'], 998) + # Method 3 + db.update_obs({'obs_id':'myobs0', 'wafer_slot':'ws5', 'bandpass':'f090'}, + data={'data1': 997}) + result_m3 = db.get({'wafer_slot':'ws5', 'obs_id':'myobs0', 'bandpass':'f090'}) + self.assertEqual(result_m3['data1'], 997) + # Method 4 + db.update_obs(('myobs0', 'ws5', 'f150'), + data={'data1': 996}) + result_m4 = db.get(('myobs0', 'ws5', 'f150')) + self.assertEqual(result_m4['data1'], 996) + def test_io(self): """Check to_file and from_file.""" db0 = get_example()