diff --git a/lib/graphql/execution/interpreter/runtime.rb b/lib/graphql/execution/interpreter/runtime.rb index 481af7f764..76534fc9c2 100644 --- a/lib/graphql/execution/interpreter/runtime.rb +++ b/lib/graphql/execution/interpreter/runtime.rb @@ -198,7 +198,7 @@ def run_eager def each_gathered_selections(response_hash) ordered_result_keys = [] - gathered_selections = gather_selections(response_hash.graphql_application_value, response_hash.graphql_result_type, response_hash.graphql_selections, nil, {}, ordered_result_keys) + gathered_selections = gather_selections(response_hash, response_hash.graphql_application_value, response_hash.graphql_result_type, response_hash.graphql_selections, nil, {}, ordered_result_keys) ordered_result_keys.uniq! if gathered_selections.is_a?(Array) gathered_selections.each do |item| @@ -209,15 +209,13 @@ def each_gathered_selections(response_hash) end end - def gather_selections(owner_object, owner_type, selections, selections_to_run, selections_by_name, ordered_result_keys) + def gather_selections(graphql_response, owner_object, owner_type, selections, selections_to_run, selections_by_name, ordered_result_keys) selections.each do |node| - # Skip gathering this if the directive says so - if !directives_include?(node, owner_object, owner_type) - next - end - if node.is_a?(GraphQL::Language::Nodes::Field) response_key = node.alias || node.name + if !directives_include?(node, owner_object, owner_type, graphql_response, response_key) + next + end ordered_result_keys << response_key selections = selections_by_name[response_key] # if there was already a selection of this field, @@ -234,6 +232,9 @@ def gather_selections(owner_object, owner_type, selections, selections_to_run, s selections_by_name[response_key] = node end else + if !directives_include?(node, owner_object, owner_type, graphql_response, nil) + next + end # This is an InlineFragment or a FragmentSpread if !@runtime_directive_names.empty? && node.directives.any? { |d| @runtime_directive_names.include?(d.name) } next_selections = {} @@ -255,14 +256,14 @@ def gather_selections(owner_object, owner_type, selections, selections_to_run, s type_defn = query.types.type(node.type.name) if query.types.possible_types(type_defn).include?(owner_type) - result = gather_selections(owner_object, owner_type, node.selections, selections_to_run, next_selections, ordered_result_keys) + result = gather_selections(graphql_response, owner_object, owner_type, node.selections, selections_to_run, next_selections, ordered_result_keys) if !result.equal?(next_selections) selections_to_run = result end end else # it's an untyped fragment, definitely continue - result = gather_selections(owner_object, owner_type, node.selections, selections_to_run, next_selections, ordered_result_keys) + result = gather_selections(graphql_response, owner_object, owner_type, node.selections, selections_to_run, next_selections, ordered_result_keys) if !result.equal?(next_selections) selections_to_run = result end @@ -271,7 +272,7 @@ def gather_selections(owner_object, owner_type, selections, selections_to_run, s fragment_def = query.fragments[node.name] type_defn = query.types.type(fragment_def.type.name) if query.types.possible_types(type_defn).include?(owner_type) - result = gather_selections(owner_object, owner_type, fragment_def.selections, selections_to_run, next_selections, ordered_result_keys) + result = gather_selections(graphql_response, owner_object, owner_type, fragment_def.selections, selections_to_run, next_selections, ordered_result_keys) if !result.equal?(next_selections) selections_to_run = result end @@ -579,7 +580,7 @@ def continue_value(value, field, is_non_null, ast_node, result_name, selection_r value.path ||= current_path value.ast_node ||= ast_node context.errors << value - if selection_result + if selection_result && result_name set_result(selection_result, result_name, nil, false, is_non_null) end end @@ -856,11 +857,31 @@ def run_directive(method_name, object, directives, idx, &block) end # Check {Schema::Directive.include?} for each directive that's present - def directives_include?(node, graphql_object, parent_type) + def directives_include?(node, graphql_object, parent_type, selection_result, extra_path_part) node.directives.each do |dir_node| dir_defn = @schema_directives.fetch(dir_node.name) - args = arguments(graphql_object, dir_defn, dir_node) - if !dir_defn.include?(graphql_object, args, context) + raw_dir_args = arguments(nil, dir_defn, dir_node) + if !raw_dir_args.is_a?(GraphQL::ExecutionError) + begin + dir_defn.validate!(raw_dir_args, context) + rescue GraphQL::ExecutionError => err + raw_dir_args = err + end + end + + if extra_path_part && raw_dir_args.is_a?(GraphQL::ExecutionError) + raw_dir_args.path = current_path + [extra_path_part] + end + + dir_args = continue_value( + raw_dir_args, # value + nil, # field + false, # is_non_null + dir_node, # ast_node + nil, # result_name + selection_result + ) + if dir_args == HALT || !dir_defn.include?(graphql_object, dir_args, context) return false end end diff --git a/spec/graphql/schema/directive_spec.rb b/spec/graphql/schema/directive_spec.rb index b22ebd8043..d18bd9c613 100644 --- a/spec/graphql/schema/directive_spec.rb +++ b/spec/graphql/schema/directive_spec.rb @@ -179,6 +179,19 @@ def self.resolve_fragment_spread(ast_node, parent_type, objects, _args, context) end end + class ValidationTest < GraphQL::Schema::Directive + locations(FIELD) + argument :int, Int, validates: { numericality: { less_than: 10 } } + + def self.include?(*_args) + true + end + + def self.resolve_field(*_args) + nil + end + end + class Thing < GraphQL::Schema::Object field :name, String, null: false, hash_key: :name end @@ -251,6 +264,7 @@ def fetch(names) class Schema < GraphQL::Schema query(Query) directive(CountFields) + directive(ValidationTest) lazy_resolve(Proc, :call) use GraphQL::Dataloader use GraphQL::Execution::Next @@ -363,6 +377,19 @@ def exec_query(...) res = exec_query(query_str, context: { backtrace: true }) assert_equal 2, res["data"]["lazyThings"].size end + + it "handles validation errors in .include?" do + skip("Custom `.include?` is not supported in Execution::Next yet") if TESTING_EXEC_NEXT + res = exec_query("{ __typename @validationTest(int: 12) }") + expected_result = { + "errors" => [{"message" => "int must be less than 10", "locations" => [{"line" => 1, "column" => 14}], "path" => [ "__typename" ]}], + "data" => {} + } + assert_equal expected_result, res.to_h + + res2 = exec_query("{ __typename @validationTest(int: 8) }") + assert_equal({ "data" => { "__typename" => "Query" }}, res2.to_h) + end end describe "raising an error from an argument" do