Skip to content

Commit 1d61e91

Browse files
committed
Implement gzfwrite and add null buffer check to gzfread
1 parent d004e30 commit 1d61e91

File tree

2 files changed

+154
-6
lines changed

2 files changed

+154
-6
lines changed

libz-rs-sys/src/gz.rs

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -986,7 +986,7 @@ pub unsafe extern "C-unwind" fn gzfread(
986986
nitems: size_t,
987987
file: gzFile,
988988
) -> size_t {
989-
if size == 0 {
989+
if size == 0 || buf.is_null() {
990990
return 0;
991991
}
992992

@@ -1463,6 +1463,58 @@ pub unsafe extern "C-unwind" fn gzwrite(file: gzFile, buf: *const c_void, len: c
14631463
unsafe { gz_write(state, buf, len) }
14641464
}
14651465

1466+
/// Compress and write `nitems` items of size `size` from `buf` to `file`, duplicating
1467+
/// the interface of C stdio's `fwrite`, with `size_t` request and return types.
1468+
///
1469+
/// # Returns
1470+
///
1471+
/// - The number of full items written of size `size` on success.
1472+
/// - Zero on error.
1473+
///
1474+
/// Note: If the multiplication of `size` and `nitems` overflows, i.e. the product does
1475+
/// not fit in a `size_t`, then nothing is written, zero is returned, and the error state
1476+
/// is set to `Z_STREAM_ERROR`.
1477+
///
1478+
/// # Safety
1479+
///
1480+
/// - `file`, if non-null, must be an open file handle obtained from [`gzopen`] or [`gzdopen`].
1481+
/// - The caller must ensure that `buf` points to at least `size * nitems` readable bytes.
1482+
#[cfg_attr(feature = "export-symbols", export_name = crate::prefix!(gzfwrite))]
1483+
pub unsafe extern "C-unwind" fn gzfwrite(
1484+
buf: *const c_void,
1485+
size: size_t,
1486+
nitems: size_t,
1487+
file: gzFile,
1488+
) -> size_t {
1489+
if size == 0 || buf.is_null() {
1490+
return 0;
1491+
}
1492+
1493+
let Some(state) = (unsafe { file.cast::<GzState>().as_mut() }) else {
1494+
return 0;
1495+
};
1496+
1497+
// Check that we're writing and that there's no error.
1498+
if state.mode != GzMode::GZ_WRITE || state.err != Z_OK {
1499+
return 0;
1500+
}
1501+
1502+
// Compute the number of bytes to write, and make sure it fits in a size_t.
1503+
let Some(len) = size.checked_mul(nitems) else {
1504+
const MSG: &str = "request does not fit in a size_t";
1505+
unsafe { gz_error(state, Some((Z_STREAM_ERROR, MSG))) };
1506+
return 0;
1507+
};
1508+
1509+
if len == 0 {
1510+
len
1511+
} else {
1512+
// Safety: The caller is responsible for ensuring that `buf` points to at least
1513+
// `len = size * nitems` readable bytes.
1514+
(unsafe { gz_write(state, buf, len) }) as size_t / size
1515+
}
1516+
}
1517+
14661518
// Internal implementation of `gzwrite`.
14671519
//
14681520
// # Returns

test-libz-rs-sys/src/gz.rs

Lines changed: 101 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ use zlib_rs::c_api::*;
22

33
use libz_rs_sys::{
44
gzFile_s, gzbuffer, gzclearerr, gzclose, gzclose_r, gzclose_w, gzdirect, gzdopen, gzerror,
5-
gzflush, gzfread, gzgetc, gzgetc_, gzgets, gzoffset, gzopen, gzputc, gzputs, gzread, gztell,
6-
gzwrite,
5+
gzflush, gzfread, gzfwrite, gzgetc, gzgetc_, gzgets, gzoffset, gzopen, gzputc, gzputs, gzread,
6+
gztell, gzwrite,
77
};
88

99
use libc::size_t;
@@ -1197,6 +1197,9 @@ fn gzfread_basic() {
11971197
assert_eq!(&buf[..20], b"gzip\nexample data\nfo");
11981198
assert_eq!(buf[20], 0);
11991199

1200+
// gzfread with a null buffer should return 0.
1201+
assert_eq!(unsafe { gzfread(ptr::null_mut(), 1, 1, file) }, 0);
1202+
12001203
// When there is not enough data remaining in the file, gzfread should transfer as many
12011204
// units of size as possible.
12021205
let mut buf = [0u8; 32];
@@ -1222,7 +1225,7 @@ fn gzfread_error() {
12221225
);
12231226

12241227
// gzfread with a size or nitems of 0 should return 0.
1225-
let file = unsafe { gzdopen(-2, CString::new("w").unwrap().as_ptr()) };
1228+
let file = unsafe { gzdopen(-2, CString::new("r").unwrap().as_ptr()) };
12261229
assert!(!file.is_null());
12271230
assert_eq!(
12281231
unsafe { gzfread(buf.as_mut_ptr().cast::<c_void>(), 0, 1, file) },
@@ -1235,12 +1238,12 @@ fn gzfread_error() {
12351238

12361239
// gzfread should return 0 if size * nitems is too big to fit in a size_t.
12371240
assert_eq!(
1238-
unsafe { gzfread(buf.as_mut_ptr().cast::<c_void>(), size_t::MAX, 1, file) },
1241+
unsafe { gzfread(buf.as_mut_ptr().cast::<c_void>(), size_t::MAX, 2, file) },
12391242
0
12401243
);
12411244
assert_eq!(unsafe { gzclose(file) }, Z_ERRNO);
12421245

1243-
// gzfread on a read-only file handle should return 0.
1246+
// gzfread on a write-only file handle should return 0.
12441247
let file = unsafe { gzdopen(-2, CString::new("w").unwrap().as_ptr()) };
12451248
assert_eq!(
12461249
unsafe { gzfread(buf.as_mut_ptr().cast::<c_void>(), 1, 1, file) },
@@ -1250,6 +1253,99 @@ fn gzfread_error() {
12501253
assert_eq!(unsafe { gzclose(file) }, Z_ERRNO);
12511254
}
12521255

1256+
#[test]
1257+
fn gzfwrite_basic() {
1258+
// Create a temporary directory that will be automatically removed when
1259+
// temp_dir goes out of scope.
1260+
let temp_dir_path = temp_base();
1261+
let temp_dir = tempfile::TempDir::new_in(temp_dir_path).unwrap();
1262+
let temp_path = temp_dir.path();
1263+
let file_name = path(temp_path, "output");
1264+
1265+
// Open a file for writing, using direct (uncompressed) mode to make it easier
1266+
// to verify the output.
1267+
let file = unsafe {
1268+
gzopen(
1269+
CString::new(file_name.as_str()).unwrap().as_ptr(),
1270+
CString::new("wT").unwrap().as_ptr(),
1271+
)
1272+
};
1273+
assert!(!file.is_null());
1274+
1275+
// gzfwrite of a single object should return 1.
1276+
assert_eq!(
1277+
unsafe { gzfwrite(b"test".as_ptr().cast::<c_void>(), 4, 1, file) },
1278+
1
1279+
);
1280+
// gzfwrite of n objects should return n.
1281+
assert_eq!(
1282+
unsafe { gzfwrite(b" of gzfwrite...".as_ptr().cast::<c_void>(), 4, 3, file) },
1283+
3
1284+
);
1285+
1286+
// gzfwrite with a null buffer should return 0.
1287+
assert_eq!(unsafe { gzfread(ptr::null_mut(), 1, 1, file) }, 0);
1288+
1289+
// After the gzfwrite calls, the file should close cleanly.
1290+
assert_eq!(unsafe { gzclose(file) }, Z_OK);
1291+
1292+
// Read in the file and verify that the contents match what was passed to gzfwrite.
1293+
let mut mode = 0;
1294+
#[cfg(target_os = "windows")]
1295+
{
1296+
mode |= libc::O_BINARY;
1297+
}
1298+
mode |= libc::O_RDONLY;
1299+
let fd = unsafe { libc::open(CString::new(file_name.as_str()).unwrap().as_ptr(), mode) };
1300+
assert_ne!(fd, -1);
1301+
const EXPECTED: &[u8] = b"test of gzfwrite";
1302+
let mut buf = [0u8; EXPECTED.len() + 1];
1303+
let ret = unsafe { libc::read(fd, buf.as_mut_ptr().cast(), buf.len() as _) };
1304+
assert_eq!(ret, EXPECTED.len() as _);
1305+
assert_eq!(&buf[..EXPECTED.len()], EXPECTED);
1306+
1307+
assert_eq!(unsafe { libc::close(fd) }, 0);
1308+
}
1309+
1310+
#[test]
1311+
fn gzfwrite_error() {
1312+
let mut buf = [0u8; 10];
1313+
1314+
// gzfwrite on a null file handle should return 0.
1315+
assert_eq!(
1316+
unsafe { gzfwrite(buf.as_mut_ptr().cast::<c_void>(), 1, 1, ptr::null_mut()) },
1317+
0
1318+
);
1319+
1320+
// gzfwrite with a size or nitems of 0 should return 0.
1321+
let file = unsafe { gzdopen(-2, CString::new("w").unwrap().as_ptr()) };
1322+
assert!(!file.is_null());
1323+
assert_eq!(
1324+
unsafe { gzfwrite(buf.as_mut_ptr().cast::<c_void>(), 0, 1, file) },
1325+
0
1326+
);
1327+
assert_eq!(
1328+
unsafe { gzfwrite(buf.as_mut_ptr().cast::<c_void>(), 1, 0, file) },
1329+
0
1330+
);
1331+
1332+
// gzfwrite should return 0 if size * nitems is too big to fit in a size_t.
1333+
assert_eq!(
1334+
unsafe { gzfwrite(buf.as_mut_ptr().cast::<c_void>(), size_t::MAX, 2, file) },
1335+
0
1336+
);
1337+
assert_eq!(unsafe { gzclose(file) }, Z_ERRNO);
1338+
1339+
// gzfwrite on a read-only file handle should return 0.
1340+
let file = unsafe { gzdopen(-2, CString::new("r").unwrap().as_ptr()) };
1341+
assert_eq!(
1342+
unsafe { gzfwrite(buf.as_mut_ptr().cast::<c_void>(), 1, 1, file) },
1343+
0
1344+
);
1345+
assert!(!file.is_null());
1346+
assert_eq!(unsafe { gzclose(file) }, Z_ERRNO);
1347+
}
1348+
12531349
// Get the size in bytes of a file.
12541350
//
12551351
// # Returns

0 commit comments

Comments
 (0)