|
19 | 19 | ExploitCheckerTaskMessage, |
20 | 20 | GetflagCheckerTaskMessage, |
21 | 21 | MumbleException, |
22 | | - OfflineException, |
23 | 22 | PutflagCheckerTaskMessage, |
| 23 | + GetnoiseCheckerTaskMessage |
24 | 24 | ) |
25 | 25 | from enochecker3.utils import assert_equals, assert_in |
26 | 26 | from httpx import Auth, AsyncClient, ConnectTimeout, NetworkError, PoolTimeout, HTTPStatusError, RequestError |
27 | 27 | from mcp.client.session import ClientSession |
28 | 28 | from mcp.client.streamable_http import streamablehttp_client |
| 29 | + |
29 | 30 | from sentences import checker_sentences |
30 | 31 |
|
31 | 32 | """ |
@@ -99,6 +100,16 @@ async def put(self, url, json=None, headers=None): |
99 | 100 | async def get(self, url, headers=None): |
100 | 101 | return await self.client.post(url, headers=self.apply_headers(headers)) |
101 | 102 |
|
| 103 | + async def stream(self, url, json=None, headers=None): |
| 104 | + async with self.client.stream("POST", url, json=json, headers=self.apply_headers(headers)) as resp: |
| 105 | + resp.raise_for_status() |
| 106 | + try: |
| 107 | + async for _ in resp.aiter_lines(): |
| 108 | + break |
| 109 | + except: |
| 110 | + pass |
| 111 | + return resp |
| 112 | + |
102 | 113 | async def register_user(self, username: str, password: str) -> Tuple[int, int]: |
103 | 114 | self.logger.debug(f"Registering user {username}") |
104 | 115 | try: |
@@ -299,6 +310,20 @@ async def mcp_read_messages(self, chat_session_id: int, token: str = None) -> li |
299 | 310 | self.logger.error(f"Error during reading messages through MCP of session id {chat_session_id}: {e}") |
300 | 311 | raise MumbleException("Error during reading messages through MCP.") |
301 | 312 |
|
| 313 | + async def send_stream(self, message: str, chat_session_id: int): |
| 314 | + self.logger.debug(f"Sending stream message to chat session {chat_session_id}") |
| 315 | + try: |
| 316 | + resp = await self.stream(f"{API_BASE}/messages/stream", json={"input": message, "ChatSessionId": chat_session_id, "SkipLlmResponse": True}) |
| 317 | + except (ConnectTimeout, NetworkError, PoolTimeout, RequestError) as e: |
| 318 | + self.logger.error(f"Connection error during streaming message to chat session {chat_session_id}: {e}") |
| 319 | + raise MumbleException(f"Connection error during streaming message.") |
| 320 | + except HTTPStatusError as e: |
| 321 | + self.logger.error(f"HTTP error during streaming message to chat session {chat_session_id}: {e}") |
| 322 | + raise MumbleException(f"HTTP error during streaming message.") |
| 323 | + except Exception as e: |
| 324 | + self.logger.error(f"Error during streaming message to chat session {chat_session_id}: {e}") |
| 325 | + raise MumbleException(f"Error during streaming message.") |
| 326 | + |
302 | 327 |
|
303 | 328 | # Dependency injection for Connection |
304 | 329 | @checker.register_dependency |
@@ -730,21 +755,7 @@ async def putflag_agent( |
730 | 755 | flag = task.flag |
731 | 756 | flag = await format_flag(flag) |
732 | 757 |
|
733 | | - async def send_stream(conn, flag, chat_session_id): |
734 | | - async with conn.client.stream( |
735 | | - "POST", |
736 | | - f"{API_BASE}/messages/stream", |
737 | | - json={"input": flag, "ChatSessionId": chat_session_id, "SkipLlmResponse": True}, |
738 | | - headers=conn.headers |
739 | | - ) as resp: |
740 | | - resp.raise_for_status() |
741 | | - try: |
742 | | - async for _ in resp.aiter_lines(): |
743 | | - break |
744 | | - except: |
745 | | - pass |
746 | | - |
747 | | - await send_stream(conn, flag, chat_session_id) |
| 758 | + await conn.send_stream(flag, chat_session_id) |
748 | 759 |
|
749 | 760 | access_token = await conn.create_access_token(random_string(4)) |
750 | 761 |
|
@@ -877,6 +888,91 @@ async def exploit_agent( |
877 | 888 | return unformatted_flag |
878 | 889 |
|
879 | 890 |
|
| 891 | +@checker.putnoise(1) |
| 892 | +async def putnoise_agent( |
| 893 | + conn: Connection, |
| 894 | + db: ChainDB, |
| 895 | +): |
| 896 | + username = random_string() |
| 897 | + password = random_string() |
| 898 | + user_id, chat_session_id = await conn.register_user(username, password) |
| 899 | + await conn.login_user(username, password) |
| 900 | + |
| 901 | + # Generate random noise sentence |
| 902 | + noise_sentence = ''.join(random.choices(get_ascii_chars(), k=SUPPOSED_FLAG_LENGTH)) |
| 903 | + |
| 904 | + await conn.send_stream(noise_sentence, chat_session_id) |
| 905 | + |
| 906 | + access_token = await conn.create_access_token(random_string(4)) |
| 907 | + |
| 908 | + await db.set("userdata", {"user_id": user_id, "session_id": chat_session_id, "access_token": access_token, "noise_sentence": noise_sentence}) |
| 909 | + |
| 910 | + |
| 911 | +@checker.getnoise(1) |
| 912 | +async def getnoise_agent( |
| 913 | + task: GetnoiseCheckerTaskMessage, conn: Connection, db: ChainDB, logger: LoggerAdapter |
| 914 | +) -> None: |
| 915 | + |
| 916 | + try: |
| 917 | + data = await db.get("userdata") |
| 918 | + user_id = data["user_id"] |
| 919 | + access_token = data["access_token"] |
| 920 | + session_id = data["session_id"] |
| 921 | + noise_sentence = data["noise_sentence"] |
| 922 | + except KeyError: |
| 923 | + logger.error("Missing database entry from putnoise(1)") |
| 924 | + raise MumbleException("Missing database entry from putnoise(1)") |
| 925 | + |
| 926 | + client_challenge_bytes = os.urandom(8) |
| 927 | + client_challenge_json = base64.b64encode(client_challenge_bytes).decode('utf-8') |
| 928 | + server_challenge, memory_key, iv = await conn.client_challenge(client_challenge_json, user_id) |
| 929 | + |
| 930 | + salt = client_challenge_bytes + server_challenge |
| 931 | + session_key = HKDF(access_token, 16, salt, SHA256, 1) |
| 932 | + |
| 933 | + client_credentials_bytes = aes_cfb8_encrypt(client_challenge_bytes, session_key, iv) |
| 934 | + client_credentials_json = base64.b64encode(client_credentials_bytes).decode('utf-8') |
| 935 | + agent_name = random_string(4) |
| 936 | + try: |
| 937 | + server_creds, jwt = await conn.client_credentials(client_credentials_json, user_id, agent_name, memory_key) |
| 938 | + except (ConnectTimeout, NetworkError, PoolTimeout, RequestError) as e: |
| 939 | + logger.error( |
| 940 | + f"Connection error during posting client credentials for user {user_id} for agent {agent_name} with memory key {memory_key} and client credentials {client_credentials_json}: {e}") |
| 941 | + raise MumbleException( |
| 942 | + f"Connection error during posting client credentials.") |
| 943 | + except HTTPStatusError as e: |
| 944 | + logger.error( |
| 945 | + f"HTTP error during posting client credentials for user {user_id} for agent {agent_name} with memory key {memory_key} and client credentials {client_credentials_json}: {e}") |
| 946 | + raise MumbleException("HTTP error during posting client credentials.") |
| 947 | + except Exception as e: |
| 948 | + logger.error( |
| 949 | + f"Error during posting client credentials for user {user_id} for agent {agent_name} with memory key {memory_key} and client credentials {client_credentials_json}: {e}") |
| 950 | + raise MumbleException("Error during posting client credentials.") |
| 951 | + |
| 952 | + server_creds_computed = aes_cfb8_encrypt(server_challenge, session_key, iv) |
| 953 | + |
| 954 | + if server_creds_computed != server_creds: |
| 955 | + logger.error(f"Server credentials do not match. Computed ones: {server_creds_computed}. Actual {server_creds}") |
| 956 | + raise MumbleException("Server credentials do not match") |
| 957 | + |
| 958 | + try: |
| 959 | + mcp_client = MCPClient(f"http://{task.address}:{SERVICE_PORT}/mcp", logger, jwt) |
| 960 | + messages = await mcp_client.run(session_id) |
| 961 | + assert len(messages) >= 2, "Expected at least 2 messages from MCP" |
| 962 | + retrieved_sentence = messages[1]["content"] |
| 963 | + except Exception as e: |
| 964 | + logger.error(f"Error during reading messages through MCP of session id {session_id}: {e}") |
| 965 | + raise MumbleException(f"Error during reading messages through MCP.") |
| 966 | + |
| 967 | + if retrieved_sentence is None or retrieved_sentence == "": |
| 968 | + logger.error("No able to retrieve noise sentence") |
| 969 | + raise MumbleException("No able to retrieve noise sentence") |
| 970 | + |
| 971 | + if noise_sentence != retrieved_sentence: |
| 972 | + logger.error(f"Noise sentence mismatch: expected {noise_sentence}, got {retrieved_sentence}") |
| 973 | + raise MumbleException("Noise sentence does not match expected value") |
| 974 | + |
| 975 | + |
880 | 976 | if __name__ == "__main__": |
881 | 977 |
|
882 | 978 | debug = False |
|
0 commit comments