Skip to content

Commit beaa185

Browse files
committed
add put/get noise for agent functionality, refactor streaming
1 parent 22aa63a commit beaa185

File tree

1 file changed

+110
-15
lines changed

1 file changed

+110
-15
lines changed

checker/src/checker.py

Lines changed: 110 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,16 @@ async def put(self, url, json=None, headers=None):
9999
async def get(self, url, headers=None):
100100
return await self.client.post(url, headers=self.apply_headers(headers))
101101

102+
async def stream(self, url, json=None, headers=None):
103+
async with self.client.stream("POST", url, json=json, headers=self.apply_headers(headers)) as resp:
104+
resp.raise_for_status()
105+
try:
106+
async for _ in resp.aiter_lines():
107+
break
108+
except:
109+
pass
110+
return resp
111+
102112
async def register_user(self, username: str, password: str) -> Tuple[int, int]:
103113
self.logger.debug(f"Registering user {username}")
104114
try:
@@ -299,6 +309,20 @@ async def mcp_read_messages(self, chat_session_id: int, token: str = None) -> li
299309
self.logger.error(f"Error during reading messages through MCP of session id {chat_session_id}: {e}")
300310
raise MumbleException("Error during reading messages through MCP.")
301311

312+
async def send_stream(self, message: str, chat_session_id: int):
313+
self.logger.debug(f"Sending stream message to chat session {chat_session_id}")
314+
try:
315+
resp = await self.stream(f"{API_BASE}/messages/stream", json={"input": message, "ChatSessionId": chat_session_id, "SkipLlmResponse": True})
316+
except (ConnectTimeout, NetworkError, PoolTimeout, RequestError) as e:
317+
self.logger.error(f"Connection error during streaming message to chat session {chat_session_id}: {e}")
318+
raise MumbleException(f"Connection error during streaming message.")
319+
except HTTPStatusError as e:
320+
self.logger.error(f"HTTP error during streaming message to chat session {chat_session_id}: {e}")
321+
raise MumbleException(f"HTTP error during streaming message.")
322+
except Exception as e:
323+
self.logger.error(f"Error during streaming message to chat session {chat_session_id}: {e}")
324+
raise MumbleException(f"Error during streaming message.")
325+
302326

303327
# Dependency injection for Connection
304328
@checker.register_dependency
@@ -730,21 +754,7 @@ async def putflag_agent(
730754
flag = task.flag
731755
flag = await format_flag(flag)
732756

733-
async def send_stream(conn, flag, chat_session_id):
734-
async with conn.client.stream(
735-
"POST",
736-
f"{API_BASE}/messages/stream",
737-
json={"input": flag, "ChatSessionId": chat_session_id, "SkipLlmResponse": True},
738-
headers=conn.headers
739-
) as resp:
740-
resp.raise_for_status()
741-
try:
742-
async for _ in resp.aiter_lines():
743-
break
744-
except:
745-
pass
746-
747-
await send_stream(conn, flag, chat_session_id)
757+
await conn.send_stream(flag, chat_session_id)
748758

749759
access_token = await conn.create_access_token(random_string(4))
750760

@@ -877,6 +887,91 @@ async def exploit_agent(
877887
return unformatted_flag
878888

879889

890+
@checker.putnoise(1)
891+
async def putnoise_agent(
892+
conn: Connection,
893+
db: ChainDB,
894+
):
895+
username = random_string()
896+
password = random_string()
897+
user_id, chat_session_id = await conn.register_user(username, password)
898+
await conn.login_user(username, password)
899+
900+
# Generate random noise sentence
901+
noise_sentence = ''.join(random.choices(get_ascii_chars(), k=SUPPOSED_FLAG_LENGTH))
902+
903+
await conn.send_stream(noise_sentence, chat_session_id)
904+
905+
access_token = await conn.create_access_token(random_string(4))
906+
907+
await db.set("userdata", {"user_id": user_id, "session_id": chat_session_id, "access_token": access_token, "noise_sentence": noise_sentence})
908+
909+
910+
@checker.getnoise(1)
911+
async def getnoise_agent(
912+
task: GetflagCheckerTaskMessage, conn: Connection, db: ChainDB, logger: LoggerAdapter
913+
) -> None:
914+
915+
try:
916+
data = await db.get("userdata")
917+
user_id = data["user_id"]
918+
access_token = data["access_token"]
919+
session_id = data["session_id"]
920+
noise_sentence = data["noise_sentence"]
921+
except KeyError:
922+
logger.error("Missing database entry from putnoise(1)")
923+
raise MumbleException("Missing database entry from putnoise(1)")
924+
925+
client_challenge_bytes = os.urandom(8)
926+
client_challenge_json = base64.b64encode(client_challenge_bytes).decode('utf-8')
927+
server_challenge, memory_key, iv = await conn.client_challenge(client_challenge_json, user_id)
928+
929+
salt = client_challenge_bytes + server_challenge
930+
session_key = HKDF(access_token, 16, salt, SHA256, 1)
931+
932+
client_credentials_bytes = aes_cfb8_encrypt(client_challenge_bytes, session_key, iv)
933+
client_credentials_json = base64.b64encode(client_credentials_bytes).decode('utf-8')
934+
agent_name = random_string(4)
935+
try:
936+
server_creds, jwt = await conn.client_credentials(client_credentials_json, user_id, agent_name, memory_key)
937+
except (ConnectTimeout, NetworkError, PoolTimeout, RequestError) as e:
938+
logger.error(
939+
f"Connection error during posting client credentials for user {user_id} for agent {agent_name} with memory key {memory_key} and client credentials {client_credentials_json}: {e}")
940+
raise MumbleException(
941+
f"Connection error during posting client credentials.")
942+
except HTTPStatusError as e:
943+
logger.error(
944+
f"HTTP error during posting client credentials for user {user_id} for agent {agent_name} with memory key {memory_key} and client credentials {client_credentials_json}: {e}")
945+
raise MumbleException("HTTP error during posting client credentials.")
946+
except Exception as e:
947+
logger.error(
948+
f"Error during posting client credentials for user {user_id} for agent {agent_name} with memory key {memory_key} and client credentials {client_credentials_json}: {e}")
949+
raise MumbleException("Error during posting client credentials.")
950+
951+
server_creds_computed = aes_cfb8_encrypt(server_challenge, session_key, iv)
952+
953+
if server_creds_computed != server_creds:
954+
logger.error(f"Server credentials do not match. Computed ones: {server_creds_computed}. Actual {server_creds}")
955+
raise MumbleException("Server credentials do not match")
956+
957+
try:
958+
mcp_client = MCPClient(f"http://{task.address}:{SERVICE_PORT}/mcp", logger, jwt)
959+
messages = await mcp_client.run(session_id)
960+
assert len(messages) >= 2, "Expected at least 2 messages from MCP"
961+
retrieved_sentence = messages[1]["content"]
962+
except Exception as e:
963+
logger.error(f"Error during reading messages through MCP of session id {session_id}: {e}")
964+
raise MumbleException(f"Error during reading messages through MCP.")
965+
966+
if retrieved_sentence is None or retrieved_sentence == "":
967+
logger.error("No able to retrieve noise sentence")
968+
raise MumbleException("No able to retrieve noise sentence")
969+
970+
if noise_sentence != retrieved_sentence:
971+
logger.error(f"Noise sentence mismatch: expected {noise_sentence}, got {retrieved_sentence}")
972+
raise MumbleException("Noise sentence does not match expected value")
973+
974+
880975
if __name__ == "__main__":
881976

882977
debug = False

0 commit comments

Comments
 (0)