diff --git a/packages/pyright-internal/src/analyzer/enums.ts b/packages/pyright-internal/src/analyzer/enums.ts index 83e3e259cb..1309c9bac3 100644 --- a/packages/pyright-internal/src/analyzer/enums.ts +++ b/packages/pyright-internal/src/analyzer/enums.ts @@ -291,7 +291,7 @@ export function transformTypeForPossibleEnumClass( evaluator: TypeEvaluator, statementNode: ParseNode, nameNode: NameNode, - getValueType: () => Type + getValueType: () => { declaredType?: Type; assignedType?: Type } ): Type | undefined { // If the node is within a class that derives from the metaclass // "EnumMeta", we need to treat assignments differently. @@ -343,9 +343,11 @@ export function transformTypeForPossibleEnumClass( return undefined; } - let valueType: Type; + const valueTypeInfo = getValueType(); + const declaredType = valueTypeInfo.declaredType; + let assignedType = valueTypeInfo.assignedType; - valueType = getValueType(); + let valueType = declaredType ?? assignedType ?? UnknownType.create(); // If the LHS is an unpacked tuple, we need to handle this as // a special case. @@ -370,18 +372,22 @@ export function transformTypeForPossibleEnumClass( return undefined; } + if (!assignedType && statementNode.nodeType === ParseNodeType.Assignment) { + assignedType = evaluator.getTypeOfExpression(statementNode.rightExpression).type; + } + // Handle the Python 3.11 "enum.member()" and "enum.nonmember()" features. - if (isClassInstance(valueType) && ClassType.isBuiltIn(valueType)) { - if (valueType.details.fullName === 'enum.nonmember') { - return valueType.typeArguments && valueType.typeArguments.length > 0 - ? valueType.typeArguments[0] + if (assignedType && isClassInstance(assignedType) && ClassType.isBuiltIn(assignedType)) { + if (assignedType.details.fullName === 'enum.nonmember') { + return assignedType.typeArguments && assignedType.typeArguments.length > 0 + ? assignedType.typeArguments[0] : UnknownType.create(); } - if (valueType.details.fullName === 'enum.member') { + if (assignedType.details.fullName === 'enum.member') { valueType = - valueType.typeArguments && valueType.typeArguments.length > 0 - ? valueType.typeArguments[0] + assignedType.typeArguments && assignedType.typeArguments.length > 0 + ? assignedType.typeArguments[0] : UnknownType.create(); isMemberOfEnumeration = true; } diff --git a/packages/pyright-internal/src/analyzer/typeEvaluator.ts b/packages/pyright-internal/src/analyzer/typeEvaluator.ts index d5b0836b83..5f250f67a4 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluator.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluator.ts @@ -15656,14 +15656,22 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions // If this is an enum, transform the type as required. rightHandType = srcType; - if (node.leftExpression.nodeType === ParseNodeType.Name && !node.typeAnnotationComment) { + + let targetName: NameNode | undefined; + if (node.leftExpression.nodeType === ParseNodeType.Name) { + targetName = node.leftExpression; + } else if ( + node.leftExpression.nodeType === ParseNodeType.TypeAnnotation && + node.leftExpression.valueExpression.nodeType === ParseNodeType.Name + ) { + targetName = node.leftExpression.valueExpression; + } + + if (targetName) { rightHandType = - transformTypeForPossibleEnumClass( - evaluatorInterface, - node, - node.leftExpression, - () => rightHandType! - ) ?? rightHandType; + transformTypeForPossibleEnumClass(evaluatorInterface, node, targetName, () => { + return { assignedType: rightHandType }; + }) ?? rightHandType; } if (typeAliasNameNode) { @@ -17176,8 +17184,9 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions // In case this is an enum class and a method wrapped in an enum.member. decoratedType = - transformTypeForPossibleEnumClass(evaluatorInterface, node, node.name, () => decoratedType!) ?? - decoratedType; + transformTypeForPossibleEnumClass(evaluatorInterface, node, node.name, () => { + return { assignedType: decoratedType }; + }) ?? decoratedType; // See if there are any overloads provided by previous function declarations. if (isFunction(decoratedType)) { @@ -20529,7 +20538,9 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions evaluatorInterface, variableNode, declaration.node, - () => declaredType! + () => { + return { declaredType }; + } ) ?? declaredType; } } @@ -20856,11 +20867,12 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions () => { assert(resolvedDecl.inferredTypeSource !== undefined); const inferredTypeSource = resolvedDecl.inferredTypeSource; - return ( - evaluateTypeForSubnode(inferredTypeSource, () => { - evaluateTypesForStatement(inferredTypeSource); - })?.type ?? UnknownType.create() - ); + return { + assignedType: + evaluateTypeForSubnode(inferredTypeSource, () => { + evaluateTypesForStatement(inferredTypeSource); + })?.type ?? UnknownType.create(), + }; } ); diff --git a/packages/pyright-internal/src/tests/samples/enum9.py b/packages/pyright-internal/src/tests/samples/enum9.py index 08a5c4a740..dd6c3d29b6 100644 --- a/packages/pyright-internal/src/tests/samples/enum9.py +++ b/packages/pyright-internal/src/tests/samples/enum9.py @@ -1,27 +1,38 @@ # This sample tests the enum.member and enum.nonmember classes introduced # in Python 3.11. -import enum +from enum import Enum, member, nonmember from typing import Literal -class E(enum.Enum): +class Enum1(Enum): MEMBER = 1 - ANOTHER_MEMBER = enum.member(2) - NON_MEMBER = enum.nonmember(3) + ANOTHER_MEMBER = member(2) + NON_MEMBER = nonmember(3) - @enum.member + @member @staticmethod def ALSO_A_MEMBER() -> Literal[4]: return 4 -reveal_type(E.MEMBER, expected_text="Literal[E.MEMBER]") -reveal_type(E.ANOTHER_MEMBER, expected_text="Literal[E.ANOTHER_MEMBER]") -reveal_type(E.ALSO_A_MEMBER, expected_text="Literal[E.ALSO_A_MEMBER]") -reveal_type(E.NON_MEMBER, expected_text="int") +reveal_type(Enum1.MEMBER, expected_text="Literal[Enum1.MEMBER]") +reveal_type(Enum1.ANOTHER_MEMBER, expected_text="Literal[Enum1.ANOTHER_MEMBER]") +reveal_type(Enum1.ALSO_A_MEMBER, expected_text="Literal[Enum1.ALSO_A_MEMBER]") +reveal_type(Enum1.NON_MEMBER, expected_text="int") -reveal_type(E.MEMBER.value, expected_text="Literal[1]") -reveal_type(E.ANOTHER_MEMBER.value, expected_text="int") -reveal_type(E.ALSO_A_MEMBER.value, expected_text="() -> Literal[4]") +reveal_type(Enum1.MEMBER.value, expected_text="Literal[1]") +reveal_type(Enum1.ANOTHER_MEMBER.value, expected_text="int") +reveal_type(Enum1.ALSO_A_MEMBER.value, expected_text="() -> Literal[4]") + + +class Enum2(Enum): + MEMBER: int = member(1) + NON_MEMBER: int = nonmember(1) + + +reveal_type(Enum2.MEMBER, expected_text="Literal[Enum2.MEMBER]") +reveal_type(Enum2.NON_MEMBER, expected_text="int") + +reveal_type(Enum2.MEMBER.value, expected_text="int")