Skip to content

Commit e02ab06

Browse files
committed
Auto generate funcs to convert interface types to their nvml* types
Signed-off-by: Kevin Klues <[email protected]>
1 parent 9554325 commit e02ab06

File tree

3 files changed

+333
-39
lines changed

3 files changed

+333
-39
lines changed

gen/nvml/generateapi.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import (
2929
"slices"
3030
"sort"
3131
"strings"
32+
"text/template"
3233
"unicode"
3334
)
3435

@@ -80,6 +81,45 @@ var GeneratableInterfaces = []GeneratableInterfacePoperties{
8081
},
8182
}
8283

84+
// Template definitions
85+
const handleHelperTemplate = `
86+
// {{.Type}}Handle attempts to convert a {{.Interface}} to an {{.Type}}.
87+
func {{.Type}}Handle({{.ParamName}} {{.Interface}}) {{.Type}} {
88+
var helper func(val reflect.Value) {{.Type}}
89+
helper = func(val reflect.Value) {{.Type}} {
90+
if val.Kind() == reflect.Interface {
91+
val = val.Elem()
92+
}
93+
if val.Kind() == reflect.Ptr {
94+
val = val.Elem()
95+
}
96+
if val.Type() == reflect.TypeOf((*{{.Type}})(nil)).Elem() {
97+
return val.Interface().({{.Type}})
98+
}
99+
if val.Kind() != reflect.Struct {
100+
panic(fmt.Errorf("unable to convert non-struct type %v to {{.Type}}", val.Kind()))
101+
}
102+
for i := 0; i < val.Type().NumField(); i++ {
103+
if !val.Type().Field(i).Anonymous {
104+
continue
105+
}
106+
if !val.Field(i).Type().Implements(reflect.TypeOf((*{{.Interface}})(nil)).Elem()) {
107+
continue
108+
}
109+
return helper(val.Field(i))
110+
}
111+
panic(fmt.Errorf("unable to convert %T to {{.Type}}", {{.ParamName}}))
112+
}
113+
return helper(reflect.ValueOf({{.ParamName}}))
114+
}`
115+
116+
// Template data structures
117+
type HandleHelperTemplateData struct {
118+
Type string
119+
Interface string
120+
ParamName string
121+
}
122+
83123
func main() {
84124
sourceDir := flag.String("sourceDir", "", "Path to the source directory for all go files")
85125
output := flag.String("output", "", "Path to the output file (default: stdout)")
@@ -140,6 +180,15 @@ func main() {
140180
fmt.Fprint(writer, "\n")
141181
}
142182
}
183+
184+
// Generate handle conversion helpers
185+
fmt.Fprint(writer, "\n")
186+
handleHelpers, err := generateHandleHelpers()
187+
if err != nil {
188+
fmt.Printf("Error: %v", err)
189+
return
190+
}
191+
fmt.Fprint(writer, handleHelpers)
143192
}
144193

145194
func getWriter(outputFile string) (io.Writer, func() error, error) {
@@ -177,6 +226,11 @@ func generateHeader() (string, error) {
177226
"",
178227
"package nvml",
179228
"",
229+
"import (",
230+
" \"fmt\"",
231+
" \"reflect\"",
232+
")",
233+
"",
180234
"",
181235
}
182236
return strings.Join(lines, "\n"), nil
@@ -418,3 +472,37 @@ func isPublic(name string) bool {
418472
}
419473
return unicode.IsUpper([]rune(name)[0])
420474
}
475+
476+
func generateHandleHelpers() (string, error) {
477+
// Parse the template
478+
tmpl, err := template.New("handleHelper").Parse(handleHelperTemplate)
479+
if err != nil {
480+
return "", fmt.Errorf("failed to parse handle helper template: %v", err)
481+
}
482+
483+
var builder strings.Builder
484+
485+
// Generate helper for each type (only if Type starts with 'nvml')
486+
for _, p := range GeneratableInterfaces {
487+
if !strings.HasPrefix(p.Type, "nvml") {
488+
continue
489+
}
490+
491+
// Create template data
492+
data := HandleHelperTemplateData{
493+
Type: p.Type,
494+
Interface: p.Interface,
495+
ParamName: strings.ToLower(p.Interface[0:1]) + p.Interface[1:],
496+
}
497+
498+
// Execute template
499+
if err := tmpl.Execute(&builder, data); err != nil {
500+
return "", fmt.Errorf("failed to execute handle helper template for %s: %v", p.Type, err)
501+
}
502+
builder.WriteString("\n")
503+
}
504+
505+
return builder.String(), nil
506+
}
507+
508+

pkg/nvml/device.go

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -15,48 +15,9 @@
1515
package nvml
1616

1717
import (
18-
"fmt"
19-
"reflect"
2018
"unsafe"
2119
)
2220

23-
// nvmlDeviceHandle attempts to convert a device d to an nvmlDevice.
24-
// This is required for functions such as GetTopologyCommonAncestor which
25-
// accept Device arguments that need to be passed to internal nvml* functions
26-
// as nvmlDevice parameters.
27-
func nvmlDeviceHandle(d Device) nvmlDevice {
28-
var helper func(val reflect.Value) nvmlDevice
29-
helper = func(val reflect.Value) nvmlDevice {
30-
if val.Kind() == reflect.Interface {
31-
val = val.Elem()
32-
}
33-
34-
if val.Kind() == reflect.Ptr {
35-
val = val.Elem()
36-
}
37-
38-
if val.Type() == reflect.TypeOf(nvmlDevice{}) {
39-
return val.Interface().(nvmlDevice)
40-
}
41-
42-
if val.Kind() != reflect.Struct {
43-
panic(fmt.Errorf("unable to convert non-struct type %v to nvmlDevice", val.Kind()))
44-
}
45-
46-
for i := 0; i < val.Type().NumField(); i++ {
47-
if !val.Type().Field(i).Anonymous {
48-
continue
49-
}
50-
if !val.Field(i).Type().Implements(reflect.TypeOf((*Device)(nil)).Elem()) {
51-
continue
52-
}
53-
return helper(val.Field(i))
54-
}
55-
panic(fmt.Errorf("unable to convert %T to nvmlDevice", d))
56-
}
57-
return helper(reflect.ValueOf(d))
58-
}
59-
6021
// EccBitType
6122
type EccBitType = MemoryErrorType
6223

0 commit comments

Comments
 (0)