Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import absolute_import, division, print_function

import os
import json
import tempfile
import time
import zipfile
Expand Down Expand Up @@ -51,8 +52,9 @@ class BaseModel:
def __init__(self, **kwargs):
self.gpu_lock = threading.Lock()
self.rag = None
self.all_locations = []
self._cached_query_locations = None
self.get_model_response = self.get_model_response_qianfan
pass

def get_model_response_deepseek(self, prompt):
# Please install OpenAI SDK first: `pip3 install openai`
Expand Down Expand Up @@ -162,7 +164,78 @@ def preprocess(self, **kwargs):

def train(self, train_data, valid_data=None, **kwargs):
print("BaseModel doesn't need to train")


def _load_locations_from_dataset(self, data):
"""
Read the original JSONL dataset to build a mapping of query -> location.

Each line in the JSONL has a 'level_4_dim' field that contains the
province name (e.g. 'Shanghai', 'Beijing'). We use this as the
source of truth for location, instead of relying on os.getcwd()
which just returns the project root directory.

Results are cached after the first call to avoid re-parsing the
dataset file on every predict() invocation.
"""
# Return cached result if we already parsed the dataset once
if self._cached_query_locations is not None:
return self._cached_query_locations

query_to_location = {}
all_locations = set()

# The train_data path points to the JSONL file that was configured
# in testenv.yaml. Try to find it from the dataset object.
dataset_path = getattr(data, 'data_file', None)

# If data object doesn't expose the file path, we can still
# build the mapping by scanning query text against known provinces.
# But first, try reading the JSONL directly if possible.
if dataset_path and os.path.isfile(dataset_path):
with open(dataset_path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if not line:
continue
entry = json.loads(line)
query = entry.get('query', '')
location = entry.get('level_4_dim', 'Unknown')
query_to_location[query] = location
# Only track valid province names, skip unknowns
if location and location != 'Unknown':
all_locations.add(location)
else:
# Fallback: figure out the location from the query text.
# The government_rag dataset queries mention the province name
# in Chinese, so we scan for known province keywords.
zh_to_en = {
"北京": "Beijing", "上海": "Shanghai", "天津": "Tianjin",
"重庆": "Chongqing", "河北": "Hebei", "山西": "Shanxi",
"辽宁": "Liaoning", "吉林": "Jilin", "黑龙江": "Heilongjiang",
"江苏": "Jiangsu", "浙江": "Zhejiang", "安徽": "Anhui",
"福建": "Fujian", "江西": "Jiangxi", "山东": "Shandong",
"河南": "Henan", "湖北": "Hubei", "湖南": "Hunan",
"广东": "Guangdong", "海南": "Hainan", "四川": "Sichuan",
"贵州": "Guizhou", "云南": "Yunnan", "陕西": "Shaanxi",
"甘肃": "Gansu", "青海": "Qinghai", "台湾": "Taiwan",
"内蒙古": "Inner Mongolia", "广西": "Guangxi", "西藏": "Tibet",
"宁夏": "Ningxia", "新疆": "Xinjiang", "香港": "Hong Kong",
"澳门": "Macau"
}

for i in range(len(data.x)):
query = data.x[i]
location = "Unknown"
for zh_name, en_name in zh_to_en.items():
if zh_name in query:
location = en_name
all_locations.add(en_name)
break
query_to_location[query] = location

result = (query_to_location, list(all_locations))
self._cached_query_locations = result
return result

def save(self, model_path):
print("BaseModel doesn't need to save")
Expand Down Expand Up @@ -202,25 +275,27 @@ def process_query(self, query: str, ground_truth: str, location: str, rag_type:
def predict(self, data, input_shape=None, **kwargs):
print("BaseModel predict")
LOGGER.info("BaseModel predict")
LOGGER.info(f"Dataset: {data.dataset_name}")
LOGGER.info(f"Description: {data.description}")

answer_list = []

# Get location from the directory name
current_dir = os.path.basename(os.getcwd())


# Build a lookup of query -> province from the dataset itself.
# The old code used os.path.basename(os.getcwd()) which always
# returned the project root name (e.g. "ianvs") instead of an
# actual province, breaking all location-based RAG filtering.
# Results are cached internally so repeated calls don't re-parse.
query_to_location, all_locations = self._load_locations_from_dataset(data)
self.all_locations = all_locations

# Create tasks for all queries
tasks = []
for i in range(len(data.x)):
# Add global task
tasks.append((data.x[i], data.y[i], current_dir, "[global]"))
# Add local task
tasks.append((data.x[i], data.y[i], current_dir, "[local]"))
# Add other task
tasks.append((data.x[i], data.y[i], current_dir, "[other]"))
# Add model task
tasks.append((data.x[i], data.y[i], current_dir, "[model]"))
query = data.x[i]
location = query_to_location.get(query, "Unknown")

tasks.append((query, data.y[i], location, "[global]"))
tasks.append((query, data.y[i], location, "[local]"))
tasks.append((query, data.y[i], location, "[other]"))
tasks.append((query, data.y[i], location, "[model]"))

# Process tasks in parallel using ThreadPoolExecutor
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: # Reduced number of workers
Expand Down
Loading