Skip to content

Prevent process crashes due to invalid SQL queries #49

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 17 additions & 16 deletions src/VectorIndex/Storages/MergeTreeVSManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ namespace ErrorCodes
extern const int QUERY_WAS_CANCELLED;
extern const int ILLEGAL_COLUMN;
extern const int INVALID_VECTOR_INDEX;
extern const int WRONG_ARGUMENTS;
}

template <typename FloatType>
Expand All @@ -64,9 +65,9 @@ std::vector<float> getQueryVector(const IColumn * query_vector_column, size_t di
if (!query_data_concrete)
{
if (is_batch)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong query column type, expect Float32 or Float64 inside Array(Array()) in batch distance function");
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Wrong query column type, expect Float32 or Float64 inside Array(Array()) in batch distance function");
else
throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong query column type, expect Float32 or Float64 inside Array() in distance function");
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Wrong query column type, expect Float32 or Float64 inside Array() in distance function");
}

const auto & query_vec = query_data_concrete->getData();
Expand All @@ -76,7 +77,7 @@ std::vector<float> getQueryVector(const IColumn * query_vector_column, size_t di
/// in batch distance case, dim_of_query = dim * offsets. dim in query is already checked in getFloatQueryVectorInBatch().
if (!is_batch && (dim_of_query != dim))
throw Exception(
ErrorCodes::LOGICAL_ERROR,
ErrorCodes::WRONG_ARGUMENTS,
"Dimension is not equal: query: {} vs search column: {}",
std::to_string(dim_of_query),
std::to_string(dim));
Expand All @@ -97,7 +98,7 @@ std::vector<float> getFloatQueryVectorInBatch(const IColumn * query_vectors_colu
const ColumnArray * query_vectors_col = checkAndGetColumn<ColumnArray>(query_vectors_column);

if (!query_vectors_col)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong query column type, expect Array(Array()) in batch distance function");
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Wrong query column type, expect Array(Array()) in batch distance function");

const IColumn & query_vectors = query_vectors_col->getData();
auto & offsets = query_vectors_col->getOffsets();
Expand All @@ -114,7 +115,7 @@ std::vector<float> getFloatQueryVectorInBatch(const IColumn * query_vectors_colu
size_t vec_size = vec_end_offset - vec_start_offset;
if (vec_size != dim)
throw Exception(
ErrorCodes::LOGICAL_ERROR,
ErrorCodes::WRONG_ARGUMENTS,
"Having query vector with wrong dimension: {} vs search column dimension: {}",
std::to_string(vec_size),
std::to_string(dim));
Expand All @@ -126,7 +127,7 @@ std::vector<float> getFloatQueryVectorInBatch(const IColumn * query_vectors_colu
else if (checkColumn<ColumnFloat64>(&query_vectors))
query_new_data = getQueryVector<Float64>(&query_vectors, dim, true);
else
throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong query column type, expect Float64 or Float32 inside Array(Array()) in batch distance function");
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Wrong query column type, expect Float64 or Float32 inside Array(Array()) in batch distance function");

return query_new_data;
}
Expand All @@ -138,15 +139,15 @@ VectorIndex::Float32VectorDatasetPtr MergeTreeVSManager::generateVectorDataset(b
auto dim = desc.search_column_dim;

if (!query_column)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong query column type");
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Wrong query column type");

ColumnPtr holder = query_column->convertToFullColumnIfConst();
const ColumnArray * query_col = checkAndGetColumn<ColumnArray>(holder.get());

if (is_batch)
{
if (!query_col)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong query column type, expect Array in batch distance function");
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Wrong query column type, expect Array in batch distance function");

const IColumn & query_data = query_col->getData();

Expand All @@ -161,7 +162,7 @@ VectorIndex::Float32VectorDatasetPtr MergeTreeVSManager::generateVectorDataset(b
else
{
if (!query_col)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong query column type, expect Array in distance function");
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Wrong query column type, expect Array in distance function");

const IColumn & query_data = query_col->getData();

Expand All @@ -171,7 +172,7 @@ VectorIndex::Float32VectorDatasetPtr MergeTreeVSManager::generateVectorDataset(b
else if (checkColumn<ColumnFloat64>(&query_data))
query_new_data = getQueryVector<Float64>(&query_data, dim, false);
else
throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong query column type, expect Float32 or Float64 inside Array() in distance function");
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Wrong query column type, expect Float32 or Float64 inside Array() in distance function");

return std::make_shared<VectorIndex::VectorDataset<Search::DataType::FloatVector>>(
1,
Expand All @@ -185,19 +186,19 @@ VectorIndex::BinaryVectorDatasetPtr MergeTreeVSManager::generateVectorDataset(bo
{
auto & query_column = desc.query_column;
if (!query_column)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong query column type");
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Wrong query column type");

auto dim = desc.search_column_dim;
if (dim % 8 != 0)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Dimension of Binary vector must be a multiple of 8");
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Dimension of Binary vector must be a multiple of 8");

ColumnPtr holder = query_column->convertToFullColumnIfConst();

if (is_batch)
{
const ColumnArray * query_col = checkAndGetColumn<ColumnArray>(holder.get());
if (!query_col)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong query column type, expect Array in batch distance function");
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Wrong query column type, expect Array in batch distance function");

const ColumnString *src_data_concrete = checkAndGetColumn<ColumnString>(query_col->getData());
if (!src_data_concrete)
Expand All @@ -217,7 +218,7 @@ VectorIndex::BinaryVectorDatasetPtr MergeTreeVSManager::generateVectorDataset(bo
size_t str_len = vec_end_offset - vec_start_offset - 1;
if (str_len * 8 != dim)
{
throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong dimension in batch distance: {}, search column dimension: {}", std::to_string(str_len * 8), std::to_string(dim));
throw Exception(ErrorCodes::WRONG_ARGUMENTS, "Wrong dimension in batch distance: {}, search column dimension: {}", std::to_string(str_len * 8), std::to_string(dim));
}

const char *str = src_data_concrete->getDataAt(i).data;
Expand All @@ -234,13 +235,13 @@ VectorIndex::BinaryVectorDatasetPtr MergeTreeVSManager::generateVectorDataset(bo
const ColumnString * query_col = checkAndGetColumn<ColumnString>(holder.get());

if (!query_col)
throw Exception(ErrorCodes::LOGICAL_ERROR, "Wrong query column type, expect fixed String in distance function");
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Wrong query column type, expect fixed String in distance function");

// every String column ends with terminating zero byte.
auto bytes_of_query = query_col->getChars().size() - 1;
if (bytes_of_query * 8 != dim)
throw Exception(
ErrorCodes::LOGICAL_ERROR,
ErrorCodes::WRONG_ARGUMENTS,
"Dimension for Binary vector search is not equal: query: {} vs search column: {}",
bytes_of_query * 8,
dim);
Expand Down