Skip to content

Commit 69e56e0

Browse files
committed
fixed correlation association tests
1 parent 92bbfc7 commit 69e56e0

File tree

4 files changed

+102
-111
lines changed

4 files changed

+102
-111
lines changed

breadbox/breadbox/crud/associations.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
from pyasn1.codec.ber.encoder import encode
2+
13
from breadbox.models.dataset import PrecomputedAssociation, Dataset, MatrixDataset
2-
from sqlalchemy import or_
34
from breadbox.db.session import SessionWithUser
4-
import sqlite3
55
from . import dataset as dataset_crud
66
from breadbox.schemas.custom_http_exception import (
77
ResourceNotFoundError,
@@ -12,19 +12,16 @@
1212
from ..service import metadata
1313
from ..crud import access_control
1414
from typing import Optional
15+
import packed_cor_tables
1516

1617

1718
def _validate_association_table(
1819
dataset_1_given_ids: set[str], dataset_2_given_ids: set[str], filename: str
1920
):
2021
"""Opens file as a sqlite3 db and verifies the dimensions have the expected IDs. Raises an UserError if any issues found"""
2122

22-
conn = sqlite3.connect(filename)
23-
cur = conn.cursor()
24-
2523
def check_given_ids(expected_given_ids: set[str], dim: str):
26-
cur.execute(f"SELECT given_id from dim_{dim}_given_id")
27-
assoc_dataset_given_ids = set([x[0] for x in cur.fetchall()])
24+
assoc_dataset_given_ids = packed_cor_tables.get_given_ids(filename, dim)
2825
missing = expected_given_ids.difference(assoc_dataset_given_ids)
2926
if len(missing) > 0:
3027
assoc_sample = sorted(assoc_dataset_given_ids)[:10]
@@ -34,14 +31,11 @@ def check_given_ids(expected_given_ids: set[str], dim: str):
3431
f"The given IDs in the association table do not match the IDs in the dataset. ({len(missing)} IDs missing). Examples from association table: {assoc_sample}, examples from dataset: {expected_sample}, examples of missing IDs: {missing_sample}"
3532
)
3633

37-
try:
38-
check_given_ids(dataset_1_given_ids, "0")
39-
check_given_ids(dataset_2_given_ids, "1")
40-
except sqlite3.OperationalError as ex:
41-
raise UserError("Invalid association table") from ex
42-
finally:
43-
cur.close()
44-
conn.close()
34+
try:
35+
check_given_ids(dataset_1_given_ids, "0")
36+
check_given_ids(dataset_2_given_ids, "1")
37+
except packed_cor_tables.InvalidAssociationTable as ex:
38+
raise UserError("Invalid association table") from ex
4539

4640

