diff --git a/unpack/unpack.go b/unpack/unpack.go index bfa77e8..dbdff16 100644 --- a/unpack/unpack.go +++ b/unpack/unpack.go @@ -96,6 +96,21 @@ func validSuffix(filename string) bool { return false } +// canonicalExtractDir returns an absolute, symlink-resolved path for the +// extraction destination. It must be called before os.Chdir so that a relative +// destination is resolved against the original working directory, not the +// post-Chdir one. +func canonicalExtractDir(destination string) (string, error) { + absDir, err := filepath.Abs(destination) + if err != nil { + return "", fmt.Errorf("error defining the absolute path of '%s': %s", destination, err) + } + if resolved, err := filepath.EvalSymlinks(absDir); err == nil { + return resolved, nil + } + return absDir, nil +} + func UnpackXzTar(filename string, destination string, verbosityLevel int) (err error) { Verbose = verbosityLevel if !common.FileExists(filename) { @@ -108,7 +123,11 @@ func UnpackXzTar(filename string, destination string, verbosityLevel int) (err e if err != nil { return err } - err = os.Chdir(destination) + destinationAbs, err := canonicalExtractDir(destination) + if err != nil { + return err + } + err = os.Chdir(destinationAbs) if err != nil { return errors.Wrapf(err, "error changing directory to %s", destination) } @@ -125,7 +144,7 @@ func UnpackXzTar(filename string, destination string, verbosityLevel int) (err e } // Create a tar Reader tr := tar.NewReader(r) - return unpackTarFiles(tr, destination) + return unpackTarFiles(tr, destinationAbs) } func UnpackTar(filename string, destination string, verbosityLevel int) (err error) { @@ -147,7 +166,11 @@ func UnpackTar(filename string, destination string, verbosityLevel int) (err err return err } defer file.Close() // #nosec G307 - err = os.Chdir(destination) + destinationAbs, err := canonicalExtractDir(destination) + if err != nil { + return err + } + err = os.Chdir(destinationAbs) if err != nil { return errors.Wrapf(err, "error changing directory to %s", destination) } @@ -165,15 +188,16 @@ func UnpackTar(filename string, destination string, verbosityLevel int) (err err } else { reader = tar.NewReader(fileReader) } - return unpackTarFiles(reader, destination) + return unpackTarFiles(reader, destinationAbs) } -func unpackTarFiles(reader *tar.Reader, extractDir string) error { +// unpackTarFiles extracts reader's entries into extractAbsDir. The caller must +// supply an absolute, symlink-resolved path (see canonicalExtractDir) so the +// validation helpers can compare canonical paths for containment. +func unpackTarFiles(reader *tar.Reader, extractAbsDir string) error { const errLinkedDirectoryOutside = "linked directory '%s' is outside the extraction directory" - extractAbsDir, err := filepath.Abs(extractDir) - if err != nil { - return fmt.Errorf("error defining the absolute path of '%s': %s", extractDir, err) - } + const errDirectoryOutside = "directory for entry '%s' is outside the extraction directory" + var err error var header *tar.Header var count int = 0 var reSlash = regexp.MustCompile(`/.*`) @@ -239,7 +263,18 @@ func unpackTarFiles(reader *tar.Reader, extractDir string) error { innerDir = upperDir } - if _, err = os.Stat(fileDir); os.IsNotExist(err) { + absFilePath := filepath.Join(extractAbsDir, filename) + absFileDir := filepath.Dir(absFilePath) + + // Validate that the entry's parent directory (after resolving any symlinks + // created by previous tar entries) stays inside the extraction directory. + // This closes the chain-symlink traversal bypass where an earlier entry + // creates a symlink whose realpath escapes extractAbsDir. + if _, err := resolveInsideExtractDir(absFileDir, extractAbsDir); err != nil { + return fmt.Errorf(errDirectoryOutside, filename) + } + + if _, err = os.Lstat(fileDir); os.IsNotExist(err) { if err = os.MkdirAll(fileDir, globals.PublicDirectoryAttr); err != nil { return err } @@ -254,6 +289,11 @@ func unpackTarFiles(reader *tar.Reader, extractDir string) error { return err } case tar.TypeReg: + // Refuse to write through a pre-existing symlink at the target name: + // os.Create would follow it and write outside the extraction directory. + if info, lerr := os.Lstat(filename); lerr == nil && info.Mode()&os.ModeSymlink != 0 { + return fmt.Errorf("refusing to overwrite existing symlink at '%s'", filename) + } if err = unpackTarFile(filename, reader); err != nil { return err } @@ -273,45 +313,85 @@ func unpackTarFiles(reader *tar.Reader, extractDir string) error { } } case tar.TypeSymlink: - if header.Linkname != "" { - linkDepth := pathDepth(header.Linkname) - nameDepth := pathDepth(header.Name) - if linkDepth > nameDepth { - fmt.Println() - return fmt.Errorf(errLinkedDirectoryOutside, header.Linkname) - } - if common.FileExists(header.Linkname) { - absFile, err := filepath.Abs(header.Linkname) - if err != nil { - return fmt.Errorf("error retrieving absolute path of %s: %s", header.Linkname, err) - } - if !common.BeginsWith(absFile, extractAbsDir) { - return fmt.Errorf(errLinkedDirectoryOutside, header.Linkname) - } - } else { - if common.BeginsWith(header.Linkname, "/") { - if !common.BeginsWith(header.Linkname, extractAbsDir) { - return fmt.Errorf(errLinkedDirectoryOutside, header.Linkname) - } - } - } - condPrint(fmt.Sprintf("%s -> %s", filename, header.Linkname), true, CHATTY) - err = os.Symlink(header.Linkname, filename) - if err != nil { - return fmt.Errorf("%#v\n#ERROR: %s", header, err) - } - } else { + if header.Linkname == "" { return fmt.Errorf("file %s is a symlink, but no link information was provided", filename) } + // Build the absolute path the symlink would point to. We concatenate + // with a raw separator instead of filepath.Join so that ".." components + // in Linkname are preserved: EvalSymlinks must walk through any + // intermediate symlinks before evaluating ".." against their real + // targets. filepath.Join would lexically collapse the ".." and miss + // chain-symlink escapes. + var targetPath string + if filepath.IsAbs(header.Linkname) { + targetPath = header.Linkname + } else { + targetPath = absFileDir + string(os.PathSeparator) + header.Linkname + } + if _, err := resolveInsideExtractDir(targetPath, extractAbsDir); err != nil { + return fmt.Errorf(errLinkedDirectoryOutside, header.Linkname) + } + condPrint(fmt.Sprintf("%s -> %s", filename, header.Linkname), true, CHATTY) + err = os.Symlink(header.Linkname, filename) + if err != nil { + return fmt.Errorf("%#v\n#ERROR: %s", header, err) + } } } // return nil } -func pathDepth(s string) int { - reSlash := regexp.MustCompilePOSIX("(/)") - list := reSlash.FindAllStringIndex(s, -1) - return len(list) +// resolveInsideExtractDir resolves target through the filesystem (following any +// existing symlinks, resolving ".." components *after* symlink expansion) and +// confirms the result is inside extractAbsDir. When the full target does not +// exist yet, the deepest existing ancestor is resolved and the remaining +// lexical suffix is appended; this lets us validate symlinks whose targets +// have not been created yet without losing chain-traversal detection (any +// ".." that would cross a symlink lives in an existing ancestor, so it gets +// resolved through the filesystem rather than lexically). +func resolveInsideExtractDir(target, extractAbsDir string) (string, error) { + if resolved, err := filepath.EvalSymlinks(target); err == nil { + if !pathInside(resolved, extractAbsDir) { + return "", fmt.Errorf("path '%s' resolves to '%s' outside extraction directory '%s'", target, resolved, extractAbsDir) + } + return resolved, nil + } + // Walk up one directory at a time using filepath.Dir so volume roots + // (e.g. "/" on POSIX, "C:\" on Windows) are handled portably. Terminate + // at the fixed point, where filepath.Dir no longer shortens the path. + parent := filepath.Dir(target) + for { + if resolved, err := filepath.EvalSymlinks(parent); err == nil { + if !pathInside(resolved, extractAbsDir) { + return "", fmt.Errorf("ancestor of '%s' resolves to '%s' outside extraction directory '%s'", target, resolved, extractAbsDir) + } + rel, err := filepath.Rel(parent, target) + if err != nil { + return "", err + } + combined := filepath.Join(resolved, rel) + if !pathInside(combined, extractAbsDir) { + return "", fmt.Errorf("path '%s' would resolve to '%s' outside extraction directory '%s'", target, combined, extractAbsDir) + } + return combined, nil + } + next := filepath.Dir(parent) + if next == parent { + return "", fmt.Errorf("cannot resolve any ancestor of '%s'", target) + } + parent = next + } +} + +func pathInside(candidate, dir string) bool { + if candidate == dir { + return true + } + sep := string(os.PathSeparator) + if !strings.HasSuffix(dir, sep) { + dir += sep + } + return strings.HasPrefix(candidate, dir) } func unpackTarFile(filename string, diff --git a/unpack/unpack_test.go b/unpack/unpack_test.go index 37b31b7..5137afb 100644 --- a/unpack/unpack_test.go +++ b/unpack/unpack_test.go @@ -16,40 +16,246 @@ package unpack import ( - "path" + "archive/tar" + "bytes" + "io" + "os" + "path/filepath" "strings" "testing" ) -func Test_pathDepth(t *testing.T) { - type args struct { - s string - } - tests := []struct { - name string - args args - want int - }{ - {"empty", args{""}, 0}, - {"~", args{"~"}, 0}, - {"lorem ipsum", args{"lorem ipsum"}, 0}, - {"/", args{"/"}, 1}, - {"./", args{"./"}, 1}, - {"repeat", args{strings.Repeat("/", 10)}, 10}, - {"path_join", args{path.Join("one", "two", "three")}, 2}, - {"strings_join", args{strings.Join([]string{"one", "two", "three"}, "/")}, 2}, - {"////", args{"////"}, 4}, - {"/etc/", args{"/etc/"}, 2}, - {"/etc/something", args{"/etc/something"}, 2}, - {"../etc", args{"../etc"}, 1}, - {"../../../etc", args{"../../../etc"}, 3}, - {"../../../../../etc", args{"../../../../../etc"}, 5}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := pathDepth(tt.args.s); got != tt.want { - t.Errorf("pathDepth(%s) = %v, want %v", tt.args.s, got, tt.want) +// writeTar builds a tar archive at path from the supplied entries. Each entry's +// Size is set to len(Body) when the entry is a regular file; symlinks carry a +// Linkname and zero size. +type tarEntry struct { + Name string + Linkname string + Typeflag byte + Body []byte +} + +func writeTar(t *testing.T, path string, entries []tarEntry) { + t.Helper() + f, err := os.Create(path) + if err != nil { + t.Fatalf("create tar: %v", err) + } + defer f.Close() + tw := tar.NewWriter(f) + for _, e := range entries { + hdr := &tar.Header{ + Name: e.Name, + Linkname: e.Linkname, + Typeflag: e.Typeflag, + Mode: 0o644, + } + if e.Typeflag == tar.TypeReg { + hdr.Size = int64(len(e.Body)) + } + if err := tw.WriteHeader(hdr); err != nil { + t.Fatalf("write header for %s: %v", e.Name, err) + } + if e.Typeflag == tar.TypeReg && len(e.Body) > 0 { + if _, err := io.Copy(tw, bytes.NewReader(e.Body)); err != nil { + t.Fatalf("write body for %s: %v", e.Name, err) } - }) + } + } + if err := tw.Close(); err != nil { + t.Fatalf("close tar: %v", err) + } +} + +// unpackTarForTest wraps UnpackTar so that the process cwd is restored after +// the call. UnpackTar does os.Chdir(destination); without a restore, tests +// that use t.TempDir() would leak a deleted cwd into subsequent tests. +func unpackTarForTest(t *testing.T, tarPath, dest string) error { + t.Helper() + cwd, err := os.Getwd() + if err != nil { + t.Fatalf("get cwd: %v", err) + } + t.Cleanup(func() { _ = os.Chdir(cwd) }) + return UnpackTar(tarPath, dest, SILENT) +} + +// Test_symlinkChainEscape reproduces the PoC from the report: a chain of +// dirN -> dirN-1/.. symlinks whose cumulative realpath climbs above the +// extraction directory, followed by a pivot symlink with a path depth equal +// to the entry name (so the previous pathDepth heuristic cannot catch it) +// and a regular file written through the pivot. The extraction must be +// rejected and no file may appear outside the extraction directory. +func Test_symlinkChainEscape(t *testing.T) { + root := t.TempDir() + dest := filepath.Join(root, "dest") + victim := filepath.Join(root, "victim") + if err := os.MkdirAll(dest, 0o755); err != nil { + t.Fatal(err) + } + if err := os.MkdirAll(victim, 0o755); err != nil { + t.Fatal(err) + } + + // Three ".." hops are enough to climb from dest up to `root`, where + // sibling `victim` lives. The pivot then names `dir3/victim`, which has + // one path separator — the same as the symlink name `test/myVictim`. + entries := []tarEntry{ + {Name: "test/dir0/baseFile.txt", Typeflag: tar.TypeReg, Body: []byte("base")}, + {Name: "test/dir1", Linkname: "dir0/..", Typeflag: tar.TypeSymlink}, + {Name: "test/dir2", Linkname: "dir1/..", Typeflag: tar.TypeSymlink}, + {Name: "test/dir3", Linkname: "dir2/..", Typeflag: tar.TypeSymlink}, + {Name: "test/myVictim", Linkname: "dir3/victim", Typeflag: tar.TypeSymlink}, + {Name: "test/myVictim/Exp.txt", Typeflag: tar.TypeReg, Body: []byte("Malicious Text\n")}, + } + tarPath := filepath.Join(root, "archive.tar") + writeTar(t, tarPath, entries) + + err := unpackTarForTest(t, tarPath, dest) + if err == nil { + t.Fatalf("expected extraction to fail, but it succeeded") + } + if !strings.Contains(err.Error(), "outside the extraction directory") { + t.Fatalf("expected 'outside the extraction directory' error, got: %v", err) + } + + exp := filepath.Join(victim, "Exp.txt") + if _, err := os.Lstat(exp); err == nil { + t.Fatalf("attacker file was written to %s despite error — fix did not block the write", exp) + } +} + +// Test_symlinkSingleHopEscape covers the simpler case of a single symlink +// whose target escapes via "..", which the previous pathDepth heuristic +// already caught. Kept as a regression guard so the new code still rejects it. +func Test_symlinkSingleHopEscape(t *testing.T) { + root := t.TempDir() + dest := filepath.Join(root, "dest") + if err := os.MkdirAll(dest, 0o755); err != nil { + t.Fatal(err) + } + entries := []tarEntry{ + {Name: "test/placeholder.txt", Typeflag: tar.TypeReg, Body: []byte("x")}, + {Name: "test/escape", Linkname: "../../etc", Typeflag: tar.TypeSymlink}, + } + tarPath := filepath.Join(root, "archive.tar") + writeTar(t, tarPath, entries) + + if err := unpackTarForTest(t, tarPath, dest); err == nil { + t.Fatalf("expected extraction to fail for escaping symlink") + } +} + +// Test_symlinkAbsoluteTarget rejects an absolute-path symlink that points +// outside the extraction directory. +func Test_symlinkAbsoluteTarget(t *testing.T) { + root := t.TempDir() + dest := filepath.Join(root, "dest") + if err := os.MkdirAll(dest, 0o755); err != nil { + t.Fatal(err) + } + entries := []tarEntry{ + {Name: "test/placeholder.txt", Typeflag: tar.TypeReg, Body: []byte("x")}, + {Name: "test/passwd", Linkname: "/etc/passwd", Typeflag: tar.TypeSymlink}, + } + tarPath := filepath.Join(root, "archive.tar") + writeTar(t, tarPath, entries) + + if err := unpackTarForTest(t, tarPath, dest); err == nil { + t.Fatalf("expected extraction to fail for absolute-path symlink") + } +} + +// Test_legitimateSymlinkPreserved confirms the fix does not regress normal +// same-directory symlinks of the kind real MySQL tarballs contain +// (e.g. lib/libssl.dylib -> libssl.1.0.0.dylib). +func Test_legitimateSymlinkPreserved(t *testing.T) { + root := t.TempDir() + dest := filepath.Join(root, "dest") + if err := os.MkdirAll(dest, 0o755); err != nil { + t.Fatal(err) + } + entries := []tarEntry{ + {Name: "mysql/lib/libssl.1.0.0.dylib", Typeflag: tar.TypeReg, Body: []byte("real")}, + {Name: "mysql/lib/libssl.dylib", Linkname: "libssl.1.0.0.dylib", Typeflag: tar.TypeSymlink}, + } + tarPath := filepath.Join(root, "archive.tar") + writeTar(t, tarPath, entries) + + if err := unpackTarForTest(t, tarPath, dest); err != nil { + t.Fatalf("unexpected extraction failure: %v", err) + } + linkPath := filepath.Join(dest, "mysql", "lib", "libssl.dylib") + info, err := os.Lstat(linkPath) + if err != nil { + t.Fatalf("expected symlink at %s: %v", linkPath, err) + } + if info.Mode()&os.ModeSymlink == 0 { + t.Fatalf("expected %s to be a symlink", linkPath) + } +} + +// Test_refuseOverwriteSymlinkWithRegular covers the belt-and-suspenders +// protection: even if a malicious entry created a symlink that passed +// validation, a subsequent regular-file entry with the same name must not +// be allowed to write through it. Here we exercise the guard by extracting +// a legitimate intra-archive symlink followed by a regular file at the +// symlink's own path; that latter entry would otherwise follow the symlink. +func Test_refuseOverwriteSymlinkWithRegular(t *testing.T) { + root := t.TempDir() + dest := filepath.Join(root, "dest") + if err := os.MkdirAll(dest, 0o755); err != nil { + t.Fatal(err) + } + entries := []tarEntry{ + {Name: "test/target.txt", Typeflag: tar.TypeReg, Body: []byte("original")}, + {Name: "test/alias", Linkname: "target.txt", Typeflag: tar.TypeSymlink}, + {Name: "test/alias", Typeflag: tar.TypeReg, Body: []byte("overwrite attempt")}, + } + tarPath := filepath.Join(root, "archive.tar") + writeTar(t, tarPath, entries) + + if err := unpackTarForTest(t, tarPath, dest); err == nil { + t.Fatalf("expected extraction to fail when a regular file would overwrite a symlink") + } + body, err := os.ReadFile(filepath.Join(dest, "test", "target.txt")) + if err != nil { + t.Fatalf("target.txt missing: %v", err) + } + if string(body) != "original" { + t.Fatalf("target.txt was modified through symlink: got %q", string(body)) + } +} + +// Test_relativeDestination verifies that a caller may pass a relative +// destination path. The canonicalization must happen before the internal +// os.Chdir, otherwise the resolved extraction directory would be incorrect. +func Test_relativeDestination(t *testing.T) { + root := t.TempDir() + if err := os.MkdirAll(filepath.Join(root, "dest"), 0o755); err != nil { + t.Fatal(err) + } + entries := []tarEntry{ + {Name: "mysql/lib/libssl.1.0.0.dylib", Typeflag: tar.TypeReg, Body: []byte("real")}, + {Name: "mysql/lib/libssl.dylib", Linkname: "libssl.1.0.0.dylib", Typeflag: tar.TypeSymlink}, + } + tarPath := filepath.Join(root, "archive.tar") + writeTar(t, tarPath, entries) + + origCwd, err := os.Getwd() + if err != nil { + t.Fatalf("get cwd: %v", err) + } + t.Cleanup(func() { _ = os.Chdir(origCwd) }) + if err := os.Chdir(root); err != nil { + t.Fatalf("chdir to root: %v", err) + } + + if err := UnpackTar(tarPath, "dest", SILENT); err != nil { + t.Fatalf("unexpected extraction failure with relative destination: %v", err) + } + linkPath := filepath.Join(root, "dest", "mysql", "lib", "libssl.dylib") + if _, err := os.Lstat(linkPath); err != nil { + t.Fatalf("expected symlink at %s: %v", linkPath, err) } }