Skip to content

Commit 88a400f

Browse files
Ensures response is handled with the same guest as the request (#69)
Before, we had a bug where an arbitrary module was used for response processing. This defeated the goal of request context correlation. This fixes the problem by holding the guest module open until the response is complete. Fixes #68 Signed-off-by: Adrian Cole <[email protected]>
1 parent 377347e commit 88a400f

File tree

7 files changed

+228
-57
lines changed

7 files changed

+228
-57
lines changed

api/handler/handler.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,3 +162,47 @@ type Host interface {
162162
// FeatureTrailers is not supported.
163163
RemoveResponseTrailer(ctx context.Context, name string)
164164
}
165+
166+
// eofReader is safer than reading from os.DevNull as it can never overrun
167+
// operating system file descriptors.
168+
type eofReader struct{}
169+
170+
func (eofReader) Close() (err error) { return }
171+
func (eofReader) Read([]byte) (int, error) { return 0, io.EOF }
172+
173+
type UnimplementedHost struct{}
174+
175+
var _ Host = UnimplementedHost{}
176+
177+
func (UnimplementedHost) EnableFeatures(context.Context, Features) Features { return 0 }
178+
func (UnimplementedHost) GetMethod(context.Context) string { return "GET" }
179+
func (UnimplementedHost) SetMethod(context.Context, string) {}
180+
func (UnimplementedHost) GetURI(context.Context) string { return "" }
181+
func (UnimplementedHost) SetURI(context.Context, string) {}
182+
func (UnimplementedHost) GetProtocolVersion(context.Context) string { return "HTTP/1.1" }
183+
func (UnimplementedHost) GetRequestHeaderNames(context.Context) (names []string) { return }
184+
func (UnimplementedHost) GetRequestHeaderValues(context.Context, string) (values []string) { return }
185+
func (UnimplementedHost) SetRequestHeaderValue(context.Context, string, string) {}
186+
func (UnimplementedHost) AddRequestHeaderValue(context.Context, string, string) {}
187+
func (UnimplementedHost) RemoveRequestHeader(context.Context, string) {}
188+
func (UnimplementedHost) RequestBodyReader(context.Context) io.ReadCloser { return eofReader{} }
189+
func (UnimplementedHost) RequestBodyWriter(context.Context) io.Writer { return io.Discard }
190+
func (UnimplementedHost) GetRequestTrailerNames(context.Context) (names []string) { return }
191+
func (UnimplementedHost) GetRequestTrailerValues(context.Context, string) (values []string) { return }
192+
func (UnimplementedHost) SetRequestTrailerValue(context.Context, string, string) {}
193+
func (UnimplementedHost) AddRequestTrailerValue(context.Context, string, string) {}
194+
func (UnimplementedHost) RemoveRequestTrailer(context.Context, string) {}
195+
func (UnimplementedHost) GetStatusCode(context.Context) uint32 { return 200 }
196+
func (UnimplementedHost) SetStatusCode(context.Context, uint32) {}
197+
func (UnimplementedHost) GetResponseHeaderNames(context.Context) (names []string) { return }
198+
func (UnimplementedHost) GetResponseHeaderValues(context.Context, string) (values []string) { return }
199+
func (UnimplementedHost) SetResponseHeaderValue(context.Context, string, string) {}
200+
func (UnimplementedHost) AddResponseHeaderValue(context.Context, string, string) {}
201+
func (UnimplementedHost) RemoveResponseHeader(context.Context, string) {}
202+
func (UnimplementedHost) ResponseBodyReader(context.Context) io.ReadCloser { return eofReader{} }
203+
func (UnimplementedHost) ResponseBodyWriter(context.Context) io.Writer { return io.Discard }
204+
func (UnimplementedHost) GetResponseTrailerNames(context.Context) (names []string) { return }
205+
func (UnimplementedHost) GetResponseTrailerValues(context.Context, string) (values []string) { return }
206+
func (UnimplementedHost) SetResponseTrailerValue(context.Context, string, string) {}
207+
func (UnimplementedHost) AddResponseTrailerValue(context.Context, string, string) {}
208+
func (UnimplementedHost) RemoveResponseTrailer(context.Context, string) {}

handler/middleware.go

Lines changed: 58 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,15 @@ type Middleware interface {
4545
var _ Middleware = (*middleware)(nil)
4646

4747
type middleware struct {
48-
host handler.Host
49-
runtime wazero.Runtime
50-
hostModule, guestModule wazero.CompiledModule
51-
moduleConfig wazero.ModuleConfig
52-
guestConfig []byte
53-
logger api.Logger
54-
pool sync.Pool
55-
features handler.Features
56-
instanceCounter uint64
48+
host handler.Host
49+
runtime wazero.Runtime
50+
guestModule wazero.CompiledModule
51+
moduleConfig wazero.ModuleConfig
52+
guestConfig []byte
53+
logger api.Logger
54+
pool sync.Pool
55+
features handler.Features
56+
instanceCounter uint64
5757
}
5858

5959
func (m *middleware) Features() handler.Features {
@@ -83,29 +83,30 @@ func NewMiddleware(ctx context.Context, guest []byte, host handler.Host, opts ..
8383
logger: o.logger,
8484
}
8585

86-
if m.hostModule, err = m.compileHost(ctx); err != nil {
87-
_ = m.Close(ctx)
86+
if m.guestModule, err = m.compileGuest(ctx, guest); err != nil {
87+
_ = wr.Close(ctx)
8888
return nil, err
8989
}
9090

91-
if _, err = wasi_snapshot_preview1.Instantiate(ctx, m.runtime); err != nil {
92-
return nil, fmt.Errorf("wasm: error instantiating wasi: %w", err)
93-
}
94-
95-
// Note: host modules don't use configuration
96-
_, err = m.runtime.InstantiateModule(ctx, m.hostModule, wazero.NewModuleConfig())
97-
if err != nil {
98-
_ = m.runtime.Close(ctx)
99-
return nil, fmt.Errorf("wasm: error instantiating host: %w", err)
100-
}
101-
102-
if m.guestModule, err = m.compileGuest(ctx, guest); err != nil {
103-
_ = m.Close(ctx)
104-
return nil, err
91+
// Detect and handle any host imports or lack thereof.
92+
imports := detectImports(m.guestModule.ImportedFunctions())
93+
switch {
94+
case imports&importWasiP1 != 0:
95+
if _, err = wasi_snapshot_preview1.Instantiate(ctx, m.runtime); err != nil {
96+
_ = wr.Close(ctx)
97+
return nil, fmt.Errorf("wasm: error instantiating wasi: %w", err)
98+
}
99+
fallthrough // proceed to configure any http_handler imports
100+
case imports&importHttpHandler != 0:
101+
if _, err = m.instantiateHost(ctx); err != nil {
102+
_ = wr.Close(ctx)
103+
return nil, fmt.Errorf("wasm: error instantiating host: %w", err)
104+
}
105105
}
106106

107+
// Eagerly add one instance to the pool. Doing so helps to fail fast.
107108
if g, err := m.newGuest(ctx); err != nil {
108-
_ = m.Close(ctx)
109+
_ = wr.Close(ctx)
109110
return nil, err
110111
} else {
111112
m.pool.Put(g)
@@ -139,11 +140,11 @@ func (m *middleware) HandleRequest(ctx context.Context) (outCtx context.Context,
139140
err = guestErr
140141
return
141142
}
142-
defer m.pool.Put(g)
143143

144-
s := &requestState{features: m.features}
144+
s := &requestState{features: m.features, putPool: m.pool.Put, g: g}
145145
defer func() {
146-
if ctxNext != 0 { // will call the next handler
146+
callNext := ctxNext != 0
147+
if callNext { // will call the next handler
147148
if closeErr := s.closeRequest(); err == nil {
148149
err = closeErr
149150
}
@@ -173,16 +174,10 @@ func (m *middleware) getOrCreateGuest(ctx context.Context) (*guest, error) {
173174

174175
// HandleResponse implements Middleware.HandleResponse
175176
func (m *middleware) HandleResponse(ctx context.Context, reqCtx uint32, hostErr error) error {
176-
g, err := m.getOrCreateGuest(ctx)
177-
if err != nil {
178-
return err
179-
}
180-
defer m.pool.Put(g)
181-
182177
s := requestStateFromContext(ctx)
183178
defer s.Close()
184179

185-
return g.handleResponse(ctx, reqCtx, hostErr)
180+
return s.g.handleResponse(ctx, reqCtx, hostErr)
186181
}
187182

188183
// Close implements api.Closer
@@ -529,7 +524,7 @@ func (m *middleware) readBody(ctx context.Context, mod wazeroapi.Module, stack [
529524
panic("unsupported body kind: " + strconv.Itoa(int(kind)))
530525
}
531526

532-
eofLen := readBody(ctx, mod, buf, bufLimit, r)
527+
eofLen := readBody(mod, buf, bufLimit, r)
533528

534529
stack[0] = eofLen
535530
}
@@ -562,10 +557,10 @@ func (m *middleware) writeBody(ctx context.Context, mod wazeroapi.Module, params
562557
panic("unsupported body kind: " + strconv.Itoa(int(kind)))
563558
}
564559

565-
writeBody(ctx, mod, buf, bufLen, w)
560+
writeBody(mod, buf, bufLen, w)
566561
}
567562

568-
func writeBody(ctx context.Context, mod wazeroapi.Module, buf, bufLen uint32, w io.Writer) {
563+
func writeBody(mod wazeroapi.Module, buf, bufLen uint32, w io.Writer) {
569564
// buf_len 0 means to overwrite with nothing
570565
var b []byte
571566
if bufLen > 0 {
@@ -596,7 +591,7 @@ func (m *middleware) setStatusCode(ctx context.Context, params []uint64) {
596591
m.host.SetStatusCode(ctx, statusCode)
597592
}
598593

599-
func readBody(ctx context.Context, mod wazeroapi.Module, buf uint32, bufLimit handler.BufLimit, r io.Reader) (eofLen uint64) {
594+
func readBody(mod wazeroapi.Module, buf uint32, bufLimit handler.BufLimit, r io.Reader) (eofLen uint64) {
600595
// buf_limit 0 serves no purpose as implementations won't return EOF on it.
601596
if bufLimit == 0 {
602597
panic(fmt.Errorf("buf_limit==0 reading body"))
@@ -645,8 +640,8 @@ func mustBeforeNextOrFeature(ctx context.Context, feature handler.Features, op,
645640

646641
const i32, i64 = wazeroapi.ValueTypeI32, wazeroapi.ValueTypeI64
647642

648-
func (m *middleware) compileHost(ctx context.Context) (wazero.CompiledModule, error) {
649-
if compiled, err := m.runtime.NewHostModuleBuilder(handler.HostModule).
643+
func (m *middleware) instantiateHost(ctx context.Context) (wazeroapi.Module, error) {
644+
return m.runtime.NewHostModuleBuilder(handler.HostModule).
650645
NewFunctionBuilder().
651646
WithGoFunction(wazeroapi.GoFunc(m.enableFeatures), []wazeroapi.ValueType{i32}, []wazeroapi.ValueType{i32}).
652647
WithParameterNames("features").Export(handler.FuncEnableFeatures).
@@ -701,11 +696,7 @@ func (m *middleware) compileHost(ctx context.Context) (wazero.CompiledModule, er
701696
NewFunctionBuilder().
702697
WithGoFunction(wazeroapi.GoFunc(m.setStatusCode), []wazeroapi.ValueType{i32}, []wazeroapi.ValueType{}).
703698
WithParameterNames("status_code").Export(handler.FuncSetStatusCode).
704-
Compile(ctx); err != nil {
705-
return nil, fmt.Errorf("wasm: error compiling host: %w", err)
706-
} else {
707-
return compiled, nil
708-
}
699+
Instantiate(ctx)
709700
}
710701

711702
func mustHeaderMutable(ctx context.Context, op string, kind handler.HeaderKind) {
@@ -766,3 +757,23 @@ func writeStringIfUnderLimit(mem wazeroapi.Memory, offset, limit handler.BufLimi
766757
mem.WriteString(offset, v)
767758
return
768759
}
760+
761+
type imports uint
762+
763+
const (
764+
importWasiP1 imports = 1 << iota
765+
importHttpHandler
766+
)
767+
768+
func detectImports(importedFns []wazeroapi.FunctionDefinition) (imports imports) {
769+
for _, f := range importedFns {
770+
moduleName, _, _ := f.Import()
771+
switch moduleName {
772+
case handler.HostModule:
773+
imports |= importHttpHandler
774+
case wasi_snapshot_preview1.ModuleName:
775+
imports |= importWasiP1
776+
}
777+
}
778+
return
779+
}

handler/middleware_test.go

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
package handler
2+
3+
import (
4+
"context"
5+
_ "embed"
6+
"reflect"
7+
"testing"
8+
9+
"github.com/http-wasm/http-wasm-host-go/api/handler"
10+
"github.com/http-wasm/http-wasm-host-go/internal/test"
11+
)
12+
13+
var testCtx = context.Background()
14+
15+
func Test_MiddlewareResponseUsesRequestModule(t *testing.T) {
16+
mw, err := NewMiddleware(testCtx, test.BinE2EHandleResponse, handler.UnimplementedHost{})
17+
if err != nil {
18+
t.Fatal(err)
19+
}
20+
defer mw.Close(testCtx)
21+
22+
// A new guest module has initial state, so its value should be 42
23+
r1Ctx, ctxNext, err := mw.HandleRequest(testCtx)
24+
expectHandleRequest(t, mw, ctxNext, err, 42)
25+
26+
// The first guest shouldn't return to the pool until HandleResponse, so
27+
// the second simultaneous call will get a new guest.
28+
r2Ctx, ctxNext2, err := mw.HandleRequest(testCtx)
29+
expectHandleRequest(t, mw, ctxNext2, err, 42)
30+
31+
// Return the first request to the pool
32+
if err = mw.HandleResponse(r1Ctx, uint32(ctxNext>>32), nil); err != nil {
33+
t.Fatal(err)
34+
}
35+
expectGlobals(t, mw, 43)
36+
37+
// The next request should re-use the returned module
38+
r3Ctx, ctxNext3, err := mw.HandleRequest(testCtx)
39+
expectHandleRequest(t, mw, ctxNext3, err, 43)
40+
if err = mw.HandleResponse(r3Ctx, uint32(ctxNext3>>32), nil); err != nil {
41+
t.Fatal(err)
42+
}
43+
expectGlobals(t, mw, 44)
44+
45+
// Return the second request to the pool
46+
if err = mw.HandleResponse(r2Ctx, uint32(ctxNext2>>32), nil); err != nil {
47+
t.Fatal(err)
48+
}
49+
expectGlobals(t, mw, 44, 43)
50+
}
51+
52+
func expectGlobals(t *testing.T, mw Middleware, wantGlobals ...uint64) {
53+
t.Helper()
54+
if want, have := wantGlobals, getGlobalVals(mw); !reflect.DeepEqual(want, have) {
55+
t.Errorf("unexpected globals, want: %v, have: %v", want, have)
56+
}
57+
}
58+
59+
func getGlobalVals(mw Middleware) []uint64 {
60+
pool := mw.(*middleware).pool
61+
var guests []*guest
62+
var globals []uint64
63+
64+
// Take all guests out of the pool
65+
for {
66+
if g, ok := pool.Get().(*guest); ok {
67+
guests = append(guests, g)
68+
continue
69+
}
70+
break
71+
}
72+
73+
for _, g := range guests {
74+
v := g.guest.ExportedGlobal("reqCtx").Get()
75+
globals = append(globals, v)
76+
pool.Put(g)
77+
}
78+
79+
return globals
80+
}
81+
82+
func expectHandleRequest(t *testing.T, mw Middleware, ctxNext handler.CtxNext, err error, expectedCtx handler.CtxNext) {
83+
t.Helper()
84+
if err != nil {
85+
t.Fatal(err)
86+
}
87+
if want, have := expectedCtx, ctxNext>>32; want != have {
88+
t.Errorf("unexpected ctx, want: %d, have: %d", want, have)
89+
}
90+
if mw.(*middleware).pool.Get() != nil {
91+
t.Error("expected handler to not return guest to the pool")
92+
}
93+
}

handler/state.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ type requestState struct {
2626
// features are the current request's features which may be more than
2727
// Middleware.Features.
2828
features handler.Features
29+
30+
putPool func(x any)
31+
g *guest
2932
}
3033

3134
func (r *requestState) closeRequest() (err error) {
@@ -44,6 +47,10 @@ func (r *requestState) closeRequest() (err error) {
4447

4548
// Close implements io.Closer
4649
func (r *requestState) Close() (err error) {
50+
if g := r.g; g != nil {
51+
r.putPool(r.g)
52+
r.g = nil
53+
}
4754
err = r.closeRequest()
4855
if respBW := r.responseBodyWriter; respBW != nil {
4956
if f, ok := respBW.(http.Flusher); ok {

internal/test/testdata.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"log"
66
"os"
77
"path"
8+
"runtime"
89
)
910

1011
//go:embed testdata/bench/log.wasm
@@ -84,7 +85,11 @@ var BinE2EHeaderNames []byte
8485

8586
// binExample instead of go:embed as files aren't relative to this directory.
8687
func binExample(name string) []byte {
87-
p := path.Join("..", "..", "examples", name+".wasm")
88+
_, thisFile, _, ok := runtime.Caller(1)
89+
if !ok {
90+
log.Panicln("cannot determine current path")
91+
}
92+
p := path.Join(path.Dir(thisFile), "..", "..", "examples", name+".wasm")
8893
if wasm, err := os.ReadFile(p); err != nil {
8994
log.Panicln(err)
9095
return nil
25 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)