Skip to content

Commit

Permalink
fixup! [red-knot] feat: implement integer comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
Slyces committed Oct 2, 2024
1 parent 9d98a39 commit 3fcda99
Showing 1 changed file with 88 additions and 44 deletions.
132 changes: 88 additions & 44 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2479,21 +2479,18 @@ impl<'db> TypeInferenceBuilder<'db> {
.chain(comparators.as_ref().iter())
// Evaluate expressions before iterating through pairs with `windows`
.map(|expr| self.infer_expression(expr))
// Allocation here is used to access the method `windows` not implemented on
// iterators (crates, experimental APIs or manual impl might provide it)
.collect::<Vec<Type<'db>>>()
.collect::<Vec<_>>()
.windows(2)
//.tuple_windows(2)
.zip(ops.iter())
.enumerate()
.for_each(|(i, (pair, op))| {
let (left_ty, right_ty) = (pair[0], pair[1]);
if inferred_ty.is_none() {
let is_last = i == ops.len() - 1;
let comparison_ty = self.infer_binary_type_comparison(left_ty, *op, right_ty);
// Special case for `Unknown` returned when an operation is not supported
inferred_ty = if let Type::Unknown = comparison_ty {
Some(Type::Unknown)
} else {
inferred_ty = if let Some(comparison_ty) =
self.infer_binary_type_comparison(left_ty, *op, right_ty)
{
match (comparison_ty.bool(self.db), is_last) {
(Truthiness::AlwaysTrue, false) => None,
(Truthiness::AlwaysTrue, true) => Some(Type::BooleanLiteral(true)),
Expand All @@ -2502,6 +2499,18 @@ impl<'db> TypeInferenceBuilder<'db> {
Some(builtins_symbol_ty(self.db, "bool").to_instance(self.db))
}
}
} else {
self.add_diagnostic(
AnyNodeRef::ExprCompare(compare),
"operator-unsupported",
format_args!(
"Operator \"{}\" is not supported for types {} and {}",
op,
left_ty.display(self.db),
right_ty.display(self.db)
),
);
Some(builtins_symbol_ty(self.db, "bool").to_instance(self.db))
};
}
});
Expand All @@ -2511,23 +2520,53 @@ impl<'db> TypeInferenceBuilder<'db> {
/// Infers the type of a binary comparison (e.g. '<left> == <right>'). See
/// `infer_compare_expression` for the higher level logic dealing with multi-comparison
/// expressions.
///
/// If the operation is not supported, return None (we need upstream context to emit a
/// diagnostic).
fn infer_binary_type_comparison(
&mut self,
left: Type<'db>,
op: ast::CmpOp,
right: Type<'db>,
) -> Type<'db> {
) -> Option<Type<'db>> {
// Note: identity (is, is not) is unreliable in Python and not part of the language specs.
// - `[ast::CompOp::Is]`: return `false` if different, `bool` if the same
// - `[ast::CompOp::IsNot]`: return `true` if different, `bool` if the same
match (left, right) {
(Type::IntLiteral(n), Type::IntLiteral(m)) => Type::BooleanLiteral(match op {
ast::CmpOp::Eq | ast::CmpOp::Is => n == m,
ast::CmpOp::NotEq | ast::CmpOp::IsNot => n != m,
ast::CmpOp::Lt => n < m,
ast::CmpOp::LtE => n <= m,
ast::CmpOp::Gt => n > m,
ast::CmpOp::GtE => n >= m,
(Type::IntLiteral(n), Type::IntLiteral(m)) => match op {
ast::CmpOp::Eq => Some(Type::BooleanLiteral(n == m)),
ast::CmpOp::NotEq => Some(Type::BooleanLiteral(n != m)),
ast::CmpOp::Lt => Some(Type::BooleanLiteral(n < m)),
ast::CmpOp::LtE => Some(Type::BooleanLiteral(n <= m)),
ast::CmpOp::Gt => Some(Type::BooleanLiteral(n > m)),
ast::CmpOp::GtE => Some(Type::BooleanLiteral(n >= m)),
ast::CmpOp::Is => {
if n == m {
Some(builtins_symbol_ty(self.db, "bool").to_instance(self.db))
} else {
Some(Type::BooleanLiteral(false))
}
}
ast::CmpOp::IsNot => {
if n == m {
Some(builtins_symbol_ty(self.db, "bool").to_instance(self.db))
} else {
Some(Type::BooleanLiteral(true))
}
}
// Undefined for (int, int)
ast::CmpOp::In | ast::CmpOp::NotIn => false,
}),
ast::CmpOp::In | ast::CmpOp::NotIn => None,
},
(Type::IntLiteral(_), Type::Instance(_)) => self.infer_binary_type_comparison(
builtins_symbol_ty(self.db, "int").to_instance(self.db),
op,
right,
),
(Type::Instance(_), Type::IntLiteral(_)) => self.infer_binary_type_comparison(
left,
op,
builtins_symbol_ty(self.db, "int").to_instance(self.db),
),
// Booleans are coded as integers (False = 0, True = 1)
(Type::IntLiteral(n), Type::BooleanLiteral(b)) => self.infer_binary_type_comparison(
Type::IntLiteral(n),
Expand All @@ -2539,25 +2578,6 @@ impl<'db> TypeInferenceBuilder<'db> {
op,
Type::IntLiteral(m),
),
// TODO: move this logic wherever we handle class instances
(Type::IntLiteral(_), Type::Instance(class_type))
| (Type::Instance(class_type), Type::IntLiteral(_)) => class_type
.is_stdlib_symbol(self.db, "builtins", "int")
.then(|| builtins_symbol_ty(self.db, "bool").to_instance(self.db))
.unwrap_or(Type::Unknown),
// TODO: type-generic cases (using _) should be moved after the code handling
// classes/instances as they can override the behaviour with literal types by
// implementing `__eq__`, ...
(Type::IntLiteral(_), _) | (_, Type::IntLiteral(_)) => match op {
ast::CmpOp::Eq | ast::CmpOp::Is => Type::BooleanLiteral(false),
ast::CmpOp::NotEq | ast::CmpOp::IsNot => Type::BooleanLiteral(true),
ast::CmpOp::Lt | ast::CmpOp::LtE | ast::CmpOp::Gt | ast::CmpOp::GtE => {
// TODO: this fails at runtime, we might need a diagnostic
Type::Unknown
}
// TODO: this is valid for some types (tuples, list, dict, set, ...)
ast::CmpOp::In | ast::CmpOp::NotIn => Type::Unknown,
},
(Type::BooleanLiteral(a), Type::BooleanLiteral(b)) => self
.infer_binary_type_comparison(
Type::IntLiteral(i64::from(a)),
Expand Down Expand Up @@ -3984,7 +4004,7 @@ mod tests {
f = 1 is 1
g = 1 is not 1
h = 1 is 2
i = 1 is ''
i = 1 is not 7
j = 1 <= ""
"#,
)?;
Expand All @@ -3994,11 +4014,11 @@ mod tests {
assert_public_ty(&db, "src/a.py", "c", "Literal[True]");
assert_public_ty(&db, "src/a.py", "d", "Literal[False]");
assert_public_ty(&db, "src/a.py", "e", "Literal[False]");
assert_public_ty(&db, "src/a.py", "f", "Literal[True]");
assert_public_ty(&db, "src/a.py", "g", "Literal[False]");
assert_public_ty(&db, "src/a.py", "f", "bool");
assert_public_ty(&db, "src/a.py", "g", "bool");
assert_public_ty(&db, "src/a.py", "h", "Literal[False]");
assert_public_ty(&db, "src/a.py", "i", "Literal[False]");
assert_public_ty(&db, "src/a.py", "j", "Unknown");
assert_public_ty(&db, "src/a.py", "i", "Literal[True]");
assert_public_ty(&db, "src/a.py", "j", "bool");

Ok(())
}
Expand All @@ -4019,8 +4039,32 @@ mod tests {

assert_public_ty(&db, "src/a.py", "a", "bool");
assert_public_ty(&db, "src/a.py", "b", "bool");
// TODO: handling int/int comparison will be done when implementing instance comparison
assert_public_ty(&db, "src/a.py", "c", "Unknown");
assert_public_ty(&db, "src/a.py", "c", "bool");

Ok(())
}

#[test]
fn comparison_unsupported_operators() -> anyhow::Result<()> {
let mut db = setup_db();
db.write_dedented(
"src/a.py",
r#"
a = 1 in 7
b = 0 not in 10
"#,
)?;

assert_public_ty(&db, "src/a.py", "a", "bool");
assert_public_ty(&db, "src/a.py", "b", "bool");
assert_file_diagnostics(
&db,
"src/a.py",
&[
"Operator \"in\" is not supported for types Literal[1] and Literal[7]",
"Operator \"not in\" is not supported for types Literal[0] and Literal[10]",
],
);

Ok(())
}
Expand Down

0 comments on commit 3fcda99

Please sign in to comment.