diff --git a/pkg/pgsql/errors/errors.go b/pkg/pgsql/errors/errors.go index 616850eaa6..0503ce6fb2 100644 --- a/pkg/pgsql/errors/errors.go +++ b/pkg/pgsql/errors/errors.go @@ -39,7 +39,7 @@ var ErrMaxParamsNumberExceeded = errors.New("number of parameters exceeded the m var ErrParametersValueSizeTooLarge = errors.New("provided parameters exceeded the maximum allowed size limit") var ErrNegativeParameterValueLen = errors.New("negative parameter length detected") var ErrMalformedMessage = errors.New("malformed message detected") -var ErrMessageTooLarge = errors.New("payload message hit allowed memory boundaries") +var ErrMessageTooLarge = errors.New("payload message hit allowed memory boundaries") func MapPgError(err error) (er bm.ErrorResp) { switch { diff --git a/pkg/pgsql/server/initialize_session.go b/pkg/pgsql/server/initialize_session.go index 89f1cc4bf1..8b0dff5346 100644 --- a/pkg/pgsql/server/initialize_session.go +++ b/pkg/pgsql/server/initialize_session.go @@ -58,7 +58,7 @@ func (s *session) InitializeSession() (err error) { s.protocolVersion = parseProtocolVersion(pvb) // SSL Request packet - if s.protocolVersion == "1234.5679" { + if s.protocolVersion == pgmeta.PgsqlSSLRequestProtocolVersion { if s.tlsConfig == nil || len(s.tlsConfig.Certificates) == 0 { if _, err = s.writeMessage([]byte(`N`)); err != nil { return err @@ -86,8 +86,20 @@ func (s *session) InitializeSession() (err error) { s.protocolVersion = parseProtocolVersion(pvb) } + if !isValidProtocolVersion(s.protocolVersion) { + return fmt.Errorf("%w: %s", pgmeta.ErrInvalidPgsqlProtocolVersion, s.protocolVersion) + } + // startup message - connStringLenght := int(binary.BigEndian.Uint32(lb) - 4) + connStringLenght := int(binary.BigEndian.Uint32(lb) - 8) + if connStringLenght < 0 { + return pserr.ErrMalformedMessage + } + + if connStringLenght > pgmeta.MaxMsgSize { + return pserr.ErrMessageTooLarge + } + connString := make([]byte, connStringLenght) if _, err := s.mr.Read(connString); err != nil { @@ -214,7 +226,7 @@ func (s *session) HandleStartup(ctx context.Context) (err error) { } // todo this is needed by jdbc driver. Here is added the minor supported version at the moment - if _, err := s.writeMessage(bm.ParameterStatus([]byte("server_version"), []byte(pgmeta.PgsqlProtocolVersion))); err != nil { + if _, err := s.writeMessage(bm.ParameterStatus([]byte("server_version"), []byte(pgmeta.PgsqlServerVersion))); err != nil { return err } @@ -227,6 +239,10 @@ func parseProtocolVersion(payload []byte) string { return fmt.Sprintf("%d.%d", major, minor) } +func isValidProtocolVersion(version string) bool { + return version == pgmeta.PgsqlProtocolVersion || version == pgmeta.PgsqlSSLRequestProtocolVersion +} + func (s *session) Close() error { s.mr.CloseConnection() diff --git a/pkg/pgsql/server/pgmeta/pg_type.go b/pkg/pgsql/server/pgmeta/pg_type.go index 071450bd2d..599d0429fc 100644 --- a/pkg/pgsql/server/pgmeta/pg_type.go +++ b/pkg/pgsql/server/pgmeta/pg_type.go @@ -17,15 +17,21 @@ limitations under the License. package pgmeta import ( + "errors" "fmt" ) -const PgTypeMapOid = 0 -const PgTypeMapLength = 1 +const ( + PgTypeMapOid = 0 + PgTypeMapLength = 1 -const PgsqlProtocolVersion = "9.6" + PgsqlProtocolVersion = "3.0" + PgsqlSSLRequestProtocolVersion = "1234.5679" + PgsqlServerVersion = "9.6" +) -var PgsqlProtocolVersionMessage = fmt.Sprintf("pgsql wire protocol %s or greater version implemented by immudb", PgsqlProtocolVersion) +var PgsqlServerVersionMessage = fmt.Sprintf("pgsql server %s or greater version implemented by immudb", PgsqlServerVersion) +var ErrInvalidPgsqlProtocolVersion = errors.New("invalid pgsql protocol version") // PgTypeMap maps the immudb type descriptor with pgsql pgtype map. // First int is the oid value (retrieved with select * from pg_type;) diff --git a/pkg/pgsql/server/pgsql_integration_test.go b/pkg/pgsql/server/pgsql_integration_test.go index fd7e8d7e7a..b40c173583 100644 --- a/pkg/pgsql/server/pgsql_integration_test.go +++ b/pkg/pgsql/server/pgsql_integration_test.go @@ -727,7 +727,7 @@ func TestPgsqlServer_VersionStatement(t *testing.T) { var version string err = db.QueryRow("SELECT version()").Scan(&version) require.NoError(t, err) - require.Equal(t, pgmeta.PgsqlProtocolVersionMessage, version) + require.Equal(t, pgmeta.PgsqlServerVersionMessage, version) _, err = db.Exec("DEALLOCATE \"_PLAN0x7fb2c0822800\"") require.NoError(t, err) diff --git a/pkg/pgsql/server/version.go b/pkg/pgsql/server/version.go index 6cbb168ec6..7b4f5d4cc9 100644 --- a/pkg/pgsql/server/version.go +++ b/pkg/pgsql/server/version.go @@ -29,7 +29,7 @@ func (s *session) writeVersionInfo() error { } rows := []*schema.Row{{ Columns: []string{"version"}, - Values: []*schema.SQLValue{{Value: &schema.SQLValue_S{S: pgmeta.PgsqlProtocolVersionMessage}}}, + Values: []*schema.SQLValue{{Value: &schema.SQLValue_S{S: pgmeta.PgsqlServerVersionMessage}}}, }} if _, err := s.writeMessage(bm.DataRow(rows, len(cols), nil)); err != nil { return err