@@ -327,3 +327,66 @@ func TestQueryWithCustomRule(t *testing.T) {
327327 }
328328 }
329329}
330+
331+ // TestCustomRuleWithArgs tests graphql.GetArgumentValues() be able to access
332+ // field's argument values from custom validation rule.
333+ func TestCustomRuleWithArgs (t * testing.T ) {
334+ fieldDef , ok := testutil .StarWarsSchema .QueryType ().Fields ()["human" ]
335+ if ! ok {
336+ t .Fatal ("can't retrieve \" human\" field definition" )
337+ }
338+
339+ // a custom validation rule to extract argument values of "human" field.
340+ var actual map [string ]interface {}
341+ enter := func (p visitor.VisitFuncParams ) (string , interface {}) {
342+ // only interested in "human" field.
343+ fieldNode , ok := p .Node .(* ast.Field )
344+ if ! ok || fieldNode .Name == nil || fieldNode .Name .Value != "human" {
345+ return visitor .ActionNoChange , nil
346+ }
347+ // extract argument values by graphql.GetArgumentValues().
348+ actual = graphql .GetArgumentValues (fieldDef .Args , fieldNode .Arguments , nil )
349+ return visitor .ActionNoChange , nil
350+ }
351+ checkHumanArgs := func (context * graphql.ValidationContext ) * graphql.ValidationRuleInstance {
352+ return & graphql.ValidationRuleInstance {
353+ VisitorOpts : & visitor.VisitorOptions {
354+ KindFuncMap : map [string ]visitor.NamedVisitFuncs {
355+ kinds .Field : {Enter : enter },
356+ },
357+ },
358+ }
359+ }
360+
361+ for _ , tc := range []struct {
362+ query string
363+ expected map [string ]interface {}
364+ }{
365+ {
366+ `query { human(id: "1000") { name } }` ,
367+ map [string ]interface {}{"id" : "1000" },
368+ },
369+ {
370+ `query { human(id: "1002") { name } }` ,
371+ map [string ]interface {}{"id" : "1002" },
372+ },
373+ {
374+ `query { human(id: "9999") { name } }` ,
375+ map [string ]interface {}{"id" : "9999" },
376+ },
377+ } {
378+ actual = nil
379+ params := graphql.Params {
380+ Schema : testutil .StarWarsSchema ,
381+ RequestString : tc .query ,
382+ ValidationRules : append (graphql .SpecifiedRules , checkHumanArgs ),
383+ }
384+ result := graphql .Do (params )
385+ if len (result .Errors ) > 0 {
386+ t .Fatalf ("wrong result, unexpected errors: %v" , result .Errors )
387+ }
388+ if ! reflect .DeepEqual (actual , tc .expected ) {
389+ t .Fatalf ("unexpected result: want=%+v got=%+v" , tc .expected , actual )
390+ }
391+ }
392+ }
0 commit comments