Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 2 additions & 43 deletions tests/core/pyspec/eth2spec/test/context.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
import importlib
from collections.abc import Callable, Sequence
from collections.abc import Sequence
from copy import deepcopy
from dataclasses import dataclass
from random import Random
from typing import Any

import pytest
from frozendict import frozendict
from lru import LRU

from eth2spec.utils import bls
from tests.infra.context import with_custom_state
from tests.infra.yield_generator import vector_test

from .exceptions import SkippedTest
Expand All @@ -32,13 +32,11 @@
POST_FORK_OF,
)
from .helpers.forks import is_post_electra, is_post_fork
from .helpers.genesis import create_genesis_state
from .helpers.specs import (
spec_targets,
)
from .helpers.typing import (
Spec,
SpecForks,
)
from .utils import (
with_meta_tags,
Expand All @@ -58,45 +56,6 @@ class ForkMeta:
fork_epoch: int


def _prepare_state(
balances_fn: Callable[[Any], Sequence[int]],
threshold_fn: Callable[[Any], int],
spec: Spec,
phases: SpecForks,
):
balances = balances_fn(spec)
activation_threshold = threshold_fn(spec)
state = create_genesis_state(
spec=spec, validator_balances=balances, activation_threshold=activation_threshold
)
return state


_custom_state_cache_dict = LRU(size=10)


def with_custom_state(
balances_fn: Callable[[Any], Sequence[int]], threshold_fn: Callable[[Any], int]
):
def deco(fn):
def entry(*args, spec: Spec, phases: SpecForks, **kw):
# make a key for the state, unique to the fork + config (incl preset choice) and balances/activations
key = (spec.fork, spec.config.__hash__(), spec.__file__, balances_fn, threshold_fn)
if key not in _custom_state_cache_dict:
state = _prepare_state(balances_fn, threshold_fn, spec, phases)
_custom_state_cache_dict[key] = state.get_backing()

# Take an entry out of the LRU.
# No copy is necessary, as we wrap the immutable backing with a new view.
state = spec.BeaconState(backing=_custom_state_cache_dict[key])
kw["state"] = state
return fn(*args, spec=spec, phases=phases, **kw)

return entry

return deco


def default_activation_threshold(spec: Spec):
"""
Helper method to use the default balance activation threshold for state creation for tests.
Expand Down
59 changes: 59 additions & 0 deletions tests/infra/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from collections.abc import Callable, Sequence
from typing import Any

from lru import LRU

from tests.core.pyspec.eth2spec.test.helpers.genesis import create_genesis_state
from tests.core.pyspec.eth2spec.test.helpers.typing import Spec, SpecForks


def _prepare_state(
balances_fn: Callable[[Any], Sequence[int]],
threshold_fn: Callable[[Any], int],
spec: Spec,
phases: SpecForks,
):
balances = balances_fn(spec)
activation_threshold = threshold_fn(spec)
state = create_genesis_state(
spec=spec,
validator_balances=balances,
activation_threshold=activation_threshold,
)
return state


_custom_state_cache_dict = LRU(size=10)


def with_custom_state(
balances_fn: Callable[[Any], Sequence[int]], threshold_fn: Callable[[Any], int]
):
"""
Decorator that provides a cached BeaconState constructed from custom balances
and activation threshold functions. The cache key is a tuple of:
(spec.fork, spec.config.__hash__(), spec.__file__, balances_fn, threshold_fn)
The cached value stores the immutable state backing to enable fast view reconstruction.
"""

def deco(fn):
def entry(*args, spec: Spec, phases: SpecForks, **kw):
key = (
spec.fork,
spec.config.__hash__(),
spec.__file__,
balances_fn,
threshold_fn,
)
if key not in _custom_state_cache_dict:
state = _prepare_state(balances_fn, threshold_fn, spec, phases)
_custom_state_cache_dict[key] = state.get_backing()

# Wrap cached immutable backing with a fresh view
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it was important in this comment, that it is reason we don't need to copy

state = spec.BeaconState(backing=_custom_state_cache_dict[key])
kw["state"] = state
return fn(*args, spec=spec, phases=phases, **kw)

return entry

return deco
121 changes: 121 additions & 0 deletions tests/infra/test_with_custom_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from eth2spec.test.context import (
default_activation_threshold,
default_balances,
with_custom_state,
)
from eth2spec.test.helpers.constants import MINIMAL, PHASE0
from eth2spec.test.helpers.specs import spec_targets


class TestWithCustomState:
"""Test suite for with_custom_state decorator."""

def test_with_custom_state_injects_state_view(self):
"""Test that the decorator injects a BeaconState with expected properties."""
spec = spec_targets[MINIMAL][PHASE0]

@with_custom_state(default_balances, default_activation_threshold)
def test_case(*, spec, phases, state):
# Verify the state is properly initialized
assert len(state.validators) > 0
assert len(state.balances) > 0
# Verify balances match default_balances (MAX_EFFECTIVE_BALANCE)
assert all(b == spec.MAX_EFFECTIVE_BALANCE for b in state.balances)
# Verify validators are activated (balance >= threshold)
assert all(v.activation_epoch == spec.GENESIS_EPOCH for v in state.validators)
return state

state = test_case(spec=spec, phases={})
assert state is not None
# Verify state properties outside the decorated function
assert len(state.validators) == spec.SLOTS_PER_EPOCH * 8
assert state.fork.current_version == spec.config.GENESIS_FORK_VERSION

def test_with_custom_state_custom_balances(self):
"""Test that custom balances are applied to the state."""
spec = spec_targets[MINIMAL][PHASE0]
custom_balance = spec.MAX_EFFECTIVE_BALANCE * 2

def custom_balances(spec):
return [custom_balance] * 4 # 4 validators

@with_custom_state(custom_balances, default_activation_threshold)
def test_case(*, spec, phases, state):
return state

state = test_case(spec=spec, phases={})
assert len(state.balances) == 4
assert all(balance == custom_balance for balance in state.balances)

def test_with_custom_state_custom_activation_threshold(self):
"""Test that custom activation threshold is applied."""
spec = spec_targets[MINIMAL][PHASE0]

# Case 1: Low threshold -> Validators should be active
low_threshold = 100

def low_threshold_fn(spec):
return low_threshold

@with_custom_state(default_balances, low_threshold_fn)
def test_case_active(*, spec, phases, state):
# The activation threshold is low, so validators should be active
assert all(v.activation_epoch == spec.GENESIS_EPOCH for v in state.validators)
return state

state_active = test_case_active(spec=spec, phases={})
assert state_active is not None

# Case 2: High threshold -> Validators should NOT be active
# Set threshold higher than default balance (MAX_EFFECTIVE_BALANCE)
high_threshold = spec.MAX_EFFECTIVE_BALANCE + 1

def high_threshold_fn(spec):
return high_threshold

@with_custom_state(default_balances, high_threshold_fn)
def test_case_inactive(*, spec, phases, state):
# The activation threshold is high, so validators should NOT be active
assert all(v.activation_epoch == spec.FAR_FUTURE_EPOCH for v in state.validators)
return state

state_inactive = test_case_inactive(spec=spec, phases={})
assert state_inactive is not None

def test_with_custom_state_with_phases(self):
"""
Test that the decorator works with phases parameter.

The decorator wraps the test function and must ensure that arguments
provided by the test runner (like 'phases') are correctly passed through
to the inner function.
"""
spec = spec_targets[MINIMAL][PHASE0]
phases = {"phase0": spec}

@with_custom_state(default_balances, default_activation_threshold)
def test_case(*, spec, phases, state):
assert phases is not None
assert "phase0" in phases
return state

state = test_case(spec=spec, phases=phases)
assert state is not None

def test_with_custom_state_multiple_calls(self):
"""Test that multiple decorated functions work independently."""
spec = spec_targets[MINIMAL][PHASE0]

balance1 = spec.MAX_EFFECTIVE_BALANCE
balance2 = spec.MAX_EFFECTIVE_BALANCE * 2

@with_custom_state(lambda _: [balance1], default_activation_threshold)
def test_case1(*, spec, phases, state):
return state.balances[0]

@with_custom_state(lambda _: [balance2], default_activation_threshold)
def test_case2(*, spec, phases, state):
return state.balances[0]

assert test_case1(spec=spec, phases={}) == balance1
assert test_case2(spec=spec, phases={}) == balance2
Loading