4741
def add_association_table(

breadbox/breadbox/service/associations.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from breadbox.db.session import SessionWithUser
33
from depmap_compute.slice import SliceQuery
44
from breadbox.schemas.associations import Associations, Association, DatasetSummary
5-
import sqlite3
65
from breadbox.crud import associations as associations_crud
76
from breadbox.crud import dataset as dataset_crud
87
from breadbox.schemas.custom_http_exception import (

breadbox/tests/api/test_temp_associations.py

Lines changed: 75 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -9,58 +9,8 @@
99
import sqlite3
1010
from glob import glob
1111

12-
13-
def create_assoc_table(filename, dataset_1_features, dataset_2_features, rows):
14-
conn = sqlite3.connect(filename)
15-
create_sql = """CREATE TABLE IF NOT EXISTS "correlation" (
16-
"dim_0" INTEGER,
17-
"dim_1" INTEGER,
18-
"cor" REAL
19-
);
20-
CREATE TABLE IF NOT EXISTS "dim_0_label_position" (
21-
"label" TEXT,
22-
"position" INTEGER
23-
);
24-
CREATE TABLE IF NOT EXISTS "dim_1_label_position" (
25-
"label" TEXT,
26-
"position" INTEGER
27-
);
28-
CREATE TABLE IF NOT EXISTS "dataset" (
29-
"dataset" INTEGER,
30-
"filename" TEXT,
31-
"label" TEXT
32-
);
33-
CREATE INDEX dim_0_label_position_idx_1 ON dim_0_label_position (label);
34-
CREATE INDEX dim_0_label_position_idx_2 ON dim_0_label_position (position);
35-
CREATE INDEX dim_1_label_position_idx_1 ON dim_1_label_position (label);
36-
CREATE INDEX dim_1_label_position_idx_2 ON dim_1_label_position (position);
37-
CREATE INDEX correlation_idx_1 ON correlation (dim_1, cor);
38-
CREATE INDEX correlation_idx_0 ON correlation (dim_0, cor);"""
39-
for stmt in create_sql.split(";"):
40-
conn.execute(stmt)
41-
42-
def populate_labels(dim, dataset_features):
43-
rows = []
44-
label_to_index = {}
45-
for i, label in enumerate(dataset_features):
46-
rows.append((label, i))
47-
label_to_index[label] = i
48-
conn.executemany(
49-
f"insert into dim_{dim}_label_position (label, position) values (?, ?)",
50-
rows,
51-
)
52-
return label_to_index
53-
54-
f1idx = populate_labels(0, dataset_1_features)
55-
f2idx = populate_labels(1, dataset_2_features)
56-
57-
for f1_label, f2_label, cor in rows:
58-
conn.execute(
59-
"insert into correlation (dim_0, dim_1, cor) values (?, ?, ?)",
60-
(f1idx[f1_label], f2idx[f2_label], cor),
61-
)
62-
conn.commit()
63-
conn.close()
12+
import pytest
13+
import packed_cor_tables
6414

6515

6616
def test_associations(
@@ -109,15 +59,46 @@ def create_matrix_dataset(sample_count, feature_count):
10959
dataset_2 = create_matrix_dataset(3, dataset_2_feature_count)
11060
minimal_db.commit()
11161

112-
assoc_table = str(tmpdir.join("assoc.sqlite3"))
113-
create_assoc_table(
114-
assoc_table,
115-
[f"feature{i}" for i in range(dataset_1_feature_count)],
116-
[f"feature{i}" for i in range(dataset_2_feature_count)],
117-
[["feature0", "feature0", 0.1], ["feature0", "feature1", 0.2]],
62+
assoc_table_1 = str(tmpdir.join("assoc.sqlite3"))
63+
64+
packed_cor_tables.write_cor_df(
65+
pd.DataFrame(
66+
{
67+
"dim_0": [0, 0],
68+
"dim_1": [0, 1],
69+
"cor": [0.1, 0.2],
70+
"log10qvalue": [-12.5, -10.1],
71+
}
72+
),
73+
packed_cor_tables.InputMatrixDesc(
74+
given_ids=[f"feature{i}" for i in range(dataset_1_feature_count)],
75+
taiga_id="ds1",
76+
name="ds1",
77+
),
78+
packed_cor_tables.InputMatrixDesc(
79+
given_ids=[f"feature{i}" for i in range(dataset_2_feature_count)],
80+
taiga_id="ds1",
81+
name="ds1",
82+
),
83+
assoc_table_1,
11884
)
11985

120-
file_ids, expected_md5 = upload_and_get_file_ids(client, filename=assoc_table)
86+
# associations in the flipped direction
87+
# assoc_table_2 = str(tmpdir.join("assoc.sqlite3"))
88+
#
89+
# packed_cor_tables.write_cor_df(
90+
# pd.DataFrame({"dim_0": [0], "dim_1": [0], "cor": [0.1], "log10qvalue": [-9.1]}),
91+
# packed_cor_tables.InputMatrixDesc(given_ids=[f"feature{i}" for i in range(dataset_2_feature_count)],
92+
# taiga_id="ds1",
93+
# name="ds1", ),
94+
# packed_cor_tables.InputMatrixDesc(given_ids=[f"feature{i}" for i in range(dataset_1_feature_count)],
95+
# taiga_id="ds1",
96+
# name="ds1",
97+
# ),
98+
# assoc_table_1,
99+
# )
100+
101+
file_ids, expected_md5 = upload_and_get_file_ids(client, filename=assoc_table_1)
121102

122103
# first upload attempt: should fail because user doesn't have access
123104
response = client.post(
@@ -174,44 +155,43 @@ def create_matrix_dataset(sample_count, feature_count):
174155
assert_status_ok(response)
175156
response_content = response.json()
176157
assert len(response_content["associated_datasets"]) == 1
177-
expected = [
178-
{
179-
"correlation": 0.2,
180-
"other_dataset_id": dataset_2.id,
181-
"other_dimension_given_id": "feature1",
182-
"other_dimension_label": "feature1",
183-
},
184-
{
185-
"correlation": 0.1,
186-
"other_dataset_id": dataset_2.id,
187-
"other_dimension_given_id": "feature0",
188-
"other_dimension_label": "feature0",
189-
},
190-
]
191-
assert response_content["associated_dimensions"] == expected
192-
193-
# query feature 0 in dataset 2 (this one only has one correlation stored)
194-
response = client.post(
195-
"/temp/associations/query-slice",
196-
json={
197-
"identifier_type": "feature_id",
198-
"dataset_id": dataset_2.id,
199-
"identifier": "feature0",
200-
},
201-
headers={"X-Forwarded-User": "anon"},
158+
fa1, fa2 = sorted(
159+
response_content["associated_dimensions"],
160+
key=lambda x: x["other_dimension_given_id"],
202161
)
203162

204-
assert_status_ok(response)
205-
response_content = response.json()
206-
assert len(response_content["associated_datasets"]) == 1
207-
assert response_content["associated_dimensions"] == [
208-
{
209-
"correlation": 0.1,
210-
"other_dataset_id": dataset_1.id,
211-
"other_dimension_given_id": "feature0",
212-
"other_dimension_label": "feature0",
213-
}
214-
]
163+
assert fa1["other_dimension_given_id"] == "feature0"
164+
assert fa1["other_dimension_label"] == "feature0"
165+
assert fa1["other_dataset_id"] == dataset_2.id
166+
assert fa1["correlation"] == pytest.approx(0.1)
167+
168+
assert fa2["other_dimension_given_id"] == "feature1"
169+
assert fa2["other_dimension_label"] == "feature1"
170+
assert fa2["other_dataset_id"] == dataset_2.id
171+
assert fa2["correlation"] == pytest.approx(0.2)
172+
173+
# # query feature 0 in dataset 2 (this one only has one correlation stored)
174+
# response = client.post(
175+
# "/temp/associations/query-slice",
176+
# json={
177+
# "identifier_type": "feature_id",
178+
# "dataset_id": dataset_2.id,
179+
# "identifier": "feature0",
180+
# },
181+
# headers={"X-Forwarded-User": "anon"},
182+
# )
183+
#
184+
# assert_status_ok(response)
185+
# response_content = response.json()
186+
# assert len(response_content["associated_datasets"]) == 1
187+
# assert response_content["associated_dimensions"] == [
188+
# {
189+
# "correlation": 0.1,
190+
# "other_dataset_id": dataset_1.id,
191+
# "other_dimension_given_id": "feature0",
192+
# "other_dimension_label": "feature0",
193+
# }
194+
# ]
215195

216196
# now, delete it
217197
response = client.delete(f"/temp/associations/{assoc_id}")

packed-cor-tables/packed_cor_tables/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,21 @@ def read_cor_for_given_id(filename, feature_id):
170170
conn.close()
171171

172172
return df.drop(columns=["dim_0", "dim_1"])
173+
174+
175+
class InvalidAssociationTable(Exception):
176+
pass
177+
178+
179+
def get_given_ids(filename: str, dim: str):
180+
conn = sqlite3.connect(filename)
181+
cur = conn.cursor()
182+
try:
183+
cur.execute(f"SELECT given_id from dim_{dim}_given_id")
184+
assoc_dataset_given_ids = set([x[0] for x in cur.fetchall()])
185+
return assoc_dataset_given_ids
186+
except sqlite3.OperationalError as ex:
187+
raise InvalidAssociationTable() from ex
188+
finally:
189+
cur.close()
190+
conn.close()

0 commit comments

Comments
 (0)