Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[red-knot] Infer target types for unpacked tuple assignment #13316

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 31 additions & 11 deletions crates/red_knot_python_semantic/src/semantic_index/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use crate::semantic_index::SemanticIndex;
use crate::Db;

use super::constraint::{Constraint, PatternConstraint};
use super::definition::{MatchPatternDefinitionNodeRef, WithItemDefinitionNodeRef};
use super::definition::{AssignmentKind, MatchPatternDefinitionNodeRef, WithItemDefinitionNodeRef};

pub(super) struct SemanticIndexBuilder<'db> {
// Builder state
Expand Down Expand Up @@ -517,8 +517,17 @@ where
debug_assert!(self.current_assignment.is_none());
self.visit_expr(&node.value);
self.add_standalone_expression(&node.value);
self.current_assignment = Some(node.into());
for target in &node.targets {
let kind = match target {
ast::Expr::Name(_) => AssignmentKind::Name,
ast::Expr::List(_) | ast::Expr::Tuple(_) => AssignmentKind::Sequence(0),
ast::Expr::Starred(_) => AssignmentKind::Starred,
ast::Expr::Attribute(_) => AssignmentKind::Attribute,
ast::Expr::Subscript(_) => AssignmentKind::Subscript,
// TODO: is this a good default for an error recovery case like `1 = 2`?
_ => continue,
};
self.current_assignment = Some(CurrentAssignment::Assign { node, kind });
self.visit_expr(target);
}
self.current_assignment = None;
Expand Down Expand Up @@ -699,12 +708,13 @@ where
let symbol = self.add_or_update_symbol(id.clone(), flags);
if flags.contains(SymbolFlags::IS_DEFINED) {
match self.current_assignment {
Some(CurrentAssignment::Assign(assignment)) => {
Some(CurrentAssignment::Assign { node, kind }) => {
self.add_definition(
symbol,
AssignmentDefinitionNodeRef {
assignment,
assignment: node,
target: name_node,
kind,
},
);
}
Expand Down Expand Up @@ -851,6 +861,19 @@ where
self.visit_expr(key);
self.visit_expr(value);
}
ast::Expr::Tuple(ast::ExprTuple { elts, ctx, .. }) => {
for (index, element) in elts.iter().enumerate() {
if let Some(CurrentAssignment::Assign {
kind: AssignmentKind::Sequence(target_index),
..
}) = self.current_assignment.as_mut()
{
*target_index = index;
}
self.visit_expr(element);
}
self.visit_expr_context(ctx);
}
_ => {
walk_expr(self, expr);
}
Expand Down Expand Up @@ -957,7 +980,10 @@ where

#[derive(Copy, Clone, Debug)]
enum CurrentAssignment<'a> {
Assign(&'a ast::StmtAssign),
Assign {
node: &'a ast::StmtAssign,
kind: AssignmentKind,
},
AnnAssign(&'a ast::StmtAnnAssign),
AugAssign(&'a ast::StmtAugAssign),
For(&'a ast::StmtFor),
Expand All @@ -969,12 +995,6 @@ enum CurrentAssignment<'a> {
WithItem(&'a ast::WithItem),
}

impl<'a> From<&'a ast::StmtAssign> for CurrentAssignment<'a> {
fn from(value: &'a ast::StmtAssign) -> Self {
Self::Assign(value)
}
}

impl<'a> From<&'a ast::StmtAnnAssign> for CurrentAssignment<'a> {
fn from(value: &'a ast::StmtAnnAssign) -> Self {
Self::AnnAssign(value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ pub(crate) struct ImportFromDefinitionNodeRef<'a> {
pub(crate) struct AssignmentDefinitionNodeRef<'a> {
pub(crate) assignment: &'a ast::StmtAssign,
pub(crate) target: &'a ast::ExprName,
pub(crate) kind: AssignmentKind,
}

#[derive(Copy, Clone, Debug)]
Expand Down Expand Up @@ -203,12 +204,15 @@ impl DefinitionNodeRef<'_> {
DefinitionNodeRef::NamedExpression(named) => {
DefinitionKind::NamedExpression(AstNodeRef::new(parsed, named))
}
DefinitionNodeRef::Assignment(AssignmentDefinitionNodeRef { assignment, target }) => {
DefinitionKind::Assignment(AssignmentDefinitionKind {
assignment: AstNodeRef::new(parsed.clone(), assignment),
target: AstNodeRef::new(parsed, target),
})
}
DefinitionNodeRef::Assignment(AssignmentDefinitionNodeRef {
assignment,
target,
kind,
}) => DefinitionKind::Assignment(AssignmentDefinitionKind {
assignment: AstNodeRef::new(parsed.clone(), assignment),
target: AstNodeRef::new(parsed, target),
kind,
}),
DefinitionNodeRef::AnnotatedAssignment(assign) => {
DefinitionKind::AnnotatedAssignment(AstNodeRef::new(parsed, assign))
}
Expand Down Expand Up @@ -276,6 +280,7 @@ impl DefinitionNodeRef<'_> {
Self::Assignment(AssignmentDefinitionNodeRef {
assignment: _,
target,
kind: _,
}) => target.into(),
Self::AnnotatedAssignment(node) => node.into(),
Self::AugmentedAssignment(node) => node.into(),
Expand Down Expand Up @@ -381,6 +386,7 @@ impl ImportFromDefinitionKind {
pub struct AssignmentDefinitionKind {
assignment: AstNodeRef<ast::StmtAssign>,
target: AstNodeRef<ast::ExprName>,
kind: AssignmentKind,
}

impl AssignmentDefinitionKind {
Expand All @@ -391,6 +397,26 @@ impl AssignmentDefinitionKind {
pub(crate) fn target(&self) -> &ast::ExprName {
self.target.node()
}

pub(crate) fn kind(&self) -> AssignmentKind {
self.kind
}
}

/// The kind of assignment target expression.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum AssignmentKind {
/// An attribute expression e.g., `x.y = 1`.
Attribute,
/// A subscript expression e.g., `x[0] = 1`.
Subscript,
/// A starred expression e.g., `*x = 1`.
Starred,
/// A name expression e.g., `x = 1`.
Name,
/// A list or tuple expression which corresponds to an unpacking assignment e.g., `(x, y) = (1, 2)`.
/// The containing value is the position of the target in the assignment.
Sequence(usize),
}

#[derive(Clone, Debug)]
Expand Down
13 changes: 13 additions & 0 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,13 @@ impl<'db> Type<'db> {
matches!(self, Type::Never)
}

pub const fn as_tuple_type(&self) -> Option<&TupleType<'db>> {
match self {
Type::Tuple(tuple_type) => Some(tuple_type),
_ => None,
}
}

pub const fn into_class_type(self) -> Option<ClassType<'db>> {
match self {
Type::Class(class_type) => Some(class_type),
Expand Down Expand Up @@ -672,3 +679,9 @@ pub struct TupleType<'db> {
#[return_ref]
elements: Box<[Type<'db>]>,
}

impl<'db> TupleType<'db> {
pub fn get(&self, db: &'db dyn Db, index: usize) -> Option<&Type<'db>> {
self.elements(db).get(index)
}
}
42 changes: 39 additions & 3 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ use ruff_text_size::Ranged;
use crate::module_name::ModuleName;
use crate::module_resolver::{file_to_module, resolve_module};
use crate::semantic_index::ast_ids::{HasScopedAstId, HasScopedUseId, ScopedExpressionId};
use crate::semantic_index::definition::{Definition, DefinitionKind, DefinitionNodeKey};
use crate::semantic_index::definition::{
AssignmentKind, Definition, DefinitionKind, DefinitionNodeKey,
};
use crate::semantic_index::expression::Expression;
use crate::semantic_index::semantic_index;
use crate::semantic_index::symbol::{NodeWithScopeKind, NodeWithScopeRef, ScopeId};
Expand Down Expand Up @@ -380,6 +382,7 @@ impl<'db> TypeInferenceBuilder<'db> {
DefinitionKind::Assignment(assignment) => {
self.infer_assignment_definition(
assignment.target(),
assignment.kind(),
assignment.assignment(),
definition,
);
Expand Down Expand Up @@ -957,19 +960,36 @@ impl<'db> TypeInferenceBuilder<'db> {
fn infer_assignment_definition(
&mut self,
target: &ast::ExprName,
kind: AssignmentKind,
assignment: &ast::StmtAssign,
definition: Definition<'db>,
) {
let expression = self.index.expression(assignment.value.as_ref());
let result = infer_expression_types(self.db, expression);
self.extend(result);

let value_ty = self
.types
.expression_ty(assignment.value.scoped_ast_id(self.db, self.scope));

let target_ty = match (value_ty, kind) {
(Type::Tuple(tuple_type), AssignmentKind::Sequence(target_index)) => {
// TODO: Raise diagnostic for cases like:
// [a, b] = (1, 2, 3)
// [a, b, c] = (1, 2)
tuple_type
.get(self.db, target_index)
.copied()
.unwrap_or(Type::Unknown)
}
_ => value_ty,
};

self.types
.expressions
.insert(target.scoped_ast_id(self.db, self.scope), value_ty);
self.types.definitions.insert(definition, value_ty);
.insert(target.scoped_ast_id(self.db, self.scope), target_ty);

self.types.definitions.insert(definition, target_ty);
}

fn infer_annotated_assignment_statement(&mut self, assignment: &ast::StmtAnnAssign) {
Expand Down Expand Up @@ -4057,6 +4077,22 @@ mod tests {
Ok(())
}

#[test]
fn unpacked_tuple_assignment() {
let mut db = setup_db();

db.write_dedented(
"/src/a.py",
"
x, y = 1, 2
",
)
.unwrap();

assert_public_ty(&db, "/src/a.py", "x", "Literal[1]");
assert_public_ty(&db, "/src/a.py", "y", "Literal[2]");
}

#[test]
fn list_literal() -> anyhow::Result<()> {
let mut db = setup_db();
Expand Down
Loading