diff --git a/cmd/immudb/command/init.go b/cmd/immudb/command/init.go index 97619118d3..32726117a5 100644 --- a/cmd/immudb/command/init.go +++ b/cmd/immudb/command/init.go @@ -55,6 +55,7 @@ func (cl *Commandline) setupFlags(cmd *cobra.Command, options *server.Options) { cmd.Flags().Int("max-recv-msg-size", options.MaxRecvMsgSize, "max message size in bytes the server can receive") cmd.Flags().Bool("no-histograms", false, "disable collection of histogram metrics like query durations") cmd.Flags().BoolP(c.DetachedFlag, c.DetachedShortFlag, options.Detached, "run immudb in background") + cmd.Flags().Bool("auto-cert", options.AutoCert, "start the server using a generated, self-signed HTTPS certificate") cmd.Flags().String("certificate", "", "server certificate file path") cmd.Flags().String("pkey", "", "server private key path") cmd.Flags().String("clientcas", "", "clients certificates list. Aka certificate authority") @@ -122,6 +123,7 @@ func setupDefaults(options *server.Options) { viper.SetDefault("max-recv-msg-size", options.MaxRecvMsgSize) viper.SetDefault("no-histograms", options.NoHistograms) viper.SetDefault("detached", options.Detached) + viper.SetDefault("auto-cert", options.AutoCert) viper.SetDefault("certificate", "") viper.SetDefault("pkey", "") viper.SetDefault("clientcas", "") diff --git a/cmd/immudb/command/parse_options.go b/cmd/immudb/command/parse_options.go index 328b901d82..54fec1195a 100644 --- a/cmd/immudb/command/parse_options.go +++ b/cmd/immudb/command/parse_options.go @@ -59,6 +59,7 @@ func parseOptions() (options *server.Options, err error) { maxRecvMsgSize := viper.GetInt("max-recv-msg-size") noHistograms := viper.GetBool("no-histograms") detached := viper.GetBool("detached") + autoCert := viper.GetBool("auto-cert") certificate := viper.GetString("certificate") pkey := viper.GetString("pkey") clientcas := viper.GetString("clientcas") @@ -118,7 +119,7 @@ func parseOptions() (options *server.Options, err error) { WithMaxSessionAgeTime(viper.GetDuration("max-session-age-time")). WithTimeout(viper.GetDuration("session-timeout")) - tlsConfig, err := setUpTLS(pkey, certificate, clientcas, mtls) + tlsConfig, err := setUpTLS(pkey, certificate, clientcas, mtls, autoCert) if err != nil { return options, err } diff --git a/cmd/immudb/command/tls_config.go b/cmd/immudb/command/tls_config.go index 8565e015f2..f808d4a2ca 100644 --- a/cmd/immudb/command/tls_config.go +++ b/cmd/immudb/command/tls_config.go @@ -21,24 +21,59 @@ import ( "crypto/x509" "errors" "fmt" - "io/ioutil" + "os" + "path/filepath" + "time" + + tlscert "github.com/codenotary/immudb/pkg/cert" +) + +const ( + certFileDefault = "immudb-cert.pem" + keyFileDefault = "immudb-key.pem" + + certOrganizationDefault = "immudb" + certExpirationDefault = 365 * 24 * time.Hour ) -func setUpTLS(pkey, cert, ca string, mtls bool) (*tls.Config, error) { +func setUpTLS(pkeyPath, certPath, ca string, mtls bool, autoCert bool) (*tls.Config, error) { + if (pkeyPath == "" && certPath != "") || (pkeyPath != "" && certPath == "") { + return nil, fmt.Errorf("both certificate and private key paths must be specified or neither") + } + var c *tls.Config - if cert != "" && pkey != "" { - certs, err := tls.LoadX509KeyPair(cert, pkey) + certPath, pkeyPath, err := getCertAndKeyPath(certPath, pkeyPath, autoCert) + if err != nil { + return nil, err + } + + if certPath != "" && pkeyPath != "" { + cert, err := ensureCert(certPath, pkeyPath, autoCert) if err != nil { - return nil, errors.New(fmt.Sprintf("failed to read client certificate or private key: %v", err)) + return nil, fmt.Errorf("failed to read client certificate or private key: %v", err) } + c = &tls.Config{ - Certificates: []tls.Certificate{certs}, + Certificates: []tls.Certificate{*cert}, ClientAuth: tls.VerifyClientCertIfGiven, } + + if autoCert { + rootCert, err := os.ReadFile(certPath) + if err != nil { + return nil, fmt.Errorf("failed to read root cert: %v", err) + } + + rootPool := x509.NewCertPool() + if ok := rootPool.AppendCertsFromPEM(rootCert); !ok { + return nil, fmt.Errorf("failed to read root cert") + } + c.RootCAs = rootPool + } } - if mtls && (cert == "" || pkey == "") { + if mtls && (certPath == "" || pkeyPath == "") { return nil, errors.New("in order to enable MTLS a certificate and private key are required") } @@ -46,7 +81,7 @@ func setUpTLS(pkey, cert, ca string, mtls bool) (*tls.Config, error) { if mtls && ca != "" { certPool := x509.NewCertPool() // Trusted store, contain the list of trusted certificates. client has to use one of this certificate to be trusted by this server - bs, err := ioutil.ReadFile(ca) + bs, err := os.ReadFile(ca) if err != nil { return nil, fmt.Errorf("failed to read client ca cert: %v", err) } @@ -57,6 +92,40 @@ func setUpTLS(pkey, cert, ca string, mtls bool) (*tls.Config, error) { } c.ClientCAs = certPool } - return c, nil } + +func loadCert(certPath, keyPath string) (*tls.Certificate, error) { + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + return nil, fmt.Errorf("failed to load cert/key pair: %w", err) + } + return &cert, nil +} + +func ensureCert(certPath, keyPath string, genCert bool) (*tls.Certificate, error) { + _, err1 := os.Stat(certPath) + _, err2 := os.Stat(keyPath) + + if (os.IsNotExist(err1) || os.IsNotExist(err2)) && genCert { + if err := tlscert.GenerateSelfSignedCert(certPath, keyPath, certOrganizationDefault, certExpirationDefault); err != nil { + return nil, err + } + } + return loadCert(certPath, keyPath) +} + +func getCertAndKeyPath(certPath, keyPath string, useDefault bool) (string, string, error) { + if !useDefault || (certPath != "" && keyPath != "") { + return certPath, keyPath, nil + } + + homeDir, err := os.UserHomeDir() + if err != nil { + return "", "", fmt.Errorf("cannot get user home directory: %w", err) + } + + return filepath.Join(homeDir, "immudb", "ssl", certFileDefault), + filepath.Join(homeDir, "immudb", "ssl", keyFileDefault), + nil +} diff --git a/cmd/immudb/command/tls_config_test.go b/cmd/immudb/command/tls_config_test.go index 067c329a4c..d8b63f4af2 100644 --- a/cmd/immudb/command/tls_config_test.go +++ b/cmd/immudb/command/tls_config_test.go @@ -19,6 +19,7 @@ package immudb import ( "fmt" "os" + "path/filepath" "testing" "github.com/stretchr/testify/require" @@ -54,23 +55,66 @@ d/ax/lUR3RCVV6A+hzTgOhYKvoV1U6iX21hUarcm6MB6qaeORCHfQzQpn62nRe6X ` func TestSetUpTLS(t *testing.T) { - _, err := setUpTLS("banana", "banana", "banana", false) + _, err := setUpTLS("banana", "", "banana", false, false) require.Error(t, err) - _, err = setUpTLS("", "", "", true) + _, err = setUpTLS("banana", "banana", "banana", false, false) require.Error(t, err) - _, err = setUpTLS("banana", "", "", true) + _, err = setUpTLS("", "", "", true, false) + require.Error(t, err) + _, err = setUpTLS("banana", "", "", true, false) require.Error(t, err) defer os.Remove("xxkey.pem") f, _ := os.Create("xxkey.pem") - fmt.Fprintf(f, key) + fmt.Fprint(f, key) f.Close() defer os.Remove("xxcert.pem") f, _ = os.Create("xxcert.pem") - fmt.Fprintf(f, cert) + fmt.Fprint(f, cert) f.Close() - _, err = setUpTLS("xxkey.pem", "xxcert.pem", "banana", true) + _, err = setUpTLS("xxkey.pem", "xxcert.pem", "banana", true, false) require.Error(t, err) } + +func TestSetUpTLSWithAutoHTTPS(t *testing.T) { + t.Run("use specified paths", func(t *testing.T) { + tempDir := t.TempDir() + + certFile := filepath.Join(tempDir, "immudb.cert") + keyFile := filepath.Join(tempDir, "immudb.key") + + tlsConfig, err := setUpTLS(certFile, keyFile, "", false, false) + require.Error(t, err) + require.Nil(t, tlsConfig) + + tlsConfig, err = setUpTLS(certFile, keyFile, "", false, true) + require.NoError(t, err) + require.NotNil(t, tlsConfig) + + require.FileExists(t, certFile) + require.FileExists(t, keyFile) + + tlsConfig, err = setUpTLS(certFile, keyFile, "", false, false) + require.NoError(t, err) + require.NotNil(t, tlsConfig) + }) + + t.Run("use default paths", func(t *testing.T) { + certPath, keyPath, err := getCertAndKeyPath("", "", true) + require.NoError(t, err) + + defer func() { + os.RemoveAll(certPath) + os.Remove(keyPath) + }() + + tlsConfig, err := setUpTLS("", "", "", false, true) + require.NoError(t, err) + require.NotNil(t, tlsConfig) + + require.FileExists(t, certPath) + require.FileExists(t, keyPath) + }) +} diff --git a/pkg/cert/cert.go b/pkg/cert/cert.go new file mode 100644 index 0000000000..7a1693392c --- /dev/null +++ b/pkg/cert/cert.go @@ -0,0 +1,119 @@ +/* +Copyright 2024 Codenotary Inc. All rights reserved. + +SPDX-License-Identifier: BUSL-1.1 +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://mariadb.com/bsl11/ + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package cert + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "net" + "os" + "path" + "time" +) + +func GenerateSelfSignedCert(certPath, keyPath string, org string, expiration time.Duration) error { + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return fmt.Errorf("failed to generate RSA key: %w", err) + } + + notBefore := time.Now() + notAfter := notBefore.Add(expiration) + + serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + return fmt.Errorf("failed to generate serial number: %w", err) + } + + hostname, err := os.Hostname() + if err != nil { + return err + } + + ips, err := listIPs() + if err != nil { + return err + } + ips = append(ips, net.ParseIP("0.0.0.0")) + + issuerOrSubject := pkix.Name{ + Organization: []string{org}, + } + + template := x509.Certificate{ + Issuer: issuerOrSubject, + SerialNumber: serialNumber, + Subject: issuerOrSubject, + DNSNames: []string{"localhost", hostname}, + NotBefore: notBefore, + NotAfter: notAfter, + IPAddresses: ips, + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + } + + certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) + if err != nil { + return fmt.Errorf("failed to create certificate: %w", err) + } + + if err := os.MkdirAll(path.Dir(certPath), 0755); err != nil { + return err + } + + certBytesPem := encodePEM(certBytes, "CERTIFICATE") + if err := os.WriteFile(certPath, certBytesPem, 0644); err != nil { + return fmt.Errorf("failed to write cert file: %w", err) + } + + privBytes := x509.MarshalPKCS1PrivateKey(priv) + privBytesPem := encodePEM(privBytes, "PRIVATE KEY") + + if err := os.WriteFile(keyPath, privBytesPem, 0600); err != nil { + return fmt.Errorf("failed to write key file: %w", err) + } + return nil +} + +func encodePEM(data []byte, blockType string) []byte { + block := &pem.Block{ + Type: blockType, + Bytes: data, + } + return pem.EncodeToMemory(block) +} + +func listIPs() ([]net.IP, error) { + addresses, err := net.InterfaceAddrs() + if err != nil { + return nil, err + } + + ips := make([]net.IP, 0, len(addresses)) + for _, addr := range addresses { + ipNet, ok := addr.(*net.IPNet) + if ok { + ips = append(ips, ipNet.IP) + } + } + return ips, nil +} diff --git a/pkg/server/corruption_checker_test.go b/pkg/server/corruption_checker_test.go index d550de10d0..91f1b2f667 100644 --- a/pkg/server/corruption_checker_test.go +++ b/pkg/server/corruption_checker_test.go @@ -16,6 +16,8 @@ limitations under the License. package server +import "fmt" + /* import ( "testing" @@ -377,18 +379,33 @@ func makeDB(dir string) *badger.DB { */ -type mockLogger struct{} +type mockLogger struct { + lines []string +} -func (l *mockLogger) Errorf(f string, v ...interface{}) {} +func (l *mockLogger) Errorf(f string, v ...interface{}) { + l.log("ERROR", f, v...) +} -func (l *mockLogger) Warningf(f string, v ...interface{}) {} +func (l *mockLogger) Warningf(f string, v ...interface{}) { + l.log("WARN", f, v...) +} -func (l *mockLogger) Infof(f string, v ...interface{}) {} +func (l *mockLogger) Infof(f string, v ...interface{}) { + l.log("INFO", f, v...) +} -func (l *mockLogger) Debugf(f string, v ...interface{}) {} +func (l *mockLogger) Debugf(f string, v ...interface{}) { + l.log("DEBUG", f, v...) +} func (l *mockLogger) Close() error { return nil } +func (l *mockLogger) log(level, f string, v ...interface{}) { + line := level + ": " + fmt.Sprintf(f, v...) + l.lines = append(l.lines, line) +} + /* func TestCryptoRandSource_Seed(t *testing.T) { cs := newCryptoRandSource() diff --git a/pkg/server/options.go b/pkg/server/options.go index 118d7c324c..aba34cdb1b 100644 --- a/pkg/server/options.go +++ b/pkg/server/options.go @@ -45,6 +45,7 @@ type Options struct { Config string Pidfile string Logfile string + AutoCert bool TLSConfig *tls.Config auth bool MaxRecvMsgSize int @@ -119,6 +120,7 @@ func DefaultOptions() *Options { Config: "configs/immudb.toml", Pidfile: "", Logfile: "", + AutoCert: false, TLSConfig: nil, auth: true, MaxRecvMsgSize: 1024 * 1024 * 32, // 32Mb diff --git a/pkg/server/server.go b/pkg/server/server.go index c183b26a86..c9c4cf04e5 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -18,6 +18,7 @@ package server import ( "context" + "crypto/x509" "fmt" "io/ioutil" "log" @@ -93,6 +94,9 @@ func (s *ImmuServer) Initialize() error { s.Logger.Infof("\n%s\n%s\n%s\n\n", immudbTextLogo, version.VersionStr(), s.Options) } + // Alert the user to certificates that are either expired or approaching expiration + s.checkTLSCerts() + if s.Options.GetMaintenance() && s.Options.GetAuth() { return ErrAuthMustBeDisabled } @@ -272,6 +276,32 @@ func (s *ImmuServer) Initialize() error { return err } +func (s *ImmuServer) checkTLSCerts() { + if s.Options.TLSConfig == nil { + return + } + + now := time.Now() + for _, cert := range s.Options.TLSConfig.Certificates { + if len(cert.Certificate) == 0 { + s.Logger.Errorf("tls config contains an invalid certificate") + continue + } + + x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + s.Logger.Errorf("could not parse certificate: %s", err) + continue + } + + if now.Before(x509Cert.NotBefore) || now.After(x509Cert.NotAfter) { + s.Logger.Warningf("certificate with serial number %s is expired", x509Cert.SerialNumber.String()) + } else if !now.Before(x509Cert.NotAfter.Add(-30 * 24 * time.Hour)) { + s.Logger.Warningf("certificate with serial number %s is about to expire: %s left", x509Cert.SerialNumber.String(), x509Cert.NotAfter.Sub(now).String()) + } + } +} + // Start starts the immudb server // Loads and starts the System DB, default db and user db func (s *ImmuServer) Start() (err error) { diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index 0233d87870..152c10ca11 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -19,6 +19,7 @@ package server import ( "bytes" "context" + "crypto/tls" "errors" "fmt" "io/ioutil" @@ -30,6 +31,7 @@ import ( "time" "github.com/codenotary/immudb/cmd/version" + "github.com/codenotary/immudb/pkg/cert" "github.com/codenotary/immudb/pkg/fs" "github.com/codenotary/immudb/pkg/stream" "golang.org/x/crypto/bcrypt" @@ -2121,3 +2123,54 @@ func TestServerDatabaseTruncate(t *testing.T) { _, err = s.CloseSession(ctx, &emptypb.Empty{}) require.NoError(t, err) } + +func TestUserIsAlertedToExpiredCerts(t *testing.T) { + dir := t.TempDir() + + certsPath := filepath.Join(dir, "certs") + + expCert := makeCert(t, certsPath, "expired", 0) + nearExpCert := makeCert(t, certsPath, "nearly-expired", 15*24*time.Hour) + validCert := makeCert(t, certsPath, "valid", 36*24*time.Hour) + + tlsConfig := &tls.Config{ + Certificates: []tls.Certificate{ + expCert, + nearExpCert, + validCert, + {}, + {Certificate: [][]byte{{1, 2, 3}}}, + }, + } + + opts := DefaultOptions(). + WithDir(dir). + WithTLS(tlsConfig) + + s, stop := testServer(opts) + defer stop() + + mockLogger := &mockLogger{} + s.WithLogger(mockLogger) + + s.checkTLSCerts() + + require.GreaterOrEqual(t, len(mockLogger.lines), 4) + require.Contains(t, mockLogger.lines[0], "is expired") + require.Contains(t, mockLogger.lines[1], "is about to expire") + require.Contains(t, mockLogger.lines[2], "tls config contains an invalid certificate") + require.Contains(t, mockLogger.lines[3], "could not parse certificate") +} + +func makeCert(t *testing.T, dir, suffix string, expiration time.Duration) tls.Certificate { + certFile := filepath.Join(dir, fmt.Sprintf("immudb-%s.cert", suffix)) + keyFile := filepath.Join(dir, fmt.Sprintf("immudb-%s.key", suffix)) + + err := cert.GenerateSelfSignedCert(certFile, keyFile, "immudb", expiration) + require.NoError(t, err) + + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + require.NoError(t, err) + + return cert +} diff --git a/pkg/server/webserver.go b/pkg/server/webserver.go index aaba6d8c14..5ac4d5eb7d 100644 --- a/pkg/server/webserver.go +++ b/pkg/server/webserver.go @@ -29,12 +29,13 @@ import ( "github.com/codenotary/immudb/webconsole" "github.com/grpc-ecosystem/grpc-gateway/runtime" "google.golang.org/grpc" + "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/grpclog" ) func startWebServer(ctx context.Context, grpcAddr string, httpAddr string, tlsConfig *tls.Config, s *ImmuServer, l logger.Logger) (*http.Server, error) { - grpcClient, err := grpcClient(ctx, grpcAddr) + grpcClient, err := grpcClient(ctx, grpcAddr, tlsConfig) if err != nil { return nil, err } @@ -105,8 +106,15 @@ func startWebServer(ctx context.Context, grpcAddr string, httpAddr string, tlsCo return httpServer, nil } -func grpcClient(ctx context.Context, grpcAddr string) (conn *grpc.ClientConn, err error) { - conn, err = grpc.Dial(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) +func grpcClient(ctx context.Context, grpcAddr string, tlsConfig *tls.Config) (conn *grpc.ClientConn, err error) { + var creds credentials.TransportCredentials + if tlsConfig != nil { + creds = credentials.NewTLS(&tls.Config{RootCAs: tlsConfig.RootCAs}) + } else { + creds = insecure.NewCredentials() + } + + conn, err = grpc.Dial(grpcAddr, grpc.WithTransportCredentials(creds)) if err != nil { return conn, err } @@ -124,6 +132,5 @@ func grpcClient(ctx context.Context, grpcAddr string) (conn *grpc.ClientConn, er } }() }() - return conn, nil }