diff --git a/openapi/code/entity.go b/openapi/code/entity.go index a61917b9d..56013244d 100644 --- a/openapi/code/entity.go +++ b/openapi/code/entity.go @@ -74,6 +74,7 @@ type Entity struct { IsString bool IsByteStream bool IsEmpty bool + Const any // this field does not have a concrete type IsAny bool @@ -87,6 +88,9 @@ type Entity struct { // Schema references the OpenAPI schema this entity was created from. Schema *openapi.Schema + + ChildTypes ChildTypes + AbstractType *Entity } // Whether the Entity contains a basic GoLang type which is not required diff --git a/openapi/code/load.go b/openapi/code/load.go index 75af8d5af..617eafde6 100644 --- a/openapi/code/load.go +++ b/openapi/code/load.go @@ -30,9 +30,14 @@ func NewFromFile(ctx context.Context, name string) (*Batch, error) { // NewFromSpec converts OpenAPI spec to intermediate representation func NewFromSpec(ctx context.Context, spec *openapi.Specification) (*Batch, error) { - batch := Batch{ + batch := &Batch{ packages: map[string]*Package{}, } + poly := newPolymorphism(spec.Components) + err := poly.Load() + if err != nil { + return nil, fmt.Errorf("polymorphic types: %w", err) + } for _, tag := range spec.Tags { pkg, ok := batch.packages[tag.Package] if !ok { @@ -41,6 +46,7 @@ func NewFromSpec(ctx context.Context, spec *openapi.Specification) (*Batch, erro Components: spec.Components, services: map[string]*Service{}, types: map[string]*Entity{}, + poly: poly, } batch.packages[tag.Package] = pkg } @@ -49,6 +55,10 @@ func NewFromSpec(ctx context.Context, spec *openapi.Specification) (*Batch, erro return nil, fmt.Errorf("fail to load %s: %w", tag.Name, err) } } + err = poly.Link(batch) + if err != nil { + return nil, fmt.Errorf("link: %w", err) + } // add some packages at least some description for _, pkg := range batch.packages { if len(pkg.services) > 1 { @@ -59,7 +69,7 @@ func NewFromSpec(ctx context.Context, spec *openapi.Specification) (*Batch, erro pkg.Description = svc.Summary() } } - return &batch, nil + return batch, nil } func (b *Batch) FullName() string { diff --git a/openapi/code/package.go b/openapi/code/package.go index 0b603bcec..b51d5f1a2 100644 --- a/openapi/code/package.go +++ b/openapi/code/package.go @@ -24,6 +24,7 @@ type Package struct { types map[string]*Entity emptyTypes []*Named extImports map[string]*Entity + poly *polymorphism } // FullName just returns pacakge name @@ -52,12 +53,28 @@ func (pkg *Package) MainService() *Service { } // Types returns sorted slice of types -func (pkg *Package) Types() (types []*Entity) { +func (pkg *Package) Types() (out []*Entity) { + types := []*Entity{} for _, v := range pkg.types { types = append(types, v) } pascalNameSort(types) - return types + // Python doesn't support forward-references for base classes, + // so that's why we have to pull abstract types first. + // topological sort is not required, as Databricks doesn't have (yet) + // complicated type hierarchy with oneOf/anyOf references. toposort + // could easily be added here later. + for _, v := range types { + if v.ChildTypes != nil { + out = append(out, v) + } + } + for _, v := range types { + if v.ChildTypes == nil { + out = append(out, v) + } + } + return out } // EmptyTypes returns sorted list of types without fields @@ -99,6 +116,14 @@ func (pkg *Package) ImportedPackages() (res []string) { } func (pkg *Package) schemaToEntity(s *openapi.Schema, path []string, hasName bool) *Entity { + if s.IsOneOf() || s.IsAnyOf() { + entity, err := pkg.poly.Resolve(pkg.Name, path[0]) + if err != nil { + err = fmt.Errorf("poly: %w", err) + panic(err) + } + return pkg.define(entity) + } if s.IsRef() { pair := strings.Split(s.Component(), ".") if len(pair) == 2 && pair[0] != pkg.Name { @@ -172,7 +197,8 @@ func (pkg *Package) schemaToEntity(s *openapi.Schema, path []string, hasName boo e.IsString = s.Type == "string" e.IsInt64 = s.Type == "integer" && s.Format == "int64" e.IsFloat64 = s.Type == "number" && s.Format == "double" - e.IsInt = s.Type == "integer" || s.Type == "int" + e.IsInt = s.Type == "integer" || s.Type == "int" || s.Type == "number" + e.Const = s.Const return e } diff --git a/openapi/code/polymorphism.go b/openapi/code/polymorphism.go new file mode 100644 index 000000000..696dbd404 --- /dev/null +++ b/openapi/code/polymorphism.go @@ -0,0 +1,279 @@ +package code + +import ( + "fmt" + "sort" + "strings" + + "github.com/databricks/databricks-sdk-go/openapi" +) + +func newPolymorphism(components *openapi.Components) *polymorphism { + return &polymorphism{ + components: components, + types: map[string]ChildTypes{}, + anyOf: map[string][]string{}, + } +} + +type polymorphism struct { + components *openapi.Components + types map[string]ChildTypes + anyOf map[string][]string +} + +type TypeLookup []*Field + +func (d TypeLookup) IsConstant() bool { + for _, v := range d { + if v.Schema.Const != nil { + return true + } + } + return false +} + +func (d TypeLookup) Sort() { + sort.Slice(d, func(i, j int) bool { + return d[i].Name < d[j].Name + }) +} + +func (d TypeLookup) String() string { + var tmp []string + for _, v := range d { + tmp = append(tmp, fmt.Sprintf("%s=%v", v.Name, v.Schema.Const)) + } + return strings.Join(tmp, ", ") +} + +type TypeDiscriminator struct { + *Entity + TypeLookup +} + +type ChildTypes []*TypeDiscriminator + +func (d ChildTypes) TypeLookup() TypeLookup { + if len(d) == 0 { + return nil + } + return d[0].TypeLookup +} + +func (d ChildTypes) IsConstant() bool { + for _, v := range d { + if v.IsConstant() { + return true + } + } + return false +} + +func (d ChildTypes) Sort() { + if d.IsConstant() { + sort.Slice(d, func(i, j int) bool { + return d[i].String() < d[j].String() + }) + return + } + sort.Slice(d, func(i, j int) bool { + return len(d[i].TypeLookup) < len(d[j].TypeLookup) + }) +} + +func (p *polymorphism) unresolvedOneOf() (map[string]ChildTypes, error) { + out := map[string]ChildTypes{} + for typeName, typeSchema := range p.components.Schemas { + if typeSchema == nil { + continue + } + s := *typeSchema + if len(s.OneOf) == 0 { + continue + } + if s.Discriminator != nil && len(s.DiscriminatorProperties) > 0 { + return nil, fmt.Errorf("both x-databricks-discriminator-properties and discriminator there: %s", typeName) + } + if s.Discriminator != nil { + s.DiscriminatorProperties = append(s.DiscriminatorProperties, s.Discriminator.PropertyName) + } + if len(s.DiscriminatorProperties) == 0 { + return nil, fmt.Errorf("missing discriminators: %s", typeName) + } + children := ChildTypes{} + for _, oneOf := range s.OneOf { + if oneOf.Ref == "" { + return nil, fmt.Errorf("can point only to a type: %s", typeName) + } + otherType := p.components.Schemas.Resolve(&openapi.Schema{Node: oneOf}) + if otherType == nil { + return nil, fmt.Errorf("not found: %s", oneOf.Ref) + } + lookup := TypeLookup{} + for _, propertyName := range s.DiscriminatorProperties { + unresolved, ok := (*otherType).Properties[propertyName] + if !ok { + return nil, fmt.Errorf("%s has no %s", oneOf.Ref, propertyName) + } + propertySchema := p.components.Schemas.Resolve(unresolved) + if propertySchema == nil { + return nil, fmt.Errorf("cannot resolve property: %s", propertyName) + } + lookup = append(lookup, &Field{ + Schema: *propertySchema, + Named: Named{ + Name: propertyName, + }, + }) + } + lookup.Sort() + otherPackage, otherName, ok := strings.Cut(oneOf.Component(), ".") + if !ok { + return nil, fmt.Errorf("no package: %s", oneOf.Ref) + } + children = append(children, &TypeDiscriminator{ + Entity: &Entity{ + Named: Named{ + Name: otherName, + }, + Package: &Package{ + Named: Named{ + Name: otherPackage, + }, + }, + }, + TypeLookup: lookup, + }) + } + children.Sort() + out[typeName] = children + } + return out, nil +} + +func (p *polymorphism) unresolvedAnyOf() (map[string][]string, error) { + out := map[string][]string{} + for typeName, typeSchema := range p.components.Schemas { + if typeSchema == nil { + continue + } + s := *typeSchema + if len(s.AnyOf) == 0 { + continue + } + for _, anyOf := range s.AnyOf { + if anyOf.Ref == "" { + return nil, fmt.Errorf("can point only to a type: %s", typeName) + } + otherName := anyOf.Component() + out[typeName] = append(out[typeName], otherName) + } + } + return out, nil +} + +func (p *polymorphism) Load() error { + types, err := p.unresolvedOneOf() + if err != nil { + return fmt.Errorf("oneOf: %w", err) + } + anyOf, err := p.unresolvedAnyOf() + if err != nil { + return fmt.Errorf("anyOf: %w", err) + } + p.types = types + p.anyOf = anyOf + return nil +} + +func (p *polymorphism) Link(batch *Batch) error { + for _, pkg := range batch.packages { + for _, abstractType := range pkg.types { + if abstractType.ChildTypes == nil { + continue + } + (*abstractType).Package = pkg + for _, subType := range abstractType.ChildTypes { + linkedPackage, ok := batch.packages[subType.Package.Name] + if !ok { + return fmt.Errorf("missing package: %s", subType.FullName()) + } + resolvedType, ok := linkedPackage.types[subType.Name] + if !ok { + return fmt.Errorf("missing type: %s", subType.FullName()) + } + (*subType).Entity = resolvedType + (*resolvedType).AbstractType = abstractType + if subType.TypeLookup == nil { + continue + } + for _, field := range subType.TypeLookup { + field.Entity = resolvedType.fields[field.Name].Entity + } + } + } + } + return nil +} + +func (p *polymorphism) Resolve(pkgName, typeName string) (*Entity, error) { + key := fmt.Sprintf("%s.%s", pkgName, typeName) + discriminators, ok := p.types[key] + if ok { + return p.resolveOneOf(typeName, discriminators) + } + assignable, ok := p.anyOf[key] + if !ok { + return nil, fmt.Errorf("not found: %s", key) + } + for _, anyOf := range assignable { + otherPackage, otherType, ok := strings.Cut(anyOf, ".") + if !ok { + return nil, fmt.Errorf("malformed: %s", anyOf) + } + discriminators = append(discriminators, &TypeDiscriminator{ + Entity: &Entity{ + Named: Named{ + Name: otherType, + }, + Package: &Package{ + Named: Named{ + Name: otherPackage, + }, + }, + }, + }) + } + return &Entity{ + Named: Named{ + Name: typeName, + }, + ChildTypes: discriminators, + fields: map[string]*Field{ + // dummy field so that it's not filtered out + "_": { + IsJson: true, + Entity: &Entity{ + IsString: true, + Const: "_", + }, + }, + }, + }, nil +} + +func (p *polymorphism) resolveOneOf(typeName string, discriminators ChildTypes) (*Entity, error) { + fields := map[string]*Field{} + for _, v := range discriminators[0].TypeLookup { + fields[v.Name] = v + v.IsJson = true + } + return &Entity{ + Named: Named{ + Name: typeName, + }, + ChildTypes: discriminators, + fields: fields, + }, nil +} diff --git a/openapi/code/tmpl_util_funcs.go b/openapi/code/tmpl_util_funcs.go index 551d075ba..eed8a0009 100644 --- a/openapi/code/tmpl_util_funcs.go +++ b/openapi/code/tmpl_util_funcs.go @@ -49,6 +49,24 @@ var HelperFuncs = template.FuncMap{ } return out }, + "noConst": func(in []*Field) (out []*Field) { + for _, v := range in { + if v.Entity.Const != nil { + continue + } + out = append(out, v) + } + return out + }, + "constOnly": func(in []*Field) (out []*Field) { + for _, v := range in { + if v.Entity.Const == nil { + continue + } + out = append(out, v) + } + return out + }, "list": func(l ...any) []any { return l }, diff --git a/openapi/generator/x_test.go b/openapi/generator/x_test.go new file mode 100644 index 000000000..63f79fea4 --- /dev/null +++ b/openapi/generator/x_test.go @@ -0,0 +1,34 @@ +package generator + +import ( + "context" + "fmt" + "os" + "path/filepath" + "testing" + + "github.com/databricks/databricks-sdk-go/openapi/code" + "github.com/stretchr/testify/require" +) + +func run() error { + ctx := context.Background() + home, _ := os.UserHomeDir() + spec, err := code.NewFromFile(ctx, filepath.Join(home, + "Downloads/RenderWidgetSpec.openapi.json")) + if err != nil { + return fmt.Errorf("spec: %w", err) + } + gen, err := NewGenerator(filepath.Join(home, + "git/labs/ucx/src/databricks/labs/ucx/framework/lakeview")) + if err != nil { + return fmt.Errorf("config: %w", err) + } + return gen.Apply(ctx, spec, nil) +} + +func TestX(t *testing.T) { + t.SkipNow() + err := run() + require.NoError(t, err) +} diff --git a/openapi/model.go b/openapi/model.go index 18709af5e..c62830619 100644 --- a/openapi/model.go +++ b/openapi/model.go @@ -236,12 +236,30 @@ type Schema struct { Properties map[string]*Schema `json:"properties,omitempty"` ArrayValue *Schema `json:"items,omitempty"` MapValue *Schema `json:"additionalProperties,omitempty"` + + Const any `json:"const,omitempty"` + Discriminator *Discriminator `json:"discriminator,omitempty"` + DiscriminatorProperties []string `json:"x-databricks-discriminator-properties"` + OneOf []Node `json:"oneOf,omitempty"` + AnyOf []Node `json:"anyOf,omitempty"` +} + +type Discriminator struct { + PropertyName string `json:"propertyName"` } func (s *Schema) IsEnum() bool { return len(s.Enum) != 0 } +func (s *Schema) IsOneOf() bool { + return len(s.OneOf) > 0 +} + +func (s *Schema) IsAnyOf() bool { + return len(s.AnyOf) > 0 +} + func (s *Schema) IsObject() bool { return len(s.Properties) != 0 } @@ -267,6 +285,12 @@ func (s *Schema) IsEmpty() bool { if s.IsArray() { return false } + if s.IsOneOf() { + return false + } + if s.IsAnyOf() { + return false + } if s.IsObject() { return false }