Skip to content

Commit fd3f676

Browse files
committed
api is being finalized, still need to integrate translation function, model selection implemented
1 parent 1d31c0c commit fd3f676

23 files changed

+594
-666
lines changed
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Lines changed: 132 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,23 @@
11
from sentence_transformers import SentenceTransformer
22
from sklearn.metrics.pairwise import cosine_similarity
33
import spacy
4-
comparison_models = [
5-
"sentence-transformers/LaBSE",
6-
"xlm-roberta-base",
7-
"multi-qa-distilbert-cos-v1",
8-
"multi-qa-MiniLM-L6-cos-v1",
9-
"multi-qa-mpnet-base-cos-v1"
10-
]
11-
12-
def semantic_compare(model_name, og_article, translated_article, source_language, target_language, sim_threshold): # main function
4+
5+
from ..main import server
6+
7+
def semantic_compare(
8+
original_blob,
9+
translated_blob,
10+
source_language,
11+
target_language,
12+
sim_threshold
13+
):
1314
"""
14-
semantic_compare(model_name, og_article, translated_article, source_language, target_language, sim_threshold)
15-
Performs semantic comparison between two articles in different languages.
15+
Performs semantic comparison between two articles in
16+
different languages.
1617
1718
Expected parameters:
1819
{
19-
"model_name": "string - name of the transformer model to use",
20-
"og_article": "string - original article text",
20+
"original_article": "string - original article text",
2121
"translated_article": "string - translated article text",
2222
"source_language": "string - language code of original article",
2323
"target_language": "string - language code of translated article",
@@ -30,36 +30,70 @@ def semantic_compare(model_name, og_article, translated_article, source_language
3030
"target_sentences": [sentences from translated article],
3131
"missing_info_index": [indices of missing content],
3232
"extra_info_index": [indices of extra content]
33+
"success": [true or false depending on if request was successful]
3334
}
3435
"""
36+
success = True
37+
3538
# Load a multilingual sentence transformer model (LaBSE or cmlm)
36-
match model_name:
37-
case "sentence-transformers/LaBSE":
38-
model = SentenceTransformer('sentence-transformers/LaBSE')
39-
case "xlm-roberta-base":
40-
model = SentenceTransformer('xlm-roberta-base')
41-
case "multi-qa-distilbert-cos-v1":
42-
model = SentenceTransformer('multi-qa-distilbert-cos-v1')
43-
case "multi-qa-MiniLM-L6-cos-v1":
44-
model = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
45-
case "multi-qa-mpnet-base-cos-v1":
46-
model = SentenceTransformer('multi-qa-mpnet-base-cos-v1')
47-
case _:
48-
model = SentenceTransformer('sentence-transformers/LaBSE')
49-
50-
og_article_sentences = preprocess_input(og_article, source_language)
51-
translated_article_sentences = preprocess_input(translated_article, target_language)
52-
53-
# encode the sentences
54-
og_embeddings = model.encode(og_article_sentences)
55-
translated_embeddings = model.encode(translated_article_sentences)
56-
57-
if sim_threshold is None:
58-
sim_threshold = 0.75
59-
60-
missing_info, missing_info_index = sentences_diff(og_article_sentences, og_embeddings, translated_embeddings, sim_threshold)
61-
extra_info, extra_info_index = sentences_diff(translated_article_sentences, translated_embeddings, og_embeddings, sim_threshold)
62-
return og_article_sentences, translated_article_sentences, missing_info_index, extra_info_index
39+
try:
40+
model = SentenceTransformer(server.selected_comparison_model)
41+
except:
42+
return {
43+
"original_sentences": original_sentences,
44+
"translated_sentences": translated_sentences,
45+
"missing_info": [],
46+
"extra_info": [],
47+
"missing_info_indices": [],
48+
"extra_info_indices": [],
49+
"success": False
50+
}
51+
52+
try:
53+
original_sentences = preprocess_input(
54+
original_blob,
55+
source_language
56+
)
57+
translated_sentences = preprocess_input(
58+
translated_blob,
59+
target_language
60+
)
61+
except:
62+
success = False
63+
64+
try:
65+
# encode the sentences
66+
original_embeddings = model.encode(original_sentences)
67+
translated_embeddings = model.encode(translated_sentences)
68+
69+
if sim_threshold is None:
70+
sim_threshold = 0.75
71+
72+
missing_info, missing_info_indices = sentences_diff(
73+
original_sentences,
74+
original_embeddings,
75+
translated_embeddings,
76+
sim_threshold
77+
)
78+
79+
extra_info, extra_info_indices = sentences_diff(
80+
translated_sentences,
81+
translated_embeddings,
82+
original_embeddings,
83+
sim_threshold
84+
)
85+
except:
86+
success = False
87+
88+
return {
89+
"original_sentences": original_sentences,
90+
"translated_sentences": translated_sentences,
91+
"missing_info": missing_info,
92+
"extra_info": extra_info,
93+
"missing_info_indices": missing_info_indices,
94+
"extra_info_indices": extra_info_indices,
95+
"success": success
96+
}
6397

