From cc181b1a8ff766fe5193ef79471853925bda686f Mon Sep 17 00:00:00 2001 From: Steven Rhodes Date: Thu, 27 Jun 2024 10:03:39 -0700 Subject: [PATCH] Add a way for sansshell-server to reload credentials on SIGHUP (#450) sansshell-server often runs on VMs and occasionally we have use cases where we want to change trust configuration on the running VMs. Before this PR, the only way to make sansshell-server immediately respect the trust configuration was to restart it. After this PR, it's possible to reload without restarting. This is especially useful if the command to change the trust configuration is being run through sansshell itself. It's slightly unusual that we're using this to reload credentials but not to reload policy files. We haven't seen any demand for reloading policy files, so I think that's an okay tradeoff. --- auth/mtls/mtls.go | 35 +++++++++++++++------------ auth/mtls/server.go | 13 +++++++--- cmd/sansshell-server/main.go | 1 + cmd/sansshell-server/server/server.go | 30 ++++++++++++++++++++++- 4 files changed, 60 insertions(+), 19 deletions(-) diff --git a/auth/mtls/mtls.go b/auth/mtls/mtls.go index 3ca7094b..04b4f08c 100644 --- a/auth/mtls/mtls.go +++ b/auth/mtls/mtls.go @@ -99,21 +99,26 @@ type WrappedTransportCredentials struct { func (w *WrappedTransportCredentials) checkRefresh() error { if w.mtlsLoader.CertsRefreshed() { w.logger.Info("certs need reloading") - // At least provide the logger we saved before we call into the loader - // or we lose all debugability. - ctx := context.Background() - ctx = logr.NewContext(ctx, w.logger) - newCreds, err := w.loader(ctx, w.loaderName) - w.logger.V(1).Info("newCreds", "creds", newCreds, "error", err) - if err != nil { - return err - } - w.mu.Lock() - defer w.mu.Unlock() - w.creds = newCreds - if w.serverName != "" { - return w.creds.OverrideServerName(w.serverName) //nolint:staticcheck - } + return w.refreshNow() + } + return nil +} + +func (w *WrappedTransportCredentials) refreshNow() error { + // At least provide the logger we saved before we call into the loader + // or we lose all debugability. + ctx := context.Background() + ctx = logr.NewContext(ctx, w.logger) + newCreds, err := w.loader(ctx, w.loaderName) + w.logger.V(1).Info("newCreds", "creds", newCreds, "error", err) + if err != nil { + return err + } + w.mu.Lock() + defer w.mu.Unlock() + w.creds = newCreds + if w.serverName != "" { + return w.creds.OverrideServerName(w.serverName) //nolint:staticcheck } return nil } diff --git a/auth/mtls/server.go b/auth/mtls/server.go index 416856dc..d5accf2d 100644 --- a/auth/mtls/server.go +++ b/auth/mtls/server.go @@ -33,15 +33,22 @@ import ( // as the TransportCredentials returned are a WrappedTransportCredentials which // will check at call time if new certificates are available. func LoadServerCredentials(ctx context.Context, loaderName string) (credentials.TransportCredentials, error) { + wrapped, _, err := LoadServerCredentialsWithForceRefresh(ctx, loaderName) + return wrapped, err +} + +// LoadServerCredentialsWithForceRefresh returns transport credentials along with +// a function that allows immediately refreshing the credentials +func LoadServerCredentialsWithForceRefresh(ctx context.Context, loaderName string) (credentials.TransportCredentials, func() error, error) { logger := logr.FromContextOrDiscard(ctx) recorder := metrics.RecorderFromContextOrNoop(ctx) mtlsLoader, err := Loader(loaderName) if err != nil { - return nil, err + return nil, nil, err } creds, err := internalLoadServerCredentials(ctx, loaderName) if err != nil { - return nil, err + return nil, nil, err } wrapped := &WrappedTransportCredentials{ creds: creds, @@ -51,7 +58,7 @@ func LoadServerCredentials(ctx context.Context, loaderName string) (credentials. logger: logger, recorder: recorder, } - return wrapped, nil + return wrapped, wrapped.refreshNow, err } func internalLoadServerCredentials(ctx context.Context, loaderName string) (credentials.TransportCredentials, error) { diff --git a/cmd/sansshell-server/main.go b/cmd/sansshell-server/main.go index b3e7edce..6e76ef9b 100644 --- a/cmd/sansshell-server/main.go +++ b/cmd/sansshell-server/main.go @@ -178,5 +178,6 @@ func main() { server.WithDebugPort(*debugport), server.WithMetricsPort(*metricsport), server.WithMetricsRecorder(recorder), + server.WithRefreshCredsOnSIGHUP(), ) } diff --git a/cmd/sansshell-server/server/server.go b/cmd/sansshell-server/server/server.go index f5775a68..58447525 100644 --- a/cmd/sansshell-server/server/server.go +++ b/cmd/sansshell-server/server/server.go @@ -26,6 +26,8 @@ import ( "net/http" "net/http/pprof" "os" + "os/signal" + "syscall" "github.com/go-logr/logr" grpc_prometheus "github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus" @@ -65,6 +67,8 @@ type runState struct { statsHandler stats.Handler authzHooks []rpcauth.RPCAuthzHook services []func(*grpc.Server) + + refreshCredsOnSIGHUP bool } type Option interface { @@ -271,6 +275,17 @@ func WithOtelTracing(interceptorOpts ...otelgrpc.Option) Option { }) } +// WithRefreshCredsOnSIGHUP will make sansshell-server refresh its credentials via +// its credential loader when it receives a SIGHUP signal. This is useful if you +// want to make sansshell immediately refresh its identity and trust configuration +// via `systemctl reload`. +func WithRefreshCredsOnSIGHUP() Option { + return optionFunc(func(_ context.Context, r *runState) error { + r.refreshCredsOnSIGHUP = true + return nil + }) +} + // Run takes the given context and RunState and starts up a sansshell server. // As this is intended to be called from main() it doesn't return errors and will instead exit on any errors. func Run(ctx context.Context, opts ...Option) { @@ -342,10 +357,23 @@ func extractTransportCredentialsFromRunState(ctx context.Context, rs *runState) return nil, fmt.Errorf("both credSource and tlsConfig are defined") } if rs.credSource != "" { - creds, err = mtls.LoadServerCredentials(ctx, rs.credSource) + var refreshCreds func() error + creds, refreshCreds, err = mtls.LoadServerCredentialsWithForceRefresh(ctx, rs.credSource) if err != nil { return nil, err } + if rs.refreshCredsOnSIGHUP { + go func() { + c := make(chan os.Signal, 1) + signal.Notify(c, syscall.SIGHUP) + for range c { + rs.logger.Info("got SIGHUP, refreshing credentials") + if err := refreshCreds(); err != nil { + rs.logger.Error(err, "unable to refresh credentials") + } + } + }() + } } else { creds = credentials.NewTLS(rs.tlsConfig) }