Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: unnecessarily was using explorer when providing proxy info manually #2524

Merged
merged 8 commits into from
Feb 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ repos:
name: black

- repo: https://github.com/pycqa/flake8
rev: 7.1.1
rev: 7.1.2
hooks:
- id: flake8
additional_dependencies: [flake8-breakpoint, flake8-print, flake8-pydantic, flake8-type-checking]
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ force_grid_wrap = 0
include_trailing_comma = true
multi_line_output = 3
use_parentheses = true
skip = ["version.py"]

[tool.mdformat]
number = true
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
"types-toml", # Needed due to mypy typeshed
"types-SQLAlchemy>=1.4.49", # Needed due to mypy typeshed
"types-python-dateutil", # Needed due to mypy typeshed
"flake8>=7.1.1,<8", # Style linter
"flake8>=7.1.2,<8", # Style linter
"flake8-breakpoint>=1.1.0,<2", # Detect breakpoints left in code
"flake8-print>=4.0.1,<5", # Detect print statements left in code
"flake8-pydantic", # For detecting issues with Pydantic models
Expand Down
163 changes: 119 additions & 44 deletions src/ape/managers/_contractscache.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,48 +554,39 @@ def get(
return None

if contract_type := self.contract_types[address_key]:
# The ContractType was previously cached.
if default and default != contract_type:
# Replacing contract type
self.contract_types[address_key] = default
return default
# The given default ContractType is different than the cached one.
# Merge the two and cache the merged result.
combined_contract_type = _merge_contract_types(contract_type, default)
self.contract_types[address_key] = combined_contract_type
return combined_contract_type

return contract_type

else:
# Contract is not cached yet. Check broader sources, such as an explorer.
if not proxy_info and detect_proxy:
# Proxy info not provided. Attempt to detect.
if not (proxy_info := self.proxy_infos[address_key]):
if proxy_info := self.provider.network.ecosystem.get_proxy_info(address_key):
self.proxy_infos[address_key] = proxy_info

if proxy_info:
# Contract is a proxy (either was detected or provided).
implementation_contract_type = self.get(proxy_info.target, default=default)
proxy_contract_type = (
self._get_contract_type_from_explorer(address_key)
if fetch_from_explorer
else None
)
if proxy_contract_type is not None and implementation_contract_type is not None:
combined_contract = _get_combined_contract_type(
proxy_contract_type, proxy_info, implementation_contract_type
)
self.contract_types[address_key] = combined_contract
return combined_contract

elif implementation_contract_type is not None:
contract_type_to_cache = implementation_contract_type
self.contract_types[address_key] = implementation_contract_type
return contract_type_to_cache

elif proxy_contract_type is not None:
self.contract_types[address_key] = proxy_contract_type
return proxy_contract_type

# Also gets cached to disk for faster lookup next time.
if fetch_from_explorer:
contract_type = self._get_contract_type_from_explorer(address_key)
# Contract is not cached yet. Check broader sources, such as an explorer.
if not proxy_info and detect_proxy:
# Proxy info not provided. Attempt to detect.
if not (proxy_info := self.proxy_infos[address_key]):
if proxy_info := self.provider.network.ecosystem.get_proxy_info(address_key):
self.proxy_infos[address_key] = proxy_info

if proxy_info:
if proxy_contract_type := self._get_proxy_contract_type(
address_key,
proxy_info,
fetch_from_explorer=fetch_from_explorer,
default=default,
):
# `proxy_contract_type` is one of the following:
# 1. A ContractType with the combined proxy and implementation ABIs
# 2. Implementation-only ABI ContractType (like forwarder proxies)
# 3. Proxy only ABI (e.g. unverified implementation ContractType)
return proxy_contract_type

# Also gets cached to disk for faster lookup next time.
if fetch_from_explorer:
contract_type = self._get_contract_type_from_explorer(address_key)

# Cache locally for faster in-session look-up.
if contract_type:
Expand All @@ -606,6 +597,65 @@ def get(

return contract_type

def _get_proxy_contract_type(
self,
address: AddressType,
proxy_info: ProxyInfoAPI,
fetch_from_explorer: bool = True,
default: Optional[ContractType] = None,
) -> Optional[ContractType]:
"""
Combines the discoverable ABIs from the proxy contract and its implementation.
"""
implementation_contract_type = self._get_contract_type(
proxy_info.target,
fetch_from_explorer=fetch_from_explorer,
default=default,
)
proxy_contract_type = self._get_contract_type(
address, fetch_from_explorer=fetch_from_explorer
)
if proxy_contract_type is not None and implementation_contract_type is not None:
combined_contract = _get_combined_contract_type(
proxy_contract_type, proxy_info, implementation_contract_type
)
self.contract_types[address] = combined_contract
return combined_contract

elif implementation_contract_type is not None:
contract_type_to_cache = implementation_contract_type
self.contract_types[address] = implementation_contract_type
return contract_type_to_cache

elif proxy_contract_type is not None:
# In this case, the implementation ContactType was not discovered.
# However, we were able to discover the ContractType of the proxy.
# Proceed with caching the proxy; the user can update the type later
# when the implementation is discoverable.
self.contract_types[address] = proxy_contract_type
return proxy_contract_type

logger.warning(f"Unable to determine the ContractType for the proxy at '{address}'.")
return None

def _get_contract_type(
self,
address: AddressType,
fetch_from_explorer: bool = True,
default: Optional[ContractType] = None,
) -> Optional[ContractType]:
"""
Get the _exact_ ContractType for a given address. For proxy contracts, returns
the proxy ABIs if there are any and not the implementation ABIs.
"""
if contract_type := self.contract_types[address]:
return contract_type

elif fetch_from_explorer:
return self._get_contract_type_from_explorer(address)

return default

@classmethod
def get_container(cls, contract_type: ContractType) -> ContractContainer:
"""
Expand Down Expand Up @@ -859,6 +909,16 @@ def _get_contract_type_from_explorer(self, address: AddressType) -> Optional[Con

if contract_type:
# Cache contract so faster look-up next time.
if not isinstance(contract_type, ContractType):
explorer_name = self.provider.network.explorer.name
wrong_type = type(contract_type)
wrong_type_str = getattr(wrong_type, "__name__", f"{wrong_type}")
logger.warning(
f"Explorer '{explorer_name}' returned unexpected "
f"type '{wrong_type_str}' ContractType."
)
return None

self.contract_types[address] = contract_type

return contract_type
Expand All @@ -869,16 +929,31 @@ def _get_combined_contract_type(
proxy_info: ProxyInfoAPI,
implementation_contract_type: ContractType,
) -> ContractType:
proxy_abis = [
abi for abi in proxy_contract_type.abi if abi.type in ("error", "event", "function")
]
proxy_abis = _get_relevant_additive_abis(proxy_contract_type)

# Include "hidden" ABIs, such as Safe's `masterCopy()`.
if proxy_info.abi and proxy_info.abi.signature not in [
abi.signature for abi in implementation_contract_type.abi
]:
proxy_abis.append(proxy_info.abi)

combined_contract_type = implementation_contract_type.model_copy(deep=True)
combined_contract_type.abi.extend(proxy_abis)
return combined_contract_type
return _merge_abis(implementation_contract_type, proxy_abis)


def _get_relevant_additive_abis(contract_type: ContractType) -> list[ABI]:
# Get ABIs you would want to add to a base contract as extra,
# such as unique ABIs from proxies.
return [abi for abi in contract_type.abi if abi.type in ("error", "event", "function")]


def _merge_abis(base_contract: ContractType, extra_abis: list[ABI]) -> ContractType:
contract_type = base_contract.model_copy(deep=True)
contract_type.abi.extend(extra_abis)
return contract_type


def _merge_contract_types(
base_contract_type: ContractType, additive_contract_type: ContractType
) -> ContractType:
relevant_abis = _get_relevant_additive_abis(additive_contract_type)
return _merge_abis(base_contract_type, relevant_abis)
7 changes: 7 additions & 0 deletions tests/functional/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,13 @@ def dummy_live_network(chain):
chain.provider.network.name = original_network


@pytest.fixture
def dummy_live_network_with_explorer(dummy_live_network, mock_explorer):
dummy_live_network.__dict__["explorer"] = mock_explorer
yield dummy_live_network
dummy_live_network.__dict__.pop("explorer", None)


@pytest.fixture(scope="session")
def proxy_contract_container(get_contract_type):
return ContractContainer(get_contract_type("proxy"))
Expand Down
12 changes: 6 additions & 6 deletions tests/functional/test_accounts.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,19 +314,19 @@ def test_deploy_and_publish_live_network_no_explorer(owner, contract_container,


@explorer_test
def test_deploy_and_publish(owner, contract_container, dummy_live_network, mock_explorer):
dummy_live_network.__dict__["explorer"] = mock_explorer
def test_deploy_and_publish(
owner, contract_container, dummy_live_network_with_explorer, mock_explorer
):
contract = owner.deploy(contract_container, 0, publish=True, required_confirmations=0)
mock_explorer.publish_contract.assert_called_once_with(contract.address)
dummy_live_network.__dict__["explorer"] = None


@explorer_test
def test_deploy_and_not_publish(owner, contract_container, dummy_live_network, mock_explorer):
dummy_live_network.__dict__["explorer"] = mock_explorer
def test_deploy_and_not_publish(
owner, contract_container, dummy_live_network_with_explorer, mock_explorer
):
owner.deploy(contract_container, 0, publish=True, required_confirmations=0)
assert not mock_explorer.call_count
dummy_live_network.__dict__["explorer"] = None


def test_deploy_proxy(owner, vyper_contract_instance, proxy_contract_container, chain):
Expand Down
15 changes: 8 additions & 7 deletions tests/functional/test_contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,13 @@ def test_Contract_at_unknown_address(networks_connected_to_tester, address):


def test_Contract_specify_contract_type(
solidity_contract_instance, vyper_contract_type, owner, networks_connected_to_tester
vyper_contract_instance, solidity_contract_type, owner, networks_connected_to_tester
):
# Vyper contract type is very close to solidity's.
# Solidity's contract type is very close to Vyper's.
# This test purposely uses the other just to show we are able to specify it externally.
contract = Contract(solidity_contract_instance.address, contract_type=vyper_contract_type)
assert contract.address == solidity_contract_instance.address
assert contract.contract_type == vyper_contract_type
assert contract.setNumber(2, sender=owner)
assert contract.myNumber() == 2
contract = Contract(vyper_contract_instance.address, contract_type=solidity_contract_type)
assert contract.address == vyper_contract_instance.address

abis = [abi.name for abi in contract.contract_type.abi if hasattr(abi, "name")]
assert "setNumber" in abis # Shared ABI.
assert "ACustomError" in abis # SolidityContract-defined ABI.
15 changes: 9 additions & 6 deletions tests/functional/test_contract_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
ProjectError,
)
from ape_ethereum.ecosystem import ProxyType
from tests.conftest import explorer_test


def test_deploy(
Expand Down Expand Up @@ -55,18 +56,20 @@ def test_deploy_and_publish_live_network_no_explorer(owner, contract_container,
contract_container.deploy(0, sender=owner, publish=True, required_confirmations=0)


def test_deploy_and_publish(owner, contract_container, dummy_live_network, mock_explorer):
dummy_live_network.__dict__["explorer"] = mock_explorer
@explorer_test
def test_deploy_and_publish(
owner, contract_container, dummy_live_network_with_explorer, mock_explorer
):
contract = contract_container.deploy(0, sender=owner, publish=True, required_confirmations=0)
mock_explorer.publish_contract.assert_called_once_with(contract.address)
dummy_live_network.__dict__["explorer"] = None


def test_deploy_and_not_publish(owner, contract_container, dummy_live_network, mock_explorer):
dummy_live_network.__dict__["explorer"] = mock_explorer
@explorer_test
def test_deploy_and_not_publish(
owner, contract_container, dummy_live_network_with_explorer, mock_explorer
):
contract_container.deploy(0, sender=owner, publish=False, required_confirmations=0)
assert not mock_explorer.call_count
dummy_live_network.__dict__["explorer"] = None


def test_deploy_privately(owner, contract_container):
Expand Down
6 changes: 3 additions & 3 deletions tests/functional/test_contract_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -880,9 +880,9 @@ def test_value_to_non_payable_fallback_and_no_receive(
break

new_contract_type = ContractType.model_validate(contract_type_data)
contract = owner.chain_manager.contracts.instance_at(
vyper_fallback_contract.address, contract_type=new_contract_type
)
contract = owner.chain_manager.contracts.instance_at(vyper_fallback_contract.address)
contract.contract_type = new_contract_type # Setting to completely override instead of merge.

expected = (
r"Contract's fallback is non-payable and there is no receive ABI\. Unable to send value\."
)
Expand Down
Loading
Loading