From eb251837ed890f2fb2c7fed17ed36a1a2fd88f7f Mon Sep 17 00:00:00 2001 From: James Chacon Date: Tue, 28 Jun 2022 14:30:43 -0700 Subject: [PATCH] Fix bidi stream blocking naive proxy setup causes. (#143) There was an assumption we could blast N start streams down and then loop reading responses. In practice this works until about 2k starts. At that point the remote server has sent enough responses that it will now block sending back. The server requires the client to clear its response queue as it won't buffer infinitely to avoid scaling issues. As a result we deadlock around ~2k+ start streams. Solve this by moving Send/Recv into separate go routines since the state they are modifying is disjoint. We need to send N starts and then we're done. We need to get back N replies but they aren't correlated past that. Add some logging at the end so debugging is easier since Send can return EOF and means "the error is on the Recv side". --- proxy/proxy/proxy.go | 198 +++++++++++++++++++++++++------------------ 1 file changed, 115 insertions(+), 83 deletions(-) diff --git a/proxy/proxy/proxy.go b/proxy/proxy/proxy.go index 865db53b..0bf6ecd0 100644 --- a/proxy/proxy/proxy.go +++ b/proxy/proxy/proxy.go @@ -21,9 +21,11 @@ package proxy import ( "context" + "fmt" "io" "log" "strings" + "sync" "time" "google.golang.org/grpc" @@ -35,6 +37,7 @@ import ( "google.golang.org/protobuf/types/known/durationpb" proxypb "github.com/Snowflake-Labs/sansshell/proxy" + "github.com/go-logr/logr" ) // Conn is a grpc.ClientConnInterface which is connected to the proxy @@ -347,100 +350,129 @@ func (p *Conn) createStreams(ctx context.Context, method string) (proxypb.Proxy_ streamIds := make(map[uint64]*Ret) - // For every target we have to send a separate StartStream (with a nonce which in our case is the target index so clients can map too). - // We then validate the nonce matches and record the stream ID so later processing can match responses to the right targets. - // This needs to be 2 loops as we want the server to process N StartStreams in parallel and then we'll loop getting responses. - for i, t := range p.Targets { - req := &proxypb.ProxyRequest{ - Request: &proxypb.ProxyRequest_StartStream{ - StartStream: &proxypb.StartStream{ - Target: t, - MethodName: method, - Nonce: uint32(i), - }, - }, - } - if p.timeouts[i] != nil { - req.GetStartStream().DialTimeout = durationpb.New(*p.timeouts[i]) - } - err = stream.Send(req) - - // If Send reports an error and is EOF we have to use Recv to get the actual error according to documentation - // for SendMsg. However it appears SendMsg will return actual errors "sometimes" when it's the first stream - // a server has ever handled so account for that here. - if err != nil && err != io.EOF { - return nil, nil, errors, status.Errorf(codes.Internal, "can't send request for %s on stream - %v", method, err) - } - if err != nil { - _, err := stream.Recv() - return nil, nil, errors, status.Errorf(codes.Internal, "remote error from Send for %s - %v", method, err) - } - } + wg := &sync.WaitGroup{} - // We sent len(p.Targets) requests so loop for that many replies. If the server doesn't we'll have to wait until - // our context times out then. If the server attempts something invalid we'll catch and just abort (i.e, duplicate - // responses and/or out of range). We may encounter closes() in here mixed in with replies but we'll never get - // more than that (data can't start until we return). - replies := 0 - for replies != len(p.Targets) { - resp, err := stream.Recv() - if err != nil { - return nil, nil, errors, status.Errorf(codes.Internal, "can't get response for %s on stream - %v", method, err) - } - - // Validate we got an answer and it has expected reflected values. - - // These are all sanity checks for the entire session so an overall error is appropriate since we're likely - // dealing with a broken proxy of some sort. - switch t := resp.Reply.(type) { - case *proxypb.ProxyReply_StartStreamReply: - replies++ - r := t.StartStreamReply - // We want the returned Target+nonce to match what we sent and that it's one we know about. - if r.Nonce >= uint32(len(p.Targets)) { - return nil, nil, errors, status.Errorf(codes.Internal, "got back invalid nonce (out of range): %+v", r) + var sendErr, recvErr error + wg.Add(1) + go func() { + defer wg.Done() + // For every target we have to send a separate StartStream (with a nonce which in our case is the target index so clients can map too). + // We then validate the nonce matches and record the stream ID so later processing can match responses to the right targets. + // This needs to be 2 loops as we want the server to process N StartStreams in parallel and then we'll loop getting responses. + for i, t := range p.Targets { + req := &proxypb.ProxyRequest{ + Request: &proxypb.ProxyRequest_StartStream{ + StartStream: &proxypb.StartStream{ + Target: t, + MethodName: method, + Nonce: uint32(i), + }, + }, } - if p.Targets[r.Nonce] != r.Target { - return nil, nil, errors, status.Errorf(codes.Internal, "Target/nonce don't match. target %s(%d) is not %s: %+v", p.Targets[r.Nonce], r.Nonce, r.Target, r) + if p.timeouts[i] != nil { + req.GetStartStream().DialTimeout = durationpb.New(*p.timeouts[i]) } + err = stream.Send(req) - id := r.GetStreamId() - if streamIds[id] != nil { - return nil, nil, errors, status.Errorf(codes.Internal, "Duplicate response for target %s. Already have %+v for response %+v", r.Target, streamIds[id], r) + // If Send reports an error and is EOF we have to use Recv to get the actual error according to documentation + // for SendMsg. However it appears SendMsg will return actual errors "sometimes" when it's the first stream + // a server has ever handled so account for that here. The actual Recv for the error will get caught in the other + // routine below. + if err != nil { + if err != io.EOF { + sendErr = status.Errorf(codes.Internal, "can't send request for %s on stream - %v", method, err) + } + return } + } + }() - ret := &Ret{ - Target: r.GetTarget(), - Index: int(r.GetNonce()), - } - // If the target reported an error stick it in errors. - if s := r.GetErrorStatus(); s != nil { - ret.Error = status.Errorf(codes.Internal, "got reply error from stream. Code: %s Message: %s", codes.Code(s.Code), s.Message) - errors = append(errors, ret) - continue + wg.Add(1) + go func() { + defer wg.Done() + + // We sent len(p.Targets) requests so loop for that many start stream replies. If the server doesn't we'll have to wait until + // our context times out then. If the server attempts something invalid we'll catch and just abort (i.e, duplicate + // responses and/or out of range). We may encounter closes() in here mixed in with replies but we'll never get + // more than that (data can't start until we return). For closes() we discover we note them in errors so later code + // just skips them as they've already errored out. + replies := 0 + for replies != len(p.Targets) { + resp, err := stream.Recv() + if err != nil { + recvErr = status.Errorf(codes.Internal, "can't get response for %s on stream - %v", method, err) + return } - // Save stream ID/nonce for later matching. - streamIds[r.GetStreamId()] = ret - case *proxypb.ProxyReply_ServerClose: - c := t.ServerClose - // We've never sent any data so a close here has to be an error. - st := c.GetStatus() - if st == nil || st.Code == 0 { - return nil, nil, errors, status.Errorf(codes.Internal, "close with no data sent and no error? %+v", resp) - } - for _, id := range c.StreamIds { - if streamIds[id] == nil { - return nil, nil, errors, status.Errorf(codes.Internal, "close on invalid stream id: %+v", resp) + // Validate we got an answer and it has expected reflected values. + + // These are all sanity checks for the entire session so an overall error is appropriate since we're likely + // dealing with a broken proxy of some sort. + switch t := resp.Reply.(type) { + case *proxypb.ProxyReply_StartStreamReply: + replies++ + r := t.StartStreamReply + // We want the returned Target+nonce to match what we sent and that it's one we know about. + if r.Nonce >= uint32(len(p.Targets)) { + recvErr = status.Errorf(codes.Internal, "got back invalid nonce (out of range): %+v", r) + return + } + if p.Targets[r.Nonce] != r.Target { + recvErr = status.Errorf(codes.Internal, "Target/nonce don't match. target %s(%d) is not %s: %+v", p.Targets[r.Nonce], r.Nonce, r.Target, r) + return + } + + id := r.GetStreamId() + if streamIds[id] != nil { + recvErr = status.Errorf(codes.Internal, "Duplicate response for target %s. Already have %+v for response %+v", r.Target, streamIds[id], r) + return + } + + ret := &Ret{ + Target: r.GetTarget(), + Index: int(r.GetNonce()), } - streamIds[id].Error = status.Errorf(codes.Internal, "got close error from stream. Code: %s Message: %s", codes.Code(st.Code), st.Message) - errors = append(errors, streamIds[id]) - // If it's closed make sure we don't process it later on. - delete(streamIds, id) + // If the target reported an error stick it in errors. + if s := r.GetErrorStatus(); s != nil { + ret.Error = status.Errorf(codes.Internal, "got reply error from stream. Code: %s Message: %s", codes.Code(s.Code), s.Message) + errors = append(errors, ret) + continue + } + + // Save stream ID/nonce for later matching. + streamIds[r.GetStreamId()] = ret + case *proxypb.ProxyReply_ServerClose: + c := t.ServerClose + // We've never sent any data so a close here has to be an error. + st := c.GetStatus() + if st == nil || st.Code == 0 { + recvErr = status.Errorf(codes.Internal, "close with no data sent and no error? %+v", resp) + return + } + for _, id := range c.StreamIds { + if streamIds[id] == nil { + recvErr = status.Errorf(codes.Internal, "close on invalid stream id: %+v", resp) + return + } + streamIds[id].Error = status.Errorf(codes.Internal, "got close error from stream. Code: %s Message: %s", codes.Code(st.Code), st.Message) + errors = append(errors, streamIds[id]) + // If it's closed make sure we don't process it later on. + delete(streamIds, id) + } + default: + recvErr = status.Errorf(codes.Internal, "unexpected reply for %s on stream - %+v", method, resp) + return } - default: - return nil, nil, errors, status.Errorf(codes.Internal, "unexpected reply for %s on stream - %+v", method, resp) } + }() + + wg.Wait() + + log := logr.FromContextOrDiscard(ctx) + if sendErr != nil || recvErr != nil { + err := fmt.Errorf("Setting up streams errors: %v - %v", sendErr, recvErr) + log.Error(err, "Setup error") + return nil, nil, errors, err } return stream, streamIds, errors, nil }