Skip to content

Commit

Permalink
Improve validation when establishing psql session
Browse files Browse the repository at this point in the history
  • Loading branch information
ostafen committed Apr 15, 2024
1 parent a4e7599 commit fb4b707
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 10 deletions.
2 changes: 1 addition & 1 deletion pkg/pgsql/errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ var ErrMaxParamsNumberExceeded = errors.New("number of parameters exceeded the m
var ErrParametersValueSizeTooLarge = errors.New("provided parameters exceeded the maximum allowed size limit")
var ErrNegativeParameterValueLen = errors.New("negative parameter length detected")
var ErrMalformedMessage = errors.New("malformed message detected")
var ErrMessageTooLarge = errors.New("payload message hit allowed memory boundaries")
var ErrMessageTooLarge = errors.New("payload message hit allowed memory boundaries")

func MapPgError(err error) (er bm.ErrorResp) {
switch {
Expand Down
22 changes: 19 additions & 3 deletions pkg/pgsql/server/initialize_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func (s *session) InitializeSession() (err error) {
s.protocolVersion = parseProtocolVersion(pvb)

// SSL Request packet
if s.protocolVersion == "1234.5679" {
if s.protocolVersion == pgmeta.PgsqlSSLRequestProtocolVersion {
if s.tlsConfig == nil || len(s.tlsConfig.Certificates) == 0 {
if _, err = s.writeMessage([]byte(`N`)); err != nil {
return err
Expand Down Expand Up @@ -86,8 +86,20 @@ func (s *session) InitializeSession() (err error) {
s.protocolVersion = parseProtocolVersion(pvb)
}

if !isValidProtocolVersion(s.protocolVersion) {
return fmt.Errorf("%w: %s", pgmeta.ErrInvalidPgsqlProtocolVersion, s.protocolVersion)
}

// startup message
connStringLenght := int(binary.BigEndian.Uint32(lb) - 4)
connStringLenght := int(binary.BigEndian.Uint32(lb) - 8)
if connStringLenght < 0 {
return pserr.ErrMalformedMessage
}

if connStringLenght > pgmeta.MaxMsgSize {
return pserr.ErrMessageTooLarge
}

connString := make([]byte, connStringLenght)

if _, err := s.mr.Read(connString); err != nil {
Expand Down Expand Up @@ -214,7 +226,7 @@ func (s *session) HandleStartup(ctx context.Context) (err error) {
}

// todo this is needed by jdbc driver. Here is added the minor supported version at the moment
if _, err := s.writeMessage(bm.ParameterStatus([]byte("server_version"), []byte(pgmeta.PgsqlProtocolVersion))); err != nil {
if _, err := s.writeMessage(bm.ParameterStatus([]byte("server_version"), []byte(pgmeta.PgsqlServerVersion))); err != nil {
return err
}

Expand All @@ -227,6 +239,10 @@ func parseProtocolVersion(payload []byte) string {
return fmt.Sprintf("%d.%d", major, minor)
}

func isValidProtocolVersion(version string) bool {
return version == pgmeta.PgsqlProtocolVersion || version == pgmeta.PgsqlSSLRequestProtocolVersion
}

func (s *session) Close() error {
s.mr.CloseConnection()

Expand Down
14 changes: 10 additions & 4 deletions pkg/pgsql/server/pgmeta/pg_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,21 @@ limitations under the License.
package pgmeta

import (
"errors"
"fmt"
)

const PgTypeMapOid = 0
const PgTypeMapLength = 1
const (
PgTypeMapOid = 0
PgTypeMapLength = 1

const PgsqlProtocolVersion = "9.6"
PgsqlProtocolVersion = "3.0"
PgsqlSSLRequestProtocolVersion = "1234.5679"
PgsqlServerVersion = "9.6"
)

var PgsqlProtocolVersionMessage = fmt.Sprintf("pgsql wire protocol %s or greater version implemented by immudb", PgsqlProtocolVersion)
var PgsqlServerVersionMessage = fmt.Sprintf("pgsql server %s or greater version implemented by immudb", PgsqlServerVersion)
var ErrInvalidPgsqlProtocolVersion = errors.New("invalid pgsql protocol version")

// PgTypeMap maps the immudb type descriptor with pgsql pgtype map.
// First int is the oid value (retrieved with select * from pg_type;)
Expand Down
2 changes: 1 addition & 1 deletion pkg/pgsql/server/pgsql_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,7 @@ func TestPgsqlServer_VersionStatement(t *testing.T) {
var version string
err = db.QueryRow("SELECT version()").Scan(&version)
require.NoError(t, err)
require.Equal(t, pgmeta.PgsqlProtocolVersionMessage, version)
require.Equal(t, pgmeta.PgsqlServerVersionMessage, version)

_, err = db.Exec("DEALLOCATE \"_PLAN0x7fb2c0822800\"")
require.NoError(t, err)
Expand Down
2 changes: 1 addition & 1 deletion pkg/pgsql/server/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func (s *session) writeVersionInfo() error {
}
rows := []*schema.Row{{
Columns: []string{"version"},
Values: []*schema.SQLValue{{Value: &schema.SQLValue_S{S: pgmeta.PgsqlProtocolVersionMessage}}},
Values: []*schema.SQLValue{{Value: &schema.SQLValue_S{S: pgmeta.PgsqlServerVersionMessage}}},
}}
if _, err := s.writeMessage(bm.DataRow(rows, len(cols), nil)); err != nil {
return err
Expand Down

0 comments on commit fb4b707

Please sign in to comment.