This repository was archived by the owner on Oct 25, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 213
[NeuralChat] RAG evaluation #1333
Open
Liangyx2
wants to merge
158
commits into
main
Choose a base branch
from
yuxiang/evaluation
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 2 commits
Commits
Show all changes
158 commits
Select commit
Hold shift + click to select a range
f820019
add retrieval dataset construction codes
Liangyx2 06f8162
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 5ef0332
Update llm_generate_raw_data.py
Liangyx2 ee1db83
Delete intel_extension_for_transformers/neural_chat/tools/evaluation/…
Liangyx2 89597f2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] b132d66
Delete intel_extension_for_transformers/neural_chat/tools/evaluation/…
Liangyx2 8e955ce
update
Liangyx2 635b906
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] d7d3d03
Delete intel_extension_for_transformers/neural_chat/tools/evaluation/…
Liangyx2 c9fec02
Delete intel_extension_for_transformers/neural_chat/tools/evaluation/…
Liangyx2 5e32113
Delete intel_extension_for_transformers/neural_chat/tools/evaluation/…
Liangyx2 f67622c
Delete intel_extension_for_transformers/neural_chat/tools/evaluation/…
Liangyx2 f2e344a
Delete intel_extension_for_transformers/neural_chat/tools/evaluation/…
Liangyx2 383e5b3
Update prompt.py
Liangyx2 81014d1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 4b7bec7
Update llm_generate_raw_data.py
Liangyx2 0df51a6
Update llm_generate_raw_data.py
Liangyx2 95b16bd
Update retrieval_dataset_construction.py
Liangyx2 80dd21b
Update llm_generate_raw_data.py
Liangyx2 f495b22
Update mine_hard_negatives_check_similarity.py
Liangyx2 593dee3
add test_evaluation.py to nightly test
Liangyx2 cf59b18
Update and rename requirements.txt to requirements_cpu.txt
Liangyx2 40e0b0e
Create requirements_cuda.txt
Liangyx2 bf1b1aa
Update requirements.txt
Liangyx2 5552ebc
Update retrieval_dataset_construction.py
Liangyx2 d3b7579
Update llm_generate_raw_data.py
Liangyx2 f500b2b
Update retrieval_dataset_construction.py
Liangyx2 b65c4bf
Update llm_generate_raw_data.py
Liangyx2 c43ab73
Update test_evaluation.py
Liangyx2 feda3c0
Update retrieval_dataset_construction.py
Liangyx2 1c2c22c
Update mine_hard_negatives_check_similarity.py
Liangyx2 55a5cda
add README.md
Liangyx2 7a74f86
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 39754d0
Update README.md
Liangyx2 d7e95f0
add evaluate_retrieval.py
Liangyx2 186ab43
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 1496219
Update test_evaluation.py
Liangyx2 03a768e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 128d587
Update test_evaluation.py
Liangyx2 25177bd
Merge branch 'main' into yuxiang/evaluation
XuehaoSun 705752a
add README.md
Liangyx2 675fe2e
Update prompt.py
Liangyx2 988e542
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] d0c3c34
add llm_generate_truth.py and data
Liangyx2 be1106b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 48788d4
add ragas_evaluation.py
Liangyx2 54cc6c0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] e1b5585
Create requirements.txt
Liangyx2 88a4293
Update llm_generate_truth.py
Liangyx2 83060f9
Update evaluate_retrieval.py
Liangyx2 76b1175
Update ragas_evaluation.py
Liangyx2 b775095
Update test_evaluation.py
Liangyx2 edbb32c
Update llm_generate_truth.py
Liangyx2 8962abf
Update README.md
Liangyx2 2ef4e05
Update README.md
Liangyx2 d2ab7d8
add README.md
Liangyx2 bcdf209
Update README.md
Liangyx2 102649b
Update README.md
Liangyx2 36a28a4
Update README.md
Liangyx2 548fdd9
Add files via upload
Liangyx2 36448ea
Delete intel_extension_for_transformers/neural_chat/tests/ci/tools/te…
Liangyx2 26e3e9d
Update requirements.txt
Liangyx2 e4793d3
Update README.md
Liangyx2 0569b54
Update hn_mine.py
Liangyx2 2d15ec0
Update README.md
Liangyx2 e8127e9
Update ragas_evaluation.py
Liangyx2 321e9b6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] f9b4dab
Update requirements.txt
Liangyx2 76dc219
Update README.md
Liangyx2 b9db553
Update README.md
Liangyx2 d7b68cb
Update README.md
Liangyx2 48de606
Update requirements.txt
Liangyx2 415ebc8
Update ragas_evaluation.py
Liangyx2 f03badd
Update test_evaluation.py
Liangyx2 2b92e74
Update README.md
Liangyx2 9091729
Update retrieval_dataset_construction.py
Liangyx2 be32736
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 2c4f452
Update hn_mine.py
Liangyx2 c48f66a
Update llm_generate_raw_data.py
Liangyx2 654c44a
Update mine_hard_negatives_check_similarity.py
Liangyx2 5208c98
Update hn_mine.py
Liangyx2 ace1090
Update test_evaluation.py
Liangyx2 83f10e9
Update ragas_evaluation.py
Liangyx2 ac0aef1
Update README.md
Liangyx2 8deaabd
Update README.md
Liangyx2 2eb084c
Update README.md
Liangyx2 510e801
Update README.md
Liangyx2 dd1f37c
Update README.md
Liangyx2 ed95d2d
Update prompt.py
Liangyx2 e253f41
Update ragas_evaluation.py
Liangyx2 fc0b6b9
add evaluate_retrieval_auto.py
Liangyx2 6f081b5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 746adec
Update evaluate_retrieval_auto.py
Liangyx2 100322e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 5e07789
Update evaluate_retrieval.py
Liangyx2 0a2f742
Update ragas_evaluation.py
Liangyx2 1752684
Update test_evaluation.py
Liangyx2 2a2238e
Update ragas_evaluation.py
Liangyx2 e8f0f9c
Update README.md
Liangyx2 8d65078
Update and rename evaluate_retrieval_auto.py to evaluate_retrieval_be…
Liangyx2 a951a89
Update evaluate_retrieval_benchmark.py
Liangyx2 13921f6
add retrieval_benchmark.py
Liangyx2 02c0813
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] d212d66
Update retrieval_benchmark.py
Liangyx2 20529a4
add ragas_benchmark ragas_evaluation_benchmark
Liangyx2 5026421
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] cfa7d9c
Update retrieval_benchmark.py
Liangyx2 8d1215e
Update evaluate_retrieval_benchmark.py
Liangyx2 3458a8e
Update retrieval_benchmark.py
Liangyx2 4effd37
Update ragas_evaluation_benchmark.py
Liangyx2 3c38ae6
Update ragas_benchmark.py
Liangyx2 b02da07
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] a2a7de1
Update ragas_evaluation_benchmark.py
Liangyx2 4191f4b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 35b2d7d
Update evaluate_retrieval_benchmark.py
Liangyx2 56037b9
Update ragas_evaluation_benchmark.py
Liangyx2 de44f0d
add retrieval_benchmark.sh
Liangyx2 67456e4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 2a91336
add ragas_benchmark.sh
Liangyx2 8f05a34
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] c64ca3c
add data.txt
Liangyx2 fbef1f6
Update ragas_benchmark.sh
Liangyx2 f50aeb4
Update ragas_evaluation_benchmark.py
Liangyx2 84aea7c
Update ragas_benchmark.sh
Liangyx2 ad1814a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 932562d
Update and rename ragas_benchmark.py to ragas_superbenchmark.py
Liangyx2 50d8c83
Update evaluate_retrieval_benchmark.py
Liangyx2 a4ea5dd
Update retrieval_benchmark.sh
Liangyx2 6e29d43
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 702f9a9
Update and rename retrieval_benchmark.py to retrieval_superbenchmark.py
Liangyx2 0452526
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 008a892
add README.md
Liangyx2 5303837
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 8957b18
Update README.md
Liangyx2 96f477c
Update README.md
Liangyx2 c99856d
Update README.md
Liangyx2 19dfb93
Update README.md
Liangyx2 99940f3
Update README.md
Liangyx2 464d52b
Update README.md
Liangyx2 da2e829
Update README.md
Liangyx2 3ce2cb2
Update README.md
Liangyx2 268d89c
Update README.md
Liangyx2 40fc2e9
Update README.md
Liangyx2 13bb3b8
Update README.md
Liangyx2 763bd1d
add config file form rag evaluation
xmx-521 092e951
complete config superbenchmark
xmx-521 e931143
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] f0a0cd6
Merge branch 'main' into yuxiang/evaluation
XuhuiRen 895075b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 6b60154
Create test_evaluation.py in CI
Liangyx2 c73a68f
Update requirements.txt
Liangyx2 c6f8906
Merge branch 'main' into yuxiang/evaluation
Liangyx2 7c80ce2
Merge branch 'main' into yuxiang/evaluation
VincyZhang 576ce57
Merge branch 'main' into yuxiang/evaluation
Liangyx2 2a3ddd9
Merge branch 'main' into yuxiang/evaluation
Liangyx2 b4c0e67
Update ragas_evaluation_benchmark.py
Liangyx2 e75bbe4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] a0853a8
Merge branch 'main' into yuxiang/evaluation
Liangyx2 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
16 changes: 16 additions & 0 deletions
16
intel_extension_for_transformers/neural_chat/tools/evaluation/__init__.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (c) 2023 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
246 changes: 246 additions & 0 deletions
246
intel_extension_for_transformers/neural_chat/tools/evaluation/context_utils.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,246 @@ | ||
# !/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (c) 2023 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import unicodedata | ||
import pandas as pd | ||
import re, json | ||
from langchain.document_loaders import UnstructuredMarkdownLoader | ||
from docx import Document as DDocument | ||
from bs4 import BeautifulSoup | ||
import fitz | ||
import easyocr | ||
from PIL import Image | ||
import numpy as np | ||
import io | ||
|
||
def uni_pro(text): | ||
"""Check if the character is ASCII or falls in the category of non-spacing marks.""" | ||
normalized_text = unicodedata.normalize('NFKD', text) | ||
filtered_text = '' | ||
for char in normalized_text: | ||
if ord(char) < 128 or unicodedata.category(char) == 'Mn': | ||
filtered_text += char | ||
return filtered_text | ||
|
||
|
||
def read_pdf(pdf_path): | ||
"""Read the pdf file.""" | ||
doc = fitz.open(pdf_path) | ||
reader = easyocr.Reader(['en']) | ||
result ='' | ||
for i in range(doc.page_count): | ||
page = doc.load_page(i) | ||
pagetext = page.get_text().strip() | ||
if pagetext: | ||
if pagetext.endswith('!') or pagetext.endswith('?') or pagetext.endswith('.'): | ||
result=result+pagetext | ||
else: | ||
result=result+pagetext+'.' | ||
if len(doc.get_page_images(i)) > 0 : | ||
for img in doc.get_page_images(i): | ||
if img: | ||
pageimg='' | ||
xref = img[0] | ||
img_data = doc.extract_image(xref) | ||
img_bytes = img_data['image'] | ||
pil_image = Image.open(io.BytesIO(img_bytes)) | ||
img = np.array(pil_image) | ||
img_result = reader.readtext(img, paragraph=True, detail=0) | ||
pageimg=pageimg + ', '.join(img_result).strip() | ||
if pageimg.endswith('!') or pageimg.endswith('?') or pageimg.endswith('.'): | ||
pass | ||
else: | ||
pageimg=pageimg+'.' | ||
result=result+pageimg | ||
return result | ||
|
||
|
||
def read_html(html_path): | ||
"""Read the html file.""" | ||
with open(html_path, 'r', encoding="utf-8") as file: | ||
html = file.read() | ||
soup = BeautifulSoup(html, 'html.parser') | ||
text = soup.get_text(strip=True) | ||
return text | ||
|
||
|
||
def read_txt(txt_path): | ||
"""Read txt file.""" | ||
with open(txt_path, 'r') as file: | ||
text = file.read() | ||
return text | ||
|
||
|
||
def read_docx(doc_path): | ||
"""Read docx file.""" | ||
doc = DDocument(doc_path) | ||
text = '' | ||
for paragraph in doc.paragraphs: | ||
text += paragraph.text | ||
return text | ||
|
||
|
||
def read_md(md_path): | ||
"""Read docx file.""" | ||
loader = UnstructuredMarkdownLoader(md_path) | ||
text = loader.load()[0].page_content | ||
return text | ||
|
||
|
||
def load_json(input, process, max_length, min_length): | ||
"""Load and process json file.""" | ||
data = [] | ||
with open(input, 'r') as file: | ||
for line in file: | ||
json_obj = json.loads(line) | ||
data.append(json_obj) | ||
|
||
new_sens = [] | ||
new_collect = [] | ||
for sub in data: | ||
sub['content'].replace('#', " ") | ||
sub['content'] = re.sub(r'\s+', ' ', sub['content']) | ||
if not process: | ||
if len(sub['content']) < min_length: | ||
continue | ||
new_doc = [sub['content'], sub['link']] | ||
new_collect.append(new_doc) | ||
else: | ||
for sub in data: | ||
sub['content'].replace('#', " ") | ||
if len(sub['content'])<min_length: | ||
continue | ||
split_sen = re.split(r'[.?!]', sub['content']) | ||
for num in range(len(split_sen)): | ||
split_sen[num] = re.sub(r'\s+', ' ', split_sen[num]) | ||
if num +1 < len(split_sen): | ||
if len(split_sen[num]) >max_length: | ||
new_sens.append(split_sen[num].strip()) | ||
else: | ||
split_sen[num +1] =split_sen[num] +split_sen[num+1] | ||
else: | ||
new_sens.append(split_sen[num]) | ||
|
||
paragraphs = list(set(new_sens)) | ||
for paragraph in paragraphs: | ||
new_doc = [paragraph, sub['link']] | ||
new_collect.append(new_doc) | ||
return new_collect | ||
|
||
|
||
def load_xlsx(input): | ||
"""Load and process xlsx file.""" | ||
df = pd.read_excel(input) | ||
header = df.columns.tolist() | ||
all_data = [] | ||
if 'Questions' in header and 'Answers' in header: | ||
for index, row in df.iterrows(): | ||
sub = row["Answers"] | ||
sub=sub.replace('#', " ") | ||
sub = sub.replace(r'\t', " ") | ||
sub = sub.replace('\n', ' ') | ||
sub = sub.replace('\n\n', ' ') | ||
sub = re.sub(r'\s+', ' ', sub) | ||
new_doc = [sub, input] | ||
all_data.append(new_doc) | ||
elif 'question' in header and 'answer' in header and 'link' in header: | ||
for index, row in df.iterrows(): | ||
sub = row["answer"] | ||
sub = sub.replace('#', " ") | ||
sub = sub.replace(r'\t', " ") | ||
sub = sub.replace('\n', ' ') | ||
sub = sub.replace('\n\n', ' ') | ||
sub = re.sub(r'\s+', ' ', sub) | ||
all_data.append([sub, row['link']]) | ||
elif 'context' in header and 'link' in header: | ||
for index, row in df.iterrows(): | ||
sub = row['context'] | ||
sub = sub.replace('#', " ") | ||
sub = sub.replace(r'\t', " ") | ||
sub = sub.replace('\n', ' ') | ||
sub = sub.replace('\n\n', ' ') | ||
sub = re.sub(r'\s+', ' ', sub) | ||
all_data.append([sub, row['link']]) | ||
return all_data | ||
|
||
def load_csv(input): | ||
""" Load the csv file.""" | ||
df = pd.read_csv(input) | ||
all_data = [] | ||
documents = [] | ||
for index, row in df.iterrows(): | ||
sub = row["correct_answer"] | ||
all_data.append(sub) | ||
|
||
for data in all_data: | ||
data.replace('#', " ") | ||
data = re.sub(r'\s+', ' ', data) | ||
new_doc = [data, input] | ||
documents.append(new_doc) | ||
return documents | ||
|
||
def load_structured_data(input, process, max_length, min_length): | ||
"""Load structured context.""" | ||
if input.endswith("jsonl") or input.endswith("json"): | ||
content = load_json(input, process, max_length, min_length) | ||
elif input.endswith("xlsx"): | ||
content = load_xlsx(input) | ||
elif input.endswith("csv"): | ||
content = load_csv(input) | ||
return content | ||
|
||
def load_unstructured_data(input): | ||
"""Load unstructured context.""" | ||
if input.endswith("pdf"): | ||
text = read_pdf(input) | ||
elif input.endswith("docx"): | ||
text = read_docx(input) | ||
elif input.endswith("html"): | ||
text = read_html(input) | ||
elif input.endswith("txt"): | ||
text = read_txt(input) | ||
elif input.endswith("md"): | ||
text = read_md(input) | ||
|
||
text = text.replace('\n', ' ') | ||
text = text.replace('\n\n', ' ') | ||
text = uni_pro(text) | ||
text = re.sub(r'\s+', ' ', text) | ||
return text | ||
|
||
def get_chuck_data(content, max_length, min_length, input): | ||
"""Process the context to make it maintain a suitable length for the generation.""" | ||
sentences = re.split('(?<=[!.?])', content) | ||
|
||
paragraphs = [] | ||
current_length = 0 | ||
count = 0 | ||
current_paragraph = "" | ||
for sub_sen in sentences: | ||
count +=1 | ||
sentence_length = len(sub_sen) | ||
if current_length + sentence_length <= max_length: | ||
current_paragraph += sub_sen | ||
current_length += sentence_length | ||
if count == len(sentences) and len(current_paragraph.strip())>min_length: | ||
paragraphs.append([current_paragraph.strip() ,input]) | ||
else: | ||
paragraphs.append([current_paragraph.strip() ,input]) | ||
current_paragraph = sub_sen | ||
current_length = sentence_length | ||
|
||
return paragraphs |
97 changes: 97 additions & 0 deletions
97
intel_extension_for_transformers/neural_chat/tools/evaluation/hn_mine.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (c) 2023 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import json | ||
import random | ||
import numpy as np | ||
import faiss | ||
from tqdm import tqdm | ||
|
||
def create_index(embeddings, use_gpu): | ||
index = faiss.IndexFlatIP(len(embeddings[0])) | ||
embeddings = np.asarray(embeddings, dtype=np.float32) | ||
if use_gpu: | ||
co = faiss.GpuMultipleClonerOptions() | ||
co.shard = True | ||
co.useFloat16 = True | ||
index = faiss.index_cpu_to_all_gpus(index, co=co) | ||
index.add(embeddings) | ||
return index | ||
|
||
def batch_search(index, | ||
query, | ||
topk: int = 200, | ||
batch_size: int = 64): | ||
all_scores, all_inxs = [], [] | ||
for start_index in tqdm(range(0, len(query), batch_size), desc="Batches", disable=len(query) < 256): | ||
batch_query = query[start_index:start_index + batch_size] | ||
batch_scores, batch_inxs = index.search(np.asarray(batch_query, dtype=np.float32), k=topk) | ||
all_scores.extend(batch_scores.tolist()) | ||
all_inxs.extend(batch_inxs.tolist()) | ||
return all_scores, all_inxs | ||
|
||
def get_corpus(candidate_pool): | ||
corpus = [] | ||
for line in open(candidate_pool): | ||
line = json.loads(line.strip()) | ||
corpus.append(line['text']) | ||
return corpus | ||
|
||
def find_knn_neg(model, input_file, candidate_pool, output_file, sample_range, negative_number, use_gpu): | ||
corpus = [] | ||
queries = [] | ||
train_data = [] | ||
for line in open(input_file): | ||
line = json.loads(line.strip()) | ||
train_data.append(line) | ||
corpus.extend(line['pos']) | ||
if 'neg' in line: | ||
corpus.extend(line['neg']) | ||
queries.append(line['query']) | ||
|
||
if candidate_pool is not None: | ||
if not isinstance(candidate_pool, list): | ||
candidate_pool = get_corpus(candidate_pool) | ||
corpus = list(set(candidate_pool)) | ||
else: | ||
corpus = list(set(corpus)) | ||
|
||
p_vecs = model.encode(corpus, batch_size=256) | ||
q_vecs = model.encode(queries, batch_size=256) | ||
|
||
index = create_index(p_vecs, use_gpu=use_gpu) | ||
_, all_inxs = batch_search(index, q_vecs, topk=sample_range[-1]) | ||
assert len(all_inxs) == len(train_data) | ||
|
||
for i, data in enumerate(train_data): | ||
query = data['query'] | ||
inxs = all_inxs[i][sample_range[0]:sample_range[1]] | ||
filtered_inx = [] | ||
for inx in inxs: | ||
if inx == -1: break | ||
if corpus[inx] not in data['pos'] and corpus[inx] != query: | ||
filtered_inx.append(inx) | ||
|
||
if len(filtered_inx) > negative_number: | ||
filtered_inx = random.sample(filtered_inx, negative_number) | ||
data['neg'] = [corpus[inx] for inx in filtered_inx] | ||
|
||
with open(output_file, 'w') as f: | ||
for data in train_data: | ||
if len(data['neg']) < negative_number: | ||
data['neg'].extend(random.sample(corpus, negative_number - len(data['neg']))) | ||
f.write(json.dumps(data, ensure_ascii=False) + '\n') |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.