diff --git a/docs/changelog.md b/docs/changelog.md index 535007d..3de5cdd 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -5,6 +5,8 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ## [UNRELEASED] +### Fixed +- generate the mapping for discriminator fields properly instead of showing a "null" value in the generated schema (#12). ## [0.12.0] - 2022-08-27 ### Added diff --git a/drf_standardized_errors/openapi.py b/drf_standardized_errors/openapi.py index cd887a0..135bce9 100644 --- a/drf_standardized_errors/openapi.py +++ b/drf_standardized_errors/openapi.py @@ -19,6 +19,7 @@ from .handler import exception_handler as standardized_errors_handler from .openapi_serializers import ( + ClientErrorEnum, ErrorResponse401Serializer, ErrorResponse403Serializer, ErrorResponse404Serializer, @@ -28,6 +29,7 @@ ErrorResponse429Serializer, ErrorResponse500Serializer, ParseErrorResponseSerializer, + ValidationErrorEnum, ) from .openapi_utils import ( InputDataField, @@ -253,12 +255,13 @@ def _get_http400_serializer(self): operation_id = self.get_operation_id() component_name = f"{camelize(operation_id)}ErrorResponse400" - http400_serializers = [] + http400_serializers = {} if self._should_add_validation_error_response(): serializer = self._get_serializer_for_validation_error_response() - http400_serializers.append(serializer) + http400_serializers[ValidationErrorEnum.VALIDATION_ERROR.value] = serializer if self._should_add_parse_error_response(): - http400_serializers.append(ParseErrorResponseSerializer) + serializer = ParseErrorResponseSerializer + http400_serializers[ClientErrorEnum.CLIENT_ERROR.value] = serializer return PolymorphicProxySerializer( component_name=component_name, diff --git a/drf_standardized_errors/openapi_utils.py b/drf_standardized_errors/openapi_utils.py index 2cf9303..476660a 100644 --- a/drf_standardized_errors/openapi_utils.py +++ b/drf_standardized_errors/openapi_utils.py @@ -410,10 +410,10 @@ def get_validation_error_serializer( ): validation_error_component_name = f"{camelize(operation_id)}ValidationError" errors_component_name = f"{camelize(operation_id)}Error" - sub_serializers = [ - get_error_serializer(operation_id, sfield.name, sfield.error_codes) + sub_serializers = { + sfield.name: get_error_serializer(operation_id, sfield.name, sfield.error_codes) for sfield in data_fields - ] + } class ValidationErrorSerializer(serializers.Serializer): type = serializers.ChoiceField(choices=ValidationErrorEnum.choices) diff --git a/tests/test_openapi.py b/tests/test_openapi.py index 2748e76..d61116d 100644 --- a/tests/test_openapi.py +++ b/tests/test_openapi.py @@ -84,6 +84,32 @@ def test_validation_error_for_unsafe_method(): assert "400" in responses +def test_discriminator_mapping_for_validation_serializer(): + route = "validate/" + view = ValidationView.as_view() + schema = generate_view_schema(route, view) + + discriminator = schema["components"]["schemas"]["ValidateCreateError"][ + "discriminator" + ] + assert discriminator["propertyName"] == "attr" + mapping_fields = set(discriminator["mapping"]) + assert mapping_fields == {"non_field_errors", "first_name"} + + +def test_discriminator_mapping_for_http400_serializer(): + route = "validate/" + view = ValidationView.as_view(parser_classes=[JSONParser]) + schema = generate_view_schema(route, view) + + discriminator = schema["components"]["schemas"]["ValidateCreateErrorResponse400"][ + "discriminator" + ] + assert discriminator["propertyName"] == "type" + mapping_fields = set(discriminator["mapping"]) + assert mapping_fields == {"validation_error", "client_error"} + + def test_no_validation_error_for_unsafe_method(): route = "validate/" view = ValidationView.as_view(serializer_class=None)