diff --git a/access/grpc/client.go b/access/grpc/client.go index 667bfcc67..d02474fe9 100644 --- a/access/grpc/client.go +++ b/access/grpc/client.go @@ -352,6 +352,29 @@ func (c *Client) SubscribeBlocksFromLatest( return c.grpc.SubscribeBlocksFromLatest(ctx, blockStatus) } +func (c *Client) SubscribeBlockHeadersFromStartBlockID( + ctx context.Context, + startBlockID flow.Identifier, + blockStatus flow.BlockStatus, +) (<-chan flow.BlockHeader, <-chan error, error) { + return c.grpc.SubscribeBlockHeadersFromStartBlockID(ctx, startBlockID, blockStatus) +} + +func (c *Client) SubscribeBlockHeadersFromStartHeight( + ctx context.Context, + startHeight uint64, + blockStatus flow.BlockStatus, +) (<-chan flow.BlockHeader, <-chan error, error) { + return c.grpc.SubscribeBlockHeadersFromStartHeight(ctx, startHeight, blockStatus) +} + +func (c *Client) SubscribeBlocksHeadersFromLatest( + ctx context.Context, + blockStatus flow.BlockStatus, +) (<-chan flow.BlockHeader, <-chan error, error) { + return c.grpc.SubscribeBlockHeadersFromLatest(ctx, blockStatus) +} + func (c *Client) Close() error { return c.grpc.Close() } diff --git a/access/grpc/grpc.go b/access/grpc/grpc.go index f84212df2..b667fa20f 100644 --- a/access/grpc/grpc.go +++ b/access/grpc/grpc.go @@ -1368,3 +1368,142 @@ func receiveBlocksFromClient[Client interface { } } } + +func (c *BaseClient) SubscribeBlockHeadersFromStartBlockID( + ctx context.Context, + startBlockID flow.Identifier, + blockStatus flow.BlockStatus, + opts ...grpc.CallOption, +) (<-chan flow.BlockHeader, <-chan error, error) { + status := convert.BlockStatusToEntity(blockStatus) + if status == entities.BlockStatus_BLOCK_UNKNOWN { + return nil, nil, newRPCError(errors.New("unknown block status")) + } + + request := &access.SubscribeBlockHeadersFromStartBlockIDRequest{ + StartBlockId: startBlockID.Bytes(), + BlockStatus: status, + } + + subscribeClient, err := c.rpcClient.SubscribeBlockHeadersFromStartBlockID(ctx, request, opts...) + if err != nil { + return nil, nil, newRPCError(err) + } + + blockHeaderChan := make(chan flow.BlockHeader) + errChan := make(chan error) + + go func() { + defer close(blockHeaderChan) + defer close(errChan) + receiveBlockHeadersFromClient(ctx, subscribeClient, blockHeaderChan, errChan) + }() + + return blockHeaderChan, errChan, nil +} + +func (c *BaseClient) SubscribeBlockHeadersFromStartHeight( + ctx context.Context, + startHeight uint64, + blockStatus flow.BlockStatus, + opts ...grpc.CallOption, +) (<-chan flow.BlockHeader, <-chan error, error) { + status := convert.BlockStatusToEntity(blockStatus) + if status == entities.BlockStatus_BLOCK_UNKNOWN { + return nil, nil, newRPCError(errors.New("unknown block status")) + } + + request := &access.SubscribeBlockHeadersFromStartHeightRequest{ + StartBlockHeight: startHeight, + BlockStatus: status, + } + + subscribeClient, err := c.rpcClient.SubscribeBlockHeadersFromStartHeight(ctx, request, opts...) + if err != nil { + return nil, nil, newRPCError(err) + } + + blockHeaderChan := make(chan flow.BlockHeader) + errChan := make(chan error) + + go func() { + defer close(blockHeaderChan) + defer close(errChan) + receiveBlockHeadersFromClient(ctx, subscribeClient, blockHeaderChan, errChan) + }() + + return blockHeaderChan, errChan, nil +} + +func (c *BaseClient) SubscribeBlockHeadersFromLatest( + ctx context.Context, + blockStatus flow.BlockStatus, + opts ...grpc.CallOption, +) (<-chan flow.BlockHeader, <-chan error, error) { + status := convert.BlockStatusToEntity(blockStatus) + if status == entities.BlockStatus_BLOCK_UNKNOWN { + return nil, nil, newRPCError(errors.New("unknown block status")) + } + + request := &access.SubscribeBlockHeadersFromLatestRequest{ + BlockStatus: status, + } + + subscribeClient, err := c.rpcClient.SubscribeBlockHeadersFromLatest(ctx, request, opts...) + if err != nil { + return nil, nil, newRPCError(err) + } + + blockHeaderChan := make(chan flow.BlockHeader) + errChan := make(chan error) + + go func() { + defer close(blockHeaderChan) + defer close(errChan) + receiveBlockHeadersFromClient(ctx, subscribeClient, blockHeaderChan, errChan) + }() + + return blockHeaderChan, errChan, nil +} + +func receiveBlockHeadersFromClient[Client interface { + Recv() (*access.SubscribeBlockHeadersResponse, error) +}]( + ctx context.Context, + client Client, + blockHeadersChan chan<- flow.BlockHeader, + errChan chan<- error, +) { + sendErr := func(err error) { + select { + case <-ctx.Done(): + case errChan <- err: + } + } + + for { + // Receive the next blockHeader response + blockHeaderResponse, err := client.Recv() + if err != nil { + if err == io.EOF { + // End of stream, return gracefully + return + } + + sendErr(fmt.Errorf("error receiving blockHeader: %w", err)) + return + } + + blockHeader, err := convert.MessageToBlockHeader(blockHeaderResponse.GetHeader()) + if err != nil { + sendErr(fmt.Errorf("error converting message to block header: %w", err)) + return + } + + select { + case <-ctx.Done(): + return + case blockHeadersChan <- blockHeader: + } + } +} diff --git a/access/grpc/grpc_test.go b/access/grpc/grpc_test.go index a5991e4df..8c7ec4727 100644 --- a/access/grpc/grpc_test.go +++ b/access/grpc/grpc_test.go @@ -2535,3 +2535,178 @@ func assertNoTxResults[TxStatus any](t *testing.T, txResultChan <-chan TxStatus, require.FailNow(t, "should not receive txStatus") } } + +func TestClient_SubscribeBlockHeaders(t *testing.T) { + blockHeaders := test.BlockHeaderGenerator() + + generateBlockHeaderResponses := func(count uint64) []*access.SubscribeBlockHeadersResponse { + var resBlockHeaders []*access.SubscribeBlockHeadersResponse + + for i := uint64(0); i < count; i++ { + header, err := convert.BlockHeaderToMessage(blockHeaders.New()) + require.NoError(t, err) + + resBlockHeaders = append(resBlockHeaders, &access.SubscribeBlockHeadersResponse{ + Header: header, + }) + } + + return resBlockHeaders + } + + t.Run("Happy Path - from start height", clientTest(func(t *testing.T, ctx context.Context, rpc *mocks.MockRPCClient, c *BaseClient) { + startHeight := uint64(1) + responseCount := uint64(100) + + ctx, cancel := context.WithCancel(ctx) + stream := &mockBlockHeaderClientStream[access.SubscribeBlockHeadersResponse]{ + ctx: ctx, + responses: generateBlockHeaderResponses(responseCount), + } + + rpc. + On("SubscribeBlockHeadersFromStartHeight", ctx, mock.Anything). + Return(stream, nil) + + blockHeadersCh, errCh, err := c.SubscribeBlockHeadersFromStartHeight(ctx, startHeight, flow.BlockStatusFinalized) + require.NoError(t, err) + + wg := sync.WaitGroup{} + wg.Add(1) + go assertNoErrors(t, errCh, wg.Done) + + for i := uint64(0); i < responseCount; i++ { + actualHeader := <-blockHeadersCh + expectedHeader, err := convert.MessageToBlockHeader(stream.responses[i].GetHeader()) + require.NoError(t, err) + require.Equal(t, expectedHeader, actualHeader) + } + cancel() + + wg.Wait() + })) + + t.Run("Happy Path - from start block id", clientTest(func(t *testing.T, ctx context.Context, rpc *mocks.MockRPCClient, c *BaseClient) { + responseCount := uint64(100) + + ctx, cancel := context.WithCancel(ctx) + stream := &mockBlockHeaderClientStream[access.SubscribeBlockHeadersResponse]{ + ctx: ctx, + responses: generateBlockHeaderResponses(responseCount), + } + + rpc. + On("SubscribeBlockHeadersFromStartBlockID", ctx, mock.Anything). + Return(stream, nil) + + startBlockID := convert.MessageToIdentifier(stream.responses[0].GetHeader().Id) + blockHeadersCh, errCh, err := c.SubscribeBlockHeadersFromStartBlockID(ctx, startBlockID, flow.BlockStatusFinalized) + require.NoError(t, err) + + wg := sync.WaitGroup{} + wg.Add(1) + go assertNoErrors(t, errCh, wg.Done) + + for i := uint64(0); i < responseCount; i++ { + actualHeader := <-blockHeadersCh + expectedHeader, err := convert.MessageToBlockHeader(stream.responses[i].GetHeader()) + require.NoError(t, err) + require.Equal(t, expectedHeader, actualHeader) + } + cancel() + + wg.Wait() + })) + + t.Run("Happy Path - from latest", clientTest(func(t *testing.T, ctx context.Context, rpc *mocks.MockRPCClient, c *BaseClient) { + responseCount := uint64(100) + + ctx, cancel := context.WithCancel(ctx) + stream := &mockBlockHeaderClientStream[access.SubscribeBlockHeadersResponse]{ + ctx: ctx, + responses: generateBlockHeaderResponses(responseCount), + } + + rpc. + On("SubscribeBlockHeadersFromLatest", ctx, mock.Anything). + Return(stream, nil) + + blockHeadersCh, errCh, err := c.SubscribeBlockHeadersFromLatest(ctx, flow.BlockStatusFinalized) + require.NoError(t, err) + + wg := sync.WaitGroup{} + wg.Add(1) + go assertNoErrors(t, errCh, wg.Done) + + for i := uint64(0); i < responseCount; i++ { + actualHeader := <-blockHeadersCh + expectedHeader, err := convert.MessageToBlockHeader(stream.responses[i].GetHeader()) + require.NoError(t, err) + require.Equal(t, expectedHeader, actualHeader) + } + cancel() + + wg.Wait() + })) + + t.Run("Stream returns error", clientTest(func(t *testing.T, ctx context.Context, rpc *mocks.MockRPCClient, c *BaseClient) { + ctx, cancel := context.WithCancel(ctx) + stream := &mockBlockHeaderClientStream[access.SubscribeBlockHeadersResponse]{ + ctx: ctx, + err: status.Error(codes.Internal, "internal error"), + } + + rpc. + On("SubscribeBlockHeadersFromLatest", ctx, mock.Anything). + Return(stream, nil) + + blockHeadersCh, errCh, err := c.SubscribeBlockHeadersFromLatest(ctx, flow.BlockStatusFinalized) + require.NoError(t, err) + + wg := sync.WaitGroup{} + wg.Add(1) + go assertNoBlockHeaders(t, blockHeadersCh, wg.Done) + + errorCount := 0 + for e := range errCh { + require.Error(t, e) + require.ErrorIs(t, e, stream.err) + errorCount += 1 + } + cancel() + + require.Equalf(t, 1, errorCount, "only 1 error is expected") + + wg.Wait() + })) +} + +type mockBlockHeaderClientStream[SubscribeBlockHeadersResponse any] struct { + grpc.ClientStream + + ctx context.Context + err error + offset int + responses []*SubscribeBlockHeadersResponse +} + +func (s *mockBlockHeaderClientStream[SubscribeBlockHeadersResponse]) Recv() (*SubscribeBlockHeadersResponse, error) { + if s.err != nil { + return nil, s.err + } + + if s.offset >= len(s.responses) { + <-s.ctx.Done() + return nil, io.EOF + } + defer func() { s.offset++ }() + + return s.responses[s.offset], nil +} + +func assertNoBlockHeaders[BlockHeader any](t *testing.T, blockHeadersChan <-chan BlockHeader, done func()) { + defer done() + for range blockHeadersChan { + require.FailNow(t, "should not receive block headers") + } +}