Skip to content

Commit 09e7e44

Browse files
authored
Merge pull request #13 from sarda-devesh/main
Feedback puller scripts
2 parents 29d1551 + df88b0c commit 09e7e44

File tree

5 files changed

+245
-8
lines changed

5 files changed

+245
-8
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,4 +160,5 @@ cython_debug/
160160
#.idea/
161161

162162
macrostrat_db_insertion/actual_macrostrat.json
163-
macrostrat_db_insertion/temp_data
163+
macrostrat_db_insertion/temp_data
164+
retraining_runner/feedback_training_dataset

README.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,3 +76,18 @@ with the database.
7676
## Frontend React component
7777

7878
The frontend React component can be found in this repo: [UW-Macrostrat/web-components](https://github.com/UW-Macrostrat/web-components/tree/main/packages/feedback-components). A current version of this feedback component can be found at [http://cosmos0003.chtc.wisc.edu:3000/?path=/docs/feedback-components-feedbackcomponent--docs](http://cosmos0003.chtc.wisc.edu:3000/?path=/docs/feedback-components-feedbackcomponent--docs)
79+
80+
## Feedback puller
81+
82+
Our training scripts use a different format to represent the relationships than the schema defined in `macrostrat_db_insertion/macrostrat_xdd_schema.sql`. Thus we wrote a script (`retraining_runner/feedback_puller.py`) which reads in the feedback from the database and converts it to a format required by the training scripts. The followings arguments must be specified to the script:
83+
```
84+
usage: feedback_puller.py [-h] --uri URI --schema SCHEMA --save_dir SAVE_DIR
85+
86+
options:
87+
-h, --help show this help message and exit
88+
--uri URI The URI to use to connect to the database
89+
--schema SCHEMA The schema to connect to
90+
--save_dir SAVE_DIR The directory to save the results to
91+
```
92+
93+
Checkout this README on how to train the `unsupervised_kg` model on this feedback dataset: [https://github.com/UW-Macrostrat/unsupervised-kg?tab=readme-ov-file#spanbert-training](https://github.com/UW-Macrostrat/unsupervised-kg?tab=readme-ov-file#spanbert-training)
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
{
2+
"nodes": [
3+
{
4+
"id": 392659,
5+
"type": 1,
6+
"name": "Kubler",
7+
"txt_range": [
8+
[
9+
204,
10+
210
11+
]
12+
],
13+
"reasoning": null,
14+
"match": null
15+
},
16+
{
17+
"id": 392660,
18+
"type": 2,
19+
"name": "The top of the sedimentary infill",
20+
"txt_range": [
21+
[
22+
0,
23+
33
24+
]
25+
],
26+
"reasoning": null,
27+
"match": null
28+
},
29+
{
30+
"id": 392662,
31+
"type": 3,
32+
"name": "sedimentary",
33+
"txt_range": [
34+
[
35+
15,
36+
26
37+
]
38+
],
39+
"reasoning": null,
40+
"match": null
41+
},
42+
{
43+
"id": 392664,
44+
"type": 3,
45+
"name": "quartz",
46+
"txt_range": [
47+
[
48+
78,
49+
84
50+
]
51+
],
52+
"reasoning": null,
53+
"match": {
54+
"type": "lith_att",
55+
"id": 94
56+
}
57+
},
58+
{
59+
"id": 392668,
60+
"type": 3,
61+
"name": "fine",
62+
"txt_range": [
63+
[
64+
508,
65+
512
66+
]
67+
],
68+
"reasoning": null,
69+
"match": {
70+
"type": "lith_att",
71+
"id": 45
72+
}
73+
},
74+
{
75+
"id": -1,
76+
"type": 2,
77+
"name": "horizon",
78+
"txt_range": [
79+
[
80+
294,
81+
301
82+
]
83+
],
84+
"reasoning": null,
85+
"match": null
86+
},
87+
{
88+
"id": -2,
89+
"type": 3,
90+
"name": "acidic",
91+
"txt_range": [
92+
[
93+
433,
94+
439
95+
]
96+
],
97+
"reasoning": null,
98+
"match": null
99+
}
100+
],
101+
"edges": [
102+
{
103+
"source": 392659,
104+
"dest": 392660
105+
},
106+
{
107+
"source": 392660,
108+
"dest": 392662
109+
},
110+
{
111+
"source": 392660,
112+
"dest": 392664
113+
},
114+
{
115+
"source": 392660,
116+
"dest": 392668
117+
},
118+
{
119+
"source": -1,
120+
"dest": -2
121+
}
122+
],
123+
"sourceTextId": 22950,
124+
"supersedesRunIds": [
125+
26730
126+
]
127+
}

macrostrat_db_insertion/server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def get_weaviate_text_id(source_text, request_additional_data, session: Session)
240240
sources_values["source_text_type"] = curr_text_type
241241

242242
sources_insert_statement = INSERT_STATEMENT(sources_table).values(**sources_values)
243-
sources_insert_statement = sources_insert_statement.on_conflict_do_nothing(index_elements = ["source_text_type", "paragraph_text"])
243+
sources_insert_statement = sources_insert_statement.on_conflict_do_nothing(index_elements = ["source_text_type", "hashed_text"])
244244
session.execute(sources_insert_statement)
245245
session.commit()
246246
except:
@@ -250,7 +250,7 @@ def get_weaviate_text_id(source_text, request_additional_data, session: Session)
250250
try:
251251
source_id_select_statement = SELECT_STATEMENT(sources_table.c.id)
252252
source_id_select_statement = source_id_select_statement.where(sources_table.c.source_text_type == curr_text_type)
253-
source_id_select_statement = source_id_select_statement.where(sources_table.c.paragraph_text == source_text["paragraph_text"])
253+
source_id_select_statement = source_id_select_statement.where(sources_table.c.hashed_text == paragraph_hash)
254254
source_id_result = session.execute(source_id_select_statement).all()
255255

256256
# Ensure we got a result

retraining_runner/feedback_puller.py

Lines changed: 99 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,15 @@
33
from sqlalchemy.orm import sessionmaker, declarative_base
44
from sqlalchemy import select as SELECT_STATEMENT
55
import argparse
6+
import os
7+
import pandas as pd
8+
import numpy as np
69

710
def read_args():
811
parser = argparse.ArgumentParser()
912
parser.add_argument("--uri", type=str, required=True, help="The URI to use to connect to the database")
1013
parser.add_argument("--schema", type=str, required=True, help="The schema to connect to")
14+
parser.add_argument("--save_dir", type=str, required=True, help="The directory to save the results to")
1115
return parser.parse_args()
1216

1317
def load_sqlalchemy(args):
@@ -44,6 +48,34 @@ def get_all_user_runs(connection_details):
4448

4549
return all_user_runs
4650

51+
def get_entity_name(connection_details, entity_id):
52+
# Load the entity table
53+
entities_table_name = get_complete_table_name(connection_details, "entity")
54+
entities_table = connection_details["metadata"].tables[entities_table_name]
55+
entities_select_statement = SELECT_STATEMENT(entities_table)
56+
entities_select_statement = entities_select_statement.where(entities_table.c.id == entity_id)
57+
58+
# Run the query
59+
entities_select_result = connection_details["session"].execute(entities_select_statement).all()
60+
if len(entities_select_result) == 0:
61+
raise Exception("Can't find entity with id " + str(entities_select_result))
62+
63+
return entities_select_result[0]._mapping["name"].strip()
64+
65+
def get_relationship_type(connection_details, relationship_type_id):
66+
# Load the relationship type table
67+
relationship_table_name = get_complete_table_name(connection_details, "relationship_type")
68+
relationship_table = connection_details["metadata"].tables[relationship_table_name]
69+
relationship_select_statement = SELECT_STATEMENT(relationship_table)
70+
relationship_select_statement = relationship_select_statement.where(relationship_table.c.id == relationship_type_id)
71+
72+
# Run the query
73+
relationship_select_result = connection_details["session"].execute(relationship_select_statement).all()
74+
if len(relationship_select_result) == 0:
75+
raise Exception("Can't find relationship type with id " + str(relationship_select_result))
76+
77+
return relationship_select_result[0]._mapping["name"].strip()
78+
4779
def get_user_run_relationships(connection_details, save_dir, run_id, source_text_id):
4880
# Load the source text
4981
texts_table_name = get_complete_table_name(connection_details, "source_text")
@@ -55,19 +87,77 @@ def get_user_run_relationships(connection_details, save_dir, run_id, source_text
5587
if len(text_select_result) == 0:
5688
raise Exception("Can't find text for source id " + str(source_text_id))
5789

90+
# Get the paragraph text details
5891
source_text = text_select_result[0]._mapping["paragraph_text"]
59-
print(source_text_id, source_text)
92+
source_text_hash = text_select_result[0]._mapping["hashed_text"]
6093

6194
# Extract the relationship
6295
relationship_table_name = get_complete_table_name(connection_details, "relationship")
6396
relationship_table = connection_details["metadata"].tables[relationship_table_name]
6497
relationship_select_statement = SELECT_STATEMENT(relationship_table)
6598
relationship_select_statement = relationship_select_statement.where(relationship_table.c.run_id == run_id)
6699

100+
all_results = []
67101
all_relationships = connection_details["session"].execute(relationship_select_statement).all()
68102
for curr_relationship in all_relationships:
69-
print(curr_relationship._mapping)
70-
break
103+
# Extract the fields
104+
src_entity_id = curr_relationship._mapping["src_entity_id"]
105+
dst_entity_id = curr_relationship._mapping["dst_entity_id"]
106+
relationship_type_id = curr_relationship._mapping["relationship_type_id"]
107+
108+
# Get the values from the ids
109+
src_text = get_entity_name(connection_details, src_entity_id)
110+
dst_text = get_entity_name(connection_details, dst_entity_id)
111+
relationship_type = get_relationship_type(connection_details, relationship_type_id)
112+
113+
# Record this dataset
114+
all_results.append({
115+
"doc_id" : source_text_id,
116+
"title" : source_text_hash,
117+
"text" : source_text,
118+
"src" : src_text,
119+
"dst" : dst_text,
120+
"type" : relationship_type
121+
})
122+
123+
return pd.DataFrame(all_results)
124+
125+
DATASET_SPLIT = [0.8, 0.1, 0.1]
126+
def save_results(combined_df, save_dir):
127+
# Create output directory if it doesn't exist
128+
os.makedirs(save_dir, exist_ok=True)
129+
130+
# Calculate split sizes
131+
total_rows = len(combined_df)
132+
train_size = int(0.8 * total_rows)
133+
test_size = int(0.1 * total_rows)
134+
valid_size = total_rows - train_size - test_size
135+
136+
# Split the dataframe
137+
train_df = combined_df[:train_size]
138+
test_df = combined_df[train_size:train_size+test_size]
139+
valid_df = combined_df[train_size+test_size:]
140+
141+
# Function to save dataframes to CSV files
142+
def save_to_csv(data, prefix):
143+
file_names = []
144+
for i, chunk in enumerate(np.array_split(data, max(1, len(data) // 1000))):
145+
file_name = f"{prefix}_{i}.csv"
146+
chunk.to_csv(os.path.join(save_dir, file_name), index=False, sep = '\t')
147+
file_names.append(file_name)
148+
return file_names
149+
150+
# Save each split to CSV files
151+
train_files = save_to_csv(train_df, 'train')
152+
test_files = save_to_csv(test_df, 'test')
153+
valid_files = save_to_csv(valid_df, 'valid')
154+
155+
# Create text files listing the CSV files for each split
156+
for split_name, file_list in [('train', train_files), ('test', test_files), ('valid', valid_files)]:
157+
with open(os.path.join(save_dir, f"{split_name}.txt"), 'w') as f:
158+
f.write('\n'.join(file_list))
159+
160+
print(f"Files saved in {save_dir} directory.")
71161

72162
def main():
73163
# Load the schema
@@ -77,10 +167,14 @@ def main():
77167
# Get all of the user runs
78168
save_dir = "extracted_feedback"
79169
all_user_runs = get_all_user_runs(connection_details)
170+
dfs_to_combine = []
80171
for run_id, source_text_id in all_user_runs:
81-
get_user_run_relationships(connection_details, save_dir, run_id, source_text_id)
82-
break
172+
feedback_df = get_user_run_relationships(connection_details, save_dir, run_id, source_text_id)
173+
dfs_to_combine.append(feedback_df)
174+
combined_df = pd.concat(dfs_to_combine)
83175

176+
# Save the result in the proper format
177+
save_results(combined_df, args.save_dir)
84178
connection_details["session"].close()
85179

86180
if __name__ == "__main__":

0 commit comments

Comments
 (0)