Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
*.log
requirements.txt
**/*.ipynb
debug/
debug_rewards.jsonl
results.db*
sn13_db.db*
Expand Down
36 changes: 36 additions & 0 deletions apex/validator/logger_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import json
from datetime import datetime
from pathlib import Path

from apex.common.models import MinerDiscriminatorResults, MinerGeneratorResults


class LoggerLocal:
def __init__(self, filepath: str = "debug/logs.jsonl"):
self._debug_file_path = Path(filepath)
self._debug_file_path.parent.mkdir(exist_ok=True)

async def log(
self,
query: str,
ground_truth: int,
reference: str | None,
generator_results: MinerGeneratorResults | None,
discriminator_results: MinerDiscriminatorResults | None,
) -> None:
day = datetime.now().strftime("%Y-%m-%d")
filepath = Path(f"{self._debug_file_path.with_suffix('')}-{day}.jsonl")
record: dict[str, str | int | list[str] | list[float] | None] = {
"date": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"query": query,
"ground_truth": ground_truth,
"reference": reference,
"generators": generator_results.generator_results if generator_results else [],
"generator_hotkeys": generator_results.generator_hotkeys if generator_results else [],
"discriminator_results": discriminator_results.discriminator_results if discriminator_results else [],
"discriminator_scores": discriminator_results.discriminator_scores if discriminator_results else [],
"generator_hotkey": discriminator_results.generator_hotkey if discriminator_results else "",
}

with filepath.open("a+") as fh:
fh.write(f"{json.dumps(record)}\n")
13 changes: 13 additions & 0 deletions apex/validator/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from apex.services.websearch.websearch_base import WebSearchBase
from apex.validator import generate_query, generate_reference
from apex.validator.logger_apex import LoggerApex
from apex.validator.logger_local import LoggerLocal
from apex.validator.miner_sampler import MinerSampler


Expand All @@ -29,6 +30,7 @@ def __init__(
queue_size: int = 10_000,
redundancy_rate: float = 0.05, # The rate that references are generated in addition to generator steps
reference_rate: float = 0.5, # The rate that references are generated as opposed to generator steps
debug: bool = False,
):
self.websearch = websearch
self.miner_registry = miner_sampler
Expand All @@ -43,6 +45,8 @@ def __init__(
self.q_out: asyncio.Queue[str] = asyncio.Queue()
self.redundancy_rate = redundancy_rate
self.reference_rate = reference_rate
self._debug = debug
self._logger_local = LoggerLocal()

async def start_loop(self, initial_queries: Sequence[str] | None = None) -> None:
"""Kick off producer -> consumer workers. Runs in perpetuity, generating unique IDs for each task."""
Expand Down Expand Up @@ -110,6 +114,15 @@ async def run_single(self, task: QueryTask) -> str:
reference=reference, discriminator_results=discriminator_results, tool_history=tool_history
)

if self._debug:
await self._logger_local.log(
query=query,
ground_truth=ground_truth,
reference=reference,
generator_results=generator_results,
discriminator_results=discriminator_results,
)

return task.query_id

async def _periodic_consumer(self) -> None:
Expand Down