Skip to content

Commit ceda3b9

Browse files
committed
Preserve original value type
1 parent e8a04f4 commit ceda3b9

File tree

5 files changed

+68
-19
lines changed

5 files changed

+68
-19
lines changed

_examples/status/job_status_enum.go

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

_examples/status/status_enum.go

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

internal/generator/enum.go.tmpl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import (
1313
// {{.Type | title}} is the exported type for the enum
1414
type {{.Type | title}} struct {
1515
name string
16-
value int
16+
value {{.OriginalType}}
1717
}
1818

1919
func (e {{.Type | title}}) String() string { return e.name }

internal/generator/generator.go

Lines changed: 42 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,12 @@ var titleCaser = cases.Title(language.English, cases.NoLower)
2626

2727
// Generator holds the data needed for enum code generation
2828
type Generator struct {
29-
Type string // the private type name (e.g., "status")
30-
Path string // output directory path
31-
values map[string]int // const values found
32-
pkgName string // package name from source file
33-
lowerCase bool // use lower case for marshal/unmarshal
29+
Type string // the private type name (e.g., "status")
30+
Path string // output directory path
31+
values map[string]int // const values found
32+
pkgName string // package name from source file
33+
lowerCase bool // use lower case for marshal/unmarshal
34+
originalType string // original type name (e.g., "uint8")
3435
}
3536

3637
// Value represents a single enum value
@@ -81,6 +82,9 @@ func (g *Generator) Parse(dir string) error {
8182
}
8283
}
8384

85+
if g.originalType == "" {
86+
return fmt.Errorf("type %s not found", g.Type)
87+
}
8488
if len(g.values) == 0 {
8589
return fmt.Errorf("no const values found for type %s", g.Type)
8690
}
@@ -90,7 +94,22 @@ func (g *Generator) Parse(dir string) error {
9094

9195
// parseFile processes a single file for enum declarations
9296
func (g *Generator) parseFile(file *ast.File) {
93-
97+
parseTypeBlock := func(decl *ast.GenDecl) {
98+
// extracts the type name from a const block
99+
for _, spec := range decl.Specs {
100+
vspec, ok := spec.(*ast.TypeSpec)
101+
if !ok {
102+
continue
103+
}
104+
if vspec.Name.Name == g.Type {
105+
tspec, ok := vspec.Type.(*ast.Ident)
106+
if !ok {
107+
continue
108+
}
109+
g.originalType = tspec.Name
110+
}
111+
}
112+
}
94113
parseConstBlock := func(decl *ast.GenDecl) {
95114
// extracts enum values from a const block
96115
var iotaVal int
@@ -148,8 +167,13 @@ func (g *Generator) parseFile(file *ast.File) {
148167
}
149168

150169
ast.Inspect(file, func(n ast.Node) bool {
151-
if decl, ok := n.(*ast.GenDecl); ok && decl.Tok == token.CONST {
152-
parseConstBlock(decl)
170+
if decl, ok := n.(*ast.GenDecl); ok {
171+
switch decl.Tok {
172+
case token.CONST:
173+
parseConstBlock(decl)
174+
case token.TYPE:
175+
parseTypeBlock(decl)
176+
}
153177
}
154178
return true
155179
})
@@ -216,15 +240,17 @@ func (g *Generator) Generate() error {
216240

217241
// prepare template data
218242
data := struct {
219-
Type string
220-
Values []Value
221-
Package string
222-
LowerCase bool
243+
Type string
244+
Values []Value
245+
Package string
246+
LowerCase bool
247+
OriginalType string
223248
}{
224-
Type: g.Type,
225-
Values: values,
226-
Package: pkgName,
227-
LowerCase: g.lowerCase,
249+
Type: g.Type,
250+
Values: values,
251+
Package: pkgName,
252+
LowerCase: g.lowerCase,
253+
OriginalType: g.originalType,
228254
}
229255

230256
// execute template

internal/generator/generator_test.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ func TestPermissions(t *testing.T) {
368368

369369
// create a sample status file
370370
sampleFile := `package source
371+
type status uint8
371372
const (
372373
statusUnknown = iota
373374
statusActive
@@ -466,6 +467,7 @@ func TestParseSpecialCases(t *testing.T) {
466467
tmpDir := t.TempDir()
467468
err := os.WriteFile(filepath.Join(tmpDir, "empty.go"), []byte(`
468469
package test
470+
type status uint8
469471
const (
470472
)
471473
`), 0o644)
@@ -483,6 +485,7 @@ const (
483485
tmpDir := t.TempDir()
484486
err := os.WriteFile(filepath.Join(tmpDir, "no_values.go"), []byte(`
485487
package test
488+
type status uint8
486489
const name string
487490
`), 0o644)
488491
require.NoError(t, err)
@@ -494,6 +497,26 @@ const name string
494497
require.Error(t, err)
495498
assert.Contains(t, err.Error(), "no const values found for type status")
496499
})
500+
501+
t.Run("no status type", func(t *testing.T) {
502+
tmpDir := t.TempDir()
503+
err := os.WriteFile(filepath.Join(tmpDir, "no_type.go"), []byte(`
504+
package test
505+
const (
506+
statusUnknown = iota
507+
statusActive
508+
statusInactive
509+
)
510+
`), 0o644)
511+
require.NoError(t, err)
512+
513+
gen, err := New("status", "")
514+
require.NoError(t, err)
515+
516+
err = gen.Parse(tmpDir)
517+
require.Error(t, err)
518+
assert.Contains(t, err.Error(), "type status not found")
519+
})
497520
}
498521

499522
func TestSplitCamelCase(t *testing.T) {

0 commit comments

Comments
 (0)