diff --git a/flask_swagger_generator/generators/generator.py b/flask_swagger_generator/generators/generator.py index 5500fd6..9b427d2 100644 --- a/flask_swagger_generator/generators/generator.py +++ b/flask_swagger_generator/generators/generator.py @@ -107,6 +107,21 @@ def wrapper(*args, **kwargs): return wrapper return swagger_security + def path_tag(self, tag): + def swagger_path_tag(func): + + if not self.generated: + self.specifier.add_path_tag( + func.__name__, tag + ) + + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + return swagger_path_tag + def create_schema(self, reference_name, properties): return self.specifier.create_schema(reference_name, properties) @@ -118,17 +133,16 @@ def index_endpoints(self, app): for rule in app.url_map.iter_rules(): + group = None + function_name = rule.endpoint if len(rule.endpoint.split(".")) > 1: group, function_name = rule.endpoint.split('.') - self.specifier.add_endpoint( + for path_tag in self.specifier.path_tags: + if path_tag.get("function_name") == function_name: + group = path_tag.get("tag") + self.specifier.add_endpoint( function_name=function_name, path=str(rule), request_types=rule.methods, group=group ) - else: - self.specifier.add_endpoint( - function_name=rule.endpoint, - path=str(rule), - request_types=rule.methods, - ) diff --git a/flask_swagger_generator/specifiers/swagger_specifier.py b/flask_swagger_generator/specifiers/swagger_specifier.py index 8b74bcc..c53704a 100644 --- a/flask_swagger_generator/specifiers/swagger_specifier.py +++ b/flask_swagger_generator/specifiers/swagger_specifier.py @@ -42,6 +42,10 @@ def add_request_body(self, function_name, schema): def add_security(self, function_name, security_type: SecurityType): raise NotImplementedError() + @abstractmethod + def add_path_tag(self, function_name, tag): + raise NotImplementedError() + def set_application_name(self, application_name): self.application_name = application_name diff --git a/flask_swagger_generator/specifiers/swagger_three_specifier.py b/flask_swagger_generator/specifiers/swagger_three_specifier.py index 9c7ebd0..b2c96b3 100644 --- a/flask_swagger_generator/specifiers/swagger_three_specifier.py +++ b/flask_swagger_generator/specifiers/swagger_three_specifier.py @@ -464,6 +464,7 @@ def __init__(self): self.schemas = [] self.responses = [] self.securities = [] + self.path_tags = [] def perform_write(self, file): # Add all request bodies to request_types with same function name @@ -678,6 +679,10 @@ def add_security(self, function_name, security_type: SecurityType): security_model = SwaggerSecurity([function_name], security_type) self.securities.append(security_model) + def add_path_tag(self, function_name: str, tag): + path_tag = {"function_name": function_name, "tag": tag} + self.path_tags.append(path_tag) + def create_schema(self, reference_name, properties): schema = SwaggerSchema(reference_name, properties) self.schemas.append(schema) diff --git a/tests/resources/reference_version_three.yaml b/tests/resources/reference_version_three.yaml index c27e10a..c845d50 100644 --- a/tests/resources/reference_version_three.yaml +++ b/tests/resources/reference_version_three.yaml @@ -1,8 +1,7 @@ openapi: 3.0.1 info: title: Application - description: Generated at 03/01/2021 20:29:35. This is the swagger - ui based on the open api 3.0 specification of the Application + description: Generated by Flask-Swagger-Generator version: 1.0.0 externalDocs: description: Find out more about Swagger @@ -75,6 +74,37 @@ paths: type: integer description: None required: True + '/objects02/{object_id}': + get: + tags: + - Object02 Endpoints + operationId: 'retrieve_object02' + responses: + '200': + $ref: '#/components/responses/retrieve_object02_response' + parameters: + - in: path + name: object_id + schema: + type: integer + description: None + required: True + post: + tags: + - Object02 Endpoints + operationId: 'create_object02' + requestBody: + $ref: '#/components/requestBodies/create_object02_request_body' + responses: + '201': + $ref: '#/components/responses/create_object02_response' + parameters: + - in: path + name: object_id + schema: + type: integer + description: None + required: True components: securitySchemes: bearerAuth: @@ -96,6 +126,13 @@ components: application/json: schema: $ref: '#/components/schemas/create_object_request_body_schema' + create_object02_request_body: + description: None + required: True + content: + application/json: + schema: + $ref: '#/components/schemas/schema_two' responses: retrieve_object_response: description: retrieve_object response @@ -121,6 +158,18 @@ components: application/json: schema: $ref: '#/components/schemas/create_object_response_schema' + retrieve_object02_response: + description: retrieve_object02 response + content: + application/json: + schema: + $ref: '#/components/schemas/schema_two' + create_object02_response: + description: create_object02 response + content: + application/json: + schema: + $ref: '#/components/schemas/schema_two' schemas: schema_two: type: object diff --git a/tests/resources/test_apis.py b/tests/resources/test_apis.py index 7f3e87f..74b731b 100644 --- a/tests/resources/test_apis.py +++ b/tests/resources/test_apis.py @@ -25,8 +25,8 @@ class ObjectDeserializer(Schema): attribute_six = fields.Nested(ObjectChildDeserializer(many=False)) -class TestVersionThreeAPI(): - +class APITestBase(): + def __init__(self): self.app = create_app() self.generator = Generator.of(SwaggerVersion.VERSION_THREE) @@ -34,6 +34,9 @@ def __init__(self): self.create_test_api() self.app.register_blueprint(self.blueprint) + +class TestVersionThreeAPI(APITestBase): + def create_test_api(self): generator = self.generator @@ -44,7 +47,6 @@ def create_test_api(self): ) schema_three = generator.create_schema('schema_three', ObjectDeserializer()) - @generator.response(200, schema_two) @blueprint.route('/objects/', methods=['GET']) def retrieve_object(object_id, child_id): @@ -70,4 +72,19 @@ def update_object(object_id): @generator.request_body({'id': 10, 'name': 'test_object'}) @blueprint.route('/objects/', methods=['POST']) def create_object(object_id): + return jsonify({'objects': []}), 200 + + + @generator.response(200, schema_two) + @generator.path_tag('Object02 Endpoints') + @blueprint.route('/objects02/', methods=['GET']) + def retrieve_object02(object_id, child_id): + return jsonify({'objects': []}), 200 + + + @generator.response(201, schema_two) + @generator.request_body(schema_two) + @generator.path_tag('Object02 Endpoints') + @blueprint.route('/objects02/', methods=['POST']) + def create_object02(object_id): return jsonify({'objects': []}), 200 \ No newline at end of file