Skip to content

Commit

Permalink
Implement Authorization Callback Function (#108)
Browse files Browse the repository at this point in the history
* Implement Authorization Callback Function

* Skip new client unit test if knox daemon is running

* add tests for SetAccessCallback

* check if knox daemon is running if on linux and using systemctl instead of pgrep

* add tests for authorizeRequest

* address nits and add testing for panic

* modify requests to return a internal server error if the authorizeRequest returns error

* run all top level tests, reset callback in tests

* Address nits in PR for authorization callback

* make static definition of expected error into testCase field

---------

Co-authored-by: henryluo <henryluo@pinterest.com>
  • Loading branch information
henryluo and henryluo authored Oct 21, 2024
1 parent 9635b33 commit 62dd952
Show file tree
Hide file tree
Showing 7 changed files with 252 additions and 7 deletions.
28 changes: 26 additions & 2 deletions client_test.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
package knox

import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"os/exec"
"path"
"reflect"
"runtime"
"sync/atomic"
"testing"
)
Expand Down Expand Up @@ -75,7 +78,7 @@ func buildServer(code int, body []byte, a func(r *http.Request)) *httptest.Serve
}))
}

func buildConcurrentServer(code int, t *testing.T, a func(r *http.Request) []byte) *httptest.Server {
func buildConcurrentServer(code int, a func(r *http.Request) []byte) *httptest.Server {
return httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := a(r)
w.WriteHeader(code)
Expand All @@ -84,6 +87,23 @@ func buildConcurrentServer(code int, t *testing.T, a func(r *http.Request) []byt
}))
}

func isKnoxDaemonRunning() bool {
if runtime.GOOS != "linux" {
return false
}

cmd := exec.Command("systemctl", "is-active", "--quiet", "knox")

var out bytes.Buffer
cmd.Stdout = &out
err := cmd.Run()
if err == nil {
return true
}

return false
}

