diff --git a/Cargo.lock b/Cargo.lock index f4326c3ffa3f8..32598f2d43a8e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2083,6 +2083,7 @@ dependencies = [ "countme", "hashbrown", "insta", + "itertools 0.13.0", "ordermap", "red_knot_vendored", "ruff_db", diff --git a/crates/red_knot_python_semantic/Cargo.toml b/crates/red_knot_python_semantic/Cargo.toml index 6aff354f5fb6d..f9aee056356ba 100644 --- a/crates/red_knot_python_semantic/Cargo.toml +++ b/crates/red_knot_python_semantic/Cargo.toml @@ -24,6 +24,7 @@ bitflags = { workspace = true } camino = { workspace = true } compact_str = { workspace = true } countme = { workspace = true } +itertools = { workspace = true} ordermap = { workspace = true } salsa = { workspace = true } thiserror = { workspace = true } diff --git a/crates/red_knot_python_semantic/src/types/infer.rs b/crates/red_knot_python_semantic/src/types/infer.rs index 119fb207b8be7..b46be62df9d3d 100644 --- a/crates/red_knot_python_semantic/src/types/infer.rs +++ b/crates/red_knot_python_semantic/src/types/infer.rs @@ -26,6 +26,7 @@ //! stringified annotations. We have a fourth Salsa query for inferring the deferred types //! associated with a particular definition. Scope-level inference infers deferred types for all //! definitions once the rest of the types in the scope have been inferred. +use itertools::Itertools; use std::num::NonZeroU32; use ruff_db::files::File; @@ -328,6 +329,14 @@ impl<'db> TypeInferenceBuilder<'db> { matches!(self.region, InferenceRegion::Deferred(_)) } + /// Get the already-inferred type of an expression node. + /// + /// PANIC if no type has been inferred for this node. + fn expression_ty(&self, expr: &ast::Expr) -> Type<'db> { + self.types + .expression_ty(expr.scoped_ast_id(self.db, self.scope)) + } + /// Infers types in the given [`InferenceRegion`]. fn infer_region(&mut self) { match self.region { @@ -984,9 +993,7 @@ impl<'db> TypeInferenceBuilder<'db> { // TODO(dhruvmanila): The correct type inference here is the return type of the __enter__ // method of the context manager. - let context_expr_ty = self - .types - .expression_ty(with_item.context_expr.scoped_ast_id(self.db, self.scope)); + let context_expr_ty = self.expression_ty(&with_item.context_expr); self.types .expressions @@ -1151,9 +1158,7 @@ impl<'db> TypeInferenceBuilder<'db> { let expression = self.index.expression(assignment.value.as_ref()); let result = infer_expression_types(self.db, expression); self.extend(result); - let value_ty = self - .types - .expression_ty(assignment.value.scoped_ast_id(self.db, self.scope)); + let value_ty = self.expression_ty(&assignment.value); self.add_binding(assignment.into(), definition, value_ty); self.types .expressions @@ -1349,9 +1354,7 @@ impl<'db> TypeInferenceBuilder<'db> { let expression = self.index.expression(iterable); let result = infer_expression_types(self.db, expression); self.extend(result); - let iterable_ty = self - .types - .expression_ty(iterable.scoped_ast_id(self.db, self.scope)); + let iterable_ty = self.expression_ty(iterable); let loop_var_value_ty = if is_async { // TODO(Alex): async iterables/iterators! @@ -2434,28 +2437,41 @@ impl<'db> TypeInferenceBuilder<'db> { op, values, } = bool_op; + Self::infer_chained_boolean_types( + self.db, + *op, + values.iter().map(|value| self.infer_expression(value)), + values.len(), + ) + } + + /// Computes the output of a chain of (one) boolean operation, consuming as input an iterator + /// of types. The iterator is consumed even if the boolean evaluation can be short-circuited, + /// in order to ensure the invariant that all expressions are evaluated when inferring types. + fn infer_chained_boolean_types( + db: &'db dyn Db, + op: ast::BoolOp, + values: impl IntoIterator>, + n_values: usize, + ) -> Type<'db> { let mut done = false; UnionType::from_elements( - self.db, - values.iter().enumerate().map(|(i, value)| { - // We need to infer the type of every expression (that's an invariant maintained by - // type inference), even if we can short-circuit boolean evaluation of some of - // those types. - let value_ty = self.infer_expression(value); + db, + values.into_iter().enumerate().map(|(i, ty)| { if done { Type::Never } else { - let is_last = i == values.len() - 1; - match (value_ty.bool(self.db), is_last, op) { - (Truthiness::Ambiguous, _, _) => value_ty, + let is_last = i == n_values - 1; + match (ty.bool(db), is_last, op) { + (Truthiness::Ambiguous, _, _) => ty, (Truthiness::AlwaysTrue, false, ast::BoolOp::And) => Type::Never, (Truthiness::AlwaysFalse, false, ast::BoolOp::Or) => Type::Never, (Truthiness::AlwaysFalse, _, ast::BoolOp::And) | (Truthiness::AlwaysTrue, _, ast::BoolOp::Or) => { done = true; - value_ty + ty } - (_, true, _) => value_ty, + (_, true, _) => ty, } } }), @@ -2466,16 +2482,138 @@ impl<'db> TypeInferenceBuilder<'db> { let ast::ExprCompare { range: _, left, - ops: _, + ops, comparators, } = compare; self.infer_expression(left); - // TODO actually handle ops and return correct type for right in comparators.as_ref() { self.infer_expression(right); } - Type::Todo + + // https://docs.python.org/3/reference/expressions.html#comparisons + // > Formally, if `a, b, c, …, y, z` are expressions and `op1, op2, …, opN` are comparison + // > operators, then `a op1 b op2 c ... y opN z` is equivalent to a `op1 b and b op2 c and + // ... > y opN z`, except that each expression is evaluated at most once. + // + // As some operators (==, !=, <, <=, >, >=) *can* return an arbitrary type, the logic below + // is shared with the one in `infer_binary_type_comparison`. + Self::infer_chained_boolean_types( + self.db, + ast::BoolOp::And, + std::iter::once(left.as_ref()) + .chain(comparators.as_ref().iter()) + .tuple_windows::<(_, _)>() + .zip(ops.iter()) + .map(|((left, right), op)| { + let left_ty = self.expression_ty(left); + let right_ty = self.expression_ty(right); + + self.infer_binary_type_comparison(left_ty, *op, right_ty) + .unwrap_or_else(|| { + // Handle unsupported operators (diagnostic, `bool`/`Unknown` outcome) + self.add_diagnostic( + AnyNodeRef::ExprCompare(compare), + "operator-unsupported", + format_args!( + "Operator `{}` is not supported for types `{}` and `{}`", + op, + left_ty.display(self.db), + right_ty.display(self.db) + ), + ); + match op { + // `in, not in, is, is not` always return bool instances + ast::CmpOp::In + | ast::CmpOp::NotIn + | ast::CmpOp::Is + | ast::CmpOp::IsNot => { + builtins_symbol_ty(self.db, "bool").to_instance(self.db) + } + // Other operators can return arbitrary types + _ => Type::Unknown, + } + }) + }), + ops.len(), + ) + } + + /// Infers the type of a binary comparison (e.g. 'left == right'). See + /// `infer_compare_expression` for the higher level logic dealing with multi-comparison + /// expressions. + /// + /// If the operation is not supported, return None (we need upstream context to emit a + /// diagnostic). + fn infer_binary_type_comparison( + &mut self, + left: Type<'db>, + op: ast::CmpOp, + right: Type<'db>, + ) -> Option> { + // Note: identity (is, is not) for equal builtin types is unreliable and not part of the + // language spec. + // - `[ast::CompOp::Is]`: return `false` if unequal, `bool` if equal + // - `[ast::CompOp::IsNot]`: return `true` if unequal, `bool` if equal + match (left, right) { + (Type::IntLiteral(n), Type::IntLiteral(m)) => match op { + ast::CmpOp::Eq => Some(Type::BooleanLiteral(n == m)), + ast::CmpOp::NotEq => Some(Type::BooleanLiteral(n != m)), + ast::CmpOp::Lt => Some(Type::BooleanLiteral(n < m)), + ast::CmpOp::LtE => Some(Type::BooleanLiteral(n <= m)), + ast::CmpOp::Gt => Some(Type::BooleanLiteral(n > m)), + ast::CmpOp::GtE => Some(Type::BooleanLiteral(n >= m)), + ast::CmpOp::Is => { + if n == m { + Some(builtins_symbol_ty(self.db, "bool").to_instance(self.db)) + } else { + Some(Type::BooleanLiteral(false)) + } + } + ast::CmpOp::IsNot => { + if n == m { + Some(builtins_symbol_ty(self.db, "bool").to_instance(self.db)) + } else { + Some(Type::BooleanLiteral(true)) + } + } + // Undefined for (int, int) + ast::CmpOp::In | ast::CmpOp::NotIn => None, + }, + (Type::IntLiteral(_), Type::Instance(_)) => { + self.infer_binary_type_comparison(Type::builtin_int_instance(self.db), op, right) + } + (Type::Instance(_), Type::IntLiteral(_)) => { + self.infer_binary_type_comparison(left, op, Type::builtin_int_instance(self.db)) + } + // Booleans are coded as integers (False = 0, True = 1) + (Type::IntLiteral(n), Type::BooleanLiteral(b)) => self.infer_binary_type_comparison( + Type::IntLiteral(n), + op, + Type::IntLiteral(i64::from(b)), + ), + (Type::BooleanLiteral(b), Type::IntLiteral(m)) => self.infer_binary_type_comparison( + Type::IntLiteral(i64::from(b)), + op, + Type::IntLiteral(m), + ), + (Type::BooleanLiteral(a), Type::BooleanLiteral(b)) => self + .infer_binary_type_comparison( + Type::IntLiteral(i64::from(a)), + op, + Type::IntLiteral(i64::from(b)), + ), + // Lookup the rich comparison `__dunder__` methods on instances + (Type::Instance(left_class_ty), Type::Instance(right_class_ty)) => match op { + ast::CmpOp::Lt => { + perform_rich_comparison(self.db, left_class_ty, right_class_ty, "__lt__") + } + // TODO: implement mapping from `ast::CmpOp` to rich comparison methods + _ => Some(Type::Todo), + }, + // TODO: handle more types + _ => Some(Type::Todo), + } } fn infer_subscript_expression(&mut self, subscript: &ast::ExprSubscript) -> Type<'db> { @@ -2995,6 +3133,36 @@ impl StringPartsCollector { } } +/// Rich comparison in Python are the operators `==`, `!=`, `<`, `<=`, `>`, and `>=`. Their +/// behaviour can be edited for classes by implementing corresponding dunder methods. +/// This function performs rich comparison between two instances and returns the resulting type. +/// see `` +fn perform_rich_comparison<'db>( + db: &'db dyn Db, + left: ClassType<'db>, + right: ClassType<'db>, + dunder_name: &str, +) -> Option> { + // The following resource has details about the rich comparison algorithm: + // https://snarky.ca/unravelling-rich-comparison-operators/ + // + // TODO: the reflected dunder actually has priority if the r.h.s. is a strict subclass of the + // l.h.s. + // TODO: `object.__ne__` will call `__eq__` if `__ne__` is not defined + + let dunder = left.class_member(db, dunder_name); + if !dunder.is_unbound() { + // TODO: this currently gives the return type even if the arg types are invalid + // (e.g. int.__lt__ with string instance should be None, currently bool) + return dunder + .call(db, &[Type::Instance(left), Type::Instance(right)]) + .return_ty(db); + } + + // TODO: reflected dunder -- (==, ==), (!=, !=), (<, >), (>, <), (<=, >=), (>=, <=) + None +} + #[cfg(test)] mod tests { @@ -3879,6 +4047,133 @@ mod tests { Ok(()) } + #[test] + fn comparison_integer_literals() -> anyhow::Result<()> { + let mut db = setup_db(); + db.write_dedented( + "src/a.py", + r#" + a = 1 == 1 == True + b = 1 == 1 == 2 == 4 + c = False < True <= 2 < 3 != 6 + d = 1 < 1 + e = 1 > 1 + f = 1 is 1 + g = 1 is not 1 + h = 1 is 2 + i = 1 is not 7 + j = 1 <= "" and 0 < 1 + "#, + )?; + + assert_public_ty(&db, "src/a.py", "a", "Literal[True]"); + assert_public_ty(&db, "src/a.py", "b", "Literal[False]"); + assert_public_ty(&db, "src/a.py", "c", "Literal[True]"); + assert_public_ty(&db, "src/a.py", "d", "Literal[False]"); + assert_public_ty(&db, "src/a.py", "e", "Literal[False]"); + assert_public_ty(&db, "src/a.py", "f", "bool"); + assert_public_ty(&db, "src/a.py", "g", "bool"); + assert_public_ty(&db, "src/a.py", "h", "Literal[False]"); + assert_public_ty(&db, "src/a.py", "i", "Literal[True]"); + assert_public_ty(&db, "src/a.py", "j", "@Todo | Literal[True]"); + + Ok(()) + } + + #[test] + fn comparison_integer_instance() -> anyhow::Result<()> { + let mut db = setup_db(); + + db.write_dedented( + "src/a.py", + r#" + def int_instance() -> int: ... + a = 1 == int_instance() + b = 9 < int_instance() + c = int_instance() < int_instance() + "#, + )?; + + // TODO: implement lookup of `__eq__` on typeshed `int` stub + assert_public_ty(&db, "src/a.py", "a", "@Todo"); + assert_public_ty(&db, "src/a.py", "b", "bool"); + assert_public_ty(&db, "src/a.py", "c", "bool"); + + Ok(()) + } + + #[test] + fn comparison_unsupported_operators() -> anyhow::Result<()> { + let mut db = setup_db(); + db.write_dedented( + "src/a.py", + r#" + a = 1 in 7 + b = 0 not in 10 + c = object() < 5 + d = 5 < object() + "#, + )?; + + assert_file_diagnostics( + &db, + "src/a.py", + &[ + "Operator `in` is not supported for types `Literal[1]` and `Literal[7]`", + "Operator `not in` is not supported for types `Literal[0]` and `Literal[10]`", + "Operator `<` is not supported for types `object` and `Literal[5]`", + ], + ); + assert_public_ty(&db, "src/a.py", "a", "bool"); + assert_public_ty(&db, "src/a.py", "b", "bool"); + assert_public_ty(&db, "src/a.py", "c", "Unknown"); + // TODO: this should be `Unknown` but we don't check if __lt__ signature is valid for right + // operand type + assert_public_ty(&db, "src/a.py", "d", "bool"); + + Ok(()) + } + + #[test] + fn comparison_non_bool_returns() -> anyhow::Result<()> { + let mut db = setup_db(); + db.write_dedented( + "src/a.py", + r#" + from __future__ import annotations + class A: + def __lt__(self, other) -> A: ... + class B: + def __lt__(self, other) -> B: ... + class C: + def __lt__(self, other) -> C: ... + + a = A() < B() < C() + b = 0 < 1 < A() < 3 + c = 10 < 0 < A() < B() < C() + "#, + )?; + + // Walking through the example + // 1. A() < B() < C() + // 2. A() < B() and B() < C() - split in N comparison + // 3. A() and B() - evaluate outcome types + // 4. bool and bool - evaluate truthiness + // 5. A | B - union of "first true" types + assert_public_ty(&db, "src/a.py", "a", "A | B"); + // Walking through the example + // 1. 0 < 1 < A() < 3 + // 2. 0 < 1 and 1 < A() and A() < 3 - split in N comparison + // 3. True and bool and A - evaluate outcome types + // 4. True and bool and bool - evaluate truthiness + // 5. bool | A - union of "true" types + assert_public_ty(&db, "src/a.py", "b", "bool | A"); + // Short-cicuit to False + assert_public_ty(&db, "src/a.py", "c", "Literal[False]"); + + Ok(()) + } + #[test] fn bytes_type() -> anyhow::Result<()> { let mut db = setup_db();