diff --git a/rope/refactor/importutils/importinfo.py b/rope/refactor/importutils/importinfo.py index 640200c4..451a387e 100644 --- a/rope/refactor/importutils/importinfo.py +++ b/rope/refactor/importutils/importinfo.py @@ -1,4 +1,5 @@ -from typing import List, Tuple +from abc import abstractmethod, ABC +from typing import List, Tuple, Optional, Protocol class ImportStatement: @@ -64,9 +65,11 @@ def accept(self, visitor): return visitor.dispatch(self) -class ImportInfo: - def get_imported_primaries(self, context): - pass +class ImportInfo(ABC): + names_and_aliases: List[Tuple[str, Optional[str]]] + + @abstractmethod + def get_imported_primaries(self, context) -> List[str]: ... def get_imported_names(self, context): return [ @@ -76,8 +79,8 @@ def get_imported_names(self, context): def __repr__(self): return f'<{self.__class__.__name__} "{self.get_import_statement()}">' - def get_import_statement(self): - pass + @abstractmethod + def get_import_statement(self) -> str: ... def is_empty(self): pass @@ -108,10 +111,13 @@ def get_empty_import(): class NormalImport(ImportInfo): - def __init__(self, names_and_aliases): + def __init__( + self, + names_and_aliases: List[Tuple[str, Optional[str]]], + ) -> None: self.names_and_aliases = names_and_aliases - def get_imported_primaries(self, context): + def get_imported_primaries(self, context) -> List[str]: result = [] for name, alias in self.names_and_aliases: if alias: @@ -120,7 +126,7 @@ def get_imported_primaries(self, context): result.append(name) return result - def get_import_statement(self): + def get_import_statement(self) -> str: result = "import " for name, alias in self.names_and_aliases: result += name @@ -134,12 +140,20 @@ def is_empty(self): class FromImport(ImportInfo): - def __init__(self, module_name, level, names_and_aliases): + module_name: str + level: int + + def __init__( + self, + module_name: str, + level: int, + names_and_aliases: List[Tuple[str, Optional[str]]], + ): self.module_name = module_name self.level = level self.names_and_aliases = names_and_aliases - def get_imported_primaries(self, context): + def get_imported_primaries(self, context) -> List[str]: if self.names_and_aliases[0][0] == "*": module = self.get_imported_module(context) return [name for name in module if not name.startswith("_")] @@ -176,7 +190,7 @@ def get_imported_module(self, context): self.module_name, context.folder, self.level ) - def get_import_statement(self): + def get_import_statement(self) -> str: result = "from " + "." * self.level + self.module_name + " import " for name, alias in self.names_and_aliases: result += name @@ -193,14 +207,17 @@ def is_star_import(self): class EmptyImport(ImportInfo): - names_and_aliases: List[Tuple[str, str]] = [] + names_and_aliases = [] def is_empty(self): return True - def get_imported_primaries(self, context): + def get_imported_primaries(self, context) -> List[str]: return [] + def get_import_statement(self) -> str: + raise NotImplementedError() + class ImportContext: def __init__(self, project, folder):