Skip to content

Commit 71d00fd

Browse files
committed
fix wiki bugs
1 parent e4920ef commit 71d00fd

18 files changed

+2737
-34
lines changed
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
import itertools
2+
import xmlrpc.client
3+
import typing as tp
4+
from concurrent.futures import ThreadPoolExecutor
5+
6+
class WikidataQueryClient:
7+
def __init__(self, url: str):
8+
self.url = url
9+
self.server = xmlrpc.client.ServerProxy(url)
10+
11+
def label2qid(self, label: str) -> str:
12+
return self.server.label2qid(label)
13+
14+
def label2pid(self, label: str) -> str:
15+
return self.server.label2pid(label)
16+
17+
def pid2label(self, pid: str) -> str:
18+
return self.server.pid2label(pid)
19+
20+
def qid2label(self, qid: str) -> str:
21+
return self.server.qid2label(qid)
22+
23+
def get_all_relations_of_an_entity(
24+
self, entity_qid: str
25+
) -> tp.Dict[str, tp.List]:
26+
return self.server.get_all_relations_of_an_entity(entity_qid)
27+
28+
def get_tail_entities_given_head_and_relation(
29+
self, head_qid: str, relation_pid: str
30+
) -> tp.Dict[str, tp.List]:
31+
return self.server.get_tail_entities_given_head_and_relation(
32+
head_qid, relation_pid
33+
)
34+
35+
def get_tail_values_given_head_and_relation(
36+
self, head_qid: str, relation_pid: str
37+
) -> tp.List[str]:
38+
return self.server.get_tail_values_given_head_and_relation(
39+
head_qid, relation_pid
40+
)
41+
42+
def get_external_id_given_head_and_relation(
43+
self, head_qid: str, relation_pid: str
44+
) -> tp.List[str]:
45+
return self.server.get_external_id_given_head_and_relation(
46+
head_qid, relation_pid
47+
)
48+
49+
def mid2qid(self, mid: str) -> str:
50+
return self.server.mid2qid(mid)
51+
52+
53+
import time
54+
import typing as tp
55+
from concurrent.futures import ThreadPoolExecutor
56+
57+
58+
class MultiServerWikidataQueryClient:
59+
def __init__(self, urls: tp.List[str]):
60+
self.clients = [WikidataQueryClient(url) for url in urls]
61+
self.executor = ThreadPoolExecutor(max_workers=len(urls))
62+
# test connections
63+
start_time = time.perf_counter()
64+
self.test_connections()
65+
end_time = time.perf_counter()
66+
print(f"Connection testing took {end_time - start_time} seconds")
67+
68+
def test_connections(self):
69+
def test_url(client):
70+
try:
71+
# Check if server provides the system.listMethods function.
72+
client.server.system.listMethods()
73+
return True
74+
except Exception as e:
75+
print(f"Failed to connect to {client.url}. Error: {str(e)}")
76+
return False
77+
78+
start_time = time.perf_counter()
79+
futures = [
80+
self.executor.submit(test_url, client) for client in self.clients
81+
]
82+
results = [f.result() for f in futures]
83+
end_time = time.perf_counter()
84+
# print(f"Testing connections took {end_time - start_time} seconds")
85+
# Remove clients that failed to connect
86+
self.clients = [
87+
client for client, result in zip(self.clients, results) if result
88+
]
89+
if not self.clients:
90+
raise Exception("Failed to connect to all URLs")
91+
92+
def query_all(self, method, *args):
93+
start_time = time.perf_counter()
94+
futures = [
95+
self.executor.submit(getattr(client, method), *args)
96+
for client in self.clients
97+
]
98+
# Retrieve results and filter out 'Not Found!'
99+
is_dict_return = method in [
100+
"get_all_relations_of_an_entity",
101+
"get_tail_entities_given_head_and_relation",
102+
]
103+
results = [f.result() for f in futures]
104+
end_time = time.perf_counter()
105+
# print(f"HTTP Queries took {end_time - start_time} seconds")
106+
107+
start_time = time.perf_counter()
108+
real_results = set() if not is_dict_return else {"head": [], "tail": []}
109+
for res in results:
110+
if isinstance(res, str) and res == "Not Found!":
111+
continue
112+
elif isinstance(res, tp.List):
113+
if len(res) == 0:
114+
continue
115+
if isinstance(res[0], tp.List):
116+
res_flattened = itertools.chain(*res)
117+
real_results.update(res_flattened)
118+
continue
119+
real_results.update(res)
120+
elif is_dict_return:
121+
real_results["head"].extend(res["head"])
122+
real_results["tail"].extend(res["tail"])
123+
else:
124+
real_results.add(res)
125+
end_time = time.perf_counter()
126+
# print(f"Querying all took {end_time - start_time} seconds")
127+
128+
return real_results if len(real_results) > 0 else "Not Found!"
129+
130+
131+
if __name__ == "__main__":
132+
import argparse
133+
134+
parser = argparse.ArgumentParser()
135+
parser.add_argument(
136+
"--addr_list",
137+
type=str,
138+
required=True,
139+
help="path to server address list",
140+
)
141+
args = parser.parse_args()
142+
143+
with open(args.addr_list, "r") as f:
144+
server_addrs = f.readlines()
145+
server_addrs = [addr.strip() for addr in server_addrs]
146+
print(f"Server addresses: {server_addrs}")
147+
client = MultiServerWikidataQueryClient(server_addrs)
148+
print(
149+
f'MSFT\'s ticker code is {client.query_all("get_tail_values_given_head_and_relation","Q2283","P249",)}'
150+
)

0 commit comments

Comments
 (0)