Skip to content

Commit 1c7d237

Browse files
committed
Improve return type of ExpressionNode::compute
1 parent eab0465 commit 1c7d237

File tree

3 files changed

+55
-53
lines changed

3 files changed

+55
-53
lines changed

cpp/arcticdb/processing/clause.cpp

+8-8
Original file line numberDiff line numberDiff line change
@@ -157,9 +157,9 @@ std::vector<EntityId> FilterClause::process(std::vector<EntityId>&& entity_ids)
157157

158158
OutputSchema FilterClause::modify_schema(OutputSchema&& output_schema) const {
159159
check_column_presence(output_schema, *clause_info_.input_columns_, "Filter");
160-
auto expr = expression_context_->expression_nodes_.get_value(expression_context_->root_node_name_.value);
161-
auto opt_datatype = expr->compute(*expression_context_, output_schema.column_types());
162-
user_input::check<ErrorCode::E_INVALID_USER_ARGUMENT>(!opt_datatype.has_value(), "FilterClause AST produces a column, not a bitset");
160+
auto root_expr = expression_context_->expression_nodes_.get_value(expression_context_->root_node_name_.value);
161+
std::variant<BitSetTag, DataType> return_type = root_expr->compute(*expression_context_, output_schema.column_types());
162+
user_input::check<ErrorCode::E_INVALID_USER_ARGUMENT>(std::holds_alternative<BitSetTag>(return_type), "FilterClause AST produces a column, not a bitset");
163163
return output_schema;
164164
}
165165

@@ -202,11 +202,11 @@ std::vector<EntityId> ProjectClause::process(std::vector<EntityId>&& entity_ids)
202202

203203
OutputSchema ProjectClause::modify_schema(OutputSchema&& output_schema) const {
204204
check_column_presence(output_schema, *clause_info_.input_columns_, "Project");
205-
auto expr = expression_context_->expression_nodes_.get_value(expression_context_->root_node_name_.value);
206-
auto opt_datatype = expr->compute(*expression_context_, output_schema.column_types());
207-
user_input::check<ErrorCode::E_INVALID_USER_ARGUMENT>(opt_datatype.has_value(), "ProjectClause AST produces a bitset, not a column");
208-
output_schema.stream_descriptor_.add_scalar_field(*opt_datatype, output_column_);
209-
output_schema.column_types().emplace(output_column_, *opt_datatype);
205+
auto root_expr = expression_context_->expression_nodes_.get_value(expression_context_->root_node_name_.value);
206+
std::variant<BitSetTag, DataType> return_type = root_expr->compute(*expression_context_, output_schema.column_types());
207+
user_input::check<ErrorCode::E_INVALID_USER_ARGUMENT>(std::holds_alternative<DataType>(return_type), "ProjectClause AST produces a column, not a bitset");
208+
output_schema.stream_descriptor_.add_scalar_field(std::get<DataType>(return_type), output_column_);
209+
output_schema.column_types().emplace(output_column_, std::get<DataType>(return_type));
210210
return output_schema;
211211
}
212212

cpp/arcticdb/processing/expression_node.cpp

+41-43
Original file line numberDiff line numberDiff line change
@@ -70,36 +70,38 @@ VariantData ExpressionNode::compute(ProcessingUnit& seg) const {
7070
}
7171
}
7272

