Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 63 additions & 24 deletions tests/test_erc20_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,68 @@

from utils.erc20_metadata import ERC20Metadata, fetch_erc20_metadata, reset_cache

# Bytecode fragments used by the gate. A "token" must contain both the symbol()
# (95d89b41) and decimals() (313ce567) selectors; anything else is non-token.
_TOKEN_CODE = bytes.fromhex("608060405295d89b41313ce567")
_NON_TOKEN_CODE = bytes.fromhex("6080604052")

USDC = "0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48"


def _client_with_code(code: bytes) -> MagicMock:
"""Build a mock Web3 client whose get_code returns ``code`` and whose batch
yields ("USDC", 6)."""
client = MagicMock()
client.eth.get_code.return_value = code
client.execute_batch.return_value = ("USDC", 6)
client.batch_requests.return_value.__enter__.return_value = MagicMock()
client.batch_requests.return_value.__exit__.return_value = False
return client


class TestFetchErc20Metadata(unittest.TestCase):
def setUp(self) -> None:
reset_cache()

@patch("utils.erc20_metadata.ChainManager")
def test_returns_symbol_and_decimals(self, mock_cm: MagicMock) -> None:
client = MagicMock()
client.execute_batch.return_value = ("USDC", 6)
client.batch_requests.return_value.__enter__.return_value = MagicMock()
client.batch_requests.return_value.__exit__.return_value = False
mock_cm.get_client.return_value = _client_with_code(_TOKEN_CODE)
meta = fetch_erc20_metadata(1, USDC)
self.assertEqual(meta, ERC20Metadata(symbol="USDC", decimals=6))

@patch("utils.erc20_metadata.ChainManager")
def test_skips_eoa_without_calling(self, mock_cm: MagicMock) -> None:
"""No deployed code => EOA => return None without a symbol() call."""
client = _client_with_code(b"")
mock_cm.get_client.return_value = client
self.assertIsNone(fetch_erc20_metadata(1, "0x" + "ab" * 20))
client.batch_requests.assert_not_called()

meta = fetch_erc20_metadata(1, "0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48")
self.assertEqual(meta, ERC20Metadata(symbol="USDC", decimals=6))
@patch("utils.erc20_metadata.get_current_implementation", return_value=None)
@patch("utils.erc20_metadata.ChainManager")
def test_skips_non_token_contract_without_calling(self, mock_cm: MagicMock, _impl: MagicMock) -> None:
"""Contract without the selectors (and not a proxy) => skip the call."""
client = _client_with_code(_NON_TOKEN_CODE)
mock_cm.get_client.return_value = client
self.assertIsNone(fetch_erc20_metadata(1, "0x" + "cd" * 20))
client.batch_requests.assert_not_called()

@patch("utils.erc20_metadata.ChainManager")
def test_returns_none_on_eth_call_failure(self, mock_cm: MagicMock) -> None:
client = MagicMock()
client.batch_requests.side_effect = RuntimeError("execution reverted")
def test_resolves_proxy_implementation(self, mock_cm: MagicMock) -> None:
"""A proxy stub carries no selectors; we resolve the impl and scan it."""
proxy = "0x" + "11" * 20
impl = "0x" + "22" * 20

def code_for(addr: str) -> bytes:
return _TOKEN_CODE if addr.lower() == impl.lower() else _NON_TOKEN_CODE

client = _client_with_code(_NON_TOKEN_CODE)
client.eth.get_code.side_effect = lambda addr: code_for(addr)
mock_cm.get_client.return_value = client

meta = fetch_erc20_metadata(1, "0x" + "ab" * 20)
self.assertIsNone(meta)
with patch("utils.erc20_metadata.get_current_implementation", return_value=impl):
meta = fetch_erc20_metadata(1, proxy)
self.assertEqual(meta, ERC20Metadata(symbol="USDC", decimals=6))

def test_invalid_address_skips_network(self) -> None:
with patch("utils.erc20_metadata.ChainManager") as mock_cm:
Expand All @@ -38,29 +76,30 @@ def test_invalid_address_skips_network(self) -> None:
mock_cm.get_client.assert_not_called()

@patch("utils.erc20_metadata.ChainManager")
def test_caches_repeat_lookups(self, mock_cm: MagicMock) -> None:
client = MagicMock()
client.execute_batch.return_value = ("USDC", 6)
client.batch_requests.return_value.__enter__.return_value = MagicMock()
client.batch_requests.return_value.__exit__.return_value = False
def test_returns_none_on_eth_call_failure(self, mock_cm: MagicMock) -> None:
"""Gate passes but the metadata call reverts => None (backstop catch)."""
client = _client_with_code(_TOKEN_CODE)
client.batch_requests.side_effect = RuntimeError("execution reverted")
mock_cm.get_client.return_value = client
self.assertIsNone(fetch_erc20_metadata(1, USDC))

