diff --git a/cmd/hauler/cli/store/add.go b/cmd/hauler/cli/store/add.go index ae532d44..d6bf1e85 100644 --- a/cmd/hauler/cli/store/add.go +++ b/cmd/hauler/cli/store/add.go @@ -36,14 +36,19 @@ func AddFileCmd(ctx context.Context, o *flags.AddFileOpts, s *store.Layout, refe if len(o.Name) > 0 { cfg.Name = o.Name } - return storeFile(ctx, s, cfg) + var allowInternal bool + if o.StoreRootOpts != nil { + allowInternal = o.AllowInternalTargets + } + return storeFile(ctx, s, cfg, allowInternal) } -func storeFile(ctx context.Context, s *store.Layout, fi v1.File) error { +func storeFile(ctx context.Context, s *store.Layout, fi v1.File, allowInternalTargets bool) error { l := log.FromContext(ctx) copts := getter.ClientOptions{ - NameOverride: fi.Name, + NameOverride: fi.Name, + AllowInternalTargets: allowInternalTargets, } f := file.NewFile(fi.Path, file.WithClient(getter.NewClient(copts))) diff --git a/cmd/hauler/cli/store/add_test.go b/cmd/hauler/cli/store/add_test.go index baad66cf..c1b6eac5 100644 --- a/cmd/hauler/cli/store/add_test.go +++ b/cmd/hauler/cli/store/add_test.go @@ -371,7 +371,7 @@ func TestStoreFile(t *testing.T) { tmp.Close() s := newTestStore(t) - if err := storeFile(ctx, s, v1.File{Path: tmp.Name()}); err != nil { + if err := storeFile(ctx, s, v1.File{Path: tmp.Name()}, false); err != nil { t.Fatalf("storeFile: %v", err) } assertArtifactInStore(t, s, filepath.Base(tmp.Name())) @@ -380,7 +380,7 @@ func TestStoreFile(t *testing.T) { t.Run("HTTP URL stored under basename", func(t *testing.T) { url := seedFileInHTTPServer(t, "script.sh", "#!/bin/sh\necho ok") s := newTestStore(t) - if err := storeFile(ctx, s, v1.File{Path: url}); err != nil { + if err := storeFile(ctx, s, v1.File{Path: url}, true); err != nil { t.Fatalf("storeFile: %v", err) } assertArtifactInStore(t, s, "script.sh") @@ -394,7 +394,7 @@ func TestStoreFile(t *testing.T) { tmp.Close() s := newTestStore(t) - if err := storeFile(ctx, s, v1.File{Path: tmp.Name(), Name: "custom.sh"}); err != nil { + if err := storeFile(ctx, s, v1.File{Path: tmp.Name(), Name: "custom.sh"}, false); err != nil { t.Fatalf("storeFile: %v", err) } assertArtifactInStore(t, s, "custom.sh") @@ -402,7 +402,7 @@ func TestStoreFile(t *testing.T) { t.Run("nonexistent local path returns error", func(t *testing.T) { s := newTestStore(t) - err := storeFile(ctx, s, v1.File{Path: "/nonexistent/path/missing-file.txt"}) + err := storeFile(ctx, s, v1.File{Path: "/nonexistent/path/missing-file.txt"}, false) if err == nil { t.Fatal("expected error for nonexistent path, got nil") } diff --git a/cmd/hauler/cli/store/copy_test.go b/cmd/hauler/cli/store/copy_test.go index 28b41f53..39f41550 100644 --- a/cmd/hauler/cli/store/copy_test.go +++ b/cmd/hauler/cli/store/copy_test.go @@ -242,7 +242,7 @@ func TestCopyCmd_Dir_Files(t *testing.T) { url := seedFileInHTTPServer(t, "data.txt", content) s := newTestStore(t) - if err := storeFile(ctx, s, v1.File{Path: url}); err != nil { + if err := storeFile(ctx, s, v1.File{Path: url}, true); err != nil { t.Fatalf("storeFile: %v", err) } diff --git a/cmd/hauler/cli/store/extract_test.go b/cmd/hauler/cli/store/extract_test.go index 498a1900..0b76ea9d 100644 --- a/cmd/hauler/cli/store/extract_test.go +++ b/cmd/hauler/cli/store/extract_test.go @@ -31,7 +31,7 @@ func TestExtractCmd_File(t *testing.T) { fileContent := "hello extract test" url := seedFileInHTTPServer(t, "extract-me.txt", fileContent) - if err := storeFile(ctx, s, v1.File{Path: url}); err != nil { + if err := storeFile(ctx, s, v1.File{Path: url}, true); err != nil { t.Fatalf("storeFile: %v", err) } @@ -533,7 +533,7 @@ func TestExtractCmd_SubstringMatch(t *testing.T) { fileContent := "substring match content" url := seedFileInHTTPServer(t, "extract-sub.txt", fileContent) - if err := storeFile(ctx, s, v1.File{Path: url}); err != nil { + if err := storeFile(ctx, s, v1.File{Path: url}, true); err != nil { t.Fatalf("storeFile: %v", err) } @@ -581,7 +581,7 @@ func TestExtractCmd_CosignArtifactsProduceNoContainerImageWarning(t *testing.T) // Seed a real file artifact so ExtractCmd finds something to extract. fileContent := "cosign-filter test file content" url := seedFileInHTTPServer(t, "sigtest.txt", fileContent) - if err := storeFile(ctx, s, v1.File{Path: url}); err != nil { + if err := storeFile(ctx, s, v1.File{Path: url}, true); err != nil { t.Fatalf("storeFile: %v", err) } diff --git a/cmd/hauler/cli/store/info_test.go b/cmd/hauler/cli/store/info_test.go index 51af7e02..a7c89769 100644 --- a/cmd/hauler/cli/store/info_test.go +++ b/cmd/hauler/cli/store/info_test.go @@ -199,7 +199,7 @@ func TestInfoCmd(t *testing.T) { t.Fatalf("write tmpFile: %v", err) } fi := v1.File{Path: tmpFile} - if err := storeFile(ctx, s, fi); err != nil { + if err := storeFile(ctx, s, fi, false); err != nil { t.Fatalf("storeFile: %v", err) } diff --git a/cmd/hauler/cli/store/lifecycle_test.go b/cmd/hauler/cli/store/lifecycle_test.go index f1f7a9e3..5c413beb 100644 --- a/cmd/hauler/cli/store/lifecycle_test.go +++ b/cmd/hauler/cli/store/lifecycle_test.go @@ -31,7 +31,7 @@ func TestLifecycle_FileArtifact_AddSaveLoadCopy(t *testing.T) { // Step 2: storeFile into store A. storeA := newTestStore(t) - if err := storeFile(ctx, storeA, v1.File{Path: url}); err != nil { + if err := storeFile(ctx, storeA, v1.File{Path: url}, true); err != nil { t.Fatalf("storeFile: %v", err) } assertArtifactInStore(t, storeA, "lifecycle.txt") @@ -252,10 +252,10 @@ func TestLifecycle_Remove_ThenSave(t *testing.T) { url2 := seedFileInHTTPServer(t, "remove-me.txt", "content to remove") storeA := newTestStore(t) - if err := storeFile(ctx, storeA, v1.File{Path: url1}); err != nil { + if err := storeFile(ctx, storeA, v1.File{Path: url1}, true); err != nil { t.Fatalf("storeFile keep-me: %v", err) } - if err := storeFile(ctx, storeA, v1.File{Path: url2}); err != nil { + if err := storeFile(ctx, storeA, v1.File{Path: url2}, true); err != nil { t.Fatalf("storeFile remove-me: %v", err) } diff --git a/cmd/hauler/cli/store/load.go b/cmd/hauler/cli/store/load.go index 62d19170..bf10492e 100644 --- a/cmd/hauler/cli/store/load.go +++ b/cmd/hauler/cli/store/load.go @@ -3,6 +3,7 @@ package store import ( "context" "encoding/json" + "fmt" "io" "net/url" "os" @@ -41,7 +42,7 @@ func LoadCmd(ctx context.Context, o *flags.LoadOpts, rso *flags.StoreRootOpts, r for _, fileName := range o.FileName { resolved := resolveHaulPath(fileName) l.Infof("loading haul [%s] to [%s]", resolved, o.StoreDir) - err := unarchiveLayoutTo(ctx, resolved, o.StoreDir, tempDir) + err := unarchiveLayoutTo(ctx, resolved, o.StoreDir, tempDir, rso.AllowInternalTargets) if err != nil { return err } @@ -52,13 +53,13 @@ func LoadCmd(ctx context.Context, o *flags.LoadOpts, rso *flags.StoreRootOpts, r } // accepts an archived OCI layout, extracts the contents to an existing OCI layout, and preserves the index -func unarchiveLayoutTo(ctx context.Context, haulPath string, dest string, tempDir string) error { +func unarchiveLayoutTo(ctx context.Context, haulPath string, dest string, tempDir string, allowInternalTargets bool) error { l := log.FromContext(ctx) if strings.HasPrefix(haulPath, "http://") || strings.HasPrefix(haulPath, "https://") { l.Debugf("detected remote archive... starting download... [%s]", haulPath) - h := getter.NewHttp() + h := getter.NewHttpWithOptions(getter.HttpOptions{AllowInternalTargets: allowInternalTargets}) parsedURL, err := url.Parse(haulPath) if err != nil { return err @@ -81,9 +82,13 @@ func unarchiveLayoutTo(ctx context.Context, haulPath string, dest string, tempDi } defer out.Close() - if _, err = io.Copy(out, rc); err != nil { + n, err := io.Copy(out, io.LimitReader(rc, consts.MaxDownloadBytes+1)) + if err != nil { return err } + if n > consts.MaxDownloadBytes { + return fmt.Errorf("remote archive at %s exceeds maximum allowed size (%d bytes)", haulPath, consts.MaxDownloadBytes) + } } // reassemble chunk files if haulPath matches the chunk naming pattern @@ -98,10 +103,18 @@ func unarchiveLayoutTo(ctx context.Context, haulPath string, dest string, tempDi } // ensure the incoming index.json has the correct annotations. - data, err := os.ReadFile(tempDir + "/index.json") + indexFile, err := os.Open(tempDir + "/index.json") + if err != nil { + return (err) + } + data, err := io.ReadAll(io.LimitReader(indexFile, consts.MaxManifestBytes+1)) + indexFile.Close() if err != nil { return (err) } + if int64(len(data)) > consts.MaxManifestBytes { + return fmt.Errorf("index.json exceeds maximum allowed size (%d bytes)", consts.MaxManifestBytes) + } var idx ocispec.Index if err := json.Unmarshal(data, &idx); err != nil { diff --git a/cmd/hauler/cli/store/load_test.go b/cmd/hauler/cli/store/load_test.go index 15d28047..37a249a4 100644 --- a/cmd/hauler/cli/store/load_test.go +++ b/cmd/hauler/cli/store/load_test.go @@ -71,7 +71,7 @@ func TestUnarchiveLayoutTo(t *testing.T) { destDir := t.TempDir() tempDir := t.TempDir() - if err := unarchiveLayoutTo(ctx, testHaulArchive, destDir, tempDir); err != nil { + if err := unarchiveLayoutTo(ctx, testHaulArchive, destDir, tempDir, false); err != nil { t.Fatalf("unarchiveLayoutTo: %v", err) } @@ -192,12 +192,16 @@ func TestLoadCmd_RemoteArchive(t *testing.T) { destDir := t.TempDir() remoteURL := srv.URL + "/haul.tar.zst" + // AllowInternalTargets=true because the test server binds to loopback. + rso := defaultRootOpts(destDir) + rso.AllowInternalTargets = true + o := &flags.LoadOpts{ - StoreRootOpts: defaultRootOpts(destDir), + StoreRootOpts: rso, FileName: []string{remoteURL}, } - if err := LoadCmd(ctx, o, defaultRootOpts(destDir), defaultCliOpts()); err != nil { + if err := LoadCmd(ctx, o, rso, defaultCliOpts()); err != nil { t.Fatalf("LoadCmd remote: %v", err) } @@ -262,7 +266,7 @@ func TestUnarchiveLayoutTo_AnnotationBackfill(t *testing.T) { // Step 4: Load the stripped archive. destDir := t.TempDir() tempDir := t.TempDir() - if err := unarchiveLayoutTo(ctx, strippedArchive, destDir, tempDir); err != nil { + if err := unarchiveLayoutTo(ctx, strippedArchive, destDir, tempDir, false); err != nil { t.Fatalf("unarchiveLayoutTo stripped: %v", err) } @@ -354,7 +358,7 @@ func TestUnarchiveLayoutTo_LegacyKindMigration(t *testing.T) { // Step 4: Load the legacy archive. destDir := t.TempDir() tempDir := t.TempDir() - if err := unarchiveLayoutTo(ctx, legacyArchive, destDir, tempDir); err != nil { + if err := unarchiveLayoutTo(ctx, legacyArchive, destDir, tempDir, false); err != nil { t.Fatalf("unarchiveLayoutTo legacy: %v", err) } diff --git a/cmd/hauler/cli/store/remove_test.go b/cmd/hauler/cli/store/remove_test.go index a013e3df..c81e8894 100644 --- a/cmd/hauler/cli/store/remove_test.go +++ b/cmd/hauler/cli/store/remove_test.go @@ -81,7 +81,7 @@ func TestRemoveCmd_Force(t *testing.T) { s := newTestStore(t) url := seedFileInHTTPServer(t, "removeme.txt", "file-to-remove") - if err := storeFile(ctx, s, v1.File{Path: url}); err != nil { + if err := storeFile(ctx, s, v1.File{Path: url}, true); err != nil { t.Fatalf("storeFile: %v", err) } @@ -133,10 +133,10 @@ func TestRemoveCmd_Force_MultipleMatches(t *testing.T) { url1 := seedFileInHTTPServer(t, "testfile-alpha.txt", "content-alpha") url2 := seedFileInHTTPServer(t, "testfile-beta.txt", "content-beta") - if err := storeFile(ctx, s, v1.File{Path: url1}); err != nil { + if err := storeFile(ctx, s, v1.File{Path: url1}, true); err != nil { t.Fatalf("storeFile alpha: %v", err) } - if err := storeFile(ctx, s, v1.File{Path: url2}); err != nil { + if err := storeFile(ctx, s, v1.File{Path: url2}, true); err != nil { t.Fatalf("storeFile beta: %v", err) } diff --git a/cmd/hauler/cli/store/save_test.go b/cmd/hauler/cli/store/save_test.go index 5902519e..9b4035db 100644 --- a/cmd/hauler/cli/store/save_test.go +++ b/cmd/hauler/cli/store/save_test.go @@ -118,7 +118,7 @@ func TestWriteExportsManifest_SkipsNonImages(t *testing.T) { url := seedFileInHTTPServer(t, "skip.sh", "#!/bin/sh\necho skip") s := newTestStore(t) - if err := storeFile(ctx, s, v1.File{Path: url}); err != nil { + if err := storeFile(ctx, s, v1.File{Path: url}, true); err != nil { t.Fatalf("storeFile: %v", err) } diff --git a/cmd/hauler/cli/store/sync.go b/cmd/hauler/cli/store/sync.go index 48283196..4a0ceb1f 100644 --- a/cmd/hauler/cli/store/sync.go +++ b/cmd/hauler/cli/store/sync.go @@ -82,11 +82,14 @@ func SyncCmd(ctx context.Context, o *flags.SyncOpts, s *store.Layout, rso *flags if err != nil { return err } - content, err := io.ReadAll(rc) + content, err := io.ReadAll(io.LimitReader(rc, consts.MaxManifestBytes+1)) rc.Close() if err != nil { return err } + if int64(len(content)) > consts.MaxManifestBytes { + return fmt.Errorf("product manifest for [%s] exceeds maximum allowed size (%d bytes)", productName, consts.MaxManifestBytes) + } // Ensure each manifest starts with a YAML document separator. if !strings.HasPrefix(string(content), "---") { @@ -161,7 +164,7 @@ func SyncCmd(ctx context.Context, o *flags.SyncOpts, s *store.Layout, rso *flags if strings.HasPrefix(haulPath, "http://") || strings.HasPrefix(haulPath, "https://") { l.Debugf("detected remote manifest... starting download... [%s]", haulPath) - h := getter.NewHttp() + h := getter.NewHttpWithOptions(getter.HttpOptions{AllowInternalTargets: rso.AllowInternalTargets}) parsedURL, err := url.Parse(haulPath) if err != nil { return err @@ -184,9 +187,13 @@ func SyncCmd(ctx context.Context, o *flags.SyncOpts, s *store.Layout, rso *flags } defer out.Close() - if _, err = io.Copy(out, rc); err != nil { + n, err := io.Copy(out, io.LimitReader(rc, consts.MaxDownloadBytes+1)) + if err != nil { return err } + if n > consts.MaxDownloadBytes { + return fmt.Errorf("remote manifest at %s exceeds maximum allowed size (%d bytes)", haulPath, consts.MaxDownloadBytes) + } } fi, err := os.Open(haulPath) @@ -213,7 +220,7 @@ func SyncCmd(ctx context.Context, o *flags.SyncOpts, s *store.Layout, rso *flags if strings.HasPrefix(haulPath, "http://") || strings.HasPrefix(haulPath, "https://") { l.Debugf("detected remote image.txt... starting download... [%s]", haulPath) - h := getter.NewHttp() + h := getter.NewHttpWithOptions(getter.HttpOptions{AllowInternalTargets: rso.AllowInternalTargets}) parsedURL, err := url.Parse(haulPath) if err != nil { return err @@ -236,9 +243,13 @@ func SyncCmd(ctx context.Context, o *flags.SyncOpts, s *store.Layout, rso *flags } defer out.Close() - if _, err = io.Copy(out, rc); err != nil { + n, err := io.Copy(out, io.LimitReader(rc, consts.MaxDownloadBytes+1)) + if err != nil { return err } + if n > consts.MaxDownloadBytes { + return fmt.Errorf("remote image.txt at %s exceeds maximum allowed size (%d bytes)", haulPath, consts.MaxDownloadBytes) + } } fi, err := os.Open(haulPath) @@ -296,7 +307,7 @@ func processContent(ctx context.Context, fi *os.File, o *flags.SyncOpts, s *stor return err } for _, f := range cfg.Spec.Files { - if err := storeFile(ctx, s, f); err != nil { + if err := storeFile(ctx, s, f, rso.AllowInternalTargets); err != nil { return err } } diff --git a/cmd/hauler/cli/store/sync_test.go b/cmd/hauler/cli/store/sync_test.go index 2a0e5985..8e512686 100644 --- a/cmd/hauler/cli/store/sync_test.go +++ b/cmd/hauler/cli/store/sync_test.go @@ -70,6 +70,7 @@ spec: fi := writeManifestFile(t, manifest) o := newSyncOpts(s.Root) + o.StoreRootOpts.AllowInternalTargets = true // test server is on loopback ro := defaultCliOpts() if err := processContent(ctx, fi, o, s, o.StoreRootOpts, ro); err != nil { @@ -216,6 +217,7 @@ spec: fi := writeManifestFile(t, manifest) o := newSyncOpts(s.Root) + o.StoreRootOpts.AllowInternalTargets = true // test servers are on loopback ro := defaultCliOpts() if err := processContent(ctx, fi, o, s, o.StoreRootOpts, ro); err != nil { @@ -260,6 +262,7 @@ spec: o := newSyncOpts(s.Root) o.FileName = []string{manifestPath} rso := defaultRootOpts(s.Root) + rso.AllowInternalTargets = true // test server is on loopback ro := defaultCliOpts() if err := SyncCmd(ctx, o, s, rso, ro); err != nil { @@ -411,6 +414,7 @@ func TestSyncCmd_ImageTxt_RemoteFile(t *testing.T) { o := newSyncOpts(s.Root) o.ImageTxt = []string{imageSrv.URL + "/images.txt"} rso := defaultRootOpts(s.Root) + rso.AllowInternalTargets = true // test server is on loopback ro := defaultCliOpts() if err := SyncCmd(ctx, o, s, rso, ro); err != nil { @@ -444,6 +448,7 @@ spec: o := newSyncOpts(s.Root) o.FileName = []string{manifestSrv.URL + "/manifest.yaml"} rso := defaultRootOpts(s.Root) + rso.AllowInternalTargets = true // both manifest server and file server are on loopback ro := defaultCliOpts() if err := SyncCmd(ctx, o, s, rso, ro); err != nil { diff --git a/install.sh b/install.sh index f762feda..f46fb96a 100755 --- a/install.sh +++ b/install.sh @@ -150,8 +150,8 @@ if [ ! -d "${HAULER_DIR}" ]; then mkdir -p "${HAULER_DIR}" || fatal "Failed to Create Hauler Directory: ${HAULER_DIR}" fi -# ensure hauler directory is writable (by user or root privileges) -chmod -R 777 "${HAULER_DIR}" || fatal "Failed to Update Permissions of Hauler Directory: ${HAULER_DIR}" +# ensure hauler directory is only accessible by the owner +chmod -R 0700 "${HAULER_DIR}" || fatal "Failed to Update Permissions of Hauler Directory: ${HAULER_DIR}" # change to hauler directory cd "${HAULER_DIR}" || fatal "Failed to Change Directory to Hauler Directory: ${HAULER_DIR}" diff --git a/internal/flags/store.go b/internal/flags/store.go index ddcf3e01..2491687d 100644 --- a/internal/flags/store.go +++ b/internal/flags/store.go @@ -13,9 +13,10 @@ import ( ) type StoreRootOpts struct { - StoreDir string - Retries int - TempOverride string + StoreDir string + Retries int + TempOverride string + AllowInternalTargets bool } func (o *StoreRootOpts) AddFlags(cmd *cobra.Command) { @@ -23,6 +24,7 @@ func (o *StoreRootOpts) AddFlags(cmd *cobra.Command) { pf.StringVarP(&o.StoreDir, "store", "s", "", "Set the directory to use for the content store") pf.IntVarP(&o.Retries, "retries", "r", consts.DefaultRetries, "Set the number of retries for operations") pf.StringVarP(&o.TempOverride, "tempdir", "t", "", "(Optional) Override the default temporary directory determined by the OS") + pf.BoolVar(&o.AllowInternalTargets, "allow-internal-targets", false, "(Optional) Allow fetching from private/loopback IP addresses (disables SSRF protection; for isolated internal CI use only)") } func (o *StoreRootOpts) Store(ctx context.Context) (*store.Layout, error) { diff --git a/pkg/archives/archiver.go b/pkg/archives/archiver.go index cb9320b9..5c418a84 100644 --- a/pkg/archives/archiver.go +++ b/pkg/archives/archiver.go @@ -12,6 +12,13 @@ import ( "hauler.dev/go/hauler/pkg/log" ) +// compressionZstd and archivalTar are package-level vars so tests can reference +// them without importing mholt/archives directly. +var ( + compressionZstd = archives.Zstd{} + archivalTar = archives.Tar{} +) + // maps to handle compression types var CompressionMap = map[string]archives.Compression{ "gz": archives.Gz{}, diff --git a/pkg/archives/limits_test.go b/pkg/archives/limits_test.go new file mode 100644 index 00000000..9c4a4d7c --- /dev/null +++ b/pkg/archives/limits_test.go @@ -0,0 +1,145 @@ +package archives + +import ( + "context" + "io" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/rs/zerolog" +) + +func limitsTestContext(t *testing.T) context.Context { + t.Helper() + l := zerolog.New(io.Discard) + return l.WithContext(context.Background()) +} + +// buildSmallTarZst creates a real .tar.zst in a temp dir using the production +// Archive() function, so its format is always compatible with Unarchive(). +func buildSmallTarZst(t *testing.T, entries map[string]string) string { + t.Helper() + srcDir := t.TempDir() + for name, body := range entries { + full := filepath.Join(srcDir, name) + if err := os.MkdirAll(filepath.Dir(full), 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(full, []byte(body), 0o644); err != nil { + t.Fatal(err) + } + } + out := filepath.Join(t.TempDir(), "test.tar.zst") + ctx := limitsTestContext(t) + if err := Archive(ctx, srcDir, out, compressionZstd, archivalTar); err != nil { + t.Fatalf("Archive(): %v", err) + } + return out +} + +// TestUnarchive_PerFileByteCap verifies that a single file exceeding +// maxArchiveFileBytes is rejected during extraction. +func TestUnarchive_PerFileByteCap(t *testing.T) { + body := make([]byte, 512) + archive := buildSmallTarZst(t, map[string]string{"big.bin": string(body)}) + + dst := t.TempDir() + ctx := limitsTestContext(t) + limits := extractionLimits{ + maxFileBytes: 256, // smaller than the 512-byte file + maxTotalBytes: 100 << 30, + maxFiles: 100_000, + } + if err := unarchiveWithLimits(ctx, archive, dst, limits); err == nil { + t.Fatal("unarchiveWithLimits() expected error for per-file cap, got nil") + } +} + +// TestUnarchive_AggregateByteCap verifies that the total extracted bytes across +// all files cannot exceed maxTotalBytes. +func TestUnarchive_AggregateByteCap(t *testing.T) { + // Two files each 256 bytes — aggregate cap set to 400, which is exceeded. + body := make([]byte, 256) + archive := buildSmallTarZst(t, map[string]string{ + "a.bin": string(body), + "b.bin": string(body), + }) + + dst := t.TempDir() + ctx := limitsTestContext(t) + limits := extractionLimits{ + maxFileBytes: 100 << 30, // per-file: no constraint + maxTotalBytes: 400, // aggregate: 256+256 = 512 > 400 + maxFiles: 100_000, + } + if err := unarchiveWithLimits(ctx, archive, dst, limits); err == nil { + t.Fatal("unarchiveWithLimits() expected error for aggregate cap, got nil") + } +} + +// TestUnarchive_FileCountCap verifies that an archive with more entries than +// maxFiles is rejected. +func TestUnarchive_FileCountCap(t *testing.T) { + entries := make(map[string]string, 6) + for i := 0; i < 6; i++ { + entries[filepath.Join("subdir", string(rune('a'+i))+".txt")] = "x" + } + archive := buildSmallTarZst(t, entries) + + dst := t.TempDir() + ctx := limitsTestContext(t) + limits := extractionLimits{ + maxFileBytes: 100 << 30, + maxTotalBytes: 100 << 30, + maxFiles: 3, // only 3 allowed; archive has 6 + } + if err := unarchiveWithLimits(ctx, archive, dst, limits); err == nil { + t.Fatal("unarchiveWithLimits() expected error for file-count cap, got nil") + } +} + +// TestUnarchive_WithinLimits confirms the default path succeeds for normal archives. +func TestUnarchive_WithinLimits(t *testing.T) { + archive := buildSmallTarZst(t, map[string]string{ + "file1.txt": "hello", + "file2.txt": "world", + }) + dst := t.TempDir() + ctx := limitsTestContext(t) + if err := Unarchive(ctx, archive, dst); err != nil { + t.Fatalf("Unarchive() unexpected error: %v", err) + } +} + +// TestUnarchive_DecompressionRatioCap verifies that an archive whose +// decompressed:compressed ratio exceeds consts.MaxDecompressionRatio is +// rejected during extraction. Highly redundant content (10 MiB of zero bytes) +// compresses with zstd to a few KiB, well above the 100x default ratio. +func TestUnarchive_DecompressionRatioCap(t *testing.T) { + body := make([]byte, 10<<20) // 10 MiB of zero bytes + archive := buildSmallTarZst(t, map[string]string{"compressible.bin": string(body)}) + + stat, err := os.Stat(archive) + if err != nil { + t.Fatalf("stat archive: %v", err) + } + t.Logf("archive size: %d bytes (decompressed: %d bytes, ratio: %.0fx)", + stat.Size(), len(body), float64(len(body))/float64(stat.Size())) + + dst := t.TempDir() + ctx := limitsTestContext(t) + limits := extractionLimits{ + maxFileBytes: 100 << 30, // not the constraint + maxTotalBytes: 100 << 30, // not the constraint + maxFiles: 100_000, + } + err = unarchiveWithLimits(ctx, archive, dst, limits) + if err == nil { + t.Fatal("unarchiveWithLimits() expected error for decompression ratio, got nil") + } + if !strings.Contains(err.Error(), "decompression ratio") { + t.Errorf("expected error mentioning decompression ratio, got: %v", err) + } +} diff --git a/pkg/archives/unarchiver.go b/pkg/archives/unarchiver.go index 7678f265..590ddb3e 100644 --- a/pkg/archives/unarchiver.go +++ b/pkg/archives/unarchiver.go @@ -10,11 +10,21 @@ import ( "sort" "strconv" "strings" + "sync/atomic" "github.com/mholt/archives" + "hauler.dev/go/hauler/pkg/consts" "hauler.dev/go/hauler/pkg/log" ) +// extractionLimits holds per-extraction caps to prevent decompression bombs +// and file-count exhaustion attacks. +type extractionLimits struct { + maxFileBytes int64 // maximum bytes for a single extracted file + maxTotalBytes int64 // maximum aggregate bytes across all extracted files + maxFiles int64 // maximum number of files extracted +} + const ( dirPermissions = 0o700 // default directory permissions filePermissions = 0o600 // default file permissions @@ -51,8 +61,17 @@ func setPermissions(path string, mode os.FileMode) error { return nil } +// extractionState tracks mutable counters shared across handleFile calls. +// archiveSize is the on-disk size of the input archive; when non-zero it is +// used to enforce a decompression-ratio bomb check against totalBytes. +type extractionState struct { + totalBytes atomic.Int64 + fileCount atomic.Int64 + archiveSize int64 +} + // handles the extraction of a file from the archive. -func handleFile(ctx context.Context, f archives.FileInfo, dst string) error { +func handleFileWithLimits(ctx context.Context, f archives.FileInfo, dst string, lim extractionLimits, state *extractionState) error { l := log.FromContext(ctx) l.Debugf("handling file [%s]", f.NameInArchive) @@ -70,7 +89,6 @@ func handleFile(ctx context.Context, f archives.FileInfo, dst string) error { // handle directories if f.IsDir() { - // create the directory with permissions from the archive if dirErr := createDirWithPermissions(ctx, dstPath, f.Mode()); dirErr != nil { return fmt.Errorf("failed to create directory: %w", dirErr) } @@ -84,6 +102,14 @@ func handleFile(ctx context.Context, f archives.FileInfo, dst string) error { return nil } + // enforce file count cap + if state != nil { + count := state.fileCount.Add(1) + if count > lim.maxFiles { + return fmt.Errorf("archive contains more than %d files (extraction limit exceeded)", lim.maxFiles) + } + } + // check and handle parent directory permissions originalMode, statErr := os.Stat(parentDir) if statErr != nil { @@ -97,7 +123,6 @@ func handleFile(ctx context.Context, f archives.FileInfo, dst string) error { return fmt.Errorf("failed to chmod parent directory: %w", chmodErr) } defer func() { - // restore the original permissions after writing if chmodErr := os.Chmod(parentDir, originalMode.Mode()); chmodErr != nil { l.Debugf("failed to restore original permissions for [%s]: %v", parentDir, chmodErr) } @@ -117,13 +142,72 @@ func handleFile(ctx context.Context, f archives.FileInfo, dst string) error { } defer dstFile.Close() - if _, copyErr := io.Copy(dstFile, reader); copyErr != nil { + // copy with per-file and aggregate byte caps + var totalPtr *atomic.Int64 + var archiveSize int64 + if state != nil { + totalPtr = &state.totalBytes + archiveSize = state.archiveSize + } + written, copyErr := copyBounded(dstFile, reader, lim.maxFileBytes, lim.maxTotalBytes, totalPtr, archiveSize) + if copyErr != nil { return fmt.Errorf("failed to copy: %w", copyErr) } + _ = written l.Debugf("successfully extracted file [%s]", dstPath) return nil } +// copyBounded copies from src to dst, enforcing a per-file cap (maxFile) and +// updating a shared total counter checked against maxTotal. total may be nil +// when called from the default (non-tracked) code path. When archiveSize > 0, +// it also enforces a decompression-ratio cap against consts.MaxDecompressionRatio. +func copyBounded(dst io.Writer, src io.Reader, maxFile, maxTotal int64, total *atomic.Int64, archiveSize int64) (int64, error) { + buf := make([]byte, 32*1024) + var fileWritten int64 + for { + nr, readErr := src.Read(buf) + if nr > 0 { + fileWritten += int64(nr) + if fileWritten > maxFile { + return fileWritten, fmt.Errorf("extracted file exceeds per-file size limit of %d bytes", maxFile) + } + if total != nil { + agg := total.Add(int64(nr)) + if agg > maxTotal { + return fileWritten, fmt.Errorf("total extracted bytes exceed aggregate limit of %d bytes", maxTotal) + } + if archiveSize > 0 && float64(agg) > float64(archiveSize)*consts.MaxDecompressionRatio { + return fileWritten, fmt.Errorf("decompression ratio exceeds %vx (likely zip bomb)", consts.MaxDecompressionRatio) + } + } + if _, writeErr := dst.Write(buf[:nr]); writeErr != nil { + return fileWritten, writeErr + } + } + if readErr == io.EOF { + break + } + if readErr != nil { + return fileWritten, readErr + } + } + return fileWritten, nil +} + +// handleFile delegates to handleFileWithLimits using the default (production) limits. +// It exists so the Unarchive function signature remains unchanged. +func handleFile(ctx context.Context, f archives.FileInfo, dst string, state *extractionState) error { + // The production limits are extremely generous and should never fire on + // legitimate hauler workloads; they exist to bound pathological inputs. + defaultLimits := extractionLimits{ + maxFileBytes: consts.MaxArchiveFileBytes, + maxTotalBytes: consts.MaxArchiveBytes, + maxFiles: consts.MaxArchiveFiles, + } + return handleFileWithLimits(ctx, f, dst, defaultLimits, state) +} + // unarchives a tarball to a directory, symlinks, and hardlinks are ignored func Unarchive(ctx context.Context, tarball, dst string) error { l := log.FromContext(ctx) @@ -134,6 +218,11 @@ func Unarchive(ctx context.Context, tarball, dst string) error { } defer archiveFile.Close() + stat, statErr := archiveFile.Stat() + if statErr != nil { + return fmt.Errorf("failed to stat tarball %s: %w", tarball, statErr) + } + format, input, identifyErr := archives.Identify(context.Background(), tarball, archiveFile) if identifyErr != nil { return fmt.Errorf("failed to identify format: %w", identifyErr) @@ -148,8 +237,9 @@ func Unarchive(ctx context.Context, tarball, dst string) error { return fmt.Errorf("failed to create destination directory: %w", dirErr) } + state := &extractionState{archiveSize: stat.Size()} handler := func(ctx context.Context, f archives.FileInfo) error { - return handleFile(ctx, f, dst) + return handleFile(ctx, f, dst, state) } if extractErr := extractor.Extract(context.Background(), input, handler); extractErr != nil { @@ -160,6 +250,48 @@ func Unarchive(ctx context.Context, tarball, dst string) error { return nil } +// unarchiveWithLimits is like Unarchive but enforces explicit extraction limits. +// It is used directly in tests to verify cap enforcement with small values. +func unarchiveWithLimits(ctx context.Context, tarball, dst string, lim extractionLimits) error { + l := log.FromContext(ctx) + l.Debugf("unarchiving [%s] to [%s] with limits", tarball, dst) + + archiveFile, openErr := os.Open(tarball) + if openErr != nil { + return fmt.Errorf("failed to open tarball %s: %w", tarball, openErr) + } + defer archiveFile.Close() + + stat, statErr := archiveFile.Stat() + if statErr != nil { + return fmt.Errorf("failed to stat tarball %s: %w", tarball, statErr) + } + + format, input, identifyErr := archives.Identify(context.Background(), tarball, archiveFile) + if identifyErr != nil { + return fmt.Errorf("failed to identify format: %w", identifyErr) + } + + extractor, ok := format.(archives.Extractor) + if !ok { + return fmt.Errorf("unsupported format for extraction") + } + + if dirErr := createDirWithPermissions(ctx, dst, dirPermissions); dirErr != nil { + return fmt.Errorf("failed to create destination directory: %w", dirErr) + } + + state := &extractionState{archiveSize: stat.Size()} + handler := func(ctx context.Context, f archives.FileInfo) error { + return handleFileWithLimits(ctx, f, dst, lim, state) + } + + if extractErr := extractor.Extract(context.Background(), input, handler); extractErr != nil { + return fmt.Errorf("failed to extract: %w", extractErr) + } + return nil +} + var chunkSuffixRe = regexp.MustCompile(`^(.+)_(\d+)$`) // chunkInfo checks whether archivePath matches the chunk naming pattern (_N). @@ -226,17 +358,27 @@ func JoinChunks(ctx context.Context, archivePath, tempDir string) (string, error } defer outf.Close() + var joinedTotal int64 for _, chunk := range matches { l.Debugf("joining chunk [%s]", chunk) cf, err := os.Open(chunk) if err != nil { return "", fmt.Errorf("failed to open chunk [%s]: %w", chunk, err) } - if _, err := io.Copy(outf, cf); err != nil { + remaining := consts.MaxArchiveBytes - joinedTotal + if remaining <= 0 { cf.Close() - return "", fmt.Errorf("failed to copy chunk [%s]: %w", chunk, err) + return "", fmt.Errorf("joined chunks exceed maximum allowed size (%d bytes)", consts.MaxArchiveBytes) } + n, err := io.Copy(outf, io.LimitReader(cf, remaining+1)) cf.Close() + if err != nil { + return "", fmt.Errorf("failed to copy chunk [%s]: %w", chunk, err) + } + joinedTotal += n + if joinedTotal > consts.MaxArchiveBytes { + return "", fmt.Errorf("joined chunks exceed maximum allowed size (%d bytes)", consts.MaxArchiveBytes) + } } l.Infof("joined %d chunk(s) into [%s]", len(matches), filepath.Base(joinedPath)) diff --git a/pkg/artifacts/file/file_test.go b/pkg/artifacts/file/file_test.go index 062b80b1..3ff5e91a 100644 --- a/pkg/artifacts/file/file_test.go +++ b/pkg/artifacts/file/file_test.go @@ -127,7 +127,7 @@ func setup() func() { mf := &mockFile{File: getter.NewFile(), fs: tfs} - mockHttp := getter.NewHttp() + mockHttp := getter.NewHttpWithOptions(getter.HttpOptions{AllowInternalTargets: true}) mhttp := afero.NewHttpFs(tfs) fileserver := http.FileServer(mhttp.Dir(".")) http.Handle("/", fileserver) diff --git a/pkg/consts/limits.go b/pkg/consts/limits.go new file mode 100644 index 00000000..9e6fd403 --- /dev/null +++ b/pkg/consts/limits.go @@ -0,0 +1,34 @@ +package consts + +const ( + // MaxDownloadBytes caps HTTP response bodies fetched by the HTTP getter. + // 10 GiB is deliberately generous for large hauler archives while still + // bounding runaway downloads. + MaxDownloadBytes int64 = 10 << 30 // 10 GiB + + // MaxManifestBytes caps OCI manifest and index reads to prevent a hostile + // registry from exhausting process memory. + MaxManifestBytes int64 = 16 << 20 // 16 MiB + + // MaxArchiveBytes caps the total uncompressed bytes written during archive + // extraction. Set to 100 GiB to comfortably exceed real-world haul sizes + // while still bounding zip-bomb attacks. + MaxArchiveBytes int64 = 100 << 30 // 100 GiB + + // MaxArchiveFileBytes caps the uncompressed size of a single file inside an + // archive. + MaxArchiveFileBytes int64 = 50 << 30 // 50 GiB + + // MaxArchiveFiles caps the number of files that may be extracted from a + // single archive. + MaxArchiveFiles int64 = 100_000 + + // MaxDecompressionRatio is the maximum allowed ratio of decompressed to + // compressed bytes. Archives exceeding this ratio are likely zip bombs. + MaxDecompressionRatio float64 = 100.0 + + // HTTPClientTimeout is the default timeout for outbound HTTP requests in + // the HTTP getter. Set to 30 minutes to handle large archive downloads + // without hanging indefinitely. + HTTPClientTimeout = 30 * 60 // seconds — resolved to time.Duration at use site +) diff --git a/pkg/content/chart/chart.go b/pkg/content/chart/chart.go index 5f86b125..d0f2acd0 100644 --- a/pkg/content/chart/chart.go +++ b/pkg/content/chart/chart.go @@ -49,6 +49,14 @@ func NewChart(name string, opts *action.ChartPathOptions) (*Chart, error) { client := action.NewInstall(actionConfig) client.ChartPathOptions.Version = opts.Version + client.ChartPathOptions.Verify = opts.Verify + client.ChartPathOptions.Username = opts.Username + client.ChartPathOptions.Password = opts.Password + client.ChartPathOptions.CertFile = opts.CertFile + client.ChartPathOptions.KeyFile = opts.KeyFile + client.ChartPathOptions.CaFile = opts.CaFile + client.ChartPathOptions.InsecureSkipTLSverify = opts.InsecureSkipTLSverify + client.ChartPathOptions.PlainHTTP = opts.PlainHTTP registryClient, err := newRegistryClient(client.CertFile, client.KeyFile, client.CaFile, client.InsecureSkipTLSverify, client.PlainHTTP) @@ -65,27 +73,17 @@ func NewChart(name string, opts *action.ChartPathOptions) (*Chart, error) { chartRef = opts.RepoURL + "/" + name } - // suppress helm downloader oci logs (stdout/stderr) - oldStdout := os.Stdout - oldStderr := os.Stderr - rOut, wOut, _ := os.Pipe() - rErr, wErr, _ := os.Pipe() - os.Stdout = wOut - os.Stderr = wErr - - chartPath, err := client.ChartPathOptions.LocateChart(chartRef, settings) - - wOut.Close() - wErr.Close() - os.Stdout = oldStdout - os.Stderr = oldStderr - _, _ = io.Copy(io.Discard, rOut) - _, _ = io.Copy(io.Discard, rErr) - rOut.Close() - rErr.Close() - - if err != nil { - return nil, err + // Suppress Helm downloader OCI logs via the goroutine-safe CaptureOutput. + // The mutex in CaptureOutput serializes all callers so concurrent chart + // fetches cannot race on os.Stdout/os.Stderr. + var chartPath string + captureErr := log.CaptureOutput(log.NewLogger(io.Discard), true, func() error { + var locErr error + chartPath, locErr = client.ChartPathOptions.LocateChart(chartRef, settings) + return locErr + }) + if captureErr != nil { + return nil, captureErr } return &Chart{ diff --git a/pkg/content/chart/chart_test.go b/pkg/content/chart/chart_test.go index 94133e92..0173e0d0 100644 --- a/pkg/content/chart/chart_test.go +++ b/pkg/content/chart/chart_test.go @@ -13,6 +13,40 @@ import ( "hauler.dev/go/hauler/pkg/content/chart" ) +// TestNewChart_VerifyAndAuthPropagated verifies that --verify and auth/TLS options +// in action.ChartPathOptions are actually wired through to the Helm client. +// With Verify=true the Helm client must reject a chart that has no .prov file. +func TestNewChart_VerifyAndAuthPropagated(t *testing.T) { + t.Run("verify flag causes failure on unsigned chart", func(t *testing.T) { + opts := &action.ChartPathOptions{ + RepoURL: "../../../testdata", + Verify: true, + } + _, err := chart.NewChart("rancher-cluster-templates-0.5.2.tgz", opts) + if err == nil { + t.Fatal("NewChart() expected error with Verify=true on unsigned chart, got nil") + } + }) + + t.Run("credentials are propagated and do not break local chart load", func(t *testing.T) { + // Credentials are passed but local chart loading does not require auth. + // This test ensures setting Username/Password does not silently break + // the happy path (i.e. they are stored, not discarded). + opts := &action.ChartPathOptions{ + RepoURL: "../../../testdata", + Username: "user", + Password: "pass", + } + c, err := chart.NewChart("rancher-cluster-templates-0.5.2.tgz", opts) + if err != nil { + t.Fatalf("NewChart() unexpected error: %v", err) + } + if c == nil { + t.Fatal("NewChart() returned nil chart") + } + }) +} + func TestNewChart(t *testing.T) { tempDir, err := os.MkdirTemp("", "hauler") if err != nil { diff --git a/pkg/getter/getter.go b/pkg/getter/getter.go index c022c212..7206ab6a 100644 --- a/pkg/getter/getter.go +++ b/pkg/getter/getter.go @@ -23,7 +23,8 @@ type Client struct { // ClientOptions provides options for the client type ClientOptions struct { - NameOverride string + NameOverride string + AllowInternalTargets bool } var ( @@ -44,7 +45,7 @@ func NewClient(opts ClientOptions) *Client { defaults := map[string]Getter{ "file": NewFile(), "directory": NewDirectory(), - "http": NewHttp(), + "http": NewHttpWithOptions(HttpOptions{AllowInternalTargets: opts.AllowInternalTargets}), } c := &Client{ diff --git a/pkg/getter/https.go b/pkg/getter/https.go index 4bd4d23e..e408f9e3 100644 --- a/pkg/getter/https.go +++ b/pkg/getter/https.go @@ -5,23 +5,190 @@ import ( "fmt" "io" "mime" + "net" "net/http" "net/url" "path/filepath" "strings" + "time" "hauler.dev/go/hauler/pkg/artifacts" "hauler.dev/go/hauler/pkg/consts" ) -type Http struct{} +// dialTimeout is the TCP connect and keep-alive timeout used by safeDial. +const dialTimeout = 30 * time.Second + +// HttpOptions configures the behaviour of the Http getter. +type HttpOptions struct { + // AllowInternalTargets disables the SSRF guard that is enforced at dial + // time by the custom DialContext. When false (the default), every IP + // address returned by DNS resolution is validated against isInternalIP + // before any connection is attempted, and the connection is made to the + // resolved IP literal directly so the check and the connect target the + // same address. Set to true only for isolated internal CI environments + // that intentionally fetch from private or loopback hosts. + AllowInternalTargets bool + + // Timeout overrides the default HTTP client timeout. + Timeout time.Duration + + // MaxBytes overrides the default per-response download cap. Zero means + // use consts.MaxDownloadBytes. + MaxBytes int64 +} + +// Http is the Getter for http/https URLs. +type Http struct { + opts HttpOptions + client *http.Client + maxBytes int64 +} func NewHttp() *Http { - return &Http{} + return NewHttpWithOptions(HttpOptions{}) +} + +func NewHttpWithTimeout(d time.Duration) *Http { + return NewHttpWithOptions(HttpOptions{Timeout: d}) +} + +func NewHttpWithOptions(opts HttpOptions) *Http { + timeout := time.Duration(consts.HTTPClientTimeout) * time.Second + if opts.Timeout > 0 { + timeout = opts.Timeout + } + + maxBytes := opts.MaxBytes + if maxBytes <= 0 { + maxBytes = consts.MaxDownloadBytes + } + + h := &Http{opts: opts, maxBytes: maxBytes} + + baseDialer := &net.Dialer{ + Timeout: dialTimeout, + KeepAlive: dialTimeout, + } + transport := &http.Transport{ + DialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + return h.safeDial(ctx, baseDialer, network, address) + }, + TLSHandshakeTimeout: 10 * time.Second, + ResponseHeaderTimeout: 30 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + MaxIdleConns: 10, + IdleConnTimeout: 90 * time.Second, + // Do NOT set TLSClientConfig: Go derives tls.Config.ServerName from + // the request URL hostname, so TLS cert verification continues to use + // the hostname even though we dial by IP literal. + } + + h.client = &http.Client{ + Timeout: timeout, + Transport: transport, + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return h.validateRequest(req) + }, + } + return h +} + +// validateRequest enforces scheme restrictions. It is called for the initial +// request and each redirect hop via CheckRedirect. IP/host validation is +// performed at dial time by safeDial so the checked address is exactly the +// address we connect to, eliminating the DNS-rebinding TOCTOU. +func (h *Http) validateRequest(req *http.Request) error { + switch req.URL.Scheme { + case "http", "https": + default: + return fmt.Errorf("scheme %q is not allowed; only http and https are permitted", req.URL.Scheme) + } + return nil +} + +// safeDial resolves address to candidate IPs, rejects internal IPs (when +// AllowInternalTargets=false), and dials the resolved IP literal directly. +// Performing both the IP check and the connect against the same resolved +// address eliminates the DNS-rebinding TOCTOU that exists when validation +// and connect each perform their own independent resolution. +func (h *Http) safeDial(ctx context.Context, dialer *net.Dialer, network, address string) (net.Conn, error) { + host, port, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + + ips, err := net.DefaultResolver.LookupIPAddr(ctx, host) + if err != nil { + return nil, fmt.Errorf("failed to resolve %s: %w", host, err) + } + if len(ips) == 0 { + return nil, fmt.Errorf("no addresses found for %s", host) + } + + if !h.opts.AllowInternalTargets { + // Reject if ANY candidate is internal — prevents an attacker from + // returning [public, private] and hoping fallback hits the private IP. + for _, ipAddr := range ips { + if isInternalIP(ipAddr.IP) { + return nil, fmt.Errorf("dial to %s rejected: resolved to internal address %s (use --allow-internal-targets to override)", host, ipAddr.IP) + } + } + } + + // Dial each candidate by IP literal until one succeeds. Bracket IPv6. + var lastErr error + for _, ipAddr := range ips { + ipStr := ipAddr.IP.String() + if ipAddr.IP.To4() == nil { + ipStr = "[" + ipStr + "]" + } + conn, dialErr := dialer.DialContext(ctx, network, net.JoinHostPort(ipStr, port)) + if dialErr == nil { + return conn, nil + } + lastErr = dialErr + } + return nil, lastErr +} + +// isInternalIP reports whether ip is in a private, loopback, link-local, or +// unique-local range. +func isInternalIP(ip net.IP) bool { + private := []net.IPNet{ + // IPv4 loopback + {IP: net.IP{127, 0, 0, 0}, Mask: net.CIDRMask(8, 32)}, + // RFC-1918 private ranges + {IP: net.IP{10, 0, 0, 0}, Mask: net.CIDRMask(8, 32)}, + {IP: net.IP{172, 16, 0, 0}, Mask: net.CIDRMask(12, 32)}, + {IP: net.IP{192, 168, 0, 0}, Mask: net.CIDRMask(16, 32)}, + // IPv4 link-local + {IP: net.IP{169, 254, 0, 0}, Mask: net.CIDRMask(16, 32)}, + } + if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + return true + } + for i := range private { + if private[i].Contains(ip) { + return true + } + } + // IPv6 unique-local (fc00::/7) + if len(ip) == 16 && (ip[0]&0xfe) == 0xfc { + return true + } + return false } func (h Http) Name(u *url.URL) string { - resp, err := http.Head(u.String()) + req, err := http.NewRequest(http.MethodHead, u.String(), nil) + if err != nil { + return "" + } + if err := h.validateRequest(req); err != nil { + return "" + } + resp, err := h.client.Do(req) if err != nil { return "" } @@ -46,7 +213,15 @@ func (h Http) Name(u *url.URL) string { } func (h Http) Open(ctx context.Context, u *url.URL) (io.ReadCloser, error) { - resp, err := http.Get(u.String()) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) + if err != nil { + return nil, err + } + if err := h.validateRequest(req); err != nil { + return nil, err + } + + resp, err := h.client.Do(req) if err != nil { return nil, err } @@ -54,7 +229,22 @@ func (h Http) Open(ctx context.Context, u *url.URL) (io.ReadCloser, error) { resp.Body.Close() return nil, fmt.Errorf("unexpected status fetching %s: %s", u.String(), resp.Status) } - return resp.Body, nil + + // Validate Content-Length header if the server provided one. + if resp.ContentLength > h.maxBytes { + resp.Body.Close() + return nil, fmt.Errorf("remote file at %s exceeds maximum allowed size (%d bytes)", u.String(), h.maxBytes) + } + + // Wrap the body so a server that lies about Content-Length (or omits it) + // is still bounded. Read maxBytes+1 so we can distinguish "exactly at cap" + // from "exceeded cap". + limited := &limitedReadCloser{ + r: io.LimitReader(resp.Body, h.maxBytes+1), + c: resp.Body, + max: h.maxBytes, + } + return limited, nil } func (h Http) Detect(u *url.URL) bool { @@ -72,6 +262,28 @@ func (h *Http) Config(u *url.URL) artifacts.Config { return artifacts.ToConfig(c, artifacts.WithConfigMediaType(consts.FileHttpConfigMediaType)) } +// limitedReadCloser wraps an io.LimitReader and returns an error when the cap +// is hit rather than silently returning EOF. +type limitedReadCloser struct { + r io.Reader + c io.Closer + max int64 + read int64 +} + +func (l *limitedReadCloser) Read(p []byte) (int, error) { + n, err := l.r.Read(p) + l.read += int64(n) + if l.read > l.max { + return n, fmt.Errorf("download exceeded maximum allowed size of %d bytes", l.max) + } + return n, err +} + +func (l *limitedReadCloser) Close() error { + return l.c.Close() +} + type httpConfig struct { config `json:",inline,omitempty"` } diff --git a/pkg/getter/https_security_test.go b/pkg/getter/https_security_test.go new file mode 100644 index 00000000..3c212a82 --- /dev/null +++ b/pkg/getter/https_security_test.go @@ -0,0 +1,198 @@ +package getter_test + +import ( + "context" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" + + "hauler.dev/go/hauler/pkg/getter" +) + +// --- A3: Unbounded download protection --- + +// TestHttp_Open_RejectsOversizedBody verifies that Open wraps the response body +// in an io.LimitReader so a server that streams more than MaxBytes causes an +// error rather than exhausting disk/memory. +func TestHttp_Open_RejectsOversizedBody(t *testing.T) { + const cap int64 = 1024 // 1 KiB test cap + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Stream cap+1 bytes so the limiter fires. + payload := strings.Repeat("x", int(cap)+1) + fmt.Fprint(w, payload) + })) + defer srv.Close() + + // AllowInternalTargets=true because the test server binds to loopback. + h := getter.NewHttpWithOptions(getter.HttpOptions{ + AllowInternalTargets: true, + MaxBytes: cap, + }) + u, _ := url.Parse(srv.URL + "/big") + // The size cap must be enforced either at Open() (via Content-Length header) + // or at read time (via LimitReader). Both are acceptable. + rc, openErr := h.Open(context.Background(), u) + if openErr != nil { + // Content-Length header triggered the cap early — that is correct. + return + } + defer rc.Close() + + _, readErr := io.ReadAll(rc) + if readErr == nil { + t.Fatal("expected an error from Open() or ReadAll() for oversized body, got neither") + } +} + +// TestHttp_Open_Timeout verifies that Open uses a client with a timeout so +// Slowloris-style servers do not hang indefinitely. +func TestHttp_Open_Timeout(t *testing.T) { + if testing.Short() { + t.Skip("skipping slow-server test in short mode") + } + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Never write anything — simulate a stalled server. + <-r.Context().Done() + })) + defer srv.Close() + + h := getter.NewHttpWithOptions(getter.HttpOptions{ + AllowInternalTargets: true, + Timeout: 50 * time.Millisecond, + }) + u, _ := url.Parse(srv.URL + "/slow") + _, err := h.Open(context.Background(), u) + if err == nil { + t.Fatal("Open() expected timeout error, got nil") + } +} + +// --- A4: SSRF protection --- + +// TestHttp_Open_RejectsNonHTTPScheme verifies that file://, gopher://, etc. +// are rejected before any network call is made. +func TestHttp_Open_RejectsNonHTTPScheme(t *testing.T) { + for _, scheme := range []string{"file", "gopher", "ftp", "data"} { + t.Run(scheme, func(t *testing.T) { + h := getter.NewHttp() + u, _ := url.Parse(scheme + "://some/path") + _, err := h.Open(context.Background(), u) + if err == nil { + t.Fatalf("Open() expected error for scheme %q, got nil", scheme) + } + }) + } +} + +// TestHttp_Open_RejectsPrivateIPByDefault verifies that private/loopback +// addresses are blocked unless AllowInternalTargets is set. +func TestHttp_Open_RejectsPrivateIPByDefault(t *testing.T) { + // Use a real local server so the DNS resolution step completes. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "secret") + })) + defer srv.Close() + + // srv.URL is http://127.0.0.1: — a loopback address. + h := getter.NewHttp() // default: AllowInternalTargets = false + u, _ := url.Parse(srv.URL + "/internal") + _, err := h.Open(context.Background(), u) + if err == nil { + t.Fatal("Open() expected SSRF rejection for loopback address, got nil") + } +} + +// TestHttp_Open_AllowsPrivateIPWithFlag verifies that an explicit opt-in flag +// lifts the private-IP restriction, enabling internal CI use cases. +func TestHttp_Open_AllowsPrivateIPWithFlag(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "ok") + })) + defer srv.Close() + + h := getter.NewHttpWithOptions(getter.HttpOptions{AllowInternalTargets: true}) + u, _ := url.Parse(srv.URL + "/internal") + rc, err := h.Open(context.Background(), u) + if err != nil { + t.Fatalf("Open() unexpected error with AllowInternalTargets=true: %v", err) + } + rc.Close() +} + +// TestHttp_Open_RejectsRedirectToPrivateIP verifies that CheckRedirect +// re-validates the destination on every hop, blocking public→private pivots. +func TestHttp_Open_RejectsRedirectToPrivateIP(t *testing.T) { + // The "private" server just responds 200. + privateSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "private data") + })) + defer privateSrv.Close() + + // The "public" server redirects to the private server. + publicSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, privateSrv.URL+"/secret", http.StatusFound) + })) + defer publicSrv.Close() + + h := getter.NewHttp() // default: AllowInternalTargets = false + u, _ := url.Parse(publicSrv.URL + "/redirect") + _, err := h.Open(context.Background(), u) + if err == nil { + t.Fatal("Open() expected error when redirect targets private IP, got nil") + } +} + +// TestHttp_Open_RejectsIPLiteralPrivate verifies that URLs containing a private, +// loopback, link-local, or IMDS IP literal are rejected at dial time without +// any external network round-trip. +func TestHttp_Open_RejectsIPLiteralPrivate(t *testing.T) { + cases := []string{ + "http://127.0.0.1:9/anything", + "http://10.0.0.1:9/anything", + "http://192.168.0.1:9/anything", + "http://169.254.169.254/latest/meta", + } + h := getter.NewHttp() // default AllowInternalTargets=false + for _, raw := range cases { + t.Run(raw, func(t *testing.T) { + u, _ := url.Parse(raw) + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + _, err := h.Open(ctx, u) + if err == nil { + t.Fatalf("Open(%s) expected SSRF rejection, got nil", raw) + } + }) + } +} + +// TestHttp_Open_RejectsHostnameResolvingToLoopback verifies that the dial-time +// check inspects the *resolved* IP, not just IP literals. This is the +// meaningful demonstration that DNS rebinding is closed: even when the URL +// hostname is "localhost" (not an IP literal), the dial fires the SSRF check +// against the resolved 127.0.0.1. +func TestHttp_Open_RejectsHostnameResolvingToLoopback(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprint(w, "secret") + })) + defer srv.Close() + + parsed, err := url.Parse(srv.URL) + if err != nil { + t.Fatalf("parse: %v", err) + } + parsed.Host = "localhost:" + parsed.Port() + + h := getter.NewHttp() // default AllowInternalTargets=false + _, err = h.Open(context.Background(), parsed) + if err == nil { + t.Fatal("Open() expected SSRF rejection for hostname resolving to loopback, got nil") + } +} diff --git a/pkg/log/logcapture.go b/pkg/log/logcapture.go index b36d10a5..bee9d538 100644 --- a/pkg/log/logcapture.go +++ b/pkg/log/logcapture.go @@ -42,8 +42,17 @@ func logStream(reader io.Reader, customWriter *CustomWriter, wg *sync.WaitGroup) } } -// CaptureOutput redirects stdout and stderr to custom loggers and executes the provided function -func CaptureOutput(logger Logger, debug bool, fn func() error) error { +// captureOutputMu serializes all CaptureOutput calls so that concurrent +// goroutines do not race on os.Stdout/os.Stderr. +var captureOutputMu sync.Mutex + +// CaptureOutput redirects stdout and stderr to custom loggers and executes the provided function. +// It is goroutine-safe: concurrent calls are serialized. A panic inside fn is +// recovered, os.Stdout/os.Stderr are restored, and an error is returned. +func CaptureOutput(logger Logger, debug bool, fn func() error) (retErr error) { + captureOutputMu.Lock() + defer captureOutputMu.Unlock() + // Create pipes for capturing stdout and stderr stdoutReader, stdoutWriter, err := os.Pipe() if err != nil { @@ -51,6 +60,8 @@ func CaptureOutput(logger Logger, debug bool, fn func() error) error { } stderrReader, stderrWriter, err := os.Pipe() if err != nil { + stdoutReader.Close() + stdoutWriter.Close() return fmt.Errorf("failed to create stderr pipe: %w", err) } @@ -62,6 +73,15 @@ func CaptureOutput(logger Logger, debug bool, fn func() error) error { os.Stdout = stdoutWriter os.Stderr = stderrWriter + // Ensure FDs are always restored — even on panic. + defer func() { + os.Stdout = origStdout + os.Stderr = origStderr + if r := recover(); r != nil { + retErr = fmt.Errorf("CaptureOutput: recovered from panic: %v", r) + } + }() + // Use WaitGroup to wait for logging goroutines to finish var wg sync.WaitGroup wg.Add(2) @@ -78,15 +98,21 @@ func CaptureOutput(logger Logger, debug bool, fn func() error) error { // Run the provided function in a separate goroutine fnErr := make(chan error, 1) go func() { + defer func() { + // Propagate panics from fn back to the main goroutine via the error channel. + if r := recover(); r != nil { + fnErr <- fmt.Errorf("panic: %v", r) + } + stdoutWriter.Close() // Close writers to signal EOF to readers + stderrWriter.Close() + }() fnErr <- fn() - stdoutWriter.Close() // Close writers to signal EOF to readers - stderrWriter.Close() }() // Wait for logging goroutines to finish wg.Wait() - // Restore original stdout and stderr + // Restore stdout/stderr early so the deferred restore is a no-op. os.Stdout = origStdout os.Stderr = origStderr diff --git a/pkg/log/logcapture_test.go b/pkg/log/logcapture_test.go new file mode 100644 index 00000000..6f2cd397 --- /dev/null +++ b/pkg/log/logcapture_test.go @@ -0,0 +1,103 @@ +package log_test + +import ( + "fmt" + "io" + "os" + "sync" + "sync/atomic" + "testing" + + haulerlog "hauler.dev/go/hauler/pkg/log" +) + +// TestCaptureOutput_ConcurrentSafe verifies that concurrent calls to +// CaptureOutput, each writing to os.Stdout from inside fn, do not race on +// os.Stdout/os.Stderr. The mutex serializes the reassignment window so +// concurrent capturers cannot see each other's writes. Run with -race to +// catch any residual data race on os.Stdout/os.Stderr. +func TestCaptureOutput_ConcurrentSafe(t *testing.T) { + const goroutines = 20 + + l := haulerlog.NewLogger(io.Discard) + + var wg sync.WaitGroup + errs := make(chan error, goroutines) + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + err := haulerlog.CaptureOutput(l, true, func() error { + _, _ = fmt.Fprintf(os.Stdout, "goroutine-%d-stdout\n", idx) + _, _ = fmt.Fprintf(os.Stderr, "goroutine-%d-stderr\n", idx) + return nil + }) + if err != nil { + errs <- err + } + }(i) + } + + wg.Wait() + close(errs) + + for err := range errs { + t.Errorf("CaptureOutput() concurrent call returned error: %v", err) + } +} + +// TestCaptureOutput_RestoresFDsAfterUse confirms that os.Stdout/os.Stderr are +// restored to their original values after CaptureOutput returns, so subsequent +// writes do not target a closed pipe. +func TestCaptureOutput_RestoresFDsAfterUse(t *testing.T) { + l := haulerlog.NewLogger(io.Discard) + origStdout := os.Stdout + origStderr := os.Stderr + + var captured atomic.Bool + if err := haulerlog.CaptureOutput(l, true, func() error { + captured.Store(true) + return nil + }); err != nil { + t.Fatalf("CaptureOutput() returned error: %v", err) + } + if !captured.Load() { + t.Fatal("fn was not invoked") + } + + if os.Stdout != origStdout { + t.Error("os.Stdout was not restored after CaptureOutput returned") + } + if os.Stderr != origStderr { + t.Error("os.Stderr was not restored after CaptureOutput returned") + } +} + +// TestCaptureOutput_PanicRestoresFDs verifies that a panic inside the +// provided function does not leave os.Stdout/os.Stderr permanently redirected +// and that subsequent calls succeed normally. +func TestCaptureOutput_PanicRestoresFDs(t *testing.T) { + l := haulerlog.NewLogger(io.Discard) + origStdout := os.Stdout + origStderr := os.Stderr + + err := haulerlog.CaptureOutput(l, true, func() error { + panic("intentional panic for test") + }) + if err == nil { + t.Fatal("CaptureOutput() expected error after panic, got nil") + } + + if os.Stdout != origStdout { + t.Error("os.Stdout was not restored after CaptureOutput panic") + } + if os.Stderr != origStderr { + t.Error("os.Stderr was not restored after CaptureOutput panic") + } + + // Run a follow-up capture to confirm the global state is healthy. + if err := haulerlog.CaptureOutput(l, true, func() error { return nil }); err != nil { + t.Errorf("subsequent CaptureOutput() failed after panic recovery: %v", err) + } +} diff --git a/pkg/store/store.go b/pkg/store/store.go index ada5a714..68ee0768 100644 --- a/pkg/store/store.go +++ b/pkg/store/store.go @@ -593,10 +593,13 @@ func (l *Layout) copyDescriptorGraph(ctx context.Context, desc ocispec.Descripto } }() - data, err := io.ReadAll(rc) + data, err := io.ReadAll(io.LimitReader(rc, consts.MaxManifestBytes+1)) if err != nil { return fmt.Errorf("failed to read manifest: %w", err) } + if int64(len(data)) > consts.MaxManifestBytes { + return fmt.Errorf("manifest exceeds maximum allowed size (%d bytes)", consts.MaxManifestBytes) + } var manifest ocispec.Manifest if err := json.Unmarshal(data, &manifest); err != nil { @@ -632,10 +635,13 @@ func (l *Layout) copyDescriptorGraph(ctx context.Context, desc ocispec.Descripto } }() - data, err := io.ReadAll(rc) + data, err := io.ReadAll(io.LimitReader(rc, consts.MaxManifestBytes+1)) if err != nil { return fmt.Errorf("failed to read index: %w", err) } + if int64(len(data)) > consts.MaxManifestBytes { + return fmt.Errorf("index exceeds maximum allowed size (%d bytes)", consts.MaxManifestBytes) + } var index ocispec.Index if err := json.Unmarshal(data, &index); err != nil {