Skip to content

Commit

Permalink
Add context for symbol_ty lookup
Browse files Browse the repository at this point in the history
  • Loading branch information
pilleye committed Oct 2, 2024
1 parent 8a3a0d8 commit b19c32f
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 25 deletions.
61 changes: 45 additions & 16 deletions crates/red_knot_python_semantic/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,44 +43,55 @@ pub fn check_types(db: &dyn Db, file: File) -> TypeCheckDiagnostics {
diagnostics
}

#[derive(Clone, Copy, Eq, PartialEq, Debug)]
enum SymbolTableLookupContext {
PublicImport,
GlobalWithinModule,
}

/// Infer the public type of a symbol (its type as seen from outside its scope).
fn symbol_ty_by_id<'db>(db: &'db dyn Db, scope: ScopeId<'db>, symbol: ScopedSymbolId) -> Type<'db> {
fn symbol_ty_by_id<'db>(
db: &'db dyn Db,
scope: ScopeId<'db>,
symbol: ScopedSymbolId,
context: SymbolTableLookupContext,
) -> Type<'db> {
let _span = tracing::trace_span!("symbol_ty_by_id", ?symbol).entered();

let use_def = use_def_map(db, scope);

// If the symbol is declared, the public type is based on declarations; otherwise, it's based
// on inference from bindings.
if use_def.has_public_declarations(symbol) {
let unbound_ty = use_def
.public_may_be_unbound(symbol)
.then_some(match context {
SymbolTableLookupContext::PublicImport => Type::Never,
SymbolTableLookupContext::GlobalWithinModule => Type::Unbound,
});

// If we want to try to use declarations and the symbol is declared, the public type is based on
// declarations; otherwise, it's based on inference from bindings.
if context == SymbolTableLookupContext::PublicImport && use_def.has_public_declarations(symbol)
{
let declarations = use_def.public_declarations(symbol);
// If the symbol is undeclared in some paths, include the inferred type in the public type.
let undeclared_ty = if declarations.may_be_undeclared() {
Some(bindings_ty(
db,
use_def.public_bindings(symbol),
use_def.public_may_be_unbound(symbol).then_some(Type::Never),
))
Some(bindings_ty(db, use_def.public_bindings(symbol), unbound_ty))
} else {
None
};
// Intentionally ignore conflicting declared types; that's not our problem, it's the
// problem of the module we are importing from.
declarations_ty(db, declarations, undeclared_ty).unwrap_or_else(|(ty, _)| ty)
} else {
bindings_ty(
db,
use_def.public_bindings(symbol),
use_def.public_may_be_unbound(symbol).then_some(Type::Never),
)
bindings_ty(db, use_def.public_bindings(symbol), unbound_ty)
}
}

/// Shorthand for `symbol_ty` that takes a symbol name instead of an ID.
/// Shorthand for `symbol_ty_by_id` that takes a symbol name instead of an ID.
fn symbol_ty<'db>(db: &'db dyn Db, scope: ScopeId<'db>, name: &str) -> Type<'db> {
let table = symbol_table(db, scope);
table
.symbol_id_by_name(name)
.map(|symbol| symbol_ty_by_id(db, scope, symbol))
.map(|symbol| symbol_ty_by_id(db, scope, symbol, SymbolTableLookupContext::PublicImport))
.unwrap_or(Type::Unbound)
}

Expand All @@ -89,6 +100,24 @@ pub(crate) fn global_symbol_ty<'db>(db: &'db dyn Db, file: File, name: &str) ->
symbol_ty(db, global_scope(db, file), name)
}

/// Shorthand for `symbol_ty_by_id` that looks up a global symbol from the context of being in that
/// module.
pub(crate) fn global_symbol_lookup<'db>(db: &'db dyn Db, file: File, name: &str) -> Type<'db> {
let table = symbol_table(db, global_scope(db, file));

table
.symbol_id_by_name(name)
.map(|symbol| {
symbol_ty_by_id(
db,
global_scope(db, file),
symbol,
SymbolTableLookupContext::GlobalWithinModule,
)
})
.unwrap_or_else(|| Type::Unbound)
}

/// Infer the type of a binding.
pub(crate) fn binding_ty<'db>(db: &'db dyn Db, definition: Definition<'db>) -> Type<'db> {
let inference = infer_definition_types(db, definition);
Expand Down
24 changes: 15 additions & 9 deletions crates/red_knot_python_semantic/src/types/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ use crate::semantic_index::SemanticIndex;
use crate::stdlib::builtins_module_scope;
use crate::types::diagnostic::{TypeCheckDiagnostic, TypeCheckDiagnostics};
use crate::types::{
bindings_ty, builtins_symbol_ty, declarations_ty, global_symbol_ty, symbol_ty,
bindings_ty, builtins_symbol_ty, declarations_ty, global_symbol_lookup, symbol_ty,
typing_extensions_symbol_ty, BytesLiteralType, ClassType, FunctionKind, FunctionType,
StringLiteralType, Truthiness, TupleType, Type, TypeArrayDisplay, UnionType,
};
Expand Down Expand Up @@ -2210,7 +2210,7 @@ impl<'db> TypeInferenceBuilder<'db> {
let ty = if file_scope_id.is_global() {
Type::Unbound
} else {
global_symbol_ty(self.db, self.file, name)
global_symbol_lookup(self.db, self.file, name)
};
// Fallback to builtins (without infinite recursion if we're already in builtins.)
if ty.may_be_unbound(self.db) && Some(self.scope) != builtins_module_scope(self.db) {
Expand Down Expand Up @@ -3016,6 +3016,7 @@ mod tests {
use ruff_db::system::{DbWithTestSystem, SystemPathBuf};
use ruff_db::testing::assert_function_query_was_not_run;
use ruff_python_ast::name::Name;
use test_case::test_case;

use super::TypeInferenceBuilder;

Expand Down Expand Up @@ -4759,18 +4760,23 @@ mod tests {
Ok(())
}

#[test]
fn conditionally_global_or_builtin() -> anyhow::Result<()> {
#[test_case("")]
// Tests that we only use the definition of a symbol instead of its declaration when we are
// checking module globals without a nonlocal binding.
#[test_case(": int"; "with a declaration")]
fn conditionally_global_or_builtin(annotation: &'static str) -> anyhow::Result<()> {
let mut db = setup_db();

db.write_dedented(
"/src/a.py",
"
if flag:
copyright = 1
def f():
y = copyright
&format!(
"
if flag:
copyright {annotation} = 1
def f():
y = copyright
",
),
)?;

let file = system_path_to_file(&db, "src/a.py").expect("Expected file to exist.");
Expand Down

0 comments on commit b19c32f

Please sign in to comment.