diff --git a/thrift/compiler/sema/check_map_keys.cc b/thrift/compiler/sema/check_map_keys.cc index 60bd5f1416c..df4c9935ddf 100644 --- a/thrift/compiler/sema/check_map_keys.cc +++ b/thrift/compiler/sema/check_map_keys.cc @@ -27,10 +27,14 @@ namespace apache::thrift::compiler { namespace { +diagnostic_level kMismatchLevel = diagnostic_level::warning; +diagnostic_level kDuplicateLevel = diagnostic_level::warning; + +using const_value_kv = std::pair; + std::string to_string(const t_const_value* val); -std::string to_string( - const std::pair& val) { +std::string to_string(const const_value_kv& val) { return fmt::format("({}, {})", to_string(val.first), to_string(val.second)); } @@ -114,9 +118,7 @@ bool equal_value(const t_const_value* a, const t_const_value* b) { bool lt_value(const t_const_value* a, const t_const_value* b); -bool lt_value( - const std::pair& a, - const std::pair& b) { +bool lt_value(const const_value_kv& a, const const_value_kv& b) { if (equal_value(a.first, b.first)) { return lt_value(a.second, b.second); } @@ -166,16 +168,34 @@ struct const_value_comp { } }; -std::vector find_duplicate_keys( - const t_const_value* value) { - std::vector duplicates; +std::vector> +find_duplicate_keys(const std::vector& map_kvs) { + std::vector> duplicates; + std::map keys; + for (const auto& kv : map_kvs) { + auto it = keys.find(kv.first); + if (it != keys.end()) { + auto level = + equal_value(kv.second, it->second) ? kDuplicateLevel : kMismatchLevel; + duplicates.emplace_back(kv.first, level); + continue; + } + keys.emplace(kv.first, kv.second); + } + return duplicates; +} + +std::vector> +find_duplicate_keys(const std::vector& set_keys) { + std::vector> duplicates; std::set keys; - for (const auto& kv : value->get_map()) { - if (keys.count(kv.first) > 0) { - duplicates.push_back(kv.first); + for (const auto& k : set_keys) { + auto it = keys.find(k); + if (it != keys.end()) { + duplicates.emplace_back(k, kDuplicateLevel); continue; } - keys.insert(kv.first); + keys.insert(k); } return duplicates; } @@ -193,30 +213,40 @@ bool is_named_const_value(const t_const_value* value, const t_node& node) { void check_key_value( diagnostics_engine& diags, const t_node& node, const t_const_value* value) { - // recurse on elements - if (value->kind() == t_const_value::CV_LIST) { - for (const t_const_value* elem : value->get_list()) { - check_key_value(diags, node, elem); - } - } - if (value->kind() == t_const_value::CV_MAP) { + auto report_duplicates = [&](const auto& duplicates) { // Don't recurse or check constant defined elsewhere. if (is_named_const_value(value, node)) { return; } - auto duplicates = find_duplicate_keys(value); - for (const auto& duplicate : duplicates) { + for (const auto& [duplicate, level] : duplicates) { // If the t_const_value has a source range, use it; otherwise, // fallback to the source range of the enclosing const. const source_range& src_range = duplicate->src_range() ? *duplicate->src_range() : (value->src_range() ? *value->src_range() : node.src_range()); // TODO(T213710219): Enable this with error severity - diags.warning( + diags.report( src_range.begin, - "Duplicate key in map literal: `{}`", + level, + "Duplicate key in {} literal: `{}`", + value->kind() == t_const_value::CV_MAP ? "map" : "set", to_string(duplicate)); } + }; + + // recurse on elements + if (value->kind() == t_const_value::CV_LIST) { + for (const t_const_value* elem : value->get_list()) { + check_key_value(diags, node, elem); + } + if (value->ttype()->get_true_type()->is_set()) { + auto duplicates = find_duplicate_keys(value->get_list()); + report_duplicates(duplicates); + } + } + if (value->kind() == t_const_value::CV_MAP) { + auto duplicates = find_duplicate_keys(value->get_map()); + report_duplicates(duplicates); for (const auto& kv : value->get_map()) { check_key_value(diags, node, kv.first); check_key_value(diags, node, kv.second); diff --git a/thrift/compiler/test/standard_validator_test.cc b/thrift/compiler/test/standard_validator_test.cc index b5548dd4663..71b41f3c866 100644 --- a/thrift/compiler/test/standard_validator_test.cc +++ b/thrift/compiler/test/standard_validator_test.cc @@ -228,3 +228,31 @@ TEST(StandardValidatorTest, FieldDefaultKeyCollision) { } )"); } + +TEST(StandardValidatorTest, SetKeyCollision) { + check_compile(R"( + const set SET_DUPE = [ + 2, + 4, + 2, + # expected-warning@-1: Duplicate key in set literal: `2` + 4, + # expected-warning@-1: Duplicate key in set literal: `4` + ]; + + const list LIST_DUPE = [2, 2, 2, 4, 2, 4]; + + const set> NESTED_SET = [[2], [4], [2]]; + # expected-warning@-1: Duplicate key in set literal: `[2]` + + const list> NESTED_IDENTIFIER = [[2], SET_DUPE]; + + struct S { + 1: set ok_init = []; + 2: set dupe_init_set = ["a", "b", "a"]; + # expected-warning@-1: Duplicate key in set literal: `a` + 3: list dupe_init_list = ["a", "b", "a"]; + 4: set set_from_named_const = SET_DUPE; + } + )"); +}