6
6
from contextlib import suppress
7
7
from collections import OrderedDict
8
8
from decimal import Decimal
9
- from inspect import signature as inspect_signature
10
9
11
10
import typing
12
11
from django .core import validators
29
28
import pkg_resources
30
29
drf_version = pkg_resources .get_distribution ("djangorestframework" ).version
31
30
31
+ try :
32
+ from types import NoneType , UnionType
33
+
34
+ UNION_TYPES = (typing .Union , UnionType )
35
+ except ImportError : # Python < 3.10
36
+ NoneType = type (None )
37
+ UNION_TYPES = (typing .Union ,)
38
+
32
39
logger = logging .getLogger (__name__ )
33
40
34
41
@@ -480,15 +487,6 @@ def decimal_return_type():
480
487
return openapi .TYPE_STRING if rest_framework_settings .COERCE_DECIMAL_TO_STRING else openapi .TYPE_NUMBER
481
488
482
489
483
- def get_origin_type (hint_class ):
484
- return getattr (hint_class , '__origin__' , None ) or hint_class
485
-
486
-
487
- def hint_class_issubclass (hint_class , check_class ):
488
- origin_type = get_origin_type (hint_class )
489
- return inspect .isclass (origin_type ) and issubclass (origin_type , check_class )
490
-
491
-
492
490
hinting_type_info = [
493
491
(bool , (openapi .TYPE_BOOLEAN , None )),
494
492
(int , (openapi .TYPE_INTEGER , None )),
@@ -505,11 +503,15 @@ def hint_class_issubclass(hint_class, check_class):
505
503
if hasattr (typing , 'get_args' ):
506
504
# python >=3.8
507
505
typing_get_args = typing .get_args
506
+ typing_get_origin = typing .get_origin
508
507
else :
509
508
# python <3.8
510
509
def typing_get_args (tp ):
511
510
return getattr (tp , '__args__' , ())
512
511
512
+ def typing_get_origin (tp ):
513
+ return getattr (tp , '__origin__' , None )
514
+
513
515
514
516
def inspect_collection_hint_class (hint_class ):
515
517
args = typing_get_args (hint_class )
@@ -525,12 +527,6 @@ def inspect_collection_hint_class(hint_class):
525
527
hinting_type_info .append (((typing .Sequence , typing .AbstractSet ), inspect_collection_hint_class ))
526
528
527
529
528
- def _get_union_types (hint_class ):
529
- origin_type = get_origin_type (hint_class )
530
- if origin_type is typing .Union :
531
- return hint_class .__args__
532
-
533
-
534
530
def get_basic_type_info_from_hint (hint_class ):
535
531
"""Given a class (eg from a SerializerMethodField's return type hint,
536
532
return its basic type information - ``type``, ``format``, ``pattern``,
@@ -540,21 +536,28 @@ def get_basic_type_info_from_hint(hint_class):
540
536
:return: the extracted attributes as a dictionary, or ``None`` if the field type is not known
541
537
:rtype: OrderedDict
542
538
"""
543
- union_types = _get_union_types (hint_class )
544
539
545
- if union_types :
540
+ if typing_get_origin ( hint_class ) in UNION_TYPES :
546
541
# Optional is implemented as Union[T, None]
547
- if len (union_types ) == 2 and isinstance (None , union_types [1 ]):
548
- result = get_basic_type_info_from_hint (union_types [0 ])
542
+ filtered_types = [t for t in typing_get_args (hint_class ) if t is not NoneType ]
543
+ if len (filtered_types ) == 1 :
544
+ result = get_basic_type_info_from_hint (filtered_types [0 ])
549
545
if result :
550
546
result ['x-nullable' ] = True
551
547
552
548
return result
553
549
554
550
return None
555
551
552
+ # resolve the origin class if the class is generic
553
+ resolved_class = typing_get_origin (hint_class ) or hint_class
554
+
555
+ # bail out early
556
+ if not inspect .isclass (resolved_class ):
557
+ return None
558
+
556
559
for check_class , info in hinting_type_info :
557
- if hint_class_issubclass ( hint_class , check_class ):
560
+ if issubclass ( resolved_class , check_class ):
558
561
if callable (info ):
559
562
return info (hint_class )
560
563
@@ -617,17 +620,19 @@ def field_to_swagger_object(self, field, swagger_object_type, use_references, **
617
620
return self .probe_field_inspectors (serializer , swagger_object_type , use_references , read_only = True )
618
621
else :
619
622
# look for Python 3.5+ style type hinting of the return value
620
- hint_class = inspect_signature (method ).return_annotation
621
-
622
- if not inspect .isclass (hint_class ) and hasattr (hint_class , '__args__' ):
623
- hint_class = hint_class .__args__ [0 ]
624
- if inspect .isclass (hint_class ) and not issubclass (hint_class , inspect ._empty ):
625
- type_info = get_basic_type_info_from_hint (hint_class )
626
-
627
- if type_info is not None :
628
- SwaggerType , ChildSwaggerType = self ._get_partial_types (field , swagger_object_type ,
629
- use_references , ** kwargs )
630
- return SwaggerType (** type_info )
623
+ hint_class = typing .get_type_hints (method ).get ('return' )
624
+
625
+ # annotations such as typing.Optional have an __instancecheck__
626
+ # hook and will not look like classes, but `issubclass` needs
627
+ # a class as its first argument, so only in that case abort
628
+ if inspect .isclass (hint_class ) and issubclass (hint_class , inspect ._empty ):
629
+ return NotHandled
630
+
631
+ type_info = get_basic_type_info_from_hint (hint_class )
632
+ if type_info is not None :
633
+ SwaggerType , ChildSwaggerType = self ._get_partial_types (field , swagger_object_type ,
634
+ use_references , ** kwargs )
635
+ return SwaggerType (** type_info )
631
636
632
637
return NotHandled
633
638
0 commit comments