Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[red-knot] Implement type narrowing for boolean conditionals #14037

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
# Narrowing for conditionals with boolean expressions

## Narrowing in `and` conditional

```py
class A: ...
class B: ...

def instance() -> A | B:
return A()

x = instance()

if isinstance(x, A) and isinstance(x, B):
reveal_type(x) # revealed: A & B
else:
reveal_type(x) # revealed: B & ~A | A & ~B
```

## Arms might not add narrowing constraints

```py
class A: ...
class B: ...

def bool_instance() -> bool:
return True

def instance() -> A | B:
return A()

x = instance()

if isinstance(x, A) and bool_instance():
reveal_type(x) # revealed: A
else:
reveal_type(x) # revealed: A | B

if bool_instance() and isinstance(x, A):
reveal_type(x) # revealed: A
else:
reveal_type(x) # revealed: A | B

reveal_type(x) # revealed: A | B
```

## Statically known arms

```py
class A: ...
class B: ...

def instance() -> A | B:
return A()

x = instance()

if isinstance(x, A) and True:
reveal_type(x) # revealed: A
else:
reveal_type(x) # revealed: B & ~A

if True and isinstance(x, A):
reveal_type(x) # revealed: A
else:
reveal_type(x) # revealed: B & ~A

if False and isinstance(x, A):
# TODO: should emit an `unreachable code` diagnostic
reveal_type(x) # revealed: A
else:
reveal_type(x) # revealed: A | B

if False or isinstance(x, A):
reveal_type(x) # revealed: A
else:
reveal_type(x) # revealed: B & ~A

if True or isinstance(x, A):
reveal_type(x) # revealed: A | B
else:
# TODO: should emit an `unreachable code` diagnostic
reveal_type(x) # revealed: B & ~A

reveal_type(x) # revealed: A | B
```

## The type of multiple symbols can be narrowed down

```py
class A: ...
class B: ...

def instance() -> A | B:
return A()

x = instance()
y = instance()

if isinstance(x, A) and isinstance(y, B):
reveal_type(x) # revealed: A
reveal_type(y) # revealed: B
else:
# No narrowing: Only-one or both checks might have failed
reveal_type(x) # revealed: A | B
reveal_type(y) # revealed: A | B

reveal_type(x) # revealed: A | B
reveal_type(y) # revealed: A | B
```

## Narrowing in `or` conditional

```py
class A: ...
class B: ...
class C: ...

def instance() -> A | B | C:
return A()

x = instance()

if isinstance(x, A) or isinstance(x, B):
reveal_type(x) # revealed: A | B
else:
reveal_type(x) # revealed: C & ~A & ~B
```

## In `or`, all arms should add constraint in order to narrow

```py
class A: ...
class B: ...
class C: ...

def instance() -> A | B | C:
return A()

def bool_instance() -> bool:
return True

x = instance()

if isinstance(x, A) or isinstance(x, B) or bool_instance():
reveal_type(x) # revealed: A | B | C
else:
reveal_type(x) # revealed: C & ~A & ~B
```

## in `or`, all arms should narrow the same set of symbols

```py
class A: ...
class B: ...
class C: ...

def instance() -> A | B | C:
return A()

x = instance()
y = instance()

if isinstance(x, A) or isinstance(y, A):
# The predicate might be satisfied by the right side, so the type of `x` can’t be narrowed down here.
reveal_type(x) # revealed: A | B | C
# The same for `y`
reveal_type(y) # revealed: A | B | C
else:
reveal_type(x) # revealed: B & ~A | C & ~A
reveal_type(y) # revealed: B & ~A | C & ~A

if (isinstance(x, A) and isinstance(y, A)) or (isinstance(x, B) and isinstance(y, B)):
# Here, types of `x` and `y` can be narrowd since all `or` arms constraint them.
reveal_type(x) # revealed: A | B
reveal_type(y) # revealed: A | B
else:
reveal_type(x) # revealed: A | B | C
reveal_type(y) # revealed: A | B | C
```

## mixing `and` and `not`