addr = "0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48"
fetch_erc20_metadata(1, addr)
fetch_erc20_metadata(1, addr)
@patch("utils.erc20_metadata.ChainManager")
def test_caches_repeat_lookups(self, mock_cm: MagicMock) -> None:
client = _client_with_code(_TOKEN_CODE)
mock_cm.get_client.return_value = client
fetch_erc20_metadata(1, USDC)
fetch_erc20_metadata(1, USDC)
self.assertEqual(client.batch_requests.call_count, 1)

@patch("utils.erc20_metadata.ChainManager")
def test_caches_misses(self, mock_cm: MagicMock) -> None:
client = MagicMock()
client.batch_requests.side_effect = RuntimeError("not a token")
"""A cached miss must not re-hit the network (get_code called once)."""
client = _client_with_code(b"")
mock_cm.get_client.return_value = client

addr = "0x" + "cd" * 20
self.assertIsNone(fetch_erc20_metadata(1, addr))
self.assertIsNone(fetch_erc20_metadata(1, addr))
# Cached miss → only one attempt.
self.assertEqual(client.batch_requests.call_count, 1)
self.assertEqual(client.eth.get_code.call_count, 1)


if __name__ == "__main__":
Expand Down
56 changes: 54 additions & 2 deletions utils/erc20_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@

from dataclasses import dataclass

from eth_utils import to_checksum_address
from eth_utils import function_signature_to_4byte_selector, to_checksum_address

from utils.abi import load_abi
from utils.chains import Chain
from utils.logging import get_logger
from utils.proxy import get_current_implementation
from utils.web3_wrapper import ChainManager

logger = get_logger("utils.erc20_metadata")
Expand All @@ -25,6 +26,12 @@
# Per-process cache: (chain_id, address_lower) -> ERC20Metadata or None for miss.
_cache: dict[tuple[int, str], "ERC20Metadata | None"] = {}

# Selectors an ERC20 must dispatch. We only ever call symbol()/decimals() when
# the contract bytecode actually contains both — never a blind eth_call we expect
# to fail. Stored as bare lowercase hex (no 0x) to substring-match raw bytecode.
_SYMBOL_SELECTOR = function_signature_to_4byte_selector("symbol()").hex() # 95d89b41
_DECIMALS_SELECTOR = function_signature_to_4byte_selector("decimals()").hex() # 313ce567


@dataclass(frozen=True)
class ERC20Metadata:
Expand Down Expand Up @@ -55,7 +62,16 @@ def fetch_erc20_metadata(chain_id: int, address: str) -> ERC20Metadata | None:
try:
chain = Chain.from_chain_id(chain_id)
client = ChainManager.get_client(chain)
token = client.get_contract(to_checksum_address(address), _ERC20_ABI)
checksum = to_checksum_address(address)

# Gate: only call symbol()/decimals() when the bytecode proves the
# contract dispatches them. EOAs and non-token contracts are skipped
# without a blind eth_call.
if not _dispatches_token_metadata(chain_id, client, checksum):
_cache[cache_key] = None
return None

token = client.get_contract(checksum, _ERC20_ABI)
with client.batch_requests() as batch:
batch.add(token.functions.symbol())
batch.add(token.functions.decimals())
Expand All @@ -70,6 +86,42 @@ def fetch_erc20_metadata(chain_id: int, address: str) -> ERC20Metadata | None:
return meta


def _code_hex(client, address: str) -> str:
"""Return deployed bytecode at ``address`` as bare lowercase hex ("" if none)."""
raw = client.eth.get_code(to_checksum_address(address))
code = raw.hex()
if code.startswith("0x"):
code = code[2:]
return code.lower()


def _has_token_selectors(code: str) -> bool:
"""True if bytecode dispatches both symbol() and decimals()."""
return bool(code) and _SYMBOL_SELECTOR in code and _DECIMALS_SELECTOR in code


def _dispatches_token_metadata(chain_id: int, client, checksum: str) -> bool:
"""Positive-evidence ERC20 check via bytecode inspection.

Returns True only when the contract — or, for proxies, its implementation —
contains both the symbol() and decimals() selectors. EOAs and non-token
contracts return False so we never blind-call functions they can't serve.

Proxy stubs delegate through a fallback and carry none of the impl's
selectors, so a bare scan would false-negative proxy tokens (e.g. USDC).
We resolve the implementation and scan its bytecode too.
"""
code = _code_hex(client, checksum)
if not code:
return False # EOA / no deployed code
if _has_token_selectors(code):
return True
impl = get_current_implementation(checksum, chain_id)
if impl:
return _has_token_selectors(_code_hex(client, impl))
return False


def reset_cache() -> None:
"""Reset the in-memory metadata cache. Useful for tests."""
_cache.clear()