diff --git a/doc/user-guide.md b/doc/user-guide.md index 7e9ab946..73913a2c 100644 --- a/doc/user-guide.md +++ b/doc/user-guide.md @@ -1971,3 +1971,188 @@ sequenceDiagram ``` The reverse-direction calls reuse the existing inbound gRPC streams established by the clients, so no additional network connections or firewall rules are needed. + +## Tree-Structured Configurations + +A `TreeConfiguration` organizes a flat `Configuration` into a tree so that fan-out and aggregation can be performed hop-by-hop rather than requiring the client to hold direct connections to every node in a large cluster. + +**When to use this:** tree configurations are useful when the cluster is too large for a single client to fan out to all nodes simultaneously, or when hierarchical aggregation (e.g. summing sensor readings) should happen at each level rather than at the root. + +### Building a Tree + +Call `AsTree` on any `Configuration` after all nodes have connected: + +```go +sys, _ := gorums.NewSystem(addr, + gorums.WithServerOptions(gorums.WithConfig(myID, peerList)), + gorums.WithOutboundNodes(peerList), + // dial options … +) + +// Wait for all peers to connect before building the tree. +if err := sys.WaitForConfig(ctx, func(cfg gorums.Configuration) bool { + return cfg.Size() == clusterSize +}); err != nil { + panic(err) +} + +tree, err := sys.OutboundConfig().AsTree(gorums.TreeOptions{ + BranchingFactor: 2, + Depth: 3, +}) +``` + +`TreeOptions.BranchingFactor` (≥ 2) is the number of children per internal node. +`TreeOptions.Depth` (≥ 1) is the number of hops from root to leaves. +The tree capacity is `(bf^(depth+1) − 1) / (bf − 1)`; `AsTree` returns an error if the configuration has more nodes than the capacity, or if it is empty. + +The tree is laid out in breadth-first order: +the first node in the configuration is the root, the next `bf` are its children, the next `bf²` are the grandchildren, and so on. +If the configuration has fewer nodes than a perfect tree of the given shape, the last level is partial. + +### Registering the Tree with the Server + +Because the tree requires live node connections, it is registered after the server starts with `RegisterTree`: + +```go +gorumsSrv.RegisterTree(tree) +``` + +Every server in the cluster should call `RegisterTree` with a tree built from its **own** outbound config, +so that `ctx.TreeChildren()` returns nodes backed by that server's connections rather than the client's: + +```go +for i, sys := range systems { + sysTree, _ := sys.OutboundConfig().AsTree(opts) + sys.RegisterService(nil, func(srv *gorums.Server) { + srv.RegisterTree(sysTree) + pb.RegisterMyServer(srv, impls[i]) + }) +} +``` + +### Tree Fan-Out: Broadcast + +Define the method as a regular `multicast` in the proto file: + +```proto +rpc Broadcast(BroadcastRequest) returns (Empty) { + option (gorums.multicast) = true; +} +``` + +In the server handler, relay to `ctx.TreeChildren()` and then apply local logic. +Leaves have no children and skip the relay automatically: + +```go +func (s *myServer) Broadcast(ctx gorums.ServerCtx, req *pb.BroadcastRequest) { + // Relay to children first so the message propagates downward. + if children := ctx.TreeChildren(); len(children) > 0 { + _ = pb.Broadcast(children.Context(ctx), req) + } + s.applyLocally(req) +} +``` + +On the client side, call the generated function with `tree.Context(ctx)` instead of `cfg.Context(ctx)`. +This targets the root's direct children; the relay handles the rest of the tree: + +```go +pb.Broadcast(tree.Context(ctx), &pb.BroadcastRequest{Text: "hello"}) +``` + +### Tree Aggregation: QuorumCall + +Define the method as a regular `quorumcall`: + +```proto +rpc Aggregate(AggregateRequest) returns (AggregateResponse) { + option (gorums.quorumcall) = true; +} +``` + +In the server handler, relay to children first with `ctx.Release()` to avoid blocking the server while waiting for child responses, +then aggregate the returned values with local logic. +Leaves return just their own value: + +```go +func (s *myServer) Aggregate(ctx gorums.ServerCtx, req *pb.AggregateRequest) (*pb.AggregateResponse, error) { + total := s.localValue + if children := ctx.TreeChildren(); len(children) > 0 { + // Release before waiting on children to allow concurrent inbound processing. + ctx.Release() + for r := range pb.Aggregate(children.Context(ctx), req).Results() { + if r.Err == nil { + total += r.Value.GetTotal() + } + } + } + return pb.AggregateResponse_builder{Total: total}.Build(), nil +} +``` + +On the client side, the call is identical to a flat quorum call. +`Responses` contains one entry per direct child of the root (at most `bf` entries); each entry is that subtree's already-aggregated result: + +```go +responses := pb.Aggregate(tree.Context(ctx), &pb.AggregateRequest{}) +grandTotal := int32(0) +for r := range responses.Results() { + if r.Err == nil { + grandTotal += r.Value.GetTotal() + } +} +``` + +### Sequence Diagram + +The following diagram shows the flow for a 7-node tree with bf=2, depth=2: + +```mermaid +sequenceDiagram + participant C as Client + participant N2 as Node 2 + participant N3 as Node 3 + participant N4 as Node 4 + participant N5 as Node 5 + participant N6 as Node 6 + participant N7 as Node 7 + + Note over C: tree.Context(ctx) addresses {2, 3} + C->>N2: Aggregate() + C->>N3: Aggregate() + + Note over N2: ctx.TreeChildren() = {4, 5} + Note over N2: ctx.Release() + N2->>N4: Aggregate() [relay] + N2->>N5: Aggregate() [relay] + N4-->>N2: total=4 + N5-->>N2: total=5 + Note over N2: 2+4+5 = 11 + N2-->>C: total=11 + + Note over N3: ctx.TreeChildren() = {6, 7} + Note over N3: ctx.Release() + N3->>N6: Aggregate() [relay] + N3->>N7: Aggregate() [relay] + N6-->>N3: total=6 + N7-->>N3: total=7 + Note over N3: 3+6+7 = 16 + N3-->>C: total=16 + + Note over C: 11+16 = 27 +``` + +### ServerCtx Tree Accessors + +Inside any server handler registered on a tree-aware server, three accessors are available on `gorums.ServerCtx`: + +| Method | Returns | Notes | +| -------------------- | --------------- | ------------------------------------------------------------------------- | +| `ctx.TreeChildren()` | `Configuration` | Direct children of this node; nil for leaves and if no tree is registered | +| `ctx.TreeParent()` | `*Node` | Parent of this node; nil for the root | + +### Working Example + +A complete working example is in [`internal/tests/tree/`](../internal/tests/tree/). +It defines a `TreeAggregator` service with `Broadcast` (multicast) and `Aggregate` (quorum call) methods and tests both patterns on a 7-node tree. diff --git a/handler.go b/handler.go index 40e81205..8ebaf5cb 100644 --- a/handler.go +++ b/handler.go @@ -119,6 +119,26 @@ func (ctx *ServerCtx) ClientConfigContext() *ConfigContext { return nil } +// TreeChildren returns a [Configuration] containing the direct children of this +// server in the registered [TreeConfiguration]. Returns nil if no tree is registered +// or this server is a leaf node or is not part of the tree. +func (ctx *ServerCtx) TreeChildren() Configuration { + if ctx.srv == nil || ctx.srv.tree == nil { + return nil + } + return ctx.srv.tree.ChildrenOf(ctx.srv.myID) +} + +// TreeParent returns the parent [Node] of this server in the registered +// [TreeConfiguration], or nil if this server is the root, not part of the tree, +// or no tree is registered. +func (ctx *ServerCtx) TreeParent() *Node { + if ctx.srv == nil || ctx.srv.tree == nil { + return nil + } + return ctx.srv.tree.ParentOf(ctx.srv.myID) +} + // NewResponseMessage creates a new response envelope based on the provided proto // message. The response includes the message ID and method from the request // to facilitate routing the response back to the caller on the client side. diff --git a/internal/tests/tree/tree.pb.go b/internal/tests/tree/tree.pb.go new file mode 100644 index 00000000..2c3dd6ba --- /dev/null +++ b/internal/tests/tree/tree.pb.go @@ -0,0 +1,281 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc v7.34.1 +// source: tree.proto + +package tree + +import ( + _ "github.com/relab/gorums" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type BroadcastRequest struct { + state protoimpl.MessageState `protogen:"opaque.v1"` + xxx_hidden_Text string `protobuf:"bytes,1,opt,name=text"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *BroadcastRequest) Reset() { + *x = BroadcastRequest{} + mi := &file_tree_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *BroadcastRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*BroadcastRequest) ProtoMessage() {} + +func (x *BroadcastRequest) ProtoReflect() protoreflect.Message { + mi := &file_tree_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +func (x *BroadcastRequest) GetText() string { + if x != nil { + return x.xxx_hidden_Text + } + return "" +} + +func (x *BroadcastRequest) SetText(v string) { + x.xxx_hidden_Text = v +} + +type BroadcastRequest_builder struct { + _ [0]func() // Prevents comparability and use of unkeyed literals for the builder. + + Text string +} + +func (b0 BroadcastRequest_builder) Build() *BroadcastRequest { + m0 := &BroadcastRequest{} + b, x := &b0, m0 + _, _ = b, x + x.xxx_hidden_Text = b.Text + return m0 +} + +type AggregateRequest struct { + state protoimpl.MessageState `protogen:"opaque.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *AggregateRequest) Reset() { + *x = AggregateRequest{} + mi := &file_tree_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *AggregateRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AggregateRequest) ProtoMessage() {} + +func (x *AggregateRequest) ProtoReflect() protoreflect.Message { + mi := &file_tree_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +type AggregateRequest_builder struct { + _ [0]func() // Prevents comparability and use of unkeyed literals for the builder. + +} + +func (b0 AggregateRequest_builder) Build() *AggregateRequest { + m0 := &AggregateRequest{} + b, x := &b0, m0 + _, _ = b, x + return m0 +} + +type AggregateResponse struct { + state protoimpl.MessageState `protogen:"opaque.v1"` + xxx_hidden_Total int32 `protobuf:"varint,1,opt,name=total"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *AggregateResponse) Reset() { + *x = AggregateResponse{} + mi := &file_tree_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *AggregateResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AggregateResponse) ProtoMessage() {} + +func (x *AggregateResponse) ProtoReflect() protoreflect.Message { + mi := &file_tree_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +func (x *AggregateResponse) GetTotal() int32 { + if x != nil { + return x.xxx_hidden_Total + } + return 0 +} + +func (x *AggregateResponse) SetTotal(v int32) { + x.xxx_hidden_Total = v +} + +type AggregateResponse_builder struct { + _ [0]func() // Prevents comparability and use of unkeyed literals for the builder. + + Total int32 +} + +func (b0 AggregateResponse_builder) Build() *AggregateResponse { + m0 := &AggregateResponse{} + b, x := &b0, m0 + _, _ = b, x + x.xxx_hidden_Total = b.Total + return m0 +} + +type Empty struct { + state protoimpl.MessageState `protogen:"opaque.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Empty) Reset() { + *x = Empty{} + mi := &file_tree_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Empty) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Empty) ProtoMessage() {} + +func (x *Empty) ProtoReflect() protoreflect.Message { + mi := &file_tree_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +type Empty_builder struct { + _ [0]func() // Prevents comparability and use of unkeyed literals for the builder. + +} + +func (b0 Empty_builder) Build() *Empty { + m0 := &Empty{} + b, x := &b0, m0 + _, _ = b, x + return m0 +} + +var File_tree_proto protoreflect.FileDescriptor + +const file_tree_proto_rawDesc = "" + + "\n" + + "\n" + + "tree.proto\x12\x04tree\x1a\fgorums.proto\"&\n" + + "\x10BroadcastRequest\x12\x12\n" + + "\x04text\x18\x01 \x01(\tR\x04text\"\x12\n" + + "\x10AggregateRequest\")\n" + + "\x11AggregateResponse\x12\x14\n" + + "\x05total\x18\x01 \x01(\x05R\x05total\"\a\n" + + "\x05Empty2\x8c\x01\n" + + "\x0eTreeAggregator\x126\n" + + "\tBroadcast\x12\x16.tree.BroadcastRequest\x1a\v.tree.Empty\"\x04\x98\xb5\x18\x01\x12B\n" + + "\tAggregate\x12\x16.tree.AggregateRequest\x1a\x17.tree.AggregateResponse\"\x04\xa0\xb5\x18\x01B)Z\"github.com/relab/gorums/tests/tree\x92\x03\x02\b\x02b\beditionsp\xe8\a" + +var file_tree_proto_msgTypes = make([]protoimpl.MessageInfo, 4) +var file_tree_proto_goTypes = []any{ + (*BroadcastRequest)(nil), // 0: tree.BroadcastRequest + (*AggregateRequest)(nil), // 1: tree.AggregateRequest + (*AggregateResponse)(nil), // 2: tree.AggregateResponse + (*Empty)(nil), // 3: tree.Empty +} +var file_tree_proto_depIdxs = []int32{ + 0, // 0: tree.TreeAggregator.Broadcast:input_type -> tree.BroadcastRequest + 1, // 1: tree.TreeAggregator.Aggregate:input_type -> tree.AggregateRequest + 3, // 2: tree.TreeAggregator.Broadcast:output_type -> tree.Empty + 2, // 3: tree.TreeAggregator.Aggregate:output_type -> tree.AggregateResponse + 2, // [2:4] is the sub-list for method output_type + 0, // [0:2] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_tree_proto_init() } +func file_tree_proto_init() { + if File_tree_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_tree_proto_rawDesc), len(file_tree_proto_rawDesc)), + NumEnums: 0, + NumMessages: 4, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_tree_proto_goTypes, + DependencyIndexes: file_tree_proto_depIdxs, + MessageInfos: file_tree_proto_msgTypes, + }.Build() + File_tree_proto = out.File + file_tree_proto_goTypes = nil + file_tree_proto_depIdxs = nil +} diff --git a/internal/tests/tree/tree.proto b/internal/tests/tree/tree.proto new file mode 100644 index 00000000..509d475c --- /dev/null +++ b/internal/tests/tree/tree.proto @@ -0,0 +1,39 @@ +edition = "2023"; + +package tree; + +import "gorums.proto"; + +option features.field_presence = IMPLICIT; +option go_package = "github.com/relab/gorums/tests/tree"; + +// TreeAggregator demonstrates hierarchical fan-out and aggregation across a +// cluster of nodes organized as a tree. A client or root server fans out calls +// to a subtree; each internal node relays the call to its own children before +// (or after) processing locally; leaves apply only local logic. +service TreeAggregator { + // Broadcast fans a notification down the tree. Each internal-node handler + // should relay to ctx.TreeChildren() before or after local processing. + rpc Broadcast(BroadcastRequest) returns (Empty) { + option (gorums.multicast) = true; + } + + // Aggregate collects a value from every node in the tree. Each + // internal-node handler should relay to ctx.TreeChildren(), sum the + // children's responses, and add its own local contribution. + rpc Aggregate(AggregateRequest) returns (AggregateResponse) { + option (gorums.quorumcall) = true; + } +} + +message BroadcastRequest { + string text = 1; +} + +message AggregateRequest {} + +message AggregateResponse { + int32 total = 1; +} + +message Empty {} diff --git a/internal/tests/tree/tree_gorums.pb.go b/internal/tests/tree/tree_gorums.pb.go new file mode 100644 index 00000000..e1cf5120 --- /dev/null +++ b/internal/tests/tree/tree_gorums.pb.go @@ -0,0 +1,82 @@ +// Code generated by protoc-gen-gorums. DO NOT EDIT. +// versions: +// protoc-gen-gorums v0.11.0-devel +// protoc v7.34.1 +// source: tree.proto + +package tree + +import ( + gorums "github.com/relab/gorums" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = gorums.EnforceVersion(11 - gorums.MinVersion) + // Verify that the gorums runtime is sufficiently up-to-date. + _ = gorums.EnforceVersion(gorums.MaxVersion - 11) +) + +// The type aliases below are useful Gorums types that we make accessible +// from generated code. These names therefore become reserved identifiers, +// meaning that proto message types with these names would collide with the +// generated aliases and cause a compile error. +// +// The bundler (gorums_bundle.go) is responsible for discovering these +// aliases and any other identifiers defined herein, and adding them to +// the reserved identifiers list. +// +// If necessary, additional aliases and other identifiers should be added in +// the generator's cmd/protoc-gen-gorums/dev directory, and the bundler will +// automatically discover them and add them to the reserved identifiers list. + +type ( + Configuration = gorums.Configuration + Node = gorums.Node + NodeContext = gorums.NodeContext + ConfigContext = gorums.ConfigContext +) + +// AsyncAggregateResponse is a future for async quorum calls returning *AggregateResponse. +type AsyncAggregateResponse = *gorums.Async[*AggregateResponse] + +// CorrectableAggregateResponse is a correctable object for quorum calls returning *AggregateResponse. +type CorrectableAggregateResponse = *gorums.Correctable[*AggregateResponse] + +// Broadcast fans a notification down the tree. Each internal-node handler +// should relay to ctx.TreeChildren() before or after local processing. +func Broadcast(ctx *ConfigContext, in *BroadcastRequest, opts ...gorums.CallOption) error { + return gorums.Multicast(ctx, in, "tree.TreeAggregator.Broadcast", opts...) +} + +// Aggregate collects a value from every node in the tree. Each +// internal-node handler should relay to ctx.TreeChildren(), sum the +// children's responses, and add its own local contribution. +func Aggregate(ctx *ConfigContext, in *AggregateRequest, opts ...gorums.CallOption) *gorums.Responses[*AggregateResponse] { + return gorums.QuorumCall[*AggregateRequest, *AggregateResponse]( + ctx, in, "tree.TreeAggregator.Aggregate", + opts..., + ) +} + +// TreeAggregator is the server-side API for the TreeAggregator Service +type TreeAggregatorServer interface { + Broadcast(gorums.ServerCtx, *BroadcastRequest) + Aggregate(gorums.ServerCtx, *AggregateRequest) (*AggregateResponse, error) +} + +func RegisterTreeAggregatorServer(srv *gorums.Server, impl TreeAggregatorServer) { + srv.RegisterHandler("tree.TreeAggregator.Broadcast", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { + req := gorums.AsProto[*BroadcastRequest](in) + impl.Broadcast(ctx, req) + return nil, nil + }) + srv.RegisterHandler("tree.TreeAggregator.Aggregate", func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { + req := gorums.AsProto[*AggregateRequest](in) + resp, err := impl.Aggregate(ctx, req) + if err != nil { + return nil, err + } + return gorums.NewResponseMessage(in, resp), nil + }) +} diff --git a/internal/tests/tree/tree_test.go b/internal/tests/tree/tree_test.go new file mode 100644 index 00000000..113e3d3b --- /dev/null +++ b/internal/tests/tree/tree_test.go @@ -0,0 +1,188 @@ +package tree_test + +import ( + "sync" + "testing" + "time" + + "github.com/relab/gorums" + "github.com/relab/gorums/internal/tests/tree" +) + +// aggregatorSrv implements tree.TreeAggregatorServer. +// Each instance holds a fixed local value and tracks received broadcasts. +type aggregatorSrv struct { + value int32 // local contribution to Aggregate + wg *sync.WaitGroup + mu sync.Mutex + received []string // texts received via Broadcast +} + +// Broadcast relays the notification to the node's tree children before +// applying local logic. Leaves skip the relay because ctx.TreeChildren() +// returns nil. This implements the downward fan-out half of a tree multicast. +func (s *aggregatorSrv) Broadcast(ctx gorums.ServerCtx, req *tree.BroadcastRequest) { + if children := ctx.TreeChildren(); len(children) > 0 { + // Relay to children; errors are not recoverable in a one-way call. + _ = tree.Broadcast(children.Context(ctx), req) + } + s.mu.Lock() + s.received = append(s.received, req.GetText()) + s.mu.Unlock() + if s.wg != nil { + s.wg.Done() + } +} + +// Aggregate sums this node's value with the totals returned by its children. +// Leaves return just their own value. This implements the upward aggregation +// half of a tree quorum call. +func (s *aggregatorSrv) Aggregate(ctx gorums.ServerCtx, req *tree.AggregateRequest) (*tree.AggregateResponse, error) { + total := s.value + if children := ctx.TreeChildren(); len(children) > 0 { + // Release the handler lock before waiting on child responses to allow + // this server to process other inbound requests concurrently. + ctx.Release() + for r := range tree.Aggregate(children.Context(ctx), req).Results() { + if r.Err == nil { + total += r.Value.GetTotal() + } + } + } + return tree.AggregateResponse_builder{Total: total}.Build(), nil +} + +// setupTree creates n fully-connected systems, builds a tree from each +// system's own outbound config, and registers the tree and service on each +// server. The caller owns cleanup via the gorums.TestSystems t.Cleanup hook. +// +// The returned tree is built from systems[0]'s connections and is used +// client-side via clientTree.Context(ctx). +func setupTree(t *testing.T, srvs []*aggregatorSrv, opts gorums.TreeOptions) *gorums.TreeConfiguration { + t.Helper() + n := len(srvs) + systems := gorums.TestSystems(t, n) + + for i, sys := range systems { + ctx := gorums.TestContext(t, 5*time.Second) + if err := sys.WaitForConfig(ctx, func(cfg gorums.Configuration) bool { + return cfg.Size() == n + }); err != nil { + t.Fatalf("system %d: WaitForConfig: %v", i+1, err) + } + } + + clientTree, err := systems[0].OutboundConfig().AsTree(opts) + if err != nil { + t.Fatalf("AsTree: %v", err) + } + + // Each server registers its own tree so ctx.TreeChildren() returns nodes + // backed by that server's outbound connections, not the client's. + for i, sys := range systems { + sysTree, err := sys.OutboundConfig().AsTree(opts) + if err != nil { + t.Fatalf("system %d: AsTree: %v", i+1, err) + } + srv := srvs[i] + sys.RegisterService(nil, func(gorumsSrv *gorums.Server) { + gorumsSrv.RegisterTree(sysTree) + tree.RegisterTreeAggregatorServer(gorumsSrv, srv) + }) + } + return clientTree +} + +// TestTreeBroadcast verifies that Broadcast fans out along the tree: the caller +// delivers to root's direct children, each internal node relays to its children, +// and leaves apply local logic only. With bf=2 and depth=2: +// +// 1 (root) +// / \ +// 2 3 +// / \ / \ +// 4 5 6 7 +// +// The caller sends to {2, 3} via clientTree.Context(ctx). +// Node 2 relays to {4, 5}; node 3 relays to {6, 7}. +// Six nodes receive the notification; the root does not. +func TestTreeBroadcast(t *testing.T) { + const n = 7 + opts := gorums.TreeOptions{BranchingFactor: 2, Depth: 2} + + var wg sync.WaitGroup + wg.Add(n - 1) // all nodes except the root + + srvs := make([]*aggregatorSrv, n) + for i := range n { + srvs[i] = &aggregatorSrv{wg: &wg} + } + + clientTree := setupTree(t, srvs, opts) + + ctx := gorums.TestContext(t, 2*time.Second) + if err := tree.Broadcast(clientTree.Context(ctx), tree.BroadcastRequest_builder{Text: "hello"}.Build()); err != nil { + t.Fatalf("Broadcast: %v", err) + } + + done := make(chan struct{}) + go func() { wg.Wait(); close(done) }() + select { + case <-done: + case <-time.After(3 * time.Second): + t.Fatal("timeout: not all nodes received the broadcast") + } + + // Root (node 1, srvs[0]) is the logical caller and must not receive the broadcast. + srvs[0].mu.Lock() + rootCount := len(srvs[0].received) + srvs[0].mu.Unlock() + if rootCount != 0 { + t.Errorf("root received %d messages, want 0", rootCount) + } + // Every other node must have received exactly one message. + for i := 1; i < n; i++ { + srvs[i].mu.Lock() + got := len(srvs[i].received) + srvs[i].mu.Unlock() + if got != 1 { + t.Errorf("node %d received %d messages, want 1", i+1, got) + } + } +} + +// TestTreeAggregate verifies that Aggregate collects subtree totals up the +// tree. Each node i has value i+1 (its node ID). Internal nodes sum the +// children's sub-totals and add their own value. The caller sums the two +// direct-child subtree results. +// +// With bf=2 and depth=2: +// - Node 2 returns 2+4+5 = 11 +// - Node 3 returns 3+6+7 = 16 +// - Caller sums: 11+16 = 27 (nodes 2–7; root is not in the call path) +func TestTreeAggregate(t *testing.T) { + const n = 7 + opts := gorums.TreeOptions{BranchingFactor: 2, Depth: 2} + + srvs := make([]*aggregatorSrv, n) + for i := range n { + srvs[i] = &aggregatorSrv{value: int32(i + 1)} // node IDs are 1-based + } + + clientTree := setupTree(t, srvs, opts) + + ctx := gorums.TestContext(t, 3*time.Second) + responses := tree.Aggregate(clientTree.Context(ctx), &tree.AggregateRequest{}) + + const wantTotal = int32(27) // 2+3+4+5+6+7; root (node 1) is not in the call path + gotTotal := int32(0) + for r := range responses.Results() { + if r.Err != nil { + t.Fatalf("Aggregate error from node %d: %v", r.NodeID, r.Err) + } + gotTotal += r.Value.GetTotal() + } + if gotTotal != wantTotal { + t.Errorf("Aggregate total: got %d, want %d", gotTotal, wantTotal) + } +} diff --git a/server.go b/server.go index 340859e9..cd0650e4 100644 --- a/server.go +++ b/server.go @@ -95,6 +95,7 @@ type Server struct { grpcServer *grpc.Server handlers map[string]Handler interceptors []Interceptor + tree *TreeConfiguration *inboundManager } @@ -140,6 +141,15 @@ func (s *Server) RegisterHandler(method string, handler Handler) { s.handlers[method] = chainInterceptors(handler, s.interceptors...) } +// RegisterTree associates a TreeConfiguration with this server so that handlers +// can query their tree position via [ServerCtx.TreeChildren], [ServerCtx.TreeParent], +// and [ServerCtx.TreePosition]. The tree must be built after servers have started, +// once node IDs and addresses are known. RegisterTree must be called before any +// handler can fire; it does not synchronize concurrent updates. +func (s *Server) RegisterTree(tree *TreeConfiguration) { + s.tree = tree +} + // HandleRequest processes an incoming request from the stream, dispatching it // to the appropriate registered handler. It serves as the bridge between the // multiplexing in the stream package and the RPC logic in the gorums package. diff --git a/tree.go b/tree.go new file mode 100644 index 00000000..40336df3 --- /dev/null +++ b/tree.go @@ -0,0 +1,185 @@ +package gorums + +import ( + "context" + "fmt" + "math" + "math/bits" +) + +// TreeOptions configures the shape of a TreeConfiguration. +type TreeOptions struct { + BranchingFactor int // number of children per internal node; must be >= 2 + Depth int // number of edge-hops from root to leaves; must be >= 1 +} + +// TreeConfiguration is a tree-shaped view over a flat Configuration. +// The node at index 0 is the root; the next BranchingFactor indices are its +// children; the next BranchingFactor^2 indices are the grandchildren; and so on +// (breadth-first order). If the configuration has fewer nodes than a perfect +// tree of the given shape, the last level is partial. +// +// Use [Configuration.AsTree] to construct a TreeConfiguration. +type TreeConfiguration struct { + nodes Configuration // breadth-first order; nodes[0] is the root + opts TreeOptions + positionOf map[uint32]int // node ID → index in nodes +} + +// AsTree builds a TreeConfiguration from cfg using the given options. +// Nodes are placed in tree order by their position in cfg. +// Returns an error if cfg is empty, the options are invalid, or cfg has more +// nodes than the tree capacity ((bf^(depth+1) − 1) / (bf − 1)). +// If cfg has fewer nodes than capacity, the last level is partial. +func (cfg Configuration) AsTree(opts TreeOptions) (*TreeConfiguration, error) { + if opts.BranchingFactor < 2 { + return nil, fmt.Errorf("gorums: TreeOptions.BranchingFactor must be >= 2, got %d", opts.BranchingFactor) + } + if opts.Depth < 1 { + return nil, fmt.Errorf("gorums: TreeOptions.Depth must be >= 1, got %d", opts.Depth) + } + if len(cfg) == 0 { + return nil, fmt.Errorf("gorums: cannot build a tree from an empty configuration") + } + capacity := treeLevelStart(opts.Depth+1, opts.BranchingFactor) + if capacity < 0 { + return nil, fmt.Errorf("gorums: tree shape (BranchingFactor=%d, Depth=%d) exceeds representable range", opts.BranchingFactor, opts.Depth) + } + if len(cfg) > capacity { + return nil, fmt.Errorf("gorums: configuration has %d nodes, exceeds tree capacity of %d", len(cfg), capacity) + } + positionOf := make(map[uint32]int, len(cfg)) + for i, n := range cfg { + positionOf[n.ID()] = i + } + return &TreeConfiguration{ + nodes: cfg, + opts: opts, + positionOf: positionOf, + }, nil +} + +// treeLevelStart returns the index of the first node at level k in a tree with +// the given branching factor (bf > 1), i.e., the total number of nodes in +// levels 0 through k−1, computed as the sum bf^0 + bf^1 + … + bf^(k−1). +// Returns -1 if the result would overflow int; callers at validation boundaries +// must check for -1 before using the result. +func treeLevelStart(k, bf int) int { + sum, pow := uint(0), uint(1) + for i := range k { + var carry uint + sum, carry = bits.Add(sum, pow, 0) + // ensure sum still fits in an int value before the final int(sum) conversion + if carry != 0 || sum > uint(math.MaxInt) { + return -1 + } + if i+1 < k { + hi, lo := bits.Mul(pow, uint(bf)) + if hi != 0 { + return -1 + } + pow = lo + } + } + return int(sum) +} + +// treePow returns base^exp for non-negative exp. +func treePow(base, exp int) int { + result := 1 + for range exp { + result *= base + } + return result +} + +// posLevel returns the depth and within-level index of the node at position pos. +// Returns -1, -1 if pos is out of range. +func (t *TreeConfiguration) posLevel(pos int) (depth, indexInLevel int) { + if pos < 0 || pos >= len(t.nodes) { + return -1, -1 + } + bf := t.opts.BranchingFactor + for d := range t.opts.Depth + 1 { + if pos < treeLevelStart(d+1, bf) { + return d, pos - treeLevelStart(d, bf) + } + } + return -1, -1 +} + +// ParentOf returns the parent of the node with the given ID, or nil if the +// node is the root or not in the tree. +func (t *TreeConfiguration) ParentOf(id uint32) *Node { + pos, found := t.positionOf[id] + if !found { + return nil + } + d, idx := t.posLevel(pos) + if d <= 0 { + return nil + } + parentPos := treeLevelStart(d-1, t.opts.BranchingFactor) + idx/t.opts.BranchingFactor + return t.nodes[parentPos] +} + +// ChildrenOf returns the direct children of the node with the given ID as a +// Configuration. Returns nil for leaves and for IDs not in the tree. +func (t *TreeConfiguration) ChildrenOf(id uint32) Configuration { + pos, found := t.positionOf[id] + if !found { + return nil + } + d, idx := t.posLevel(pos) + if d < 0 || d >= t.opts.Depth { + return nil + } + bf := t.opts.BranchingFactor + firstChild := treeLevelStart(d+1, bf) + idx*bf + if firstChild >= len(t.nodes) { + return nil + } + lastChild := min(firstChild+bf, len(t.nodes)) + return t.nodes[firstChild:lastChild] +} + +// Subtree returns a Configuration containing the given node and all its +// descendants in breadth-first order. Returns nil if the node is not in +// the tree. +func (t *TreeConfiguration) Subtree(id uint32) Configuration { + pos, found := t.positionOf[id] + if !found { + return nil + } + d, idx := t.posLevel(pos) + if d < 0 { + return nil + } + bf := t.opts.BranchingFactor + result := make(Configuration, 0, len(t.nodes)-pos) + result = append(result, t.nodes[pos]) + for lvl := d + 1; lvl <= t.opts.Depth; lvl++ { + span := treePow(bf, lvl-d) + subtreeStart := treeLevelStart(lvl, bf) + idx*span + if subtreeStart >= len(t.nodes) { + break + } + subtreeEnd := min(subtreeStart+span, len(t.nodes)) + result = append(result, t.nodes[subtreeStart:subtreeEnd]...) + } + return result +} + +// Context returns a ConfigContext for use with generated tree-call client +// wrappers, addressed to the direct children of the tree root. The client acts +// as the external root; generated server-side dispatchers handle further relay +// via [ServerCtx.TreeChildren]. +// +// Panics if the root has no children (e.g., a single-node configuration). +func (t *TreeConfiguration) Context(parent context.Context) *ConfigContext { + children := t.ChildrenOf(t.nodes[0].ID()) + if len(children) == 0 { + panic("gorums: tree root has no children") + } + return children.Context(parent) +} diff --git a/tree_integration_test.go b/tree_integration_test.go new file mode 100644 index 00000000..19848447 --- /dev/null +++ b/tree_integration_test.go @@ -0,0 +1,155 @@ +package gorums_test + +import ( + "sync" + "testing" + "time" + + "github.com/relab/gorums" + "github.com/relab/gorums/internal/testutils/mock" + pb "google.golang.org/protobuf/types/known/wrapperspb" +) + +// setupTreeSystems creates n fully-connected systems and registers a matching +// TreeConfiguration on each server. Each server's tree is built from that +// server's own outbound config so that ctx.TreeChildren() returns nodes with +// live connections owned by that server. +// +// The returned tree is built from systems[0]'s outbound config and can be used +// as the client-side tree (e.g. to call tree.Context(ctx)). +// +// awaitSystemReady is called before returning so all systems are ready for use. +func setupTreeSystems(t *testing.T, n int, opts gorums.TreeOptions) ([]*gorums.System, *gorums.TreeConfiguration) { + t.Helper() + systems := gorums.TestSystems(t, n) + awaitSystemReady(t, systems) + + clientTree, err := systems[0].OutboundConfig().AsTree(opts) + if err != nil { + t.Fatalf("AsTree: %v", err) + } + + // Each server needs its own tree so ctx.TreeChildren() returns nodes + // backed by that server's own outbound connections. + for _, sys := range systems { + sysTree, err := sys.OutboundConfig().AsTree(opts) + if err != nil { + t.Fatalf("AsTree: %v", err) + } + sys.RegisterService(nil, func(srv *gorums.Server) { + srv.RegisterTree(sysTree) + }) + } + return systems, clientTree +} + +// TestTreeMulticast verifies that Multicast fans out along the tree: the caller +// delivers to the root's direct children, each internal node relays to its own +// children, and leaves apply local logic. With bf=2, depth=2: +// +// 1 (root) +// / \ +// 2 3 +// / \ / \ +// 4 5 6 7 +// +// The caller sends to nodes {2, 3} via tree.Context(ctx). +// Nodes 2 and 3 relay to {4, 5} and {6, 7} respectively. +// Total handler invocations: 6 (all nodes except the logical caller root). +func TestTreeMulticast(t *testing.T) { + opts := gorums.TreeOptions{BranchingFactor: 2, Depth: 2} + systems, tree := setupTreeSystems(t, 7, opts) + + var wg sync.WaitGroup + wg.Add(6) // nodes 2–7; root (node 1) is not in the call path + + for _, sys := range systems { + sys.RegisterService(nil, func(srv *gorums.Server) { + srv.RegisterHandler(mock.Stream, func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { + // Relay to children before running local logic. + // Internal nodes fan out; leaves skip this branch. + if children := ctx.TreeChildren(); len(children) > 0 { + req := gorums.AsProto[*pb.StringValue](in) + if err := gorums.Multicast(children.Context(ctx), req, mock.Stream); err != nil { + t.Errorf("relay multicast error: %v", err) + } + } + wg.Done() + return nil, nil + }) + }) + } + + ctx := gorums.TestContext(t, 2*time.Second) + // tree.Context(ctx) addresses root's children {2, 3} — the caller acts as + // the external root; relay to further descendants is handled by the handlers. + if err := gorums.Multicast(tree.Context(ctx), pb.String("ping"), mock.Stream); err != nil { + t.Fatalf("multicast error: %v", err) + } + + waitWithTimeout(t, &wg) +} + +// TestTreeQuorumCall verifies that a QuorumCall aggregates responses up the tree. +// Each node i contributes its node ID as its local value. Internal nodes sum +// their children's aggregated replies and add their own contribution before +// returning. The caller sums the two subtree totals. +// +// With bf=2, depth=2: +// +// 1 (root) +// / \ +// 2 3 +// / \ / \ +// 4 5 6 7 +// +// Node 2 returns 2+4+5 = 11; node 3 returns 3+6+7 = 16. +// The caller sums both subtree results: 11+16 = 27. +func TestTreeQuorumCall(t *testing.T) { + opts := gorums.TreeOptions{BranchingFactor: 2, Depth: 2} + systems, tree := setupTreeSystems(t, 7, opts) + + for i, sys := range systems { + myVal := int32(i + 1) // node ID: system[0] → 1, system[1] → 2, … + sys.RegisterService(nil, func(srv *gorums.Server) { + srv.RegisterHandler(mock.GetValueMethod, func(ctx gorums.ServerCtx, in *gorums.Message) (*gorums.Message, error) { + total := myVal + if children := ctx.TreeChildren(); len(children) > 0 { + // Release before blocking on child responses; allows the + // server to process other incoming requests concurrently. + ctx.Release() + childResp := gorums.QuorumCall[*pb.Int32Value, *pb.Int32Value]( + children.Context(ctx), + gorums.AsProto[*pb.Int32Value](in), + mock.GetValueMethod, + ) + for r := range childResp.Results() { + if r.Err == nil { + total += r.Value.GetValue() + } + } + } + return gorums.NewResponseMessage(in, pb.Int32(total)), nil + }) + }) + } + + ctx := gorums.TestContext(t, 3*time.Second) + responses := gorums.QuorumCall[*pb.Int32Value, *pb.Int32Value]( + tree.Context(ctx), + pb.Int32(0), + mock.GetValueMethod, + ) + + const wantTotal = int32(27) // 2+3+4+5+6+7; root (node 1) is not in the call path + gotTotal := int32(0) + for r := range responses.Results() { + if r.Err != nil { + t.Fatalf("quorum call error from node %d: %v", r.NodeID, r.Err) + } + gotTotal += r.Value.GetValue() + } + if gotTotal != wantTotal { + t.Errorf("tree aggregate sum: got %d, want %d", gotTotal, wantTotal) + } +} diff --git a/tree_test.go b/tree_test.go new file mode 100644 index 00000000..7838caa2 --- /dev/null +++ b/tree_test.go @@ -0,0 +1,441 @@ +package gorums + +import ( + "context" + "slices" + "strings" + "testing" +) + +// makeTreeNode creates a minimal node for tree layout testing. +func makeTreeNode(id uint32) *Node { + return &Node{id: id} +} + +// makeTreeConfig builds a Configuration with sequential 1-based node IDs, +// matching the rest of the codebase where node ID 0 is reserved. +func makeTreeConfig(n int) Configuration { + cfg := make(Configuration, n) + for i := range n { + cfg[i] = makeTreeNode(uint32(i + 1)) + } + return cfg +} + +func TestAsTree_Errors(t *testing.T) { + cfg := makeTreeConfig(7) + tests := []struct { + name string + cfg Configuration + opts TreeOptions + wantErr string + }{ + { + name: "BranchingFactorZero", + cfg: cfg, + opts: TreeOptions{BranchingFactor: 0, Depth: 2}, + wantErr: "BranchingFactor must be >= 2", + }, + { + name: "BranchingFactorOne", + cfg: cfg, + opts: TreeOptions{BranchingFactor: 1, Depth: 2}, + wantErr: "BranchingFactor must be >= 2", + }, + { + name: "DepthZero", + cfg: cfg, + opts: TreeOptions{BranchingFactor: 2, Depth: 0}, + wantErr: "Depth must be >= 1", + }, + { + name: "EmptyConfig", + cfg: Configuration{}, + opts: TreeOptions{BranchingFactor: 2, Depth: 1}, + wantErr: "empty configuration", + }, + { + name: "NilConfig", + cfg: nil, + opts: TreeOptions{BranchingFactor: 2, Depth: 1}, + wantErr: "empty configuration", + }, + { + name: "ExcessNodes", + cfg: makeTreeConfig(15), // capacity for bf=3, depth=2 is 13 + opts: TreeOptions{BranchingFactor: 3, Depth: 2}, + wantErr: "exceeds tree capacity", + }, + { + name: "OverflowCapacity", + cfg: makeTreeConfig(1), + opts: TreeOptions{BranchingFactor: 2, Depth: 63}, + wantErr: "exceeds representable range", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := tt.cfg.AsTree(tt.opts) + if err == nil { + t.Fatal("expected error, got nil") + } + if tt.wantErr != "" && !strings.Contains(err.Error(), tt.wantErr) { + t.Errorf("error %q does not contain %q", err.Error(), tt.wantErr) + } + }) + } +} + +// TestTreeLevelStart verifies the level-start index arithmetic. +func TestTreeLevelStart(t *testing.T) { + tests := []struct { + bf int + level int + want int + }{ + // bf=2: 0, 1, 3, 7, 15 + {bf: 2, level: 0, want: 0}, + {bf: 2, level: 1, want: 1}, + {bf: 2, level: 2, want: 3}, + {bf: 2, level: 3, want: 7}, + {bf: 2, level: 4, want: 15}, + // bf=3: 0, 1, 4, 13 + {bf: 3, level: 0, want: 0}, + {bf: 3, level: 1, want: 1}, + {bf: 3, level: 2, want: 4}, + {bf: 3, level: 3, want: 13}, + // bf=4: 0, 1, 5, 21 + {bf: 4, level: 0, want: 0}, + {bf: 4, level: 1, want: 1}, + {bf: 4, level: 2, want: 5}, + {bf: 4, level: 3, want: 21}, + } + for _, tt := range tests { + got := treeLevelStart(tt.level, tt.bf) + if got != tt.want { + t.Errorf("treeLevelStart(%d, bf=%d) = %d, want %d", tt.level, tt.bf, got, tt.want) + } + } +} + +// TestTreeParentOf verifies ParentOf on a perfect bf=3 depth=2 tree. +func TestTreeParentOf(t *testing.T) { + tree := mustNewTree(t, 13, TreeOptions{BranchingFactor: 3, Depth: 2}) + tests := []struct { + id uint32 + wantParent uint32 // ignored when wantNil is true + wantNil bool + }{ + {id: 1, wantNil: true}, // root has no parent + {id: 2, wantParent: 1}, // children of root + {id: 3, wantParent: 1}, + {id: 4, wantParent: 1}, + {id: 5, wantParent: 2}, // children of node 2 + {id: 6, wantParent: 2}, + {id: 7, wantParent: 2}, + {id: 8, wantParent: 3}, // children of node 3 + {id: 9, wantParent: 3}, + {id: 10, wantParent: 3}, + {id: 11, wantParent: 4}, // children of node 4 + {id: 12, wantParent: 4}, + {id: 13, wantParent: 4}, + {id: 14, wantNil: true}, // not in tree + } + for _, tt := range tests { + got := tree.ParentOf(tt.id) + if tt.wantNil { + if got != nil { + t.Errorf("ParentOf(%d) = node %d, want nil", tt.id, got.ID()) + } + } else { + if got == nil { + t.Errorf("ParentOf(%d) = nil, want node %d", tt.id, tt.wantParent) + } else if got.ID() != tt.wantParent { + t.Errorf("ParentOf(%d) = node %d, want node %d", tt.id, got.ID(), tt.wantParent) + } + } + } +} + +// TestTreeChildrenOf verifies ChildrenOf on a perfect bf=3 depth=2 tree. +func TestTreeChildrenOf(t *testing.T) { + tree := mustNewTree(t, 13, TreeOptions{BranchingFactor: 3, Depth: 2}) + tests := []struct { + id uint32 + wantIDs []uint32 + }{ + {id: 1, wantIDs: []uint32{2, 3, 4}}, // root + {id: 2, wantIDs: []uint32{5, 6, 7}}, + {id: 3, wantIDs: []uint32{8, 9, 10}}, + {id: 4, wantIDs: []uint32{11, 12, 13}}, + {id: 5, wantIDs: nil}, // leaves + {id: 6, wantIDs: nil}, + {id: 7, wantIDs: nil}, + {id: 8, wantIDs: nil}, + {id: 9, wantIDs: nil}, + {id: 10, wantIDs: nil}, + {id: 11, wantIDs: nil}, + {id: 12, wantIDs: nil}, + {id: 13, wantIDs: nil}, + {id: 14, wantIDs: nil}, // not in tree + } + for _, tt := range tests { + got := tree.ChildrenOf(tt.id) + if !slices.Equal(got.NodeIDs(), tt.wantIDs) { + t.Errorf("ChildrenOf(%d) = %v, want %v", tt.id, got.NodeIDs(), tt.wantIDs) + } + } +} + +// TestTreeSubtree verifies Subtree on a perfect bf=3 depth=2 tree. +func TestTreeSubtree(t *testing.T) { + tree := mustNewTree(t, 13, TreeOptions{BranchingFactor: 3, Depth: 2}) + tests := []struct { + id uint32 + wantIDs []uint32 + }{ + {id: 1, wantIDs: []uint32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13}}, // full tree + {id: 2, wantIDs: []uint32{2, 5, 6, 7}}, + {id: 3, wantIDs: []uint32{3, 8, 9, 10}}, + {id: 4, wantIDs: []uint32{4, 11, 12, 13}}, + {id: 5, wantIDs: []uint32{5}}, // leaves: just self + {id: 6, wantIDs: []uint32{6}}, + {id: 7, wantIDs: []uint32{7}}, + {id: 8, wantIDs: []uint32{8}}, + {id: 9, wantIDs: []uint32{9}}, + {id: 10, wantIDs: []uint32{10}}, + {id: 11, wantIDs: []uint32{11}}, + {id: 12, wantIDs: []uint32{12}}, + {id: 13, wantIDs: []uint32{13}}, + {id: 14, wantIDs: nil}, // not in tree + } + for _, tt := range tests { + got := tree.Subtree(tt.id) + if !slices.Equal(got.NodeIDs(), tt.wantIDs) { + t.Errorf("Subtree(%d) = %v, want %v", tt.id, got.NodeIDs(), tt.wantIDs) + } + } +} + +// TestTreePartialLastLevel verifies layout when the configuration is smaller +// than a perfect tree (bf=3, depth=2, only 10 of 13 nodes present). +// +// 1 (root) +// / | \ +// 2 3 4 +// /|\ /|\ +// 5 6 7 8 9 10 +func TestTreePartialLastLevel(t *testing.T) { + tree := mustNewTree(t, 10, TreeOptions{BranchingFactor: 3, Depth: 2}) + + // Node 2 (level 1, idx 0): children at positions 4,5,6 — all present. + if got := tree.ChildrenOf(2); !slices.Equal(got.NodeIDs(), []uint32{5, 6, 7}) { + t.Errorf("ChildrenOf(2) = %v, want [5 6 7]", got.NodeIDs()) + } + // Node 3 (level 1, idx 1): children at positions 7,8,9 — all present. + if got := tree.ChildrenOf(3); !slices.Equal(got.NodeIDs(), []uint32{8, 9, 10}) { + t.Errorf("ChildrenOf(3) = %v, want [8 9 10]", got.NodeIDs()) + } + // Node 4 (level 1, idx 2): children would be at 10,11,12 — none present. + if got := tree.ChildrenOf(4); got != nil { + t.Errorf("ChildrenOf(4) = %v, want nil", got.NodeIDs()) + } + // Subtree(4) = just node 4. + if got := tree.Subtree(4); !slices.Equal(got.NodeIDs(), []uint32{4}) { + t.Errorf("Subtree(4) = %v, want [4]", got.NodeIDs()) + } + // Subtree of root spans every present node. + if got := tree.Subtree(1); !slices.Equal(got.NodeIDs(), []uint32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) { + t.Errorf("Subtree(1) = %v, want [1..10]", got.NodeIDs()) + } +} + +// TestTreePartialLastLevel_OneChild verifies the case where the last internal +// node has fewer children than bf (bf=3, depth=2, 12 nodes: node 4 has 2 +// children instead of 3). +// +// 1 (root) +// / | \ +// 2 3 4 +// /|\ /|\ / | +// 5 6 7 8 9 10 11 12 +func TestTreePartialLastLevel_OneChild(t *testing.T) { + tree := mustNewTree(t, 12, TreeOptions{BranchingFactor: 3, Depth: 2}) + // Node 4 (level 1, idx 2): children at 11, 12 — position 12 (slot for ID 13) absent. + if got := tree.ChildrenOf(4); !slices.Equal(got.NodeIDs(), []uint32{11, 12}) { + t.Errorf("ChildrenOf(4) = %v, want [11 12]", got.NodeIDs()) + } +} + +// TestTreeExcessNodes verifies that a configuration larger than the tree +// capacity is rejected (bf=3, depth=2, capacity=13; give 15 nodes). +func TestTreeExcessNodes(t *testing.T) { + _, err := makeTreeConfig(15).AsTree(TreeOptions{BranchingFactor: 3, Depth: 2}) + if err == nil { + t.Fatal("expected error for configuration exceeding tree capacity, got nil") + } + if !strings.Contains(err.Error(), "exceeds tree capacity") { + t.Errorf("unexpected error message: %v", err) + } +} + +// TestTreeBF2Depth3 exercises a perfect binary tree (bf=2, depth=3, 15 nodes). +// +// 1 (root) +// / \ +// 2 3 +// / \ / \ +// 4 5 6 7 +// / \ / \ / \ / \ +// 8 9 10 11 12 13 14 15 +func TestTreeBF2Depth3(t *testing.T) { + tree := mustNewTree(t, 15, TreeOptions{BranchingFactor: 2, Depth: 3}) + + parentTests := []struct { + id uint32 + wantID uint32 + nilOK bool + }{ + {1, 0, true}, + {2, 1, false}, + {3, 1, false}, + {4, 2, false}, + {5, 2, false}, + {6, 3, false}, + {7, 3, false}, + {8, 4, false}, + {9, 4, false}, + {15, 7, false}, + } + for _, tt := range parentTests { + p := tree.ParentOf(tt.id) + if tt.nilOK { + if p != nil { + t.Errorf("ParentOf(%d) = %d, want nil", tt.id, p.ID()) + } + } else if p == nil || p.ID() != tt.wantID { + got := uint32(0) + if p != nil { + got = p.ID() + } + t.Errorf("ParentOf(%d) = %d, want %d", tt.id, got, tt.wantID) + } + } + + // Subtree of node 2: {2, 4, 5, 8, 9, 10, 11} + if got := tree.Subtree(2); !slices.Equal(got.NodeIDs(), []uint32{2, 4, 5, 8, 9, 10, 11}) { + t.Errorf("Subtree(2) = %v, want [2 4 5 8 9 10 11]", got.NodeIDs()) + } +} + +// TestServerCtxTree verifies the ServerCtx tree accessors against the +// bf=3 depth=2 tree used throughout this file. +// +// 1 (root) +// / | \ +// 2 3 4 +// /|\ /|\ / | \ +// 5 6 7 8 9 10 11 12 13 +func TestServerCtxTree(t *testing.T) { + tree := mustNewTree(t, 13, TreeOptions{BranchingFactor: 3, Depth: 2}) + + // serverCtxFor builds a minimal ServerCtx whose srv.myID is set to id. + serverCtxFor := func(id uint32) ServerCtx { + s := &Server{ + inboundManager: &inboundManager{myID: id}, + tree: tree, + } + return ServerCtx{Context: context.Background(), srv: s} + } + + t.Run("Root", func(t *testing.T) { + ctx := serverCtxFor(1) + if got := ctx.TreeChildren(); !slices.Equal(got.NodeIDs(), []uint32{2, 3, 4}) { + t.Errorf("TreeChildren = %v, want [2 3 4]", got.NodeIDs()) + } + if p := ctx.TreeParent(); p != nil { + t.Errorf("TreeParent = node %d, want nil", p.ID()) + } + }) + + t.Run("InternalNode", func(t *testing.T) { + ctx := serverCtxFor(3) // level 1, idx 1; children=[8,9,10]; parent=1 + if got := ctx.TreeChildren(); !slices.Equal(got.NodeIDs(), []uint32{8, 9, 10}) { + t.Errorf("TreeChildren = %v, want [8 9 10]", got.NodeIDs()) + } + if p := ctx.TreeParent(); p == nil || p.ID() != 1 { + t.Errorf("TreeParent = %v, want node 1", p) + } + }) + + t.Run("Leaf", func(t *testing.T) { + ctx := serverCtxFor(8) // level 2, idx 3; no children; parent=3 + if got := ctx.TreeChildren(); got != nil { + t.Errorf("TreeChildren = %v, want nil", got.NodeIDs()) + } + if p := ctx.TreeParent(); p == nil || p.ID() != 3 { + t.Errorf("TreeParent = %v, want node 3", p) + } + }) + + t.Run("NodeNotInTree", func(t *testing.T) { + ctx := serverCtxFor(99) // ID not present in the 13-node tree + if got := ctx.TreeChildren(); got != nil { + t.Errorf("TreeChildren = %v, want nil", got.NodeIDs()) + } + if p := ctx.TreeParent(); p != nil { + t.Errorf("TreeParent = node %d, want nil", p.ID()) + } + }) + + t.Run("NoTreeRegistered", func(t *testing.T) { + s := &Server{inboundManager: &inboundManager{myID: 1}} + ctx := ServerCtx{Context: context.Background(), srv: s} + if got := ctx.TreeChildren(); got != nil { + t.Errorf("TreeChildren = %v, want nil (no tree)", got.NodeIDs()) + } + if p := ctx.TreeParent(); p != nil { + t.Errorf("TreeParent = node %d, want nil (no tree)", p.ID()) + } + }) +} + +// TestTreeContext verifies that Context returns a ConfigContext addressing +// the root's direct children on a perfect bf=3 depth=2 tree. +func TestTreeContext(t *testing.T) { + tree := mustNewTree(t, 13, TreeOptions{BranchingFactor: 3, Depth: 2}) + ctx := tree.Context(context.Background()) + if ctx == nil { + t.Fatal("Context returned nil") + } + want := []uint32{2, 3, 4} + if got := ctx.Configuration().NodeIDs(); !slices.Equal(got, want) { + t.Errorf("Context.Configuration().NodeIDs() = %v, want %v", got, want) + } +} + +// TestTreeContext_PanicsOnRootWithNoChildren verifies that Context panics +// when the configuration contains only the root (no children present). +func TestTreeContext_PanicsOnRootWithNoChildren(t *testing.T) { + tree := mustNewTree(t, 1, TreeOptions{BranchingFactor: 2, Depth: 1}) + defer func() { + r := recover() + if r == nil { + t.Fatal("Context did not panic on a tree with no children") + } + if msg, ok := r.(string); !ok || !strings.Contains(msg, "no children") { + t.Errorf("panic = %v, want string containing %q", r, "no children") + } + }() + _ = tree.Context(context.Background()) +} + +// mustNewTree creates a TreeConfiguration for testing, failing the test on error. +func mustNewTree(t *testing.T, n int, opts TreeOptions) *TreeConfiguration { + t.Helper() + tree, err := makeTreeConfig(n).AsTree(opts) + if err != nil { + t.Fatalf("AsTree: %v", err) + } + return tree +}