Skip to content

Commit b421c0e

Browse files
committed
Merge remote-tracking branch 'origin/main' into release
2 parents 4898baa + e5dcfad commit b421c0e

4 files changed

Lines changed: 57 additions & 1 deletion

File tree

aepsych/database/tables.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,6 +591,20 @@ def update(db: Any, engine: Engine) -> None:
591591
outcome_name="outcome_" + str(j),
592592
outcome_value=float(outcome_value),
593593
)
594+
else: # Raws are already in, so we just need to update it
595+
for master_table in db.get_master_records():
596+
unique_id = master_table.unique_id
597+
raws = db.get_raw_for(unique_id)
598+
tells = [
599+
message
600+
for message in db.get_replay_for(unique_id)
601+
if message.message_type == "tell"
602+
]
603+
604+
if len(raws) == len(tells):
605+
for raw, tell in zip(raws, tells):
606+
if tell.extra_info is not None and len(tell.extra_info) > 0:
607+
raw.extra_data = tell.extra_info
594608
else:
595609
db.record_outcome(
596610
raw_table=db_raw_record,

aepsych/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@
55
# This source code is licensed under the license found in the
66
# LICENSE file in the root directory of this source tree.
77

8-
__version__ = "0.7.0"
8+
__version__ = "0.7.0+dev"

tests/test_databases/extra_info.db

172 KB
Binary file not shown.

tests/test_db.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,48 @@ def test_update_db_with_raw_data_tables(self):
252252

253253
test_database.delete_db()
254254

255+
def test_update_db_with_raw_extra_data(self):
256+
current_path = Path(os.path.abspath(__file__)).parent
257+
db_path = current_path
258+
db_path = db_path.joinpath("test_databases/extra_info.db")
259+
260+
# copy the db to a new file
261+
dst_db_path = Path(self._dbname)
262+
shutil.copy(str(db_path), str(dst_db_path))
263+
self.assertTrue(dst_db_path.is_file())
264+
265+
# sleep to ensure db is ready
266+
time.sleep(0.1)
267+
268+
# open the new db
269+
test_database = db.Database(db_path=dst_db_path.as_posix(), update=False)
270+
271+
replay_tells = [
272+
row for row in test_database.get_replay_for(1) if row.message_type == "tell"
273+
]
274+
275+
# Make sure that update is required
276+
self.assertTrue(test_database.is_update_required())
277+
278+
# Update the database
279+
test_database.perform_updates()
280+
281+
# The trial numbers line up with tells
282+
none_rows = 0
283+
for row in test_database.get_raw_for(1):
284+
if row.extra_data is None:
285+
none_rows += 1
286+
else:
287+
self.assertTrue(row.unique_id == row.extra_data["trial_number"])
288+
self.assertTrue(row.extra_data["extra"] == "info")
289+
290+
# Exactly one row should be none
291+
self.assertTrue(none_rows == 1)
292+
293+
self.assertFalse(test_database.is_update_required())
294+
295+
test_database.delete_db()
296+
255297
def test_update_configs(self):
256298
config_str = """
257299
[common]

0 commit comments

Comments
 (0)