Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions chinatravel/data/load_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,27 @@ def default(self, obj):
return super(NpEncoder, self).default(obj)


def _get_language(args):
return getattr(args, "lang", "zh")


def _normalize_query_language(data_i, lang):
if lang == "zh":
return data_i

key_pairs = [
("nature_language", "nature_language_en"),
("hard_logic", "hard_logic_en"),
("hard_logic_py", "hard_logic_py_en"),
("hard_logic_nl", "hard_logic_nl_en"),
("preference", "preference_en"),
]
for base_key, lang_key in key_pairs:
if lang_key in data_i:
data_i[base_key] = data_i[lang_key]
return data_i


def load_query_local(args, version="", verbose=False):
query_data = {}

Expand Down Expand Up @@ -63,6 +84,8 @@ def load_query_local(args, version="", verbose=False):
open(os.path.join(dir_ii, file_i), encoding="utf-8")
)

data_i = _normalize_query_language(data_i, _get_language(args))

if hasattr(args, 'oracle_translation') and not args.oracle_translation:
if "hard_logic" in data_i:
del data_i["hard_logic"]
Expand Down Expand Up @@ -111,7 +134,8 @@ def load_query(args):


for data_i in query_data:
if "hard_logic_py" in data_i:
data_i = _normalize_query_language(data_i, _get_language(args))
if "hard_logic_py" in data_i and isinstance(data_i["hard_logic_py"], str):
data_i["hard_logic_py"] = ast.literal_eval(data_i["hard_logic_py"])

query_id_list = [data_i["uid"] for data_i in query_data]
Expand Down Expand Up @@ -155,4 +179,4 @@ def load_query(args):
print(uid, query_data[uid])
else:
raise ValueError(f"{uid} not in query_data")


