Skip to content

Commit 18efdcb

Browse files
authored
Merge branch 'staging' into SN1-447-add-logits
2 parents ddd3529 + 6c0b13b commit 18efdcb

File tree

8 files changed

+121
-19
lines changed

8 files changed

+121
-19
lines changed

poetry.lock

+16
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

prompting/datasets/random_website.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,16 @@ class DDGDataset(BaseDataset):
2626

2727
def search_random_term(self, retries: int = 3) -> tuple[Optional[str], Optional[list[dict[str, str]]]]:
2828
ddg = PatchedDDGS(proxy=settings.shared_settings.PROXY_URL, verify=False)
29+
exception: BaseException | None = None
2930
for _ in range(retries):
3031
random_words = " ".join(random.sample(ENGLISH_WORDS, 3))
3132
try:
3233
results = list(ddg.text(random_words))
3334
if results:
3435
return random_words, results
35-
except Exception as ex:
36-
logger.debug(f"Failed to get search results from DuckDuckGo: {ex}")
37-
logger.warning(f"Failed to get search results from DuckDuckGo after {retries} tries")
36+
except BaseException as ex:
37+
exception = ex
38+
logger.warning(f"Failed to get search results from DuckDuckGo after {retries} tries: {exception}")
3839
return None, None
3940

4041
@staticmethod

prompting/llms/apis/llm_messages.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def calculate_image_tokens(width: int, height: int, low_res: bool = False) -> in
3737

3838
class LLMMessage(BaseModel):
3939
role: Literal["system", "user", "assistant"]
40-
content: str = None
40+
content: str | None = None
4141
image: Image.Image | None = None
4242

4343
class Config:

prompting/rewards/web_retrieval.py

+45-8
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
import json
55
import os
66
from collections import defaultdict
7+
from datetime import datetime
78

89
import numpy as np
910
import pandas as pd
11+
import whois
1012
from loguru import logger
1113
from pydantic import BaseModel
1214
from scipy import spatial
@@ -48,6 +50,9 @@
4850
# Maximum number of past URLs to store per user
4951
N_PAST_URLS = 200
5052

53+
# Minimum age of the website.
54+
MIN_AGE_DAYS = 90
55+
5156
# Load the past_websites dictionary and top domains
5257
try:
5358
# Load top domains
@@ -98,6 +103,36 @@ def __hash__(self):
98103
# Use the id of the object as its hash
99104
return hash(self.model_dump_json)
100105

