diff --git a/boc/cell.go b/boc/cell.go index cb7e06f..ddc8244 100644 --- a/boc/cell.go +++ b/boc/cell.go @@ -436,5 +436,37 @@ func (c *Cell) GetMerkleRoot() ([32]byte, error) { var hash [32]byte copy(hash[:], bytes[1:]) return hash, nil +} + +// TODO: move to deserializer +func (c *Cell) isValidMerkleProofCell() bool { + return c.cellType == MerkleProofCell && c.RefsSize() == 1 && c.BitSize() == 280 +} + +func (c *Cell) CalculateMerkleProofMeta() (int, [32]byte, error) { + if !c.isValidMerkleProofCell() { + return 0, [32]byte{}, errors.New("not valid merkle proof cell") + } + imc, err := newImmutableCell(c.Refs()[0], map[*Cell]*immutableCell{}) + if err != nil { + return 0, [32]byte{}, fmt.Errorf("get immutable cell: %w", err) + } + h := imc.Hash(0) + var hash [32]byte + copy(hash[:], h) + depth := imc.Depth(0) + return depth, hash, nil +} +// TODO: or add level as optional parameter to Hash256() +func (c *Cell) Hash256WithLevel(level int) ([32]byte, error) { + // TODO: or check for pruned cell and read hash directly from cell + imc, err := newImmutableCell(c, map[*Cell]*immutableCell{}) + if err != nil { + return [32]byte{}, err + } + b := imc.Hash(level) + var h [32]byte + copy(h[:], b) + return h, nil } diff --git a/config/config.go b/config/config.go index 24f0cbc..7bc3a05 100644 --- a/config/config.go +++ b/config/config.go @@ -4,6 +4,7 @@ import ( "encoding/binary" "encoding/json" "fmt" + "github.com/tonkeeper/tongo/ton" "io" "os" ) @@ -19,15 +20,28 @@ type liteServerId struct { Key string `json:"key"` } +type initBlockConfig struct { + Workchain int32 `json:"workchain"` + Shard int64 `json:"shard"` + Seqno int64 `json:"seqno"` + RootHash []byte `json:"root_hash"` + FileHash []byte `json:"file_hash"` +} + +type validatorConfig struct { + InitBlock initBlockConfig `json:"init_block"` +} + type configGlobal struct { LiteServers []liteServerConfig `json:"liteservers"` - //Validator ValidatorConfig `json:"validator"` + Validator validatorConfig `json:"validator"` } // GlobalConfigurationFile contains global configuration of the TON Blockchain. // It is shared by all nodes and includes information about network, init block, hardforks, etc. type GlobalConfigurationFile struct { LiteServers []LiteServer + Validator Validator } // LiteServer TODO: clarify struct @@ -36,6 +50,10 @@ type LiteServer struct { Key string } +type Validator struct { + InitBlock ton.BlockIDExt +} + func ParseConfigFile(path string) (*GlobalConfigurationFile, error) { jsonFile, err := os.Open(path) if err != nil { @@ -74,6 +92,19 @@ func ParseConfig(data io.Reader) (*GlobalConfigurationFile, error) { } options.LiteServers = append(options.LiteServers, ls) } + var rootHash [32]byte + copy(rootHash[:], conf.Validator.InitBlock.RootHash) + var fileHash [32]byte + copy(fileHash[:], conf.Validator.InitBlock.FileHash) + options.Validator.InitBlock = ton.BlockIDExt{ + BlockID: ton.BlockID{ + Workchain: conf.Validator.InitBlock.Workchain, + Shard: uint64(conf.Validator.InitBlock.Shard), + Seqno: uint32(conf.Validator.InitBlock.Seqno), + }, + RootHash: rootHash, + FileHash: fileHash, + } if len(options.LiteServers) == 0 { return nil, fmt.Errorf("no one supported liteservers") } diff --git a/liteapi/account.go b/liteapi/account.go new file mode 100644 index 0000000..acb123e --- /dev/null +++ b/liteapi/account.go @@ -0,0 +1,147 @@ +package liteapi + +import ( + "context" + "errors" + "fmt" + "github.com/tonkeeper/tongo/boc" + "github.com/tonkeeper/tongo/tlb" + "github.com/tonkeeper/tongo/ton" +) + +// GetAccountWithProof +// For safe operation, always use GetAccountWithProof with WithBlock(proofedBlock ton.BlockIDExt), as the proof of masterchain cashed blocks is not implemented yet! +func (c *Client) GetAccountWithProof(ctx context.Context, accountID ton.AccountID) (*tlb.ShardAccount, *tlb.ShardStateUnsplit, error) { + res, err := c.GetAccountStateRaw(ctx, accountID) // TODO: add proof check for masterHead + if err != nil { + return nil, nil, err + } + blockID := res.Id.ToBlockIdExt() + if len(res.Proof) == 0 { + return nil, nil, errors.New("empty proof") + } + + var blockHash ton.Bits256 + if (accountID.Workchain == -1 && blockID.Workchain == -1) || blockID == res.Shardblk.ToBlockIdExt() { + blockHash = blockID.RootHash + } else { + if len(res.ShardProof) == 0 { + return nil, nil, errors.New("empty shard proof") + } + if res.Shardblk.RootHash == [32]byte{} { // TODO: how to check for empty shard? + return nil, nil, errors.New("shard block not passed") + } + shardHash := ton.Bits256(res.Shardblk.RootHash) + if _, err := checkShardInMasterProof(blockID, res.ShardProof, accountID.Workchain, shardHash); err != nil { + return nil, nil, fmt.Errorf("shard proof is incorrect: %w", err) + } + blockHash = shardHash + } + cellsMap := make(map[[32]byte]*boc.Cell) + if len(res.State) > 0 { + stateCells, err := boc.DeserializeBoc(res.State) + if err != nil { + return nil, nil, fmt.Errorf("state deserialization failed: %w", err) + } + hash, err := stateCells[0].Hash256() + if err != nil { + return nil, nil, fmt.Errorf("get hash err: %w", err) + } + cellsMap[hash] = stateCells[0] + } + proofCells, err := boc.DeserializeBoc(res.Proof) + if err != nil { + return nil, nil, err + } + shardState, err := checkBlockShardStateProof(proofCells, blockHash, cellsMap) + if err != nil { + return nil, nil, fmt.Errorf("incorrect block proof: %w", err) + } + values := shardState.ShardStateUnsplit.Accounts.Values() + keys := shardState.ShardStateUnsplit.Accounts.Keys() + for i, k := range keys { + if k == accountID.Address { + return &values[i], shardState, nil + } + } + if len(res.State) == 0 { + return &tlb.ShardAccount{Account: tlb.Account{SumType: "AccountNone"}}, shardState, nil + } + return nil, nil, errors.New("invalid account state") +} + +func checkShardInMasterProof(master ton.BlockIDExt, shardProof []byte, workchain int32, shardRootHash ton.Bits256) (*tlb.McStateExtra, error) { + shardProofCells, err := boc.DeserializeBoc(shardProof) + if err != nil { + return nil, err + } + shardState, err := checkBlockShardStateProof(shardProofCells, master.RootHash, nil) + if err != nil { + return nil, fmt.Errorf("check block proof failed: %w", err) + } + if !shardState.ShardStateUnsplit.Custom.Exists { + return nil, fmt.Errorf("not a masterchain block") + } + stateExtra := shardState.ShardStateUnsplit.Custom.Value.Value + keys := stateExtra.ShardHashes.Keys() + values := stateExtra.ShardHashes.Values() + for i, k := range keys { + binTreeValues := values[i].Value.BinTree.Values + for _, b := range binTreeValues { + switch b.SumType { + case "Old": + if int32(k) == workchain && ton.Bits256(b.Old.RootHash) == shardRootHash { + return &stateExtra, nil + } + case "New": + if int32(k) == workchain && ton.Bits256(b.New.RootHash) == shardRootHash { + return &stateExtra, nil + } + } + } + } + return nil, fmt.Errorf("required shard hash not found in proof") +} + +func checkBlockShardStateProof(proof []*boc.Cell, blockRootHash ton.Bits256, cellsMap map[[32]byte]*boc.Cell) (*tlb.ShardStateUnsplit, error) { + if len(proof) != 2 { + return nil, errors.New("must be two root cells") + } + block, err := checkBlockProof(*proof[0], blockRootHash) + if err != nil { + return nil, fmt.Errorf("incorrect block proof: %w", err) + } + var stateProof struct { + Proof tlb.MerkleProof[tlb.ShardStateUnsplit] + } + decoder := tlb.NewDecoder() + if cellsMap != nil { + decoder = decoder.WithPrunedResolver(func(hash tlb.Bits256) (*boc.Cell, error) { + cell, ok := cellsMap[hash] + if ok { + return cell, nil + } + return nil, errors.New("not found") + }) + } + err = decoder.Unmarshal(proof[1], &stateProof) + if err != nil { + return nil, err + } + if stateProof.Proof.VirtualHash != block.VirtualRoot.StateUpdate.ToHash { + return nil, errors.New("invalid virtual hash") + } + return &stateProof.Proof.VirtualRoot, nil +} + +func checkBlockProof(proof boc.Cell, blockRootHash ton.Bits256) (*tlb.MerkleProof[tlb.Block], error) { + var res tlb.MerkleProof[tlb.Block] + err := tlb.Unmarshal(&proof, &res) // merkle hash and depth checks inside + if err != nil { + return nil, fmt.Errorf("failed to unmarshal block proof: %w", err) + } + if ton.Bits256(res.VirtualHash) != blockRootHash { + return nil, fmt.Errorf("invalid block root hash") + } + return &res, nil // return new_hash field of MerkleUpdate of ShardState +} diff --git a/liteapi/account_test.go b/liteapi/account_test.go new file mode 100644 index 0000000..8b72c3f --- /dev/null +++ b/liteapi/account_test.go @@ -0,0 +1,171 @@ +package liteapi + +import ( + "bytes" + "context" + "encoding/base64" + "errors" + "fmt" + "github.com/tonkeeper/tongo/boc" + "github.com/tonkeeper/tongo/tlb" + "github.com/tonkeeper/tongo/ton" + "testing" +) + +func TestGetAccountWithProof(t *testing.T) { + api, err := NewClient(Testnet(), FromEnvs()) + if err != nil { + t.Fatal(err) + } + testCases := []struct { + name string + accountID string + }{ + { + name: "account from masterchain", + accountID: "-1:34517c7bdf5187c55af4f8b61fdc321588c7ab768dee24b006df29106458d7cf", + }, + { + name: "active account from basechain", + accountID: "0:e33ed33a42eb2032059f97d90c706f8400bb256d32139ca707f1564ad699c7dd", + }, + { + name: "nonexisted from basechain", + accountID: "0:5f00decb7da51881764dc3959cec60609045f6ca1b89e646bde49d492705d77c", + }, + } + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + accountID, err := ton.AccountIDFromRaw(tt.accountID) + if err != nil { + t.Fatal("AccountIDFromRaw() failed: %w", err) + } + acc, st, err := api.GetAccountWithProof(context.TODO(), accountID) + if err != nil { + t.Fatal(err) + } + fmt.Printf("Account status: %v\n", acc.Account.Status()) + fmt.Printf("Last proof utime: %v\n", st.ShardStateUnsplit.GenUtime) + }) + } +} + +func TestUnmarshallingProofWithPrunedResolver(t *testing.T) { + testCases := []struct { + name string + accountID string + state string + proof string + }{ + { + name: "account from masterchain", + accountID: "-1:34517c7bdf5187c55af4f8b61fdc321588c7ab768dee24b006df29106458d7cf", + state: "te6ccgEBAQEANwAAac/zRRfHvfUYfFWvT4th/cMhWIx6t2je4ksAbfKRBkWNfPICV8MiQ7WQAAAnwcjbgQjD0JAE", + proof: "te6ccgECMQIABnwBAAlGAzm5ngf8wRtgCPSbEv1KYCOfL3YI9/HjNbeRsayNPbNBAcQCCUYDNjyZxQ6TS+uioSqhEmArXFMzcJ0iBgOO8gRScN1HDdQAFykkW5Ajr+L////9AP////8AAAAAAAAAAAFZWXUAAAAAZtWqnAAAFybEzYeEAVlZcmADBAUGKEgBAaHRVuPHjzLUmEYd44x/vzWcCD0Yz14taK8lFYkyYi16AAEiE4IRdMqOqEN5qjAHJyIzAAAAAAAAAAD//////////4RdMqOqEN5qiCgnKChIAQGSq4StFUmWS1wEONBSEQt3Wuup9Nhbdrp5fmk6oPx7cgAbIxMBCLplR1QhvNUYCAknIxMBCLADRttzmtl4CgsnKEgBASOab6P8maV1G2OhFqlTgNjoroG1i5MO5qtsoOqgY1ViAcAjEwEIqC8wZsH7ingMDScoSAEB6CVtOp/z9Kpj+SrDgasFP3kzGT7kZL/D7DV+kXZw3p4BwChIAQFIWAk+xGmmFk4X6v0lmgWpnk4YjEI/Tam9kXRPNLTjdgEPIhEA8MDW7kYtTOgODyhIAQFvFG13TDMvzWbO+0LVJbC2JHWAiHGL4nE6ztkqqPbefQENIhEA8BRymICH9QgQESIRAPAErnsrX9VoEhMoSAEB3EtPbrrXUOrrKC6PIjfYhj6sGE/sozDHyHhOqJco3MEAHShIAQEzaqPXOwDDYGRMpiB4f56Ep+oa0eJgHewJ3AzRahxDIgAaIhEA4GnDcXrbdIgUFSIRAOBpsy/ekzFIFhcoSAEBNEqqYfkIiFNMQLEWcnTl6stVZGy5ErSafIxVo9y43QYAGCINAKISMHbNyBgZKEgBAb3/C7yo+BlD9hObGRxmV/KM5e4Dx2dwvpYwbKHtaWxnABciDQChOnTEp4gaGyhIAQGefqFE75DubH/tAm07waI/ZJXXaUbA9TAjnJiXm5bFCQAYKEgBAWXR2cYWGVGOhRr0zx2IbfGvk/PYc1y+oSjg/RG7M23AABAiDQCgymlW3sgcHSINAKBa4jVqKB4fKEgBAScXOibV9zcakxfRicgWyCdTJFdgVU+FKgRvmF+aGkzaAAooSAEBgCBa9skgBATVk+IARArXssVqC0fkYdTeesmbGCNG2QIAEyIJAGHoSAggISIJAGHoSAgiIyhIAQG0d/c3vjpwCZ/skE9GxWaQW2p12p94naEbq+nI8BwxkAALIgkAYehICCQlKEgBAbcMTUahqWzuFpxMXlh6PQ1nbe5GZl+2H6cojgVpqdJGAAohl7xvj3vqMPirXp8Ww/uGQrEY9W7RvcSWANvlIgyLGvngMPQkBwpuh9/a8Q8liD7J/k6cDctwlpd7/wLfSj2ZkV3vQyWYAABPg5G3AgwmKEgBAdYv7OF3BCOejG/QrCN6d0U1gfBqoWUQPs3lCtkW0bqjAAgoSAEBfNH4EBmmQQpGeuMDQrdiJrcg1/foaCtvYt8A2eULqDYAAChIAQGyDjajs2pM3uYBEGxkLpBxiwpY2vIAdT27MYn5VrSUtgABKEgBAfXuY7WxrmX661SjWg783A5GU3G2ZUuk96jxo07FszvmAcAkEBHvVar////9KissLQGgm8ephwAAAAAEAQFZWXUAAAAAAP////8AAAAAAAAAAGbVqpwAABcmxM2HgAAAFybEzYeEjIe6qAAErqgBWVlyAVlQJsQAAAAIAAAAAAAAAe4uKEgBASkmYQFN/IwKZw+6jDvG7Hla0bypRJASLgCSLllgZYYxAAMqigQUxb9ElcbqpVZzQQGVtjJWkzZu/gqQV6cRwEqnJT7ljzm5ngf8wRtgCPSbEv1KYCOfL3YI9/HjNbeRsayNPbNBAcQBxC8wKEgBAfZANQckARh74l3KoHg6MIoIlCtXCklSokH5oFnYkvWaAAcAmAAAFybEvkVEAVlZdIw41mHg1tiTlZWmUEC5Zs1iJSaJiU/PG7sL/HsqfBj+zYXmULmtzn4TRGwnVVC5tKAhaIUDbFZrLZ+xVZ8cOhpojAEDFMW/RJXG6qVWc0EBlbYyVpM2bv4KkFenEcBKpyU+5Y+0LNwg0RHTx+GvVrTHWlXSAsJOr1Re1+VF1o0FxmRgmwHEABVojAEDObmeB/zBG2AI9JsS/UpgI58vdgj38eM1t5GxrI09s0EncQaO3Qwlxbnasj2PyljXoXXcs0VfOqaRU3MLD/XjOwHEABU=", + }, + { + name: "active account from basechain", + accountID: "0:e33ed33a42eb2032059f97d90c706f8400bb256d32139ca707f1564ad699c7dd", + state: "te6ccgECRgEACUQAAnPADjPtM6QusgMgWfl9kMcG+EALslbTITnKcH8VZK1pnH3SjJBKAzalNagAAFyBA05ODYC7pLkMcNNAAQIBFP8A9KQT9LzyyAsDAgAdHgIBYgQFAgLMBgcCASAXGAIBIAgJAgEgExQCASAKCwIBIA8QAW1CDHAJSED/Lw3gHQ0wMBcbCSXwPg+kAwAdMf7UTQ1NQwMSLAAOMCECRfBIIQNw/sUbrchA/y8IDAIBIA0OANAy+CMgghBi5EBpvPLgxwHwBCDXSSDCGPLgyCCBA/C78uDJIHipCMAA8uDKIfAF8uDLWPAHFL7y4Mwi+QGAUPgzIG6zjhDQ9AQwUhCDB/QOb6Ex8tDNkTDiyFAEzxbJyFADzxYSzMnwDAANHDIywHJ0IAAzHCfAdMHAcAAILOUAqYIAt4S5jEgwADy0MmACASAREgIBIDQ1AE8yI4gIddJEtcYWc8WIddKIMAAILObAcAB8uDKAtQw0AKRMeLmMcnQgAH8cCHXSY41XLogs44uMALTByHALSPCALAkpvhSQLmwIsIvI8E6sLEiwmADwXsTsBKxsyCzlAKmCALeE97mbBK6gAgFYFRYAOdLPgFOBD4BbvADGRlgqxnizh9AWW15mZkwCB9gEAC0AcjL//gozxbJcCDIywET9AD0AMsAyYAAbPkAdMjLAhLKB8v/ydCACASAZGgIBIBscAAe4tdMYAB+6ej7UTQ1NQwMfAKcAHwC4ABu5Bb7UTQ1NQwMH/wAhKACdujDDAg10l4qQjAAPLgRiDXCgfAACHXScAIUhCwk1t4beAglQHTBzEB3iHwA1Ei1xgw+QGCALqTyMsPAYIBZ6PtQ9jPFskBkXiRcOISoAGABIAWh0dHBzOi8vZG5zLnRvbi5vcmcvY29sbGVjdGlvbi5qc29uART/APSkE/S88sgLHwIBYiAhAgLMIiMCASA8PQIBICQlAgFINjcCASAmJwIBWDQ1AgEgKCkADUcMjLAcnQgB9z4J28QAtDTAwFxsJJfBOD6QPpAMfoAMXHXIfoAMfoAMPAKJ7OOTl8FbCI0UjLHBfLhlQH6QNQwbXDIywf0AMn4I4IQYuRAaaGCCCeNAKkEIMIMkzCADN6BASyBAPBYqIAMqQSh+CMBoPACRHfwCRA1+CPwC+BTWccFGLCAqABE+kQwcLry4U2AD+I40EJtfC/pAMHAg+CVtgEBwgBDIywVQB88WUAX6AhXLahLLH8s/Im6zlFjPFwGRMuIByQH7AOApxwCRcJUJ0x9QquIh8Aj4IyG8JMAAjp40Ojo7jhY2Njc3N1E1xwXy4ZYQJRAkECP4I/AL4w7gMQ3TPyVusx+wkmwh4w0rLC0A/jAmgGmAZKmEUrC+8uGXghA7msoAUqChUnC8mTaCEDuaygAZoZM5CAXiIMIAjjKCEFV86iD4JRA5bXFwgBDIywVQB88WUAX6AhXLahLLH8s/Im6zlFjPFwGRMuIByQH7AJIwNuKAPCP4I6GhIMIAkxOgApEw4kR08AkQJPgj8AsA0jQ2U82hghA7msoAUhChUnC8mTaCEDuaygAWoZIwBeIgwgCON4IQNw/sUW1yKVE0VEdDcIAQyMsFUAfPFlAF+gIVy2oSyx/LPyJus5RYzxcBkTLiAckB+wAcoQuRMOJtVHdlVHdjLvALAgTIghBfzD0UUiC6jpUxNztTcscF8uGREJoQSRA4RwZAFQTgghAaC51RUiC6jhlbMjU1NzdRNccF8uGaA9QwQBUEUDP4I/AL4CGCEE6x8Pm64wI7IIIQRL6uQbrjAjgnghBO0UtlujEuLzAAiFs2Njg4UUfHBfLhmwTT/yDXSsIAB9DTBwHAAPLhnPQEMAeY1DBAFoMH9BeYMFAFgwf0WzDicMjLB/QAyRA1QBT4I/ALAf4wNjokbvLhnYBQ+DPQ9AQwUkCDB/QOb6Hy4Z/TByHAACLAAbHy4aAhwACOkSQQmxBoUXoQVxBGEFxDFEzdljAQOjlfB+IBwAGOMnCCEDcP7FFYbYEAoHCAEMjLBVAHzxZQBfoCFctqEssfyz8ibrOUWM8XAZEy4gHJAfsAkVviMQH+jno3+CNQBqGBAli8Bm4WsPLhniPQ10n4I/AHUpC+8uGXUXihghA7msoAoSDCAI4yECeCEE7RS2VYB21ycIAQyMsFUAfPFlAF+gIVy2oSyx/LPyJus5RYzxcBkTLiAckB+wCTMDU14vgjgQEsoPACRHfwCRBFEDQS+CPwC+BfBDMB8DUC+kAh8AH6QNIAMfoAghA7msoAHaEhlFMUoKHeItcLAcMAIJIFoZE14iDC//LhkiGOPoIQBRONkchQC88WUA3PFnEkSxRUSMBwgBDIywVQB88WUAX6AhXLahLLH8s/Im6zlFjPFwGRMuIByQH7ABBplBAsOVviATIAio41KPABghDVMnbbEDlGCW1xcIAQyMsFUAfPFlAF+gIVy2oSyx/LPyJus5RYzxcBkTLiAckB+wCTODQw4hBFEDQS+CPwCwCaMjU1ghAvyyaiuo46cIIQi3cXNQTIy/9QBc8WFEMwgEBwgBDIywVQB88WUAX6AhXLahLLH8s/Im6zlFjPFwGRMuIByQH7AOBfBIQP8vAAkwgwASWMIED6IBk4CDABZYwgQH0gDLgIMAGljCBAZCAKOAgwAeWMIEBLIAe4CDACJYwgQDIgBTgIMAJlDCAZHrgwAqTgDJ14HpxgAGkAasC8AYBghA7msoAqAGCEDuaygCoAoIQYuRAaaGCCCeNAKkEIMIVkVvgbBKWp1qAZKkE5IAIBIDg5AgEgOjsAIQgbpQwbXAg4ND6QPoA0z8wgABcyFADzxYB+gLLP8mAAUTtRNDT//pAINdJwgCffwH6QNTU9ATTPzAQVxBW4DBwbW1tbSQQVxBWgACsBsjL/1AFzxZQA88WzMz0AMs/ye1UgAgEgPj8CASBCQwATu7OfAKF18H8AiAICdEBBABCodPAKEEdfBwAMqVnwCmxxAA24/P8ApfA4AgEgREUAE7ZKXgFCBOvg+hAAx7RhhDrpJA8VIRgAHlwI3gFCBuvg+hpg4DgAHlwznoCGAHrhQPgAHlwzuEERxGYQXgM+BIg9yxH7ZN3ElkrRuga4eSQNwjVy83zFyqqxQ6L/+8QYABJmDwA8ADBg/oHt9CYPADA=", + proof: "te6ccgECPwIACDsBAAlGA13wIJUiI7PTg32Ejju+cdCEhP3rVfdUykAsuzXJC+XeAhoCCUYDjCFt8RcRp3CSKwob3sWjzrlRTNWeNVuLJlIfcVPa8CsAHTYjW5Ajr+L////9AgAAAADAAAAAAAAAAAFyPHcAAAAAZtWq5wAAFybGxRHCAVlZkCADBAUoSAEBiyK6Ydkff6GhmEvUf7v2ypzoO80QfF1X31CgOrDi9NkAASERgcZ7gbxJWKSQBgDXAAAAAAAAAAD//////////3Ge4G8SVikji03qk7ZraRAAAXJsa1z4QBWVmQQHydZ7bv6t6cGWTc3mk/vill/h79hpOKKfqDLKYVV2UIGVEYrOM3Wj2tBpeo0h3v1D9oVZi+yPKb3hs1ZHIKP4IhJsDjPcDeJKxSQHCChIAQFV/Gx1KlYcCS7JoHnYKwxWUvcHSieeSmxWO4uodXNNAAIXIhEA4LZ6enx+bWgJCiIRAOBaroXhVZgICwwoSAEBjCXGi3ABTEA16fNsFq6NoDhjFI0NMoPkohAjH+1qLLEB+CIRAOAksERHiOeoDQ4oSAEBKwFZZ4JbhmBFD0mdW0SfSTNG6kvDu3q3m/GIC+nY4zIB+SIPANc4OPbi8mgPEChIAQFxKVccovU7jHX+7gMJQXv9jCCrCiQGA6Y+QxKqTUBOhgH2KEgBAarJ0lCkaVdnW8kHI8DI6r+f2A/fTt8z22IAYXPQ809ZAG0iDwDQN70ViI0IERIoSAEBYKCiCBXVsTmcFm80Mdta68M1YYjRCTuZc93Triz7m9gBgyIPAMKMtn9VX0gTFCIPAMFcoYfNXGgVFihIAQGfX/bkZ1jS0/86yXx7DUfHatLji8tg3KF1vqPCk7IuVgB1Ig8AwOFqAW4syBcYKEgBAcUv4J2EF0tP6D+5WhVgwDMLTkVX7DMpqFjgQDi6E9jhACgoSAEB3U3tRh2vNqt3+VIQEWsJZvr1iIynQifDUaSkLWptWeEAViIPAMCaEhFwVKgZGihIAQEgbSkgGVjN4KMCds3fZKZrTlcAovsrbsmf4H7z/kMvmAAlIg8AwHvkcXhAiBscKEgBASVq+znmXDdANt51HB6R1aLUiNC2dO0HT1iUqtS8Z6F5ADEiDwDAa8Yggk9oHR4oSAEBUYTWv2vPRimCXtZAgNy5iMY3b0wHigalk4BeV1BC+VUAHiIPAMBjmLIoEUgfIChIAQFkFPWNkgIntiWCGsnRF2KAeS1zN7xJz25ijFXeaks1OwAbIg8AwGBRncod6CEiIg8AwF7ekhx2aCMkKEgBAZmbgpiPVKGvOf/WEC7nxevwaoe7MjdmUaGfGElg+1PWAB8oSAEBi34yMErN+/f9620NGVPyNM4YLA7kNUNBzi6/M5+T+rIAGyIPAMBeRJvtTOglJihIAQGuzAJ2VhPM6ZjvDnlUhxmUUW2KVATVTIxDBiyxWEnGjgAaIg8AwF4SHEYnaCcoIg8AwF3qmKb5yCkqKEgBATwM4rRbTvXMTygQOAAN0P3je99htC5985672M3oFzKrABAoSAEBmvQ7T6WK4sikPEF56dRPfYMoKohn6AD14iHPu0PIXG0AESIPAMBd2lpx5QgrLCIPAMBd0wBKaYgtLihIAQGAJDolEz6vxu7T4bZ5nycsMPnhWdrLb6YHk6DtX6clHwAPIg8AwF3SzXUuCC8wKEgBAZj5RGzisK6cu7D5HEuUfV3XH9l7kh+YJkUvHmOD8x+OAAQoSAEBen2eXEL+HfMUrc8tTs/QzzywQbk7BwkhbrJcKjmjKuwAECIPAMBd0pT/QGgxMihIAQFGfJ9GFmcwJQAG2fIW1sw7PoTSSnnu73zaAArcog2ekQAKIg8AwF3SbPOwaDM0IZu53SF1kBkCz8vshjg3wgBdkraZCc5Tg/irJWtM4+6BgLukuQxwz+nllEVa9/Mu2AwDiKHA4s/62dmELHS3FsVCWMVhGZ2gAALkCBpycDA1KEgBATtYLqw6Myl/aB80aal2YMmYp4nlmVgzmLq4OM6C27smAAIoSAEBBw2M4JYgqdIbA+zPjO/RlQITYyUb3l6fcGI6GsYk/xQADSQQEe9Vqv////03ODk6AqCbx6mHAAAAAIQBAXI8dwAAAAACAAAAAMAAAAAAAAAAZtWq5wAAFybGxRHAAAAXJsbFEcLqGGx8AASwigFZWZABWVAmxAAAAAgAAAAAAAAB7js8KEgBAdbtVydD+uFhNgWoW1/6GJjaj2d2eMYE36jWJ+9zq7BIAAEqigSQjubvByF4zveT/d62LW5Uzl9wAVo+9ozG55Q4oZr/kF3wIJUiI7PTg32Ejju+cdCEhP3rVfdUykAsuzXJC+XeAhoCGj0+KEgBAdzATRK2nM22UX18oy9xQvXwCJ9Zf54Cg36vVlziEuIcAAgAmAAAFybGtc+EAVlZkEB8nWe27+renBlk3N5pP74pZf4e/YaTiin6gyymFVdlCBlRGKzjN1o9rQaXqNId79Q/aFWYvsjym94bNWRyCj8AmAAAFybGtc+CAXI8dnR1t2tl9ygPFKytIrYccschqEVLVJKRzfGoXydZLkF/V+9JqqYAgWOFo1SWBohYySfyS4Jzv7iZCQya5q+vgNJojAEDkI7m7wcheM73k/3eti1uVM5fcAFaPvaMxueUOKGa/5Au1SPeq+s/fkBEbdDR9O8KVspDwcDI3pfnn1mShOrlfAIaABpojAEDXfAglSIjs9ODfYSOO75x0ISE/etV91TKQCy7NckL5d4d2zEQnvPwYNp0OmphoUWBv1hhDLQJv0uX98Ed7i21wgIaABs=", + }, + } + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + accountID, err := ton.AccountIDFromRaw(tt.accountID) + if err != nil { + t.Fatal("AccountIDFromRaw() failed: %w", err) + } + state, err := base64.StdEncoding.DecodeString(tt.state) + if err != nil { + t.Fatal("base64 decoding failed: %w", err) + } + proof, err := base64.StdEncoding.DecodeString(tt.proof) + if err != nil { + t.Fatal("base64 decoding failed: %w", err) + } + stateCells, err := boc.DeserializeBoc(state) + if err != nil { + t.Fatal("DeserializeBoc() failed: %w", err) + } + proofCells, err := boc.DeserializeBoc(proof) + if err != nil { + t.Fatal("DeserializeBoc() failed: %w", err) + } + hash, err := stateCells[0].Hash256() + if err != nil { + t.Fatal("Get hash failed: %w", err) + } + cellsMap := map[[32]byte]*boc.Cell{hash: stateCells[0]} + if err != nil { + t.Fatal("Get NonPrunedCells() failed: %w", err) + } + decoder := tlb.NewDecoder().WithDebug().WithPrunedResolver(func(hash tlb.Bits256) (*boc.Cell, error) { + cell, ok := cellsMap[hash] + if ok { + return cell, nil + } + return nil, errors.New("not found") + }) + var stateProof struct { + Proof tlb.MerkleProof[tlb.ShardStateUnsplit] + } + err = decoder.Unmarshal(proofCells[1], &stateProof) + if err != nil { + t.Fatal("proof unmarshalling failed: %w", err) + } + values := stateProof.Proof.VirtualRoot.ShardStateUnsplit.Accounts.Values() + keys := stateProof.Proof.VirtualRoot.ShardStateUnsplit.Accounts.Keys() + for i, k := range keys { + if bytes.Equal(k[:], accountID.Address[:]) { + fmt.Printf("Account status: %v\n", values[i].Account.Status()) + } + } + }) + } +} + +func TestGetAccountWithProofForBlock(t *testing.T) { + api, err := NewClient(Testnet(), FromEnvs()) + if err != nil { + t.Fatal(err) + } + testCases := []struct { + name string + accountID string + block string + }{ + { + name: "active account from basechain", + accountID: "0:e33ed33a42eb2032059f97d90c706f8400bb256d32139ca707f1564ad699c7dd", + block: "(0,e000000000000000,24681072)", + }, + { + name: "account from masterchain", + accountID: "-1:34517c7bdf5187c55af4f8b61fdc321588c7ab768dee24b006df29106458d7cf", + block: "(-1,8000000000000000,23040403)", + }, + } + for _, tt := range testCases { + t.Run(tt.name, func(t *testing.T) { + accountID, err := ton.AccountIDFromRaw(tt.accountID) + if err != nil { + t.Fatal("AccountIDFromRaw() failed: %w", err) + } + b, err := ton.ParseBlockID(tt.block) + if err != nil { + t.Fatal("ParseBlockID() failed: %w", err) + } + block, _, err := api.LookupBlock(context.TODO(), b, 1, nil, nil) + if err != nil { + t.Fatal("LookupBlock() failed: %w", err) + } + acc, st, err := api.WithBlock(block).GetAccountWithProof(context.TODO(), accountID) + if err != nil { + t.Fatal(err) + } + fmt.Printf("Account status: %v\n", acc.Account.Status()) + fmt.Printf("Last proof utime: %v\n", st.ShardStateUnsplit.GenUtime) + }) + } +} diff --git a/liteapi/client.go b/liteapi/client.go index c6d45b2..14a9301 100644 --- a/liteapi/client.go +++ b/liteapi/client.go @@ -34,6 +34,7 @@ const ( var ( // ErrAccountNotFound is returned by lite server when executing a method for an account that has not been deployed to the blockchain. ErrAccountNotFound = errors.New("account not found") + BlockMismatch = errors.New("got invalid block from liteserver") ) // ProofPolicy specifies a policy for proof checks. @@ -391,6 +392,9 @@ func (c *Client) GetBlockRaw(ctx context.Context, blockID ton.BlockIDExt) (litec if err != nil { return liteclient.LiteServerBlockDataC{}, err } + if res.Id.ToBlockIdExt() != blockID { + return liteclient.LiteServerBlockDataC{}, BlockMismatch + } return res, err } @@ -412,6 +416,9 @@ func (c *Client) GetStateRaw(ctx context.Context, blockID ton.BlockIDExt) (litec if err != nil { return liteclient.LiteServerBlockStateC{}, err } + if res.Id.ToBlockIdExt() != blockID { + return liteclient.LiteServerBlockStateC{}, BlockMismatch + } return res, nil } @@ -436,6 +443,9 @@ func (c *Client) GetBlockHeaderRaw(ctx context.Context, blockID ton.BlockIDExt, if err != nil { return liteclient.LiteServerBlockHeaderC{}, err } + if res.Id.ToBlockIdExt() != blockID { + return liteclient.LiteServerBlockHeaderC{}, BlockMismatch + } return res, nil } @@ -457,6 +467,9 @@ func (c *Client) LookupBlock(ctx context.Context, blockID ton.BlockID, mode uint if err != nil { return ton.BlockIDExt{}, tlb.BlockInfo{}, err } + if res.Id.ToBlockIdExt().BlockID != blockID { + return ton.BlockIDExt{}, tlb.BlockInfo{}, BlockMismatch + } return decodeBlockHeader(res) } @@ -505,9 +518,10 @@ func (c *Client) RunSmcMethodByID(ctx context.Context, accountID ton.AccountID, if err != nil { return 0, tlb.VmStack{}, err } + blockID := c.targetBlockOr(masterHead) req := liteclient.LiteServerRunSmcMethodRequest{ Mode: 4, - Id: liteclient.BlockIDExt(c.targetBlockOr(masterHead)), + Id: liteclient.BlockIDExt(blockID), Account: liteclient.AccountID(accountID), MethodId: uint64(methodID), Params: b, @@ -516,6 +530,9 @@ func (c *Client) RunSmcMethodByID(ctx context.Context, accountID ton.AccountID, if err != nil { return 0, tlb.VmStack{}, err } + if res.Id.ToBlockIdExt() != blockID { + return 0, tlb.VmStack{}, BlockMismatch + } var result tlb.VmStack if res.ExitCode == 4294967040 { //-256 return res.ExitCode, nil, ErrAccountNotFound @@ -577,6 +594,9 @@ func (c *Client) GetAccountStateRaw(ctx context.Context, accountID ton.AccountID if err != nil { return liteclient.LiteServerAccountStateC{}, err } + if res.Id.ToBlockIdExt() != blockID { + return liteclient.LiteServerAccountStateC{}, BlockMismatch + } return res, nil } @@ -633,6 +653,9 @@ func (c *Client) GetShardInfoRaw(ctx context.Context, blockID ton.BlockIDExt, wo if err != nil { return liteclient.LiteServerShardInfoC{}, err } + if res.Id.ToBlockIdExt() != blockID { + return liteclient.LiteServerShardInfoC{}, BlockMismatch + } return res, nil } @@ -673,6 +696,9 @@ func (c *Client) GetAllShardsInfoRaw(ctx context.Context, blockID ton.BlockIDExt if err != nil { return liteclient.LiteServerAllShardsInfoC{}, err } + if res.Id.ToBlockIdExt() != blockID { + return liteclient.LiteServerAllShardsInfoC{}, BlockMismatch + } return res, nil } @@ -694,6 +720,9 @@ func (c *Client) GetOneTransactionFromBlock( if err != nil { return ton.Transaction{}, err } + if r.Id.ToBlockIdExt() != blockId { + return ton.Transaction{}, BlockMismatch + } if len(r.Transaction) == 0 { return ton.Transaction{}, fmt.Errorf("transaction not found") } @@ -845,6 +874,9 @@ func (c *Client) ListBlockTransactionsRaw(ctx context.Context, blockID ton.Block if err != nil { return liteclient.LiteServerBlockTransactionsC{}, err } + if res.Id.ToBlockIdExt() != blockID { + return liteclient.LiteServerBlockTransactionsC{}, BlockMismatch + } return res, nil } @@ -898,7 +930,31 @@ func (c *Client) GetConfigAll(ctx context.Context, mode ConfigMode) (tlb.ConfigP if err != nil { return tlb.ConfigParams{}, err } - return ton.DecodeConfigParams(res.ConfigProof) + if c.proofPolicy == ProofPolicyUnsafe { + return ton.DecodeConfigParams(res.ConfigProof) + } + stateProofCell, err := boc.DeserializeBoc(res.StateProof) + if err != nil { + return tlb.ConfigParams{}, err + } + if len(stateProofCell) != 1 { + return tlb.ConfigParams{}, fmt.Errorf("invalid number of roots in state proof boc") + } + configProofCell, err := boc.DeserializeBoc(res.ConfigProof) + if err != nil { + return tlb.ConfigParams{}, err + } + if len(configProofCell) != 1 { + return tlb.ConfigParams{}, fmt.Errorf("invalid number of roots in config proof boc") + } + shardState, err := checkBlockShardStateProof([]*boc.Cell{stateProofCell[0], configProofCell[0]}, ton.Bits256(res.Id.RootHash), nil) + if err != nil { + return tlb.ConfigParams{}, err + } + if !shardState.ShardStateUnsplit.Custom.Exists { + return tlb.ConfigParams{}, fmt.Errorf("missing master chain state extra value") + } + return shardState.ShardStateUnsplit.Custom.Value.Value.Config, nil } func (c *Client) GetConfigAllRaw(ctx context.Context, mode ConfigMode) (liteclient.LiteServerConfigInfoC, error) { @@ -906,13 +962,17 @@ func (c *Client) GetConfigAllRaw(ctx context.Context, mode ConfigMode) (liteclie if err != nil { return liteclient.LiteServerConfigInfoC{}, err } + blockID := c.targetBlockOr(masterHead) res, err := client.LiteServerGetConfigAll(ctx, liteclient.LiteServerGetConfigAllRequest{ Mode: uint32(mode), - Id: liteclient.BlockIDExt(c.targetBlockOr(masterHead)), + Id: liteclient.BlockIDExt(blockID), }) if err != nil { return liteclient.LiteServerConfigInfoC{}, err } + if res.Id.ToBlockIdExt() != blockID { + return liteclient.LiteServerConfigInfoC{}, BlockMismatch + } return res, nil } @@ -921,14 +981,18 @@ func (c *Client) GetConfigParams(ctx context.Context, mode ConfigMode, paramList if err != nil { return tlb.ConfigParams{}, err } + blockID := c.targetBlockOr(masterHead) r, err := client.LiteServerGetConfigParams(ctx, liteclient.LiteServerGetConfigParamsRequest{ Mode: uint32(mode), - Id: liteclient.BlockIDExt(c.targetBlockOr(masterHead)), + Id: liteclient.BlockIDExt(blockID), ParamList: paramList, }) if err != nil { return tlb.ConfigParams{}, err } + if r.Id.ToBlockIdExt() != blockID { + return tlb.ConfigParams{}, BlockMismatch + } return ton.DecodeConfigParams(r.ConfigProof) } @@ -947,9 +1011,10 @@ func (c *Client) GetValidatorStats( b := tl.Int256(*startAfter) sa = &b } + blockID := c.targetBlockOr(masterHead) r, err := client.LiteServerGetValidatorStats(ctx, liteclient.LiteServerGetValidatorStatsRequest{ Mode: mode, - Id: liteclient.BlockIDExt(c.targetBlockOr(masterHead)), + Id: liteclient.BlockIDExt(blockID), Limit: limit, StartAfter: sa, ModifiedAfter: modifiedAfter, @@ -957,6 +1022,9 @@ func (c *Client) GetValidatorStats( if err != nil { return nil, err } + if r.Id.ToBlockIdExt() != blockID { + return nil, BlockMismatch + } cells, err := boc.DeserializeBoc(r.DataProof) if err != nil { return nil, err @@ -1000,7 +1068,14 @@ func (c *Client) GetLibraries(ctx context.Context, libraryList []ton.Bits256) (m if len(data) != 1 { return nil, fmt.Errorf("multiroot lib is not supported") } - libs[ton.Bits256(lib.Hash)] = data[0] + dataHash, err := data[0].Hash256() + if err != nil { + return nil, err + } + if lib.Hash != dataHash { + return nil, fmt.Errorf("got wrong library data from liteserver") + } + libs[dataHash] = data[0] } return libs, nil } @@ -1018,9 +1093,17 @@ func (c *Client) GetShardBlockProofRaw(ctx context.Context) (liteclient.LiteServ if err != nil { return liteclient.LiteServerShardBlockProofC{}, err } - return client.LiteServerGetShardBlockProof(ctx, liteclient.LiteServerGetShardBlockProofRequest{ - Id: liteclient.BlockIDExt(c.targetBlockOr(masterHead)), + blockID := c.targetBlockOr(masterHead) + res, err := client.LiteServerGetShardBlockProof(ctx, liteclient.LiteServerGetShardBlockProofRequest{ + Id: liteclient.BlockIDExt(blockID), }) + if err != nil { + return liteclient.LiteServerShardBlockProofC{}, err + } + if res.MasterchainId.ToBlockIdExt() != blockID { + return liteclient.LiteServerShardBlockProofC{}, BlockMismatch + } + return res, nil } // WaitMasterchainSeqno waits for a masterchain block with the given seqno. diff --git a/liteapi/client_test.go b/liteapi/client_test.go index 2051f78..168ebbf 100644 --- a/liteapi/client_test.go +++ b/liteapi/client_test.go @@ -341,6 +341,17 @@ func TestGetConfigAll(t *testing.T) { } } +func TestGetConfigAllWithSafePolicy(t *testing.T) { + api, err := NewClient(Mainnet(), FromEnvsOrMainnet(), WithProofPolicy(ProofPolicyFast)) + if err != nil { + t.Fatal(err) + } + _, err = api.GetConfigAll(context.TODO(), 0) + if err != nil { + t.Fatal(err) + } +} + func TestGetAccountState(t *testing.T) { api, err := NewClient(Mainnet(), FromEnvs()) if err != nil { diff --git a/liteapi/proof.go b/liteapi/proof.go new file mode 100644 index 0000000..0738bd8 --- /dev/null +++ b/liteapi/proof.go @@ -0,0 +1,350 @@ +package liteapi + +import ( + "bytes" + "context" + "crypto/ed25519" + "crypto/sha256" + "encoding/binary" + "encoding/hex" + "fmt" + "github.com/tonkeeper/tongo/boc" + "github.com/tonkeeper/tongo/liteclient" + "github.com/tonkeeper/tongo/tl" + "github.com/tonkeeper/tongo/tlb" + "github.com/tonkeeper/tongo/ton" + "hash/crc32" + "sort" +) + +var castagnoliTable = crc32.MakeTable(crc32.Castagnoli) + +const ( + magicTonBlockID = 0xc50b6e70 // crc32(ton.blockId root_cell_hash:int256 file_hash:int256 = ton.BlockId) + magicPublicKey = 0x4813b4c6 // crc32(pub.ed25519 key:int256 = PublicKey) +) + +func (c *Client) VerifyProofChain(ctx context.Context, source, target ton.BlockIDExt) error { + isForward := source.Seqno < target.Seqno + for source.Seqno != target.Seqno { + partialBlockProof, err := c.GetBlockProofRaw(ctx, source, &target) + if err != nil { + return fmt.Errorf("cannot get partial block proof from liteserver: %w", err) + } + partialBlockSource := partialBlockProof.From.ToBlockIdExt() + if partialBlockSource != source { + return fmt.Errorf("incorrect source partial block: got %v, want %v", partialBlockProof.From.Seqno, source.Seqno) + } + for _, step := range partialBlockProof.Steps { + switch step.SumType { + case "LiteServerBlockLinkBack": + toKeyBlock := step.LiteServerBlockLinkBack.ToKeyBlock + sourceBlock := step.LiteServerBlockLinkBack.From.ToBlockIdExt() + if sourceBlock != source { + return fmt.Errorf("incorrect partial block source") + } + targetBlock := step.LiteServerBlockLinkBack.To.ToBlockIdExt() + destProof, err := boc.DeserializeBoc(step.LiteServerBlockLinkBack.DestProof) + if err != nil { + return fmt.Errorf("unable to deserialize dest proof boc: %w", err) + } + stateProof, err := boc.DeserializeBoc(step.LiteServerBlockLinkBack.StateProof) + if err != nil { + return fmt.Errorf("unable to deserialize state proof boc: %w", err) + } + proof, err := boc.DeserializeBoc(step.LiteServerBlockLinkBack.Proof) + if err != nil { + return fmt.Errorf("unable to deserialize proof boc: %w", err) + } + err = verifyBackwardProofLink(toKeyBlock, sourceBlock, targetBlock, destProof[0], stateProof[0], proof[0]) + if err != nil { + return fmt.Errorf("failed to verify backward proof: %w", err) + } + source = step.LiteServerBlockLinkBack.To.ToBlockIdExt() + case "LiteServerBlockLinkForward": + if !isForward { + return fmt.Errorf("blocks cannot be linked forward if source.Seqno > target.Seqno") + } + toKeyBlock := step.LiteServerBlockLinkForward.ToKeyBlock + sourceBlock := step.LiteServerBlockLinkForward.From.ToBlockIdExt() + if sourceBlock != source { + return fmt.Errorf("incorrect partial block source") + } + targetBlock := step.LiteServerBlockLinkForward.To.ToBlockIdExt() + destProof, err := boc.DeserializeBoc(step.LiteServerBlockLinkForward.DestProof) + if err != nil { + return fmt.Errorf("unable to deserialize dest proof boc: %w", err) + } + configProof, err := boc.DeserializeBoc(step.LiteServerBlockLinkForward.ConfigProof) + if err != nil { + return fmt.Errorf("unable to deserialize state proof boc: %w", err) + } + signs := step.LiteServerBlockLinkForward.Signatures + err = verifyForwardProofLink(toKeyBlock, sourceBlock, targetBlock, destProof[0], configProof[0], signs) + if err != nil { + return fmt.Errorf("failed to verify forward proof: %w", err) + } + source = step.LiteServerBlockLinkForward.To.ToBlockIdExt() + } + } + source = partialBlockProof.To.ToBlockIdExt() + } + return nil +} + +func verifyBackwardProofLink(toKeyBlock bool, source, target ton.BlockIDExt, destProof, stateProof, proof *boc.Cell) error { + if source.Workchain != -1 || target.Workchain != -1 { + return fmt.Errorf("both blocks must be from the masterchain") + } + if source.Seqno <= target.Seqno { + return fmt.Errorf("source seqno must be > target seqno for backward link") + } + shardState, err := checkBlockShardStateProof([]*boc.Cell{proof, stateProof}, source.RootHash, nil) + if err != nil { + return fmt.Errorf("failed to check proof for shard in the masterchain: %w", err) + } + if !shardState.ShardStateUnsplit.Custom.Exists { + return fmt.Errorf("source block is not a masterchain block") + } + prevBlocks := shardState.ShardStateUnsplit.Custom.Value.Value.Other.PrevBlocks + var targetProofBlock *tlb.KeyExtBlkRef + for i, k := range prevBlocks.Keys() { + if k.Equal(tlb.Uint32(target.Seqno)) { + targetProofBlock = &prevBlocks.Values()[i] + break + } + } + if targetProofBlock == nil { + return fmt.Errorf("target block is not found in shard state") + } + if targetProofBlock.Key != toKeyBlock { + return fmt.Errorf("unexpected target block in proof isKey value: got = %v, want %v", targetProofBlock.Key, toKeyBlock) + } + if targetProofBlock.BlkRef.RootHash.Equal(target.RootHash) { + return fmt.Errorf("incorrect target block hash in proof") + } + // proof target block + targetBlock, err := checkBlockProof(*destProof, target.RootHash) + if err != nil { + return fmt.Errorf("failed to check target block proof: %w", err) + } + if targetBlock.VirtualRoot.Info.KeyBlock != toKeyBlock { + return fmt.Errorf("unexpected target block isKey value: got = %v, want %v", targetBlock.VirtualRoot.Info.KeyBlock, toKeyBlock) + } + return nil +} + +func verifyForwardProofLink(toKeyBlock bool, source, target ton.BlockIDExt, destProof, configProof *boc.Cell, signatures liteclient.LiteServerSignatureSet) error { + if source.Workchain != -1 || target.Workchain != -1 { + return fmt.Errorf("both blocks must be source the masterchain") + } + if source.Seqno >= target.Seqno { + return fmt.Errorf("source seqno must be < target seqno for forward link") + } + // proof source block + sourceBlock, err := checkBlockProof(*configProof, source.RootHash) + if err != nil { + return fmt.Errorf("failed to check source block proof: %w", err) + } + if !sourceBlock.VirtualRoot.Extra.Custom.Exists { + return fmt.Errorf("source block is lack of extra info") + } + cfg, err := ton.ConvertBlockchainConfigStrict(sourceBlock.VirtualRoot.Extra.Custom.Value.Value.Config) + if err != nil { + return fmt.Errorf("failed to convert config block: %w", err) + } + if cfg.ConfigParam34 == nil { + return fmt.Errorf("config block is missing 34 config param") + } + blockValidators := cfg.ConfigParam34.CurValidators + if cfg.ConfigParam28 == nil { + return fmt.Errorf("config block is missing 28 config param") + } + catchainConfig := cfg.ConfigParam28.CatchainConfig + // proof target block + targetBlock, err := checkBlockProof(*destProof, target.RootHash) + if err != nil { + return fmt.Errorf("failed to check target block proof: %w", err) + } + if targetBlock.VirtualRoot.Info.KeyBlock != toKeyBlock { + return fmt.Errorf("unexpected target block in proof isKey value: got = %v, want %v", targetBlock.VirtualRoot.Info.KeyBlock, toKeyBlock) + } + if targetBlock.VirtualRoot.Info.GenValidatorListHashShort != signatures.ValidatorSetHash { + return fmt.Errorf("incorrect validator list hash") + } + if targetBlock.VirtualRoot.Info.GenCatchainSeqno != signatures.CatchainSeqno { + return fmt.Errorf("incorrect catchain seqno") + } + validators, err := getMainValidators(&target, catchainConfig, blockValidators, targetBlock.VirtualRoot.Info.GenCatchainSeqno) + if err != nil { + return fmt.Errorf("failed to get main validators: %w", err) + } + if err = checkBlockSignatures(&target, signatures, validators); err != nil { + return fmt.Errorf("failed to check block signatures: %w", err) + } + return nil +} + +func getMainValidators(block *ton.BlockIDExt, catchainConfig tlb.CatchainConfig, validatorsSet tlb.ValidatorSet, catchainSeqno uint32) ([]*tlb.ValidatorAddr, error) { + if block.Workchain != -1 { + return nil, fmt.Errorf("block must be from the masterchain") + } + isShuffle := false + var validatorsNum int + // set isShuffle only for new catchain config, if its old, then isShuffle is always false + switch catchainConfig.SumType { + case "CatchainConfigNew": + isShuffle = catchainConfig.CatchainConfigNew.ShuffleMcValidators + } + var ( + validatorAddrs []tlb.ValidatorAddr + keys []tlb.Uint16 + ) + switch validatorsSet.SumType { + case "Validators": + validatorsNum = int(validatorsSet.Validators.Main) + validatorAddrs = make([]tlb.ValidatorAddr, len(validatorsSet.Validators.List.Keys())) + keys = validatorsSet.Validators.List.Keys() + for i, v := range validatorsSet.Validators.List.Values() { + validatorAddrs[i] = *v.ValidatorAddr + } + case "ValidatorsExt": + validatorsNum = int(validatorsSet.ValidatorsExt.Main) + totalWeight := uint64(0) + validatorAddrs = make([]tlb.ValidatorAddr, len(validatorsSet.ValidatorsExt.List.Keys())) + keys = validatorsSet.ValidatorsExt.List.Keys() + for i, v := range validatorsSet.ValidatorsExt.List.Values() { + totalWeight += v.ValidatorAddr.Weight + validatorAddrs[i] = *v.ValidatorAddr + } + if totalWeight != validatorsSet.ValidatorsExt.TotalWeight { + return nil, fmt.Errorf("incorrect sum of validators weights") + } + default: + return nil, fmt.Errorf("unknown validators set sumtype") + } + sort.Slice(validatorAddrs, func(i, j int) bool { + return keys[i] < keys[j] + }) + if len(validatorAddrs) == 0 { + return nil, fmt.Errorf("zero validators found") + } + if validatorsNum > len(validatorAddrs) { + validatorsNum = len(validatorAddrs) + } + var validators = make([]*tlb.ValidatorAddr, validatorsNum) + if isShuffle { + prng, err := ton.NewValidatorPRNG([32]byte{}, block.Shard, block.Workchain, catchainSeqno) + if err != nil { + return nil, fmt.Errorf("unable to create validator prng: %w", err) + } + idx := make([]uint32, validatorsNum) + for i := 0; i < validatorsNum; i++ { + j := prng.NextRanged(uint64(i) + 1) + idx[i] = idx[j] + idx[j] = uint32(i) + } + for i := 0; i < validatorsNum; i++ { + validators[i] = &validatorAddrs[idx[i]] + } + return validators, nil + } + for i := 0; i < validatorsNum; i++ { + validators[i] = &validatorAddrs[i] + } + return validators, nil +} + +func checkBlockSignatures(block *ton.BlockIDExt, signatures liteclient.LiteServerSignatureSet, validators []*tlb.ValidatorAddr) error { + if len(validators) == 0 { + return fmt.Errorf("zero validators found") + } + if len(signatures.Signatures) == 0 { + return fmt.Errorf("zero signatures found") + } + validatorSetHash, err := computeValidatorSetHash(signatures.CatchainSeqno, validators) + if err != nil { + return fmt.Errorf("unable to compute validator set hash: %w", err) + } + if validatorSetHash != signatures.ValidatorSetHash { + return fmt.Errorf("invalid validator set hash") + } + totalWeight := uint64(0) + keyToValidator := map[[32]byte]*tlb.ValidatorAddr{} + magicPrefix := make([]byte, 4) + binary.LittleEndian.PutUint32(magicPrefix, magicPublicKey) + for _, v := range validators { + pubKey := v.PublicKey.PubKey[:] + // add some magic prefix for each validator's pub keys + hashKey := sha256.Sum256(append(magicPrefix, pubKey...)) + + totalWeight += v.Weight + keyToValidator[hashKey] = v + } + sort.Slice(signatures.Signatures, func(i, j int) bool { + return bytes.Compare(signatures.Signatures[i].NodeIdShort[:], signatures.Signatures[j].NodeIdShort[:]) < 0 + }) + blockBytes := make([]byte, 4) // blockBytes = 0xc50b6e70 + blockIDExt.RootHash + blockIDExt.FileHash + binary.LittleEndian.PutUint32(blockBytes, magicTonBlockID) + blockBytes = append(blockBytes, block.RootHash[:]...) + blockBytes = append(blockBytes, block.FileHash[:]...) + signedWeight := uint64(0) + for i, sig := range signatures.Signatures { + if i > 0 { + prevSig := signatures.Signatures[i-1] + if sig.NodeIdShort == prevSig.NodeIdShort { + return fmt.Errorf("duplicated node signatures found") + } + } + v, ok := keyToValidator[sig.NodeIdShort] + if !ok { + return fmt.Errorf("unknown validator signatures: %v", hex.EncodeToString(sig.NodeIdShort[:])) + } + var pubKey [32]byte + pubKey = v.PublicKey.PubKey + if !ed25519.Verify(pubKey[:], blockBytes, sig.Signature) { + return fmt.Errorf("invalid validator signature: %v", hex.EncodeToString(sig.NodeIdShort[:])) + } + signedWeight += v.Weight + if signedWeight > totalWeight { + break + } + } + if 3*signedWeight <= 2*totalWeight { + return fmt.Errorf("not enoght signed weights: %v/%v < 2/3", signedWeight, totalWeight) + } + return nil +} + +func computeValidatorSetHash(catchainSeqno uint32, validators []*tlb.ValidatorAddr) (uint32, error) { + type tlValidator struct { + Key tl.Int256 + Weight uint64 + Addr tl.Int256 + } + type tlValidatorSet struct { // SumType with one field for magic prefix + SumType tl.SumType + ValidatorSet struct { + CatchainSeqno uint32 + Validators []tlValidator + } `tlSumType:"901660ed"` + } + tlValidators := make([]tlValidator, len(validators)) + for i, currValidator := range validators { + tlValidators[i].Key = tl.Int256(currValidator.PublicKey.PubKey) + tlValidators[i].Weight = currValidator.Weight + tlValidators[i].Addr = tl.Int256(currValidator.AdnlAddr) + } + tlValSet := tlValidatorSet{ + SumType: "ValidatorSet", + ValidatorSet: struct { + CatchainSeqno uint32 + Validators []tlValidator + }{CatchainSeqno: catchainSeqno, Validators: tlValidators}, + } + validatorSetBytes, err := tl.Marshal(tlValSet) + if err != nil { + return 0, fmt.Errorf("unable to marshal validator set: %w", err) + } + return crc32.Checksum(validatorSetBytes, castagnoliTable), nil +} diff --git a/liteapi/proof_test.go b/liteapi/proof_test.go new file mode 100644 index 0000000..f154f24 --- /dev/null +++ b/liteapi/proof_test.go @@ -0,0 +1,103 @@ +package liteapi + +import ( + "context" + "fmt" + "github.com/tonkeeper/tongo/ton" + "testing" + "time" +) + +func getInitBlock() (*ton.BlockIDExt, error) { + var rootHash ton.Bits256 + err := rootHash.FromBase64("VpWyfNOLm8Rqt6CZZ9dZGqJRO3NyrlHHYN1k1oLbJ6g=") + if err != nil { + return nil, fmt.Errorf("incorrect root hash") + } + var fileHash ton.Bits256 + err = fileHash.FromBase64("8o12KX54BtJM8RERD1J97Qe1ZWk61LIIyXydlBnixK8=") + if err != nil { + return nil, fmt.Errorf("incorrect file hash") + } + return &ton.BlockIDExt{ + BlockID: ton.BlockID{ + Workchain: -1, + Shard: 9223372036854775808, + Seqno: 34835953, + }, + RootHash: rootHash, + FileHash: fileHash, + }, nil +} + +func getLastBlockInMasterchain(c *Client) (*ton.BlockIDExt, error) { + lst, err := c.GetMasterchainInfo(context.Background()) + if err != nil { + return nil, err + } + blk := lst.Last.ToBlockIdExt() + return &blk, nil +} + +func TestVerifyProofChain(t *testing.T) { + c, err := NewClientWithDefaultMainnet() + if err != nil { + t.Fatalf("unable to create liteclient: %v", err) + } + from, err := getInitBlock() + if err != nil { + t.Fatalf("unable to get init block: %v", err) + } + to, err := getLastBlockInMasterchain(c) + if err != nil { + t.Fatalf("unable to get last block: %v", err) + } + type Test struct { + name string + from *ton.BlockIDExt + to *ton.BlockIDExt + } + tests := []Test{ + { + name: "test verify forward proof chain", + from: from, + to: to, + }, + { + name: "test verify backward proof chain", + from: to, + to: from, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + err = c.VerifyProofChain(context.Background(), *test.from, *test.to) + if err != nil { + t.Errorf("proof chain failed from %v, to %v: %v", test.from.Seqno, test.to.Seqno, err) + } + }) + } +} + +func TestVerifyProofChainFor5NewBlocks(t *testing.T) { + c, err := NewClientWithDefaultMainnet() + if err != nil { + t.Fatalf("unable to create liteclient: %v", err) + } + from, err := getLastBlockInMasterchain(c) + if err != nil { + t.Fatalf("unable to get source block: %v", err) + } + for i := 0; i < 5; i++ { + time.Sleep(5 * time.Second) // delay to wait for new blocks + to, err := getLastBlockInMasterchain(c) + if err != nil { + t.Fatalf("unable to get target block: %v", err) + } + err = c.VerifyProofChain(context.Background(), *from, *to) + if err != nil { + t.Errorf("proof chain failed from %v, to %v: %v", from.Seqno, to.Seqno, err) + } + from = to + } +} diff --git a/tlb/bintree.go b/tlb/bintree.go index 183f37a..8d36af4 100644 --- a/tlb/bintree.go +++ b/tlb/bintree.go @@ -54,6 +54,13 @@ func (b *BinTree[T]) UnmarshalTLB(c *boc.Cell, decoder *Decoder) error { } b.Values = make([]T, 0, len(dec)) for _, i := range dec { + if i.CellType() == boc.PrunedBranchCell { + cell := resolvePrunedCell(c, decoder.resolvePruned) + if cell == nil { + continue + } + i = cell + } var t T err := decoder.Unmarshal(i, &t) if err != nil { diff --git a/tlb/decoder.go b/tlb/decoder.go index 09fecf0..b8de0ba 100644 --- a/tlb/decoder.go +++ b/tlb/decoder.go @@ -8,13 +8,15 @@ import ( ) type resolveLib func(hash Bits256) (*boc.Cell, error) +type resolvePruned func(hash Bits256) (*boc.Cell, error) // Decoder unmarshals a cell into a golang type. type Decoder struct { - hasher *boc.Hasher - withDebug bool - debugPath []string - resolveLib resolveLib + hasher *boc.Hasher + withDebug bool + debugPath []string + resolveLib resolveLib + resolvePruned resolvePruned } func (d *Decoder) WithDebug() *Decoder { @@ -28,6 +30,12 @@ func (d *Decoder) WithLibraryResolver(resolveLib resolveLib) *Decoder { return d } +// WithPrunedResolver provides a function which is used to fetch a pruned cell by its hash. +func (d *Decoder) WithPrunedResolver(resolvePruned resolvePruned) *Decoder { + d.resolvePruned = resolvePruned + return d +} + // NewDecoder returns a new Decoder. func NewDecoder() *Decoder { return &Decoder{ @@ -67,6 +75,25 @@ func decode(c *boc.Cell, tag string, val reflect.Value, decoder *Decoder) error decoder.debugPath = decoder.debugPath[:len(decoder.debugPath)-1] }() } + if c.CellType() == boc.PrunedBranchCell { + if val.Kind() == reflect.Ptr && val.Type() == bocCellPointerType { + // this is a pruned cell, and we unmarshal it to a cell. + // let's not resolve it and keep it as is + val.Elem().Set(reflect.ValueOf(c).Elem()) + return nil + } + if val.Kind() == reflect.Ptr && val.Type() == bocTlbANyPointerType { + // same as lib resolve + //todo: remove + a := Any(*c) + val.Elem().Set(reflect.ValueOf(a)) + return nil + } + cell := resolvePrunedCell(c, decoder.resolvePruned) + if cell != nil { + c = cell + } + } if c.IsLibrary() { if val.Kind() == reflect.Ptr && val.Type() == bocCellPointerType { // this is a library cell, and we unmarshal it to a cell. @@ -116,7 +143,14 @@ func decode(c *boc.Cell, tag string, val reflect.Value, decoder *Decoder) error return fmt.Errorf("library cell as a ref is not implemented") } if c.CellType() == boc.PrunedBranchCell { - return nil + cell := resolvePrunedCell(c, decoder.resolvePruned) + // TODO: maybe check for pointer too + if val.Kind() == reflect.Struct && val.Type() == bocCellType { + return decodeCell(c, val) + } else if cell == nil { + return nil + } + c = cell } case t.IsMaybe: tag = "" @@ -143,7 +177,14 @@ func decode(c *boc.Cell, tag string, val reflect.Value, decoder *Decoder) error return fmt.Errorf("library cell as a ref is not implemented") } if c.CellType() == boc.PrunedBranchCell { - return nil + cell := resolvePrunedCell(c, decoder.resolvePruned) + // TODO: maybe check for pointer too + if val.Kind() == reflect.Struct && val.Type() == bocCellType { + return decodeCell(c, val) + } else if cell == nil { + return nil + } + c = cell } } i, ok := val.Interface().(UnmarshalerTLB) @@ -345,3 +386,20 @@ func decodeBitString(c *boc.Cell, val reflect.Value) error { func (dec *Decoder) Hasher() *boc.Hasher { return dec.hasher } + +func resolvePrunedCell(c *boc.Cell, resolver resolvePruned) *boc.Cell { + if resolver == nil { + return nil + } + hash, err := c.Hash256WithLevel(0) + if err != nil { + return nil + } + cell, err := resolver(hash) + if err != nil { + return nil + } + // TODO: we need to reset all counters for cells or deep copy all cells + cell.ResetCounters() + return cell +} diff --git a/tlb/primitives.go b/tlb/primitives.go index 7a5c5db..efdf4c2 100644 --- a/tlb/primitives.go +++ b/tlb/primitives.go @@ -228,9 +228,13 @@ func (m *Ref[T]) UnmarshalTLB(c *boc.Cell, decoder *Decoder) error { return err } if r.CellType() == boc.PrunedBranchCell { - var value T - m.Value = value - return nil + cell := resolvePrunedCell(c, decoder.resolvePruned) + if cell == nil { + var value T + m.Value = value + return nil + } + r = cell } err = decoder.Unmarshal(r, &m.Value) if err != nil { diff --git a/tlb/proof.go b/tlb/proof.go index 554c0ee..584f96d 100644 --- a/tlb/proof.go +++ b/tlb/proof.go @@ -1,6 +1,7 @@ package tlb import ( + "errors" "fmt" "github.com/tonkeeper/tongo/boc" @@ -15,6 +16,33 @@ type MerkleProof[T any] struct { VirtualRoot T `tlb:"^"` } +func (p *MerkleProof[T]) UnmarshalTLB(c *boc.Cell, decoder *Decoder) error { + depth, hash, err := c.CalculateMerkleProofMeta() + if err != nil { + return err + } + // TODO: remove duplicates + type merkleProof[T any] struct { + Magic Magic `tlb:"!merkle_proof#03"` + VirtualHash Bits256 + Depth uint16 + VirtualRoot T `tlb:"^"` + } + var res merkleProof[T] + err = decoder.Unmarshal(c, &res) + if err != nil { + return err + } + if res.VirtualHash != hash { + return errors.New("invalid virtual hash") + } + if int(res.Depth) != depth { + return errors.New("invalid depth") + } + *p = MerkleProof[T](res) + return nil +} + type MerkleUpdate[T any] struct { Magic Magic `tlb:"!merkle_update#04"` FromHash Bits256 @@ -92,7 +120,15 @@ func (s *ShardState) UnmarshalTLB(c *boc.Cell, decoder *Decoder) error { return err } } else { - s.SplitState.Left = ShardStateUnsplit{} + cell := resolvePrunedCell(c, decoder.resolvePruned) + if cell == nil { + s.SplitState.Left = ShardStateUnsplit{} + } else { + err = decoder.Unmarshal(cell, &s.SplitState.Left) + if err != nil { + return err + } + } } c1, err = c.NextRef() if err != nil { @@ -103,7 +139,15 @@ func (s *ShardState) UnmarshalTLB(c *boc.Cell, decoder *Decoder) error { return err } } else { - s.SplitState.Right = ShardStateUnsplit{} + cell := resolvePrunedCell(c, decoder.resolvePruned) + if cell == nil { + s.SplitState.Right = ShardStateUnsplit{} + } else { + err = decoder.Unmarshal(cell, &s.SplitState.Right) + if err != nil { + return err + } + } } s.SumType = "SplitState" break diff --git a/tlb/validators.go b/tlb/validators.go index 08749bb..0d2c20d 100644 --- a/tlb/validators.go +++ b/tlb/validators.go @@ -62,6 +62,14 @@ func (vs ValidatorsSet) Common() ValidatorSetsCommon { return vs.ValidatorsExt.ValidatorSetsCommon } +// validator_addr#73 public_key:SigPubKey weight:uint64 +// adnl_addr:bits256 = ValidatorDescr; +type ValidatorAddr struct { + PublicKey SigPubKey + Weight uint64 + AdnlAddr Bits256 +} + type ValidatorDescr struct { SumType // validator#53 public_key:SigPubKey weight:uint64 = ValidatorDescr; @@ -69,12 +77,7 @@ type ValidatorDescr struct { PublicKey SigPubKey Weight uint64 } `tlbSumType:"validator#53"` - // validator_addr#73 public_key:SigPubKey weight:uint64 adnl_addr:bits256 = ValidatorDescr; - ValidatorAddr *struct { - PublicKey SigPubKey - Weight uint64 - AdnlAddr Bits256 - } `tlbSumType:"validatoraddr#73"` + ValidatorAddr *ValidatorAddr `tlbSumType:"validatoraddr#73"` } func (vd ValidatorDescr) MarshalJSON() ([]byte, error) { diff --git a/ton/prng.go b/ton/prng.go new file mode 100644 index 0000000..1f94286 --- /dev/null +++ b/ton/prng.go @@ -0,0 +1,75 @@ +package ton + +import ( + "crypto/sha512" + "encoding/binary" + "math/big" +) + +type ValidatorPRNGDescr struct { + seed [32]byte // seed for validator set computation, set to zero if none + shard uint64 + workchain int32 + catchainSeqno uint32 + hash []byte +} + +// ValidatorPRNG is a pseudorandom number generator to randomize validators order +type ValidatorPRNG struct { + descr ValidatorPRNGDescr + pos int + limit int +} + +func NewValidatorPRNG(seed [32]byte, shard uint64, workchain int32, catchainSeqno uint32) (*ValidatorPRNG, error) { + descr := ValidatorPRNGDescr{ + seed: seed, + shard: shard, + workchain: workchain, + catchainSeqno: catchainSeqno, + } + return &ValidatorPRNG{ + descr: descr, + }, nil +} + +func (v *ValidatorPRNG) NextUInt64() uint64 { + if v.pos < v.limit { + temp := v.pos + v.pos++ + return binary.BigEndian.Uint64(v.descr.hash[temp*8:]) + } + v.rebuildHash() + v.increaseSeed() + v.pos = 1 + v.limit = 8 + return binary.BigEndian.Uint64(v.descr.hash) +} + +func (v *ValidatorPRNG) NextRanged(rng uint64) uint64 { + y := new(big.Int).SetUint64(v.NextUInt64()) + bigRange := new(big.Int).SetUint64(rng) + // return (y * rng) >> 64. + // Use big int to avoid uint64 overflow + return new(big.Int).Rsh(new(big.Int).Mul(y, bigRange), 64).Uint64() +} + +func (v *ValidatorPRNG) increaseSeed() { + for i := 31; i >= 0; i-- { + v.descr.seed[i]++ + if v.descr.seed[i] != 0 { + break + } + } +} + +func (v *ValidatorPRNG) rebuildHash() { + h := sha512.New() + h.Write(v.descr.seed[:]) + buf := make([]byte, 16) + binary.BigEndian.PutUint64(buf, v.descr.shard) + binary.BigEndian.PutUint32(buf[8:], uint32(v.descr.workchain)) + binary.BigEndian.PutUint32(buf[12:], v.descr.catchainSeqno) + h.Write(buf) + v.descr.hash = h.Sum(nil) +} diff --git a/ton/prng_test.go b/ton/prng_test.go new file mode 100644 index 0000000..251df38 --- /dev/null +++ b/ton/prng_test.go @@ -0,0 +1,174 @@ +package ton + +import ( + "reflect" + "testing" +) + +func createValidatorPRNG() (*ValidatorPRNG, error) { + return NewValidatorPRNG([32]byte{}, 0x8000000000000000, -1, 0) +} + +func TestNextUInt64(t *testing.T) { + tests := []struct { + name string + times int + // expected output + outputs []uint64 + }{ + { + name: "test next uint64 1 time", + times: 1, + outputs: []uint64{6186953295200455061}, + }, + { + name: "test next uint64 5 times", + times: 5, + outputs: []uint64{ + 6186953295200455061, 9716249430906648876, 893850564141714240, 16362499097668570104, + 7550721807492789767, + }, + }, + { + name: "test next uint64 10 times", + times: 10, + outputs: []uint64{ + 6186953295200455061, 9716249430906648876, 893850564141714240, 16362499097668570104, + 7550721807492789767, 8027788155046975774, 2198044665159296191, 15889925754150310949, + 2854201576873883948, 3908958851740847745, + }, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + prng, err := createValidatorPRNG() + if err != nil { + t.Fatalf("cannot create validator prng: %v", err) + } + output := make([]uint64, test.times) + for i := 0; i < test.times; i++ { + output[i] = prng.NextUInt64() + } + if !reflect.DeepEqual(output, test.outputs) { + t.Errorf("incorrect values: got %v, want %v", output, test.outputs) + } + }) + } +} + +func TestNextRanged(t *testing.T) { + tests := []struct { + name string + rng uint64 + // expected output + output uint64 + }{ + { + name: "test next ranged with range 5", + rng: 5, + output: 1, + }, + { + name: "test next ranged with range 18324", + rng: 18324, + output: 6145, + }, + { + name: "test next ranged with range 10000000000", + rng: 10000000000, + output: 3353954101, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + prng, err := createValidatorPRNG() + if err != nil { + t.Fatalf("cannot create validator prng: %v", err) + } + output := prng.NextRanged(test.rng) + if output != test.output { + t.Errorf("incorrect value: got %v, want %v", output, test.output) + } + }) + } +} + +func TestIncreaseSeed(t *testing.T) { + seed1 := [32]byte{} + seed1[31] = 1 + seed40 := [32]byte{} + seed40[31] = 40 + seed256 := [32]byte{} + seed256[30] = 1 + tests := []struct { + name string + times int + // expected seed + seed [32]byte + }{ + { + name: "test increase seed 1 time", + times: 1, + seed: seed1, + }, + { + name: "test increase seed 40 time", + times: 40, + seed: seed40, + }, + { + name: "test increase seed 256 time", // todo is it okay? + times: 256, + seed: seed256, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + prng, err := createValidatorPRNG() + if err != nil { + t.Fatalf("cannot create validator prng: %v", err) + } + for i := 0; i < test.times; i++ { + prng.increaseSeed() + } + if !reflect.DeepEqual(prng.descr.seed, test.seed) { + t.Errorf("incorrect seed after increases: got %v, want %v", prng.descr.seed, test.seed) + } + }) + } +} + +func TestIncreaseHash(t *testing.T) { + seed32 := make([]byte, 32) + seed32[31] = 32 + hash32 := []byte{ + 74, 27, 97, 202, 222, 150, 35, 10, 94, 215, 240, 213, 147, 229, 252, 235, 220, 93, 61, 153, 58, 129, 85, 207, 18, 223, 177, 238, 191, 27, + 82, 201, 215, 138, 181, 138, 211, 64, 181, 135, 235, 229, 167, 89, 39, 106, 210, 242, 97, 239, 129, 126, 111, 113, 182, 53, 72, 200, 103, + 177, 156, 208, 84, 20, + } + tests := []struct { + name string + seed []byte + // expected hash + hash []byte + }{ + { + name: "test rebuild hash", + seed: seed32, + hash: hash32, + }, + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + prng, err := createValidatorPRNG() + if err != nil { + t.Fatalf("cannot create validator prng: %v", err) + } + copy(prng.descr.seed[:], test.seed) + prng.rebuildHash() + if !reflect.DeepEqual(prng.descr.hash, test.hash) { + t.Errorf("incorrect hash after rebuild: got %v, want %v", prng.descr.hash, test.hash) + } + }) + } +}