9 changes: 4 additions & 5 deletions chinatravel/environment/tools/accommodations/apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class Accommodations:
def __init__(
self, base_path: str = "../../database/accommodations/", en_version=False
):
file_suffix = "_en" if en_version else ""
curdir = os.path.dirname(os.path.realpath(__file__))
city_list = [
"beijing",
Expand All @@ -29,7 +30,7 @@ def __init__(
"chongqing",
]
data_path_list = [
os.path.join(curdir, f"{base_path}/{city}/accommodations.csv")
os.path.join(curdir, f"{base_path}/{city}/accommodations{file_suffix}.csv")
for city in city_list
]
self.data = {}
Expand All @@ -56,10 +57,8 @@ def __init__(
]

for i, city in enumerate(city_list):
self.data[city_cn_list[i]] = self.data.pop(city)
self.key_type_tuple_list[city_cn_list[i]] = self.key_type_tuple_list.pop(
city
)
self.data[city_cn_list[i]] = self.data[city]
self.key_type_tuple_list[city_cn_list[i]] = self.key_type_tuple_list[city]

self.poi = Poi(en_version=en_version)

Expand Down
13 changes: 6 additions & 7 deletions chinatravel/environment/tools/attractions/apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def __init__(
base_path: str = "../../database/attractions",
en_version=False,
):
file_suffix = "_en" if en_version else ""
city_list = [
"beijing",
"shanghai",
Expand All @@ -30,7 +31,7 @@ def __init__(
]
curdir = os.path.dirname(os.path.realpath(__file__))
data_path_list = [
os.path.join(curdir, f"{base_path}/{city}/attractions.csv")
os.path.join(curdir, f"{base_path}/{city}/attractions{file_suffix}.csv")
for city in city_list
]

Expand Down Expand Up @@ -61,13 +62,11 @@ def __init__(
]

for i, city in enumerate(city_list):
self.data[city_cn_list[i]] = self.data.pop(city)
self.key_type_tuple_list_map[city_cn_list[i]] = (
self.key_type_tuple_list_map.pop(city)
)
self.type_list_map[city_cn_list[i]] = self.type_list_map.pop(city)
self.data[city_cn_list[i]] = self.data[city]
self.key_type_tuple_list_map[city_cn_list[i]] = self.key_type_tuple_list_map[city]
self.type_list_map[city_cn_list[i]] = self.type_list_map[city]

self.poi = Poi()
self.poi = Poi(en_version=en_version)

def keys(self, city: str):
return self.key_type_tuple_list_map[city]
Expand Down
19 changes: 16 additions & 3 deletions chinatravel/environment/tools/intercity_transport/apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ def time2float(time_str):


class IntercityTransport:
def __init__(self, path: str = "../../database/intercity_transport/"):
def __init__(self, path: str = "../../database/intercity_transport/", en_version=False):
file_suffix = "_en" if en_version else ""
curdir = os.path.dirname(os.path.realpath(__file__))
self.base_path = os.path.join(curdir, path)
self.airplane_path = self.base_path + "airplane.jsonl"
self.airplane_path = self.base_path + f"airplane{file_suffix}.jsonl"
self.airplane_df = pd.read_json(
self.airplane_path, lines=True, keep_default_dates=False
)
Expand All @@ -28,6 +29,14 @@ def __init__(self, path: str = "../../database/intercity_transport/"):
"武汉",
"南京",
]
self.city_list = city_list
self.city_en_list = [
"shanghai", "beijing", "shenzhen", "guangzhou", "chongqing",
"suzhou", "chengdu", "hangzhou", "wuhan", "nanjing",
]
self.city_cn_to_en = dict(zip(self.city_list, self.city_en_list))
self.city_en_to_cn = dict(zip(self.city_en_list, self.city_list))

self.train_df_dict = {}

for start_city in city_list:
Expand All @@ -37,14 +46,18 @@ def __init__(self, path: str = "../../database/intercity_transport/"):
train_path = (
self.base_path
+ "train/"
+ "from_{}_to_{}.json".format(start_city, end_city)
+ f"from_{start_city}_to_{end_city}{file_suffix}.json"
)
train_df = pd.read_json(train_path)
self.train_df_dict[(start_city, end_city)] = train_df

def select(
self, start_city, end_city, intercity_type, earliest_leave_time="00:00"
) -> DataFrame:
if start_city in self.city_en_to_cn:
start_city = self.city_en_to_cn[start_city]
if end_city in self.city_en_to_cn:
end_city = self.city_en_to_cn[end_city]
if intercity_type not in ["train", "airplane"]:
return "only support intercity_type in ['train','airplane']"
res = self._select(start_city, end_city, intercity_type)
Expand Down
10 changes: 7 additions & 3 deletions chinatravel/environment/tools/poi/apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

class Poi:
def __init__(self, base_path: str = "../../database/poi/", en_version=False):
self.en_version = en_version
file_suffix = "_en" if en_version else ""

city_list = [
"beijing",
Expand All @@ -19,7 +21,7 @@ def __init__(self, base_path: str = "../../database/poi/", en_version=False):
]
curdir = os.path.dirname(os.path.realpath(__file__))
data_path_list = [
os.path.join(curdir, f"{base_path}/{city}/poi.json") for city in city_list
os.path.join(curdir, f"{base_path}/{city}/poi{file_suffix}.json") for city in city_list
]
self.data = {}
for i, city in enumerate(city_list):
Expand All @@ -46,7 +48,7 @@ def __init__(self, base_path: str = "../../database/poi/", en_version=False):
"重庆",
]
for i, city in enumerate(city_list):
self.data[city_cn_list[i]] = self.data.pop(city)
self.data[city_cn_list[i]] = self.data[city]
self.city_cn_list = city_cn_list
self.city_list = city_list

Expand All @@ -57,7 +59,9 @@ def search(self, city: str, name: str):
try:
return city_data[name]
except KeyError:
return f"No such point in the city. Check the point name: {name}."
if self.en_version:
return f"No such point in the city. Check the point name: {name}."
return f"城市中没有该地点,请检查地点名称: {name}."


