Skip to content

Commit 26ac5c2

Browse files
committed
Update tests
1 parent a066818 commit 26ac5c2

File tree

4 files changed

+31
-19
lines changed

4 files changed

+31
-19
lines changed

src/pytest_pl_grader/fixture.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -130,16 +130,18 @@ def _assert_process_running(self) -> None:
130130
if process_return_code is not None:
131131
raise RuntimeError(f"Student code server process terminated with code {process_return_code}.")
132132

133-
def read_from_socket(self) -> bytes:
133+
def _send_json_object(
134+
self, json_object: StudentQueryRequest | ProcessStartRequest | StudentFunctionRequest | SetupQueryRequest
135+
) -> None:
134136
"""
135-
Reads data from a socket until a termination character is found.
137+
Sends a JSON object to the student code server.
138+
"""
139+
assert self.student_socket is not None, "Student socket is not connected. Please start the student code server first."
140+
self.student_socket.sendall((json.dumps(json_object) + os.linesep).encode("utf-8"))
136141

137-
:param sock: The active socket object.
138-
:param terminator: The byte string sequence that signals the end of the message.
139-
:param max_len: Optional. The maximum number of bytes to read before stopping.
140-
:return: The received data, including the terminator.
141-
:raises TimeoutError: If the socket times out during the read operation.
142-
:raises Exception: If the connection is closed before the terminator is found.
142+
def _read_from_socket(self) -> bytes:
143+
"""
144+
Reads data from a socket until a termination character is found.
143145
"""
144146
buffer = bytearray()
145147

@@ -154,11 +156,10 @@ def read_from_socket(self) -> bytes:
154156
# TODO mabye set a hard iteration limit to avoid infinite loops?
155157
while True:
156158
# Check if the termination sequence is already in the buffer
157-
if terminator in buffer:
159+
if (idx := buffer.rfind(terminator)) != -1:
158160
# Return the buffer content up to and including the terminator
159-
# We use buffer[:buffer.index(terminator) + len(terminator)]
160161
# to only return the necessary data if more data was read
161-
return bytes(buffer)
162+
return buffer[: idx + len(terminator)]
162163

163164
# Check for maximum length constraint
164165
if max_len is not None and len(buffer) >= max_len:
@@ -241,10 +242,10 @@ def try_drop_privileges() -> None:
241242
self.student_socket.settimeout(initialization_timeout)
242243
self.student_socket.connect((host, int(port)))
243244

244-
self.student_socket.sendall(json.dumps(json_message).encode("utf-8") + os.linesep.encode("utf-8"))
245+
self._send_json_object(json_message)
245246

246247
try:
247-
data = self.read_from_socket().decode() # Adjust the buffer size as needed
248+
data = self._read_from_socket().decode() # Adjust the buffer size as needed
248249
res: ProcessStartResponse = json.loads(data)
249250
except Exception as e:
250251
res = {
@@ -264,8 +265,8 @@ def query_setup_raw(self, var_to_query: str) -> SetupQueryResponse:
264265
json_message: SetupQueryRequest = {"message_type": "query_setup", "var": var_to_query}
265266

266267
assert self.student_socket is not None, "Student socket is not connected. Please start the student code server first."
267-
self.student_socket.sendall((json.dumps(json_message) + os.linesep).encode("utf-8"))
268-
data: SetupQueryResponse = json.loads(self.read_from_socket().decode())
268+
self._send_json_object(json_message)
269+
data: SetupQueryResponse = json.loads(self._read_from_socket().decode())
269270

270271
return data
271272

@@ -287,8 +288,8 @@ def query_raw(self, var_to_query: str, *, query_timeout: float = DEFAULT_TIMEOUT
287288

288289
assert self.student_socket is not None, "Student socket is not connected. Please start the student code server first."
289290
self.student_socket.settimeout(query_timeout)
290-
self.student_socket.sendall((json.dumps(json_message) + os.linesep).encode("utf-8"))
291-
data: StudentQueryResponse = json.loads(self.read_from_socket().decode())
291+
self._send_json_object(json_message)
292+
data: StudentQueryResponse = json.loads(self._read_from_socket().decode())
292293

293294
return data
294295

@@ -319,7 +320,7 @@ def query_function_raw(self, function_name: str, *args, query_timeout: float = D
319320
assert self.student_socket is not None, "Student socket is not connected. Please start the student code server first."
320321
self.student_socket.settimeout(query_timeout)
321322
self.student_socket.sendall((json.dumps(json_message) + os.linesep).encode("utf-8"))
322-
data: StudentFunctionResponse = json.loads(self.read_from_socket().decode())
323+
data: StudentFunctionResponse = json.loads(self._read_from_socket().decode())
323324

324325
return data
325326

tests/scenario_root/test_private_namespace/expected_outcome.json

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"expected_data_object": {
3-
"score": 0.3333333333333333,
3+
"score": 0.25,
44
"tests": [
55
{
66
"max_points": 1,
@@ -20,6 +20,15 @@
2020
"points_frac": 0.0,
2121
"test_id": "test_private_namespace.py::test_random[student_code_fail]"
2222
},
23+
{
24+
"max_points": 1,
25+
"message": "NameError: name '__data_params' is not defined",
26+
"name": "test_private_namespace.py::test_random[student_code_fail_read_builtin]",
27+
"outcome": "error",
28+
"points": 0.0,
29+
"points_frac": 0.0,
30+
"test_id": "test_private_namespace.py::test_random[student_code_fail_read_builtin]"
31+
},
2332
{
2433
"max_points": 1,
2534
"message": "",

tests/scenario_root/test_private_namespace/student_code.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@
22
a = 10
33
b["c"] = 50
44
d = f(10)
5+
numpy_array
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
a = __data_params

0 commit comments

Comments
 (0)