diff --git a/protocol/v2/ssv/spectest/msg_processing_type.go b/protocol/v2/ssv/spectest/msg_processing_type.go index 1f576d6765..0ae0713f62 100644 --- a/protocol/v2/ssv/spectest/msg_processing_type.go +++ b/protocol/v2/ssv/spectest/msg_processing_type.go @@ -148,8 +148,13 @@ func (test *MsgProcessingSpecTest) runPreTesting(ctx context.Context, logger *za lastErr = err } if test.DecidedSlashable && IsQBFTProposalMessage(msg) { + consensusMsg, err := specqbft.DecodeMessage(msg.SSVMessage.Data) + if err != nil { + panic(err) + } + slot := phase0.Slot(consensusMsg.Height) for _, validatorShare := range test.Runner.GetShares() { - test.Runner.GetSigner().(*ekm.TestingKeyManagerAdapter).AddSlashableSlot(validatorShare.SharePubKey, spectestingutils.TestingDutySlot) + test.Runner.GetSigner().(*ekm.TestingKeyManagerAdapter).AddSlashableSlot(validatorShare.SharePubKey, slot) } } } @@ -185,6 +190,7 @@ func (test *MsgProcessingSpecTest) RunAsPartOfMultiTest(t *testing.T, logger *za network := &spectestingutils.TestingNetwork{} var beaconNetwork *protocoltesting.BeaconNodeWrapped var committee []*spectypes.Operator + actualRunner := test.Runner switch test.Runner.(type) { case *runner.CommitteeRunner: @@ -193,6 +199,7 @@ func (test *MsgProcessingSpecTest) RunAsPartOfMultiTest(t *testing.T, logger *za runnerInstance = runner break } + actualRunner = runnerInstance network = runnerInstance.GetNetwork().(*spectestingutils.TestingNetwork) beaconNetwork = runnerInstance.GetBeaconNode().(*protocoltesting.BeaconNodeWrapped) committee = c.CommitteeMember.Committee @@ -215,12 +222,18 @@ func (test *MsgProcessingSpecTest) RunAsPartOfMultiTest(t *testing.T, logger *za assertRootsRelaxed(t, test.BeaconBroadcastedRoots, beaconNetwork.GetBroadcastedRoots()) // post root - postRoot, err := test.Runner.GetRoot() + if !test.DontStartDuty { + if proposerRunner, ok := actualRunner.(*runner.ProposerRunner); ok { + normalizeExpectedProposerStartValues(proposerRunner) + } + } + postRoot, err := actualRunner.GetRoot() require.NoError(t, err) - if test.PostDutyRunnerStateRoot != hex.EncodeToString(postRoot[:]) { - diff := dumpState(t, test.Name, test.Runner, test.PostDutyRunnerState) - logger.Error("post runner state not equal", zap.String("state", diff)) + actualPostRoot := hex.EncodeToString(postRoot[:]) + if test.PostDutyRunnerStateRoot != actualPostRoot { + diff := dumpState(t, test.Name, actualRunner, test.PostDutyRunnerState) + require.EqualValues(t, test.PostDutyRunnerStateRoot, actualPostRoot, "post runner state not equal\n%s\n", diff) } } @@ -232,6 +245,11 @@ func (test *MsgProcessingSpecTest) overrideStateComparison(t *testing.T) { func overrideStateComparison(t *testing.T, test *MsgProcessingSpecTest, name string, testType string) { r := runnerForTest(t, test.Runner, name, testType) + if !test.DontStartDuty { + if proposerRunner, ok := r.(*runner.ProposerRunner); ok { + normalizeExpectedProposerStartValues(proposerRunner) + } + } test.PostDutyRunnerState = r diff --git a/protocol/v2/ssv/spectest/util.go b/protocol/v2/ssv/spectest/util.go index 485708fa0b..a7927f13f7 100644 --- a/protocol/v2/ssv/spectest/util.go +++ b/protocol/v2/ssv/spectest/util.go @@ -4,11 +4,13 @@ import ( "path/filepath" "testing" + spectypes "github.com/ssvlabs/ssv-spec/types" typescomparable "github.com/ssvlabs/ssv-spec/types/testingutils/comparable" "github.com/stretchr/testify/require" "github.com/ssvlabs/ssv/ibft/storage" "github.com/ssvlabs/ssv/networkconfig" + blindutil "github.com/ssvlabs/ssv/protocol/v2/blockchain/beacon/blind" "github.com/ssvlabs/ssv/protocol/v2/ssv/runner" ) @@ -93,3 +95,61 @@ func runnerForTest(t *testing.T, runnerType runner.Runner, name string, testType return r } + +func normalizeExpectedProposerStartValues(pr *runner.ProposerRunner) { + if pr == nil || pr.BaseRunner == nil { + return + } + if state := pr.BaseRunner.State; state != nil { + state.DecidedValue = normalizeProposerConsensusValue(state.DecidedValue) + if state.RunningInstance != nil { + state.RunningInstance.StartValue = normalizeProposerConsensusValue(state.RunningInstance.StartValue) + if state.RunningInstance.State != nil { + state.RunningInstance.State.LastPreparedValue = normalizeProposerConsensusValue(state.RunningInstance.State.LastPreparedValue) + state.RunningInstance.State.DecidedValue = normalizeProposerConsensusValue(state.RunningInstance.State.DecidedValue) + } + } + } + if pr.BaseRunner.QBFTController == nil { + return + } + for _, inst := range pr.BaseRunner.QBFTController.StoredInstances { + if inst == nil { + continue + } + inst.StartValue = normalizeProposerConsensusValue(inst.StartValue) + if inst.State != nil { + inst.State.LastPreparedValue = normalizeProposerConsensusValue(inst.State.LastPreparedValue) + inst.State.DecidedValue = normalizeProposerConsensusValue(inst.State.DecidedValue) + } + } +} + +func normalizeProposerConsensusValue(value []byte) []byte { + if len(value) == 0 { + return value + } + cd := &spectypes.ValidatorConsensusData{} + if err := cd.Decode(value); err != nil { + return value + } + vBlk, _, err := cd.GetBlockData() + if err != nil { + return value + } + blindedVBlk, blindedMarshaler, err := blindutil.EnsureBlinded(vBlk) + if err != nil { + return value + } + blindedDataSSZ, err := blindedMarshaler.MarshalSSZ() + if err != nil { + return value + } + cd.Version = blindedVBlk.Version + cd.DataSSZ = blindedDataSSZ + encoded, err := cd.Encode() + if err != nil { + return value + } + return encoded +}