def test():
Expand Down
19 changes: 9 additions & 10 deletions chinatravel/environment/tools/restaurants/apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@


class Restaurants:
def __init__(self, base_path: str = "../../database/restaurants"):
def __init__(self, base_path: str = "../../database/restaurants", en_version=False):
file_suffix = "_en" if en_version else ""
city_list = [
"beijing",
"shanghai",
Expand All @@ -27,7 +28,7 @@ def __init__(self, base_path: str = "../../database/restaurants"):
self.data = {}
curdir = os.path.dirname(os.path.realpath(__file__))
for city in city_list:
path = os.path.join(curdir, base_path, city, "restaurants_" + city + ".csv")
path = os.path.join(curdir, base_path, city, f"restaurants_{city}{file_suffix}.csv")
self.data[city] = pd.read_csv(path)

self.key_type_tuple_list_map = {}
Expand All @@ -54,13 +55,11 @@ def __init__(self, base_path: str = "../../database/restaurants"):
]

for i, city in enumerate(city_list):
self.data[city_cn_list[i]] = self.data.pop(city)
self.key_type_tuple_list_map[city_cn_list[i]] = (
self.key_type_tuple_list_map.pop(city)
)
self.cuisine_list_map[city_cn_list[i]] = self.cuisine_list_map.pop(city)
self.data[city_cn_list[i]] = self.data[city]
self.key_type_tuple_list_map[city_cn_list[i]] = self.key_type_tuple_list_map[city]
self.cuisine_list_map[city_cn_list[i]] = self.cuisine_list_map[city]

self.poi = Poi()
self.poi = Poi(en_version=en_version)

