|
20 | 20 |
|
21 | 21 | use libfuzzer_sys::fuzz_target; |
22 | 22 | use mssql_tds::connection::client_context::ClientContext; |
23 | | -use mssql_tds::fuzz_support::{MockTransport, TdsConnectionProvider, TdsPacketReader}; |
24 | | -use mssql_tds::core::TdsResult; |
25 | | -use std::io::{Error, ErrorKind}; |
26 | | - |
27 | | -/// Simple reader that wraps fuzz input data |
28 | | -struct FuzzReader { |
29 | | - data: Vec<u8>, |
30 | | - position: usize, |
31 | | -} |
32 | | - |
33 | | -impl FuzzReader { |
34 | | - fn new(data: &[u8]) -> Self { |
35 | | - Self { |
36 | | - data: data.to_vec(), |
37 | | - position: 0, |
38 | | - } |
39 | | - } |
40 | | -} |
41 | | - |
42 | | -#[async_trait::async_trait] |
43 | | -impl TdsPacketReader for FuzzReader { |
44 | | - async fn read_byte(&mut self) -> TdsResult<u8> { |
45 | | - if self.position >= self.data.len() { |
46 | | - return Err(mssql_tds::error::Error::Io(Error::new( |
47 | | - ErrorKind::UnexpectedEof, |
48 | | - "EOF", |
49 | | - ))); |
50 | | - } |
51 | | - let byte = self.data[self.position]; |
52 | | - self.position += 1; |
53 | | - Ok(byte) |
54 | | - } |
55 | | - |
56 | | - async fn read_int16_big_endian(&mut self) -> TdsResult<i16> { |
57 | | - let mut buf = [0u8; 2]; |
58 | | - self.read_bytes(&mut buf).await?; |
59 | | - Ok(i16::from_be_bytes(buf)) |
60 | | - } |
61 | | - |
62 | | - async fn read_int32_big_endian(&mut self) -> TdsResult<i32> { |
63 | | - let mut buf = [0u8; 4]; |
64 | | - self.read_bytes(&mut buf).await?; |
65 | | - Ok(i32::from_be_bytes(buf)) |
66 | | - } |
67 | | - |
68 | | - async fn read_uint40(&mut self) -> TdsResult<u64> { |
69 | | - let mut buf = [0u8; 8]; |
70 | | - self.read_bytes(&mut buf[..5]).await?; |
71 | | - Ok(u64::from_le_bytes(buf)) |
72 | | - } |
73 | | - |
74 | | - async fn read_float32(&mut self) -> TdsResult<f32> { |
75 | | - let mut buf = [0u8; 4]; |
76 | | - self.read_bytes(&mut buf).await?; |
77 | | - Ok(f32::from_le_bytes(buf)) |
78 | | - } |
79 | | - |
80 | | - async fn read_float64(&mut self) -> TdsResult<f64> { |
81 | | - let mut buf = [0u8; 8]; |
82 | | - self.read_bytes(&mut buf).await?; |
83 | | - Ok(f64::from_le_bytes(buf)) |
84 | | - } |
85 | | - |
86 | | - async fn read_uint16(&mut self) -> TdsResult<u16> { |
87 | | - let mut buf = [0u8; 2]; |
88 | | - self.read_bytes(&mut buf).await?; |
89 | | - Ok(u16::from_le_bytes(buf)) |
90 | | - } |
91 | | - |
92 | | - async fn read_uint32(&mut self) -> TdsResult<u32> { |
93 | | - let mut buf = [0u8; 4]; |
94 | | - self.read_bytes(&mut buf).await?; |
95 | | - Ok(u32::from_le_bytes(buf)) |
96 | | - } |
97 | | - |
98 | | - async fn read_uint64(&mut self) -> TdsResult<u64> { |
99 | | - let mut buf = [0u8; 8]; |
100 | | - self.read_bytes(&mut buf).await?; |
101 | | - Ok(u64::from_le_bytes(buf)) |
102 | | - } |
103 | | - |
104 | | - async fn read_int16(&mut self) -> TdsResult<i16> { |
105 | | - let mut buf = [0u8; 2]; |
106 | | - self.read_bytes(&mut buf).await?; |
107 | | - Ok(i16::from_le_bytes(buf)) |
108 | | - } |
109 | | - |
110 | | - async fn read_uint24(&mut self) -> TdsResult<u32> { |
111 | | - let mut buf = [0u8; 4]; |
112 | | - self.read_bytes(&mut buf[..3]).await?; |
113 | | - Ok(u32::from_le_bytes(buf)) |
114 | | - } |
115 | | - |
116 | | - async fn read_int32(&mut self) -> TdsResult<i32> { |
117 | | - let mut buf = [0u8; 4]; |
118 | | - self.read_bytes(&mut buf).await?; |
119 | | - Ok(i32::from_le_bytes(buf)) |
120 | | - } |
121 | | - |
122 | | - async fn read_int64(&mut self) -> TdsResult<i64> { |
123 | | - let mut buf = [0u8; 8]; |
124 | | - self.read_bytes(&mut buf).await?; |
125 | | - Ok(i64::from_le_bytes(buf)) |
126 | | - } |
127 | | - |
128 | | - async fn read_bytes(&mut self, buf: &mut [u8]) -> TdsResult<usize> { |
129 | | - // Use checked arithmetic to prevent overflow |
130 | | - let end_position = self.position.checked_add(buf.len()).ok_or_else(|| { |
131 | | - mssql_tds::error::Error::Io(Error::new( |
132 | | - ErrorKind::InvalidInput, |
133 | | - "buffer length causes position overflow", |
134 | | - )) |
135 | | - })?; |
136 | | - |
137 | | - if end_position > self.data.len() { |
138 | | - return Err(mssql_tds::error::Error::Io(Error::new( |
139 | | - ErrorKind::UnexpectedEof, |
140 | | - "EOF", |
141 | | - ))); |
142 | | - } |
143 | | - buf.copy_from_slice(&self.data[self.position..end_position]); |
144 | | - self.position = end_position; |
145 | | - Ok(buf.len()) |
146 | | - } |
147 | | - |
148 | | - async fn read_u8_varbyte(&mut self) -> TdsResult<Vec<u8>> { |
149 | | - let len = self.read_byte().await? as usize; |
150 | | - const MAX_ALLOC: usize = 1024 * 1024; |
151 | | - if len > MAX_ALLOC { |
152 | | - return Err(mssql_tds::error::Error::Io(Error::new( |
153 | | - ErrorKind::InvalidData, |
154 | | - format!("Allocation size {} exceeds max {}", len, MAX_ALLOC), |
155 | | - ))); |
156 | | - } |
157 | | - let mut buf = vec![0u8; len]; |
158 | | - self.read_bytes(&mut buf).await?; |
159 | | - Ok(buf) |
160 | | - } |
161 | | - |
162 | | - async fn read_u16_varbyte(&mut self) -> TdsResult<Vec<u8>> { |
163 | | - let len = self.read_uint16().await? as usize; |
164 | | - const MAX_ALLOC: usize = 1024 * 1024; |
165 | | - if len > MAX_ALLOC { |
166 | | - return Err(mssql_tds::error::Error::Io(Error::new( |
167 | | - ErrorKind::InvalidData, |
168 | | - format!("Allocation size {} exceeds max {}", len, MAX_ALLOC), |
169 | | - ))); |
170 | | - } |
171 | | - let mut buf = vec![0u8; len]; |
172 | | - self.read_bytes(&mut buf).await?; |
173 | | - Ok(buf) |
174 | | - } |
175 | | - |
176 | | - async fn read_varchar_u16_length(&mut self) -> TdsResult<Option<String>> { |
177 | | - let len = self.read_uint16().await?; |
178 | | - if len == 0xFFFF { |
179 | | - return Ok(None); |
180 | | - } |
181 | | - let byte_len = (len as usize) * 2; |
182 | | - const MAX_ALLOC: usize = 1024 * 1024; |
183 | | - if byte_len > MAX_ALLOC { |
184 | | - return Err(mssql_tds::error::Error::Io(Error::new( |
185 | | - ErrorKind::InvalidData, |
186 | | - format!("Allocation size {} exceeds max {}", byte_len, MAX_ALLOC), |
187 | | - ))); |
188 | | - } |
189 | | - let mut buf = vec![0u8; byte_len]; |
190 | | - self.read_bytes(&mut buf).await?; |
191 | | - String::from_utf16(&buf.chunks_exact(2).map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]])).collect::<Vec<u16>>()) |
192 | | - .map(Some) |
193 | | - .map_err(|e| mssql_tds::error::Error::Io(Error::new(ErrorKind::InvalidData, e))) |
194 | | - } |
195 | | - |
196 | | - async fn read_varchar_u8_length(&mut self) -> TdsResult<String> { |
197 | | - let len = self.read_byte().await? as usize; |
198 | | - let byte_len = len * 2; |
199 | | - const MAX_ALLOC: usize = 1024 * 1024; |
200 | | - if byte_len > MAX_ALLOC { |
201 | | - return Err(mssql_tds::error::Error::Io(Error::new( |
202 | | - ErrorKind::InvalidData, |
203 | | - format!("Allocation size {} exceeds max {}", byte_len, MAX_ALLOC), |
204 | | - ))); |
205 | | - } |
206 | | - let mut buf = vec![0u8; byte_len]; |
207 | | - self.read_bytes(&mut buf).await?; |
208 | | - String::from_utf16(&buf.chunks_exact(2).map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]])).collect::<Vec<u16>>()) |
209 | | - .map_err(|e| mssql_tds::error::Error::Io(Error::new(ErrorKind::InvalidData, e))) |
210 | | - } |
211 | | - |
212 | | - async fn read_unicode(&mut self, string_length: usize) -> TdsResult<String> { |
213 | | - let byte_len = string_length * 2; |
214 | | - const MAX_ALLOC: usize = 1024 * 1024; |
215 | | - if byte_len > MAX_ALLOC { |
216 | | - return Err(mssql_tds::error::Error::Io(Error::new( |
217 | | - ErrorKind::InvalidData, |
218 | | - format!("Allocation size {} exceeds max {}", byte_len, MAX_ALLOC), |
219 | | - ))); |
220 | | - } |
221 | | - let mut buf = vec![0u8; byte_len]; |
222 | | - self.read_bytes(&mut buf).await?; |
223 | | - String::from_utf16(&buf.chunks_exact(2).map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]])).collect::<Vec<u16>>()) |
224 | | - .map_err(|e| mssql_tds::error::Error::Io(Error::new(ErrorKind::InvalidData, e))) |
225 | | - } |
226 | | - |
227 | | - async fn read_unicode_with_byte_length(&mut self, byte_length: usize) -> TdsResult<String> { |
228 | | - const MAX_ALLOC: usize = 1024 * 1024; |
229 | | - if byte_length > MAX_ALLOC { |
230 | | - return Err(mssql_tds::error::Error::Io(Error::new( |
231 | | - ErrorKind::InvalidData, |
232 | | - format!("Allocation size {} exceeds max {}", byte_length, MAX_ALLOC), |
233 | | - ))); |
234 | | - } |
235 | | - let mut buf = vec![0u8; byte_length]; |
236 | | - self.read_bytes(&mut buf).await?; |
237 | | - String::from_utf16(&buf.chunks_exact(2).map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]])).collect::<Vec<u16>>()) |
238 | | - .map_err(|e| mssql_tds::error::Error::Io(Error::new(ErrorKind::InvalidData, e))) |
239 | | - } |
240 | | - |
241 | | - async fn skip_bytes(&mut self, skip_count: usize) -> TdsResult<()> { |
242 | | - // Use checked arithmetic to prevent overflow |
243 | | - let new_position = self.position.checked_add(skip_count).ok_or_else(|| { |
244 | | - mssql_tds::error::Error::Io(Error::new( |
245 | | - ErrorKind::InvalidInput, |
246 | | - "skip_count causes position overflow", |
247 | | - )) |
248 | | - })?; |
249 | | - |
250 | | - if new_position > self.data.len() { |
251 | | - return Err(mssql_tds::error::Error::Io(Error::new( |
252 | | - ErrorKind::UnexpectedEof, |
253 | | - "EOF", |
254 | | - ))); |
255 | | - } |
256 | | - self.position = new_position; |
257 | | - Ok(()) |
258 | | - } |
259 | | - |
260 | | - async fn cancel_read_stream(&mut self) -> TdsResult<()> { |
261 | | - // No-op for fuzzing |
262 | | - Ok(()) |
263 | | - } |
264 | | - |
265 | | - fn reset_reader(&mut self) { |
266 | | - self.position = 0; |
267 | | - } |
268 | | -} |
| 23 | +use mssql_tds::fuzz_support::{FuzzReader, MockTransport, TdsConnectionProvider}; |
269 | 24 |
|
270 | 25 | fuzz_target!(|data: &[u8]| { |
271 | 26 | // Need at least some data to work with |
|
0 commit comments