Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions recordio/mmap_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ import (
"bytes"
"errors"
"fmt"
"github.com/thomasjungblut/go-sstables/recordio/simd"
"io"
"reflect"
"unsafe"

"golang.org/x/exp/mmap"

Expand All @@ -20,6 +23,9 @@ type MMapReader struct {
bufferPool *pool.Pool
path string
seekLen int

simdAvailable bool
mmapReaderSlice []byte
}

func (r *MMapReader) Open() error {
Expand Down Expand Up @@ -48,6 +54,14 @@ func (r *MMapReader) Open() error {
r.header = header
r.bufferPool = pool.NewPool(1024, 20)
r.open = true
r.simdAvailable = simd.AVXSupported()

v := reflect.ValueOf(r.mmapReader).Elem()
dataField := v.FieldByName("data")
dataPtr := unsafe.Pointer(dataField.UnsafeAddr())
dataSlice := *(*[]byte)(dataPtr)
r.mmapReaderSlice = unsafe.Slice(&dataSlice[0], len(dataSlice))

return nil
}

Expand All @@ -63,6 +77,10 @@ func (r *MMapReader) SeekNext(offset uint64) (uint64, []byte, error) {
return 0, nil, fmt.Errorf("unsupported on files with version lower than v2")
}

if r.simdAvailable {
return r.seekNextVectorized(offset)
}

headerBufPooled := r.bufferPool.Get(r.seekLen)
defer r.bufferPool.Put(headerBufPooled)

Expand Down Expand Up @@ -127,6 +145,29 @@ func (r *MMapReader) SeekNext(offset uint64) (uint64, []byte, error) {
}
}

func (r *MMapReader) seekNextVectorized(offset uint64) (uint64, []byte, error) {
i := offset
for {
ofx := simd.FindMagicNumber(r.mmapReaderSlice, int(i))
if ofx < 0 {
return 0, nil, io.EOF
}

record, err := r.ReadNextAt(uint64(ofx))
if err != nil {
if errors.Is(err, HeaderChecksumMismatchErr) || errors.Is(err, MagicNumberMismatchErr) || errors.Is(err, io.EOF) {
// try to seek again, the record couldn't be read fully
i = uint64(ofx + 1)
continue
}

return 0, nil, err
} else {
return uint64(ofx), record, nil
}
}
}

func (r *MMapReader) ReadNextAt(offset uint64) ([]byte, error) {
if !r.open || r.closed {
return nil, fmt.Errorf("reader at '%s' was either not opened yet or is closed already", r.path)
Expand Down
38 changes: 38 additions & 0 deletions recordio/simd/magic_number_search.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package simd

/*
#cgo CFLAGS: -mavx2
#include "search.h"
*/
import "C"
import "unsafe"

func AVXSupported() bool {
result := C.cpu_supports_avx2()
return int(result) == 1
}

func FindFirstMagicNumber(data []byte) int {
if len(data) < 3 {
return -1
}
ptr := (*C.uchar)(unsafe.Pointer(&data[0]))
offset := C.size_t(0)
length := C.size_t(len(data))
result := C.find_magic_numbers(ptr, offset, length)
return int(result)
}

func FindMagicNumber(data []byte, off int) int {
if len(data) < 3 {
return -1
}
if off >= len(data) || off < 0 {
return -1
}
ptr := (*C.uchar)(unsafe.Pointer(&data[0]))
offset := C.size_t(off)
length := C.size_t(len(data))
result := C.find_magic_numbers(ptr, offset, length)
return int(result)
}
40 changes: 40 additions & 0 deletions recordio/simd/magic_number_search_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
package simd

import (
"github.com/stretchr/testify/require"
"testing"
)

func TestMagicNumberSearchHappyPath(t *testing.T) {
if !AVXSupported() {
t.Skip()
}

data := make([]byte, 10000)

data[10000-300] = 145
data[10000-299] = 141
data[10000-298] = 76

data[10000-3] = 145
data[10000-2] = 141
data[10000-1] = 76

index := FindFirstMagicNumber(data)
require.Equal(t, 10000-300, index)
index = FindMagicNumber(data, 0)
require.Equal(t, 10000-300, index)

ix := FindFirstMagicNumber(data[9701:])
require.Equal(t, 296, ix)
ix = FindMagicNumber(data, 9701)
require.Equal(t, 10000-3, ix)
}

func TestMagicNumberSearchBoundary(t *testing.T) {
require.Equal(t, -1, FindFirstMagicNumber([]byte{0, 1}))
require.Equal(t, -1, FindMagicNumber([]byte{0, 1}, 0))
require.Equal(t, -1, FindMagicNumber([]byte{0, 1, 3, 4}, 3))
require.Equal(t, -1, FindMagicNumber([]byte{0, 1, 3, 4}, 4))
require.Equal(t, -1, FindMagicNumber([]byte{0, 1, 3, 4}, -1))
}
114 changes: 114 additions & 0 deletions recordio/simd/search.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#include "search.h"
#include <immintrin.h>
#include <cpuid.h>
#include <stdint.h>

static const unsigned char pattern[] = {145, 141, 76};

// Returns 1 if AVX2 is available, 0 otherwise
int cpu_supports_avx2() {
unsigned int eax, ebx, ecx, edx;

// First, check if CPUID leaf 7 is supported
if (__get_cpuid_max(0, 0) < 7)
return 0;

// Call CPUID leaf 7, subleaf 0
__cpuid_count(7, 0, eax, ebx, ecx, edx);

// Bit 5 of EBX in CPUID leaf 7 indicates AVX2 support
return (ebx & (1 << 5)) != 0;
}

// Returns 1 if AVX512 is available, 0 otherwise
int cpu_supports_avx512() {
unsigned int eax, ebx, ecx, edx;
if (__get_cpuid_max(0, 0) < 7) return 0;
__cpuid_count(7, 0, eax, ebx, ecx, edx);
return (ebx & (1 << 16)) != 0; // AVX-512F
}

int find_magic_numbers(const unsigned char* data, size_t off, size_t len) {
if (len < 3) return -1;
if (off >= len) return -1;

size_t i = off;
size_t end = len - 2;

// process 32 bytes per loop using AVX2
for (; i + 32 <= end; i += 1) {
__m256i d0 = _mm256_loadu_si256((const __m256i*)(data + i));
__m256i d1 = _mm256_loadu_si256((const __m256i*)(data + i + 1));
__m256i d2 = _mm256_loadu_si256((const __m256i*)(data + i + 2));

__m256i p0 = _mm256_set1_epi8(pattern[0]);
__m256i p1 = _mm256_set1_epi8(pattern[1]);
__m256i p2 = _mm256_set1_epi8(pattern[2]);

__m256i m0 = _mm256_cmpeq_epi8(d0, p0);
__m256i m1 = _mm256_cmpeq_epi8(d1, p1);
__m256i m2 = _mm256_cmpeq_epi8(d2, p2);

__m256i mask = _mm256_and_si256(_mm256_and_si256(m0, m1), m2);
int matchmask = _mm256_movemask_epi8(mask);

if (matchmask) {
// return the first match index
return i + __builtin_ctz(matchmask);
}
}

// Fallback naive scan for remaining bytes
for (; i < end; i++) {
if (data[i] == pattern[0] &&
data[i+1] == pattern[1] &&
data[i+2] == pattern[2]) {
return i;
}
}

return -1;
}

/*
TODO(thomas): we would need to split the cgo flags and compilation units to match

int find_magic_numbers_avx512(const unsigned char* data, size_t off, size_t len) {
if (len < 3) return -1;
if (off >= len) return -1;

size_t i = off;
size_t end = len - 2;

// process 64 bytes per loop using AVX512
for (size_t i = 0; i + 64 <= end; i++) {
__m512i d0 = _mm512_loadu_si512((const void*)(data + i));
__m512i d1 = _mm512_loadu_si512((const void*)(data + i + 1));
__m512i d2 = _mm512_loadu_si512((const void*)(data + i + 2));

__m512i p0 = _mm512_set1_epi8(pattern[0]);
__m512i p1 = _mm512_set1_epi8(pattern[1]);
__m512i p2 = _mm512_set1_epi8(pattern[2]);

__mmask64 m0 = _mm512_cmpeq_epi8_mask(d0, p0);
__mmask64 m1 = _mm512_cmpeq_epi8_mask(d1, p1);
__mmask64 m2 = _mm512_cmpeq_epi8_mask(d2, p2);

__mmask64 m = m0 & m1 & m2;
if (m) {
return i + __builtin_ctzll(m);
}
}

// Fallback naive scan for remaining bytes
for (; i < end; i++) {
if (data[i] == pattern[0] &&
data[i+1] == pattern[1] &&
data[i+2] == pattern[2]) {
return i;
}
}

return -1;
}
*/
12 changes: 12 additions & 0 deletions recordio/simd/search.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#ifndef SEARCH_H
#define SEARCH_H

#include <stddef.h>

int cpu_supports_avx2();

int cpu_supports_avx512();

int find_magic_numbers(const unsigned char* data, size_t off, size_t len);

#endif
Loading