Skip to content

Commit

Permalink
feat: expose cancellation token triggering when the test framework is…
Browse files Browse the repository at this point in the history
… stopping in order to be able to abort communication with the test server if it is for example panicing

BREAKING CHANGE: The test framework is no longer disposable. It is teared down when calling the disposable returned from Start.
  • Loading branch information
Fredrik Arvidsson committed Jan 21, 2022
1 parent 21138fd commit 52c77c9
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 48 deletions.
8 changes: 3 additions & 5 deletions src/Kafka.TestFramework/Client.cs
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
using System;
using System.IO;
using System.IO.Pipelines;
using System.Threading;
using System.Threading.Tasks;
using Kafka.Protocol;
using Int32 = Kafka.Protocol.Int32;

namespace Kafka.TestFramework
{
internal abstract class Client : IAsyncDisposable
{
private readonly CancellationTokenSource _cancellationSource = new CancellationTokenSource();
private readonly CancellationTokenSource _cancellationSource;
private readonly Pipe _pipe = new Pipe();
private readonly INetworkClient _networkClient;
private Task _sendAndReceiveBackgroundTask = default!;

protected Client(INetworkClient networkClient)
protected Client(INetworkClient networkClient, CancellationToken cancellationToken)
{
_networkClient = networkClient;
NetworkClient = new NetworkStream(networkClient);
Reader = _pipe.Reader;
_cancellationSource = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken);
}

protected NetworkStream NetworkClient { get; }
Expand Down
4 changes: 2 additions & 2 deletions src/Kafka.TestFramework/InMemoryKafkaTestFramework.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ public async Task<IRequestClient> CreateRequestClientAsync(
var requestClient = new CrossWiredMemoryNetworkClient(first, second);
var responseClient = new CrossWiredMemoryNetworkClient(second, first);
await _clients
.SendAsync(responseClient, cancellationToken)
.SendAsync(responseClient, CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, Stopping).Token)
.ConfigureAwait(false);
return new DisposableRequestClientDecorator(RequestClient.Start(requestClient), responseClient, second, first);
return new DisposableRequestClientDecorator(RequestClient.Start(requestClient, Stopping), responseClient, second, first);
}

private class DisposableRequestClientDecorator : IRequestClient
Expand Down
61 changes: 43 additions & 18 deletions src/Kafka.TestFramework/KafkaTestFramework.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,23 @@

namespace Kafka.TestFramework
{
public abstract class KafkaTestFramework : IAsyncDisposable
public abstract class KafkaTestFramework
{
private readonly INetworkServer _networkServer;

private readonly CancellationTokenSource _cancellationTokenSource =
new CancellationTokenSource();

private readonly List<Task> _backgroundTasks = new List<Task>();

private const int Stopped = 0;
private const int Started = 1;
private int _status = Stopped;
private const int HasStopped = 0;
private const int HasStarted = 1;
private int _status = HasStopped;

/// <summary>
/// Triggered when the test framework is stopping
/// </summary>
public CancellationToken Stopping => _cancellationTokenSource.Token;

public static InMemoryKafkaTestFramework InMemory()
{
Expand Down Expand Up @@ -47,10 +53,10 @@ internal KafkaTestFramework(INetworkServer networkServer)

public IAsyncDisposable Start()
{
var previousStatus = Interlocked.Exchange(ref _status, Started);
if (previousStatus == Started)
var previousStatus = Interlocked.Exchange(ref _status, HasStarted);
if (previousStatus == HasStarted)
{
return this;
return new StopOnDispose(this);
}

var task = Task.Run(
Expand All @@ -61,46 +67,51 @@ public IAsyncDisposable Start()
try
{
var client = await _networkServer
.WaitForConnectedClientAsync(_cancellationTokenSource.Token)
.WaitForConnectedClientAsync(Stopping)
.ConfigureAwait(false);
ReceiveMessagesFor(client);
}
catch when (_cancellationTokenSource.IsCancellationRequested)
{
return;
}
catch
{
_cancellationTokenSource.Cancel();
throw;
}
}
});
_backgroundTasks.Add(task);
return this;
return new StopOnDispose(this);
}

private void ReceiveMessagesFor(INetworkClient networkClient)
{
var task = Task.Run(
async () =>
{
var client = ResponseClient.Start(networkClient);
var client = ResponseClient.Start(networkClient, Stopping);
await using var _ = client.ConfigureAwait(false);
while (_cancellationTokenSource.IsCancellationRequested == false)
{
try
{
var requestPayload = await client
.ReadAsync(_cancellationTokenSource.Token)
.ReadAsync(Stopping)
.ConfigureAwait(false);
if (!_subscriptions.TryGetValue(
requestPayload.Message.GetType(),
out var subscription))
{
throw new InvalidOperationException(
$"Missing subscription for {requestPayload.Message.GetType()}");
$"Missing subscription for {requestPayload.Message.GetType()}");
}
var response = await subscription(
requestPayload.Message,
_cancellationTokenSource.Token);
requestPayload.Message,
Stopping);
await client
.SendAsync(
Expand All @@ -109,13 +120,18 @@ await client
Messages.GetResponseHeaderVersionFor(requestPayload))
.WithCorrelationId(requestPayload.Header.CorrelationId),
response),
_cancellationTokenSource.Token)
Stopping)
.ConfigureAwait(false);
}
catch when (_cancellationTokenSource.IsCancellationRequested)
{
return;
}
catch
{
_cancellationTokenSource.Cancel();
throw;
}
}
});
_backgroundTasks.Add(task);
Expand Down Expand Up @@ -163,11 +179,20 @@ public KafkaTestFramework On<TRequestMessage, TResponseMessage>(
return this;
}