73-
std::optional<DataType> ExpressionNode::compute(const ExpressionContext& expression_context,
74-
const ankerl::unordered_dense::map<std::string, DataType>& column_types) const {
75-
std::optional<DataType> res;
76-
std::optional<DataType> left_type = util::variant_match(
73+
std::variant<BitSetTag, DataType> ExpressionNode::compute(
74+
const ExpressionContext& expression_context,
75+
const ankerl::unordered_dense::map<std::string, DataType>& column_types) const {
76+
// Default to BitSetTag
77+
std::variant<BitSetTag, DataType> res;
78+
std::variant<BitSetTag, DataType> left_type = util::variant_match(
7779
left_,
78-
[&column_types] (const ColumnName& column_name) -> std::optional<DataType> {
80+
[&column_types] (const ColumnName& column_name) -> std::variant<BitSetTag, DataType> {
7981
auto it = column_types.find(column_name.value);
8082
schema::check<ErrorCode::E_COLUMN_DOESNT_EXIST>(it != column_types.end(),
8183
"ProjectClause requires column '{}' to exist in input data"
8284
,column_name.value);
8385
return it->second;
8486
},
85-
[&expression_context] (const ValueName& value_name) -> std::optional<DataType> {
87+
[&expression_context] (const ValueName& value_name) -> std::variant<BitSetTag, DataType> {
8688
return expression_context.values_.get_value(value_name.value)->data_type_;
8789
},
88-
[&expression_context, &column_types] (const ExpressionName& expression_name) -> std::optional<DataType> {
90+
[&expression_context, &column_types] (const ExpressionName& expression_name) -> std::variant<BitSetTag, DataType> {
8991
auto expr = expression_context.expression_nodes_.get_value(expression_name.value);
9092
return expr->compute(expression_context, column_types);
9193
},
92-
[] (auto&&) -> std::optional<DataType> {
94+
[] (auto&&) -> std::variant<BitSetTag, DataType> {
9395
internal::raise<ErrorCode::E_ASSERTION_FAILURE>("Unexpected expression argument type");
94-
return std::nullopt;
96+
return {};
9597
}
9698
);
9799
if (is_unary_operation(operation_type_)) {
98100
switch (operation_type_) {
99101
case OperationType::ABS:
100102
case OperationType::NEG:
101-
user_input::check<ErrorCode::E_INVALID_USER_ARGUMENT>(left_type.has_value(), "Unexpected bitset input to unary arithmetic operation");
102-
details::visit_type(*left_type, [this, &res](auto tag) {
103+
user_input::check<ErrorCode::E_INVALID_USER_ARGUMENT>(std::holds_alternative<DataType>(left_type), "Unexpected bitset input to unary arithmetic operation");
104+
details::visit_type(std::get<DataType>(left_type), [this, &res](auto tag) {
103105
using type_info = ScalarTypeInfo<decltype(tag)>;
104106
if constexpr (is_numeric_type(type_info::data_type)) {
105107
if (operation_type_ == OperationType::ABS) {
@@ -119,60 +121,58 @@ std::optional<DataType> ExpressionNode::compute(const ExpressionContext& express
119121
case OperationType::ISNULL:
120122
case OperationType::NOTNULL:
121123
user_input::check<ErrorCode::E_INVALID_USER_ARGUMENT>(
122-
left_type.has_value() && (is_floating_point_type(*left_type) || is_sequence_type(*left_type) ||
123-
is_time_type(*left_type)),
124-
"Unexpected data type input to unary comparison operation {}",
125-
left_type.has_value() ? *left_type : DataType::UNKNOWN);
124+
std::holds_alternative<DataType>(left_type) && (is_floating_point_type(std::get<DataType>(left_type)) || is_sequence_type(std::get<DataType>(left_type)) ||
125+
is_time_type(std::get<DataType>(left_type))),
126+
"Unexpected data type input to unary comparison operation");
126127
break;
127128
case OperationType::IDENTITY:
128129
case OperationType::NOT:
129-
user_input::check<ErrorCode::E_INVALID_USER_ARGUMENT>(!left_type.has_value() || *left_type == DataType::BOOL8,
130-
"Unexpected data type input to unary boolean operation {}",
131-
left_type.has_value() ? *left_type : DataType::UNKNOWN);
130+
user_input::check<ErrorCode::E_INVALID_USER_ARGUMENT>(!std::holds_alternative<DataType>(left_type) || std::get<DataType>(left_type) == DataType::BOOL8,
131+
"Unexpected data type input to unary boolean operation");
132132
break;
133133
default:
134134
internal::raise<ErrorCode::E_ASSERTION_FAILURE>("Unexpected unary operator");
135135
}
136136
} else {
137137
// Binary operation
138-
std::optional<DataType> right_type;
138+
std::variant<BitSetTag, DataType> right_type;
139139
std::optional<bool> empty_value_set;
140140
right_type = util::variant_match(
141141
right_,
142-
[&column_types] (const ColumnName& column_name) -> std::optional<DataType> {
142+
[&column_types] (const ColumnName& column_name) -> std::variant<BitSetTag, DataType> {
143143
auto it = column_types.find(column_name.value);
144144
schema::check<ErrorCode::E_COLUMN_DOESNT_EXIST>(it != column_types.end(),
145145
"ProjectClause requires column '{}' to exist in input data"
146146
,column_name.value);
147147
return it->second;
148148
},
149-
[&expression_context] (const ValueName& value_name) -> std::optional<DataType> {
149+
[&expression_context] (const ValueName& value_name) -> std::variant<BitSetTag, DataType> {
150150
return expression_context.values_.get_value(value_name.value)->data_type_;
151151
},
152-
[&expression_context, &empty_value_set] (const ValueSetName& value_set_name) -> std::optional<DataType> {
152+
[&expression_context, &empty_value_set] (const ValueSetName& value_set_name) -> std::variant<BitSetTag, DataType> {
153153
auto value_set = expression_context.value_sets_.get_value(value_set_name.value);
154154
empty_value_set = value_set->empty();
155155
return value_set->base_type().data_type();
156156
},
157-
[&expression_context, &column_types] (const ExpressionName& expression_name) -> std::optional<DataType> {
157+
[&expression_context, &column_types] (const ExpressionName& expression_name) -> std::variant<BitSetTag, DataType> {
158158
auto expr = expression_context.expression_nodes_.get_value(expression_name.value);
159159
return expr->compute(expression_context, column_types);
160160
},
161-
[] (auto&&) -> std::optional<DataType> {
161+
[] (auto&&) -> std::variant<BitSetTag, DataType> {
162162
internal::raise<ErrorCode::E_ASSERTION_FAILURE>("Unexpected expression argument type");
163-
return std::nullopt;
163+
return {};
164164
}
165165
);
166166
switch (operation_type_) {
167167
case OperationType::ADD:
168168
case OperationType::SUB:
169169
case OperationType::MUL:
170170
case OperationType::DIV:
171-
user_input::check<ErrorCode::E_INVALID_USER_ARGUMENT>(left_type.has_value(), "Unexpected bitset input to binary arithmetic operator");
172-
user_input::check<ErrorCode::E_INVALID_USER_ARGUMENT>(right_type.has_value() && !empty_value_set.has_value(), "Unexpected input to binary arithmetic operator");
173-
details::visit_type(*left_type, [this, &res, right_type](auto left_tag) {
171+
user_input::check<ErrorCode::E_INVALID_USER_ARGUMENT>(std::holds_alternative<DataType>(left_type), "Unexpected bitset input to binary arithmetic operator");
172+
user_input::check<ErrorCode::E_INVALID_USER_ARGUMENT>(std::holds_alternative<DataType>(right_type) && !empty_value_set.has_value(), "Unexpected input to binary arithmetic operator");
173+
details::visit_type(std::get<DataType>(left_type), [this, &res, right_type](auto left_tag) {
174174
using left_type_info = ScalarTypeInfo<decltype(left_tag)>;
175-
details::visit_type(*right_type, [this, &res](auto right_tag) {
175+
details::visit_type(std::get<DataType>(right_type), [this, &res](auto right_tag) {
176176
using right_type_info = ScalarTypeInfo<decltype(right_tag)>;
177177
if constexpr (is_numeric_type(left_type_info::data_type) && is_numeric_type(right_type_info::data_type)) {
178178
switch (operation_type_) {
@@ -212,34 +212,32 @@ std::optional<DataType> ExpressionNode::compute(const ExpressionContext& express
212212
case OperationType::LE:
213213
case OperationType::GT:
214214
case OperationType::GE:
215-
user_input::check<ErrorCode::E_INVALID_USER_ARGUMENT>(left_type.has_value(), "Unexpected bitset input to binary comparison operator");
216-
user_input::check<ErrorCode::E_INVALID_USER_ARGUMENT>(right_type.has_value() && !empty_value_set.has_value(), "Unexpected input to binary comparison operator");
215+
user_input::check<ErrorCode::E_INVALID_USER_ARGUMENT>(std::holds_alternative<DataType>(left_type), "Unexpected bitset input to binary comparison operator");
216+
user_input::check<ErrorCode::E_INVALID_USER_ARGUMENT>(std::holds_alternative<DataType>(right_type) && !empty_value_set.has_value(), "Unexpected input to binary comparison operator");
217217
user_input::check<ErrorCode::E_INVALID_USER_ARGUMENT>(
218-
(is_numeric_type(*left_type) && is_numeric_type(*right_type)) ||
219-
(is_bool_type(*left_type) && is_bool_type(*right_type)) ||
220-
(is_sequence_type(*left_type) && is_sequence_type(*right_type) && (operation_type_ == OperationType::EQ || operation_type_ == OperationType::NE)),
218+
(is_numeric_type(std::get<DataType>(left_type)) && is_numeric_type(std::get<DataType>(right_type))) ||
219+
(is_bool_type(std::get<DataType>(left_type)) && is_bool_type(std::get<DataType>(right_type))) ||
220+
(is_sequence_type(std::get<DataType>(left_type)) && is_sequence_type(std::get<DataType>(right_type)) && (operation_type_ == OperationType::EQ || operation_type_ == OperationType::NE)),
221221
"Incompatible data types provided in binary comparison {}, {}",
222-
*left_type, *right_type);
222+
std::get<DataType>(left_type), std::get<DataType>(right_type));
223223
break;
224224
case OperationType::ISIN:
225225
case OperationType::ISNOTIN:
226-
user_input::check<ErrorCode::E_INVALID_USER_ARGUMENT>(left_type.has_value(), "Unexpected bitset input to binary comparison operator");
227-
user_input::check<ErrorCode::E_INVALID_USER_ARGUMENT>(right_type.has_value() && empty_value_set.has_value(), "Unexpected input to binary comparison operator");
226+
user_input::check<ErrorCode::E_INVALID_USER_ARGUMENT>(std::holds_alternative<DataType>(left_type), "Unexpected bitset input to binary comparison operator");
227+
user_input::check<ErrorCode::E_INVALID_USER_ARGUMENT>(std::holds_alternative<DataType>(right_type) && empty_value_set.has_value(), "Unexpected input to binary comparison operator");
228228
if (!*empty_value_set) {
229229
user_input::check<ErrorCode::E_INVALID_USER_ARGUMENT>(
230-
(is_sequence_type(*left_type) && is_sequence_type(*right_type)) || (is_numeric_type(*left_type) && is_numeric_type(*right_type)),
230+
(is_sequence_type(std::get<DataType>(left_type)) && is_sequence_type(std::get<DataType>(right_type))) || (is_numeric_type(std::get<DataType>(left_type)) && is_numeric_type(std::get<DataType>(right_type))),
231231
"Incompatible data types provided in set membership operator {}, {}",
232-
*left_type, *right_type);
232+
std::get<DataType>(left_type), std::get<DataType>(right_type));
233233
} // else - Empty value set compatible with all data types
234234
break;
235235
case OperationType::AND:
236236
case OperationType::OR:
237237
case OperationType::XOR:
238238
user_input::check<ErrorCode::E_INVALID_USER_ARGUMENT>(
239-
(!left_type.has_value() || *left_type == DataType::BOOL8) && (!right_type.has_value() || *right_type == DataType::BOOL8),
240-
"Unexpected data types input to binary boolean operation {} {}",
241-
left_type.has_value() ? *left_type : DataType::UNKNOWN,
242-
right_type.has_value() ? *right_type : DataType::UNKNOWN);
239+
(!std::holds_alternative<DataType>(left_type) || std::get<DataType>(left_type) == DataType::BOOL8) && (!std::holds_alternative<DataType>(right_type) || std::get<DataType>(right_type) == DataType::BOOL8),
240+
"Unexpected data types input to binary boolean operation");
243241
break;
244242
default:
245243
internal::raise<ErrorCode::E_ASSERTION_FAILURE>("Unexpected binary operator");

cpp/arcticdb/processing/expression_node.hpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ struct EmptyResult {};
7575

7676
using VariantData = std::variant<FullResult, EmptyResult, std::shared_ptr<Value>, std::shared_ptr<ValueSet>, ColumnWithStrings, util::BitSet>;
7777

78+
// Used to represent that an ExpressionNode returns a bitset
79+
struct BitSetTag{};
80+
7881
/*
7982
* Basic AST node.
8083
*/
@@ -89,8 +92,9 @@ struct ExpressionNode {
8992

9093
VariantData compute(ProcessingUnit& seg) const;
9194

92-
std::optional<DataType> compute(const ExpressionContext& expression_context,
93-
const ankerl::unordered_dense::map<std::string, DataType>& column_types) const;
95+
std::variant<BitSetTag, DataType> compute(
96+
const ExpressionContext& expression_context,
97+
const ankerl::unordered_dense::map<std::string, DataType>& column_types) const;
9498
};
9599

96100
} //namespace arcticdb

0 commit comments

Comments
 (0)