diff --git a/cmd/hauler/cli/store/add.go b/cmd/hauler/cli/store/add.go
index ae532d44..d6bf1e85 100644
--- a/cmd/hauler/cli/store/add.go
+++ b/cmd/hauler/cli/store/add.go
@@ -36,14 +36,19 @@ func AddFileCmd(ctx context.Context, o *flags.AddFileOpts, s *store.Layout, refe
if len(o.Name) > 0 {
cfg.Name = o.Name
}
- return storeFile(ctx, s, cfg)
+ var allowInternal bool
+ if o.StoreRootOpts != nil {
+ allowInternal = o.AllowInternalTargets
+ }
+ return storeFile(ctx, s, cfg, allowInternal)
}
-func storeFile(ctx context.Context, s *store.Layout, fi v1.File) error {
+func storeFile(ctx context.Context, s *store.Layout, fi v1.File, allowInternalTargets bool) error {
l := log.FromContext(ctx)
copts := getter.ClientOptions{
- NameOverride: fi.Name,
+ NameOverride: fi.Name,
+ AllowInternalTargets: allowInternalTargets,
}
f := file.NewFile(fi.Path, file.WithClient(getter.NewClient(copts)))
diff --git a/cmd/hauler/cli/store/add_test.go b/cmd/hauler/cli/store/add_test.go
index baad66cf..c1b6eac5 100644
--- a/cmd/hauler/cli/store/add_test.go
+++ b/cmd/hauler/cli/store/add_test.go
@@ -371,7 +371,7 @@ func TestStoreFile(t *testing.T) {
tmp.Close()
s := newTestStore(t)
- if err := storeFile(ctx, s, v1.File{Path: tmp.Name()}); err != nil {
+ if err := storeFile(ctx, s, v1.File{Path: tmp.Name()}, false); err != nil {
t.Fatalf("storeFile: %v", err)
}
assertArtifactInStore(t, s, filepath.Base(tmp.Name()))
@@ -380,7 +380,7 @@ func TestStoreFile(t *testing.T) {
t.Run("HTTP URL stored under basename", func(t *testing.T) {
url := seedFileInHTTPServer(t, "script.sh", "#!/bin/sh\necho ok")
s := newTestStore(t)
- if err := storeFile(ctx, s, v1.File{Path: url}); err != nil {
+ if err := storeFile(ctx, s, v1.File{Path: url}, true); err != nil {
t.Fatalf("storeFile: %v", err)
}
assertArtifactInStore(t, s, "script.sh")
@@ -394,7 +394,7 @@ func TestStoreFile(t *testing.T) {
tmp.Close()
s := newTestStore(t)
- if err := storeFile(ctx, s, v1.File{Path: tmp.Name(), Name: "custom.sh"}); err != nil {
+ if err := storeFile(ctx, s, v1.File{Path: tmp.Name(), Name: "custom.sh"}, false); err != nil {
t.Fatalf("storeFile: %v", err)
}
assertArtifactInStore(t, s, "custom.sh")
@@ -402,7 +402,7 @@ func TestStoreFile(t *testing.T) {
t.Run("nonexistent local path returns error", func(t *testing.T) {
s := newTestStore(t)
- err := storeFile(ctx, s, v1.File{Path: "/nonexistent/path/missing-file.txt"})
+ err := storeFile(ctx, s, v1.File{Path: "/nonexistent/path/missing-file.txt"}, false)
if err == nil {
t.Fatal("expected error for nonexistent path, got nil")
}
diff --git a/cmd/hauler/cli/store/copy_test.go b/cmd/hauler/cli/store/copy_test.go
index 28b41f53..39f41550 100644
--- a/cmd/hauler/cli/store/copy_test.go
+++ b/cmd/hauler/cli/store/copy_test.go
@@ -242,7 +242,7 @@ func TestCopyCmd_Dir_Files(t *testing.T) {
url := seedFileInHTTPServer(t, "data.txt", content)
s := newTestStore(t)
- if err := storeFile(ctx, s, v1.File{Path: url}); err != nil {
+ if err := storeFile(ctx, s, v1.File{Path: url}, true); err != nil {
t.Fatalf("storeFile: %v", err)
}
diff --git a/cmd/hauler/cli/store/extract_test.go b/cmd/hauler/cli/store/extract_test.go
index 498a1900..0b76ea9d 100644
--- a/cmd/hauler/cli/store/extract_test.go
+++ b/cmd/hauler/cli/store/extract_test.go
@@ -31,7 +31,7 @@ func TestExtractCmd_File(t *testing.T) {
fileContent := "hello extract test"
url := seedFileInHTTPServer(t, "extract-me.txt", fileContent)
- if err := storeFile(ctx, s, v1.File{Path: url}); err != nil {
+ if err := storeFile(ctx, s, v1.File{Path: url}, true); err != nil {
t.Fatalf("storeFile: %v", err)
}
@@ -533,7 +533,7 @@ func TestExtractCmd_SubstringMatch(t *testing.T) {
fileContent := "substring match content"
url := seedFileInHTTPServer(t, "extract-sub.txt", fileContent)
- if err := storeFile(ctx, s, v1.File{Path: url}); err != nil {
+ if err := storeFile(ctx, s, v1.File{Path: url}, true); err != nil {
t.Fatalf("storeFile: %v", err)
}
@@ -581,7 +581,7 @@ func TestExtractCmd_CosignArtifactsProduceNoContainerImageWarning(t *testing.T)
// Seed a real file artifact so ExtractCmd finds something to extract.
fileContent := "cosign-filter test file content"
url := seedFileInHTTPServer(t, "sigtest.txt", fileContent)
- if err := storeFile(ctx, s, v1.File{Path: url}); err != nil {
+ if err := storeFile(ctx, s, v1.File{Path: url}, true); err != nil {
t.Fatalf("storeFile: %v", err)
}
diff --git a/cmd/hauler/cli/store/info_test.go b/cmd/hauler/cli/store/info_test.go
index 51af7e02..a7c89769 100644
--- a/cmd/hauler/cli/store/info_test.go
+++ b/cmd/hauler/cli/store/info_test.go
@@ -199,7 +199,7 @@ func TestInfoCmd(t *testing.T) {
t.Fatalf("write tmpFile: %v", err)
}
fi := v1.File{Path: tmpFile}
- if err := storeFile(ctx, s, fi); err != nil {
+ if err := storeFile(ctx, s, fi, false); err != nil {
t.Fatalf("storeFile: %v", err)
}
diff --git a/cmd/hauler/cli/store/lifecycle_test.go b/cmd/hauler/cli/store/lifecycle_test.go
index f1f7a9e3..5c413beb 100644
--- a/cmd/hauler/cli/store/lifecycle_test.go
+++ b/cmd/hauler/cli/store/lifecycle_test.go
@@ -31,7 +31,7 @@ func TestLifecycle_FileArtifact_AddSaveLoadCopy(t *testing.T) {
// Step 2: storeFile into store A.
storeA := newTestStore(t)
- if err := storeFile(ctx, storeA, v1.File{Path: url}); err != nil {
+ if err := storeFile(ctx, storeA, v1.File{Path: url}, true); err != nil {
t.Fatalf("storeFile: %v", err)
}
assertArtifactInStore(t, storeA, "lifecycle.txt")
@@ -252,10 +252,10 @@ func TestLifecycle_Remove_ThenSave(t *testing.T) {
url2 := seedFileInHTTPServer(t, "remove-me.txt", "content to remove")
storeA := newTestStore(t)
- if err := storeFile(ctx, storeA, v1.File{Path: url1}); err != nil {
+ if err := storeFile(ctx, storeA, v1.File{Path: url1}, true); err != nil {
t.Fatalf("storeFile keep-me: %v", err)
}
- if err := storeFile(ctx, storeA, v1.File{Path: url2}); err != nil {
+ if err := storeFile(ctx, storeA, v1.File{Path: url2}, true); err != nil {
t.Fatalf("storeFile remove-me: %v", err)
}
diff --git a/cmd/hauler/cli/store/load.go b/cmd/hauler/cli/store/load.go
index 62d19170..bf10492e 100644
--- a/cmd/hauler/cli/store/load.go
+++ b/cmd/hauler/cli/store/load.go
@@ -3,6 +3,7 @@ package store
import (
"context"
"encoding/json"
+ "fmt"
"io"
"net/url"
"os"
@@ -41,7 +42,7 @@ func LoadCmd(ctx context.Context, o *flags.LoadOpts, rso *flags.StoreRootOpts, r
for _, fileName := range o.FileName {
resolved := resolveHaulPath(fileName)
l.Infof("loading haul [%s] to [%s]", resolved, o.StoreDir)
- err := unarchiveLayoutTo(ctx, resolved, o.StoreDir, tempDir)
+ err := unarchiveLayoutTo(ctx, resolved, o.StoreDir, tempDir, rso.AllowInternalTargets)
if err != nil {
return err
}
@@ -52,13 +53,13 @@ func LoadCmd(ctx context.Context, o *flags.LoadOpts, rso *flags.StoreRootOpts, r
}
// accepts an archived OCI layout, extracts the contents to an existing OCI layout, and preserves the index
-func unarchiveLayoutTo(ctx context.Context, haulPath string, dest string, tempDir string) error {
+func unarchiveLayoutTo(ctx context.Context, haulPath string, dest string, tempDir string, allowInternalTargets bool) error {
l := log.FromContext(ctx)
if strings.HasPrefix(haulPath, "http://") || strings.HasPrefix(haulPath, "https://") {
l.Debugf("detected remote archive... starting download... [%s]", haulPath)
- h := getter.NewHttp()
+ h := getter.NewHttpWithOptions(getter.HttpOptions{AllowInternalTargets: allowInternalTargets})
parsedURL, err := url.Parse(haulPath)
if err != nil {
return err
@@ -81,9 +82,13 @@ func unarchiveLayoutTo(ctx context.Context, haulPath string, dest string, tempDi
}
defer out.Close()
- if _, err = io.Copy(out, rc); err != nil {
+ n, err := io.Copy(out, io.LimitReader(rc, consts.MaxDownloadBytes+1))
+ if err != nil {
return err
}
+ if n > consts.MaxDownloadBytes {
+ return fmt.Errorf("remote archive at %s exceeds maximum allowed size (%d bytes)", haulPath, consts.MaxDownloadBytes)
+ }
}
// reassemble chunk files if haulPath matches the chunk naming pattern
@@ -98,10 +103,18 @@ func unarchiveLayoutTo(ctx context.Context, haulPath string, dest string, tempDi
}
// ensure the incoming index.json has the correct annotations.
- data, err := os.ReadFile(tempDir + "/index.json")
+ indexFile, err := os.Open(tempDir + "/index.json")
+ if err != nil {
+ return (err)
+ }
+ data, err := io.ReadAll(io.LimitReader(indexFile, consts.MaxManifestBytes+1))
+ indexFile.Close()
if err != nil {
return (err)
}
+ if int64(len(data)) > consts.MaxManifestBytes {
+ return fmt.Errorf("index.json exceeds maximum allowed size (%d bytes)", consts.MaxManifestBytes)
+ }
var idx ocispec.Index
if err := json.Unmarshal(data, &idx); err != nil {
diff --git a/cmd/hauler/cli/store/load_test.go b/cmd/hauler/cli/store/load_test.go
index 15d28047..37a249a4 100644
--- a/cmd/hauler/cli/store/load_test.go
+++ b/cmd/hauler/cli/store/load_test.go
@@ -71,7 +71,7 @@ func TestUnarchiveLayoutTo(t *testing.T) {
destDir := t.TempDir()
tempDir := t.TempDir()
- if err := unarchiveLayoutTo(ctx, testHaulArchive, destDir, tempDir); err != nil {
+ if err := unarchiveLayoutTo(ctx, testHaulArchive, destDir, tempDir, false); err != nil {
t.Fatalf("unarchiveLayoutTo: %v", err)
}
@@ -192,12 +192,16 @@ func TestLoadCmd_RemoteArchive(t *testing.T) {
destDir := t.TempDir()
remoteURL := srv.URL + "/haul.tar.zst"
+ // AllowInternalTargets=true because the test server binds to loopback.
+ rso := defaultRootOpts(destDir)
+ rso.AllowInternalTargets = true
+
o := &flags.LoadOpts{
- StoreRootOpts: defaultRootOpts(destDir),
+ StoreRootOpts: rso,
FileName: []string{remoteURL},
}
- if err := LoadCmd(ctx, o, defaultRootOpts(destDir), defaultCliOpts()); err != nil {
+ if err := LoadCmd(ctx, o, rso, defaultCliOpts()); err != nil {
t.Fatalf("LoadCmd remote: %v", err)
}
@@ -262,7 +266,7 @@ func TestUnarchiveLayoutTo_AnnotationBackfill(t *testing.T) {
// Step 4: Load the stripped archive.
destDir := t.TempDir()
tempDir := t.TempDir()
- if err := unarchiveLayoutTo(ctx, strippedArchive, destDir, tempDir); err != nil {
+ if err := unarchiveLayoutTo(ctx, strippedArchive, destDir, tempDir, false); err != nil {
t.Fatalf("unarchiveLayoutTo stripped: %v", err)
}
@@ -354,7 +358,7 @@ func TestUnarchiveLayoutTo_LegacyKindMigration(t *testing.T) {
// Step 4: Load the legacy archive.
destDir := t.TempDir()
tempDir := t.TempDir()
- if err := unarchiveLayoutTo(ctx, legacyArchive, destDir, tempDir); err != nil {
+ if err := unarchiveLayoutTo(ctx, legacyArchive, destDir, tempDir, false); err != nil {
t.Fatalf("unarchiveLayoutTo legacy: %v", err)
}
diff --git a/cmd/hauler/cli/store/remove_test.go b/cmd/hauler/cli/store/remove_test.go
index a013e3df..c81e8894 100644
--- a/cmd/hauler/cli/store/remove_test.go
+++ b/cmd/hauler/cli/store/remove_test.go
@@ -81,7 +81,7 @@ func TestRemoveCmd_Force(t *testing.T) {
s := newTestStore(t)
url := seedFileInHTTPServer(t, "removeme.txt", "file-to-remove")
- if err := storeFile(ctx, s, v1.File{Path: url}); err != nil {
+ if err := storeFile(ctx, s, v1.File{Path: url}, true); err != nil {
t.Fatalf("storeFile: %v", err)
}
@@ -133,10 +133,10 @@ func TestRemoveCmd_Force_MultipleMatches(t *testing.T) {
url1 := seedFileInHTTPServer(t, "testfile-alpha.txt", "content-alpha")
url2 := seedFileInHTTPServer(t, "testfile-beta.txt", "content-beta")
- if err := storeFile(ctx, s, v1.File{Path: url1}); err != nil {
+ if err := storeFile(ctx, s, v1.File{Path: url1}, true); err != nil {
t.Fatalf("storeFile alpha: %v", err)
}
- if err := storeFile(ctx, s, v1.File{Path: url2}); err != nil {
+ if err := storeFile(ctx, s, v1.File{Path: url2}, true); err != nil {
t.Fatalf("storeFile beta: %v", err)
}
diff --git a/cmd/hauler/cli/store/save_test.go b/cmd/hauler/cli/store/save_test.go
index 5902519e..9b4035db 100644
--- a/cmd/hauler/cli/store/save_test.go
+++ b/cmd/hauler/cli/store/save_test.go
@@ -118,7 +118,7 @@ func TestWriteExportsManifest_SkipsNonImages(t *testing.T) {
url := seedFileInHTTPServer(t, "skip.sh", "#!/bin/sh\necho skip")
s := newTestStore(t)
- if err := storeFile(ctx, s, v1.File{Path: url}); err != nil {
+ if err := storeFile(ctx, s, v1.File{Path: url}, true); err != nil {
t.Fatalf("storeFile: %v", err)
}
diff --git a/cmd/hauler/cli/store/sync.go b/cmd/hauler/cli/store/sync.go
index 48283196..4a0ceb1f 100644
--- a/cmd/hauler/cli/store/sync.go
+++ b/cmd/hauler/cli/store/sync.go
@@ -82,11 +82,14 @@ func SyncCmd(ctx context.Context, o *flags.SyncOpts, s *store.Layout, rso *flags
if err != nil {
return err
}
- content, err := io.ReadAll(rc)
+ content, err := io.ReadAll(io.LimitReader(rc, consts.MaxManifestBytes+1))
rc.Close()
if err != nil {
return err
}
+ if int64(len(content)) > consts.MaxManifestBytes {
+ return fmt.Errorf("product manifest for [%s] exceeds maximum allowed size (%d bytes)", productName, consts.MaxManifestBytes)
+ }
// Ensure each manifest starts with a YAML document separator.
if !strings.HasPrefix(string(content), "---") {
@@ -161,7 +164,7 @@ func SyncCmd(ctx context.Context, o *flags.SyncOpts, s *store.Layout, rso *flags
if strings.HasPrefix(haulPath, "http://") || strings.HasPrefix(haulPath, "https://") {
l.Debugf("detected remote manifest... starting download... [%s]", haulPath)
- h := getter.NewHttp()
+ h := getter.NewHttpWithOptions(getter.HttpOptions{AllowInternalTargets: rso.AllowInternalTargets})
parsedURL, err := url.Parse(haulPath)
if err != nil {
return err
@@ -184,9 +187,13 @@ func SyncCmd(ctx context.Context, o *flags.SyncOpts, s *store.Layout, rso *flags
}
defer out.Close()
- if _, err = io.Copy(out, rc); err != nil {
+ n, err := io.Copy(out, io.LimitReader(rc, consts.MaxDownloadBytes+1))
+ if err != nil {
return err
}
+ if n > consts.MaxDownloadBytes {
+ return fmt.Errorf("remote manifest at %s exceeds maximum allowed size (%d bytes)", haulPath, consts.MaxDownloadBytes)
+ }
}
fi, err := os.Open(haulPath)
@@ -213,7 +220,7 @@ func SyncCmd(ctx context.Context, o *flags.SyncOpts, s *store.Layout, rso *flags
if strings.HasPrefix(haulPath, "http://") || strings.HasPrefix(haulPath, "https://") {
l.Debugf("detected remote image.txt... starting download... [%s]", haulPath)
- h := getter.NewHttp()
+ h := getter.NewHttpWithOptions(getter.HttpOptions{AllowInternalTargets: rso.AllowInternalTargets})
parsedURL, err := url.Parse(haulPath)
if err != nil {
return err
@@ -236,9 +243,13 @@ func SyncCmd(ctx context.Context, o *flags.SyncOpts, s *store.Layout, rso *flags
}
defer out.Close()
- if _, err = io.Copy(out, rc); err != nil {
+ n, err := io.Copy(out, io.LimitReader(rc, consts.MaxDownloadBytes+1))
+ if err != nil {
return err
}
+ if n > consts.MaxDownloadBytes {
+ return fmt.Errorf("remote image.txt at %s exceeds maximum allowed size (%d bytes)", haulPath, consts.MaxDownloadBytes)
+ }
}
fi, err := os.Open(haulPath)
@@ -296,7 +307,7 @@ func processContent(ctx context.Context, fi *os.File, o *flags.SyncOpts, s *stor
return err
}
for _, f := range cfg.Spec.Files {
- if err := storeFile(ctx, s, f); err != nil {
+ if err := storeFile(ctx, s, f, rso.AllowInternalTargets); err != nil {
return err
}
}
diff --git a/cmd/hauler/cli/store/sync_test.go b/cmd/hauler/cli/store/sync_test.go
index 2a0e5985..8e512686 100644
--- a/cmd/hauler/cli/store/sync_test.go
+++ b/cmd/hauler/cli/store/sync_test.go
@@ -70,6 +70,7 @@ spec:
fi := writeManifestFile(t, manifest)
o := newSyncOpts(s.Root)
+ o.StoreRootOpts.AllowInternalTargets = true // test server is on loopback
ro := defaultCliOpts()
if err := processContent(ctx, fi, o, s, o.StoreRootOpts, ro); err != nil {
@@ -216,6 +217,7 @@ spec:
fi := writeManifestFile(t, manifest)
o := newSyncOpts(s.Root)
+ o.StoreRootOpts.AllowInternalTargets = true // test servers are on loopback
ro := defaultCliOpts()
if err := processContent(ctx, fi, o, s, o.StoreRootOpts, ro); err != nil {
@@ -260,6 +262,7 @@ spec:
o := newSyncOpts(s.Root)
o.FileName = []string{manifestPath}
rso := defaultRootOpts(s.Root)
+ rso.AllowInternalTargets = true // test server is on loopback
ro := defaultCliOpts()
if err := SyncCmd(ctx, o, s, rso, ro); err != nil {
@@ -411,6 +414,7 @@ func TestSyncCmd_ImageTxt_RemoteFile(t *testing.T) {
o := newSyncOpts(s.Root)
o.ImageTxt = []string{imageSrv.URL + "/images.txt"}
rso := defaultRootOpts(s.Root)
+ rso.AllowInternalTargets = true // test server is on loopback
ro := defaultCliOpts()
if err := SyncCmd(ctx, o, s, rso, ro); err != nil {
@@ -444,6 +448,7 @@ spec:
o := newSyncOpts(s.Root)
o.FileName = []string{manifestSrv.URL + "/manifest.yaml"}
rso := defaultRootOpts(s.Root)
+ rso.AllowInternalTargets = true // both manifest server and file server are on loopback
ro := defaultCliOpts()
if err := SyncCmd(ctx, o, s, rso, ro); err != nil {
diff --git a/install.sh b/install.sh
index f762feda..f46fb96a 100755
--- a/install.sh
+++ b/install.sh
@@ -150,8 +150,8 @@ if [ ! -d "${HAULER_DIR}" ]; then
mkdir -p "${HAULER_DIR}" || fatal "Failed to Create Hauler Directory: ${HAULER_DIR}"
fi
-# ensure hauler directory is writable (by user or root privileges)
-chmod -R 777 "${HAULER_DIR}" || fatal "Failed to Update Permissions of Hauler Directory: ${HAULER_DIR}"
+# ensure hauler directory is only accessible by the owner
+chmod -R 0700 "${HAULER_DIR}" || fatal "Failed to Update Permissions of Hauler Directory: ${HAULER_DIR}"
# change to hauler directory
cd "${HAULER_DIR}" || fatal "Failed to Change Directory to Hauler Directory: ${HAULER_DIR}"
diff --git a/internal/flags/store.go b/internal/flags/store.go
index ddcf3e01..2491687d 100644
--- a/internal/flags/store.go
+++ b/internal/flags/store.go
@@ -13,9 +13,10 @@ import (
)
type StoreRootOpts struct {
- StoreDir string
- Retries int
- TempOverride string
+ StoreDir string
+ Retries int
+ TempOverride string
+ AllowInternalTargets bool
}
func (o *StoreRootOpts) AddFlags(cmd *cobra.Command) {
@@ -23,6 +24,7 @@ func (o *StoreRootOpts) AddFlags(cmd *cobra.Command) {
pf.StringVarP(&o.StoreDir, "store", "s", "", "Set the directory to use for the content store")
pf.IntVarP(&o.Retries, "retries", "r", consts.DefaultRetries, "Set the number of retries for operations")
pf.StringVarP(&o.TempOverride, "tempdir", "t", "", "(Optional) Override the default temporary directory determined by the OS")
+ pf.BoolVar(&o.AllowInternalTargets, "allow-internal-targets", false, "(Optional) Allow fetching from private/loopback IP addresses (disables SSRF protection; for isolated internal CI use only)")
}
func (o *StoreRootOpts) Store(ctx context.Context) (*store.Layout, error) {
diff --git a/pkg/archives/archiver.go b/pkg/archives/archiver.go
index cb9320b9..5c418a84 100644
--- a/pkg/archives/archiver.go
+++ b/pkg/archives/archiver.go
@@ -12,6 +12,13 @@ import (
"hauler.dev/go/hauler/pkg/log"
)
+// compressionZstd and archivalTar are package-level vars so tests can reference
+// them without importing mholt/archives directly.
+var (
+ compressionZstd = archives.Zstd{}
+ archivalTar = archives.Tar{}
+)
+
// maps to handle compression types
var CompressionMap = map[string]archives.Compression{
"gz": archives.Gz{},
diff --git a/pkg/archives/limits_test.go b/pkg/archives/limits_test.go
new file mode 100644
index 00000000..9c4a4d7c
--- /dev/null
+++ b/pkg/archives/limits_test.go
@@ -0,0 +1,145 @@
+package archives
+
+import (
+ "context"
+ "io"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+
+ "github.com/rs/zerolog"
+)
+
+func limitsTestContext(t *testing.T) context.Context {
+ t.Helper()
+ l := zerolog.New(io.Discard)
+ return l.WithContext(context.Background())
+}
+
+// buildSmallTarZst creates a real .tar.zst in a temp dir using the production
+// Archive() function, so its format is always compatible with Unarchive().
+func buildSmallTarZst(t *testing.T, entries map[string]string) string {
+ t.Helper()
+ srcDir := t.TempDir()
+ for name, body := range entries {
+ full := filepath.Join(srcDir, name)
+ if err := os.MkdirAll(filepath.Dir(full), 0o755); err != nil {
+ t.Fatal(err)
+ }
+ if err := os.WriteFile(full, []byte(body), 0o644); err != nil {
+ t.Fatal(err)
+ }
+ }
+ out := filepath.Join(t.TempDir(), "test.tar.zst")
+ ctx := limitsTestContext(t)
+ if err := Archive(ctx, srcDir, out, compressionZstd, archivalTar); err != nil {
+ t.Fatalf("Archive(): %v", err)
+ }
+ return out
+}
+
+// TestUnarchive_PerFileByteCap verifies that a single file exceeding
+// maxArchiveFileBytes is rejected during extraction.
+func TestUnarchive_PerFileByteCap(t *testing.T) {
+ body := make([]byte, 512)
+ archive := buildSmallTarZst(t, map[string]string{"big.bin": string(body)})
+
+ dst := t.TempDir()
+ ctx := limitsTestContext(t)
+ limits := extractionLimits{
+ maxFileBytes: 256, // smaller than the 512-byte file
+ maxTotalBytes: 100 << 30,
+ maxFiles: 100_000,
+ }
+ if err := unarchiveWithLimits(ctx, archive, dst, limits); err == nil {
+ t.Fatal("unarchiveWithLimits() expected error for per-file cap, got nil")
+ }
+}
+
+// TestUnarchive_AggregateByteCap verifies that the total extracted bytes across
+// all files cannot exceed maxTotalBytes.
+func TestUnarchive_AggregateByteCap(t *testing.T) {
+ // Two files each 256 bytes — aggregate cap set to 400, which is exceeded.
+ body := make([]byte, 256)
+ archive := buildSmallTarZst(t, map[string]string{
+ "a.bin": string(body),
+ "b.bin": string(body),
+ })
+
+ dst := t.TempDir()
+ ctx := limitsTestContext(t)
+ limits := extractionLimits{
+ maxFileBytes: 100 << 30, // per-file: no constraint
+ maxTotalBytes: 400, // aggregate: 256+256 = 512 > 400
+ maxFiles: 100_000,
+ }
+ if err := unarchiveWithLimits(ctx, archive, dst, limits); err == nil {
+ t.Fatal("unarchiveWithLimits() expected error for aggregate cap, got nil")
+ }
+}
+
+// TestUnarchive_FileCountCap verifies that an archive with more entries than
+// maxFiles is rejected.
+func TestUnarchive_FileCountCap(t *testing.T) {
+ entries := make(map[string]string, 6)
+ for i := 0; i < 6; i++ {
+ entries[filepath.Join("subdir", string(rune('a'+i))+".txt")] = "x"
+ }
+ archive := buildSmallTarZst(t, entries)
+
+ dst := t.TempDir()
+ ctx := limitsTestContext(t)
+ limits := extractionLimits{
+ maxFileBytes: 100 << 30,
+ maxTotalBytes: 100 << 30,
+ maxFiles: 3, // only 3 allowed; archive has 6
+ }
+ if err := unarchiveWithLimits(ctx, archive, dst, limits); err == nil {
+ t.Fatal("unarchiveWithLimits() expected error for file-count cap, got nil")
+ }
+}
+
+// TestUnarchive_WithinLimits confirms the default path succeeds for normal archives.
+func TestUnarchive_WithinLimits(t *testing.T) {
+ archive := buildSmallTarZst(t, map[string]string{
+ "file1.txt": "hello",
+ "file2.txt": "world",
+ })
+ dst := t.TempDir()
+ ctx := limitsTestContext(t)
+ if err := Unarchive(ctx, archive, dst); err != nil {
+ t.Fatalf("Unarchive() unexpected error: %v", err)
+ }
+}
+
+// TestUnarchive_DecompressionRatioCap verifies that an archive whose
+// decompressed:compressed ratio exceeds consts.MaxDecompressionRatio is
+// rejected during extraction. Highly redundant content (10 MiB of zero bytes)
+// compresses with zstd to a few KiB, well above the 100x default ratio.
+func TestUnarchive_DecompressionRatioCap(t *testing.T) {
+ body := make([]byte, 10<<20) // 10 MiB of zero bytes
+ archive := buildSmallTarZst(t, map[string]string{"compressible.bin": string(body)})
+
+ stat, err := os.Stat(archive)
+ if err != nil {
+ t.Fatalf("stat archive: %v", err)
+ }
+ t.Logf("archive size: %d bytes (decompressed: %d bytes, ratio: %.0fx)",
+ stat.Size(), len(body), float64(len(body))/float64(stat.Size()))
+
+ dst := t.TempDir()
+ ctx := limitsTestContext(t)
+ limits := extractionLimits{
+ maxFileBytes: 100 << 30, // not the constraint
+ maxTotalBytes: 100 << 30, // not the constraint
+ maxFiles: 100_000,
+ }
+ err = unarchiveWithLimits(ctx, archive, dst, limits)
+ if err == nil {
+ t.Fatal("unarchiveWithLimits() expected error for decompression ratio, got nil")
+ }
+ if !strings.Contains(err.Error(), "decompression ratio") {
+ t.Errorf("expected error mentioning decompression ratio, got: %v", err)
+ }
+}
diff --git a/pkg/archives/unarchiver.go b/pkg/archives/unarchiver.go
index 7678f265..590ddb3e 100644
--- a/pkg/archives/unarchiver.go
+++ b/pkg/archives/unarchiver.go
@@ -10,11 +10,21 @@ import (
"sort"
"strconv"
"strings"
+ "sync/atomic"
"github.com/mholt/archives"
+ "hauler.dev/go/hauler/pkg/consts"
"hauler.dev/go/hauler/pkg/log"
)
+// extractionLimits holds per-extraction caps to prevent decompression bombs
+// and file-count exhaustion attacks.
+type extractionLimits struct {
+ maxFileBytes int64 // maximum bytes for a single extracted file
+ maxTotalBytes int64 // maximum aggregate bytes across all extracted files
+ maxFiles int64 // maximum number of files extracted
+}
+
const (
dirPermissions = 0o700 // default directory permissions
filePermissions = 0o600 // default file permissions
@@ -51,8 +61,17 @@ func setPermissions(path string, mode os.FileMode) error {
return nil
}
+// extractionState tracks mutable counters shared across handleFile calls.
+// archiveSize is the on-disk size of the input archive; when non-zero it is
+// used to enforce a decompression-ratio bomb check against totalBytes.
+type extractionState struct {
+ totalBytes atomic.Int64
+ fileCount atomic.Int64
+ archiveSize int64
+}
+
// handles the extraction of a file from the archive.
-func handleFile(ctx context.Context, f archives.FileInfo, dst string) error {
+func handleFileWithLimits(ctx context.Context, f archives.FileInfo, dst string, lim extractionLimits, state *extractionState) error {
l := log.FromContext(ctx)
l.Debugf("handling file [%s]", f.NameInArchive)
@@ -70,7 +89,6 @@ func handleFile(ctx context.Context, f archives.FileInfo, dst string) error {
// handle directories
if f.IsDir() {
- // create the directory with permissions from the archive
if dirErr := createDirWithPermissions(ctx, dstPath, f.Mode()); dirErr != nil {
return fmt.Errorf("failed to create directory: %w", dirErr)
}
@@ -84,6 +102,14 @@ func handleFile(ctx context.Context, f archives.FileInfo, dst string) error {
return nil
}
+ // enforce file count cap
+ if state != nil {
+ count := state.fileCount.Add(1)
+ if count > lim.maxFiles {
+ return fmt.Errorf("archive contains more than %d files (extraction limit exceeded)", lim.maxFiles)
+ }
+ }
+
// check and handle parent directory permissions
originalMode, statErr := os.Stat(parentDir)
if statErr != nil {
@@ -97,7 +123,6 @@ func handleFile(ctx context.Context, f archives.FileInfo, dst string) error {
return fmt.Errorf("failed to chmod parent directory: %w", chmodErr)
}
defer func() {
- // restore the original permissions after writing
if chmodErr := os.Chmod(parentDir, originalMode.Mode()); chmodErr != nil {
l.Debugf("failed to restore original permissions for [%s]: %v", parentDir, chmodErr)
}
@@ -117,13 +142,72 @@ func handleFile(ctx context.Context, f archives.FileInfo, dst string) error {
}
defer dstFile.Close()
- if _, copyErr := io.Copy(dstFile, reader); copyErr != nil {
+ // copy with per-file and aggregate byte caps
+ var totalPtr *atomic.Int64
+ var archiveSize int64
+ if state != nil {
+ totalPtr = &state.totalBytes
+ archiveSize = state.archiveSize
+ }
+ written, copyErr := copyBounded(dstFile, reader, lim.maxFileBytes, lim.maxTotalBytes, totalPtr, archiveSize)
+ if copyErr != nil {
return fmt.Errorf("failed to copy: %w", copyErr)
}
+ _ = written
l.Debugf("successfully extracted file [%s]", dstPath)
return nil
}
+// copyBounded copies from src to dst, enforcing a per-file cap (maxFile) and
+// updating a shared total counter checked against maxTotal. total may be nil
+// when called from the default (non-tracked) code path. When archiveSize > 0,
+// it also enforces a decompression-ratio cap against consts.MaxDecompressionRatio.
+func copyBounded(dst io.Writer, src io.Reader, maxFile, maxTotal int64, total *atomic.Int64, archiveSize int64) (int64, error) {
+ buf := make([]byte, 32*1024)
+ var fileWritten int64
+ for {
+ nr, readErr := src.Read(buf)
+ if nr > 0 {
+ fileWritten += int64(nr)
+ if fileWritten > maxFile {
+ return fileWritten, fmt.Errorf("extracted file exceeds per-file size limit of %d bytes", maxFile)
+ }
+ if total != nil {
+ agg := total.Add(int64(nr))
+ if agg > maxTotal {
+ return fileWritten, fmt.Errorf("total extracted bytes exceed aggregate limit of %d bytes", maxTotal)
+ }
+ if archiveSize > 0 && float64(agg) > float64(archiveSize)*consts.MaxDecompressionRatio {
+ return fileWritten, fmt.Errorf("decompression ratio exceeds %vx (likely zip bomb)", consts.MaxDecompressionRatio)
+ }
+ }
+ if _, writeErr := dst.Write(buf[:nr]); writeErr != nil {
+ return fileWritten, writeErr
+ }
+ }
+ if readErr == io.EOF {
+ break
+ }
+ if readErr != nil {
+ return fileWritten, readErr
+ }
+ }
+ return fileWritten, nil
+}
+
+// handleFile delegates to handleFileWithLimits using the default (production) limits.
+// It exists so the Unarchive function signature remains unchanged.
+func handleFile(ctx context.Context, f archives.FileInfo, dst string, state *extractionState) error {
+ // The production limits are extremely generous and should never fire on
+ // legitimate hauler workloads; they exist to bound pathological inputs.
+ defaultLimits := extractionLimits{
+ maxFileBytes: consts.MaxArchiveFileBytes,
+ maxTotalBytes: consts.MaxArchiveBytes,
+ maxFiles: consts.MaxArchiveFiles,
+ }
+ return handleFileWithLimits(ctx, f, dst, defaultLimits, state)
+}
+
// unarchives a tarball to a directory, symlinks, and hardlinks are ignored
func Unarchive(ctx context.Context, tarball, dst string) error {
l := log.FromContext(ctx)
@@ -134,6 +218,11 @@ func Unarchive(ctx context.Context, tarball, dst string) error {
}
defer archiveFile.Close()
+ stat, statErr := archiveFile.Stat()
+ if statErr != nil {
+ return fmt.Errorf("failed to stat tarball %s: %w", tarball, statErr)
+ }
+
format, input, identifyErr := archives.Identify(context.Background(), tarball, archiveFile)
if identifyErr != nil {
return fmt.Errorf("failed to identify format: %w", identifyErr)
@@ -148,8 +237,9 @@ func Unarchive(ctx context.Context, tarball, dst string) error {
return fmt.Errorf("failed to create destination directory: %w", dirErr)
}
+ state := &extractionState{archiveSize: stat.Size()}
handler := func(ctx context.Context, f archives.FileInfo) error {
- return handleFile(ctx, f, dst)
+ return handleFile(ctx, f, dst, state)
}
if extractErr := extractor.Extract(context.Background(), input, handler); extractErr != nil {
@@ -160,6 +250,48 @@ func Unarchive(ctx context.Context, tarball, dst string) error {
return nil
}
+// unarchiveWithLimits is like Unarchive but enforces explicit extraction limits.
+// It is used directly in tests to verify cap enforcement with small values.
+func unarchiveWithLimits(ctx context.Context, tarball, dst string, lim extractionLimits) error {
+ l := log.FromContext(ctx)
+ l.Debugf("unarchiving [%s] to [%s] with limits", tarball, dst)
+
+ archiveFile, openErr := os.Open(tarball)
+ if openErr != nil {
+ return fmt.Errorf("failed to open tarball %s: %w", tarball, openErr)
+ }
+ defer archiveFile.Close()
+
+ stat, statErr := archiveFile.Stat()
+ if statErr != nil {
+ return fmt.Errorf("failed to stat tarball %s: %w", tarball, statErr)
+ }
+
+ format, input, identifyErr := archives.Identify(context.Background(), tarball, archiveFile)
+ if identifyErr != nil {
+ return fmt.Errorf("failed to identify format: %w", identifyErr)
+ }
+
+ extractor, ok := format.(archives.Extractor)
+ if !ok {
+ return fmt.Errorf("unsupported format for extraction")
+ }
+
+ if dirErr := createDirWithPermissions(ctx, dst, dirPermissions); dirErr != nil {
+ return fmt.Errorf("failed to create destination directory: %w", dirErr)
+ }
+
+ state := &extractionState{archiveSize: stat.Size()}
+ handler := func(ctx context.Context, f archives.FileInfo) error {
+ return handleFileWithLimits(ctx, f, dst, lim, state)
+ }
+
+ if extractErr := extractor.Extract(context.Background(), input, handler); extractErr != nil {
+ return fmt.Errorf("failed to extract: %w", extractErr)
+ }
+ return nil
+}
+
var chunkSuffixRe = regexp.MustCompile(`^(.+)_(\d+)$`)
// chunkInfo checks whether archivePath matches the chunk naming pattern (_N).
@@ -226,17 +358,27 @@ func JoinChunks(ctx context.Context, archivePath, tempDir string) (string, error
}
defer outf.Close()
+ var joinedTotal int64
for _, chunk := range matches {
l.Debugf("joining chunk [%s]", chunk)
cf, err := os.Open(chunk)
if err != nil {
return "", fmt.Errorf("failed to open chunk [%s]: %w", chunk, err)
}
- if _, err := io.Copy(outf, cf); err != nil {
+ remaining := consts.MaxArchiveBytes - joinedTotal
+ if remaining <= 0 {
cf.Close()
- return "", fmt.Errorf("failed to copy chunk [%s]: %w", chunk, err)
+ return "", fmt.Errorf("joined chunks exceed maximum allowed size (%d bytes)", consts.MaxArchiveBytes)
}
+ n, err := io.Copy(outf, io.LimitReader(cf, remaining+1))
cf.Close()
+ if err != nil {
+ return "", fmt.Errorf("failed to copy chunk [%s]: %w", chunk, err)
+ }
+ joinedTotal += n
+ if joinedTotal > consts.MaxArchiveBytes {
+ return "", fmt.Errorf("joined chunks exceed maximum allowed size (%d bytes)", consts.MaxArchiveBytes)
+ }
}
l.Infof("joined %d chunk(s) into [%s]", len(matches), filepath.Base(joinedPath))
diff --git a/pkg/artifacts/file/file_test.go b/pkg/artifacts/file/file_test.go
index 062b80b1..3ff5e91a 100644
--- a/pkg/artifacts/file/file_test.go
+++ b/pkg/artifacts/file/file_test.go
@@ -127,7 +127,7 @@ func setup() func() {
mf := &mockFile{File: getter.NewFile(), fs: tfs}
- mockHttp := getter.NewHttp()
+ mockHttp := getter.NewHttpWithOptions(getter.HttpOptions{AllowInternalTargets: true})
mhttp := afero.NewHttpFs(tfs)
fileserver := http.FileServer(mhttp.Dir("."))
http.Handle("/", fileserver)
diff --git a/pkg/consts/limits.go b/pkg/consts/limits.go
new file mode 100644
index 00000000..9e6fd403
--- /dev/null
+++ b/pkg/consts/limits.go
@@ -0,0 +1,34 @@
+package consts
+
+const (
+ // MaxDownloadBytes caps HTTP response bodies fetched by the HTTP getter.
+ // 10 GiB is deliberately generous for large hauler archives while still
+ // bounding runaway downloads.
+ MaxDownloadBytes int64 = 10 << 30 // 10 GiB
+
+ // MaxManifestBytes caps OCI manifest and index reads to prevent a hostile
+ // registry from exhausting process memory.
+ MaxManifestBytes int64 = 16 << 20 // 16 MiB
+
+ // MaxArchiveBytes caps the total uncompressed bytes written during archive
+ // extraction. Set to 100 GiB to comfortably exceed real-world haul sizes
+ // while still bounding zip-bomb attacks.
+ MaxArchiveBytes int64 = 100 << 30 // 100 GiB
+
+ // MaxArchiveFileBytes caps the uncompressed size of a single file inside an
+ // archive.
+ MaxArchiveFileBytes int64 = 50 << 30 // 50 GiB
+
+ // MaxArchiveFiles caps the number of files that may be extracted from a
+ // single archive.
+ MaxArchiveFiles int64 = 100_000
+
+ // MaxDecompressionRatio is the maximum allowed ratio of decompressed to
+ // compressed bytes. Archives exceeding this ratio are likely zip bombs.
+ MaxDecompressionRatio float64 = 100.0
+
+ // HTTPClientTimeout is the default timeout for outbound HTTP requests in
+ // the HTTP getter. Set to 30 minutes to handle large archive downloads
+ // without hanging indefinitely.
+ HTTPClientTimeout = 30 * 60 // seconds — resolved to time.Duration at use site
+)
diff --git a/pkg/content/chart/chart.go b/pkg/content/chart/chart.go
index 5f86b125..d0f2acd0 100644
--- a/pkg/content/chart/chart.go
+++ b/pkg/content/chart/chart.go
@@ -49,6 +49,14 @@ func NewChart(name string, opts *action.ChartPathOptions) (*Chart, error) {
client := action.NewInstall(actionConfig)
client.ChartPathOptions.Version = opts.Version
+ client.ChartPathOptions.Verify = opts.Verify
+ client.ChartPathOptions.Username = opts.Username
+ client.ChartPathOptions.Password = opts.Password
+ client.ChartPathOptions.CertFile = opts.CertFile
+ client.ChartPathOptions.KeyFile = opts.KeyFile
+ client.ChartPathOptions.CaFile = opts.CaFile
+ client.ChartPathOptions.InsecureSkipTLSverify = opts.InsecureSkipTLSverify
+ client.ChartPathOptions.PlainHTTP = opts.PlainHTTP
registryClient, err := newRegistryClient(client.CertFile, client.KeyFile, client.CaFile,
client.InsecureSkipTLSverify, client.PlainHTTP)
@@ -65,27 +73,17 @@ func NewChart(name string, opts *action.ChartPathOptions) (*Chart, error) {
chartRef = opts.RepoURL + "/" + name
}
- // suppress helm downloader oci logs (stdout/stderr)
- oldStdout := os.Stdout
- oldStderr := os.Stderr
- rOut, wOut, _ := os.Pipe()
- rErr, wErr, _ := os.Pipe()
- os.Stdout = wOut
- os.Stderr = wErr
-
- chartPath, err := client.ChartPathOptions.LocateChart(chartRef, settings)
-
- wOut.Close()
- wErr.Close()
- os.Stdout = oldStdout
- os.Stderr = oldStderr
- _, _ = io.Copy(io.Discard, rOut)
- _, _ = io.Copy(io.Discard, rErr)
- rOut.Close()
- rErr.Close()
-
- if err != nil {
- return nil, err
+ // Suppress Helm downloader OCI logs via the goroutine-safe CaptureOutput.
+ // The mutex in CaptureOutput serializes all callers so concurrent chart
+ // fetches cannot race on os.Stdout/os.Stderr.
+ var chartPath string
+ captureErr := log.CaptureOutput(log.NewLogger(io.Discard), true, func() error {
+ var locErr error
+ chartPath, locErr = client.ChartPathOptions.LocateChart(chartRef, settings)
+ return locErr
+ })
+ if captureErr != nil {
+ return nil, captureErr
}
return &Chart{
diff --git a/pkg/content/chart/chart_test.go b/pkg/content/chart/chart_test.go
index 94133e92..0173e0d0 100644
--- a/pkg/content/chart/chart_test.go
+++ b/pkg/content/chart/chart_test.go
@@ -13,6 +13,40 @@ import (
"hauler.dev/go/hauler/pkg/content/chart"
)
+// TestNewChart_VerifyAndAuthPropagated verifies that --verify and auth/TLS options
+// in action.ChartPathOptions are actually wired through to the Helm client.
+// With Verify=true the Helm client must reject a chart that has no .prov file.
+func TestNewChart_VerifyAndAuthPropagated(t *testing.T) {
+ t.Run("verify flag causes failure on unsigned chart", func(t *testing.T) {
+ opts := &action.ChartPathOptions{
+ RepoURL: "../../../testdata",
+ Verify: true,
+ }
+ _, err := chart.NewChart("rancher-cluster-templates-0.5.2.tgz", opts)
+ if err == nil {
+ t.Fatal("NewChart() expected error with Verify=true on unsigned chart, got nil")
+ }
+ })
+
+ t.Run("credentials are propagated and do not break local chart load", func(t *testing.T) {
+ // Credentials are passed but local chart loading does not require auth.
+ // This test ensures setting Username/Password does not silently break
+ // the happy path (i.e. they are stored, not discarded).
+ opts := &action.ChartPathOptions{
+ RepoURL: "../../../testdata",
+ Username: "user",
+ Password: "pass",
+ }
+ c, err := chart.NewChart("rancher-cluster-templates-0.5.2.tgz", opts)
+ if err != nil {
+ t.Fatalf("NewChart() unexpected error: %v", err)
+ }
+ if c == nil {
+ t.Fatal("NewChart() returned nil chart")
+ }
+ })
+}
+
func TestNewChart(t *testing.T) {
tempDir, err := os.MkdirTemp("", "hauler")
if err != nil {
diff --git a/pkg/getter/getter.go b/pkg/getter/getter.go
index c022c212..7206ab6a 100644
--- a/pkg/getter/getter.go
+++ b/pkg/getter/getter.go
@@ -23,7 +23,8 @@ type Client struct {
// ClientOptions provides options for the client
type ClientOptions struct {
- NameOverride string
+ NameOverride string
+ AllowInternalTargets bool
}
var (
@@ -44,7 +45,7 @@ func NewClient(opts ClientOptions) *Client {
defaults := map[string]Getter{
"file": NewFile(),
"directory": NewDirectory(),
- "http": NewHttp(),
+ "http": NewHttpWithOptions(HttpOptions{AllowInternalTargets: opts.AllowInternalTargets}),
}
c := &Client{
diff --git a/pkg/getter/https.go b/pkg/getter/https.go
index 4bd4d23e..e408f9e3 100644
--- a/pkg/getter/https.go
+++ b/pkg/getter/https.go
@@ -5,23 +5,190 @@ import (
"fmt"
"io"
"mime"
+ "net"
"net/http"
"net/url"
"path/filepath"
"strings"
+ "time"
"hauler.dev/go/hauler/pkg/artifacts"
"hauler.dev/go/hauler/pkg/consts"
)
-type Http struct{}
+// dialTimeout is the TCP connect and keep-alive timeout used by safeDial.
+const dialTimeout = 30 * time.Second
+
+// HttpOptions configures the behaviour of the Http getter.
+type HttpOptions struct {
+ // AllowInternalTargets disables the SSRF guard that is enforced at dial
+ // time by the custom DialContext. When false (the default), every IP
+ // address returned by DNS resolution is validated against isInternalIP
+ // before any connection is attempted, and the connection is made to the
+ // resolved IP literal directly so the check and the connect target the
+ // same address. Set to true only for isolated internal CI environments
+ // that intentionally fetch from private or loopback hosts.
+ AllowInternalTargets bool
+
+ // Timeout overrides the default HTTP client timeout.
+ Timeout time.Duration
+
+ // MaxBytes overrides the default per-response download cap. Zero means
+ // use consts.MaxDownloadBytes.
+ MaxBytes int64
+}
+
+// Http is the Getter for http/https URLs.
+type Http struct {
+ opts HttpOptions
+ client *http.Client
+ maxBytes int64
+}
func NewHttp() *Http {
- return &Http{}
+ return NewHttpWithOptions(HttpOptions{})
+}
+
+func NewHttpWithTimeout(d time.Duration) *Http {
+ return NewHttpWithOptions(HttpOptions{Timeout: d})
+}
+
+func NewHttpWithOptions(opts HttpOptions) *Http {
+ timeout := time.Duration(consts.HTTPClientTimeout) * time.Second
+ if opts.Timeout > 0 {
+ timeout = opts.Timeout
+ }
+
+ maxBytes := opts.MaxBytes
+ if maxBytes <= 0 {
+ maxBytes = consts.MaxDownloadBytes
+ }
+
+ h := &Http{opts: opts, maxBytes: maxBytes}
+
+ baseDialer := &net.Dialer{
+ Timeout: dialTimeout,
+ KeepAlive: dialTimeout,
+ }
+ transport := &http.Transport{
+ DialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
+ return h.safeDial(ctx, baseDialer, network, address)
+ },
+ TLSHandshakeTimeout: 10 * time.Second,
+ ResponseHeaderTimeout: 30 * time.Second,
+ ExpectContinueTimeout: 1 * time.Second,
+ MaxIdleConns: 10,
+ IdleConnTimeout: 90 * time.Second,
+ // Do NOT set TLSClientConfig: Go derives tls.Config.ServerName from
+ // the request URL hostname, so TLS cert verification continues to use
+ // the hostname even though we dial by IP literal.
+ }
+
+ h.client = &http.Client{
+ Timeout: timeout,
+ Transport: transport,
+ CheckRedirect: func(req *http.Request, via []*http.Request) error {
+ return h.validateRequest(req)
+ },
+ }
+ return h
+}
+
+// validateRequest enforces scheme restrictions. It is called for the initial
+// request and each redirect hop via CheckRedirect. IP/host validation is
+// performed at dial time by safeDial so the checked address is exactly the
+// address we connect to, eliminating the DNS-rebinding TOCTOU.
+func (h *Http) validateRequest(req *http.Request) error {
+ switch req.URL.Scheme {
+ case "http", "https":
+ default:
+ return fmt.Errorf("scheme %q is not allowed; only http and https are permitted", req.URL.Scheme)
+ }
+ return nil
+}
+
+// safeDial resolves address to candidate IPs, rejects internal IPs (when
+// AllowInternalTargets=false), and dials the resolved IP literal directly.
+// Performing both the IP check and the connect against the same resolved
+// address eliminates the DNS-rebinding TOCTOU that exists when validation
+// and connect each perform their own independent resolution.
+func (h *Http) safeDial(ctx context.Context, dialer *net.Dialer, network, address string) (net.Conn, error) {
+ host, port, err := net.SplitHostPort(address)
+ if err != nil {
+ return nil, err
+ }
+
+ ips, err := net.DefaultResolver.LookupIPAddr(ctx, host)
+ if err != nil {
+ return nil, fmt.Errorf("failed to resolve %s: %w", host, err)
+ }
+ if len(ips) == 0 {
+ return nil, fmt.Errorf("no addresses found for %s", host)
+ }
+
+ if !h.opts.AllowInternalTargets {
+ // Reject if ANY candidate is internal — prevents an attacker from
+ // returning [public, private] and hoping fallback hits the private IP.
+ for _, ipAddr := range ips {
+ if isInternalIP(ipAddr.IP) {
+ return nil, fmt.Errorf("dial to %s rejected: resolved to internal address %s (use --allow-internal-targets to override)", host, ipAddr.IP)
+ }
+ }
+ }
+
+ // Dial each candidate by IP literal until one succeeds. Bracket IPv6.
+ var lastErr error
+ for _, ipAddr := range ips {
+ ipStr := ipAddr.IP.String()
+ if ipAddr.IP.To4() == nil {
+ ipStr = "[" + ipStr + "]"
+ }
+ conn, dialErr := dialer.DialContext(ctx, network, net.JoinHostPort(ipStr, port))
+ if dialErr == nil {
+ return conn, nil
+ }
+ lastErr = dialErr
+ }
+ return nil, lastErr
+}
+
+// isInternalIP reports whether ip is in a private, loopback, link-local, or
+// unique-local range.
+func isInternalIP(ip net.IP) bool {
+ private := []net.IPNet{
+ // IPv4 loopback
+ {IP: net.IP{127, 0, 0, 0}, Mask: net.CIDRMask(8, 32)},
+ // RFC-1918 private ranges
+ {IP: net.IP{10, 0, 0, 0}, Mask: net.CIDRMask(8, 32)},
+ {IP: net.IP{172, 16, 0, 0}, Mask: net.CIDRMask(12, 32)},
+ {IP: net.IP{192, 168, 0, 0}, Mask: net.CIDRMask(16, 32)},
+ // IPv4 link-local
+ {IP: net.IP{169, 254, 0, 0}, Mask: net.CIDRMask(16, 32)},
+ }
+ if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
+ return true
+ }
+ for i := range private {
+ if private[i].Contains(ip) {
+ return true
+ }
+ }
+ // IPv6 unique-local (fc00::/7)
+ if len(ip) == 16 && (ip[0]&0xfe) == 0xfc {
+ return true
+ }
+ return false
}
func (h Http) Name(u *url.URL) string {
- resp, err := http.Head(u.String())
+ req, err := http.NewRequest(http.MethodHead, u.String(), nil)
+ if err != nil {
+ return ""
+ }
+ if err := h.validateRequest(req); err != nil {
+ return ""
+ }
+ resp, err := h.client.Do(req)
if err != nil {
return ""
}
@@ -46,7 +213,15 @@ func (h Http) Name(u *url.URL) string {
}
func (h Http) Open(ctx context.Context, u *url.URL) (io.ReadCloser, error) {
- resp, err := http.Get(u.String())
+ req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil)
+ if err != nil {
+ return nil, err
+ }
+ if err := h.validateRequest(req); err != nil {
+ return nil, err
+ }
+
+ resp, err := h.client.Do(req)
if err != nil {
return nil, err
}
@@ -54,7 +229,22 @@ func (h Http) Open(ctx context.Context, u *url.URL) (io.ReadCloser, error) {
resp.Body.Close()
return nil, fmt.Errorf("unexpected status fetching %s: %s", u.String(), resp.Status)
}
- return resp.Body, nil
+
+ // Validate Content-Length header if the server provided one.
+ if resp.ContentLength > h.maxBytes {
+ resp.Body.Close()
+ return nil, fmt.Errorf("remote file at %s exceeds maximum allowed size (%d bytes)", u.String(), h.maxBytes)
+ }
+
+ // Wrap the body so a server that lies about Content-Length (or omits it)
+ // is still bounded. Read maxBytes+1 so we can distinguish "exactly at cap"
+ // from "exceeded cap".
+ limited := &limitedReadCloser{
+ r: io.LimitReader(resp.Body, h.maxBytes+1),
+ c: resp.Body,
+ max: h.maxBytes,
+ }
+ return limited, nil
}
func (h Http) Detect(u *url.URL) bool {
@@ -72,6 +262,28 @@ func (h *Http) Config(u *url.URL) artifacts.Config {
return artifacts.ToConfig(c, artifacts.WithConfigMediaType(consts.FileHttpConfigMediaType))
}
+// limitedReadCloser wraps an io.LimitReader and returns an error when the cap
+// is hit rather than silently returning EOF.
+type limitedReadCloser struct {
+ r io.Reader
+ c io.Closer
+ max int64
+ read int64
+}
+
+func (l *limitedReadCloser) Read(p []byte) (int, error) {
+ n, err := l.r.Read(p)
+ l.read += int64(n)
+ if l.read > l.max {
+ return n, fmt.Errorf("download exceeded maximum allowed size of %d bytes", l.max)
+ }
+ return n, err
+}
+
+func (l *limitedReadCloser) Close() error {
+ return l.c.Close()
+}
+
type httpConfig struct {
config `json:",inline,omitempty"`
}
diff --git a/pkg/getter/https_security_test.go b/pkg/getter/https_security_test.go
new file mode 100644
index 00000000..3c212a82
--- /dev/null
+++ b/pkg/getter/https_security_test.go
@@ -0,0 +1,198 @@
+package getter_test
+
+import (
+ "context"
+ "fmt"
+ "io"
+ "net/http"
+ "net/http/httptest"
+ "net/url"
+ "strings"
+ "testing"
+ "time"
+
+ "hauler.dev/go/hauler/pkg/getter"
+)
+
+// --- A3: Unbounded download protection ---
+
+// TestHttp_Open_RejectsOversizedBody verifies that Open wraps the response body
+// in an io.LimitReader so a server that streams more than MaxBytes causes an
+// error rather than exhausting disk/memory.
+func TestHttp_Open_RejectsOversizedBody(t *testing.T) {
+ const cap int64 = 1024 // 1 KiB test cap
+
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // Stream cap+1 bytes so the limiter fires.
+ payload := strings.Repeat("x", int(cap)+1)
+ fmt.Fprint(w, payload)
+ }))
+ defer srv.Close()
+
+ // AllowInternalTargets=true because the test server binds to loopback.
+ h := getter.NewHttpWithOptions(getter.HttpOptions{
+ AllowInternalTargets: true,
+ MaxBytes: cap,
+ })
+ u, _ := url.Parse(srv.URL + "/big")
+ // The size cap must be enforced either at Open() (via Content-Length header)
+ // or at read time (via LimitReader). Both are acceptable.
+ rc, openErr := h.Open(context.Background(), u)
+ if openErr != nil {
+ // Content-Length header triggered the cap early — that is correct.
+ return
+ }
+ defer rc.Close()
+
+ _, readErr := io.ReadAll(rc)
+ if readErr == nil {
+ t.Fatal("expected an error from Open() or ReadAll() for oversized body, got neither")
+ }
+}
+
+// TestHttp_Open_Timeout verifies that Open uses a client with a timeout so
+// Slowloris-style servers do not hang indefinitely.
+func TestHttp_Open_Timeout(t *testing.T) {
+ if testing.Short() {
+ t.Skip("skipping slow-server test in short mode")
+ }
+
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ // Never write anything — simulate a stalled server.
+ <-r.Context().Done()
+ }))
+ defer srv.Close()
+
+ h := getter.NewHttpWithOptions(getter.HttpOptions{
+ AllowInternalTargets: true,
+ Timeout: 50 * time.Millisecond,
+ })
+ u, _ := url.Parse(srv.URL + "/slow")
+ _, err := h.Open(context.Background(), u)
+ if err == nil {
+ t.Fatal("Open() expected timeout error, got nil")
+ }
+}
+
+// --- A4: SSRF protection ---
+
+// TestHttp_Open_RejectsNonHTTPScheme verifies that file://, gopher://, etc.
+// are rejected before any network call is made.
+func TestHttp_Open_RejectsNonHTTPScheme(t *testing.T) {
+ for _, scheme := range []string{"file", "gopher", "ftp", "data"} {
+ t.Run(scheme, func(t *testing.T) {
+ h := getter.NewHttp()
+ u, _ := url.Parse(scheme + "://some/path")
+ _, err := h.Open(context.Background(), u)
+ if err == nil {
+ t.Fatalf("Open() expected error for scheme %q, got nil", scheme)
+ }
+ })
+ }
+}
+
+// TestHttp_Open_RejectsPrivateIPByDefault verifies that private/loopback
+// addresses are blocked unless AllowInternalTargets is set.
+func TestHttp_Open_RejectsPrivateIPByDefault(t *testing.T) {
+ // Use a real local server so the DNS resolution step completes.
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprint(w, "secret")
+ }))
+ defer srv.Close()
+
+ // srv.URL is http://127.0.0.1: — a loopback address.
+ h := getter.NewHttp() // default: AllowInternalTargets = false
+ u, _ := url.Parse(srv.URL + "/internal")
+ _, err := h.Open(context.Background(), u)
+ if err == nil {
+ t.Fatal("Open() expected SSRF rejection for loopback address, got nil")
+ }
+}
+
+// TestHttp_Open_AllowsPrivateIPWithFlag verifies that an explicit opt-in flag
+// lifts the private-IP restriction, enabling internal CI use cases.
+func TestHttp_Open_AllowsPrivateIPWithFlag(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprint(w, "ok")
+ }))
+ defer srv.Close()
+
+ h := getter.NewHttpWithOptions(getter.HttpOptions{AllowInternalTargets: true})
+ u, _ := url.Parse(srv.URL + "/internal")
+ rc, err := h.Open(context.Background(), u)
+ if err != nil {
+ t.Fatalf("Open() unexpected error with AllowInternalTargets=true: %v", err)
+ }
+ rc.Close()
+}
+
+// TestHttp_Open_RejectsRedirectToPrivateIP verifies that CheckRedirect
+// re-validates the destination on every hop, blocking public→private pivots.
+func TestHttp_Open_RejectsRedirectToPrivateIP(t *testing.T) {
+ // The "private" server just responds 200.
+ privateSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprint(w, "private data")
+ }))
+ defer privateSrv.Close()
+
+ // The "public" server redirects to the private server.
+ publicSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ http.Redirect(w, r, privateSrv.URL+"/secret", http.StatusFound)
+ }))
+ defer publicSrv.Close()
+
+ h := getter.NewHttp() // default: AllowInternalTargets = false
+ u, _ := url.Parse(publicSrv.URL + "/redirect")
+ _, err := h.Open(context.Background(), u)
+ if err == nil {
+ t.Fatal("Open() expected error when redirect targets private IP, got nil")
+ }
+}
+
+// TestHttp_Open_RejectsIPLiteralPrivate verifies that URLs containing a private,
+// loopback, link-local, or IMDS IP literal are rejected at dial time without
+// any external network round-trip.
+func TestHttp_Open_RejectsIPLiteralPrivate(t *testing.T) {
+ cases := []string{
+ "http://127.0.0.1:9/anything",
+ "http://10.0.0.1:9/anything",
+ "http://192.168.0.1:9/anything",
+ "http://169.254.169.254/latest/meta",
+ }
+ h := getter.NewHttp() // default AllowInternalTargets=false
+ for _, raw := range cases {
+ t.Run(raw, func(t *testing.T) {
+ u, _ := url.Parse(raw)
+ ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+ _, err := h.Open(ctx, u)
+ if err == nil {
+ t.Fatalf("Open(%s) expected SSRF rejection, got nil", raw)
+ }
+ })
+ }
+}
+
+// TestHttp_Open_RejectsHostnameResolvingToLoopback verifies that the dial-time
+// check inspects the *resolved* IP, not just IP literals. This is the
+// meaningful demonstration that DNS rebinding is closed: even when the URL
+// hostname is "localhost" (not an IP literal), the dial fires the SSRF check
+// against the resolved 127.0.0.1.
+func TestHttp_Open_RejectsHostnameResolvingToLoopback(t *testing.T) {
+ srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprint(w, "secret")
+ }))
+ defer srv.Close()
+
+ parsed, err := url.Parse(srv.URL)
+ if err != nil {
+ t.Fatalf("parse: %v", err)
+ }
+ parsed.Host = "localhost:" + parsed.Port()
+
+ h := getter.NewHttp() // default AllowInternalTargets=false
+ _, err = h.Open(context.Background(), parsed)
+ if err == nil {
+ t.Fatal("Open() expected SSRF rejection for hostname resolving to loopback, got nil")
+ }
+}
diff --git a/pkg/log/logcapture.go b/pkg/log/logcapture.go
index b36d10a5..bee9d538 100644
--- a/pkg/log/logcapture.go
+++ b/pkg/log/logcapture.go
@@ -42,8 +42,17 @@ func logStream(reader io.Reader, customWriter *CustomWriter, wg *sync.WaitGroup)
}
}
-// CaptureOutput redirects stdout and stderr to custom loggers and executes the provided function
-func CaptureOutput(logger Logger, debug bool, fn func() error) error {
+// captureOutputMu serializes all CaptureOutput calls so that concurrent
+// goroutines do not race on os.Stdout/os.Stderr.
+var captureOutputMu sync.Mutex
+
+// CaptureOutput redirects stdout and stderr to custom loggers and executes the provided function.
+// It is goroutine-safe: concurrent calls are serialized. A panic inside fn is
+// recovered, os.Stdout/os.Stderr are restored, and an error is returned.
+func CaptureOutput(logger Logger, debug bool, fn func() error) (retErr error) {
+ captureOutputMu.Lock()
+ defer captureOutputMu.Unlock()
+
// Create pipes for capturing stdout and stderr
stdoutReader, stdoutWriter, err := os.Pipe()
if err != nil {
@@ -51,6 +60,8 @@ func CaptureOutput(logger Logger, debug bool, fn func() error) error {
}
stderrReader, stderrWriter, err := os.Pipe()
if err != nil {
+ stdoutReader.Close()
+ stdoutWriter.Close()
return fmt.Errorf("failed to create stderr pipe: %w", err)
}
@@ -62,6 +73,15 @@ func CaptureOutput(logger Logger, debug bool, fn func() error) error {
os.Stdout = stdoutWriter
os.Stderr = stderrWriter
+ // Ensure FDs are always restored — even on panic.
+ defer func() {
+ os.Stdout = origStdout
+ os.Stderr = origStderr
+ if r := recover(); r != nil {
+ retErr = fmt.Errorf("CaptureOutput: recovered from panic: %v", r)
+ }
+ }()
+
// Use WaitGroup to wait for logging goroutines to finish
var wg sync.WaitGroup
wg.Add(2)
@@ -78,15 +98,21 @@ func CaptureOutput(logger Logger, debug bool, fn func() error) error {
// Run the provided function in a separate goroutine
fnErr := make(chan error, 1)
go func() {
+ defer func() {
+ // Propagate panics from fn back to the main goroutine via the error channel.
+ if r := recover(); r != nil {
+ fnErr <- fmt.Errorf("panic: %v", r)
+ }
+ stdoutWriter.Close() // Close writers to signal EOF to readers
+ stderrWriter.Close()
+ }()
fnErr <- fn()
- stdoutWriter.Close() // Close writers to signal EOF to readers
- stderrWriter.Close()
}()
// Wait for logging goroutines to finish
wg.Wait()
- // Restore original stdout and stderr
+ // Restore stdout/stderr early so the deferred restore is a no-op.
os.Stdout = origStdout
os.Stderr = origStderr
diff --git a/pkg/log/logcapture_test.go b/pkg/log/logcapture_test.go
new file mode 100644
index 00000000..6f2cd397
--- /dev/null
+++ b/pkg/log/logcapture_test.go
@@ -0,0 +1,103 @@
+package log_test
+
+import (
+ "fmt"
+ "io"
+ "os"
+ "sync"
+ "sync/atomic"
+ "testing"
+
+ haulerlog "hauler.dev/go/hauler/pkg/log"
+)
+
+// TestCaptureOutput_ConcurrentSafe verifies that concurrent calls to
+// CaptureOutput, each writing to os.Stdout from inside fn, do not race on
+// os.Stdout/os.Stderr. The mutex serializes the reassignment window so
+// concurrent capturers cannot see each other's writes. Run with -race to
+// catch any residual data race on os.Stdout/os.Stderr.
+func TestCaptureOutput_ConcurrentSafe(t *testing.T) {
+ const goroutines = 20
+
+ l := haulerlog.NewLogger(io.Discard)
+
+ var wg sync.WaitGroup
+ errs := make(chan error, goroutines)
+
+ for i := 0; i < goroutines; i++ {
+ wg.Add(1)
+ go func(idx int) {
+ defer wg.Done()
+ err := haulerlog.CaptureOutput(l, true, func() error {
+ _, _ = fmt.Fprintf(os.Stdout, "goroutine-%d-stdout\n", idx)
+ _, _ = fmt.Fprintf(os.Stderr, "goroutine-%d-stderr\n", idx)
+ return nil
+ })
+ if err != nil {
+ errs <- err
+ }
+ }(i)
+ }
+
+ wg.Wait()
+ close(errs)
+
+ for err := range errs {
+ t.Errorf("CaptureOutput() concurrent call returned error: %v", err)
+ }
+}
+
+// TestCaptureOutput_RestoresFDsAfterUse confirms that os.Stdout/os.Stderr are
+// restored to their original values after CaptureOutput returns, so subsequent
+// writes do not target a closed pipe.
+func TestCaptureOutput_RestoresFDsAfterUse(t *testing.T) {
+ l := haulerlog.NewLogger(io.Discard)
+ origStdout := os.Stdout
+ origStderr := os.Stderr
+
+ var captured atomic.Bool
+ if err := haulerlog.CaptureOutput(l, true, func() error {
+ captured.Store(true)
+ return nil
+ }); err != nil {
+ t.Fatalf("CaptureOutput() returned error: %v", err)
+ }
+ if !captured.Load() {
+ t.Fatal("fn was not invoked")
+ }
+
+ if os.Stdout != origStdout {
+ t.Error("os.Stdout was not restored after CaptureOutput returned")
+ }
+ if os.Stderr != origStderr {
+ t.Error("os.Stderr was not restored after CaptureOutput returned")
+ }
+}
+
+// TestCaptureOutput_PanicRestoresFDs verifies that a panic inside the
+// provided function does not leave os.Stdout/os.Stderr permanently redirected
+// and that subsequent calls succeed normally.
+func TestCaptureOutput_PanicRestoresFDs(t *testing.T) {
+ l := haulerlog.NewLogger(io.Discard)
+ origStdout := os.Stdout
+ origStderr := os.Stderr
+
+ err := haulerlog.CaptureOutput(l, true, func() error {
+ panic("intentional panic for test")
+ })
+ if err == nil {
+ t.Fatal("CaptureOutput() expected error after panic, got nil")
+ }
+
+ if os.Stdout != origStdout {
+ t.Error("os.Stdout was not restored after CaptureOutput panic")
+ }
+ if os.Stderr != origStderr {
+ t.Error("os.Stderr was not restored after CaptureOutput panic")
+ }
+
+ // Run a follow-up capture to confirm the global state is healthy.
+ if err := haulerlog.CaptureOutput(l, true, func() error { return nil }); err != nil {
+ t.Errorf("subsequent CaptureOutput() failed after panic recovery: %v", err)
+ }
+}
diff --git a/pkg/store/store.go b/pkg/store/store.go
index ada5a714..68ee0768 100644
--- a/pkg/store/store.go
+++ b/pkg/store/store.go
@@ -593,10 +593,13 @@ func (l *Layout) copyDescriptorGraph(ctx context.Context, desc ocispec.Descripto
}
}()
- data, err := io.ReadAll(rc)
+ data, err := io.ReadAll(io.LimitReader(rc, consts.MaxManifestBytes+1))
if err != nil {
return fmt.Errorf("failed to read manifest: %w", err)
}
+ if int64(len(data)) > consts.MaxManifestBytes {
+ return fmt.Errorf("manifest exceeds maximum allowed size (%d bytes)", consts.MaxManifestBytes)
+ }
var manifest ocispec.Manifest
if err := json.Unmarshal(data, &manifest); err != nil {
@@ -632,10 +635,13 @@ func (l *Layout) copyDescriptorGraph(ctx context.Context, desc ocispec.Descripto
}
}()
- data, err := io.ReadAll(rc)
+ data, err := io.ReadAll(io.LimitReader(rc, consts.MaxManifestBytes+1))
if err != nil {
return fmt.Errorf("failed to read index: %w", err)
}
+ if int64(len(data)) > consts.MaxManifestBytes {
+ return fmt.Errorf("index exceeds maximum allowed size (%d bytes)", consts.MaxManifestBytes)
+ }
var index ocispec.Index
if err := json.Unmarshal(data, &index); err != nil {