Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revisit *Calls API #2555

Closed
wants to merge 15 commits into from
41 changes: 41 additions & 0 deletions .github/workflows/black_auto.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
---
name: Run black (auto)

defaults:
run:
# To load bashrc
shell: bash -ieo pipefail {0}

on:
pull_request:
branches: [master, dev]
paths:
- "**/*.py"

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
build:
name: Black
runs-on: ubuntu-latest

steps:
- name: Checkout Code
uses: actions/checkout@v4

- name: Set up Python 3.8
uses: actions/setup-python@v5
with:
python-version: 3.8

- name: Install Python dependencies
run: pip install .[dev]

- name: Run linters
uses: wearerequired/lint-action@v2
with:
auto_fix: true
black: true
black_auto_fix: true
2 changes: 1 addition & 1 deletion FUNDING.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
},
"drips": {
"ethereum": {
"ownedBy": "0x5e2BA02F62bD4efa939e3B80955bBC21d015DbA0"
"ownedBy": "0xc44F30Be3eBBEfdDBB5a85168710b4f0e18f4Ff0"
}
}
}
69 changes: 29 additions & 40 deletions slither/core/cfg/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,6 @@
if TYPE_CHECKING:
from slither.slithir.variables.variable import SlithIRVariable
from slither.core.compilation_unit import SlitherCompilationUnit
from slither.utils.type_helpers import (
InternalCallType,
HighLevelCallType,
LibraryCallType,
LowLevelCallType,
)
from slither.core.cfg.scope import Scope
from slither.core.scope.scope import FileScope

