Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 2 additions & 43 deletions tests/core/pyspec/eth2spec/test/context.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import importlib
from collections.abc import Callable, Sequence
from collections.abc import Sequence
from copy import deepcopy
from dataclasses import dataclass
from random import Random
from typing import Any

import pytest
from frozendict import frozendict
from lru import LRU

from eth2spec.utils import bls
from tests.infra.context import with_custom_state
from tests.infra.yield_generator import vector_test

from .exceptions import SkippedTest
Expand All @@ -32,13 +32,11 @@
POST_FORK_OF,
)
from .helpers.forks import is_post_electra, is_post_fork
from .helpers.genesis import create_genesis_state
from .helpers.specs import (
spec_targets,
)
from .helpers.typing import (
Spec,
SpecForks,
)
from .utils import (
with_meta_tags,
Expand All @@ -58,45 +56,6 @@ class ForkMeta:
fork_epoch: int


def _prepare_state(
balances_fn: Callable[[Any], Sequence[int]],
threshold_fn: Callable[[Any], int],
spec: Spec,
phases: SpecForks,
):
balances = balances_fn(spec)
activation_threshold = threshold_fn(spec)
state = create_genesis_state(
spec=spec, validator_balances=balances, activation_threshold=activation_threshold
)
return state


_custom_state_cache_dict = LRU(size=10)


def with_custom_state(
balances_fn: Callable[[Any], Sequence[int]], threshold_fn: Callable[[Any], int]
):
def deco(fn):
def entry(*args, spec: Spec, phases: SpecForks, **kw):
# make a key for the state, unique to the fork + config (incl preset choice) and balances/activations
key = (spec.fork, spec.config.__hash__(), spec.__file__, balances_fn, threshold_fn)
if key not in _custom_state_cache_dict:
state = _prepare_state(balances_fn, threshold_fn, spec, phases)
_custom_state_cache_dict[key] = state.get_backing()

# Take an entry out of the LRU.
# No copy is necessary, as we wrap the immutable backing with a new view.
state = spec.BeaconState(backing=_custom_state_cache_dict[key])
kw["state"] = state
return fn(*args, spec=spec, phases=phases, **kw)

return entry

return deco


def default_activation_threshold(spec: Spec):
"""
Helper method to use the default balance activation threshold for state creation for tests.
Expand Down
46 changes: 46 additions & 0 deletions tests/infra/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from collections.abc import Callable, Sequence
from typing import Any

from lru import LRU

from eth2spec.test.helpers.genesis import create_genesis_state
from eth2spec.test.helpers.typing import Spec, SpecForks


def _prepare_state(
balances_fn: Callable[[Any], Sequence[int]],
threshold_fn: Callable[[Any], int],
spec: Spec,
phases: SpecForks,
):
balances = balances_fn(spec)
activation_threshold = threshold_fn(spec)
state = create_genesis_state(
spec=spec, validator_balances=balances, activation_threshold=activation_threshold
)
return state


_custom_state_cache_dict = LRU(size=10)


def with_custom_state(
balances_fn: Callable[[Any], Sequence[int]], threshold_fn: Callable[[Any], int]
):
def deco(fn):
def entry(*args, spec: Spec, phases: SpecForks, **kw):
# make a key for the state, unique to the fork + config (incl preset choice) and balances/activations
key = (spec.fork, spec.config.__hash__(), spec.__file__, balances_fn, threshold_fn)
if key not in _custom_state_cache_dict:
state = _prepare_state(balances_fn, threshold_fn, spec, phases)
_custom_state_cache_dict[key] = state.get_backing()

# Take an entry out of the LRU.
# No copy is necessary, as we wrap the immutable backing with a new view.
state = spec.BeaconState(backing=_custom_state_cache_dict[key])
kw["state"] = state
return fn(*args, spec=spec, phases=phases, **kw)

return entry

return deco
48 changes: 48 additions & 0 deletions tests/infra/test_custom_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from eth2spec.phase0 import spec as phase0_spec
from eth2spec.test.helpers.typing import Spec, SpecForks
from tests.infra.context import with_custom_state


def test_custom_state_matrix():
"""
Verifies with_custom_state with various inputs.
Checks the expected balances.
"""
test_cases = [
# Case 1: Standard 32 ETH threshold
{
"give": 100 * 10**9,
"threshold": 32 * 10**9,
"expected_balance": 100 * 10**9,
},
# Case 2: Custom 16 ETH threshold
{
"give": 55 * 10**9,
"threshold": 16 * 10**9,
"expected_balance": 55 * 10**9,
},
# Case 3: Boundary Condition (Balance == Threshold)
{
"give": 32 * 10**9,
"threshold": 32 * 10**9,
"expected_balance": 32 * 10**9,
},
]

for i, case in enumerate(test_cases):
# Prepare the helpers using lambda - turn the numbers to functions that the decorator expects
balance_fn = lambda spec: [case["give"]]
threshold_fn = lambda spec: case["threshold"]

# Define the decorated function dynamically
@with_custom_state(balances_fn=balance_fn, threshold_fn=threshold_fn)
def check_state_logic(spec: Spec, phases: SpecForks, state, **kwargs):
return state.balances[0]

phases = {phase0_spec.fork: phase0_spec}
result = check_state_logic(spec=phase0_spec, phases=phases)

error_message = (
f"Failed on case {i + 1}: Given {case['give']}, Expected {case['expected_balance']}"
)
assert result == case["expected_balance"], error_message