Skip to content

Commit 13e1bdc

Browse files
authored
feat: codegen for expr envs (#193)
1 parent aad1dca commit 13e1bdc

File tree

5 files changed

+396
-0
lines changed

5 files changed

+396
-0
lines changed

docs/Exprgen.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Exprgen
2+
3+
## Install
4+
```
5+
go install github.com/antonmedv/expr/exprgen
6+
```
7+
## Usage
8+
Fetch methods generates for all struct/map/array/string named types(exception is map types with unnamed not basic key type like `map[struct{...}]int`).
9+
10+
To generate just call exprgen with pkg paths as arguments:
11+
```
12+
exprgen pkg1 pkg2 ...
13+
```
14+
15+
After call, file `*pkg_name*_exprgen.go` will be created in each packages from arguments.

docs/Optimizations.md

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,4 +116,18 @@ func main() {
116116
}
117117
```
118118

119+
## Reduced use of reflect
120+
121+
To fetch fields from struct, values from map, get by indexes expr uses reflect package.
122+
Envs can implement vm.Fetcher interface, to avoid use reflect:
123+
```go
124+
type Fetcher interface {
125+
Fetch(interface{}) interface{}
126+
}
127+
```
128+
When you need to fetch a field, the method will be used instead reflect functions.
129+
If the field is not found, Fetch must return nil.
130+
To generate Fetch for your types, use [Exprgen](Exprgen.md).
131+
132+
119133
* [Contents](README.md)

exprgen/exprgen.go

Lines changed: 351 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,351 @@
1+
package main
2+
3+
import (
4+
"flag"
5+
"fmt"
6+
"go/ast"
7+
"go/format"
8+
"go/importer"
9+
"go/parser"
10+
"go/token"
11+
"go/types"
12+
"io/fs"
13+
"io/ioutil"
14+
"os"
15+
"path/filepath"
16+
"sort"
17+
"strings"
18+
)
19+
20+
const exprgenSuffix = "_exprgen.go"
21+
22+
func main() {
23+
flag.Parse()
24+
25+
filenames := flag.Args()
26+
if len(filenames) == 0 {
27+
flag.Usage()
28+
os.Exit(1)
29+
}
30+
31+
for _, filename := range filenames {
32+
if err := generate(filename); err != nil {
33+
fmt.Fprintf(os.Stderr, "generate '%s' error: %s", filename, err.Error())
34+
os.Exit(2)
35+
}
36+
}
37+
}
38+
39+
func generate(filename string) error {
40+
fi, err := os.Stat(filename)
41+
if err != nil {
42+
return fmt.Errorf("stat err: %w", err)
43+
}
44+
45+
if !fi.IsDir() {
46+
return fmt.Errorf("filename must be dir")
47+
}
48+
tfs := token.NewFileSet()
49+
packages, err := parser.ParseDir(tfs, filename, func(info fs.FileInfo) bool {
50+
return !strings.HasSuffix(info.Name(), exprgenSuffix)
51+
}, parser.ParseComments)
52+
if err != nil {
53+
return fmt.Errorf("parse dir error: %w", err)
54+
}
55+
56+
typesChecker := types.Config{
57+
Importer: importer.ForCompiler(tfs, "source", nil),
58+
}
59+
60+
for name, pkg := range packages {
61+
if strings.HasSuffix(name, "_test") {
62+
continue
63+
}
64+
65+
files := make([]*ast.File, 0, len(pkg.Files))
66+
for _, f := range pkg.Files {
67+
files = append(files, f)
68+
}
69+
70+
packageTypes, err := typesChecker.Check(name, tfs, files, nil)
71+
if err != nil {
72+
return fmt.Errorf("types check error: %w", err)
73+
}
74+
75+
b, err := fileData(name, packageTypes)
76+
if err != nil {
77+
return err
78+
}
79+
80+
err = ioutil.WriteFile(filepath.Join(filename, name+exprgenSuffix), b, 0644)
81+
if err != nil {
82+
return err
83+
}
84+
}
85+
86+
return nil
87+
}
88+
89+
func fileData(pkgName string, pkg *types.Package) ([]byte, error) {
90+
var data string
91+
echo := func(s string, xs ...interface{}) {
92+
data += fmt.Sprintf(s, xs...) + "\n"
93+
}
94+
echoRaw := func(s string) {
95+
data += fmt.Sprint(s) + "\n"
96+
}
97+
98+
echo(`// Code generated by exprgen. DO NOT EDIT.`)
99+
echo(``)
100+
echo(`package ` + pkgName)
101+
echo(``)
102+
echo(`--imports`)
103+
echo(``)
104+
105+
echoRaw(`func toInt(a interface{}) int {
106+
switch x := a.(type) {
107+
case float32:
108+
return int(x)
109+
case float64:
110+
return int(x)
111+
112+
case int:
113+
return x
114+
case int8:
115+
return int(x)
116+
case int16:
117+
return int(x)
118+
case int32:
119+
return int(x)
120+
case int64:
121+
return int(x)
122+
123+
case uint:
124+
return int(x)
125+
case uint8:
126+
return int(x)
127+
case uint16:
128+
return int(x)
129+
case uint32:
130+
return int(x)
131+
case uint64:
132+
return int(x)
133+
134+
default:
135+
panic(fmt.Sprintf("invalid operation: int(%T)", x))
136+
}
137+
}`)
138+
echo(``)
139+
140+
imports := make(map[string]string)
141+
142+
scope := pkg.Scope()
143+
for _, objectName := range scope.Names() {
144+
obj := scope.Lookup(objectName)
145+
namedType, ok := obj.Type().(*types.Named)
146+
if !ok {
147+
continue
148+
}
149+
150+
recvName := "v"
151+
for i := 0; i < namedType.NumMethods(); i++ {
152+
method := namedType.Method(i)
153+
signature := method.Type().(*types.Signature)
154+
recv := signature.Recv()
155+
if recv != nil && recv.Name() != "" {
156+
recvName = recv.Name()
157+
break
158+
}
159+
}
160+
161+
switch t := namedType.Underlying().(type) {
162+
case *types.Basic:
163+
if t.Kind() != types.String {
164+
break
165+
}
166+
167+
echo("func (%s %s) Fetch(i interface{}) interface{} {", recvName, objectName)
168+
echo("return %s[toInt(i)]", recvName)
169+
echo("}")
170+
case *types.Slice, *types.Array:
171+
echo("func (%s %s) Fetch(i interface{}) interface{} {", recvName, objectName)
172+
echo("return %s[toInt(i)]", recvName)
173+
echo("}")
174+
case *types.Map:
175+
echo("func (%s %s) Fetch(i interface{}) interface{} {", recvName, objectName)
176+
key := t.Key()
177+
178+
numericCases := []string{
179+
"int",
180+
"int8",
181+
"int16",
182+
"int32",
183+
"int64",
184+
"uint",
185+
"uint8",
186+
"uint16",
187+
"uint32",
188+
"uint64",
189+
"uintptr",
190+
"float32",
191+
"float64",
192+
}
193+
194+
switch k := key.(type) {
195+
case *types.Named:
196+
objKey := k.Obj()
197+
keyName := objKey.Name()
198+
if objKey.Pkg().Path() != pkg.Path() {
199+
path := objKey.Pkg().Path()
200+
name := objKey.Pkg().Name()
201+
for imports[name] != "" && path != imports[name] {
202+
name = name + "1"
203+
}
204+
imports[name] = path
205+
keyName = name + "." + keyName
206+
}
207+
208+
echo(`switch _x_i := i.(type) {`)
209+
echo("case %s:", keyName)
210+
echo("return %s[_x_i]", recvName)
211+
if basicKey, ok := k.Underlying().(*types.Basic); ok {
212+
if basicKey.Info()&types.IsNumeric != 0 {
213+
for _, c := range numericCases {
214+
echo("case %s:", c)
215+
echo("return %s[%s(_x_i)]", recvName, keyName)
216+
}
217+
}
218+
if basicKey.Info()&types.IsString != 0 {
219+
echo(`case string:`)
220+
echo("return %s[%s(_x_i)]", recvName, keyName)
221+
echo("default:")
222+
imports["fmt"] = "fmt"
223+
echo("return %s[%s(fmt.Sprint(i))]", recvName, keyName)
224+
}
225+
}
226+
echo(`}`)
227+
case *types.Basic:
228+
keyName := k.String()
229+
echo(`switch _x_i := i.(type) {`)
230+
echo("case %s:", keyName)
231+
echo("return %s[_x_i]", recvName)
232+
if k.Info()&types.IsNumeric != 0 {
233+
for _, c := range numericCases {
234+
if c == keyName {
235+
continue
236+
}
237+
echo("case %s:", c)
238+
echo("return %s[%s(_x_i)]", recvName, keyName)
239+
}
240+
}
241+
242+
if k.Info()&types.IsString != 0 {
243+
echo("default:")
244+
imports["fmt"] = "fmt"
245+
echo("return %s[%s(fmt.Sprint(i))]", recvName, keyName)
246+
}
247+
248+
echo(`}`)
249+
}
250+
echo("return nil")
251+
echo(`}`)
252+
case *types.Struct:
253+
echo("func (%s %s) Fetch(i interface{}) interface{} {", recvName, objectName)
254+
255+
fields := make(map[string]string)
256+
collectStruct(recvName, t, func(c string, r string) {
257+
if _, ok := fields[c]; ok {
258+
fields[c] = "-"
259+
}
260+
fields[c] = r
261+
})
262+
263+
keys := make([]string, 0, len(fields))
264+
for c, r := range fields {
265+
if r == "-" {
266+
continue
267+
}
268+
keys = append(keys, c)
269+
}
270+
sort.Strings(keys)
271+
imports["fmt"] = "fmt"
272+
273+
echo(`var string_i string`)
274+
echo(`if s, ok := i.(string); ok {`)
275+
echo(`string_i = s`)
276+
echo(`} else {`)
277+
echo(`string_i = fmt.Sprint(i)`)
278+
echo(`}`)
279+
280+
echo(`switch string_i {`)
281+
for _, key := range keys {
282+
echo("case \"%s\":", key)
283+
echo("return %s", fields[key])
284+
}
285+
echo(`}`)
286+
echo(`return nil`)
287+
echo(`}`)
288+
}
289+
}
290+
291+
importsString := "import (\n"
292+
for k, v := range imports {
293+
importsString += k + "\"" + v + "\"\n"
294+
}
295+
importsString += ")"
296+
data = strings.Replace(data, "--imports", importsString, 1)
297+
298+
return format.Source([]byte(data))
299+
}
300+
301+
func collectStruct(recv string, t *types.Struct, collect func(string, string), skippedNames ...string) {
302+
fieldNames := make([]string, 0, t.NumFields())
303+
for i := 0; i < t.NumFields(); i++ {
304+
fieldNames = append(fieldNames, t.Field(i).Name())
305+
}
306+
307+
for i := 0; i < t.NumFields(); i++ {
308+
v := t.Field(i)
309+
if !v.Exported() || contains(skippedNames, v.Name()) {
310+
continue
311+
}
312+
313+
collect(v.Name(), recv+"."+v.Name())
314+
315+
if v.Embedded() {
316+
tt := v.Type()
317+
for dereference(tt) != underlying(tt) {
318+
tt = dereference(tt)
319+
tt = underlying(tt)
320+
}
321+
322+
switch vt := tt.(type) {
323+
case *types.Struct:
324+
collectStruct(recv+"."+v.Name(), vt, collect, fieldNames...)
325+
}
326+
}
327+
}
328+
}
329+
330+
func dereference(t types.Type) types.Type {
331+
if p, ok := t.(*types.Pointer); ok {
332+
return dereference(p.Elem())
333+
}
334+
return t
335+
}
336+
337+
func underlying(t types.Type) types.Type {
338+
if t != t.Underlying() {
339+
return underlying(t.Underlying())
340+
}
341+
return t
342+
}
343+
344+
func contains(arr []string, s string) bool {
345+
for _, e := range arr {
346+
if e == s {
347+
return true
348+
}
349+
}
350+
return false
351+
}

go.sum

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
3232
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
3333
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
3434
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
35+
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
3536
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
3637
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
3738
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=

0 commit comments

Comments
 (0)