diff --git a/plugin/auth/postgresql/postgresql.go b/plugin/auth/postgresql/postgresql.go index 63fc510..992efe6 100644 --- a/plugin/auth/postgresql/postgresql.go +++ b/plugin/auth/postgresql/postgresql.go @@ -3,6 +3,7 @@ package postgresql import ( "bytes" "fmt" + "github.com/jmoiron/sqlx" _ "github.com/lib/pq" "github.com/wind-c/comqtt/v2/mqtt" @@ -90,12 +91,20 @@ func (a *Auth) Init(config any) error { sqlxDB.SetMaxOpenConns(a.config.Dsn.MaxOpenConns) sqlxDB.SetMaxIdleConns(a.config.Dsn.MaxIdleConns) - authSql := fmt.Sprintf("select %s, %s from %s where %s=?", + authSql := fmt.Sprintf(`select %s, %s from %s where %s=$1`, a.config.Auth.PasswordColumn, a.config.Auth.AllowColumn, a.config.Auth.Table, a.config.Auth.UserColumn) - aclSql := fmt.Sprintf("select %s, %s from %s where %s=?", + aclSql := fmt.Sprintf(`select %s, %s from %s where %s=$1`, a.config.Acl.TopicColumn, a.config.Acl.AccessColumn, a.config.Acl.Table, a.config.Acl.UserColumn) - a.authStmt, _ = sqlxDB.Preparex(authSql) - a.aclStmt, _ = sqlxDB.Preparex(aclSql) + a.authStmt, err = sqlxDB.Preparex(authSql) + if err != nil { + a.Log.Error().Str("authSql", authSql).Msg("Unable to create prepared statement for auth-sql") + return err + } + a.aclStmt, err = sqlxDB.Preparex(aclSql) + if err != nil { + a.Log.Error().Str("aclStmt", aclSql).Msg("Unable to create prepared statement for acl-sql") + return err + } a.db = sqlxDB return nil } @@ -132,7 +141,8 @@ func (a *Auth) OnConnectAuthenticate(cl *mqtt.Client, pk packets.Packet) bool { var password string var allow int - err := a.authStmt.QueryRowx(key).Scan(&password, &allow) + row := a.authStmt.QueryRowx(key) + err := row.Scan(&password, &allow) if err != nil || allow == 0 { return false }