diff --git a/chain/transaction.go b/chain/transaction.go index d89991d6e3..be7944338c 100644 --- a/chain/transaction.go +++ b/chain/transaction.go @@ -112,6 +112,8 @@ func (t *TransactionData) Marshal(p *codec.Packer) error { return t.marshal(p) } +func (*TransactionData) Priority() uint64 { return 0 } + func (t *TransactionData) marshal(p *codec.Packer) error { t.Base.Marshal(p) return t.Actions.MarshalInto(p) diff --git a/internal/list/list.go b/internal/list/list.go index 21b9c47783..c84de3b376 100644 --- a/internal/list/list.go +++ b/internal/list/list.go @@ -3,7 +3,9 @@ package list -import "github.com/ava-labs/avalanchego/ids" +import ( + "github.com/ava-labs/avalanchego/ids" +) // Item defines an interface accepted by [List]. // @@ -13,6 +15,7 @@ import "github.com/ava-labs/avalanchego/ids" type Item interface { GetID() ids.ID // method for returning an id of the item GetExpiry() int64 // method for returning this items timestamp + Priority() uint64 } // List implements a double-linked list. It offers @@ -64,6 +67,10 @@ func (e *Element[T]) GetExpiry() int64 { return e.value.GetExpiry() } +func (e *Element[T]) Priority() uint64 { + return e.value.Priority() +} + func (l *List[T]) First() *Element[T] { if l.size == 0 { return nil diff --git a/internal/list/list_test.go b/internal/list/list_test.go index 5fbf9e8ae7..e951238b16 100644 --- a/internal/list/list_test.go +++ b/internal/list/list_test.go @@ -29,6 +29,10 @@ func (mti *TestItem) GetExpiry() int64 { return mti.timestamp } +func (*TestItem) Priority() uint64 { + return 0 +} + func GenerateTestItem(str string) *TestItem { id := ids.GenerateTestID() return &TestItem{ diff --git a/internal/mempool/abstract.go b/internal/mempool/abstract.go new file mode 100644 index 0000000000..8858a6d24c --- /dev/null +++ b/internal/mempool/abstract.go @@ -0,0 +1,39 @@ +// Copyright (C) 2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package mempool + +import ( + "context" + "time" + + "github.com/ava-labs/avalanchego/ids" + "github.com/ava-labs/avalanchego/trace" +) + +// NewGeneralMempool is creating a mempool using a FIFO queue. +type AbstractMempoolFactory[T Item] func( + tracer trace.Tracer, + maxSize int, + maxSponsorSize int, +) AbstractMempool[T] + +type AbstractMempool[T Item] interface { + Has(ctx context.Context, itemID ids.ID) bool + Add(ctx context.Context, items []T) + PeekNext(ctx context.Context) (T, bool) + PopNext(ctx context.Context) (T, bool) + Remove(ctx context.Context, items []T) + Len(ctx context.Context) int + Size(context.Context) int + SetMinTimestamp(ctx context.Context, t int64) []T + Top( + ctx context.Context, + targetDuration time.Duration, + f func(context.Context, T) (cont bool, restore bool, err error), + ) error + StartStreaming(_ context.Context) + PrepareStream(ctx context.Context, count int) + Stream(ctx context.Context, count int) []T + FinishStreaming(ctx context.Context, restorable []T) int +} diff --git a/internal/mempool/mempool.go b/internal/mempool/mempool.go index 062b034c3e..1ae6ad7cf9 100644 --- a/internal/mempool/mempool.go +++ b/internal/mempool/mempool.go @@ -16,18 +16,22 @@ import ( "github.com/ava-labs/hypersdk/codec" "github.com/ava-labs/hypersdk/internal/eheap" "github.com/ava-labs/hypersdk/internal/list" + "github.com/ava-labs/hypersdk/internal/mempool/queue" ) const maxPrealloc = 4_096 +type GeneralMempool[T Item] Mempool[T, *list.Element[T]] + type Item interface { eheap.Item Sponsor() codec.Address Size() int + Priority() uint64 } -type Mempool[T Item] struct { +type Mempool[T Item, E queue.Item] struct { tracer trace.Tracer mu sync.RWMutex @@ -37,8 +41,9 @@ type Mempool[T Item] struct { maxSize int maxSponsorSize int // Maximum items allowed by a single sponsor - queue *list.List[T] - eh *eheap.ExpiryHeap[*list.Element[T]] + // queue *list.List[T, E] + queue queue.Queue[T, E] + eh *eheap.ExpiryHeap[E] // owned tracks the number of items in the mempool owned by a single // [Sponsor] @@ -52,27 +57,25 @@ type Mempool[T Item] struct { nextStreamFetched bool } -// New creates a new [Mempool]. [maxSize] must be > 0 or else the -// implementation may panic. -func New[T Item]( +// NewGeneralMempool is creating a mempool using a FIFO queue. +func NewGeneralMempool[T Item]( tracer trace.Tracer, maxSize int, maxSponsorSize int, -) *Mempool[T] { - return &Mempool[T]{ - tracer: tracer, - - maxSize: maxSize, - maxSponsorSize: maxSponsorSize, - - queue: &list.List[T]{}, - eh: eheap.New[*list.Element[T]](min(maxSize, maxPrealloc)), +) AbstractMempool[T] { + return newGeneralMempool[T](tracer, maxSize, maxSponsorSize) +} - owned: map[codec.Address]int{}, - } +// NewPriorityMempool is creating a mempool using a Priority queue. +func NewPriorityMempool[T Item]( + tracer trace.Tracer, + maxSize int, + maxSponsorSize int, +) AbstractMempool[T] { + return newPriorityMempool[T](tracer, maxSize, maxSponsorSize) } -func (m *Mempool[T]) removeFromOwned(item T) { +func (m *Mempool[T, E]) removeFromOwned(item T) { sender := item.Sponsor() items, ok := m.owned[sender] if !ok { @@ -87,7 +90,7 @@ func (m *Mempool[T]) removeFromOwned(item T) { } // Has returns if the eh of [m] contains [itemID] -func (m *Mempool[T]) Has(ctx context.Context, itemID ids.ID) bool { +func (m *Mempool[T, E]) Has(ctx context.Context, itemID ids.ID) bool { _, span := m.tracer.Start(ctx, "Mempool.Has") defer span.End() @@ -101,7 +104,7 @@ func (m *Mempool[T]) Has(ctx context.Context, itemID ids.ID) bool { // the item sponsor is not exempt and their items in the mempool exceed m.maxSponsorSize. // If the size of m exceeds m.maxSize, Add pops the lowest value item // from m.eh. -func (m *Mempool[T]) Add(ctx context.Context, items []T) { +func (m *Mempool[T, E]) Add(ctx context.Context, items []T) { _, span := m.tracer.Start(ctx, "Mempool.Add") defer span.End() @@ -111,7 +114,7 @@ func (m *Mempool[T]) Add(ctx context.Context, items []T) { m.add(items, false) } -func (m *Mempool[T]) add(items []T, front bool) { +func (m *Mempool[T, E]) add(items []T, front bool) { for _, item := range items { sender := item.Sponsor() @@ -136,11 +139,11 @@ func (m *Mempool[T]) add(items []T, front bool) { } // Add to mempool - var elem *list.Element[T] + var elem E if !front { - elem = m.queue.PushBack(item) + elem = m.queue.Push(item) } else { - elem = m.queue.PushFront(item) + elem = m.queue.Restore(item) } m.eh.Add(elem) m.owned[sender]++ @@ -150,23 +153,23 @@ func (m *Mempool[T]) add(items []T, front bool) { // PeekNext returns the highest valued item in m.eh. // Assumes there is non-zero items in [Mempool] -func (m *Mempool[T]) PeekNext(ctx context.Context) (T, bool) { +func (m *Mempool[T, E]) PeekNext(ctx context.Context) (T, bool) { _, span := m.tracer.Start(ctx, "Mempool.PeekNext") defer span.End() m.mu.RLock() defer m.mu.RUnlock() - first := m.queue.First() - if first == nil { + firstValue, ok := m.queue.FirstValue() + if !ok { return *new(T), false } - return first.Value(), true + return firstValue, true } // PopNext removes and returns the highest valued item in m.eh. // Assumes there is non-zero items in [Mempool] -func (m *Mempool[T]) PopNext(ctx context.Context) (T, bool) { // O(log N) +func (m *Mempool[T, E]) PopNext(ctx context.Context) (T, bool) { // O(log N) _, span := m.tracer.Start(ctx, "Mempool.PopNext") defer span.End() @@ -176,9 +179,9 @@ func (m *Mempool[T]) PopNext(ctx context.Context) (T, bool) { // O(log N) return m.popNext() } -func (m *Mempool[T]) popNext() (T, bool) { - first := m.queue.First() - if first == nil { +func (m *Mempool[T, E]) popNext() (T, bool) { + first, ok := m.queue.First() + if !ok { return *new(T), false } v := m.queue.Remove(first) @@ -189,7 +192,7 @@ func (m *Mempool[T]) popNext() (T, bool) { } // Remove removes [items] from m. -func (m *Mempool[T]) Remove(ctx context.Context, items []T) { +func (m *Mempool[T, E]) Remove(ctx context.Context, items []T) { _, span := m.tracer.Start(ctx, "Mempool.Remove") defer span.End() @@ -208,7 +211,7 @@ func (m *Mempool[T]) Remove(ctx context.Context, items []T) { } // Len returns the number of items in m. -func (m *Mempool[T]) Len(ctx context.Context) int { +func (m *Mempool[T, E]) Len(ctx context.Context) int { _, span := m.tracer.Start(ctx, "Mempool.Len") defer span.End() @@ -219,15 +222,15 @@ func (m *Mempool[T]) Len(ctx context.Context) int { } // Size returns the size (in bytes) of items in m. -func (m *Mempool[T]) Size(context.Context) int { +func (m *Mempool[T, E]) Size(context.Context) int { m.mu.RLock() defer m.mu.RUnlock() return m.pendingSize } -// SetMinTimestamp removes and returns all items with a lower expiry than [t] from m. -func (m *Mempool[T]) SetMinTimestamp(ctx context.Context, t int64) []T { +// SetMinTimestamp removes and returns all items with a lower expiry than [T, E] from m. +func (m *Mempool[T, E]) SetMinTimestamp(ctx context.Context, t int64) []T { _, span := m.tracer.Start(ctx, "Mempool.SetMinTimesamp") defer span.End() @@ -237,8 +240,7 @@ func (m *Mempool[T]) SetMinTimestamp(ctx context.Context, t int64) []T { removedElems := m.eh.SetMin(t) removed := make([]T, len(removedElems)) for i, remove := range removedElems { - m.queue.Remove(remove) - v := remove.Value() + v := m.queue.Remove(remove) m.removeFromOwned(v) m.pendingSize -= v.Size() removed[i] = v @@ -247,7 +249,7 @@ func (m *Mempool[T]) SetMinTimestamp(ctx context.Context, t int64) []T { } // Top iterates over the highest-valued items in the mempool. -func (m *Mempool[T]) Top( +func (m *Mempool[T, E]) Top( ctx context.Context, targetDuration time.Duration, f func(context.Context, T) (cont bool, restore bool, err error), @@ -289,7 +291,7 @@ func (m *Mempool[T]) Top( // Streaming is useful for block building because we can get a feed of the // best txs to build without holding the lock during the duration of the build // process. Streaming in batches allows for various state prefetching operations. -func (m *Mempool[T]) StartStreaming(_ context.Context) { +func (m *Mempool[T, E]) StartStreaming(_ context.Context) { m.mu.Lock() defer m.mu.Unlock() @@ -299,7 +301,7 @@ func (m *Mempool[T]) StartStreaming(_ context.Context) { // PrepareStream prefetches the next [count] items from the mempool to // reduce the latency of calls to [StreamItems]. -func (m *Mempool[T]) PrepareStream(ctx context.Context, count int) { +func (m *Mempool[T, E]) PrepareStream(ctx context.Context, count int) { _, span := m.tracer.Start(ctx, "Mempool.PrepareStream") defer span.End() @@ -312,7 +314,7 @@ func (m *Mempool[T]) PrepareStream(ctx context.Context, count int) { // Stream gets the next highest-valued [count] items from the mempool, not // including what has already been streamed. -func (m *Mempool[T]) Stream(ctx context.Context, count int) []T { +func (m *Mempool[T, E]) Stream(ctx context.Context, count int) []T { _, span := m.tracer.Start(ctx, "Mempool.Stream") defer span.End() @@ -328,7 +330,7 @@ func (m *Mempool[T]) Stream(ctx context.Context, count int) []T { return m.streamItems(count) } -func (m *Mempool[T]) streamItems(count int) []T { +func (m *Mempool[T, E]) streamItems(count int) []T { txs := make([]T, 0, count) for len(txs) < count { item, ok := m.popNext() @@ -343,7 +345,7 @@ func (m *Mempool[T]) streamItems(count int) []T { // FinishStreaming restores [restorable] items to the mempool and clears // the set of all previously streamed items. -func (m *Mempool[T]) FinishStreaming(ctx context.Context, restorable []T) int { +func (m *Mempool[T, E]) FinishStreaming(ctx context.Context, restorable []T) int { _, span := m.tracer.Start(ctx, "Mempool.FinishStreaming") defer span.End() @@ -366,3 +368,41 @@ func (m *Mempool[T]) FinishStreaming(ctx context.Context, restorable []T) int { m.streamLock.Unlock() return restored } + +// newGeneralMempool is creating a mempool using a FIFO queue. +func newGeneralMempool[T Item]( + tracer trace.Tracer, + maxSize int, + maxSponsorSize int, +) *Mempool[T, *list.Element[T]] { + return &Mempool[T, *list.Element[T]]{ + tracer: tracer, + + maxSize: maxSize, + maxSponsorSize: maxSponsorSize, + + queue: queue.NewList[T](), + eh: eheap.New[*list.Element[T]](min(maxSize, maxPrealloc)), + + owned: map[codec.Address]int{}, + } +} + +// NewPriorityMempool is creating a mempool using a Priority queue. +func newPriorityMempool[T Item]( + tracer trace.Tracer, + maxSize int, + maxSponsorSize int, +) *Mempool[T, T] { + return &Mempool[T, T]{ + tracer: tracer, + + maxSize: maxSize, + maxSponsorSize: maxSponsorSize, + + queue: queue.NewPriorityQueue[T](), + eh: eheap.New[T](min(maxSize, maxPrealloc)), + + owned: map[codec.Address]int{}, + } +} diff --git a/internal/mempool/mempool_test.go b/internal/mempool/mempool_test.go index 1732a8d46f..3510a53e86 100644 --- a/internal/mempool/mempool_test.go +++ b/internal/mempool/mempool_test.go @@ -20,6 +20,7 @@ type TestItem struct { id ids.ID sponsor codec.Address timestamp int64 + priority uint64 } func (mti *TestItem) GetID() ids.ID { @@ -34,28 +35,33 @@ func (mti *TestItem) GetExpiry() int64 { return mti.timestamp } +func (mti *TestItem) Priority() uint64 { + return mti.priority +} + func (*TestItem) Size() int { return 2 // distinguish from len } -func GenerateTestItem(sponsor codec.Address, t int64) *TestItem { +func GenerateTestItem(sponsor codec.Address, t int64, p uint64) *TestItem { id := ids.GenerateTestID() return &TestItem{ id: id, sponsor: sponsor, timestamp: t, + priority: p, } } -func TestMempool(t *testing.T) { +func TestGeneralMempool(t *testing.T) { require := require.New(t) ctx := context.TODO() tracer, _ := trace.New(&trace.Config{Enabled: false}) - txm := New[*TestItem](tracer, 3, 16) + txm := newGeneralMempool[*TestItem](tracer, 3, 16) for _, i := range []int64{100, 200, 300, 400} { - item := GenerateTestItem(testSponsor, i) + item := GenerateTestItem(testSponsor, i, 0) items := []*TestItem{item} txm.Add(ctx, items) } @@ -66,13 +72,32 @@ func TestMempool(t *testing.T) { require.Equal(6, txm.Size(ctx)) } +func TestPriorityMempool(t *testing.T) { + require := require.New(t) + + ctx := context.TODO() + tracer, _ := trace.New(&trace.Config{Enabled: false}) + txm := NewPriorityMempool[*TestItem](tracer, 3, 16) + + for _, i := range []int64{100, 200, 300, 400} { + item := GenerateTestItem(testSponsor, i, uint64(i)) + items := []*TestItem{item} + txm.Add(ctx, items) + } + next, ok := txm.PeekNext(ctx) + require.True(ok) + require.Equal(int64(300), next.GetExpiry()) + require.Equal(3, txm.Len(ctx)) + require.Equal(6, txm.Size(ctx)) +} + func TestMempoolAddDuplicates(t *testing.T) { require := require.New(t) ctx := context.TODO() tracer, _ := trace.New(&trace.Config{Enabled: false}) - txm := New[*TestItem](tracer, 3, 16) + txm := newGeneralMempool[*TestItem](tracer, 3, 16) // Generate item - item := GenerateTestItem(testSponsor, 300) + item := GenerateTestItem(testSponsor, 300, 0) items := []*TestItem{item} txm.Add(ctx, items) require.Equal(1, txm.Len(ctx), "Item not added.") @@ -92,10 +117,10 @@ func TestMempoolAddExceedMaxSponsorSize(t *testing.T) { tracer, _ := trace.New(&trace.Config{Enabled: false}) sponsor := codec.CreateAddress(4, ids.GenerateTestID()) // Non exempt sponsors max of 4 - txm := New[*TestItem](tracer, 20, 4) + txm := newGeneralMempool[*TestItem](tracer, 20, 4) // Add 6 transactions for each sponsor for i := int64(0); i <= 5; i++ { - itemSponsor := GenerateTestItem(sponsor, i) + itemSponsor := GenerateTestItem(sponsor, i, 0) txm.Add(ctx, []*TestItem{itemSponsor}) } require.Equal(4, txm.Len(ctx), "Mempool has incorrect txs.") @@ -107,10 +132,10 @@ func TestMempoolAddExceedMaxSize(t *testing.T) { ctx := context.TODO() tracer, _ := trace.New(&trace.Config{Enabled: false}) - txm := New[*TestItem](tracer, 3, 20) + txm := newGeneralMempool[*TestItem](tracer, 3, 20) // Add more tx's than txm.maxSize for i := int64(0); i < 10; i++ { - item := GenerateTestItem(testSponsor, i) + item := GenerateTestItem(testSponsor, i, 0) items := []*TestItem{item} txm.Add(ctx, items) if i < 3 { @@ -135,14 +160,14 @@ func TestMempoolRemoveTxs(t *testing.T) { ctx := context.TODO() tracer, _ := trace.New(&trace.Config{Enabled: false}) - txm := New[*TestItem](tracer, 3, 20) + txm := newGeneralMempool[*TestItem](tracer, 3, 20) // Add - item := GenerateTestItem(testSponsor, 10) + item := GenerateTestItem(testSponsor, 10, 0) items := []*TestItem{item} txm.Add(ctx, items) require.True(txm.Has(ctx, item.GetID()), "TX not included") // Remove - itemNotIn := GenerateTestItem(testSponsor, 10) + itemNotIn := GenerateTestItem(testSponsor, 10, 0) items = []*TestItem{item, itemNotIn} txm.Remove(ctx, items) require.Equal(0, txm.Len(ctx), "Mempool has incorrect number of txs.") @@ -153,10 +178,10 @@ func TestMempoolSetMinTimestamp(t *testing.T) { ctx := context.TODO() tracer, _ := trace.New(&trace.Config{Enabled: false}) - txm := New[*TestItem](tracer, 20, 20) + txm := newGeneralMempool[*TestItem](tracer, 20, 20) // Add more tx's than txm.maxSize for i := int64(0); i < 10; i++ { - item := GenerateTestItem(testSponsor, i) + item := GenerateTestItem(testSponsor, i, 0) items := []*TestItem{item} txm.Add(ctx, items) require.True(txm.Has(ctx, item.GetID()), "TX not included") diff --git a/internal/mempool/queue/fifo_queue.go b/internal/mempool/queue/fifo_queue.go new file mode 100644 index 0000000000..66e44c8fe1 --- /dev/null +++ b/internal/mempool/queue/fifo_queue.go @@ -0,0 +1,45 @@ +// Copyright (C) 2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package queue + +import ( + "github.com/ava-labs/hypersdk/internal/list" +) + +var _ Queue[Item, *list.Element[Item]] = (*FIFOQueue[Item])(nil) + +type FIFOQueue[T list.Item] struct { + *list.List[T] +} + +func NewList[T Item]() *FIFOQueue[T] { + return &FIFOQueue[T]{ + List: &list.List[T]{}, + } +} + +func (l *FIFOQueue[T]) Remove(elem *list.Element[T]) T { + return l.List.Remove(elem) +} + +func (l *FIFOQueue[T]) First() (*list.Element[T], bool) { + first := l.List.First() + return first, first != nil +} + +func (l *FIFOQueue[T]) FirstValue() (T, bool) { + first := l.List.First() + if first == nil { + return *new(T), false + } + return first.Value(), true +} + +func (l *FIFOQueue[T]) Restore(item T) *list.Element[T] { + return l.PushFront(item) +} + +func (l *FIFOQueue[T]) Push(item T) *list.Element[T] { + return l.PushBack(item) +} diff --git a/internal/mempool/queue/priority_queue.go b/internal/mempool/queue/priority_queue.go new file mode 100644 index 0000000000..84a4045380 --- /dev/null +++ b/internal/mempool/queue/priority_queue.go @@ -0,0 +1,43 @@ +// Copyright (C) 2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package queue + +import ( + "github.com/ava-labs/hypersdk/internal/pheap" +) + +var _ Queue[Item, Item] = (*PriorityQueue[Item])(nil) + +type PriorityQueue[T Item] struct { + *pheap.PriorityHeap[T] +} + +func NewPriorityQueue[T Item]() *PriorityQueue[T] { + return &PriorityQueue[T]{ + PriorityHeap: pheap.New[T](0), + } +} + +func (p *PriorityQueue[T]) Size() int { + return p.Len() +} + +func (p *PriorityQueue[T]) FirstValue() (T, bool) { + return p.First() +} + +func (p *PriorityQueue[T]) Push(item T) T { + p.Add(item) + return item +} + +func (p *PriorityQueue[T]) Remove(item T) T { + p.PriorityHeap.Remove(item.GetID()) + return item +} + +func (p *PriorityQueue[T]) Restore(item T) T { + p.PriorityHeap.Add(item) + return item +} diff --git a/internal/mempool/queue/queue.go b/internal/mempool/queue/queue.go new file mode 100644 index 0000000000..daaa75caa4 --- /dev/null +++ b/internal/mempool/queue/queue.go @@ -0,0 +1,23 @@ +// Copyright (C) 2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package queue + +import ( + "github.com/ava-labs/avalanchego/ids" +) + +type Item interface { + GetID() ids.ID // method for returning an id of the item + GetExpiry() int64 // method for returning this items timestamp + Priority() uint64 +} + +type Queue[T Item, E Item] interface { + Size() int + First() (E, bool) + FirstValue() (T, bool) + Remove(E) T + Push(T) E + Restore(T) E +} diff --git a/internal/pheap/pheap.go b/internal/pheap/pheap.go new file mode 100644 index 0000000000..63f708ae7c --- /dev/null +++ b/internal/pheap/pheap.go @@ -0,0 +1,79 @@ +// Copyright (C) 2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package pheap + +import ( + "github.com/ava-labs/avalanchego/ids" + + "github.com/ava-labs/hypersdk/internal/heap" +) + +// Item is the interface that any item put in the heap must adheare to. +type Item interface { + GetID() ids.ID + Priority() uint64 +} + +// PriorityHeap keeps a max heap of [Items] sorted by [PriorityFees]. +type PriorityHeap[T Item] struct { + maxHeap *heap.Heap[T, uint64] +} + +// New returns an instance of PriorityHeap with maxHeap containing [items]. +func New[T Item](items int) *PriorityHeap[T] { + return &PriorityHeap[T]{ + maxHeap: heap.New[T, uint64](items, false), + } +} + +// Add pushes [item] to ph. +func (ph *PriorityHeap[T]) Add(item T) { + poolLen := ph.maxHeap.Len() + ph.maxHeap.Push(&heap.Entry[T, uint64]{ + ID: item.GetID(), + Val: item.Priority(), + Item: item, + Index: poolLen, + }) +} + +// Remove removes [id] from ph. If the id does not exist, Remove returns. +func (ph *PriorityHeap[T]) Remove(id ids.ID) (T, bool) { + entry, ok := ph.maxHeap.Get(id) // O(1) + if !ok { + // This should never happen, as that would mean the heaps are out of + // sync. + return *new(T), false + } + ph.maxHeap.Remove(entry.Index) // O(log N) + return entry.Item, true +} + +// PopMax removes the maximum value in ph. +func (ph *PriorityHeap[T]) Pop() (T, bool) { + entry := ph.maxHeap.Pop() + if entry == nil { + return *new(T), false + } + return entry.Item, true +} + +// Has returns if [item] is in ph. +func (ph *PriorityHeap[T]) Has(item ids.ID) bool { + return ph.maxHeap.Has(item) +} + +// Len returns the number of elements in ph. +func (ph *PriorityHeap[T]) Len() int { + return ph.maxHeap.Len() +} + +// First returns the maximum value in ph. +func (ph *PriorityHeap[T]) First() (T, bool) { + first := ph.maxHeap.First() + if first == nil { + return *new(T), false + } + return first.Item, true +} diff --git a/internal/pheap/pheap_test.go b/internal/pheap/pheap_test.go new file mode 100644 index 0000000000..d266caee2b --- /dev/null +++ b/internal/pheap/pheap_test.go @@ -0,0 +1,167 @@ +// Copyright (C) 2024, Ava Labs, Inc. All rights reserved. +// See the file LICENSE for licensing terms. + +package pheap + +import ( + "testing" + + "github.com/ava-labs/avalanchego/ids" + "github.com/stretchr/testify/require" +) + +const testSponsor = "testSponsor" + +type TestItem struct { + id ids.ID + sponsor string + timestamp int64 + priority uint64 +} + +func (mti *TestItem) GetID() ids.ID { + return mti.id +} + +func (mti *TestItem) Sponsor() string { + return mti.sponsor +} + +func (mti *TestItem) GetExpiry() int64 { + return mti.timestamp +} + +func (mti *TestItem) Priority() uint64 { + return mti.priority +} + +func GenerateTestItem(sponsor string, t int64, priority uint64) *TestItem { + id := ids.GenerateTestID() + return &TestItem{ + id: id, + sponsor: sponsor, + timestamp: t, + priority: priority, + } +} + +func TestPriorityHeapNew(t *testing.T) { + // Creates empty max heaps + require := require.New(t) + pheap := New[*TestItem](0) + require.Zero(pheap.maxHeap.Len(), "MaxHeap not initialized correctly") +} + +func TestPriorityHeapAdd(t *testing.T) { + // Adds to the mempool. + require := require.New(t) + pheap := New[*TestItem](0) + item := GenerateTestItem("sponsor", 1, 0) + pheap.Add(item) + require.Equal(1, pheap.maxHeap.Len(), "MaxHeap not pushed correctly") + require.True(pheap.maxHeap.Has(item.GetID()), "MaxHeap does not have ID") +} + +func TestPriorityHeapRemove(t *testing.T) { + // Removes from the mempool. + require := require.New(t) + pheap := New[*TestItem](0) + item := GenerateTestItem("sponsor", 1, 0) + // Add first + pheap.Add(item) + require.Equal(1, pheap.maxHeap.Len(), "MaxHeap not pushed correctly") + require.True(pheap.maxHeap.Has(item.GetID()), "MaxHeap does not have ID") + // Remove + pheap.Remove(item.GetID()) + require.Zero(pheap.maxHeap.Len(), "MaxHeap not removed") + require.False(pheap.maxHeap.Has(item.GetID()), "MaxHeap still has ID") +} + +func TestPriorityHeapRemoveEmpty(t *testing.T) { + // Try to remove a non existing entry. + // Removes from the mempool. + require := require.New(t) + pheap := New[*TestItem](0) + item := GenerateTestItem("sponsor", 1, 0) + // Require this returns + pheap.Remove(item.GetID()) + require.True(true, "not true") +} + +func TestPeekFirst(t *testing.T) { + require := require.New(t) + pheap := New[*TestItem](0) + + itemMin := GenerateTestItem(testSponsor, 1, 1) + itemMed := GenerateTestItem(testSponsor, 2, 2) + itemMax := GenerateTestItem(testSponsor, 3, 3) + max, ok := pheap.First() + require.False(ok) + require.Nil(max, "Peek UnitPrice is incorrect") + // Check PeekFirst + pheap.Add(itemMed) + require.True(pheap.Has(itemMed.GetID()), "TX not included") + max, ok = pheap.First() + require.True(ok) + require.Equal(itemMed, max, "Peek value is incorrect") + + pheap.Add(itemMax) + require.True(pheap.Has(itemMax.GetID()), "TX not included") + max, ok = pheap.First() + require.True(ok) + require.Equal(itemMax, max, "Peek value is incorrect") + + pheap.Add(itemMin) + require.True(pheap.Has(itemMin.GetID()), "TX not included") + max, ok = pheap.First() + require.True(ok) + require.Equal(itemMax, max, "Peek value is incorrect") +} + +func TestPop(t *testing.T) { + require := require.New(t) + + pheap := New[*TestItem](0) + + itemMin := GenerateTestItem(testSponsor, 1, 1) + itemMed := GenerateTestItem(testSponsor, 2, 2) + itemMax := GenerateTestItem(testSponsor, 3, 3) + max, ok := pheap.Pop() + require.False(ok) + require.Nil(max, "Pop value is incorrect") + // Check Pop + pheap.Add(itemMed) + pheap.Add(itemMin) + pheap.Add(itemMax) + max, ok = pheap.Pop() + require.True(ok) + require.Equal(itemMax, max, "PopMax value is incorrect") + max, ok = pheap.Pop() + require.True(ok) + require.Equal(itemMed, max, "PopMax value is incorrect") + max, ok = pheap.Pop() + require.True(ok) + require.Equal(itemMin, max, "PopMax value is incorrect") +} + +func TestHas(t *testing.T) { + require := require.New(t) + + pheap := New[*TestItem](0) + item := GenerateTestItem(testSponsor, 1, 1) + require.False(pheap.Has(item.GetID()), "Found an item that was not added.") + pheap.Add(item) + require.True(pheap.Has(item.GetID()), "Did not find item.") +} + +func TestLen(t *testing.T) { + require := require.New(t) + + pheap := New[*TestItem](0) + for i := int64(0); i <= 4; i++ { + item := GenerateTestItem(testSponsor, i, uint64(i)) + pheap.Add(item) + require.True(pheap.Has(item.GetID()), "TX not included") + } + require.Equal(5, pheap.Len(), "Length of mempool is not as expected.") +} diff --git a/vm/vm.go b/vm/vm.go index 8a8dfb54b9..b6b8df695e 100644 --- a/vm/vm.go +++ b/vm/vm.go @@ -53,6 +53,12 @@ import ( internalfees "github.com/ava-labs/hypersdk/internal/fees" ) +var mempoolFactory mempool.AbstractMempoolFactory[*chain.Transaction] = mempool.NewGeneralMempool[*chain.Transaction] + +func WithPriorityMempool(factory mempool.AbstractMempoolFactory[*chain.Transaction]) { + mempoolFactory = factory +} + const ( blockDB = "blockdb" stateDB = "statedb" @@ -109,7 +115,7 @@ type VM struct { network *p2p.Network tracer avatrace.Tracer - mempool *mempool.Mempool[*chain.Transaction] + mempool mempool.AbstractMempool[*chain.Transaction] // We cannot use a map here because we may parse blocks up in the ancestry parsedBlocks *avacache.LRU[ids.ID, *StatefulBlock] @@ -259,7 +265,7 @@ func (vm *VM) Initialize( defer span.End() // Set defaults - vm.mempool = mempool.New[*chain.Transaction](vm.tracer, vm.config.MempoolSize, vm.config.MempoolSponsorSize) + vm.mempool = mempoolFactory(vm.tracer, vm.config.MempoolSize, vm.config.MempoolSponsorSize) vm.acceptedSubscriptions = append(vm.acceptedSubscriptions, event.SubscriptionFunc[*StatefulBlock]{ NotifyF: func(ctx context.Context, b *StatefulBlock) error { droppedTxs := vm.mempool.SetMinTimestamp(ctx, b.Tmstmp)