```py
class A: ...
class B: ...
class C: ...

def instance() -> A | B | C:
return A()

x = instance()

if isinstance(x, B) and not isinstance(x, C):
reveal_type(x) # revealed: B & ~C
else:
# ~(B & ~C) -> ~B | C -> (A & ~B) | (C & ~B) | C -> (A & ~B) | C
reveal_type(x) # revealed: A & ~B | C
```

## mixing `or` and `not`

```py
class A: ...
class B: ...
class C: ...

def instance() -> A | B | C:
return A()

x = instance()

if isinstance(x, B) or not isinstance(x, C):
reveal_type(x) # revealed: B | A & ~C
else:
reveal_type(x) # revealed: C & ~B
```

## `or` with nested `and`

```py
class A: ...
class B: ...
class C: ...

def instance() -> A | B | C:
return A()

x = instance()

if isinstance(x, A) or (isinstance(x, B) and not isinstance(x, C)):
reveal_type(x) # revealed: A | B & ~C
else:
# ~(A | (B & ~C)) -> ~A & ~(B & ~C) -> ~A & (~B | C) -> (~A & C) | (~A ~ B)
reveal_type(x) # revealed: C & ~A
```

## `and` with nested `or`

```py
class A: ...
class B: ...
class C: ...

def instance() -> A | B | C:
return A()

x = instance()

if isinstance(x, A) and (isinstance(x, B) or not isinstance(x, C)):
# A & (B | ~C) -> (A & B) | (A & ~C)
reveal_type(x) # revealed: A & B | A & ~C
else:
# ~((A & B) | (A & ~C)) ->
# ~(A & B) & ~(A & ~C) ->
# (~A | ~B) & (~A | C) ->
# [(~A | ~B) & ~A] | [(~A | ~B) & C] ->
# ~A | (~A & C) | (~B & C) ->
# ~A | (C & ~B) ->
# ~A | (C & ~B) The positive side of ~A is A | B | C ->
reveal_type(x) # revealed: B & ~A | C & ~A | C & ~B
```

## Boolean expression internal narrowing