func TestGetKey(t *testing.T) {
expected := Key{
ID: "testkey",
Expand Down Expand Up @@ -357,7 +377,7 @@ func TestPutAccess(t *testing.T) {

func TestConcurrentDeletes(t *testing.T) {
var ops uint64
srv := buildConcurrentServer(200, t, func(r *http.Request) []byte {
srv := buildConcurrentServer(200, func(r *http.Request) []byte {
if r.Method != "DELETE" {
t.Fatalf("%s is not DELETE", r.Method)
}
Expand Down Expand Up @@ -511,6 +531,10 @@ func TestGetInvalidKeys(t *testing.T) {
}

func TestNewFileClient(t *testing.T) {
if isKnoxDaemonRunning() {
t.Skip("Knox daemon is running, skipping the test.")
}

_, err := NewFileClient("ThisKeyDoesNotExistSoWeExpectAnError")
if (err.Error() != "error getting knox key ThisKeyDoesNotExistSoWeExpectAnError. error: exit status 1") && (err.Error() != "error getting knox key ThisKeyDoesNotExistSoWeExpectAnError. error: exec: \"knox\": executable file not found in $PATH") {
t.Fatal("Unexpected error", err.Error())
Expand Down
24 changes: 24 additions & 0 deletions knox.go
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,14 @@ type Principal interface {
CanAccess(ACL, AccessType) bool
GetID() string
Type() string
Raw() []RawPrincipal
}

// RawPrincipal is a serializable version of a principal for passing to
// access callbacks.
type RawPrincipal struct {
ID string `json:"id"`
Type string `json:"type"`
}

// PrincipalMux provides a Principal Interface over multiple Principals.
Expand Down Expand Up @@ -564,6 +572,15 @@ func (p PrincipalMux) Default() Principal {
return p.defaultPrincipal
}

// Raw returns the raw version of all the principals.
func (p PrincipalMux) Raw() []RawPrincipal {
raw := []RawPrincipal{}
for _, principal := range p.allPrincipals {
raw = append(raw, principal.Raw()...)
}
return raw
}

// NewPrincipalMux returns a Principal that represents many principals.
func NewPrincipalMux(defaultPrincipal Principal, allPrincipals map[string]Principal) Principal {
return PrincipalMux{
Expand Down Expand Up @@ -599,3 +616,10 @@ type Response struct {
Message string `json:"message"`
Data interface{} `json:"data"`
}

// AccessCallbackInput is the input to the access callback function.
type AccessCallbackInput struct {
Key Key `json:"key"`
Principals []RawPrincipal `json:"principals"`
AccessType AccessType `json:"access_type"`
}
7 changes: 7 additions & 0 deletions server/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,13 @@ func AddDefaultAccess(a *knox.Access) {
defaultAccess = append(defaultAccess, *a)
}

var accessCallback func(knox.AccessCallbackInput) (bool, error)

// SetAccessCallback adds a callback.
func SetAccessCallback(callback func(knox.AccessCallbackInput) (bool, error)) {
accessCallback = callback
}

// Extra validators to apply on principals submitted to Knox.
var extraPrincipalValidators []knox.PrincipalValidator

Expand Down
25 changes: 25 additions & 0 deletions server/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ func additionalMockHandler(m KeyManager, principal knox.Principal, parameters ma
return "The meaning of life is 42", nil
}

func mockAccessCallback(input knox.AccessCallbackInput) (bool, error) {
return true, nil
}

func mockRoute() Route {
return Route{
Method: "GET",
Expand Down Expand Up @@ -105,6 +109,27 @@ func TestAddDefaultAccess(t *testing.T) {

}

func TestSetAccessCallback(t *testing.T) {
defer SetAccessCallback(nil)

SetAccessCallback(mockAccessCallback)

input := knox.AccessCallbackInput{}

if accessCallback == nil {
t.Fatal("accessCallback should not be nil")
}

canAccess, err := accessCallback(input)
if err != nil {
t.Fatal("accessCallback should not return an error")
}

if !canAccess {
t.Fatal("accessCallback should return true")
}
}

func TestParseFormParameter(t *testing.T) {
p := PostParameter("key")

Expand Down
27 changes: 27 additions & 0 deletions server/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,15 @@ func (u user) GetID() string {
return u.ID
}

func (u user) Raw() []knox.RawPrincipal {
return []knox.RawPrincipal{
{
ID: u.GetID(),
Type: u.Type(),
},
}
}

// Type returns the underlying type of a principal, for logging/debugging purposes.
func (u user) Type() string {
return "user"
Expand Down Expand Up @@ -415,6 +424,15 @@ func (m machine) Type() string {
return "machine"
}

func (m machine) Raw() []knox.RawPrincipal {
return []knox.RawPrincipal{
{
ID: m.GetID(),
Type: m.Type(),
},
}
}

// CanAccess determines if a Machine can access an object represented by the ACL
// with a certain AccessType. It compares Machine hostname and hostname prefix.
func (m machine) CanAccess(acl knox.ACL, t knox.AccessType) bool {
Expand Down Expand Up @@ -450,6 +468,15 @@ func (s service) Type() string {
return "service"
}

func (s service) Raw() []knox.RawPrincipal {
return []knox.RawPrincipal{
{
ID: s.GetID(),
Type: s.Type(),
},
}
}

// CanAccess determines if a Service can access an object represented by the ACL
// with a certain AccessType. It compares Service id and id prefix.
func (s service) CanAccess(acl knox.ACL, t knox.AccessType) bool {
Expand Down
59 changes: 54 additions & 5 deletions server/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"strconv"

"github.com/pinterest/knox"
"github.com/pinterest/knox/log"
"github.com/pinterest/knox/server/auth"
)

Expand Down Expand Up @@ -211,9 +212,15 @@ func getKeyHandler(m KeyManager, principal knox.Principal, parameters map[string
}

// Authorize access to data
if !principal.CanAccess(key.ACL, knox.Read) {
authorized, authzErr := authorizeRequest(key, principal, knox.Read)
if authzErr != nil {
return nil, errF(knox.InternalServerErrorCode, authzErr.Error())
}

if !authorized {
return nil, errF(knox.UnauthorizedCode, fmt.Sprintf("Principal %s not authorized to read %s", principal.GetID(), keyID))
}

// Zero ACL for key response, in order to avoid caching unnecessarily
key.ACL = knox.ACL{}
return key, nil
Expand All @@ -234,7 +241,12 @@ func deleteKeyHandler(m KeyManager, principal knox.Principal, parameters map[str
}

// Authorize
if !principal.CanAccess(key.ACL, knox.Admin) {
authorized, authzErr := authorizeRequest(key, principal, knox.Admin)
if authzErr != nil {
return nil, errF(knox.InternalServerErrorCode, authzErr.Error())
}

if !authorized {
return nil, errF(knox.UnauthorizedCode, fmt.Sprintf("Principal %s not authorized to delete %s", principal.GetID(), keyID))
}

Expand Down Expand Up @@ -314,7 +326,12 @@ func putAccessHandler(m KeyManager, principal knox.Principal, parameters map[str
}

// Authorize
if !principal.CanAccess(key.ACL, knox.Admin) {
authorized, authzErr := authorizeRequest(key, principal, knox.Admin)
if authzErr != nil {
return nil, errF(knox.InternalServerErrorCode, authzErr.Error())
}

if !authorized {
return nil, errF(knox.UnauthorizedCode, fmt.Sprintf("Principal %s not authorized to update access for %s", principal.GetID(), keyID))
}

Expand Down Expand Up @@ -371,7 +388,12 @@ func postVersionHandler(m KeyManager, principal knox.Principal, parameters map[s
}

// Authorize
if !principal.CanAccess(key.ACL, knox.Write) {
authorized, authzErr := authorizeRequest(key, principal, knox.Write)
if authzErr != nil {
return nil, errF(knox.InternalServerErrorCode, authzErr.Error())
}

if !authorized {
return nil, errF(knox.UnauthorizedCode, fmt.Sprintf("Principal %s not authorized to write %s", principal.GetID(), keyID))
}

Expand Down Expand Up @@ -428,7 +450,12 @@ func putVersionsHandler(m KeyManager, principal knox.Principal, parameters map[s
}

// Authorize
if !principal.CanAccess(key.ACL, knox.Write) {
authorized, authzErr := authorizeRequest(key, principal, knox.Write)
if authzErr != nil {
return nil, errF(knox.InternalServerErrorCode, authzErr.Error())
}

if !authorized {
return nil, errF(knox.UnauthorizedCode, fmt.Sprintf("Principal %s not authorized to write %s", principal.GetID(), keyID))
}

Expand All @@ -445,3 +472,25 @@ func putVersionsHandler(m KeyManager, principal knox.Principal, parameters map[s
return nil, errF(knox.InternalServerErrorCode, err.Error())
}
}

func authorizeRequest(key *knox.Key, principal knox.Principal, access knox.AccessType) (allow bool, err error) {
defer func() {
if r := recover(); r != nil {
log.Printf("Recovered from panic in access callback: %v", r)

err = fmt.Errorf("Recovered from panic in access callback: %v", r)
}
}()

allow = principal.CanAccess(key.ACL, access)

if !allow && accessCallback != nil {
allow, err = accessCallback(knox.AccessCallbackInput{
Key: *key,
Principals: principal.Raw(),
AccessType: access,
})
}

return
}
Loading

0 comments on commit 62dd952

Please sign in to comment.