Skip to content

Commit

Permalink
Add a way for sansshell-server to reload credentials on SIGHUP (#450)
Browse files Browse the repository at this point in the history
sansshell-server often runs on VMs and occasionally we have use cases where we want to change trust configuration on the running VMs. Before this PR, the only way to make sansshell-server immediately respect the trust configuration was to restart it. After this PR, it's possible to reload without restarting.

This is especially useful if the command to change the trust configuration is being run through sansshell itself.

It's slightly unusual that we're using this to reload credentials but not to reload policy files. We haven't seen any demand for reloading policy files, so I think that's an okay tradeoff.
  • Loading branch information
stvnrhodes authored Jun 27, 2024
1 parent b8eae3a commit cc181b1
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 19 deletions.
35 changes: 20 additions & 15 deletions auth/mtls/mtls.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,21 +99,26 @@ type WrappedTransportCredentials struct {
func (w *WrappedTransportCredentials) checkRefresh() error {
if w.mtlsLoader.CertsRefreshed() {
w.logger.Info("certs need reloading")
// At least provide the logger we saved before we call into the loader
// or we lose all debugability.
ctx := context.Background()
ctx = logr.NewContext(ctx, w.logger)
newCreds, err := w.loader(ctx, w.loaderName)
w.logger.V(1).Info("newCreds", "creds", newCreds, "error", err)
if err != nil {
return err
}
w.mu.Lock()
defer w.mu.Unlock()
w.creds = newCreds
if w.serverName != "" {
return w.creds.OverrideServerName(w.serverName) //nolint:staticcheck
}
return w.refreshNow()
}
return nil
}

func (w *WrappedTransportCredentials) refreshNow() error {
// At least provide the logger we saved before we call into the loader
// or we lose all debugability.
ctx := context.Background()
ctx = logr.NewContext(ctx, w.logger)
newCreds, err := w.loader(ctx, w.loaderName)
w.logger.V(1).Info("newCreds", "creds", newCreds, "error", err)
if err != nil {
return err
}
w.mu.Lock()
defer w.mu.Unlock()
w.creds = newCreds
if w.serverName != "" {
return w.creds.OverrideServerName(w.serverName) //nolint:staticcheck
}
return nil
}
Expand Down
13 changes: 10 additions & 3 deletions auth/mtls/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,22 @@ import (
// as the TransportCredentials returned are a WrappedTransportCredentials which
// will check at call time if new certificates are available.
func LoadServerCredentials(ctx context.Context, loaderName string) (credentials.TransportCredentials, error) {
wrapped, _, err := LoadServerCredentialsWithForceRefresh(ctx, loaderName)
return wrapped, err
}

// LoadServerCredentialsWithForceRefresh returns transport credentials along with
// a function that allows immediately refreshing the credentials
func LoadServerCredentialsWithForceRefresh(ctx context.Context, loaderName string) (credentials.TransportCredentials, func() error, error) {
logger := logr.FromContextOrDiscard(ctx)
recorder := metrics.RecorderFromContextOrNoop(ctx)
mtlsLoader, err := Loader(loaderName)
if err != nil {
return nil, err
return nil, nil, err
}
creds, err := internalLoadServerCredentials(ctx, loaderName)
if err != nil {
return nil, err
return nil, nil, err
}
wrapped := &WrappedTransportCredentials{
creds: creds,
Expand All @@ -51,7 +58,7 @@ func LoadServerCredentials(ctx context.Context, loaderName string) (credentials.
logger: logger,
recorder: recorder,
}
return wrapped, nil
return wrapped, wrapped.refreshNow, err
}

func internalLoadServerCredentials(ctx context.Context, loaderName string) (credentials.TransportCredentials, error) {
Expand Down
1 change: 1 addition & 0 deletions cmd/sansshell-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,5 +178,6 @@ func main() {
server.WithDebugPort(*debugport),
server.WithMetricsPort(*metricsport),
server.WithMetricsRecorder(recorder),
server.WithRefreshCredsOnSIGHUP(),
)
}
30 changes: 29 additions & 1 deletion cmd/sansshell-server/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import (
"net/http"
"net/http/pprof"
"os"
"os/signal"
"syscall"

"github.com/go-logr/logr"
grpc_prometheus "github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus"
Expand Down Expand Up @@ -65,6 +67,8 @@ type runState struct {
statsHandler stats.Handler
authzHooks []rpcauth.RPCAuthzHook
services []func(*grpc.Server)

refreshCredsOnSIGHUP bool
}

type Option interface {
Expand Down Expand Up @@ -271,6 +275,17 @@ func WithOtelTracing(interceptorOpts ...otelgrpc.Option) Option {
})
}

// WithRefreshCredsOnSIGHUP will make sansshell-server refresh its credentials via
// its credential loader when it receives a SIGHUP signal. This is useful if you
// want to make sansshell immediately refresh its identity and trust configuration
// via `systemctl reload`.
func WithRefreshCredsOnSIGHUP() Option {
return optionFunc(func(_ context.Context, r *runState) error {
r.refreshCredsOnSIGHUP = true
return nil
})
}

// Run takes the given context and RunState and starts up a sansshell server.
// As this is intended to be called from main() it doesn't return errors and will instead exit on any errors.
func Run(ctx context.Context, opts ...Option) {
Expand Down Expand Up @@ -342,10 +357,23 @@ func extractTransportCredentialsFromRunState(ctx context.Context, rs *runState)
return nil, fmt.Errorf("both credSource and tlsConfig are defined")
}
if rs.credSource != "" {
creds, err = mtls.LoadServerCredentials(ctx, rs.credSource)
var refreshCreds func() error
creds, refreshCreds, err = mtls.LoadServerCredentialsWithForceRefresh(ctx, rs.credSource)
if err != nil {
return nil, err
}
if rs.refreshCredsOnSIGHUP {
go func() {
c := make(chan os.Signal, 1)
signal.Notify(c, syscall.SIGHUP)
for range c {
rs.logger.Info("got SIGHUP, refreshing credentials")
if err := refreshCreds(); err != nil {
rs.logger.Error(err, "unable to refresh credentials")
}
}
}()
}
} else {
creds = credentials.NewTLS(rs.tlsConfig)
}
Expand Down

0 comments on commit cc181b1

Please sign in to comment.