Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[red-knot] feat: implement integer comparison #13571

Merged
merged 13 commits into from
Oct 4, 2024
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/red_knot_python_semantic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ bitflags = { workspace = true }
camino = { workspace = true }
compact_str = { workspace = true }
countme = { workspace = true }
itertools = { workspace = true}
ordermap = { workspace = true }
salsa = { workspace = true }
thiserror = { workspace = true }
Expand Down
323 changes: 310 additions & 13 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
//! stringified annotations. We have a fourth Salsa query for inferring the deferred types
//! associated with a particular definition. Scope-level inference infers deferred types for all
//! definitions once the rest of the types in the scope have been inferred.
use itertools::Itertools;
use std::num::NonZeroU32;

use ruff_db::files::File;
Expand Down Expand Up @@ -2434,19 +2435,33 @@ impl<'db> TypeInferenceBuilder<'db> {
op,
values,
} = bool_op;
Self::infer_chained_boolean_types(
self.db,
*op,
values.iter().map(|value| self.infer_expression(value)),
values.len(),
)
}

/// Computes the output of a chain of (one) boolean operation, consuming as input an iterator
/// of types. The iterator is consumed even if the boolean evaluation can be short-circuited,
/// in order to ensure the invariant that all expressions are evaluated when inferring types.
fn infer_chained_boolean_types<T: Into<Type<'db>>>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't need to be generic over Into<Type<'db>>, it can just take an IntoIterator over Type<'db>. This change compiles and passes all tests:

--- a/crates/red_knot_python_semantic/src/types/infer.rs
+++ b/crates/red_knot_python_semantic/src/types/infer.rs
@@ -2446,10 +2446,10 @@ impl<'db> TypeInferenceBuilder<'db> {
     /// Computes the output of a chain of (one) boolean operation, consuming as input an iterator
     /// of types. The iterator is consumed even if the boolean evaluation can be short-circuited,
     /// in order to ensure the invariant that all expressions are evaluated when inferring types.
-    fn infer_chained_boolean_types<T: Into<Type<'db>>>(
+    fn infer_chained_boolean_types(
         db: &'db dyn Db,
         op: ast::BoolOp,
-        values: impl IntoIterator<Item = T>,
+        values: impl IntoIterator<Item = Type<'db>>,
         n_values: usize,
     ) -> Type<'db> {
         let mut done = false;
@@ -2460,17 +2460,16 @@ impl<'db> TypeInferenceBuilder<'db> {
                     Type::Never
                 } else {
                     let is_last = i == n_values - 1;
-                    let value_ty: Type<'db> = ty.into();
-                    match (value_ty.bool(db), is_last, op) {
-                        (Truthiness::Ambiguous, _, _) => value_ty,
+                    match (ty.bool(db), is_last, op) {
+                        (Truthiness::Ambiguous, _, _) => ty,
                         (Truthiness::AlwaysTrue, false, ast::BoolOp::And) => Type::Never,
                         (Truthiness::AlwaysFalse, false, ast::BoolOp::Or) => Type::Never,
                         (Truthiness::AlwaysFalse, _, ast::BoolOp::And)
                         | (Truthiness::AlwaysTrue, _, ast::BoolOp::Or) => {
                             done = true;
-                            value_ty
+                            ty
                         }
-                        (_, true, _) => value_ty,
+                        (_, true, _) => ty,
                     }
                 }
             }),

