Skip to content

Commit

Permalink
Fix bidi stream blocking naive proxy setup causes. (#143)
Browse files Browse the repository at this point in the history
There was an assumption we could blast N start streams down and then loop reading responses.

In practice this works until about 2k starts. At that point the remote server has sent enough responses that it will now block sending back.

The server requires the client to clear its response queue as it won't buffer infinitely to avoid scaling issues.

As a result we deadlock around ~2k+ start streams.

Solve this by moving Send/Recv into separate go routines since the state they are modifying is disjoint. We need to send N starts and then we're done. We need to get back N replies but they aren't correlated past that.

Add some logging at the end so debugging is easier since Send can return EOF and means "the error is on the Recv side".
  • Loading branch information
sfc-gh-jchacon authored Jun 28, 2022
1 parent 6cd231d commit eb25183
Showing 1 changed file with 115 additions and 83 deletions.
198 changes: 115 additions & 83 deletions proxy/proxy/proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ package proxy

import (
"context"
"fmt"
"io"
"log"
"strings"
"sync"
"time"

"google.golang.org/grpc"
Expand All @@ -35,6 +37,7 @@ import (
"google.golang.org/protobuf/types/known/durationpb"

proxypb "github.com/Snowflake-Labs/sansshell/proxy"
"github.com/go-logr/logr"
)

// Conn is a grpc.ClientConnInterface which is connected to the proxy
Expand Down Expand Up @@ -347,100 +350,129 @@ func (p *Conn) createStreams(ctx context.Context, method string) (proxypb.Proxy_

streamIds := make(map[uint64]*Ret)

// For every target we have to send a separate StartStream (with a nonce which in our case is the target index so clients can map too).
// We then validate the nonce matches and record the stream ID so later processing can match responses to the right targets.
// This needs to be 2 loops as we want the server to process N StartStreams in parallel and then we'll loop getting responses.
for i, t := range p.Targets {
req := &proxypb.ProxyRequest{
Request: &proxypb.ProxyRequest_StartStream{
StartStream: &proxypb.StartStream{
Target: t,
MethodName: method,
Nonce: uint32(i),
},
},
}
if p.timeouts[i] != nil {
req.GetStartStream().DialTimeout = durationpb.New(*p.timeouts[i])
}
err = stream.Send(req)

// If Send reports an error and is EOF we have to use Recv to get the actual error according to documentation
// for SendMsg. However it appears SendMsg will return actual errors "sometimes" when it's the first stream
// a server has ever handled so account for that here.
if err != nil && err != io.EOF {
return nil, nil, errors, status.Errorf(codes.Internal, "can't send request for %s on stream - %v", method, err)
}
if err != nil {
_, err := stream.Recv()
return nil, nil, errors, status.Errorf(codes.Internal, "remote error from Send for %s - %v", method, err)
}
}
wg := &sync.WaitGroup{}

// We sent len(p.Targets) requests so loop for that many replies. If the server doesn't we'll have to wait until
// our context times out then. If the server attempts something invalid we'll catch and just abort (i.e, duplicate
// responses and/or out of range). We may encounter closes() in here mixed in with replies but we'll never get
// more than that (data can't start until we return).
replies := 0
for replies != len(p.Targets) {
resp, err := stream.Recv()
if err != nil {
return nil, nil, errors, status.Errorf(codes.Internal, "can't get response for %s on stream - %v", method, err)
}

// Validate we got an answer and it has expected reflected values.

// These are all sanity checks for the entire session so an overall error is appropriate since we're likely
// dealing with a broken proxy of some sort.
switch t := resp.Reply.(type) {
case *proxypb.ProxyReply_StartStreamReply:
replies++
r := t.StartStreamReply
// We want the returned Target+nonce to match what we sent and that it's one we know about.
if r.Nonce >= uint32(len(p.Targets)) {
return nil, nil, errors, status.Errorf(codes.Internal, "got back invalid nonce (out of range): %+v", r)
var sendErr, recvErr error
wg.Add(1)
go func() {
defer wg.Done()
// For every target we have to send a separate StartStream (with a nonce which in our case is the target index so clients can map too).
// We then validate the nonce matches and record the stream ID so later processing can match responses to the right targets.
// This needs to be 2 loops as we want the server to process N StartStreams in parallel and then we'll loop getting responses.
for i, t := range p.Targets {
req := &proxypb.ProxyRequest{
Request: &proxypb.ProxyRequest_StartStream{
StartStream: &proxypb.StartStream{
Target: t,
MethodName: method,
Nonce: uint32(i),
},
},
}
if p.Targets[r.Nonce] != r.Target {
return nil, nil, errors, status.Errorf(codes.Internal, "Target/nonce don't match. target %s(%d) is not %s: %+v", p.Targets[r.Nonce], r.Nonce, r.Target, r)
if p.timeouts[i] != nil {
req.GetStartStream().DialTimeout = durationpb.New(*p.timeouts[i])
}
err = stream.Send(req)

id := r.GetStreamId()
if streamIds[id] != nil {
return nil, nil, errors, status.Errorf(codes.Internal, "Duplicate response for target %s. Already have %+v for response %+v", r.Target, streamIds[id], r)
// If Send reports an error and is EOF we have to use Recv to get the actual error according to documentation
// for SendMsg. However it appears SendMsg will return actual errors "sometimes" when it's the first stream
// a server has ever handled so account for that here. The actual Recv for the error will get caught in the other
// routine below.
if err != nil {
if err != io.EOF {
sendErr = status.Errorf(codes.Internal, "can't send request for %s on stream - %v", method, err)
}
return
}
}
}()

ret := &Ret{
Target: r.GetTarget(),
Index: int(r.GetNonce()),
}
// If the target reported an error stick it in errors.
if s := r.GetErrorStatus(); s != nil {
ret.Error = status.Errorf(codes.Internal, "got reply error from stream. Code: %s Message: %s", codes.Code(s.Code), s.Message)
errors = append(errors, ret)
continue
wg.Add(1)
go func() {
defer wg.Done()

// We sent len(p.Targets) requests so loop for that many start stream replies. If the server doesn't we'll have to wait until
// our context times out then. If the server attempts something invalid we'll catch and just abort (i.e, duplicate
// responses and/or out of range). We may encounter closes() in here mixed in with replies but we'll never get
// more than that (data can't start until we return). For closes() we discover we note them in errors so later code
// just skips them as they've already errored out.
replies := 0
for replies != len(p.Targets) {
resp, err := stream.Recv()
if err != nil {
recvErr = status.Errorf(codes.Internal, "can't get response for %s on stream - %v", method, err)
return
}

// Save stream ID/nonce for later matching.
streamIds[r.GetStreamId()] = ret
case *proxypb.ProxyReply_ServerClose:
c := t.ServerClose
// We've never sent any data so a close here has to be an error.
st := c.GetStatus()
if st == nil || st.Code == 0 {
return nil, nil, errors, status.Errorf(codes.Internal, "close with no data sent and no error? %+v", resp)
}
for _, id := range c.StreamIds {
if streamIds[id] == nil {
return nil, nil, errors, status.Errorf(codes.Internal, "close on invalid stream id: %+v", resp)
// Validate we got an answer and it has expected reflected values.

// These are all sanity checks for the entire session so an overall error is appropriate since we're likely
// dealing with a broken proxy of some sort.
switch t := resp.Reply.(type) {
case *proxypb.ProxyReply_StartStreamReply:
replies++
r := t.StartStreamReply
// We want the returned Target+nonce to match what we sent and that it's one we know about.
if r.Nonce >= uint32(len(p.Targets)) {
recvErr = status.Errorf(codes.Internal, "got back invalid nonce (out of range): %+v", r)
return
}
if p.Targets[r.Nonce] != r.Target {
recvErr = status.Errorf(codes.Internal, "Target/nonce don't match. target %s(%d) is not %s: %+v", p.Targets[r.Nonce], r.Nonce, r.Target, r)
return
}

id := r.GetStreamId()
if streamIds[id] != nil {
recvErr = status.Errorf(codes.Internal, "Duplicate response for target %s. Already have %+v for response %+v", r.Target, streamIds[id], r)
return
}

ret := &Ret{
Target: r.GetTarget(),
Index: int(r.GetNonce()),
}
streamIds[id].Error = status.Errorf(codes.Internal, "got close error from stream. Code: %s Message: %s", codes.Code(st.Code), st.Message)
errors = append(errors, streamIds[id])
// If it's closed make sure we don't process it later on.
delete(streamIds, id)
// If the target reported an error stick it in errors.
if s := r.GetErrorStatus(); s != nil {
ret.Error = status.Errorf(codes.Internal, "got reply error from stream. Code: %s Message: %s", codes.Code(s.Code), s.Message)
errors = append(errors, ret)
continue
}

// Save stream ID/nonce for later matching.
streamIds[r.GetStreamId()] = ret
case *proxypb.ProxyReply_ServerClose:
c := t.ServerClose
// We've never sent any data so a close here has to be an error.
st := c.GetStatus()
if st == nil || st.Code == 0 {
recvErr = status.Errorf(codes.Internal, "close with no data sent and no error? %+v", resp)
return
}
for _, id := range c.StreamIds {
if streamIds[id] == nil {
recvErr = status.Errorf(codes.Internal, "close on invalid stream id: %+v", resp)
return
}
streamIds[id].Error = status.Errorf(codes.Internal, "got close error from stream. Code: %s Message: %s", codes.Code(st.Code), st.Message)
errors = append(errors, streamIds[id])
// If it's closed make sure we don't process it later on.
delete(streamIds, id)
}
default:
recvErr = status.Errorf(codes.Internal, "unexpected reply for %s on stream - %+v", method, resp)
return
}
default:
return nil, nil, errors, status.Errorf(codes.Internal, "unexpected reply for %s on stream - %+v", method, resp)
}
}()

wg.Wait()

log := logr.FromContextOrDiscard(ctx)
if sendErr != nil || recvErr != nil {
err := fmt.Errorf("Setting up streams errors: %v - %v", sendErr, recvErr)
log.Error(err, "Setup error")
return nil, nil, errors, err
}
return stream, streamIds, errors, nil
}
Expand Down

0 comments on commit eb25183

Please sign in to comment.