From c66a85d854501b2f6be9e56d0e80f733ab15934b Mon Sep 17 00:00:00 2001 From: Dhruv Manilawala Date: Tue, 10 Sep 2024 18:06:08 -0400 Subject: [PATCH] [red-knot] Infer target types for unpacked tuple assignment --- .../src/semantic_index/builder.rs | 42 ++++++++++++++----- .../src/semantic_index/definition.rs | 32 +++++++++++--- crates/red_knot_python_semantic/src/types.rs | 13 ++++++ .../src/types/infer.rs | 40 ++++++++++++++++-- 4 files changed, 107 insertions(+), 20 deletions(-) diff --git a/crates/red_knot_python_semantic/src/semantic_index/builder.rs b/crates/red_knot_python_semantic/src/semantic_index/builder.rs index 3f440a89b3f8f7..e2e8d8366a4628 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/builder.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/builder.rs @@ -27,7 +27,7 @@ use crate::semantic_index::SemanticIndex; use crate::Db; use super::constraint::{Constraint, PatternConstraint}; -use super::definition::{MatchPatternDefinitionNodeRef, WithItemDefinitionNodeRef}; +use super::definition::{AssignmentKind, MatchPatternDefinitionNodeRef, WithItemDefinitionNodeRef}; pub(super) struct SemanticIndexBuilder<'db> { // Builder state @@ -517,8 +517,17 @@ where debug_assert!(self.current_assignment.is_none()); self.visit_expr(&node.value); self.add_standalone_expression(&node.value); - self.current_assignment = Some(node.into()); for target in &node.targets { + let kind = match target { + ast::Expr::Name(_) => AssignmentKind::Name, + ast::Expr::List(_) | ast::Expr::Tuple(_) => AssignmentKind::Sequence(0), + ast::Expr::Starred(_) => AssignmentKind::Starred, + ast::Expr::Attribute(_) => AssignmentKind::Attribute, + ast::Expr::Subscript(_) => AssignmentKind::Subscript, + // TODO: is this a good default for an error recovery case like `1 = 2`? + _ => continue, + }; + self.current_assignment = Some(CurrentAssignment::Assign { node, kind }); self.visit_expr(target); } self.current_assignment = None; @@ -699,12 +708,13 @@ where let symbol = self.add_or_update_symbol(id.clone(), flags); if flags.contains(SymbolFlags::IS_DEFINED) { match self.current_assignment { - Some(CurrentAssignment::Assign(assignment)) => { + Some(CurrentAssignment::Assign { node, kind }) => { self.add_definition( symbol, AssignmentDefinitionNodeRef { - assignment, + assignment: node, target: name_node, + kind, }, ); } @@ -851,6 +861,19 @@ where self.visit_expr(key); self.visit_expr(value); } + ast::Expr::Tuple(ast::ExprTuple { elts, ctx, .. }) => { + for (index, element) in elts.iter().enumerate() { + if let Some(CurrentAssignment::Assign { + kind: AssignmentKind::Sequence(target_index), + .. + }) = self.current_assignment.as_mut() + { + *target_index = index; + } + self.visit_expr(element); + } + self.visit_expr_context(ctx); + } _ => { walk_expr(self, expr); } @@ -957,7 +980,10 @@ where #[derive(Copy, Clone, Debug)] enum CurrentAssignment<'a> { - Assign(&'a ast::StmtAssign), + Assign { + node: &'a ast::StmtAssign, + kind: AssignmentKind, + }, AnnAssign(&'a ast::StmtAnnAssign), AugAssign(&'a ast::StmtAugAssign), For(&'a ast::StmtFor), @@ -969,12 +995,6 @@ enum CurrentAssignment<'a> { WithItem(&'a ast::WithItem), } -impl<'a> From<&'a ast::StmtAssign> for CurrentAssignment<'a> { - fn from(value: &'a ast::StmtAssign) -> Self { - Self::Assign(value) - } -} - impl<'a> From<&'a ast::StmtAnnAssign> for CurrentAssignment<'a> { fn from(value: &'a ast::StmtAnnAssign) -> Self { Self::AnnAssign(value) diff --git a/crates/red_knot_python_semantic/src/semantic_index/definition.rs b/crates/red_knot_python_semantic/src/semantic_index/definition.rs index 00d51a3a060123..70a5aacebdf401 100644 --- a/crates/red_knot_python_semantic/src/semantic_index/definition.rs +++ b/crates/red_knot_python_semantic/src/semantic_index/definition.rs @@ -147,6 +147,7 @@ pub(crate) struct ImportFromDefinitionNodeRef<'a> { pub(crate) struct AssignmentDefinitionNodeRef<'a> { pub(crate) assignment: &'a ast::StmtAssign, pub(crate) target: &'a ast::ExprName, + pub(crate) kind: AssignmentKind, } #[derive(Copy, Clone, Debug)] @@ -203,12 +204,15 @@ impl DefinitionNodeRef<'_> { DefinitionNodeRef::NamedExpression(named) => { DefinitionKind::NamedExpression(AstNodeRef::new(parsed, named)) } - DefinitionNodeRef::Assignment(AssignmentDefinitionNodeRef { assignment, target }) => { - DefinitionKind::Assignment(AssignmentDefinitionKind { - assignment: AstNodeRef::new(parsed.clone(), assignment), - target: AstNodeRef::new(parsed, target), - }) - } + DefinitionNodeRef::Assignment(AssignmentDefinitionNodeRef { + assignment, + target, + kind, + }) => DefinitionKind::Assignment(AssignmentDefinitionKind { + assignment: AstNodeRef::new(parsed.clone(), assignment), + target: AstNodeRef::new(parsed, target), + kind, + }), DefinitionNodeRef::AnnotatedAssignment(assign) => { DefinitionKind::AnnotatedAssignment(AstNodeRef::new(parsed, assign)) } @@ -276,6 +280,7 @@ impl DefinitionNodeRef<'_> { Self::Assignment(AssignmentDefinitionNodeRef { assignment: _, target, + kind: _, }) => target.into(), Self::AnnotatedAssignment(node) => node.into(), Self::AugmentedAssignment(node) => node.into(), @@ -381,6 +386,7 @@ impl ImportFromDefinitionKind { pub struct AssignmentDefinitionKind { assignment: AstNodeRef, target: AstNodeRef, + kind: AssignmentKind, } impl AssignmentDefinitionKind { @@ -391,6 +397,20 @@ impl AssignmentDefinitionKind { pub(crate) fn target(&self) -> &ast::ExprName { self.target.node() } + + pub(crate) fn kind(&self) -> AssignmentKind { + self.kind + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum AssignmentKind { + Attribute, + Subscript, + Starred, + Name, + /// list or tuple with an index into the list of targets. + Sequence(usize), } #[derive(Clone, Debug)] diff --git a/crates/red_knot_python_semantic/src/types.rs b/crates/red_knot_python_semantic/src/types.rs index c5f38eb642082f..65ca3e1ab7ee4a 100644 --- a/crates/red_knot_python_semantic/src/types.rs +++ b/crates/red_knot_python_semantic/src/types.rs @@ -210,6 +210,13 @@ impl<'db> Type<'db> { matches!(self, Type::Never) } + pub const fn as_tuple_type(&self) -> Option<&TupleType<'db>> { + match self { + Type::Tuple(tuple_type) => Some(tuple_type), + _ => None, + } + } + pub const fn into_class_type(self) -> Option> { match self { Type::Class(class_type) => Some(class_type), @@ -672,3 +679,9 @@ pub struct TupleType<'db> { #[return_ref] elements: Box<[Type<'db>]>, } + +impl<'db> TupleType<'db> { + pub fn get(&self, db: &'db dyn Db, index: usize) -> Option<&Type<'db>> { + self.elements(db).get(index) + } +} diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 051e8db2bf6271..23940a4e0cebd8 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -40,7 +40,9 @@ use ruff_text_size::Ranged; use crate::module_name::ModuleName; use crate::module_resolver::{file_to_module, resolve_module}; use crate::semantic_index::ast_ids::{HasScopedAstId, HasScopedUseId, ScopedExpressionId}; -use crate::semantic_index::definition::{Definition, DefinitionKind, DefinitionNodeKey}; +use crate::semantic_index::definition::{ + AssignmentKind, Definition, DefinitionKind, DefinitionNodeKey, +}; use crate::semantic_index::expression::Expression; use crate::semantic_index::semantic_index; use crate::semantic_index::symbol::{NodeWithScopeKind, NodeWithScopeRef, ScopeId}; @@ -380,6 +382,7 @@ impl<'db> TypeInferenceBuilder<'db> { DefinitionKind::Assignment(assignment) => { self.infer_assignment_definition( assignment.target(), + assignment.kind(), assignment.assignment(), definition, ); @@ -957,19 +960,34 @@ impl<'db> TypeInferenceBuilder<'db> { fn infer_assignment_definition( &mut self, target: &ast::ExprName, + kind: AssignmentKind, assignment: &ast::StmtAssign, definition: Definition<'db>, ) { let expression = self.index.expression(assignment.value.as_ref()); let result = infer_expression_types(self.db, expression); self.extend(result); + let value_ty = self .types .expression_ty(assignment.value.scoped_ast_id(self.db, self.scope)); + + let target_ty = match (value_ty, kind) { + (Type::Tuple(tuple_type), AssignmentKind::Sequence(target_index)) => { + // TODO: when does this happen? + tuple_type + .get(self.db, target_index) + .copied() + .unwrap_or(Type::Unknown) + } + _ => value_ty, + }; + self.types .expressions - .insert(target.scoped_ast_id(self.db, self.scope), value_ty); - self.types.definitions.insert(definition, value_ty); + .insert(target.scoped_ast_id(self.db, self.scope), target_ty); + + self.types.definitions.insert(definition, target_ty); } fn infer_annotated_assignment_statement(&mut self, assignment: &ast::StmtAnnAssign) { @@ -4057,6 +4075,22 @@ mod tests { Ok(()) } + #[test] + fn unpacked_tuple_assignment() { + let mut db = setup_db(); + + db.write_dedented( + "/src/a.py", + " + x, y = 1, 2 + ", + ) + .unwrap(); + + assert_public_ty(&db, "/src/a.py", "x", "Literal[1]"); + assert_public_ty(&db, "/src/a.py", "y", "Literal[2]"); + } + #[test] fn list_literal() -> anyhow::Result<()> { let mut db = setup_db();