Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions _examples/status/job_status_enum.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion _examples/status/status_enum.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions internal/generator/enum.go.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import (
// {{.Type | title}} is the exported type for the enum
type {{.Type | title}} struct {
name string
value int
value {{.OriginalType}}
}

func (e {{.Type | title}}) String() string { return e.name }
Expand Down Expand Up @@ -91,7 +91,7 @@ func Must{{.Type | title}}(v string) {{.Type | title}} {

{{if .GenerateGetter -}}
// Get{{.Type | title}}ByID gets the correspondent {{.Type}} enum value by its ID (raw integer value)
func Get{{.Type | title}}ByID(v int) ({{.Type | title}}, error) {
func Get{{.Type | title}}ByID(v {{.OriginalType}}) ({{.Type | title}}, error) {
switch v {
{{range .Values -}}
case {{.Index}}:
Expand Down
32 changes: 29 additions & 3 deletions internal/generator/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ type Generator struct {
pkgName string // package name from source file
lowerCase bool // use lower case for marshal/unmarshal
generateGetter bool // generate getter methods for enum values
originalType string // original type name (e.g., "uint8")
}

// Value represents a single enum value
Expand Down Expand Up @@ -88,6 +89,9 @@ func (g *Generator) Parse(dir string) error {
}
}

if g.originalType == "" {
return fmt.Errorf("type %s not found", g.Type)
}
if len(g.values) == 0 {
return fmt.Errorf("no const values found for type %s", g.Type)
}
Expand All @@ -97,7 +101,22 @@ func (g *Generator) Parse(dir string) error {

// parseFile processes a single file for enum declarations
func (g *Generator) parseFile(file *ast.File) {

parseTypeBlock := func(decl *ast.GenDecl) {
// extracts the type name from a const block
for _, spec := range decl.Specs {
vspec, ok := spec.(*ast.TypeSpec)
if !ok {
continue
}
if vspec.Name.Name == g.Type {
tspec, ok := vspec.Type.(*ast.Ident)
if !ok {
continue
}
g.originalType = tspec.Name
}
}
}
parseConstBlock := func(decl *ast.GenDecl) {
// extracts enum values from a const block
var iotaVal int
Expand Down Expand Up @@ -155,8 +174,13 @@ func (g *Generator) parseFile(file *ast.File) {
}

ast.Inspect(file, func(n ast.Node) bool {
if decl, ok := n.(*ast.GenDecl); ok && decl.Tok == token.CONST {
parseConstBlock(decl)
if decl, ok := n.(*ast.GenDecl); ok {
switch decl.Tok {
case token.CONST:
parseConstBlock(decl)
case token.TYPE:
parseTypeBlock(decl)
}
}
return true
})
Expand Down Expand Up @@ -250,12 +274,14 @@ func (g *Generator) Generate() error {
Package string
LowerCase bool
GenerateGetter bool
OriginalType string
}{
Type: g.Type,
Values: values,
Package: pkgName,
LowerCase: g.lowerCase,
GenerateGetter: g.generateGetter,
OriginalType: g.originalType,
}

// execute template
Expand Down
27 changes: 25 additions & 2 deletions internal/generator/generator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ func TestGenerator(t *testing.T) {
require.NoError(t, err)

// check content
assert.Contains(t, string(content), "func GetJobStatusByID(v int) (JobStatus, error)")
assert.Contains(t, string(content), "func GetJobStatusByID(v uint8) (JobStatus, error)")
assert.Contains(t, string(content), "case 0:\n\t\treturn JobStatusUnknown, nil")
assert.Contains(t, string(content), "case 1:\n\t\treturn JobStatusActive, nil")
assert.Contains(t, string(content), "case 2:\n\t\treturn JobStatusInactive, nil")
Expand Down Expand Up @@ -266,7 +266,7 @@ func TestGenerator(t *testing.T) {
require.NoError(t, err)

// check content
assert.Contains(t, string(content), "func GetExplicitValuesByID(v int) (ExplicitValues, error)")
assert.Contains(t, string(content), "func GetExplicitValuesByID(v uint8) (ExplicitValues, error)")
assert.Contains(t, string(content), "case 10:\n\t\treturn ExplicitValuesFirst, nil")
assert.Contains(t, string(content), "case 20:\n\t\treturn ExplicitValuesSecond, nil")
assert.Contains(t, string(content), "case 30:\n\t\treturn ExplicitValuesThird, nil")
Expand Down Expand Up @@ -442,6 +442,7 @@ func TestPermissions(t *testing.T) {

// create a sample status file
sampleFile := `package source
type status uint8
const (
statusUnknown = iota
statusActive
Expand Down Expand Up @@ -540,6 +541,7 @@ func TestParseSpecialCases(t *testing.T) {
tmpDir := t.TempDir()
err := os.WriteFile(filepath.Join(tmpDir, "empty.go"), []byte(`
package test
type status uint8
const (
)
`), 0o644)
Expand All @@ -557,6 +559,7 @@ const (
tmpDir := t.TempDir()
err := os.WriteFile(filepath.Join(tmpDir, "no_values.go"), []byte(`
package test
type status uint8
const name string
`), 0o644)
require.NoError(t, err)
Expand All @@ -568,6 +571,26 @@ const name string
require.Error(t, err)
assert.Contains(t, err.Error(), "no const values found for type status")
})

t.Run("no status type", func(t *testing.T) {
tmpDir := t.TempDir()
err := os.WriteFile(filepath.Join(tmpDir, "no_type.go"), []byte(`
package test
const (
statusUnknown = iota
statusActive
statusInactive
)
`), 0o644)
require.NoError(t, err)

gen, err := New("status", "")
require.NoError(t, err)

err = gen.Parse(tmpDir)
require.Error(t, err)
assert.Contains(t, err.Error(), "type status not found")
})
}

func TestSplitCamelCase(t *testing.T) {
Expand Down