Skip to content

Commit

Permalink
Merge pull request #358 from carlosms/fix-kill
Browse files Browse the repository at this point in the history
Kill queries also after the rows start streaming
  • Loading branch information
carlosms authored Mar 6, 2019
2 parents 56c5ce4 + f5f7752 commit 4cbf198
Showing 1 changed file with 42 additions and 28 deletions.
70 changes: 42 additions & 28 deletions server/handler/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,36 +53,37 @@ func genericVals(colTypes []string) []interface{} {
// the rows as JSON
func Query(db service.SQLDB) RequestProcessFunc {
return func(r *http.Request) (*serializer.Response, error) {
var queryRequest queryRequest
var queryReq queryRequest
body, err := ioutil.ReadAll(r.Body)
if err != nil {
return nil, err
}

err = json.Unmarshal(body, &queryRequest)
if err != nil || queryRequest.Query == "" {
err = json.Unmarshal(body, &queryReq)
if err != nil || queryReq.Query == "" {
return nil, serializer.NewHTTPError(http.StatusBadRequest,
`Bad Request. Expected body: { "query": "SQL statement", "limit": 1234 }`)
}

query, limitSet := addLimit(queryRequest.Query, queryRequest.Limit)

// go-sql-driver/mysql QueryContext stops waiting for the query results on
// context cancel, but it does not actually cancel the query on the server

c := make(chan error, 1)

var rows *sql.Rows
conn, err := db.Conn(r.Context())
if err != nil {
return nil, fmt.Errorf("failed to get a DB connection: %s", err)
}
defer conn.Close()

connID, err := getConnID(r, conn)
connID, err := getConnID(conn)
if err != nil {
return nil, fmt.Errorf("failed to get connection id: %s", err)
}

var resp *serializer.Response
go func() {
rows, err = conn.QueryContext(r.Context(), query)
resp, err = queryContext(r.Context(), conn, queryReq)
c <- err
}()

Expand All @@ -103,40 +104,53 @@ func Query(db service.SQLDB) RequestProcessFunc {
return nil, dbError(err)
}

defer rows.Close()
return resp, nil
}
}

columnNames, columnTypes, err := columnsInfo(rows)
if err != nil {
return nil, err
}
func queryContext(ctx context.Context, conn *sql.Conn, queryReq queryRequest) (*serializer.Response, error) {
query, limitSet := addLimit(queryReq.Query, queryReq.Limit)

columnValsPtr := genericVals(columnTypes)
var rows *sql.Rows

tableData := make([]map[string]interface{}, 0)
rows, err := conn.QueryContext(ctx, query)
if err != nil {
return nil, err
}

for rows.Next() {
if err := rows.Scan(columnValsPtr...); err != nil {
return nil, err
}
defer rows.Close()

colData, err := columnsData(columnNames, columnTypes, columnValsPtr)
if err != nil {
return nil, err
}
columnNames, columnTypes, err := columnsInfo(rows)
if err != nil {
return nil, err
}

columnValsPtr := genericVals(columnTypes)

tableData := make([]map[string]interface{}, 0)

tableData = append(tableData, colData)
for rows.Next() {
if err := rows.Scan(columnValsPtr...); err != nil {
return nil, err
}

if err := rows.Err(); err != nil {
colData, err := columnsData(columnNames, columnTypes, columnValsPtr)
if err != nil {
return nil, err
}

return serializer.NewQueryResponse(
tableData, columnNames, columnTypes, limitSet, queryRequest.Limit), nil
tableData = append(tableData, colData)
}

if err := rows.Err(); err != nil {
return nil, err
}

return serializer.NewQueryResponse(
tableData, columnNames, columnTypes, limitSet, queryReq.Limit), nil
}

func getConnID(r *http.Request, conn *sql.Conn) (uint32, error) {
func getConnID(conn *sql.Conn) (uint32, error) {
const connIDQuery = "SELECT CONNECTION_ID()"
var connID uint32

Expand Down

0 comments on commit 4cbf198

Please sign in to comment.