|
1 | 1 | /* |
2 | | - * SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. |
| 2 | + * SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION. |
3 | 3 | * SPDX-License-Identifier: Apache-2.0 |
4 | 4 | */ |
5 | 5 |
|
|
10 | 10 | #include <cuvs/neighbors/ivf_flat.hpp> |
11 | 11 | #include <dlpack/dlpack.h> |
12 | 12 | #include <raft/core/error.hpp> |
| 13 | +#include <raft/core/numpy_serializer.hpp> |
13 | 14 | #include <raft/core/serialize.hpp> |
14 | 15 |
|
15 | 16 | #include "../core/exceptions.hpp" |
@@ -396,9 +397,12 @@ extern "C" cuvsError_t cuvsMultiGpuIvfFlatDeserialize(cuvsResources_t res, |
396 | 397 | return cuvs::core::translate_exceptions([=] { |
397 | 398 | std::ifstream is(filename, std::ios::in | std::ios::binary); |
398 | 399 | if (!is) { RAFT_FAIL("Cannot open file %s", filename); } |
399 | | - char dtype_string[4]; |
400 | | - is.read(dtype_string, 4); |
401 | | - auto dtype = raft::detail::numpy_serializer::parse_descr(std::string(dtype_string, 4)); |
| 400 | + char dtype_string[4]{}; |
| 401 | + if (!is.read(dtype_string, sizeof(dtype_string))) { |
| 402 | + RAFT_FAIL("Invalid or truncated index header in file %s", filename); |
| 403 | + } |
| 404 | + auto dtype = |
| 405 | + raft::numpy_serializer::parse_descr(std::string(dtype_string, sizeof(dtype_string))); |
402 | 406 | is.close(); |
403 | 407 |
|
404 | 408 | index->dtype.bits = dtype.itemsize * 8; |
@@ -427,9 +431,12 @@ extern "C" cuvsError_t cuvsMultiGpuIvfFlatDistribute(cuvsResources_t res, |
427 | 431 | return cuvs::core::translate_exceptions([=] { |
428 | 432 | std::ifstream is(filename, std::ios::in | std::ios::binary); |
429 | 433 | if (!is) { RAFT_FAIL("Cannot open file %s", filename); } |
430 | | - char dtype_string[4]; |
431 | | - is.read(dtype_string, 4); |
432 | | - auto dtype = raft::detail::numpy_serializer::parse_descr(std::string(dtype_string, 4)); |
| 434 | + char dtype_string[4]{}; |
| 435 | + if (!is.read(dtype_string, sizeof(dtype_string))) { |
| 436 | + RAFT_FAIL("Invalid or truncated index header in file %s", filename); |
| 437 | + } |
| 438 | + auto dtype = |
| 439 | + raft::numpy_serializer::parse_descr(std::string(dtype_string, sizeof(dtype_string))); |
433 | 440 | is.close(); |
434 | 441 |
|
435 | 442 | index->dtype.bits = dtype.itemsize * 8; |
|
0 commit comments