diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9aebd89444..2475c00e09 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,7 +16,7 @@ repos: name: black - repo: https://github.com/pycqa/flake8 - rev: 7.1.1 + rev: 7.1.2 hooks: - id: flake8 additional_dependencies: [flake8-breakpoint, flake8-print, flake8-pydantic, flake8-type-checking] diff --git a/pyproject.toml b/pyproject.toml index 11702954e0..0a69f6ecb1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ force_grid_wrap = 0 include_trailing_comma = true multi_line_output = 3 use_parentheses = true +skip = ["version.py"] [tool.mdformat] number = true diff --git a/setup.py b/setup.py index 9866837d9a..963935faec 100644 --- a/setup.py +++ b/setup.py @@ -31,7 +31,7 @@ "types-toml", # Needed due to mypy typeshed "types-SQLAlchemy>=1.4.49", # Needed due to mypy typeshed "types-python-dateutil", # Needed due to mypy typeshed - "flake8>=7.1.1,<8", # Style linter + "flake8>=7.1.2,<8", # Style linter "flake8-breakpoint>=1.1.0,<2", # Detect breakpoints left in code "flake8-print>=4.0.1,<5", # Detect print statements left in code "flake8-pydantic", # For detecting issues with Pydantic models diff --git a/src/ape/managers/_contractscache.py b/src/ape/managers/_contractscache.py index 0e7bcbfece..d9c901f6eb 100644 --- a/src/ape/managers/_contractscache.py +++ b/src/ape/managers/_contractscache.py @@ -554,48 +554,39 @@ def get( return None if contract_type := self.contract_types[address_key]: + # The ContractType was previously cached. if default and default != contract_type: - # Replacing contract type - self.contract_types[address_key] = default - return default + # The given default ContractType is different than the cached one. + # Merge the two and cache the merged result. + combined_contract_type = _merge_contract_types(contract_type, default) + self.contract_types[address_key] = combined_contract_type + return combined_contract_type return contract_type - else: - # Contract is not cached yet. Check broader sources, such as an explorer. - if not proxy_info and detect_proxy: - # Proxy info not provided. Attempt to detect. - if not (proxy_info := self.proxy_infos[address_key]): - if proxy_info := self.provider.network.ecosystem.get_proxy_info(address_key): - self.proxy_infos[address_key] = proxy_info - - if proxy_info: - # Contract is a proxy (either was detected or provided). - implementation_contract_type = self.get(proxy_info.target, default=default) - proxy_contract_type = ( - self._get_contract_type_from_explorer(address_key) - if fetch_from_explorer - else None - ) - if proxy_contract_type is not None and implementation_contract_type is not None: - combined_contract = _get_combined_contract_type( - proxy_contract_type, proxy_info, implementation_contract_type - ) - self.contract_types[address_key] = combined_contract - return combined_contract - - elif implementation_contract_type is not None: - contract_type_to_cache = implementation_contract_type - self.contract_types[address_key] = implementation_contract_type - return contract_type_to_cache - - elif proxy_contract_type is not None: - self.contract_types[address_key] = proxy_contract_type - return proxy_contract_type - - # Also gets cached to disk for faster lookup next time. - if fetch_from_explorer: - contract_type = self._get_contract_type_from_explorer(address_key) + # Contract is not cached yet. Check broader sources, such as an explorer. + if not proxy_info and detect_proxy: + # Proxy info not provided. Attempt to detect. + if not (proxy_info := self.proxy_infos[address_key]): + if proxy_info := self.provider.network.ecosystem.get_proxy_info(address_key): + self.proxy_infos[address_key] = proxy_info + + if proxy_info: + if proxy_contract_type := self._get_proxy_contract_type( + address_key, + proxy_info, + fetch_from_explorer=fetch_from_explorer, + default=default, + ): + # `proxy_contract_type` is one of the following: + # 1. A ContractType with the combined proxy and implementation ABIs + # 2. Implementation-only ABI ContractType (like forwarder proxies) + # 3. Proxy only ABI (e.g. unverified implementation ContractType) + return proxy_contract_type + + # Also gets cached to disk for faster lookup next time. + if fetch_from_explorer: + contract_type = self._get_contract_type_from_explorer(address_key) # Cache locally for faster in-session look-up. if contract_type: @@ -606,6 +597,65 @@ def get( return contract_type + def _get_proxy_contract_type( + self, + address: AddressType, + proxy_info: ProxyInfoAPI, + fetch_from_explorer: bool = True, + default: Optional[ContractType] = None, + ) -> Optional[ContractType]: + """ + Combines the discoverable ABIs from the proxy contract and its implementation. + """ + implementation_contract_type = self._get_contract_type( + proxy_info.target, + fetch_from_explorer=fetch_from_explorer, + default=default, + ) + proxy_contract_type = self._get_contract_type( + address, fetch_from_explorer=fetch_from_explorer + ) + if proxy_contract_type is not None and implementation_contract_type is not None: + combined_contract = _get_combined_contract_type( + proxy_contract_type, proxy_info, implementation_contract_type + ) + self.contract_types[address] = combined_contract + return combined_contract + + elif implementation_contract_type is not None: + contract_type_to_cache = implementation_contract_type + self.contract_types[address] = implementation_contract_type + return contract_type_to_cache + + elif proxy_contract_type is not None: + # In this case, the implementation ContactType was not discovered. + # However, we were able to discover the ContractType of the proxy. + # Proceed with caching the proxy; the user can update the type later + # when the implementation is discoverable. + self.contract_types[address] = proxy_contract_type + return proxy_contract_type + + logger.warning(f"Unable to determine the ContractType for the proxy at '{address}'.") + return None + + def _get_contract_type( + self, + address: AddressType, + fetch_from_explorer: bool = True, + default: Optional[ContractType] = None, + ) -> Optional[ContractType]: + """ + Get the _exact_ ContractType for a given address. For proxy contracts, returns + the proxy ABIs if there are any and not the implementation ABIs. + """ + if contract_type := self.contract_types[address]: + return contract_type + + elif fetch_from_explorer: + return self._get_contract_type_from_explorer(address) + + return default + @classmethod def get_container(cls, contract_type: ContractType) -> ContractContainer: """ @@ -859,6 +909,16 @@ def _get_contract_type_from_explorer(self, address: AddressType) -> Optional[Con if contract_type: # Cache contract so faster look-up next time. + if not isinstance(contract_type, ContractType): + explorer_name = self.provider.network.explorer.name + wrong_type = type(contract_type) + wrong_type_str = getattr(wrong_type, "__name__", f"{wrong_type}") + logger.warning( + f"Explorer '{explorer_name}' returned unexpected " + f"type '{wrong_type_str}' ContractType." + ) + return None + self.contract_types[address] = contract_type return contract_type @@ -869,9 +929,7 @@ def _get_combined_contract_type( proxy_info: ProxyInfoAPI, implementation_contract_type: ContractType, ) -> ContractType: - proxy_abis = [ - abi for abi in proxy_contract_type.abi if abi.type in ("error", "event", "function") - ] + proxy_abis = _get_relevant_additive_abis(proxy_contract_type) # Include "hidden" ABIs, such as Safe's `masterCopy()`. if proxy_info.abi and proxy_info.abi.signature not in [ @@ -879,6 +937,23 @@ def _get_combined_contract_type( ]: proxy_abis.append(proxy_info.abi) - combined_contract_type = implementation_contract_type.model_copy(deep=True) - combined_contract_type.abi.extend(proxy_abis) - return combined_contract_type + return _merge_abis(implementation_contract_type, proxy_abis) + + +def _get_relevant_additive_abis(contract_type: ContractType) -> list[ABI]: + # Get ABIs you would want to add to a base contract as extra, + # such as unique ABIs from proxies. + return [abi for abi in contract_type.abi if abi.type in ("error", "event", "function")] + + +def _merge_abis(base_contract: ContractType, extra_abis: list[ABI]) -> ContractType: + contract_type = base_contract.model_copy(deep=True) + contract_type.abi.extend(extra_abis) + return contract_type + + +def _merge_contract_types( + base_contract_type: ContractType, additive_contract_type: ContractType +) -> ContractType: + relevant_abis = _get_relevant_additive_abis(additive_contract_type) + return _merge_abis(base_contract_type, relevant_abis) diff --git a/tests/functional/conftest.py b/tests/functional/conftest.py index c8ca706759..21f18f58d8 100644 --- a/tests/functional/conftest.py +++ b/tests/functional/conftest.py @@ -472,6 +472,13 @@ def dummy_live_network(chain): chain.provider.network.name = original_network +@pytest.fixture +def dummy_live_network_with_explorer(dummy_live_network, mock_explorer): + dummy_live_network.__dict__["explorer"] = mock_explorer + yield dummy_live_network + dummy_live_network.__dict__.pop("explorer", None) + + @pytest.fixture(scope="session") def proxy_contract_container(get_contract_type): return ContractContainer(get_contract_type("proxy")) diff --git a/tests/functional/test_accounts.py b/tests/functional/test_accounts.py index 7f1bd215e0..21f688b3b8 100644 --- a/tests/functional/test_accounts.py +++ b/tests/functional/test_accounts.py @@ -314,19 +314,19 @@ def test_deploy_and_publish_live_network_no_explorer(owner, contract_container, @explorer_test -def test_deploy_and_publish(owner, contract_container, dummy_live_network, mock_explorer): - dummy_live_network.__dict__["explorer"] = mock_explorer +def test_deploy_and_publish( + owner, contract_container, dummy_live_network_with_explorer, mock_explorer +): contract = owner.deploy(contract_container, 0, publish=True, required_confirmations=0) mock_explorer.publish_contract.assert_called_once_with(contract.address) - dummy_live_network.__dict__["explorer"] = None @explorer_test -def test_deploy_and_not_publish(owner, contract_container, dummy_live_network, mock_explorer): - dummy_live_network.__dict__["explorer"] = mock_explorer +def test_deploy_and_not_publish( + owner, contract_container, dummy_live_network_with_explorer, mock_explorer +): owner.deploy(contract_container, 0, publish=True, required_confirmations=0) assert not mock_explorer.call_count - dummy_live_network.__dict__["explorer"] = None def test_deploy_proxy(owner, vyper_contract_instance, proxy_contract_container, chain): diff --git a/tests/functional/test_contract.py b/tests/functional/test_contract.py index aab93b4bb6..4723d07a9d 100644 --- a/tests/functional/test_contract.py +++ b/tests/functional/test_contract.py @@ -85,12 +85,13 @@ def test_Contract_at_unknown_address(networks_connected_to_tester, address): def test_Contract_specify_contract_type( - solidity_contract_instance, vyper_contract_type, owner, networks_connected_to_tester + vyper_contract_instance, solidity_contract_type, owner, networks_connected_to_tester ): - # Vyper contract type is very close to solidity's. + # Solidity's contract type is very close to Vyper's. # This test purposely uses the other just to show we are able to specify it externally. - contract = Contract(solidity_contract_instance.address, contract_type=vyper_contract_type) - assert contract.address == solidity_contract_instance.address - assert contract.contract_type == vyper_contract_type - assert contract.setNumber(2, sender=owner) - assert contract.myNumber() == 2 + contract = Contract(vyper_contract_instance.address, contract_type=solidity_contract_type) + assert contract.address == vyper_contract_instance.address + + abis = [abi.name for abi in contract.contract_type.abi if hasattr(abi, "name")] + assert "setNumber" in abis # Shared ABI. + assert "ACustomError" in abis # SolidityContract-defined ABI. diff --git a/tests/functional/test_contract_container.py b/tests/functional/test_contract_container.py index e311b23a7b..e125cb7a37 100644 --- a/tests/functional/test_contract_container.py +++ b/tests/functional/test_contract_container.py @@ -10,6 +10,7 @@ ProjectError, ) from ape_ethereum.ecosystem import ProxyType +from tests.conftest import explorer_test def test_deploy( @@ -55,18 +56,20 @@ def test_deploy_and_publish_live_network_no_explorer(owner, contract_container, contract_container.deploy(0, sender=owner, publish=True, required_confirmations=0) -def test_deploy_and_publish(owner, contract_container, dummy_live_network, mock_explorer): - dummy_live_network.__dict__["explorer"] = mock_explorer +@explorer_test +def test_deploy_and_publish( + owner, contract_container, dummy_live_network_with_explorer, mock_explorer +): contract = contract_container.deploy(0, sender=owner, publish=True, required_confirmations=0) mock_explorer.publish_contract.assert_called_once_with(contract.address) - dummy_live_network.__dict__["explorer"] = None -def test_deploy_and_not_publish(owner, contract_container, dummy_live_network, mock_explorer): - dummy_live_network.__dict__["explorer"] = mock_explorer +@explorer_test +def test_deploy_and_not_publish( + owner, contract_container, dummy_live_network_with_explorer, mock_explorer +): contract_container.deploy(0, sender=owner, publish=False, required_confirmations=0) assert not mock_explorer.call_count - dummy_live_network.__dict__["explorer"] = None def test_deploy_privately(owner, contract_container): diff --git a/tests/functional/test_contract_instance.py b/tests/functional/test_contract_instance.py index 1c786187d8..1056284cc8 100644 --- a/tests/functional/test_contract_instance.py +++ b/tests/functional/test_contract_instance.py @@ -880,9 +880,9 @@ def test_value_to_non_payable_fallback_and_no_receive( break new_contract_type = ContractType.model_validate(contract_type_data) - contract = owner.chain_manager.contracts.instance_at( - vyper_fallback_contract.address, contract_type=new_contract_type - ) + contract = owner.chain_manager.contracts.instance_at(vyper_fallback_contract.address) + contract.contract_type = new_contract_type # Setting to completely override instead of merge. + expected = ( r"Contract's fallback is non-payable and there is no receive ABI\. Unable to send value\." ) diff --git a/tests/functional/test_contracts_cache.py b/tests/functional/test_contracts_cache.py index d635d9e9b2..cd47a73538 100644 --- a/tests/functional/test_contracts_cache.py +++ b/tests/functional/test_contracts_cache.py @@ -4,8 +4,8 @@ from ape import Contract from ape.contracts import ContractInstance from ape.exceptions import ContractNotFoundError, ConversionError -from ape.logging import LogLevel -from ape_ethereum.proxies import _make_minimal_proxy +from ape.logging import LogLevel, logger +from ape_ethereum.proxies import ProxyInfo, ProxyType, _make_minimal_proxy from tests.conftest import explorer_test, skip_if_plugin_installed @@ -161,10 +161,10 @@ def test_instance_at_skip_proxy(mocker, chain, vyper_contract_instance, owner): def test_cache_deployment_live_network( chain, + dummy_live_network, + clean_contract_caches, vyper_contract_instance, vyper_contract_container, - clean_contract_caches, - dummy_live_network, ): # Arrange - Ensure the contract is not cached anywhere address = vyper_contract_instance.address @@ -435,9 +435,6 @@ def test_get_attempts_explorer_logs_rate_limit_error_from_explorer( ): contract = owner.deploy(vyper_fallback_container) - # Ensure is not cached locally. - del chain.contracts[contract.address] - # For rate limit errors, we don't show anything else, # as it may be confusing. check_error_str = "you have been rate limited" @@ -449,10 +446,14 @@ def get_contract_type(addr): raise ValueError("nope") with create_mock_sepolia() as network: + # Ensure is not cached locally. + del chain.contracts[contract.address] + mock_explorer.get_contract_type.side_effect = get_contract_type network.__dict__["explorer"] = mock_explorer try: - actual = chain.contracts.get(contract.address) + with logger.at_level(LogLevel.INFO): + actual = chain.contracts.get(contract.address) finally: network.__dict__["explorer"] = None @@ -474,7 +475,7 @@ def test_cache_non_checksum_address(chain, vyper_contract_instance): assert chain.contracts[vyper_contract_instance.address] == vyper_contract_instance.contract_type -def test_get_when_proxy(chain, owner, minimal_proxy_container, vyper_contract_instance): +def test_get_proxy(chain, owner, minimal_proxy_container, vyper_contract_instance): placeholder = "0xBEbeBeBEbeBebeBeBEBEbebEBeBeBebeBeBebebe" if placeholder in chain.contracts: del chain.contracts[placeholder] @@ -486,7 +487,7 @@ def test_get_when_proxy(chain, owner, minimal_proxy_container, vyper_contract_in assert actual == minimal_proxy.contract_type -def test_get_when_proxy_but_implementation_missing(chain, owner, vyper_contract_container): +def test_get_proxy_implementation_missing(chain, owner, vyper_contract_container): """ Proxy is cached but implementation is missing. """ @@ -507,7 +508,7 @@ def test_get_when_proxy_but_implementation_missing(chain, owner, vyper_contract_ assert actual == minimal_proxy.contract_type -def test_get_pass_along_proxy_info(chain, owner, minimal_proxy_container, ethereum): +def test_get_proxy_pass_proxy_info(chain, owner, minimal_proxy_container, ethereum): placeholder = "0xBEbeBeBEbeBebeBeBEBEbebEBeBeBebeBeBebebe" if placeholder in chain.contracts: del chain.contracts[placeholder] @@ -528,6 +529,28 @@ def test_get_pass_along_proxy_info(chain, owner, minimal_proxy_container, ethere assert minimal_proxy.address not in chain.contracts.contract_types +@explorer_test +def test_get_proxy_pass_proxy_info_and_no_explorer( + chain, owner, proxy_contract_container, ethereum, dummy_live_network_with_explorer +): + """ + Tests the condition of both passing `proxy_info=` and setting `use_explorer=False` + when getting the ContractType of a proxy. + """ + explorer = dummy_live_network_with_explorer.explorer + placeholder = "0xBEbeBeBEbeBebeBeBEBEbebEBeBeBebeBeBebebe" + if placeholder in chain.contracts: + del chain.contracts[placeholder] + + proxy = proxy_contract_container.deploy(placeholder, sender=owner, required_confirmations=0) + info = ProxyInfo(type=ProxyType.Minimal, target=placeholder) + explorer.get_contract_type.reset_mock() + chain.contracts.get(proxy.address, proxy_info=info, fetch_from_explorer=False) + + # Ensure explorer was not used. + assert explorer.get_contract_type.call_count == 0 + + def test_get_creation_metadata(chain, vyper_contract_instance, owner): address = vyper_contract_instance.address creation = chain.contracts.get_creation_metadata(address) diff --git a/tests/functional/test_ecosystem.py b/tests/functional/test_ecosystem.py index 19466242a3..1c3491aeab 100644 --- a/tests/functional/test_ecosystem.py +++ b/tests/functional/test_ecosystem.py @@ -716,6 +716,7 @@ class L2NetworkConfig(BaseEthereumConfig): @pytest.mark.parametrize("network_name", (LOCAL_NETWORK_NAME, "mainnet-fork", "mainnet_fork")) def test_gas_limit_local_networks(ethereum, network_name): network = ethereum.get_network(network_name) + network.__dict__.pop("gas_limit", None) # Refresh in case was changed in another test. assert network.gas_limit == "max"