Skip to content

Commit

Permalink
fixup! [red-knot] feat: implement integer comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
Slyces committed Oct 3, 2024
1 parent 3fcda99 commit b74c007
Showing 1 changed file with 160 additions and 57 deletions.
217 changes: 160 additions & 57 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2434,19 +2434,33 @@ impl<'db> TypeInferenceBuilder<'db> {
op,
values,
} = bool_op;
Self::infer_chained_boolean_types(
self.db,
*op,
values.iter().map(|value| self.infer_expression(value)),
values.len(),
)
}

/// Computes the output of a chain of (one) boolean operation, consuming as input an iterator
/// of types. The iterator is consumed even if the boolean evaluation can be short-circuited,
/// in order to ensure the invariant that all expressions are evaluated when inferring types.
fn infer_chained_boolean_types<T: Into<Type<'db>>>(
db: &'db dyn Db,
op: ast::BoolOp,
values: impl IntoIterator<Item = T>,
n_values: usize,
) -> Type<'db> {
let mut done = false;
UnionType::from_elements(
self.db,
values.iter().enumerate().map(|(i, value)| {
// We need to infer the type of every expression (that's an invariant maintained by
// type inference), even if we can short-circuit boolean evaluation of some of
// those types.
let value_ty = self.infer_expression(value);
db,
values.into_iter().enumerate().map(|(i, ty)| {
if done {
Type::Never
} else {
let is_last = i == values.len() - 1;
match (value_ty.bool(self.db), is_last, op) {
let is_last = i == n_values - 1;
let value_ty: Type<'db> = ty.into();
match (value_ty.bool(db), is_last, op) {
(Truthiness::Ambiguous, _, _) => value_ty,
(Truthiness::AlwaysTrue, false, ast::BoolOp::And) => Type::Never,
(Truthiness::AlwaysFalse, false, ast::BoolOp::Or) => Type::Never,
Expand Down Expand Up @@ -2474,47 +2488,50 @@ impl<'db> TypeInferenceBuilder<'db> {
// > Formally, if `a, b, c, …, y, z` are expressions and `op1, op2, …, opN` are comparison
// > operators, then `a op1 b op2 c ... y opN z` is equivalent to a `op1 b and b op2 c and
// ... > y opN z`, except that each expression is evaluated at most once.
let mut inferred_ty: Option<Type<'db>> = None;
std::iter::once(left.as_ref())
.chain(comparators.as_ref().iter())
// Evaluate expressions before iterating through pairs with `windows`
.map(|expr| self.infer_expression(expr))
.collect::<Vec<_>>()
.windows(2)
//.tuple_windows(2)
.zip(ops.iter())
.enumerate()
.for_each(|(i, (pair, op))| {
let (left_ty, right_ty) = (pair[0], pair[1]);
if inferred_ty.is_none() {
let is_last = i == ops.len() - 1;
inferred_ty = if let Some(comparison_ty) =
self.infer_binary_type_comparison(left_ty, *op, right_ty)
{
match (comparison_ty.bool(self.db), is_last) {
(Truthiness::AlwaysTrue, false) => None,
(Truthiness::AlwaysTrue, true) => Some(Type::BooleanLiteral(true)),
(Truthiness::AlwaysFalse, _) => Some(Type::BooleanLiteral(false)),
(Truthiness::Ambiguous, _) => {
Some(builtins_symbol_ty(self.db, "bool").to_instance(self.db))
//
// As some operators (==, !=, <, <=, >, >=) *can* return an arbitrary type, the logic below
// is shared with the one in `infer_binary_type_comparison`.
Self::infer_chained_boolean_types(
self.db,
ast::BoolOp::And,
std::iter::once(left.as_ref())
.chain(comparators.as_ref().iter())
// Evaluate expressions before iterating through pairs with `windows`
.map(|expr| self.infer_expression(expr))
.collect::<Vec<_>>()
.windows(2)
//.tuple_windows(2)
.zip(ops.iter())
.map(|(pair, op)| {
let (left_ty, right_ty) = (pair[0], pair[1]);
self.infer_binary_type_comparison(left_ty, *op, right_ty)
.unwrap_or_else(|| {
// Handle unsupported operators (diagnostic, `bool`/`Unknown` outcome)
self.add_diagnostic(
AnyNodeRef::ExprCompare(compare),
"operator-unsupported",
format_args!(
"Operator \"{}\" is not supported for types {} and {}",
op,
left_ty.display(self.db),
right_ty.display(self.db)
),
);
match op {
// `in, not in, is, is not` always return bool instances
ast::CmpOp::In
| ast::CmpOp::NotIn
| ast::CmpOp::Is
| ast::CmpOp::IsNot => {
builtins_symbol_ty(self.db, "bool").to_instance(self.db)
}
// Other operators can return arbitrary types
_ => Type::Unknown,
}
}
} else {
self.add_diagnostic(
AnyNodeRef::ExprCompare(compare),
"operator-unsupported",
format_args!(
"Operator \"{}\" is not supported for types {} and {}",
op,
left_ty.display(self.db),
right_ty.display(self.db)
),
);
Some(builtins_symbol_ty(self.db, "bool").to_instance(self.db))
};
}
});
inferred_ty.expect("A type should always be inferred on the last comparison in the chain")
})
}),
ops.len(),
)
}

/// Infers the type of a binary comparison (e.g. '<left> == <right>'). See
Expand Down Expand Up @@ -2584,6 +2601,14 @@ impl<'db> TypeInferenceBuilder<'db> {
op,
Type::IntLiteral(i64::from(b)),
),
// Lookup the rich comparison `__dunder__` methods on instances
(Type::Instance(left_class_ty), Type::Instance(right_class_ty)) => match op {
ast::CmpOp::Lt => {
perform_rich_comparison(self.db, left_class_ty, right_class_ty, "__lt__")
}
// TODO: implement mapping from `ast::CmpOp` to rich comparison methods
_ => Some(Type::Todo),
},
// TODO: handle more types
_ => Some(Type::Todo),
}
Expand Down Expand Up @@ -3106,6 +3131,36 @@ impl StringPartsCollector {
}
}

/// Rich comparison in Python are the operators `==`, `!=`, `<`, `<=`, `>`, and `>=`. Their
/// behaviour can be edited for classes by implementing corresponding dunder methods.
/// This function performs rich comparison between two instances and returns the resulting type.
/// see `<https://docs.python.org/3/reference/datamodel.html#object.__lt__>`
fn perform_rich_comparison<'db>(
db: &'db dyn Db,
left: ClassType<'db>,
right: ClassType<'db>,
dunder_name: &str,
) -> Option<Type<'db>> {
// The following resource has details about the rich comparison algorithm:
// https://snarky.ca/unravelling-rich-comparison-operators/
//
// TODO: the reflected dunder actually has priority if the r.h.s. is a strict subclass of the
// l.h.s.
// TODO: __neq__ in object will call __eq__ if __neq__ is not defined

let dunder = left.class_member(db, dunder_name);
if !dunder.is_unbound() {
// TODO: this currently gives the return type even if the arg types are invalid
// (e.g. int.__lt__ with string instance should be None, currently bool)
return dunder
.call(db, &[Type::Instance(left), Type::Instance(right)])
.return_ty(db);
}

// TODO: reflected dunder -- (==, ==), (!=, !=), (<, >), (>, <), (<=, >=), (>=, <=)
None
}

#[cfg(test)]
mod tests {

Expand Down Expand Up @@ -4005,7 +4060,7 @@ mod tests {
g = 1 is not 1
h = 1 is 2
i = 1 is not 7
j = 1 <= ""
j = 1 <= "" and 0 < 1
"#,
)?;

Expand All @@ -4018,7 +4073,7 @@ mod tests {
assert_public_ty(&db, "src/a.py", "g", "bool");
assert_public_ty(&db, "src/a.py", "h", "Literal[False]");
assert_public_ty(&db, "src/a.py", "i", "Literal[True]");
assert_public_ty(&db, "src/a.py", "j", "bool");
assert_public_ty(&db, "src/a.py", "j", "@Todo | Literal[True]");

Ok(())
}
Expand All @@ -4030,14 +4085,15 @@ mod tests {
db.write_dedented(
"src/a.py",
r#"
def foo() -> int: ...
a = 1 == foo()
b = 9 < foo()
c = foo() != foo()
def int_instance() -> int: ...
a = 1 == int_instance()
b = 9 < int_instance()
c = int_instance() < int_instance()
"#,
)?;

assert_public_ty(&db, "src/a.py", "a", "bool");
// TODO: implement lookup of `__eq__` on typeshed `int` stub
assert_public_ty(&db, "src/a.py", "a", "@Todo");
assert_public_ty(&db, "src/a.py", "b", "bool");
assert_public_ty(&db, "src/a.py", "c", "bool");

Expand All @@ -4052,19 +4108,66 @@ mod tests {
r#"
a = 1 in 7
b = 0 not in 10
c = object() < 5
d = 5 < object()
"#,
)?;

assert_public_ty(&db, "src/a.py", "a", "bool");
assert_public_ty(&db, "src/a.py", "b", "bool");
assert_file_diagnostics(
&db,
"src/a.py",
&[
"Operator \"in\" is not supported for types Literal[1] and Literal[7]",
"Operator \"not in\" is not supported for types Literal[0] and Literal[10]",
"Operator \"<\" is not supported for types object and Literal[5]",
],
);
assert_public_ty(&db, "src/a.py", "a", "bool");
assert_public_ty(&db, "src/a.py", "b", "bool");
assert_public_ty(&db, "src/a.py", "c", "Unknown");
// TODO: this should be `Unknown` but we don't check if __lt__ signature is valid for right
// operand type
assert_public_ty(&db, "src/a.py", "d", "bool");

Ok(())
}

#[test]
fn comparison_non_bool_returns() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
r#"
from __future__ import annotations
class A:
def __lt__(self, other) -> A: ...
class B:
def __lt__(self, other) -> B: ...
class C:
def __lt__(self, other) -> C: ...
a = A() < B() < C()
b = 0 < 1 < A() < 3
c = 10 < 0 < A() < B() < C()
"#,
)?;

// Walking through the example
// 1. A() < B() < C()
// 2. A() < B() and B() < C() - split in N comparison
// 3. A() and B() - evaluate outcome types
// 4. bool and bool - evaluate truthiness
// 5. A | B - union of "first true" types
assert_public_ty(&db, "src/a.py", "a", "A | B");
// Walking through the example
// 1. 0 < 1 < A() < 3
// 2. 0 < 1 and 1 < A() and A() < 3 - split in N comparison
// 3. True and bool and A - evaluate outcome types
// 4. True and bool and bool - evaluate truthiness
// 5. bool | A - union of "true" types
assert_public_ty(&db, "src/a.py", "b", "bool | A");
// Short-cicuit to False
assert_public_ty(&db, "src/a.py", "c", "Literal[False]");

Ok(())
}
Expand Down

0 comments on commit b74c007

Please sign in to comment.