def keys(self, city: str):
return self.key_type_tuple_list_map[city]
Expand All @@ -77,12 +76,12 @@ def id_is_open(self, city: str, id: int, time: str) -> bool:
end_time = match["endtime"].values[0]
open_time = (
-1
if open_time == "不营业"
if open_time in ["不营业", "closed"]
else float(open_time.split(":")[0]) + float(open_time.split(":")[1]) / 60
)
end_time = (
-1
if end_time == "不营业"
if end_time in ["不营业", "closed"]
else float(end_time.split(":")[0]) + float(end_time.split(":")[1]) / 60
)
time = float(time.split(":")[0]) + float(time.split(":")[1]) / 60
Expand Down
13 changes: 7 additions & 6 deletions chinatravel/environment/tools/transportation/apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ class Transportation:
def __init__(
self, base_path: str = "../../database/transportation/", en_version=False
):
file_suffix = "_en" if en_version else ""
self.city_list = [
"shanghai",
"beijing",
Expand All @@ -165,7 +166,7 @@ def __init__(
]

curdir = os.path.dirname(os.path.realpath(__file__))
SUBWAY_PATH = os.path.join(curdir, base_path + "subways.json")
SUBWAY_PATH = os.path.join(curdir, base_path + f"subways{file_suffix}.json")
self.city_stations_dict = {}
self.city_lines_dict = {}
self.city_station_to_line = {}
Expand All @@ -181,7 +182,7 @@ def __init__(
for city in self.city_list:
self.graphs[city] = build_graph(self.city_lines_dict[city])

self.poi_search = Poi()
self.poi_search = Poi(en_version=en_version)

def goto(self, city, start, end, start_time, transport_type, verbose=False):
if transport_type not in ["walk", "metro", "taxi"]:
Expand Down Expand Up @@ -286,7 +287,7 @@ def goto(self, city, start, end, start_time, transport_type, verbose=False):
transports.append(
{
"start": locationA_name,
"end": stationA["name"] + "-地铁站",
"end": stationA["name"] + ("-metro station" if self.poi_search.en_version else "-地铁站"),
"mode": "walk",
"start_time": start_time,
"end_time": end_timeA,
Expand All @@ -296,8 +297,8 @@ def goto(self, city, start, end, start_time, transport_type, verbose=False):
)
transports.append(
{
"start": stationA["name"] + "-地铁站",
"end": stationB["name"] + "-地铁站",
"start": stationA["name"] + ("-metro station" if self.poi_search.en_version else "-地铁站"),
"end": stationB["name"] + ("-metro station" if self.poi_search.en_version else "-地铁站"),
"mode": "metro",
"start_time": end_timeA,
"end_time": end_timeB,
Expand All @@ -307,7 +308,7 @@ def goto(self, city, start, end, start_time, transport_type, verbose=False):
)
transports.append(
{
"start": stationB["name"] + "-地铁站",
"start": stationB["name"] + ("-metro station" if self.poi_search.en_version else "-地铁站"),
"end": locationB_name,
"mode": "walk",
"start_time": end_timeB,
Expand Down
12 changes: 6 additions & 6 deletions chinatravel/environment/world_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,12 @@ def __init__(self, en_version=False):
"武汉",
"南京",
]
self.attractions = Attractions()
self.accommodations = Accommodations()
self.restaurants = Restaurants()
self.intercitytransport = IntercityTransport()
self.transportation = Transportation()
self.poi = Poi()
self.attractions = Attractions(en_version=en_version)
self.accommodations = Accommodations(en_version=en_version)
self.restaurants = Restaurants(en_version=en_version)
self.intercitytransport = IntercityTransport(en_version=en_version)
self.transportation = Transportation(en_version=en_version)
self.poi = Poi(en_version=en_version)

self.results = []

Expand Down
1 change: 1 addition & 0 deletions eval_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def load_result_for_method(method):
if __name__ == "__main__":

parser = argparse.ArgumentParser()
parser.add_argument("--lang", type=str, default="zh", choices=["zh", "en"], help="query language")
parser.add_argument("--splits", "-s", type=str, default="example")
parser.add_argument(
"--method", "-m", type=str, default="example"
Expand Down
1 change: 1 addition & 0 deletions eval_tpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def write_file(file, content):
if __name__ == "__main__":

parser = argparse.ArgumentParser()
parser.add_argument("--lang", type=str, default="zh", choices=["zh", "en"], help="query language")
parser.add_argument("--splits", "-s", type=str, default="example")
parser.add_argument(
"--method", "-m", type=str, default="travel_agent"
Expand Down
5 changes: 3 additions & 2 deletions run_exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
if __name__ == "__main__":

parser = argparse.ArgumentParser(description="argparse testing")
parser.add_argument("--lang", type=str, default="zh", choices=["zh", "en"], help="dataset/environment language")
parser.add_argument(
"--splits",
"-s",
Expand Down Expand Up @@ -97,7 +98,7 @@
max_model_len = None
kwargs = {
"method": args.agent,
"env": WorldEnv(),
"env": WorldEnv(en_version=args.lang == "en"),
"backbone_llm": init_llm(args.llm, max_model_len=max_model_len),
"cache_dir": cache_dir,
"log_dir": log_dir,
Expand Down Expand Up @@ -134,7 +135,7 @@
query_i = query_data[data_idx]
print(query_i)
if args.agent in ["ReAct", "ReAct0", "Act"]:
plan_log = agent(query_i["nature_language"])
plan_log = agent(query_i.get("nature_language", query_i.get("nature_language_en", "")))
plan = plan_log["ans"]
if isinstance(plan, str):
try:
Expand Down
3 changes: 2 additions & 1 deletion run_tpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
if __name__ == "__main__":

parser = argparse.ArgumentParser(description="argparse testing")
parser.add_argument("--lang", type=str, default="zh", choices=["zh", "en"], help="dataset/environment language")
parser.add_argument(
"--splits",
"-s",
Expand Down Expand Up @@ -93,7 +94,7 @@

kwargs = {
"method": args.agent,
"env": WorldEnv(),
"env": WorldEnv(en_version=args.lang == "en"),
"backbone_llm": init_llm(args.llm),
"cache_dir": cache_dir,
"log_dir": log_dir,
Expand Down