diff --git a/chinatravel/data/load_datasets.py b/chinatravel/data/load_datasets.py index cc6c0d6..d074f32 100644 --- a/chinatravel/data/load_datasets.py +++ b/chinatravel/data/load_datasets.py @@ -24,6 +24,27 @@ def default(self, obj): return super(NpEncoder, self).default(obj) +def _get_language(args): + return getattr(args, "lang", "zh") + + +def _normalize_query_language(data_i, lang): + if lang == "zh": + return data_i + + key_pairs = [ + ("nature_language", "nature_language_en"), + ("hard_logic", "hard_logic_en"), + ("hard_logic_py", "hard_logic_py_en"), + ("hard_logic_nl", "hard_logic_nl_en"), + ("preference", "preference_en"), + ] + for base_key, lang_key in key_pairs: + if lang_key in data_i: + data_i[base_key] = data_i[lang_key] + return data_i + + def load_query_local(args, version="", verbose=False): query_data = {} @@ -63,6 +84,8 @@ def load_query_local(args, version="", verbose=False): open(os.path.join(dir_ii, file_i), encoding="utf-8") ) + data_i = _normalize_query_language(data_i, _get_language(args)) + if hasattr(args, 'oracle_translation') and not args.oracle_translation: if "hard_logic" in data_i: del data_i["hard_logic"] @@ -111,7 +134,8 @@ def load_query(args): for data_i in query_data: - if "hard_logic_py" in data_i: + data_i = _normalize_query_language(data_i, _get_language(args)) + if "hard_logic_py" in data_i and isinstance(data_i["hard_logic_py"], str): data_i["hard_logic_py"] = ast.literal_eval(data_i["hard_logic_py"]) query_id_list = [data_i["uid"] for data_i in query_data] @@ -155,4 +179,4 @@ def load_query(args): print(uid, query_data[uid]) else: raise ValueError(f"{uid} not in query_data") - \ No newline at end of file + diff --git a/chinatravel/environment/tools/accommodations/apis.py b/chinatravel/environment/tools/accommodations/apis.py index f4a5727..1f8fe56 100644 --- a/chinatravel/environment/tools/accommodations/apis.py +++ b/chinatravel/environment/tools/accommodations/apis.py @@ -15,6 +15,7 @@ class Accommodations: def __init__( self, base_path: str = "../../database/accommodations/", en_version=False ): + file_suffix = "_en" if en_version else "" curdir = os.path.dirname(os.path.realpath(__file__)) city_list = [ "beijing", @@ -29,7 +30,7 @@ def __init__( "chongqing", ] data_path_list = [ - os.path.join(curdir, f"{base_path}/{city}/accommodations.csv") + os.path.join(curdir, f"{base_path}/{city}/accommodations{file_suffix}.csv") for city in city_list ] self.data = {} @@ -56,10 +57,8 @@ def __init__( ] for i, city in enumerate(city_list): - self.data[city_cn_list[i]] = self.data.pop(city) - self.key_type_tuple_list[city_cn_list[i]] = self.key_type_tuple_list.pop( - city - ) + self.data[city_cn_list[i]] = self.data[city] + self.key_type_tuple_list[city_cn_list[i]] = self.key_type_tuple_list[city] self.poi = Poi(en_version=en_version) diff --git a/chinatravel/environment/tools/attractions/apis.py b/chinatravel/environment/tools/attractions/apis.py index d235d3b..a348d2e 100644 --- a/chinatravel/environment/tools/attractions/apis.py +++ b/chinatravel/environment/tools/attractions/apis.py @@ -16,6 +16,7 @@ def __init__( base_path: str = "../../database/attractions", en_version=False, ): + file_suffix = "_en" if en_version else "" city_list = [ "beijing", "shanghai", @@ -30,7 +31,7 @@ def __init__( ] curdir = os.path.dirname(os.path.realpath(__file__)) data_path_list = [ - os.path.join(curdir, f"{base_path}/{city}/attractions.csv") + os.path.join(curdir, f"{base_path}/{city}/attractions{file_suffix}.csv") for city in city_list ] @@ -61,13 +62,11 @@ def __init__( ] for i, city in enumerate(city_list): - self.data[city_cn_list[i]] = self.data.pop(city) - self.key_type_tuple_list_map[city_cn_list[i]] = ( - self.key_type_tuple_list_map.pop(city) - ) - self.type_list_map[city_cn_list[i]] = self.type_list_map.pop(city) + self.data[city_cn_list[i]] = self.data[city] + self.key_type_tuple_list_map[city_cn_list[i]] = self.key_type_tuple_list_map[city] + self.type_list_map[city_cn_list[i]] = self.type_list_map[city] - self.poi = Poi() + self.poi = Poi(en_version=en_version) def keys(self, city: str): return self.key_type_tuple_list_map[city] diff --git a/chinatravel/environment/tools/intercity_transport/apis.py b/chinatravel/environment/tools/intercity_transport/apis.py index 2826b70..ee107b9 100644 --- a/chinatravel/environment/tools/intercity_transport/apis.py +++ b/chinatravel/environment/tools/intercity_transport/apis.py @@ -9,10 +9,11 @@ def time2float(time_str): class IntercityTransport: - def __init__(self, path: str = "../../database/intercity_transport/"): + def __init__(self, path: str = "../../database/intercity_transport/", en_version=False): + file_suffix = "_en" if en_version else "" curdir = os.path.dirname(os.path.realpath(__file__)) self.base_path = os.path.join(curdir, path) - self.airplane_path = self.base_path + "airplane.jsonl" + self.airplane_path = self.base_path + f"airplane{file_suffix}.jsonl" self.airplane_df = pd.read_json( self.airplane_path, lines=True, keep_default_dates=False ) @@ -28,6 +29,14 @@ def __init__(self, path: str = "../../database/intercity_transport/"): "武汉", "南京", ] + self.city_list = city_list + self.city_en_list = [ + "shanghai", "beijing", "shenzhen", "guangzhou", "chongqing", + "suzhou", "chengdu", "hangzhou", "wuhan", "nanjing", + ] + self.city_cn_to_en = dict(zip(self.city_list, self.city_en_list)) + self.city_en_to_cn = dict(zip(self.city_en_list, self.city_list)) + self.train_df_dict = {} for start_city in city_list: @@ -37,7 +46,7 @@ def __init__(self, path: str = "../../database/intercity_transport/"): train_path = ( self.base_path + "train/" - + "from_{}_to_{}.json".format(start_city, end_city) + + f"from_{start_city}_to_{end_city}{file_suffix}.json" ) train_df = pd.read_json(train_path) self.train_df_dict[(start_city, end_city)] = train_df @@ -45,6 +54,10 @@ def __init__(self, path: str = "../../database/intercity_transport/"): def select( self, start_city, end_city, intercity_type, earliest_leave_time="00:00" ) -> DataFrame: + if start_city in self.city_en_to_cn: + start_city = self.city_en_to_cn[start_city] + if end_city in self.city_en_to_cn: + end_city = self.city_en_to_cn[end_city] if intercity_type not in ["train", "airplane"]: return "only support intercity_type in ['train','airplane']" res = self._select(start_city, end_city, intercity_type) diff --git a/chinatravel/environment/tools/poi/apis.py b/chinatravel/environment/tools/poi/apis.py index 4719607..c4750f8 100644 --- a/chinatravel/environment/tools/poi/apis.py +++ b/chinatravel/environment/tools/poi/apis.py @@ -4,6 +4,8 @@ class Poi: def __init__(self, base_path: str = "../../database/poi/", en_version=False): + self.en_version = en_version + file_suffix = "_en" if en_version else "" city_list = [ "beijing", @@ -19,7 +21,7 @@ def __init__(self, base_path: str = "../../database/poi/", en_version=False): ] curdir = os.path.dirname(os.path.realpath(__file__)) data_path_list = [ - os.path.join(curdir, f"{base_path}/{city}/poi.json") for city in city_list + os.path.join(curdir, f"{base_path}/{city}/poi{file_suffix}.json") for city in city_list ] self.data = {} for i, city in enumerate(city_list): @@ -46,7 +48,7 @@ def __init__(self, base_path: str = "../../database/poi/", en_version=False): "重庆", ] for i, city in enumerate(city_list): - self.data[city_cn_list[i]] = self.data.pop(city) + self.data[city_cn_list[i]] = self.data[city] self.city_cn_list = city_cn_list self.city_list = city_list @@ -57,7 +59,9 @@ def search(self, city: str, name: str): try: return city_data[name] except KeyError: - return f"No such point in the city. Check the point name: {name}." + if self.en_version: + return f"No such point in the city. Check the point name: {name}." + return f"城市中没有该地点,请检查地点名称: {name}." def test(): diff --git a/chinatravel/environment/tools/restaurants/apis.py b/chinatravel/environment/tools/restaurants/apis.py index c30d188..fe8f0b2 100644 --- a/chinatravel/environment/tools/restaurants/apis.py +++ b/chinatravel/environment/tools/restaurants/apis.py @@ -11,7 +11,8 @@ class Restaurants: - def __init__(self, base_path: str = "../../database/restaurants"): + def __init__(self, base_path: str = "../../database/restaurants", en_version=False): + file_suffix = "_en" if en_version else "" city_list = [ "beijing", "shanghai", @@ -27,7 +28,7 @@ def __init__(self, base_path: str = "../../database/restaurants"): self.data = {} curdir = os.path.dirname(os.path.realpath(__file__)) for city in city_list: - path = os.path.join(curdir, base_path, city, "restaurants_" + city + ".csv") + path = os.path.join(curdir, base_path, city, f"restaurants_{city}{file_suffix}.csv") self.data[city] = pd.read_csv(path) self.key_type_tuple_list_map = {} @@ -54,13 +55,11 @@ def __init__(self, base_path: str = "../../database/restaurants"): ] for i, city in enumerate(city_list): - self.data[city_cn_list[i]] = self.data.pop(city) - self.key_type_tuple_list_map[city_cn_list[i]] = ( - self.key_type_tuple_list_map.pop(city) - ) - self.cuisine_list_map[city_cn_list[i]] = self.cuisine_list_map.pop(city) + self.data[city_cn_list[i]] = self.data[city] + self.key_type_tuple_list_map[city_cn_list[i]] = self.key_type_tuple_list_map[city] + self.cuisine_list_map[city_cn_list[i]] = self.cuisine_list_map[city] - self.poi = Poi() + self.poi = Poi(en_version=en_version) def keys(self, city: str): return self.key_type_tuple_list_map[city] @@ -77,12 +76,12 @@ def id_is_open(self, city: str, id: int, time: str) -> bool: end_time = match["endtime"].values[0] open_time = ( -1 - if open_time == "不营业" + if open_time in ["不营业", "closed"] else float(open_time.split(":")[0]) + float(open_time.split(":")[1]) / 60 ) end_time = ( -1 - if end_time == "不营业" + if end_time in ["不营业", "closed"] else float(end_time.split(":")[0]) + float(end_time.split(":")[1]) / 60 ) time = float(time.split(":")[0]) + float(time.split(":")[1]) / 60 diff --git a/chinatravel/environment/tools/transportation/apis.py b/chinatravel/environment/tools/transportation/apis.py index 81668e2..a9d4832 100644 --- a/chinatravel/environment/tools/transportation/apis.py +++ b/chinatravel/environment/tools/transportation/apis.py @@ -139,6 +139,7 @@ class Transportation: def __init__( self, base_path: str = "../../database/transportation/", en_version=False ): + file_suffix = "_en" if en_version else "" self.city_list = [ "shanghai", "beijing", @@ -165,7 +166,7 @@ def __init__( ] curdir = os.path.dirname(os.path.realpath(__file__)) - SUBWAY_PATH = os.path.join(curdir, base_path + "subways.json") + SUBWAY_PATH = os.path.join(curdir, base_path + f"subways{file_suffix}.json") self.city_stations_dict = {} self.city_lines_dict = {} self.city_station_to_line = {} @@ -181,7 +182,7 @@ def __init__( for city in self.city_list: self.graphs[city] = build_graph(self.city_lines_dict[city]) - self.poi_search = Poi() + self.poi_search = Poi(en_version=en_version) def goto(self, city, start, end, start_time, transport_type, verbose=False): if transport_type not in ["walk", "metro", "taxi"]: @@ -286,7 +287,7 @@ def goto(self, city, start, end, start_time, transport_type, verbose=False): transports.append( { "start": locationA_name, - "end": stationA["name"] + "-地铁站", + "end": stationA["name"] + ("-metro station" if self.poi_search.en_version else "-地铁站"), "mode": "walk", "start_time": start_time, "end_time": end_timeA, @@ -296,8 +297,8 @@ def goto(self, city, start, end, start_time, transport_type, verbose=False): ) transports.append( { - "start": stationA["name"] + "-地铁站", - "end": stationB["name"] + "-地铁站", + "start": stationA["name"] + ("-metro station" if self.poi_search.en_version else "-地铁站"), + "end": stationB["name"] + ("-metro station" if self.poi_search.en_version else "-地铁站"), "mode": "metro", "start_time": end_timeA, "end_time": end_timeB, @@ -307,7 +308,7 @@ def goto(self, city, start, end, start_time, transport_type, verbose=False): ) transports.append( { - "start": stationB["name"] + "-地铁站", + "start": stationB["name"] + ("-metro station" if self.poi_search.en_version else "-地铁站"), "end": locationB_name, "mode": "walk", "start_time": end_timeB, diff --git a/chinatravel/environment/world_env.py b/chinatravel/environment/world_env.py index e146001..7996164 100644 --- a/chinatravel/environment/world_env.py +++ b/chinatravel/environment/world_env.py @@ -103,12 +103,12 @@ def __init__(self, en_version=False): "武汉", "南京", ] - self.attractions = Attractions() - self.accommodations = Accommodations() - self.restaurants = Restaurants() - self.intercitytransport = IntercityTransport() - self.transportation = Transportation() - self.poi = Poi() + self.attractions = Attractions(en_version=en_version) + self.accommodations = Accommodations(en_version=en_version) + self.restaurants = Restaurants(en_version=en_version) + self.intercitytransport = IntercityTransport(en_version=en_version) + self.transportation = Transportation(en_version=en_version) + self.poi = Poi(en_version=en_version) self.results = [] diff --git a/eval_exp.py b/eval_exp.py index 79a59cf..04c7db9 100644 --- a/eval_exp.py +++ b/eval_exp.py @@ -74,6 +74,7 @@ def load_result_for_method(method): if __name__ == "__main__": parser = argparse.ArgumentParser() + parser.add_argument("--lang", type=str, default="zh", choices=["zh", "en"], help="query language") parser.add_argument("--splits", "-s", type=str, default="example") parser.add_argument( "--method", "-m", type=str, default="example" diff --git a/eval_tpc.py b/eval_tpc.py index 1f27b45..b86c650 100644 --- a/eval_tpc.py +++ b/eval_tpc.py @@ -145,6 +145,7 @@ def write_file(file, content): if __name__ == "__main__": parser = argparse.ArgumentParser() + parser.add_argument("--lang", type=str, default="zh", choices=["zh", "en"], help="query language") parser.add_argument("--splits", "-s", type=str, default="example") parser.add_argument( "--method", "-m", type=str, default="travel_agent" diff --git a/run_exp.py b/run_exp.py index 4576a0f..54ccd09 100644 --- a/run_exp.py +++ b/run_exp.py @@ -20,6 +20,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description="argparse testing") + parser.add_argument("--lang", type=str, default="zh", choices=["zh", "en"], help="dataset/environment language") parser.add_argument( "--splits", "-s", @@ -97,7 +98,7 @@ max_model_len = None kwargs = { "method": args.agent, - "env": WorldEnv(), + "env": WorldEnv(en_version=args.lang == "en"), "backbone_llm": init_llm(args.llm, max_model_len=max_model_len), "cache_dir": cache_dir, "log_dir": log_dir, @@ -134,7 +135,7 @@ query_i = query_data[data_idx] print(query_i) if args.agent in ["ReAct", "ReAct0", "Act"]: - plan_log = agent(query_i["nature_language"]) + plan_log = agent(query_i.get("nature_language", query_i.get("nature_language_en", ""))) plan = plan_log["ans"] if isinstance(plan, str): try: diff --git a/run_tpc.py b/run_tpc.py index a463ffe..d23f84e 100644 --- a/run_tpc.py +++ b/run_tpc.py @@ -21,6 +21,7 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description="argparse testing") + parser.add_argument("--lang", type=str, default="zh", choices=["zh", "en"], help="dataset/environment language") parser.add_argument( "--splits", "-s", @@ -93,7 +94,7 @@ kwargs = { "method": args.agent, - "env": WorldEnv(), + "env": WorldEnv(en_version=args.lang == "en"), "backbone_llm": init_llm(args.llm), "cache_dir": cache_dir, "log_dir": log_dir,