db: &'db dyn Db,
op: ast::BoolOp,
values: impl IntoIterator<Item = T>,
n_values: usize,
) -> Type<'db> {
let mut done = false;
UnionType::from_elements(
self.db,
values.iter().enumerate().map(|(i, value)| {
// We need to infer the type of every expression (that's an invariant maintained by
// type inference), even if we can short-circuit boolean evaluation of some of
// those types.
let value_ty = self.infer_expression(value);
db,
values.into_iter().enumerate().map(|(i, ty)| {
if done {
Type::Never
} else {
let is_last = i == values.len() - 1;
match (value_ty.bool(self.db), is_last, op) {
let is_last = i == n_values - 1;
let value_ty: Type<'db> = ty.into();
match (value_ty.bool(db), is_last, op) {
(Truthiness::Ambiguous, _, _) => value_ty,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my comment on the last test, but maybe we could fix this by having a special case for builtins.bool (which will be a very common type here)

Suggested change
(Truthiness::Ambiguous, _, _) => value_ty,
(Truthiness::Ambiguous, false, ast::BoolOp::And) => match value_ty {
// Ambiguous types that are not the last in the `and` chain can only be
// returned if they are falsy. In the special case of `builtins.bool`,
// being falsy is `Literal[False]`.
// TODO: we could do this optimisation for other literal that have a
// single falsy value (`""`, `0`, ...?)
Type::Instance(class) => class
.is_stdlib_symbol(db, "builtins", "bool")
.then(|| Type::BooleanLiteral(false))
.unwrap_or_else(|| value_ty),
_ => value_ty,
},
(Truthiness::Ambiguous, _, _) => value_ty,

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's do this as a follow-up, I created #13632 with some comments about it.

(Truthiness::AlwaysTrue, false, ast::BoolOp::And) => Type::Never,
(Truthiness::AlwaysFalse, false, ast::BoolOp::Or) => Type::Never,
Expand All @@ -2466,16 +2481,141 @@ impl<'db> TypeInferenceBuilder<'db> {
let ast::ExprCompare {
range: _,
left,
ops: _,
ops,
comparators,
} = compare;

self.infer_expression(left);
// TODO actually handle ops and return correct type
for right in comparators.as_ref() {
self.infer_expression(right);
for expr in comparators {
self.infer_expression(expr);
}
carljm marked this conversation as resolved.
Show resolved Hide resolved

// https://docs.python.org/3/reference/expressions.html#comparisons
// > Formally, if `a, b, c, …, y, z` are expressions and `op1, op2, …, opN` are comparison
// > operators, then `a op1 b op2 c ... y opN z` is equivalent to a `op1 b and b op2 c and
// ... > y opN z`, except that each expression is evaluated at most once.
//
// As some operators (==, !=, <, <=, >, >=) *can* return an arbitrary type, the logic below
// is shared with the one in `infer_binary_type_comparison`.
Self::infer_chained_boolean_types(
self.db,
ast::BoolOp::And,
std::iter::once(left.as_ref())
.chain(comparators.as_ref().iter())
.tuple_windows::<(_, _)>()
.zip(ops.iter())
.map(|((left, right), op)| {
let left_ty = self
.types
.expression_ty(left.scoped_ast_id(self.db, self.scope));
let right_ty = self
.types
.expression_ty(right.scoped_ast_id(self.db, self.scope));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've got enough occurrences of this verbose pattern now, it's high time for a helper method:

+    /// Get the already-inferred type of an expression node.
+    ///
+    /// PANIC if no type has been inferred for this node.
+    fn expression_ty(&self, expr: &ast::Expr) -> Type<'db> {
+        self.types
+            .expression_ty(expr.scoped_ast_id(self.db, self.scope))
+    }
+


self.infer_binary_type_comparison(left_ty, *op, right_ty)
.unwrap_or_else(|| {
// Handle unsupported operators (diagnostic, `bool`/`Unknown` outcome)
self.add_diagnostic(
AnyNodeRef::ExprCompare(compare),
"operator-unsupported",
format_args!(
"Operator \"{}\" is not supported for types {} and {}",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, to match our usual diagnostic message style:

Suggested change
"Operator \"{}\" is not supported for types {} and {}",
"Operator `{}` is not supported for types `{}` and `{}`",

op,
left_ty.display(self.db),
right_ty.display(self.db)
),
);
match op {
// `in, not in, is, is not` always return bool instances
ast::CmpOp::In
| ast::CmpOp::NotIn
| ast::CmpOp::Is
| ast::CmpOp::IsNot => {
builtins_symbol_ty(self.db, "bool").to_instance(self.db)
}
// Other operators can return arbitrary types
_ => Type::Unknown,
}
})
}),
ops.len(),
)
}

/// Infers the type of a binary comparison (e.g. 'left == right'). See
/// `infer_compare_expression` for the higher level logic dealing with multi-comparison
/// expressions.
///
/// If the operation is not supported, return None (we need upstream context to emit a
/// diagnostic).
fn infer_binary_type_comparison(
&mut self,
left: Type<'db>,
op: ast::CmpOp,
right: Type<'db>,
) -> Option<Type<'db>> {
// Note: identity (is, is not) is unreliable in Python and not part of the language specs.
// - `[ast::CompOp::Is]`: return `false` if different, `bool` if the same
// - `[ast::CompOp::IsNot]`: return `true` if different, `bool` if the same
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor clarification

Suggested change
// Note: identity (is, is not) is unreliable in Python and not part of the language specs.
// - `[ast::CompOp::Is]`: return `false` if different, `bool` if the same
// - `[ast::CompOp::IsNot]`: return `true` if different, `bool` if the same
// Note: identity (is, is not) for equal builtin types is unreliable and not part of the language spec.
// - `[ast::CompOp::Is]`: return `false` if unequal, `bool` if equal
// - `[ast::CompOp::IsNot]`: return `true` if unequal, `bool` if equal

match (left, right) {
(Type::IntLiteral(n), Type::IntLiteral(m)) => match op {
ast::CmpOp::Eq => Some(Type::BooleanLiteral(n == m)),
ast::CmpOp::NotEq => Some(Type::BooleanLiteral(n != m)),
ast::CmpOp::Lt => Some(Type::BooleanLiteral(n < m)),
ast::CmpOp::LtE => Some(Type::BooleanLiteral(n <= m)),
ast::CmpOp::Gt => Some(Type::BooleanLiteral(n > m)),
ast::CmpOp::GtE => Some(Type::BooleanLiteral(n >= m)),
ast::CmpOp::Is => {
if n == m {
Some(builtins_symbol_ty(self.db, "bool").to_instance(self.db))
} else {
Some(Type::BooleanLiteral(false))
}
}
ast::CmpOp::IsNot => {
if n == m {
Some(builtins_symbol_ty(self.db, "bool").to_instance(self.db))
} else {
Some(Type::BooleanLiteral(true))
}
}
// Undefined for (int, int)
ast::CmpOp::In | ast::CmpOp::NotIn => None,
},
(Type::IntLiteral(_), Type::Instance(_)) => {
self.infer_binary_type_comparison(Type::builtin_int_instance(self.db), op, right)
}
(Type::Instance(_), Type::IntLiteral(_)) => {
self.infer_binary_type_comparison(left, op, Type::builtin_int_instance(self.db))
}
// Booleans are coded as integers (False = 0, True = 1)
(Type::IntLiteral(n), Type::BooleanLiteral(b)) => self.infer_binary_type_comparison(
Type::IntLiteral(n),
op,
Type::IntLiteral(i64::from(b)),
),
(Type::BooleanLiteral(b), Type::IntLiteral(m)) => self.infer_binary_type_comparison(
Type::IntLiteral(i64::from(b)),
op,
Type::IntLiteral(m),
),
(Type::BooleanLiteral(a), Type::BooleanLiteral(b)) => self
.infer_binary_type_comparison(
Type::IntLiteral(i64::from(a)),
op,
Type::IntLiteral(i64::from(b)),
),
// Lookup the rich comparison `__dunder__` methods on instances
(Type::Instance(left_class_ty), Type::Instance(right_class_ty)) => match op {
ast::CmpOp::Lt => {
perform_rich_comparison(self.db, left_class_ty, right_class_ty, "__lt__")
}
// TODO: implement mapping from `ast::CmpOp` to rich comparison methods
_ => Some(Type::Todo),
},
// TODO: handle more types
_ => Some(Type::Todo),
}
Type::Todo
}

fn infer_subscript_expression(&mut self, subscript: &ast::ExprSubscript) -> Type<'db> {
Expand Down Expand Up @@ -2995,6 +3135,36 @@ impl StringPartsCollector {
}
}

/// Rich comparison in Python are the operators `==`, `!=`, `<`, `<=`, `>`, and `>=`. Their
/// behaviour can be edited for classes by implementing corresponding dunder methods.
/// This function performs rich comparison between two instances and returns the resulting type.
/// see `<https://docs.python.org/3/reference/datamodel.html#object.__lt__>`
Comment on lines +3136 to +3139
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This (and all the TODO comments in this function) are a fantastic resource for someone to flesh this out later -- thank you!!

fn perform_rich_comparison<'db>(
db: &'db dyn Db,
left: ClassType<'db>,
right: ClassType<'db>,
dunder_name: &str,
) -> Option<Type<'db>> {
// The following resource has details about the rich comparison algorithm:
// https://snarky.ca/unravelling-rich-comparison-operators/
//
// TODO: the reflected dunder actually has priority if the r.h.s. is a strict subclass of the
// l.h.s.
// TODO: `object.__ne__` will call `__eq__` if `__ne__` is not defined

let dunder = left.class_member(db, dunder_name);
if !dunder.is_unbound() {
// TODO: this currently gives the return type even if the arg types are invalid
// (e.g. int.__lt__ with string instance should be None, currently bool)
return dunder
.call(db, &[Type::Instance(left), Type::Instance(right)])
.return_ty(db);
}

// TODO: reflected dunder -- (==, ==), (!=, !=), (<, >), (>, <), (<=, >=), (>=, <=)
None
}

#[cfg(test)]
mod tests {

Expand Down Expand Up @@ -3879,6 +4049,133 @@ mod tests {
Ok(())
}

#[test]
fn comparison_integer_literals() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
r#"
a = 1 == 1 == True
b = 1 == 1 == 2 == 4
c = False < True <= 2 < 3 != 6
d = 1 < 1
e = 1 > 1
f = 1 is 1
g = 1 is not 1
h = 1 is 2
i = 1 is not 7
j = 1 <= "" and 0 < 1
"#,
)?;

assert_public_ty(&db, "src/a.py", "a", "Literal[True]");
assert_public_ty(&db, "src/a.py", "b", "Literal[False]");
assert_public_ty(&db, "src/a.py", "c", "Literal[True]");
assert_public_ty(&db, "src/a.py", "d", "Literal[False]");
assert_public_ty(&db, "src/a.py", "e", "Literal[False]");
assert_public_ty(&db, "src/a.py", "f", "bool");
assert_public_ty(&db, "src/a.py", "g", "bool");
assert_public_ty(&db, "src/a.py", "h", "Literal[False]");
assert_public_ty(&db, "src/a.py", "i", "Literal[True]");
assert_public_ty(&db, "src/a.py", "j", "@Todo | Literal[True]");

Ok(())
}

#[test]
fn comparison_integer_instance() -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"src/a.py",
r#"
def int_instance() -> int: ...
a = 1 == int_instance()
b = 9 < int_instance()
c = int_instance() < int_instance()
"#,
)?;

// TODO: implement lookup of `__eq__` on typeshed `int` stub
assert_public_ty(&db, "src/a.py", "a", "@Todo");
assert_public_ty(&db, "src/a.py", "b", "bool");
assert_public_ty(&db, "src/a.py", "c", "bool");

Ok(())
}

#[test]
fn comparison_unsupported_operators() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
r#"
a = 1 in 7
b = 0 not in 10
c = object() < 5
d = 5 < object()
"#,
)?;

