Skip to content

Commit 93c40bb

Browse files
committed
Refactor error message handling in FFI functions to return allocated C strings instead of using provided buffers. Updated related tests to reflect changes in error handling functions.
1 parent 42d5021 commit 93c40bb

File tree

3 files changed

+76
-39
lines changed

3 files changed

+76
-39
lines changed

src/ffi/request.rs

Lines changed: 67 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -251,11 +251,10 @@ pub extern "C" fn request_read_error_kind(request_id: RequestId) -> u8 {
251251
#[unsafe(no_mangle)]
252252
pub extern "C" fn request_read_error_message(
253253
request_id: RequestId,
254-
buffer: *mut c_char,
255-
buffer_len: usize,
256-
) -> usize {
257-
if buffer.is_null() || buffer_len == 0 {
258-
return 0;
254+
num_bytes: *mut u32,
255+
) -> *mut c_char {
256+
if num_bytes.is_null() {
257+
return ptr::null_mut();
259258
}
260259

261260
let tracker = REQUEST_TRACKER.lock().unwrap();
@@ -264,28 +263,43 @@ pub extern "C" fn request_read_error_message(
264263
let progress = progress_info.read().unwrap();
265264
if let Some(ref response) = progress.final_response {
266265
if let Err(ref error_msg) = response.body {
267-
let msg_len = error_msg.len().min(buffer_len - 1);
266+
let msg_len = error_msg.len();
267+
unsafe { *num_bytes = msg_len as u32 };
268+
269+
// Allocate memory for the string + null terminator
270+
let c_str_ptr = unsafe { libc::malloc(msg_len + 1) as *mut c_char };
271+
if c_str_ptr.is_null() {
272+
unsafe { *num_bytes = 0 };
273+
return ptr::null_mut();
274+
}
275+
276+
// Copy the string and add null terminator
268277
unsafe {
269-
std::ptr::copy_nonoverlapping(error_msg.as_ptr(), buffer as *mut u8, msg_len);
270-
*buffer.add(msg_len) = 0;
278+
std::ptr::copy_nonoverlapping(
279+
error_msg.as_ptr(),
280+
c_str_ptr as *mut u8,
281+
msg_len,
282+
);
283+
*(c_str_ptr.add(msg_len)) = 0;
271284
}
272-
return msg_len;
285+
286+
return c_str_ptr;
273287
}
274288
}
275289
}
276290

277-
0
291+
unsafe { *num_bytes = 0 };
292+
ptr::null_mut()
278293
}
279294

280295
/// Get the error URL if available
281296
#[unsafe(no_mangle)]
282297
pub extern "C" fn request_read_error_url(
283298
request_id: RequestId,
284-
buffer: *mut c_char,
285-
buffer_len: usize,
286-
) -> usize {
287-
if buffer.is_null() || buffer_len == 0 {
288-
return 0;
299+
num_bytes: *mut u32,
300+
) -> *mut c_char {
301+
if num_bytes.is_null() {
302+
return ptr::null_mut();
289303
}
290304

291305
let tracker = REQUEST_TRACKER.lock().unwrap();
@@ -294,28 +308,39 @@ pub extern "C" fn request_read_error_url(
294308
let progress = progress_info.read().unwrap();
295309
if let Some(ref response) = progress.final_response {
296310
if let Some(ref url) = response.error_url {
297-
let url_len = url.len().min(buffer_len - 1);
311+
let url_len = url.len();
312+
unsafe { *num_bytes = url_len as u32 };
313+
314+
// Allocate memory for the string + null terminator
315+
let c_str_ptr = unsafe { libc::malloc(url_len + 1) as *mut c_char };
316+
if c_str_ptr.is_null() {
317+
unsafe { *num_bytes = 0 };
318+
return ptr::null_mut();
319+
}
320+
321+
// Copy the string and add null terminator
298322
unsafe {
299-
std::ptr::copy_nonoverlapping(url.as_ptr(), buffer as *mut u8, url_len);
300-
*buffer.add(url_len) = 0;
323+
std::ptr::copy_nonoverlapping(url.as_ptr(), c_str_ptr as *mut u8, url_len);
324+
*(c_str_ptr.add(url_len)) = 0;
301325
}
302-
return url_len;
326+
327+
return c_str_ptr;
303328
}
304329
}
305330
}
306331

307-
0
332+
unsafe { *num_bytes = 0 };
333+
ptr::null_mut()
308334
}
309335

310336
/// Get the root cause error message
311337
#[unsafe(no_mangle)]
312338
pub extern "C" fn request_read_error_source(
313339
request_id: RequestId,
314-
buffer: *mut c_char,
315-
buffer_len: usize,
316-
) -> usize {
317-
if buffer.is_null() || buffer_len == 0 {
318-
return 0;
340+
num_bytes: *mut u32,
341+
) -> *mut c_char {
342+
if num_bytes.is_null() {
343+
return ptr::null_mut();
319344
}
320345

321346
let tracker = REQUEST_TRACKER.lock().unwrap();
@@ -324,17 +349,29 @@ pub extern "C" fn request_read_error_source(
324349
let progress = progress_info.read().unwrap();
325350
if let Some(ref response) = progress.final_response {
326351
if let Some(ref source) = response.error_source {
327-
let src_len = source.len().min(buffer_len - 1);
352+
let src_len = source.len();
353+
unsafe { *num_bytes = src_len as u32 };
354+
355+
// Allocate memory for the string + null terminator
356+
let c_str_ptr = unsafe { libc::malloc(src_len + 1) as *mut c_char };
357+
if c_str_ptr.is_null() {
358+
unsafe { *num_bytes = 0 };
359+
return ptr::null_mut();
360+
}
361+
362+
// Copy the string and add null terminator
328363
unsafe {
329-
std::ptr::copy_nonoverlapping(source.as_ptr(), buffer as *mut u8, src_len);
330-
*buffer.add(src_len) = 0;
364+
std::ptr::copy_nonoverlapping(source.as_ptr(), c_str_ptr as *mut u8, src_len);
365+
*(c_str_ptr.add(src_len)) = 0;
331366
}
332-
return src_len;
367+
368+
return c_str_ptr;
333369
}
334370
}
335371
}
336372

337-
0
373+
unsafe { *num_bytes = 0 };
374+
ptr::null_mut()
338375
}
339376

340377
/// Check if response has an error

tests/integration/common.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ pub fn wait_for_request_with_retry(request_id: usize, max_retries: u64) -> (u16,
4141

4242
// Get error message if available for debugging
4343
let mut error_len: u32 = 0;
44-
let error_ptr = request_read_transport_error(request_id, &mut error_len);
44+
let error_ptr = request_read_error_message(request_id, &mut error_len);
4545
let error_msg = if !error_ptr.is_null() && error_len > 0 {
4646
let error_str = unsafe {
4747
let error_slice =

tests/integration/error_handling.rs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,12 @@ fn test_error_handling_invalid_url() {
2525
std::thread::sleep(Duration::from_millis(50));
2626
}
2727

28-
// Check if we have a transport error
29-
assert!(request_has_transport_error(request_id));
28+
// Check if we have an error
29+
assert!(request_has_error(request_id));
3030

3131
// Read the error message
3232
let mut error_len: u32 = 0;
33-
let error_ptr = request_read_transport_error(request_id, &mut error_len);
33+
let error_ptr = request_read_error_message(request_id, &mut error_len);
3434
assert!(!error_ptr.is_null());
3535
assert!(error_len > 0);
3636

@@ -81,12 +81,12 @@ fn test_error_handling_connection_timeout() {
8181
std::thread::sleep(Duration::from_millis(50));
8282
}
8383

84-
// Check if we have a transport error
85-
assert!(request_has_transport_error(request_id));
84+
// Check if we have an error
85+
assert!(request_has_error(request_id));
8686

8787
// Read the error message
8888
let mut error_len: u32 = 0;
89-
let error_ptr = request_read_transport_error(request_id, &mut error_len);
89+
let error_ptr = request_read_error_message(request_id, &mut error_len);
9090
assert!(!error_ptr.is_null());
9191
assert!(error_len > 0);
9292

@@ -129,8 +129,8 @@ fn test_error_handling_404_not_found() {
129129
std::thread::sleep(Duration::from_millis(50));
130130
}
131131

132-
// For HTTP errors like 404, we don't expect a transport error
133-
assert!(!request_has_transport_error(request_id));
132+
// For HTTP errors like 404, we don't expect an error in the body
133+
assert!(!request_has_error(request_id));
134134

135135
// But we do expect a 404 status code
136136
let status = request_read_response_status(request_id);

0 commit comments

Comments
 (0)