106+
@staticmethod
107+
@async_lru_cache(maxsize=1000)
108+
async def domain_age_days(domain: str, fallback_age: int = 1_000_000) -> int:
109+
"""Returns the age of a domain in days.
110+
111+
Args:
112+
domain: Website url.
113+
fallback_age: If can't fetch domain age, fallback to `fallback_age` age.
114+
115+
Returns:
116+
Domain age in days since creation.
117+
"""
118+
fallback_age = 1_000_000
119+
try:
120+
w = whois.whois(domain)
121+
creation_date = w.creation_date
122+
if isinstance(creation_date, list):
123+
creation_date = creation_date[0]
124+
125+
if creation_date is None:
126+
return fallback_age
127+
# Convert everything to naive datetime in UTC or local.
128+
if hasattr(creation_date, "tzinfo") and creation_date.tzinfo is not None:
129+
creation_date = creation_date.replace(tzinfo=None)
130+
delta = datetime.now() - creation_date
131+
return delta.days
132+
except BaseException as e:
133+
logger.debug(f"Error fetching domain age data: {e}")
134+
return fallback_age
135+
101136
@async_lru_cache(maxsize=1000)
102137
async def _cosine_similarity(self, content1: str, content2: str) -> float:
103138
"""Calculate the cosine similarity between sentence embeddings of the reference and completions."""
@@ -111,13 +146,18 @@ async def score_website_result(
111146
if not response_url or not response_content or not response_relevant:
112147
return 0
113148

114-
# Extract domain from URL
149+
# Extract domain from URL.
115150
netloc = extract_main_domain(response_url)
151+
logger.debug(f"Scoring url: {response_url}")
116152

117153
if any(term in response_url for term in BLACKLISTED_TERMS):
118154
logger.debug(f"Domain {response_url} contains blacklisted term, scoring 0")
119155
return 0
120156

157+
if (days := await self.domain_age_days(response_url)) < MIN_AGE_DAYS:
158+
logger.debug(f"Domain {response_url} is too young ({days} days old), scoring 0")
159+
return 0
160+
121161
# Penalise a completion where the relevant section is contained in the URL (e.g. miners)
122162
# trying to use a search box to enter exactly the relevant section they need
123163
discount_factor = 1 - fuzz.token_sort_ratio(response_url, response_relevant) / 100
@@ -147,17 +187,17 @@ async def score_website_result(
147187
# Content scraped from the URL provided in the completion.
148188
reference_website_content = DDGDataset.extract_website_content(response_url)
149189
if not reference_website_content or len(reference_website_content) == 0:
150-
logger.debug(f"Failed to extract miner's content from website: {response_url}")
190+
logger.debug(f"Failed to extract miner {uid} content from website: {response_url}")
151191
return 0
152192

153193
if fuzz.ratio(response_content, reference_website_content) < MIN_MATCH_THRESHOLD:
154-
logger.debug("Miner returned text that doesn't match the website, scoring 0")
194+
logger.debug(f"Miner {uid} returned text that doesn't match the website, scoring 0")
155195
return 0
156196

157197
if len(response_relevant) > len(response_content) or len(response_relevant) < MIN_RELEVANT_CHARS:
158198
logger.debug(
159-
f"Relevant section is too short (<{MIN_RELEVANT_CHARS} chars) or longer than the whole website content "
160-
f"{len(response_relevant)} > {len(response_content)}"
199+
f"Miner {uid} relevant section is too short (<{MIN_RELEVANT_CHARS} chars) or longer than the whole "
200+
f"website content {len(response_relevant)} > {len(response_content)}"
161201
)
162202
return 0
163203

@@ -209,9 +249,6 @@ async def reward(
209249
rewards.append(await self.score_miner_response(dataset_entry, completion, task=task, uid=uid))
210250
timings.append(0)
211251

212-
logger.debug(f"REWARDWEBRETRIEVAL: {rewards}")
213-
logger.debug(f"COMPLETIONS: {response_event.completions}")
214-
215252
# Save the past_websites dictionary to CSV
216253
past_websites_data = []
217254
for uid, domains in past_websites.items():

pyproject.toml

+2
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ datasets = { version = ">=3.1.0", optional = true }
159159
nltk = { version = ">=3.8.1", optional = true }
160160
thefuzz = { version = ">=0.22.1", optional = true }
161161
wandb = { version = ">=0.19.4", optional = true }
162+
python-whois = { version = ">=0.9.5", optional = true }
162163
substrate-interface = "^1.7.11"
163164
tldextract = "^5.1.3"
164165
justext = "3.0.2"
@@ -183,6 +184,7 @@ validator = [
183184
"datasets",
184185
"nltk",
185186
"wandb",
187+
"python-whois",
186188
]
187189

188190
[build-system]

shared/epistula.py

+31-5
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,21 @@ async def query_miners(
123123
tasks = []
124124
for uid in uids:
125125
try:
126+
timeout_connect = 10
127+
timeout_postprocess = 5
126128
response = asyncio.wait_for(
127129
asyncio.create_task(
128-
make_openai_query(shared_settings.METAGRAPH, shared_settings.WALLET, timeout_seconds, body, uid)
130+
make_openai_query(
131+
shared_settings.METAGRAPH,
132+
shared_settings.WALLET,
133+
timeout_seconds,
134+
body,
135+
uid,
136+
timeout_connect=timeout_connect,
137+
)
129138
),
130-
timeout=timeout_seconds,
139+
# Give additional time for connect and result post-processings.
140+
timeout=timeout_seconds + timeout_connect + timeout_postprocess,
131141
)
132142
except asyncio.TimeoutError:
133143
logger.error(f"Timeout exceeded while querying miner {uid}")
@@ -136,23 +146,37 @@ async def query_miners(
136146

137147
responses = await asyncio.gather(*tasks, return_exceptions=True)
138148

139-
results = []
149+
responses_valid = 0
150+
responses_error = 0
151+
responses_exception = 0
152+
exception_info: Exception | None = None
153+
results: list[SynapseStreamResult] = []
140154
for response, uid in zip(responses, uids):
141155
if isinstance(response, Exception):
156+
responses_exception += 1
157+
exception_info = response
142158
results.append(SynapseStreamResult(exception=str(response)))
143159
elif isinstance(response, tuple) and isinstance(response[0], ChatCompletion):
160+
if response and response[1]:
161+
responses_valid += 1
144162
results.append(
145163
SynapseStreamResult(
146164
uid=uid,
147-
response=response[0],
148165
accumulated_chunks=response[1],
149166
accumulated_chunks_timings=response[2],
150167
accumulated_chunk_dicts_raw=response[3],
151168
)
152169
)
153170
else:
171+
responses_error += 1
154172
logger.error(f"Unknown response type: {response}")
155173
results.append(SynapseStreamResult(uid=uid, exception=f"Unknown response type: {response}"))
174+
175+
logger.info(
176+
f"Responses success: {responses_valid}/{len(uids)}. "
177+
f"Responses exception: {responses_exception}/{len(uids)} [{exception_info}]. "
178+
f"Reponses error: {responses_error}/{len(uids)}"
179+
)
156180
return results
157181
except Exception as e:
158182
logger.exception(f"Error in query_miners: {e}")
@@ -211,14 +235,15 @@ async def make_openai_query(
211235
body: dict[str, Any],
212236
uid: int,
213237
stream: bool = False,
238+
timeout_connect: int = 10,
214239
) -> tuple[ChatCompletion, list, list] | AsyncGenerator:
215240
body["seed"] = body.get("seed", random.randint(0, 1000000))
216241
axon_info = metagraph.axons[uid]
217242
miner = openai.AsyncOpenAI(
218243
base_url=f"http://{axon_info.ip}:{axon_info.port}/v1",
219244
api_key="Apex",
220245
max_retries=0,
221-
timeout=Timeout(timeout_seconds, connect=5, read=timeout_seconds - 5),
246+
timeout=Timeout(timeout_seconds, connect=timeout_connect, read=timeout_seconds),
222247
http_client=openai.DefaultAsyncHttpxClient(
223248
event_hooks={
224249
"request": [create_header_hook(wallet.hotkey, axon_info.hotkey, timeout_seconds=timeout_seconds)]
@@ -227,6 +252,7 @@ async def make_openai_query(
227252
)
228253
extra_body = {k: v for k, v in body.items() if k not in ["messages", "model"]}
229254
body["messages"] = model_factory(body.get("model")).format_messages(body["messages"])
255+
230256
start_time = time.perf_counter()
231257
chat = await miner.chat.completions.create(
232258
# model=None,

tests/prompting/rewards/test_web_retrieval.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# ruff: noqa: E402
2-
from unittest.mock import MagicMock
2+
from datetime import datetime, timedelta
3+
from unittest.mock import MagicMock, patch
34

45
import numpy as np
56
import pytest
@@ -10,6 +11,26 @@
1011
from prompting.rewards.web_retrieval import WebRetrievalRewardModel
1112

1213

14+
@pytest.mark.asyncio
15+
@pytest.mark.parametrize(
16+
"creation_date, expected_age",
17+
[
18+
# Domain created 100 days ago.
19+
(datetime.now() - timedelta(days=100), 100),
20+
# Domain created 10 days ago.
21+
(datetime.now() - timedelta(days=10), 10),
22+
# Domain has no valid creation_date => fallback_age.
23+
(None, 1_000_000),
24+
],
25+
)
26+
async def test_domain_age(creation_date: datetime, expected_age: int):
27+
mock_whois = MagicMock()
28+
mock_whois.creation_date = creation_date
29+
with patch("prompting.rewards.web_retrieval.whois.whois", return_value=mock_whois):
30+
age = await WebRetrievalRewardModel.domain_age_days("testdomain.com", fallback_age=1_000_000)
31+
assert age == expected_age
32+
33+
1334
@pytest.mark.parametrize(
1435
"completion, expected_url, expected_content, expected_relevant",
1536
[

validator_api/test_time_inference.py

-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,6 @@ async def single_attempt():
110110
max_tokens=2000,
111111
)
112112

113-
logger.debug(f"Making API call with\n\nMESSAGES: {messages}\n\nRESPONSE: {response_str}")
114113
response_dict = parse_multiple_json(response_str)[0]
115114
return response_dict
116115
except Exception as e:

0 commit comments

Comments
 (0)