diff --git a/enterprise/server/cmd/ci_runner/BUILD b/enterprise/server/cmd/ci_runner/BUILD index 8293d201a62..c4f166d800d 100644 --- a/enterprise/server/cmd/ci_runner/BUILD +++ b/enterprise/server/cmd/ci_runner/BUILD @@ -1,6 +1,6 @@ load("@io_bazel_rules_docker//container:container.bzl", "container_image") load("@io_bazel_rules_docker//go:image.bzl", "go_image") -load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_library") +load("@io_bazel_rules_go//go:def.bzl", "go_binary", "go_library", "go_test") # gazelle:default_visibility //enterprise:__subpackages__ package( @@ -72,3 +72,10 @@ go_image( binary = ":ci_runner", tags = ["manual"], ) + +go_test( + name = "ci_runner_test", + srcs = ["main_test.go"], + embed = [":main"], # keep + deps = ["@com_github_stretchr_testify//require"], +) diff --git a/enterprise/server/cmd/ci_runner/main.go b/enterprise/server/cmd/ci_runner/main.go index b7e5601ab05..68e71abc339 100644 --- a/enterprise/server/cmd/ci_runner/main.go +++ b/enterprise/server/cmd/ci_runner/main.go @@ -486,6 +486,9 @@ func (r *buildEventReporter) Stop() error { r.cancelBackgroundFlush() r.cancelBackgroundFlush = nil } + if err := r.log.Flush(); err != nil { + return err + } r.FlushProgress() elapsedTimeSeconds := float64(time.Since(r.startTime)) / float64(time.Second) @@ -938,6 +941,10 @@ type invocationLog struct { lockingbuffer.LockingBuffer writer io.Writer writeListener func(s string) + + mu sync.Mutex + partialLine bytes.Buffer + writeErr error // First write error encountered } func newInvocationLog() *invocationLog { @@ -946,17 +953,79 @@ func newInvocationLog() *invocationLog { return invLog } +// Write ultimately writes the given bytes to the writeListener and given writer. +// Write will buffer bytes until encountering a newline, when it will redact secrets from the full line. +// Callers are expected to call Flush when there will be no more Write calls. func (invLog *invocationLog) Write(b []byte) (int, error) { - output := string(b) + invLog.mu.Lock() + defer invLog.mu.Unlock() - redacted := redact.RedactText(output) + if invLog.writeErr != nil { + return 0, invLog.writeErr + } - invLog.writeListener(redacted) - _, err := invLog.writer.Write([]byte(redacted)) + if n, err := invLog.partialLine.Write(b); err != nil { + return n, err + } + + var writeErr error + + for { + data := invLog.partialLine.Bytes() + if len(data) == 0 { + break + } + + idx := bytes.IndexAny(data, "\n\r") + if idx < 0 { + break + } + + newlineLen := 1 + if data[idx] == '\r' && idx+1 < len(data) && data[idx+1] == '\n' { + newlineLen = 2 + } + + line := string(data[:idx+newlineLen]) + invLog.partialLine.Next(idx + newlineLen) + + redacted := redact.RedactText(line) - // Return the size of the original buffer even if a redacted size was written, - // or clients will return a short write error - return len(b), err + invLog.writeListener(redacted) + if _, err := invLog.writer.Write([]byte(redacted)); err != nil && writeErr == nil { + writeErr = err + invLog.writeErr = err + } + } + + return len(b), writeErr +} + +// Flush redacts secrets from the partialLine buffer and writes the redacted output +// to the writeListener and underlying writer. +func (invLog *invocationLog) Flush() error { + invLog.mu.Lock() + defer invLog.mu.Unlock() + + if invLog.writeErr != nil { + return invLog.writeErr + } + + if invLog.partialLine.Len() == 0 { + return nil + } + + line := invLog.partialLine.String() + invLog.partialLine.Reset() + + redacted := redact.RedactText(line) + + invLog.writeListener(redacted) + if _, err := invLog.writer.Write([]byte(redacted)); err != nil { + invLog.writeErr = err + return err + } + return nil } func (invLog *invocationLog) Println(vals ...interface{}) { diff --git a/enterprise/server/cmd/ci_runner/main_test.go b/enterprise/server/cmd/ci_runner/main_test.go new file mode 100644 index 00000000000..c861293c445 --- /dev/null +++ b/enterprise/server/cmd/ci_runner/main_test.go @@ -0,0 +1,129 @@ +package main + +import ( + "bytes" + "errors" + "io" + "testing" + + "github.com/stretchr/testify/require" +) + +func newTestInvocationLog() (*invocationLog, *bytes.Buffer) { + invLog := &invocationLog{} + invLog.writeListener = func(string) {} + buf := &bytes.Buffer{} + invLog.writer = io.MultiWriter(&invLog.LockingBuffer, buf) + return invLog, buf +} + +func TestInvocationLogRedactsRemoteExecHeaderAcrossWrites(t *testing.T) { + invLog, buf := newTestInvocationLog() + + firstChunk := "common --remote_exec_header=secret-token" + n, err := invLog.Write([]byte(firstChunk)) + require.NoError(t, err) + require.Equal(t, len(firstChunk), n) + require.Equal(t, "", buf.String(), "should not flush before newline") + + secondChunk := "-continued\n" + n, err = invLog.Write([]byte(secondChunk)) + require.NoError(t, err) + require.Equal(t, len(secondChunk), n) + require.Equal(t, "common --remote_exec_header=\n", buf.String()) +} + +func TestInvocationLogRedactsMultipleLines(t *testing.T) { + invLog, buf := newTestInvocationLog() + + chunk := "line1 --remote_exec_header=first-secret\nline2 --remote_exec_header=second-secret\n" + n, err := invLog.Write([]byte(chunk)) + require.NoError(t, err) + require.Equal(t, len(chunk), n) + require.Equal(t, + "line1 --remote_exec_header=\nline2 --remote_exec_header=\n", + buf.String()) +} + +func TestInvocationLogFlushesPartialLine(t *testing.T) { + invLog, buf := newTestInvocationLog() + + chunk := "common --remote_exec_header=secret-token" + n, err := invLog.Write([]byte(chunk)) + require.NoError(t, err) + require.Equal(t, len(chunk), n) + require.Equal(t, "", buf.String(), "should not flush before newline") + + require.NoError(t, invLog.Flush()) + require.Equal(t, "common --remote_exec_header=", buf.String()) +} + +func TestInvocationLogHandlesCRLF(t *testing.T) { + invLog, buf := newTestInvocationLog() + + chunk := "line1 --remote_exec_header=secret\r\nline2\r\n" + n, err := invLog.Write([]byte(chunk)) + require.NoError(t, err) + require.Equal(t, len(chunk), n) + require.Equal(t, "line1 --remote_exec_header=\r\nline2\r\n", buf.String()) +} + +// failingWriter simulates a writer that fails after a certain number of writes +type failingWriter struct { + failAfter int + callCount int +} + +func (fw *failingWriter) Write(p []byte) (int, error) { + fw.callCount++ + if fw.callCount > fw.failAfter { + return 0, errors.New("write error") + } + return len(p), nil +} + +func TestInvocationLogFailsFastAfterWriteError(t *testing.T) { + invLog := &invocationLog{} + invLog.writeListener = func(string) {} + failWriter := &failingWriter{failAfter: 1} + invLog.writer = io.MultiWriter(&invLog.LockingBuffer, failWriter) + + // First write should succeed + n, err := invLog.Write([]byte("line1\n")) + require.NoError(t, err) + require.Equal(t, 6, n) + + // Second write should fail + n, err = invLog.Write([]byte("line2\n")) + require.Error(t, err) + require.Equal(t, 6, n) + require.Equal(t, "write error", err.Error()) + + // Third write should fail fast without attempting to write + n, err = invLog.Write([]byte("line3\n")) + require.Error(t, err) + require.Equal(t, 0, n) + require.Equal(t, "write error", err.Error()) + require.Equal(t, 2, failWriter.callCount, "should not attempt write after error") +} + +func TestInvocationLogFlushFailsFastAfterWriteError(t *testing.T) { + invLog := &invocationLog{} + invLog.writeListener = func(string) {} + failWriter := &failingWriter{failAfter: 1} + invLog.writer = io.MultiWriter(&invLog.LockingBuffer, failWriter) + + // First write should succeed + _, err := invLog.Write([]byte("line1\n")) + require.NoError(t, err) + + // Second write should fail + _, err = invLog.Write([]byte("line2\n")) + require.Error(t, err) + + // Flush should fail fast without attempting to write + err = invLog.Flush() + require.Error(t, err) + require.Equal(t, "write error", err.Error()) + require.Equal(t, 2, failWriter.callCount, "should not attempt write after error") +}