6498
def universal_sentences_split(text):
6599
"""
@@ -83,7 +117,8 @@ def universal_sentences_split(text):
83117

84118
def preprocess_input(article, language):
85119
"""
86-
Preprocesses input text based on language using appropriate spaCy model.
120+
Preprocesses input text based on language using appropriate
121+
spaCy model.
87122
88123
Expected parameters:
89124
{
@@ -96,9 +131,10 @@ def preprocess_input(article, language):
96131
"sentences": [array of preprocessed sentences]
97132
}
98133
"""
134+
99135
# Define a mapping of languages to spaCy model names
100136
language_model_map = {
101-
"en": "en_core_web_sm", # English
137+
"en": "en_core_web_sm", # English
102138
"de": "de_core_news_sm", # German
103139
"fr": "fr_core_news_sm", # French
104140
"es": "es_core_news_sm", # Spanish
@@ -108,16 +144,21 @@ def preprocess_input(article, language):
108144
}
109145

110146
# Acommodate for TITLES
111-
cleaned_article = article.replace('\n\n', '<DOUBLE_NEWLINE>') # temporarily replace double newlines
112-
cleaned_article = cleaned_article.replace('\n', '.') # replace single newlines with periods
113-
cleaned_article = cleaned_article.replace('<DOUBLE_NEWLINE>', ' ').strip() # remove double newlines
114-
115-
# Check if the language is supported
116-
if language not in language_model_map:
117-
sentences = universal_sentences_split(cleaned_article) # Fallback to universal sentence splitting
118-
return sentences
119-
else:
120-
# Load the appropriate spaCy model
147+
# temporarily replace double newlines
148+
cleaned_article = article.replace(
149+
'\n\n',
150+
'<DOUBLE_NEWLINE>'
151+
)
152+
153+
cleaned_article = cleaned_article.replace('\n', '.')
154+
155+
cleaned_article = cleaned_article.replace(
156+
'<DOUBLE_NEWLINE>',
157+
' '
158+
).strip()
159+
160+
if language in language_model_map:
161+
# Load the appropriate spaCy model
121162
model_name = language_model_map[language]
122163
nlp = spacy.load(model_name)
123164

@@ -126,15 +167,27 @@ def preprocess_input(article, language):
126167
sentences = [sent.text for sent in doc.sents]
127168
return sentences
128169

129-
def sentences_diff(article_sentences, first_embeddings, second_embeddings, sim_threshold):
170+
# Fallback to universal sentence splitting
171+
sentences = universal_sentences_split(cleaned_article)
172+
return sentences
173+
174+
175+
def sentences_diff(
176+
article_sentences,
177+
first_embeddings,
178+
second_embeddings,
179+
sim_threshold
180+
):
130181
"""
131182
Compares sentence embeddings to find semantic differences.
132183
133184
Expected parameters:
134185
{
135186
"article_sentences": [array of sentences],
136-
"first_embeddings": [array of sentence embeddings from first article],
137-
"second_embeddings": [array of sentence embeddings from second article],
187+
"first_embeddings": [array of sentence embeddings from
188+
first article],
189+
"second_embeddings": [array of sentence embeddings from second
190+
article],
138191
"sim_threshold": "float - similarity threshold value"
139192
}
140193
@@ -145,91 +198,38 @@ def sentences_diff(article_sentences, first_embeddings, second_embeddings, sim_t
145198
}
146199
"""
147200
diff_info = []
148-
indices = [] # Track the indices of differing sentences
201+
indices = [] # track the indices of differing sentences
149202
for i, eng_embedding in enumerate(first_embeddings):
150-
# Calculate similarity between the current English sentence and all French sentences
151-
similarities = cosine_similarity([eng_embedding], second_embeddings)[0]
203+
similarities = cosine_similarity(
204+
[eng_embedding], second_embeddings)[0]
152205

153-
# Find the best matching sentences
206+
# find the best matching sentences
154207
max_sim = max(similarities)
155208

156-
if max_sim < sim_threshold: # Threshold for similarity
157-
diff_info.append(article_sentences[i]) # This sentence might be missing or extra
209+
if max_sim < sim_threshold:
210+
# this sentence might be missing or extra
211+
diff_info.append(article_sentences[i])
158212
indices.append(i)
159213

160214
return diff_info, indices
161215

162-
def perform_semantic_comparison(request_data):
163-
"""
164-
Process the JSON request data and perform semantic comparison
165-
166-
Expected JSON format:
167-
{
168-
"article_text_blob_1": "string",
169-
"article_text_blob_2": "string",
170-
"article_text_blob_1_language": "string",
171-
"article_text_blob_2_language": "string",
172-
"comparison_threshold": 0,
173-
"model_name": "string"
174-
}
175-
176-
Returns:
177-
{
178-
"comparisons": [
179-
{
180-
"left_article_array": [sentences from article 1],
181-
"right_article_array": [sentences from article 2],
182-
"left_article_missing_info_index": [indices of missing content],
183-
"right_article_extra_info_index": [indices of extra content]
184-
}
185-
]
186-
}
187-
"""
188-
# Extract values from request data
189-
source_article = request_data["article_text_blob_1"]
190-
target_article = request_data["article_text_blob_2"]
191-
source_language = request_data["article_text_blob_1_language"]
192-
target_language = request_data["article_text_blob_2_language"]
193-
sim_threshold = request_data["comparison_threshold"] or 0.65 # Default to 0.65 if 0
194-
model_name = request_data["model_name"] or "LaBSE" # Default to LaBSE if not specified
195-
196-
# Perform semantic comparison
197-
source_sentences, target_sentences, missing_info_index, extra_info_index = semantic_compare(
198-
model_name=model_name,
199-
og_article=source_article,
200-
translated_article=target_article,
201-
source_language=source_language,
202-
target_language=target_language,
203-
sim_threshold=sim_threshold
204-
)
205-
206-
# Return results in a structured format
207-
return {
208-
"comparisons": [
209-
{
210-
"left_article_array": source_sentences,
211-
"right_article_array": target_sentences,
212-
"left_article_missing_info_index": missing_info_index,
213-
"right_article_extra_info_index": extra_info_index
214-
}
215-
]
216-
}
217-
218-
219-
def main(): #testing the code
220-
# Example test request data
221-
test_request = {
222-
"article_text_blob_1": "This is the first sentence.\n\nThis is the second sentence\nThis is the third sentence.",
223-
"article_text_blob_2": "\n\nCeci est la première phrase\nJe vais bien. Ceci est la deuxième phrase.",
224-
"article_text_blob_1_language": "en",
225-
"article_text_blob_2_language": "fr",
226-
"comparison_threshold": 0.65,
227-
"model_name": "sentence-transformers/LaBSE"
228-
}
216+
# def main():
217+
# print("pre begin")
218+
# # Example test request data
219+
# test_request = {
220+
# "article_text_blob_1": "This is the first sentence.\n\nThis is the second sentence\nThis is the third sentence.",
221+
# "article_text_blob_2": "\n\nCeci est la première phrase\nJe vais bien. Ceci est la deuxième phrase.",
222+
# "article_text_blob_1_language": "en",
223+
# "article_text_blob_2_language": "fr",
224+
# "comparison_threshold": 0.65,
225+
# "model_name": "sentence-transformers/LaBSE"
226+
# }
229227

230-
result = perform_semantic_comparison(test_request)
231-
print("Comparison Results:", result)
228+
# print("begin")
229+
# result = perform_semantic_comparison(test_request)
230+
# print("Comparison Results:", result)
232231

233232

234-
if __name__ == "__main__":
235-
main()
233+
# if __name__ == "__main__":
234+
# print("pre pre begin")
235+
# main()

0 commit comments

Comments
 (0)