Expand Down Expand Up @@ -153,11 +147,11 @@ def __init__(
self._ssa_vars_written: List["SlithIRVariable"] = []
self._ssa_vars_read: List["SlithIRVariable"] = []

self._internal_calls: List[Union["Function", "SolidityFunction"]] = []
self._solidity_calls: List[SolidityFunction] = []
self._high_level_calls: List["HighLevelCallType"] = [] # contains library calls
self._library_calls: List["LibraryCallType"] = []
self._low_level_calls: List["LowLevelCallType"] = []
self._internal_calls: List[InternalCall] = [] # contains solidity calls
self._solidity_calls: List[SolidityCall] = []
self._high_level_calls: List[Tuple[Contract, HighLevelCall]] = [] # contains library calls
self._library_calls: List[LibraryCall] = []
self._low_level_calls: List[LowLevelCall] = []
self._external_calls_as_expressions: List[Expression] = []
self._internal_calls_as_expressions: List[Expression] = []
self._irs: List[Operation] = []
Expand Down Expand Up @@ -226,8 +220,9 @@ def type(self, new_type: NodeType) -> None:
@property
def will_return(self) -> bool:
if not self.sons and self.type != NodeType.THROW:
if SolidityFunction("revert()") not in self.solidity_calls:
if SolidityFunction("revert(string)") not in self.solidity_calls:
solidity_calls = [ir.function for ir in self.solidity_calls]
if SolidityFunction("revert()") not in solidity_calls:
if SolidityFunction("revert(string)") not in solidity_calls:
return True
return False

Expand Down Expand Up @@ -373,44 +368,38 @@ def variables_written_as_expression(self, exprs: List[Expression]) -> None:
###################################################################################

@property
def internal_calls(self) -> List["InternalCallType"]:
def internal_calls(self) -> List[InternalCall]:
"""
list(Function or SolidityFunction): List of internal/soldiity function calls
list(InternalCall): List of IR operations with internal/solidity function calls
"""
return list(self._internal_calls)

@property
def solidity_calls(self) -> List[SolidityFunction]:
def solidity_calls(self) -> List[SolidityCall]:
"""
list(SolidityFunction): List of Soldity calls
list(SolidityCall): List of IR operations with solidity calls
"""
return list(self._solidity_calls)

@property
def high_level_calls(self) -> List["HighLevelCallType"]:
def high_level_calls(self) -> List[HighLevelCall]:
"""
list((Contract, Function|Variable)):
List of high level calls (external calls).
A variable is called in case of call to a public state variable
list(HighLevelCall): List of IR operations with high level calls (external calls).
Include library calls
"""
return list(self._high_level_calls)

@property
def library_calls(self) -> List["LibraryCallType"]:
def library_calls(self) -> List[LibraryCall]:
"""
list((Contract, Function)):
Include library calls
list(LibraryCall): List of IR operations with library calls.
"""
return list(self._library_calls)

@property
def low_level_calls(self) -> List["LowLevelCallType"]:
def low_level_calls(self) -> List[LowLevelCall]:
"""
list((Variable|SolidityVariable, str)): List of low_level call
A low level call is defined by
- the variable called
- the name of the function (call/delegatecall/codecall)
list(LowLevelCall): List of IR operations with low_level call
"""
return list(self._low_level_calls)

Expand Down Expand Up @@ -529,9 +518,9 @@ def contains_require_or_assert(self) -> bool:
bool: True if the node has a require or assert call
"""
return any(
c.name
ir.function.name
in ["require(bool)", "require(bool,string)", "require(bool,error)", "assert(bool)"]
for c in self.internal_calls
for ir in self.internal_calls
)

def contains_if(self, include_loop: bool = True) -> bool:
Expand Down Expand Up @@ -895,11 +884,11 @@ def _find_read_write_call(self) -> None: # pylint: disable=too-many-statements
self._vars_written.append(var)

if isinstance(ir, InternalCall):
self._internal_calls.append(ir.function)
self._internal_calls.append(ir)
if isinstance(ir, SolidityCall):
# TODO: consider removing dependancy of solidity_call to internal_call
self._solidity_calls.append(ir.function)
self._internal_calls.append(ir.function)
self._solidity_calls.append(ir)
self._internal_calls.append(ir)
if (
isinstance(ir, SolidityCall)
and ir.function == SolidityFunction("sstore(uint256,uint256)")
Expand All @@ -917,22 +906,22 @@ def _find_read_write_call(self) -> None: # pylint: disable=too-many-statements
self._vars_read.append(ir.arguments[0])
if isinstance(ir, LowLevelCall):
assert isinstance(ir.destination, (Variable, SolidityVariable))
self._low_level_calls.append((ir.destination, str(ir.function_name.value)))
self._low_level_calls.append(ir)
elif isinstance(ir, HighLevelCall) and not isinstance(ir, LibraryCall):
# Todo investigate this if condition
# It does seem right to compare against a contract
# This might need a refactoring
if isinstance(ir.destination.type, Contract):
self._high_level_calls.append((ir.destination.type, ir.function))
self._high_level_calls.append((ir.destination.type, ir))
elif ir.destination == SolidityVariable("this"):
func = self.function
# Can't use this in a top level function
assert isinstance(func, FunctionContract)
self._high_level_calls.append((func.contract, ir.function))
self._high_level_calls.append((func.contract, ir))
else:
try:
# Todo this part needs more tests and documentation
self._high_level_calls.append((ir.destination.type.type, ir.function))
self._high_level_calls.append((ir.destination.type.type, ir))
except AttributeError as error:
# pylint: disable=raise-missing-from
raise SlitherException(
Expand All @@ -941,8 +930,8 @@ def _find_read_write_call(self) -> None: # pylint: disable=too-many-statements
elif isinstance(ir, LibraryCall):
assert isinstance(ir.destination, Contract)
assert isinstance(ir.function, Function)
self._high_level_calls.append((ir.destination, ir.function))
self._library_calls.append((ir.destination, ir.function))
self._high_level_calls.append((ir.destination, ir))
self._library_calls.append(ir)

self._vars_read = list(set(self._vars_read))
self._state_vars_read = [v for v in self._vars_read if isinstance(v, StateVariable)]
Expand Down
22 changes: 14 additions & 8 deletions slither/core/declarations/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

# pylint: disable=too-many-lines,too-many-instance-attributes,import-outside-toplevel,too-many-nested-blocks
if TYPE_CHECKING:
from slither.utils.type_helpers import LibraryCallType, HighLevelCallType, InternalCallType
from slither.core.declarations import (
Enum,
EventContract,
Expand All @@ -39,6 +38,7 @@
FunctionContract,
CustomErrorContract,
)
from slither.slithir.operations import HighLevelCall, LibraryCall
from slither.slithir.variables.variable import SlithIRVariable
from slither.core.variables import Variable, StateVariable
from slither.core.compilation_unit import SlitherCompilationUnit
Expand Down Expand Up @@ -106,7 +106,7 @@ def __init__(self, compilation_unit: "SlitherCompilationUnit", scope: "FileScope
self._is_incorrectly_parsed: bool = False

self._available_functions_as_dict: Optional[Dict[str, "Function"]] = None
self._all_functions_called: Optional[List["InternalCallType"]] = None
self._all_functions_called: Optional[List["Function"]] = None

self.compilation_unit: "SlitherCompilationUnit" = compilation_unit
self.file_scope: "FileScope" = scope
Expand Down Expand Up @@ -1023,15 +1023,21 @@ def get_functions_overridden_by(self, function: "Function") -> List["Function"]:
###################################################################################

@property
def all_functions_called(self) -> List["InternalCallType"]:
def all_functions_called(self) -> List["Function"]:
"""
list(Function): List of functions reachable from the contract
Includes super, and private/internal functions not shadowed
"""
from slither.slithir.operations import Operation

if self._all_functions_called is None:
all_functions = [f for f in self.functions + self.modifiers if not f.is_shadowed] # type: ignore
all_callss = [f.all_internal_calls() for f in all_functions] + [list(all_functions)]
all_calls = [item for sublist in all_callss for item in sublist]
all_calls = [
item.function if isinstance(item, Operation) else item
for sublist in all_callss
for item in sublist
]
all_calls = list(set(all_calls))

all_constructors = [c.constructor for c in self.inheritance if c.constructor]
Expand Down Expand Up @@ -1069,18 +1075,18 @@ def all_state_variables_read(self) -> List["StateVariable"]:
return list(set(all_state_variables_read))

@property
def all_library_calls(self) -> List["LibraryCallType"]:
def all_library_calls(self) -> List["LibraryCall"]:
"""
list((Contract, Function): List all of the libraries func called
list(LibraryCall): List all of the libraries func called
"""
all_high_level_callss = [f.all_library_calls() for f in self.functions + self.modifiers] # type: ignore
all_high_level_calls = [item for sublist in all_high_level_callss for item in sublist]
return list(set(all_high_level_calls))

@property
def all_high_level_calls(self) -> List["HighLevelCallType"]:
def all_high_level_calls(self) -> List[Tuple["Contract", "HighLevelCall"]]:
"""
list((Contract, Function|Variable)): List all of the external high level calls
list(Tuple("Contract", "HighLevelCall")): List all of the external high level calls
"""
all_high_level_callss = [f.all_high_level_calls() for f in self.functions + self.modifiers] # type: ignore
all_high_level_calls = [item for sublist in all_high_level_callss for item in sublist]
Expand Down
Loading