Skip to content

Commit

Permalink
Only process request files if they are no less than a minute old.
Browse files Browse the repository at this point in the history
This improves the protection against replay attacks by ruling out old
request files. Only new request files are potentially dangerous, and
these are handled by locks on the active request registry.

Also moved all of the request-related functionality into their own file
for easier development and testing.
  • Loading branch information
LTLA committed Sep 17, 2024
1 parent a2ba8b4 commit 5339ff2
Show file tree
Hide file tree
Showing 3 changed files with 192 additions and 104 deletions.
107 changes: 12 additions & 95 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,9 @@ import (
"os"
"errors"
"strings"
"fmt"
"encoding/json"
"net/http"
"strconv"
"io/fs"
"syscall"
"sync"
)

func dumpJsonResponse(w http.ResponseWriter, status int, v interface{}, path string) {
Expand Down Expand Up @@ -59,94 +55,6 @@ func configureCors(w http.ResponseWriter, r *http.Request) bool {

/***************************************************/

func checkRequestFile(path, staging string) (string, error) {
if !strings.HasPrefix(path, "request-") {
return "", newHttpError(http.StatusBadRequest, errors.New("file name should start with \"request-\""))
}

if !filepath.IsLocal(path) {
return "", newHttpError(http.StatusBadRequest, errors.New("path should be local to the staging directory"))
}
reqpath := filepath.Join(staging, path)

info, err := os.Lstat(reqpath)
if err != nil {
return "", newHttpError(http.StatusBadRequest, fmt.Errorf("failed to access path; %v", err))
}

if info.IsDir() {
return "", newHttpError(http.StatusBadRequest, errors.New("path is a directory"))
}

if info.Mode() & fs.ModeSymlink != 0 {
return "", newHttpError(http.StatusBadRequest, errors.New("path is a symbolic link"))
}

s, ok := info.Sys().(*syscall.Stat_t)
if !ok {
return "", fmt.Errorf("failed to convert to a syscall.Stat_t; %w", err)
}
if uint32(s.Nlink) > 1 {
return "", newHttpError(http.StatusBadRequest, errors.New("path seems to have multiple hard links"))
}

return reqpath, nil
}

/***************************************************/

// This tracks the requests that are currently being processed, to prevent the
// same request being processed multiple times at the same time. We use a
// multi-pool approach to improve parallelism across requests.
type activeRegistry struct {
NumPools int
Locks []sync.Mutex
Active []map[string]bool
}

func newActiveRegistry(num_pools int) *activeRegistry {
return &activeRegistry {
NumPools: num_pools,
Locks: make([]sync.Mutex, num_pools),
Active: make([]map[string]bool, num_pools),
}
}

func (a *activeRegistry) choosePool(path string) int {
sum := 0
for _, r := range path {
sum += int(r)
}
return sum % a.NumPools
}

func (a *activeRegistry) Add(path string) bool {
i := a.choosePool(path)
a.Locks[i].Lock()
defer a.Locks[i].Unlock()

if a.Active[i] == nil {
a.Active[i] = map[string]bool{}
} else {
_, ok := a.Active[i][path]
if ok {
return false
}
}

a.Active[i][path] = true
return true
}

func (a *activeRegistry) Remove(path string) {
i := a.choosePool(path)
a.Locks[i].Lock()
defer a.Locks[i].Unlock()
delete(a.Active[i], path)
}

/***************************************************/

func main() {
spath := flag.String("staging", "", "Path to the staging directory")
rpath := flag.String("registry", "", "Path to the registry")
Expand Down Expand Up @@ -174,7 +82,12 @@ func main() {
}
}

actreg := newActiveRegistry(11)
actreg := newActiveRequestRegistry(11)
const request_expiry = time.Minute
err := prefillActiveRequestRegistry(actreg, staging, request_expiry)
if err != nil {
log.Fatalf("failed to prefill active request registry; %v", err)
}

endpt_prefix := *prefix
if endpt_prefix != "" {
Expand All @@ -186,7 +99,7 @@ func main() {
path := r.PathValue("path")
log.Println("processing " + path)

reqpath, err := checkRequestFile(path, staging)
reqpath, err := checkRequestFile(path, staging, request_expiry)
if err != nil {
dumpHttpErrorResponse(w, err, path)
return
Expand Down Expand Up @@ -244,10 +157,14 @@ func main() {

// Purge the request file once it's processed, to reduce the potential
// for replay attacks. For safety's sake, we only remove it from the
// registry if the request file was properly deleted.
// registry if the request file was properly deleted or it expired.
err = os.Remove(reqpath)
if err != nil {
log.Printf("failed to purge the request file at %q; %v", path, err)
go func() {
time.Sleep(request_expiry)
actreg.Remove(path)
}()
} else {
actreg.Remove(path)
}
Expand Down
123 changes: 123 additions & 0 deletions request.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package main

import (
"sync"
"strings"
"path/filepath"
"net/http"
"time"
"os"
"fmt"
"errors"
"io/fs"
"syscall"
)

func chooseLockPool(path string, num_pools int) int {
sum := 0
for _, r := range path {
sum += int(r)
}
return sum % num_pools
}

// This tracks the requests that are currently being processed, to prevent the
// same request being processed multiple times at the same time. We use a
// multi-pool approach to improve parallelism across requests.
type activeRequestRegistry struct {
NumPools int
Locks []sync.Mutex
Active []map[string]bool
}

func newActiveRequestRegistry(num_pools int) *activeRequestRegistry {
return &activeRequestRegistry {
NumPools: num_pools,
Locks: make([]sync.Mutex, num_pools),
Active: make([]map[string]bool, num_pools),
}
}

func prefillActiveRequestRegistry(a *activeRequestRegistry, staging string, expiry time.Duration) error {
// Prefilling the registry ensures that a user can't replay requests after a restart of the service.
entries, err := os.ReadDir(staging)
if err != nil {
return fmt.Errorf("failed to list existing request files in '%s'", staging)
}

// This is only necessary until the expiry time is exceeded, after which we can evict those entries.
// Technically we only need to do this for files that weren't already expired, but this doesn't hurt.
for _, e := range entries {
path := e.Name()
a.Add(path)
go func(p string) {
time.Sleep(expiry)
a.Remove(p)
}(path)
}
return nil
}

func (a *activeRequestRegistry) Add(path string) bool {
i := chooseLockPool(path, a.NumPools)
a.Locks[i].Lock()
defer a.Locks[i].Unlock()

if a.Active[i] == nil {
a.Active[i] = map[string]bool{}
} else {
_, ok := a.Active[i][path]
if ok {
return false
}
}

a.Active[i][path] = true
return true
}

func (a *activeRequestRegistry) Remove(path string) {
i := chooseLockPool(path, a.NumPools)
a.Locks[i].Lock()
defer a.Locks[i].Unlock()
delete(a.Active[i], path)
}

func checkRequestFile(path, staging string, expiry time.Duration) (string, error) {
if !strings.HasPrefix(path, "request-") {
return "", newHttpError(http.StatusBadRequest, errors.New("file name should start with \"request-\""))
}

if !filepath.IsLocal(path) {
return "", newHttpError(http.StatusBadRequest, errors.New("path should be local to the staging directory"))
}
reqpath := filepath.Join(staging, path)

info, err := os.Lstat(reqpath)
if err != nil {
return "", newHttpError(http.StatusBadRequest, fmt.Errorf("failed to access path; %v", err))
}

if info.IsDir() {
return "", newHttpError(http.StatusBadRequest, errors.New("path is a directory"))
}

if info.Mode() & fs.ModeSymlink != 0 {
return "", newHttpError(http.StatusBadRequest, errors.New("path is a symbolic link"))
}

s, ok := info.Sys().(*syscall.Stat_t)
if !ok {
return "", fmt.Errorf("failed to convert to a syscall.Stat_t; %w", err)
}
if uint32(s.Nlink) > 1 {
return "", newHttpError(http.StatusBadRequest, errors.New("path seems to have multiple hard links"))
}

current := time.Now()
if current.Sub(info.ModTime()) >= expiry {
return "", newHttpError(http.StatusBadRequest, errors.New("request file is expired"))
}

return reqpath, nil
}
Loading

0 comments on commit 5339ff2

Please sign in to comment.