Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions filecache/filecache.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"path/filepath"
"time"

"github.com/gofrs/flock"
"github.com/pkg/errors"
"go.jetify.com/pkg/cachehash"
)
Expand Down Expand Up @@ -57,6 +58,13 @@ func (c *Cache[T]) Set(key string, val T, dur time.Duration) error {
return errors.WithStack(err)
}

// Acquire exclusive lock to prevent concurrent writes from corrupting the file
lock := flock.New(c.lockfile())
if err := lock.Lock(); err != nil {
return errors.WithStack(err)
}
defer lock.Unlock()

return errors.WithStack(os.WriteFile(c.filename(key), d, 0o644))
}

Expand All @@ -68,6 +76,13 @@ func (c *Cache[T]) SetWithTime(key string, val T, t time.Time) error {
return errors.WithStack(err)
}

// Acquire exclusive lock to prevent concurrent writes from corrupting the file
lock := flock.New(c.lockfile())
if err := lock.Lock(); err != nil {
return errors.WithStack(err)
}
defer lock.Unlock()

return errors.WithStack(os.WriteFile(c.filename(key), d, 0o644))
}

Expand All @@ -76,6 +91,13 @@ func (c *Cache[T]) Get(key string) (T, error) {
path := c.filename(key)
resultData := data[T]{}

// Acquire shared lock before checking file existence to prevent TOCTOU race
lock := flock.New(c.lockfile())
if err := lock.RLock(); err != nil {
return resultData.Val, errors.WithStack(err)
}
defer lock.Unlock()

if _, err := os.Stat(path); errors.Is(err, os.ErrNotExist) {
return resultData.Val, NotFound
}
Expand Down Expand Up @@ -147,3 +169,9 @@ func (c *Cache[T]) filename(key string) string {
_ = os.MkdirAll(dir, 0o755)
return filepath.Join(dir, cachehash.Slug(key))
}

func (c *Cache[T]) lockfile() string {
dir := filepath.Join(c.cacheDir, c.domain)
_ = os.MkdirAll(dir, 0o755)
return filepath.Join(dir, ".lock")
}
260 changes: 260 additions & 0 deletions filecache/filecache_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
package filecache_test

import (
"fmt"
"sync"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.jetify.com/pkg/filecache"
)

type testData struct {
Value string
Counter int
}

func TestCacheOperations(t *testing.T) {
tests := []struct {
name string
run func(t *testing.T, cache *filecache.Cache[testData])
}{
{
name: "basic set and get",
run: func(t *testing.T, cache *filecache.Cache[testData]) {
// Test cache miss
_, err := cache.Get("key1")
assert.True(t, filecache.IsCacheMiss(err))

// Test Set and Get
data := testData{Value: "hello", Counter: 42}
err = cache.Set("key1", data, time.Hour)
require.NoError(t, err)

result, err := cache.Get("key1")
require.NoError(t, err)
assert.Equal(t, data, result)
},
},
{
name: "set with time",
run: func(t *testing.T, cache *filecache.Cache[testData]) {
data := testData{Value: "world", Counter: 123}
expiration := time.Now().Add(time.Hour)
err := cache.SetWithTime("key1", data, expiration)
require.NoError(t, err)

result, err := cache.Get("key1")
require.NoError(t, err)
assert.Equal(t, data, result)
},
},
{
name: "expiration",
run: func(t *testing.T, cache *filecache.Cache[testData]) {
data := testData{Value: "expires", Counter: 1}
// Set with expiration in the past
err := cache.SetWithTime("key1", data, time.Now().Add(-time.Hour))
require.NoError(t, err)

_, err = cache.Get("key1")
assert.True(t, filecache.IsCacheMiss(err))
},
},
{
name: "get or set",
run: func(t *testing.T, cache *filecache.Cache[testData]) {
callCount := 0
fetchFunc := func() (testData, time.Duration, error) {
callCount++
return testData{Value: "fetched", Counter: callCount}, time.Hour, nil
}

// First call should fetch
result1, err := cache.GetOrSet("key1", fetchFunc)
require.NoError(t, err)
assert.Equal(t, "fetched", result1.Value)
assert.Equal(t, 1, callCount)

// Second call should use cache
result2, err := cache.GetOrSet("key1", fetchFunc)
require.NoError(t, err)
assert.Equal(t, "fetched", result2.Value)
assert.Equal(t, 1, callCount) // Should not increment
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cache := filecache.New[testData]("test-domain", filecache.WithCacheDir[testData](t.TempDir()))
tt.run(t, cache)
})
}
}

func TestConcurrentAccess(t *testing.T) {
t.Run("concurrent writes to same key", func(t *testing.T) {
cache := filecache.New[testData]("test-domain", filecache.WithCacheDir[testData](t.TempDir()))

numGoroutines := 10
var wg sync.WaitGroup
wg.Add(numGoroutines)

// All goroutines write to the same key
for i := 0; i < numGoroutines; i++ {
go func(id int) {
defer wg.Done()
data := testData{Value: fmt.Sprintf("writer-%d", id), Counter: id}
err := cache.Set("same-key", data, time.Hour)
assert.NoError(t, err)
}(i)
}

wg.Wait()

// The key should exist and contain valid data from one of the writers
result, err := cache.Get("same-key")
require.NoError(t, err)
assert.NotEmpty(t, result.Value)
assert.True(t, result.Counter >= 0 && result.Counter < numGoroutines)
})

t.Run("concurrent writes to different keys", func(t *testing.T) {
cache := filecache.New[testData]("test-domain", filecache.WithCacheDir[testData](t.TempDir()))

numGoroutines := 20
var wg sync.WaitGroup
wg.Add(numGoroutines)

// Each goroutine writes to a different key
for i := 0; i < numGoroutines; i++ {
go func(id int) {
defer wg.Done()
key := fmt.Sprintf("key-%d", id)
data := testData{Value: fmt.Sprintf("value-%d", id), Counter: id}
err := cache.Set(key, data, time.Hour)
assert.NoError(t, err)
}(i)
}

wg.Wait()

// Verify all keys were written correctly
for i := 0; i < numGoroutines; i++ {
key := fmt.Sprintf("key-%d", i)
result, err := cache.Get(key)
require.NoError(t, err, "Failed to get key %s", key)
assert.Equal(t, fmt.Sprintf("value-%d", i), result.Value)
assert.Equal(t, i, result.Counter)
}
})

t.Run("concurrent get or set same key", func(t *testing.T) {
cache := filecache.New[testData]("test-domain", filecache.WithCacheDir[testData](t.TempDir()))

numGoroutines := 10
var wg sync.WaitGroup
wg.Add(numGoroutines)

callCount := 0
var mu sync.Mutex

fetchFunc := func() (testData, time.Duration, error) {
mu.Lock()
callCount++
count := callCount
mu.Unlock()
// Simulate slow fetch
time.Sleep(10 * time.Millisecond)
return testData{Value: "shared", Counter: count}, time.Hour, nil
}

// All goroutines try to GetOrSet the same key
for i := 0; i < numGoroutines; i++ {
go func() {
defer wg.Done()
result, err := cache.GetOrSet("shared-key", fetchFunc)
assert.NoError(t, err)
assert.Equal(t, "shared", result.Value)
}()
}

wg.Wait()

// The fetch function may be called multiple times due to race,
// but the final cached value should be valid
result, err := cache.Get("shared-key")
require.NoError(t, err)
assert.Equal(t, "shared", result.Value)
assert.True(t, result.Counter > 0)
})

t.Run("concurrent reads and writes", func(t *testing.T) {
cache := filecache.New[testData]("test-domain", filecache.WithCacheDir[testData](t.TempDir()))

// Pre-populate the cache
err := cache.Set("key", testData{Value: "initial", Counter: 0}, time.Hour)
require.NoError(t, err)

numReaders := 10
numWriters := 5
var wg sync.WaitGroup
wg.Add(numReaders + numWriters)

// Spawn readers
for i := 0; i < numReaders; i++ {
go func() {
defer wg.Done()
for j := 0; j < 100; j++ {
result, err := cache.Get("key")
// We should either get valid data or an error, but never corrupted data
if err == nil {
assert.NotEmpty(t, result.Value)
}
}
}()
}

// Spawn writers
for i := 0; i < numWriters; i++ {
go func(id int) {
defer wg.Done()
for j := 0; j < 50; j++ {
data := testData{Value: fmt.Sprintf("writer-%d-iteration-%d", id, j), Counter: j}
err := cache.Set("key", data, time.Hour)
assert.NoError(t, err)
}
}(i)
}

wg.Wait()

// Final read should succeed with valid data
result, err := cache.Get("key")
require.NoError(t, err)
assert.NotEmpty(t, result.Value)
})
}

func TestClear(t *testing.T) {
cache := filecache.New[testData]("test-domain", filecache.WithCacheDir[testData](t.TempDir()))

// Add some data
err := cache.Set("key1", testData{Value: "value1", Counter: 1}, time.Hour)
require.NoError(t, err)
err = cache.Set("key2", testData{Value: "value2", Counter: 2}, time.Hour)
require.NoError(t, err)

// Clear the cache
err = cache.Clear()
require.NoError(t, err)

// Data should be gone
_, err = cache.Get("key1")
assert.True(t, filecache.IsCacheMiss(err))
_, err = cache.Get("key2")
assert.True(t, filecache.IsCacheMiss(err))
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ require (
github.com/fatih/color v1.18.0
github.com/go-jose/go-jose/v4 v4.1.2
github.com/goccy/go-yaml v1.18.0
github.com/gofrs/flock v0.12.1
github.com/google/go-github/v74 v74.0.0
github.com/google/renameio/v2 v2.0.0
github.com/gosimple/slug v1.15.0
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ github.com/go-jose/go-jose/v4 v4.1.2 h1:TK/7NqRQZfgAh+Td8AlsrvtPoUyiHh0LqVvokh+1
github.com/go-jose/go-jose/v4 v4.1.2/go.mod h1:22cg9HWM1pOlnRiY+9cQYJ9XHmya1bYW8OeDM6Ku6Oo=
github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw=
github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA=
github.com/gofrs/flock v0.12.1 h1:MTLVXXHf8ekldpJk3AKicLij9MdwOWkZ+a/jHHZby9E=
github.com/gofrs/flock v0.12.1/go.mod h1:9zxTsyu5xtJ9DK+1tFZyibEV7y3uwDxPPfbxeeHCoD0=
github.com/gofrs/uuid/v5 v5.3.2 h1:2jfO8j3XgSwlz/wHqemAEugfnTlikAYHhnqQ8Xh4fE0=
github.com/gofrs/uuid/v5 v5.3.2/go.mod h1:CDOjlDMVAtN56jqyRUZh58JT31Tiw7/oQyEXZV+9bD8=
github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg=
Expand Down