Skip to content

Commit a8a4935

Browse files
committed
add put/get noise for agent functionality, refactor streaming
1 parent 22aa63a commit a8a4935

File tree

1 file changed

+112
-16
lines changed

1 file changed

+112
-16
lines changed

checker/src/checker.py

Lines changed: 112 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,14 @@
1919
ExploitCheckerTaskMessage,
2020
GetflagCheckerTaskMessage,
2121
MumbleException,
22-
OfflineException,
2322
PutflagCheckerTaskMessage,
23+
GetnoiseCheckerTaskMessage
2424
)
2525
from enochecker3.utils import assert_equals, assert_in
2626
from httpx import Auth, AsyncClient, ConnectTimeout, NetworkError, PoolTimeout, HTTPStatusError, RequestError
2727
from mcp.client.session import ClientSession
2828
from mcp.client.streamable_http import streamablehttp_client
29+
2930
from sentences import checker_sentences
3031

3132
"""
@@ -99,6 +100,16 @@ async def put(self, url, json=None, headers=None):
99100
async def get(self, url, headers=None):
100101
return await self.client.post(url, headers=self.apply_headers(headers))
101102

103+
async def stream(self, url, json=None, headers=None):
104+
async with self.client.stream("POST", url, json=json, headers=self.apply_headers(headers)) as resp:
105+
resp.raise_for_status()
106+
try:
107+
async for _ in resp.aiter_lines():
108+
break
109+
except:
110+
pass
111+
return resp
112+
102113
async def register_user(self, username: str, password: str) -> Tuple[int, int]:
103114
self.logger.debug(f"Registering user {username}")
104115
try:
@@ -299,6 +310,20 @@ async def mcp_read_messages(self, chat_session_id: int, token: str = None) -> li
299310
self.logger.error(f"Error during reading messages through MCP of session id {chat_session_id}: {e}")
300311
raise MumbleException("Error during reading messages through MCP.")
301312

313+
async def send_stream(self, message: str, chat_session_id: int):
314+
self.logger.debug(f"Sending stream message to chat session {chat_session_id}")
315+
try:
316+
resp = await self.stream(f"{API_BASE}/messages/stream", json={"input": message, "ChatSessionId": chat_session_id, "SkipLlmResponse": True})
317+
except (ConnectTimeout, NetworkError, PoolTimeout, RequestError) as e:
318+
self.logger.error(f"Connection error during streaming message to chat session {chat_session_id}: {e}")
319+
raise MumbleException(f"Connection error during streaming message.")
320+
except HTTPStatusError as e:
321+
self.logger.error(f"HTTP error during streaming message to chat session {chat_session_id}: {e}")
322+
raise MumbleException(f"HTTP error during streaming message.")
323+
except Exception as e:
324+
self.logger.error(f"Error during streaming message to chat session {chat_session_id}: {e}")
325+
raise MumbleException(f"Error during streaming message.")
326+
302327

303328
# Dependency injection for Connection
304329
@checker.register_dependency
@@ -730,21 +755,7 @@ async def putflag_agent(
730755
flag = task.flag
731756
flag = await format_flag(flag)
732757

733-
async def send_stream(conn, flag, chat_session_id):
734-
async with conn.client.stream(
735-
"POST",
736-
f"{API_BASE}/messages/stream",
737-
json={"input": flag, "ChatSessionId": chat_session_id, "SkipLlmResponse": True},
738-
headers=conn.headers
739-
) as resp:
740-
resp.raise_for_status()
741-
try:
742-
async for _ in resp.aiter_lines():
743-
break
744-
except:
745-
pass
746-
747-
await send_stream(conn, flag, chat_session_id)
758+
await conn.send_stream(flag, chat_session_id)
748759

749760
access_token = await conn.create_access_token(random_string(4))
750761

@@ -877,6 +888,91 @@ async def exploit_agent(
877888
return unformatted_flag
878889

879890

891+
@checker.putnoise(1)
892+
async def putnoise_agent(
893+
conn: Connection,
894+
db: ChainDB,
895+
):
896+
username = random_string()
897+
password = random_string()
898+
user_id, chat_session_id = await conn.register_user(username, password)
899+
await conn.login_user(username, password)
900+
901+
# Generate random noise sentence
902+
noise_sentence = ''.join(random.choices(get_ascii_chars(), k=SUPPOSED_FLAG_LENGTH))
903+
904+
await conn.send_stream(noise_sentence, chat_session_id)
905+
906+
access_token = await conn.create_access_token(random_string(4))
907+
908+
await db.set("userdata", {"user_id": user_id, "session_id": chat_session_id, "access_token": access_token, "noise_sentence": noise_sentence})
909+
910+
911+
@checker.getnoise(1)
912+
async def getnoise_agent(
913+
task: GetnoiseCheckerTaskMessage, conn: Connection, db: ChainDB, logger: LoggerAdapter
914+
) -> None:
915+
916+
try:
917+
data = await db.get("userdata")
918+
user_id = data["user_id"]
919+
access_token = data["access_token"]
920+
session_id = data["session_id"]
921+
noise_sentence = data["noise_sentence"]
922+
except KeyError:
923+
logger.error("Missing database entry from putnoise(1)")
924+
raise MumbleException("Missing database entry from putnoise(1)")
925+
926+
client_challenge_bytes = os.urandom(8)
927+
client_challenge_json = base64.b64encode(client_challenge_bytes).decode('utf-8')
928+
server_challenge, memory_key, iv = await conn.client_challenge(client_challenge_json, user_id)
929+
930+
salt = client_challenge_bytes + server_challenge
931+
session_key = HKDF(access_token, 16, salt, SHA256, 1)
932+
933+
client_credentials_bytes = aes_cfb8_encrypt(client_challenge_bytes, session_key, iv)
934+
client_credentials_json = base64.b64encode(client_credentials_bytes).decode('utf-8')
935+
agent_name = random_string(4)
936+
try:
937+
server_creds, jwt = await conn.client_credentials(client_credentials_json, user_id, agent_name, memory_key)
938+
except (ConnectTimeout, NetworkError, PoolTimeout, RequestError) as e:
939+
logger.error(
940+
f"Connection error during posting client credentials for user {user_id} for agent {agent_name} with memory key {memory_key} and client credentials {client_credentials_json}: {e}")
941+
raise MumbleException(
942+
f"Connection error during posting client credentials.")
943+
except HTTPStatusError as e:
944+
logger.error(
945+
f"HTTP error during posting client credentials for user {user_id} for agent {agent_name} with memory key {memory_key} and client credentials {client_credentials_json}: {e}")
946+
raise MumbleException("HTTP error during posting client credentials.")
947+
except Exception as e:
948+
logger.error(
949+
f"Error during posting client credentials for user {user_id} for agent {agent_name} with memory key {memory_key} and client credentials {client_credentials_json}: {e}")
950+
raise MumbleException("Error during posting client credentials.")
951+
952+
server_creds_computed = aes_cfb8_encrypt(server_challenge, session_key, iv)
953+
954+
if server_creds_computed != server_creds:
955+
logger.error(f"Server credentials do not match. Computed ones: {server_creds_computed}. Actual {server_creds}")
956+
raise MumbleException("Server credentials do not match")
957+
958+
try:
959+
mcp_client = MCPClient(f"http://{task.address}:{SERVICE_PORT}/mcp", logger, jwt)
960+
messages = await mcp_client.run(session_id)
961+
assert len(messages) >= 2, "Expected at least 2 messages from MCP"
962+
retrieved_sentence = messages[1]["content"]
963+
except Exception as e:
964+
logger.error(f"Error during reading messages through MCP of session id {session_id}: {e}")
965+
raise MumbleException(f"Error during reading messages through MCP.")
966+
967+
if retrieved_sentence is None or retrieved_sentence == "":
968+
logger.error("No able to retrieve noise sentence")
969+
raise MumbleException("No able to retrieve noise sentence")
970+
971+
if noise_sentence != retrieved_sentence:
972+
logger.error(f"Noise sentence mismatch: expected {noise_sentence}, got {retrieved_sentence}")
973+
raise MumbleException("Noise sentence does not match expected value")
974+
975+
880976
if __name__ == "__main__":
881977

882978
debug = False

0 commit comments

Comments
 (0)