diff --git a/cmd/protoc-gen-openapi/generator/generator.go b/cmd/protoc-gen-openapi/generator/generator.go index e548ab21..7d80d409 100644 --- a/cmd/protoc-gen-openapi/generator/generator.go +++ b/cmd/protoc-gen-openapi/generator/generator.go @@ -441,14 +441,15 @@ func (g *OpenAPIv3Generator) buildOperationV3( description string, defaultHost string, path string, - bodyField string, + reqBodyField string, + resBodyField string, inputMessage *protogen.Message, outputMessage *protogen.Message, ) (*v3.Operation, string) { // coveredParameters tracks the parameters that have been used in the body or path. coveredParameters := make([]string, 0) - if bodyField != "" { - coveredParameters = append(coveredParameters, bodyField) + if reqBodyField != "" { + coveredParameters = append(coveredParameters, reqBodyField) } // Initialize the list of operation parameters. parameters := []*v3.ParameterOrReference{} @@ -542,10 +543,10 @@ func (g *OpenAPIv3Generator) buildOperationV3( } // Add any unhandled fields in the request message as query parameters. - if bodyField != "*" && string(inputMessage.Desc.FullName()) != "google.api.HttpBody" { + if reqBodyField != "*" && string(inputMessage.Desc.FullName()) != "google.api.HttpBody" { for _, field := range inputMessage.Fields { fieldName := string(field.Desc.Name()) - if !contains(coveredParameters, fieldName) && fieldName != bodyField { + if !contains(coveredParameters, fieldName) && fieldName != reqBodyField { fieldParams := g.buildQueryParamsV3(field) parameters = append(parameters, fieldParams...) } @@ -553,6 +554,16 @@ func (g *OpenAPIv3Generator) buildOperationV3( } // Create the response. + if resBodyField != "" && resBodyField != "*" { + // If body refers to a message field, use that type. + for _, field := range outputMessage.Fields { + if string(field.Desc.Name()) == resBodyField { + outputMessage = field.Message + break + } + } + } + name, content := g.reflect.responseContentForMessage(outputMessage.Desc) responses := &v3.Responses{ ResponseOrReference: []*v3.NamedResponseOrReference{ @@ -615,17 +626,17 @@ func (g *OpenAPIv3Generator) buildOperationV3( } // If a body field is specified, we need to pass a message as the request body. - if bodyField != "" { + if reqBodyField != "" { var requestSchema *v3.SchemaOrReference - if bodyField == "*" { + if reqBodyField == "*" { // Pass the entire request message as the request body. requestSchema = g.reflect.schemaOrReferenceForMessage(inputMessage.Desc) } else { // If body refers to a message field, use that type. for _, field := range inputMessage.Fields { - if string(field.Desc.Name()) == bodyField { + if string(field.Desc.Name()) == reqBodyField { switch field.Desc.Kind() { case protoreflect.StringKind: requestSchema = &v3.SchemaOrReference{ @@ -722,9 +733,12 @@ func (g *OpenAPIv3Generator) addPathsToDocumentV3(d *v3.Document, services []*pr for _, rule := range rules { var path string var methodName string - var body string + var reqBody string + var resBody string + + reqBody = rule.Body + resBody = rule.ResponseBody - body = rule.Body switch pattern := rule.Pattern.(type) { case *annotations.HttpRule_Get: path = pattern.Get @@ -751,7 +765,7 @@ func (g *OpenAPIv3Generator) addPathsToDocumentV3(d *v3.Document, services []*pr defaultHost := proto.GetExtension(service.Desc.Options(), annotations.E_DefaultHost).(string) op, path2 := g.buildOperationV3( - d, operationID, service.GoName, comment, defaultHost, path, body, inputMessage, outputMessage) + d, operationID, service.GoName, comment, defaultHost, path, reqBody, resBody, inputMessage, outputMessage) // Merge any `Operation` annotations with the current extOperation := proto.GetExtension(method.Desc.Options(), v3.E_Operation)