@@ -18,15 +18,90 @@ import (
1818 "errors"
1919 "fmt"
2020 "go/ast"
21+ "strconv"
2122 "strings"
2223
24+ "sigs.k8s.io/controller-tools/pkg/crd"
25+ "sigs.k8s.io/controller-tools/pkg/loader"
2326 "sigs.k8s.io/controller-tools/pkg/markers"
2427)
2528
26- // getMarkedChildrenOfField collects all marked fields from type declarations starting at root in depth-first order.
27- func (g generator ) getMarkedChildrenOfField (root markers.FieldInfo ) (map [string ][]* fieldInfo , error ) {
29+ // importIdents maps import identifiers to a list of corresponding path and file containing that path.
30+ // For example, consider the set of 2 files, one containing 'import foo "my/foo"' and the other 'import foo "your/foo"'.
31+ // Then the map would be: map["foo"][]struct{{f: (file 1), path: "my/foo"},{f: (file 2), path: "your/foo"}}.
32+ type importIdents map [string ][]struct {
33+ f * ast.File
34+ path string
35+ }
36+
37+ // newImportIdents creates an importIdents from all imports in pkg.
38+ func newImportIdents (pkg * loader.Package ) (importIdents , error ) {
39+ importIDs := make (map [string ][]struct {
40+ f * ast.File
41+ path string
42+ })
43+ for _ , file := range pkg .Syntax {
44+ for _ , impSpec := range file .Imports {
45+ val , err := strconv .Unquote (impSpec .Path .Value )
46+ if err != nil {
47+ return nil , err
48+ }
49+ // Most imports are not locally named, so the real package name should be used.
50+ var impName string
51+ if imp , hasImp := pkg .Imports ()[val ]; hasImp {
52+ impName = imp .Name
53+ }
54+ // impSpec.Name will not be empty for locally named imports
55+ if impSpec .Name != nil {
56+ impName = impSpec .Name .Name
57+ }
58+ importIDs [impName ] = append (importIDs [impName ], struct {
59+ f * ast.File
60+ path string
61+ }{file , val })
62+ }
63+ }
64+ return importIDs , nil
65+ }
66+
67+ // findPackagePathForSelExpr returns the package path corresponding to the package name used in expr if it exists in im.
68+ func (im importIdents ) findPackagePathForSelExpr (expr * ast.SelectorExpr ) (pkgPath string ) {
69+ // X contains the name being selected from.
70+ xIdent , isIdent := expr .X .(* ast.Ident )
71+ if ! isIdent {
72+ return ""
73+ }
74+ // Imports for all import statements where local import name == name being selected from.
75+ imports , hasImports := im [xIdent .String ()]
76+ if ! hasImports {
77+ return ""
78+ }
79+
80+ // Short-circuit if only one import.
81+ if len (imports ) == 1 {
82+ return imports [0 ].path
83+ }
84+
85+ // If multiple files contain the same local import name, check to see which file contains the selector expression.
86+ for _ , imp := range imports {
87+ if imp .f .Pos () <= expr .Pos () && imp .f .End () >= expr .End () {
88+ return imp .path
89+ }
90+ }
91+ return ""
92+ }
93+
94+ // getMarkedChildrenOfField collects all marked fields from type declarations starting at rootField in depth-first order.
95+ func (g generator ) getMarkedChildrenOfField (rootPkg * loader.Package , rootField markers.FieldInfo ) (map [string ][]* fieldInfo , error ) {
96+ // Gather all types and imports needed to build the BFS tree.
97+ rootPkg .NeedTypesInfo ()
98+ importIDs , err := newImportIdents (rootPkg )
99+ if err != nil {
100+ return nil , err
101+ }
102+
28103 // ast.Inspect will not traverse into fields, so iteratively collect them and to check for markers.
29- nextFields := []* fieldInfo {{FieldInfo : root }}
104+ nextFields := []* fieldInfo {{FieldInfo : rootField }}
30105 markedFields := map [string ][]* fieldInfo {}
31106 for len (nextFields ) > 0 {
32107 fields := []* fieldInfo {}
@@ -36,49 +111,65 @@ func (g generator) getMarkedChildrenOfField(root markers.FieldInfo) (map[string]
36111 if n == nil {
37112 return true
38113 }
39- switch expr := n .(type ) {
114+
115+ var info * markers.TypeInfo
116+ var hasInfo bool
117+ switch nt := n .(type ) {
118+ case * ast.SelectorExpr :
119+ // Case of a type definition in an imported package.
120+
121+ pkgPath := importIDs .findPackagePathForSelExpr (nt )
122+ if pkgPath == "" {
123+ // Found no reference to pkgPath in any file.
124+ return true
125+ }
126+ if pkg , hasImport := rootPkg .Imports ()[loader .NonVendorPath (pkgPath )]; hasImport {
127+ // Check if the field's type exists in the known types.
128+ info , hasInfo = g .types [crd.TypeIdent {Package : pkg , Name : nt .Sel .Name }]
129+ }
40130 case * ast.Ident :
131+ // Case of a local type definition.
132+
41133 // Only look at type names.
42- if expr .Obj == nil || expr .Obj .Kind != ast .Typ {
43- return true
134+ if nt .Obj != nil && nt .Obj .Kind == ast .Typ {
135+ // Check if the field's type exists in the known types.
136+ info , hasInfo = g .types [crd.TypeIdent {Package : rootPkg , Name : nt .Name }]
44137 }
45- // Check if the field's type exists in the known types.
46- info , hasInfo := g .types [expr .Name ]
47- if ! hasInfo {
138+ }
139+ if ! hasInfo {
140+ return true
141+ }
142+
143+ // Add all child fields to the list to search next.
144+ for _ , finfo := range info .Fields {
145+ segment , err := getPathSegmentForField (finfo )
146+ if err != nil {
147+ errs = append (errs , fmt .Errorf ("error getting path from type %s field %s: %v" , info .Name , finfo .Name , err ))
48148 return true
49149 }
50- // Add all child fields to the list to search next.
51- for _ , finfo := range info .Fields {
52- segment , err := getPathSegmentForField (finfo )
53- if err != nil {
54- errs = append (errs , fmt .Errorf ("error getting path from type %s field %s: %v" ,
55- info .Name , finfo .Name , err ),
56- )
57- return true
58- }
59- // Add extra information to the segment if it comes from a certain field type.
60- switch finfo .RawField .Type .(type ) {
61- case (* ast.ArrayType ):
62- // arrayFieldGroup case.
63- if segment != ignoredTag && segment != inlinedTag {
64- segment += "[0]"
65- }
66- }
67- // Create a new set of path segments using the parent's segments
68- // and add the field to the next fields to search.
69- parentSegments := make ([]string , len (field .pathSegments ), len (field .pathSegments )+ 1 )
70- copy (parentSegments , field .pathSegments )
71- f := & fieldInfo {
72- FieldInfo : finfo ,
73- pathSegments : append (parentSegments , segment ),
74- }
75- fields = append (fields , f )
76- // Marked fields get collected for the caller to parse.
77- if len (finfo .Markers ) != 0 {
78- markedFields [info .Name ] = append (markedFields [info .Name ], f )
150+ // Add extra information to the segment if it comes from a certain field type.
151+ switch finfo .RawField .Type .(type ) {
152+ case * ast.ArrayType :
153+ // arrayFieldGroup case.
154+ if segment != ignoredTag && segment != inlinedTag {
155+ segment += "[0]"
79156 }
80157 }
158+ // Create a new set of path segments using the parent's segments
159+ // and add the field to the next fields to search.
160+ parentSegments := make ([]string , len (field .pathSegments ), len (field .pathSegments )+ 1 )
161+ copy (parentSegments , field .pathSegments )
162+ f := & fieldInfo {
163+ FieldInfo : finfo ,
164+ pathSegments : append (parentSegments , segment ),
165+ }
166+ fields = append (fields , f )
167+ // Marked fields get collected for the caller to parse.
168+ if len (finfo .Markers ) != 0 {
169+ markedFields [info .Name ] = append (markedFields [info .Name ], f )
170+ }
81171 }
172+
82173 return true
83174 })
84175 if err := fmtParseErrors (errs ); err != nil {
0 commit comments