diff --git a/auth/mtls/mtls.go b/auth/mtls/mtls.go index 3ca7094b..04b4f08c 100644 --- a/auth/mtls/mtls.go +++ b/auth/mtls/mtls.go @@ -99,21 +99,26 @@ type WrappedTransportCredentials struct { func (w *WrappedTransportCredentials) checkRefresh() error { if w.mtlsLoader.CertsRefreshed() { w.logger.Info("certs need reloading") - // At least provide the logger we saved before we call into the loader - // or we lose all debugability. - ctx := context.Background() - ctx = logr.NewContext(ctx, w.logger) - newCreds, err := w.loader(ctx, w.loaderName) - w.logger.V(1).Info("newCreds", "creds", newCreds, "error", err) - if err != nil { - return err - } - w.mu.Lock() - defer w.mu.Unlock() - w.creds = newCreds - if w.serverName != "" { - return w.creds.OverrideServerName(w.serverName) //nolint:staticcheck - } + return w.refreshNow() + } + return nil +} + +func (w *WrappedTransportCredentials) refreshNow() error { + // At least provide the logger we saved before we call into the loader + // or we lose all debugability. + ctx := context.Background() + ctx = logr.NewContext(ctx, w.logger) + newCreds, err := w.loader(ctx, w.loaderName) + w.logger.V(1).Info("newCreds", "creds", newCreds, "error", err) + if err != nil { + return err + } + w.mu.Lock() + defer w.mu.Unlock() + w.creds = newCreds + if w.serverName != "" { + return w.creds.OverrideServerName(w.serverName) //nolint:staticcheck } return nil } diff --git a/auth/mtls/server.go b/auth/mtls/server.go index 416856dc..d5accf2d 100644 --- a/auth/mtls/server.go +++ b/auth/mtls/server.go @@ -33,15 +33,22 @@ import ( // as the TransportCredentials returned are a WrappedTransportCredentials which // will check at call time if new certificates are available. func LoadServerCredentials(ctx context.Context, loaderName string) (credentials.TransportCredentials, error) { + wrapped, _, err := LoadServerCredentialsWithForceRefresh(ctx, loaderName) + return wrapped, err +} + +// LoadServerCredentialsWithForceRefresh returns transport credentials along with +// a function that allows immediately refreshing the credentials +func LoadServerCredentialsWithForceRefresh(ctx context.Context, loaderName string) (credentials.TransportCredentials, func() error, error) { logger := logr.FromContextOrDiscard(ctx) recorder := metrics.RecorderFromContextOrNoop(ctx) mtlsLoader, err := Loader(loaderName) if err != nil { - return nil, err + return nil, nil, err } creds, err := internalLoadServerCredentials(ctx, loaderName) if err != nil { - return nil, err + return nil, nil, err } wrapped := &WrappedTransportCredentials{ creds: creds, @@ -51,7 +58,7 @@ func LoadServerCredentials(ctx context.Context, loaderName string) (credentials. logger: logger, recorder: recorder, } - return wrapped, nil + return wrapped, wrapped.refreshNow, err } func internalLoadServerCredentials(ctx context.Context, loaderName string) (credentials.TransportCredentials, error) { diff --git a/cmd/sansshell-server/main.go b/cmd/sansshell-server/main.go index b3e7edce..6e76ef9b 100644 --- a/cmd/sansshell-server/main.go +++ b/cmd/sansshell-server/main.go @@ -178,5 +178,6 @@ func main() { server.WithDebugPort(*debugport), server.WithMetricsPort(*metricsport), server.WithMetricsRecorder(recorder), + server.WithRefreshCredsOnSIGHUP(), ) } diff --git a/cmd/sansshell-server/server/server.go b/cmd/sansshell-server/server/server.go index f5775a68..58447525 100644 --- a/cmd/sansshell-server/server/server.go +++ b/cmd/sansshell-server/server/server.go @@ -26,6 +26,8 @@ import ( "net/http" "net/http/pprof" "os" + "os/signal" + "syscall" "github.com/go-logr/logr" grpc_prometheus "github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus" @@ -65,6 +67,8 @@ type runState struct { statsHandler stats.Handler authzHooks []rpcauth.RPCAuthzHook services []func(*grpc.Server) + + refreshCredsOnSIGHUP bool } type Option interface { @@ -271,6 +275,17 @@ func WithOtelTracing(interceptorOpts ...otelgrpc.Option) Option { }) } +// WithRefreshCredsOnSIGHUP will make sansshell-server refresh its credentials via +// its credential loader when it receives a SIGHUP signal. This is useful if you +// want to make sansshell immediately refresh its identity and trust configuration +// via `systemctl reload`. +func WithRefreshCredsOnSIGHUP() Option { + return optionFunc(func(_ context.Context, r *runState) error { + r.refreshCredsOnSIGHUP = true + return nil + }) +} + // Run takes the given context and RunState and starts up a sansshell server. // As this is intended to be called from main() it doesn't return errors and will instead exit on any errors. func Run(ctx context.Context, opts ...Option) { @@ -342,10 +357,23 @@ func extractTransportCredentialsFromRunState(ctx context.Context, rs *runState) return nil, fmt.Errorf("both credSource and tlsConfig are defined") } if rs.credSource != "" { - creds, err = mtls.LoadServerCredentials(ctx, rs.credSource) + var refreshCreds func() error + creds, refreshCreds, err = mtls.LoadServerCredentialsWithForceRefresh(ctx, rs.credSource) if err != nil { return nil, err } + if rs.refreshCredsOnSIGHUP { + go func() { + c := make(chan os.Signal, 1) + signal.Notify(c, syscall.SIGHUP) + for range c { + rs.logger.Info("got SIGHUP, refreshing credentials") + if err := refreshCreds(); err != nil { + rs.logger.Error(err, "unable to refresh credentials") + } + } + }() + } } else { creds = credentials.NewTLS(rs.tlsConfig) }