assert_file_diagnostics(
&db,
"src/a.py",
&[
"Operator \"in\" is not supported for types Literal[1] and Literal[7]",
"Operator \"not in\" is not supported for types Literal[0] and Literal[10]",
"Operator \"<\" is not supported for types object and Literal[5]",
],
);
assert_public_ty(&db, "src/a.py", "a", "bool");
assert_public_ty(&db, "src/a.py", "b", "bool");
assert_public_ty(&db, "src/a.py", "c", "Unknown");
// TODO: this should be `Unknown` but we don't check if __lt__ signature is valid for right
// operand type
assert_public_ty(&db, "src/a.py", "d", "bool");

Ok(())
}

#[test]
fn comparison_non_bool_returns() -> anyhow::Result<()> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test, and its comments, are fantastic!

let mut db = setup_db();
db.write_dedented(
"src/a.py",
r#"
from __future__ import annotations
class A:
def __lt__(self, other) -> A: ...
class B:
def __lt__(self, other) -> B: ...
class C:
def __lt__(self, other) -> C: ...

a = A() < B() < C()
b = 0 < 1 < A() < 3
c = 10 < 0 < A() < B() < C()
"#,
)?;

// Walking through the example
// 1. A() < B() < C()
// 2. A() < B() and B() < C() - split in N comparison
// 3. A() and B() - evaluate outcome types
// 4. bool and bool - evaluate truthiness
// 5. A | B - union of "first true" types
assert_public_ty(&db, "src/a.py", "a", "A | B");
// Walking through the example
// 1. 0 < 1 < A() < 3
// 2. 0 < 1 and 1 < A() and A() < 3 - split in N comparison
// 3. True and bool and A - evaluate outcome types
// 4. True and bool and bool - evaluate truthiness
// 5. bool | A - union of "true" types
assert_public_ty(&db, "src/a.py", "b", "bool | A");
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should that be Literal[False] | A? There is no runtime path where we can end up with Literal[True] here as if the second comparison 1 < A() evaluates to True we would return the value of the last comparison which is A.
Only path where we don't take the value of A here is if we had a Literal[False] in the earlier steps.

But that is more of a comment on how we handle chained and I guess?

Copy link
Contributor

@carljm carljm Oct 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great point! I think this can be done as a separate follow-up so we can go ahead and get this PR merged; it's kind of separate from this PR as it's an improvement to the existing chained-boolean-expression logic.

I think the logic here is that when we have an actual bool type in a chained boolean expression, in non-last-position, rather than adding bool to the union we can instead add Literal[False] (for AND) or Literal[True] (for OR).

There's a more generalized version of this where we add Falsy and Truthy types, and then we'd always intersect any type in that position (not just bool) with Falsy or Truthy, and then for bools that would simplify out in the intersection (e.g. bool & Falsy is Literal[False]). I think it's likely we end up doing this for type narrowing anyway, but I don't have strong feelings about whether we go straight for the generalized solution or start with a bool-specific implementation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created #13632 to track this as a follow-up.

// Short-cicuit to False
assert_public_ty(&db, "src/a.py", "c", "Literal[False]");

Ok(())
}

#[test]
fn bytes_type() -> anyhow::Result<()> {
let mut db = setup_db();
Expand Down
Loading