Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 25 additions & 11 deletions cmd/protoc-gen-openapi/generator/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -441,14 +441,15 @@ func (g *OpenAPIv3Generator) buildOperationV3(
description string,
defaultHost string,
path string,
bodyField string,
reqBodyField string,
resBodyField string,
inputMessage *protogen.Message,
outputMessage *protogen.Message,
) (*v3.Operation, string) {
// coveredParameters tracks the parameters that have been used in the body or path.
coveredParameters := make([]string, 0)
if bodyField != "" {
coveredParameters = append(coveredParameters, bodyField)
if reqBodyField != "" {
coveredParameters = append(coveredParameters, reqBodyField)
}
// Initialize the list of operation parameters.
parameters := []*v3.ParameterOrReference{}
Expand Down Expand Up @@ -542,17 +543,27 @@ func (g *OpenAPIv3Generator) buildOperationV3(
}

// Add any unhandled fields in the request message as query parameters.
if bodyField != "*" && string(inputMessage.Desc.FullName()) != "google.api.HttpBody" {
if reqBodyField != "*" && string(inputMessage.Desc.FullName()) != "google.api.HttpBody" {
for _, field := range inputMessage.Fields {
fieldName := string(field.Desc.Name())
if !contains(coveredParameters, fieldName) && fieldName != bodyField {
if !contains(coveredParameters, fieldName) && fieldName != reqBodyField {
fieldParams := g.buildQueryParamsV3(field)
parameters = append(parameters, fieldParams...)
}
}
}

// Create the response.
if resBodyField != "" && resBodyField != "*" {
// If body refers to a message field, use that type.
for _, field := range outputMessage.Fields {
if string(field.Desc.Name()) == resBodyField {
outputMessage = field.Message
break
}
}
}

name, content := g.reflect.responseContentForMessage(outputMessage.Desc)
responses := &v3.Responses{
ResponseOrReference: []*v3.NamedResponseOrReference{
Expand Down Expand Up @@ -615,17 +626,17 @@ func (g *OpenAPIv3Generator) buildOperationV3(
}

// If a body field is specified, we need to pass a message as the request body.
if bodyField != "" {
if reqBodyField != "" {
var requestSchema *v3.SchemaOrReference

if bodyField == "*" {
if reqBodyField == "*" {
// Pass the entire request message as the request body.
requestSchema = g.reflect.schemaOrReferenceForMessage(inputMessage.Desc)

} else {
// If body refers to a message field, use that type.
for _, field := range inputMessage.Fields {
if string(field.Desc.Name()) == bodyField {
if string(field.Desc.Name()) == reqBodyField {
switch field.Desc.Kind() {
case protoreflect.StringKind:
requestSchema = &v3.SchemaOrReference{
Expand Down Expand Up @@ -722,9 +733,12 @@ func (g *OpenAPIv3Generator) addPathsToDocumentV3(d *v3.Document, services []*pr
for _, rule := range rules {
var path string
var methodName string
var body string
var reqBody string
var resBody string

reqBody = rule.Body
resBody = rule.ResponseBody

body = rule.Body
switch pattern := rule.Pattern.(type) {
case *annotations.HttpRule_Get:
path = pattern.Get
Expand All @@ -751,7 +765,7 @@ func (g *OpenAPIv3Generator) addPathsToDocumentV3(d *v3.Document, services []*pr
defaultHost := proto.GetExtension(service.Desc.Options(), annotations.E_DefaultHost).(string)

op, path2 := g.buildOperationV3(
d, operationID, service.GoName, comment, defaultHost, path, body, inputMessage, outputMessage)
d, operationID, service.GoName, comment, defaultHost, path, reqBody, resBody, inputMessage, outputMessage)

// Merge any `Operation` annotations with the current
extOperation := proto.GetExtension(method.Desc.Options(), v3.E_Operation)
Expand Down