```py
def optional_string() -> str | None:
return None

x = optional_string()
y = optional_string()

if x is None and y is not x:
reveal_type(y) # revealed: str

# Neither of the conditions alone is sufficient for narrowing y's type:
if x is None:
reveal_type(y) # revealed: str | None

if y is not x:
reveal_type(y) # revealed: str | None
```
78 changes: 78 additions & 0 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,46 @@ impl<'db> Type<'db> {
.elements(db)
.iter()
.any(|&elem_ty| ty.is_subtype_of(db, elem_ty)),
(Type::Intersection(self_intersection), Type::Intersection(target_intersection)) => {
// Check that all target positive values are covered in self positive values
target_intersection
.positive(db)
.iter()
.all(|&target_pos_elem| {
self_intersection
.positive(db)
.iter()
.any(|&self_pos_elem| self_pos_elem.is_subtype_of(db, target_pos_elem))
})
// Check that all target negative values are excluded in self, either by being
// subtypes of a self negative value or being disjoint from a self positive value.
&& target_intersection
.negative(db)
.iter()
.all(|&target_neg_elem| {
// Is target negative value is subtype of a self negative value
self_intersection.negative(db).iter().any(|&self_neg_elem| {
target_neg_elem.is_subtype_of(db, self_neg_elem)
// Is target negative value is disjoint from a self positive value?
}) || self_intersection.positive(db).iter().any(|&self_pos_elem| {
target_neg_elem.is_disjoint_from(db, self_pos_elem)
})
})
}
(Type::Intersection(intersection), ty) => intersection
.positive(db)
.iter()
.any(|&elem_ty| elem_ty.is_subtype_of(db, ty)),
(ty, Type::Intersection(intersection)) => {
intersection
.positive(db)
.iter()
.all(|&pos_ty| ty.is_subtype_of(db, pos_ty))
&& intersection
.negative(db)
.iter()
.all(|&neg_ty| neg_ty.is_disjoint_from(db, ty))
}
(Type::Instance(self_class), Type::Instance(target_class)) => {
self_class.is_subclass_of(db, target_class)
}
Expand Down Expand Up @@ -2190,6 +2230,11 @@ mod tests {
Ty::BuiltinInstance("FloatingPointError"),
Ty::BuiltinInstance("Exception")
)]
#[test_case(Ty::Intersection{pos: vec![Ty::BuiltinInstance("int")], neg: vec![Ty::IntLiteral(2)]}, Ty::BuiltinInstance("int"))]
#[test_case(Ty::Intersection{pos: vec![Ty::BuiltinInstance("int")], neg: vec![Ty::IntLiteral(2)]}, Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(2)]})]
#[test_case(Ty::Intersection{pos: vec![], neg: vec![Ty::BuiltinInstance("int")]}, Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(2)]})]
#[test_case(Ty::IntLiteral(1), Ty::Intersection{pos: vec![Ty::BuiltinInstance("int")], neg: vec![Ty::IntLiteral(2)]})]
#[test_case(Ty::Intersection{pos: vec![Ty::BuiltinInstance("str")], neg: vec![Ty::StringLiteral("foo")]}, Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(2)]})]
fn is_subtype_of(from: Ty, to: Ty) {
let db = setup_db();
assert!(from.into_type(&db).is_subtype_of(&db, to.into_type(&db)));
Expand All @@ -2210,6 +2255,11 @@ mod tests {
#[test_case(Ty::Tuple(vec![Ty::IntLiteral(42)]), Ty::Tuple(vec![Ty::BuiltinInstance("str")]))]
#[test_case(Ty::Tuple(vec![Ty::Todo]), Ty::Tuple(vec![Ty::IntLiteral(2)]))]
#[test_case(Ty::Tuple(vec![Ty::IntLiteral(2)]), Ty::Tuple(vec![Ty::Todo]))]
#[test_case(Ty::Intersection{pos: vec![Ty::BuiltinInstance("int")], neg: vec![Ty::IntLiteral(2)]}, Ty::Intersection{pos: vec![Ty::BuiltinInstance("int")], neg: vec![Ty::IntLiteral(3)]})]
#[test_case(Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(2)]}, Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(3)]})]
#[test_case(Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(2)]}, Ty::Intersection{pos: vec![], neg: vec![Ty::BuiltinInstance("int")]})]
#[test_case(Ty::BuiltinInstance("int"), Ty::Intersection{pos: vec![], neg: vec![Ty::IntLiteral(3)]})]
#[test_case(Ty::IntLiteral(1), Ty::Intersection{pos: vec![Ty::BuiltinInstance("int")], neg: vec![Ty::IntLiteral(1)]})]
fn is_not_subtype_of(from: Ty, to: Ty) {
let db = setup_db();
assert!(!from.into_type(&db).is_subtype_of(&db, to.into_type(&db)));
Expand Down Expand Up @@ -2241,6 +2291,34 @@ mod tests {
assert!(type_u.is_subtype_of(&db, Ty::BuiltinInstance("object").into_type(&db)));
}

#[test]
fn is_subtype_of_intersection_of_class_instances() {
let mut db = setup_db();
db.write_dedented(
"/src/module.py",
"
class A: ...
a = A()
class B: ...
b = B()
",
)
.unwrap();
let module = ruff_db::files::system_path_to_file(&db, "/src/module.py").unwrap();

let a_ty = super::global_symbol(&db, module, "a").expect_type();
let b_ty = super::global_symbol(&db, module, "b").expect_type();
let intersection = IntersectionBuilder::new(&db)
.add_positive(a_ty)
.add_positive(b_ty)
.build();

assert_eq!(intersection.display(&db).to_string(), "A & B");
assert!(!a_ty.is_subtype_of(&db, b_ty));
assert!(intersection.is_subtype_of(&db, b_ty));
assert!(intersection.is_subtype_of(&db, a_ty));
}

#[test_case(
Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)]),
Ty::Union(vec![Ty::IntLiteral(1), Ty::IntLiteral(2)])
Expand Down
Loading
Loading