public async ValueTask DisposeAsync()
internal sealed class StopOnDispose : IAsyncDisposable
{
_cancellationTokenSource.Cancel();
private readonly KafkaTestFramework _testFramework;

await Task.WhenAll(_backgroundTasks);
public StopOnDispose(KafkaTestFramework testFramework)
{
_testFramework = testFramework;
}
public async ValueTask DisposeAsync()
{
_testFramework._cancellationTokenSource.Cancel();
await Task.WhenAll(_testFramework._backgroundTasks)
.ConfigureAwait(false);
}
}
}
}
8 changes: 4 additions & 4 deletions src/Kafka.TestFramework/RequestClient.cs
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
using System.Threading;
using System.Threading.Tasks;
using Kafka.Protocol;
using Int32 = Kafka.Protocol.Int32;

namespace Kafka.TestFramework
{
internal class RequestClient : Client, IRequestClient
{
private RequestClient(INetworkClient networkClient) : base(networkClient)
private RequestClient(INetworkClient networkClient, CancellationToken cancellationToken) :
base(networkClient, cancellationToken)
{
}

internal static RequestClient Start(INetworkClient networkClient)
internal static RequestClient Start(INetworkClient networkClient, CancellationToken cancellationToken)
{
var client = new RequestClient(networkClient);
var client = new RequestClient(networkClient, cancellationToken);
client.StartReceiving();
return client;
}
Expand Down
7 changes: 4 additions & 3 deletions src/Kafka.TestFramework/ResponseClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@ namespace Kafka.TestFramework
{
internal class ResponseClient : Client
{
private ResponseClient(INetworkClient networkClient) : base(networkClient)
private ResponseClient(INetworkClient networkClient, CancellationToken cancellationToken) :
base(networkClient, cancellationToken)
{
}

internal static ResponseClient Start(INetworkClient networkClient)
internal static ResponseClient Start(INetworkClient networkClient, CancellationToken cancellationToken)
{
var client = new ResponseClient(networkClient);
var client = new ResponseClient(networkClient, cancellationToken);
client.StartReceiving();
return client;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ protected override Task GivenAsync()
.WithLeaderId(Int32.From(0))
.WithPartitionIndex(Int32.From(0))
.WithReplicaNodesCollection(new[] { Int32.From(0) }))))
.ToArray() ?? new Func<MetadataResponse.MetadataResponseTopic, MetadataResponse.MetadataResponseTopic>[0])
.ToArray() ?? Array.Empty<Func<MetadataResponse.MetadataResponseTopic, MetadataResponse.MetadataResponseTopic>>())
.WithControllerId(Int32.From(0))
.WithClusterId(String.From("test"))
.WithBrokersCollection(broker => broker
Expand Down Expand Up @@ -96,7 +96,7 @@ protected override async Task WhenAsync()
await using (_testServer.Start()
.ConfigureAwait(false))
{
await ProduceMessageFromClientAsync("localhost", _testServer.Port)
await ProduceMessageFromClientAsync("localhost", _testServer.Port, _testServer.Stopping)
.ConfigureAwait(false);
}
}
Expand All @@ -114,7 +114,7 @@ public void It_should_have_read_the_message_sent()
}

private static async Task ProduceMessageFromClientAsync(string host,
int port)
int port, CancellationToken testServerStopping)
{
var producerConfig = new ProducerConfig(new Dictionary<string, string>
{
Expand All @@ -138,7 +138,7 @@ private static async Task ProduceMessageFromClientAsync(string host,
.ConfigureAwait(false);
LogFactory.Create("producer").Info("Produce report {@report}", report);

producer.Flush();
producer.Flush(testServerStopping);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ protected override Task GivenAsync()
.WithLeaderId(Int32.From(0))
.WithPartitionIndex(Int32.From(0))
.WithReplicaNodesCollection(new[] { Int32.From(0) }))))
.ToArray() ?? new Func<MetadataResponse.MetadataResponseTopic, MetadataResponse.MetadataResponseTopic>[0])
.ToArray() ?? Array.Empty<Func<MetadataResponse.MetadataResponseTopic, MetadataResponse.MetadataResponseTopic>>())
.WithControllerId(Int32.From(0))
.WithClusterId(String.From("test"))
.WithBrokersCollection(broker => broker
Expand Down Expand Up @@ -96,7 +96,7 @@ protected override async Task WhenAsync()
await using (_testServer.Start()
.ConfigureAwait(false))
{
await ProduceMessageFromClientAsync("localhost", _testServer.Port)
await ProduceMessageFromClientAsync("localhost", _testServer.Port, _testServer.Stopping)
.ConfigureAwait(false);
}
}
Expand All @@ -114,7 +114,7 @@ public void It_should_have_read_the_message_sent()
}

private static async Task ProduceMessageFromClientAsync(string host,
int port)
int port, CancellationToken testServerStopping)
{
var producerConfig = new ProducerConfig(new Dictionary<string, string>
{
Expand All @@ -133,11 +133,11 @@ private static async Task ProduceMessageFromClientAsync(string host,

var report = await producer
.ProduceAsync("my-topic",
new Message<Null, string> { Value = "test" })
new Message<Null, string> { Value = "test" }, testServerStopping)
.ConfigureAwait(false);
LogFactory.Create("producer").Info("Produce report {@report}", report);

producer.Flush();
producer.Flush(testServerStopping);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,6 @@ public void
.ApiKeysCollection[FetchRequest.ApiKey].MaxVersion.Should()
.Be(FetchRequest.MaxVersion);
}

protected override async Task TearDownAsync()
{
await _testServer
.DisposeAsync()
.ConfigureAwait(false);
}
}
}
}

0 comments on commit 52c77c9

Please sign in to comment.