Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(Event streams): add support for cancellation and IAsyncEnumerable #3543

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/*******************************************************************************
/*******************************************************************************
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use
* this file except in compliance with the License. A copy of the License is located at
Expand All @@ -25,6 +25,7 @@
using System.Diagnostics.CodeAnalysis;
using System.IO;
#if AWS_ASYNC_API
using System.Threading;
using System.Threading.Tasks;
#endif

Expand All @@ -36,7 +37,12 @@ namespace Amazon.Runtime.EventStreams.Internal
/// <typeparam name="T">An implementation of IEventStreamEvent (e.g. IS3Event).</typeparam>
/// <typeparam name="TE">An implementation of EventStreamException (e.g. S3EventStreamException).</typeparam>
[SuppressMessage("Microsoft.Naming", "CA1710", Justification = "IEventStreamCollection is not descriptive.")]
#if AWS_ASYNC_ENUMERABLES_API

public interface IEnumerableEventStream<T, TE> : IEventStream<T, TE>, IEnumerable<T>, IAsyncEnumerable<T> where T : IEventStreamEvent where TE : EventStreamException, new()
#else
public interface IEnumerableEventStream<T, TE> : IEventStream<T, TE>, IEnumerable<T> where T : IEventStreamEvent where TE : EventStreamException, new()
#endif
{
}

Expand Down Expand Up @@ -171,13 +177,72 @@ public override void StartProcessing()
///
/// The Task will be completed when all of the events from the stream have been processed.
/// </summary>
public override async Task StartProcessingAsync()
public override async Task StartProcessingAsync(CancellationToken cancellationToken = default)
{
// If they are/have enumerated, the event-driven mode should be disabled
if (IsEnumerated) throw new InvalidOperationException(MutuallyExclusiveExceptionMessage);

await base.StartProcessingAsync().ConfigureAwait(false);
await base.StartProcessingAsync(cancellationToken).ConfigureAwait(false);
}
#endif

#if AWS_ASYNC_ENUMERABLES_API
/// <summary>
/// Returns an async enumerator that iterates through the collection.
/// </summary>
/// <returns>An async enumerator that can be used to iterate through the collection.</returns>
public async IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken)
{
if (IsProcessing)
{
// If the queue has already begun processing, refuse to enumerate.
throw new InvalidOperationException(MutuallyExclusiveExceptionMessage);
}

// There could be more than 1 message created per decoder cycle.
var events = new Queue<T>();

// Opting out of events - letting the enumeration handle everything.
IsEnumerated = true;
IsProcessing = true;

// Enumeration is just magic over the event driven mechanism.
EventReceived += (sender, args) => events.Enqueue(args.EventStreamEvent);

var buffer = new byte[BufferSize];

while (IsProcessing)
{
// If there are already events ready to be served, do not ask for more.
if (events.Count > 0)
{
var ev = events.Dequeue();
// Enumeration handles terminal events on behalf of the user.
if (ev is IEventStreamTerminalEvent)
{
IsProcessing = false;
Dispose();
}

yield return ev;
}
else
{
try
{
await ReadFromStreamAsync(buffer, cancellationToken).ConfigureAwait(false);
}
catch (Exception ex)
{
IsProcessing = false;
Dispose();

// Wrap exceptions as needed to match event-driven behavior.
throw WrapException(ex);
}
}
}
}
#endif
}
}
}
22 changes: 11 additions & 11 deletions sdk/src/Core/Amazon.Runtime/EventStreams/Internal/EventStream.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/*
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
Expand All @@ -17,10 +17,9 @@
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Threading;
#if AWS_ASYNC_API
using System.Threading.Tasks;
#else
using System.Threading;
#endif

namespace Amazon.Runtime.EventStreams.Internal
Expand Down Expand Up @@ -55,7 +54,7 @@ namespace Amazon.Runtime.EventStreams.Internal
///
/// The Task will be completed when all of the events from the stream have been processed.
/// </summary>
Task StartProcessingAsync();
Task StartProcessingAsync(CancellationToken cancellationToken = default);
#endif
}

Expand Down Expand Up @@ -262,23 +261,23 @@ protected void Process()
{
#if AWS_ASYNC_API
// Task only exists in framework 4.5 and up, and Standard.
Task.Run(() => ProcessLoopAsync());
Task.Run(() => ProcessLoopAsync(CancellationToken.None));
#else
// ThreadPool only exists in 3.5 and below. These implementations do not have the Task library.
ThreadPool.QueueUserWorkItem(ProcessLoop);
#endif
}

#if AWS_ASYNC_API
private async Task ProcessLoopAsync()
private async Task ProcessLoopAsync(CancellationToken cancellationToken)
{
var buffer = new byte[BufferSize];

try
{
while (IsProcessing)
{
await ReadFromStreamAsync(buffer).ConfigureAwait(false);
await ReadFromStreamAsync(buffer, cancellationToken).ConfigureAwait(false);
}
}
// These exceptions are raised on the background thread. They are fired as events for visibility.
Expand Down Expand Up @@ -351,9 +350,10 @@ protected void ReadFromStream(byte[] buffer)
/// each message it decodes.
/// </summary>
/// <param name="buffer">The buffer to store the read bytes from the stream.</param>
protected async Task ReadFromStreamAsync(byte[] buffer)
/// <param name="cancellationToken">A cancellation token.</param>
protected async Task ReadFromStreamAsync(byte[] buffer, CancellationToken cancellationToken)
{
var bytesRead = await NetworkStream.ReadAsync(buffer, 0, buffer.Length).ConfigureAwait(false);
var bytesRead = await NetworkStream.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false);
if (bytesRead > 0)
{
// Decoder raises MessageReceived for every message it encounters.
Expand Down Expand Up @@ -408,13 +408,13 @@ public virtual void StartProcessing()
///
/// The Task will be completed when all of the events from the stream have been processed.
/// </summary>
public virtual async Task StartProcessingAsync()
public virtual async Task StartProcessingAsync(CancellationToken cancellationToken = default)
{
if (IsProcessing)
return;

IsProcessing = true;
await ProcessLoopAsync().ConfigureAwait(false);
await ProcessLoopAsync(cancellationToken).ConfigureAwait(false);
}
#endif

Expand Down
9 changes: 8 additions & 1 deletion sdk/src/Services/BedrockRuntime/BedrockRuntime.sln
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AWSSDK.IntegrationTests.Bed
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AWSSDK.IntegrationTests.BedrockRuntime.Net45", "../../../test/Services/BedrockRuntime/IntegrationTests/AWSSDK.IntegrationTests.BedrockRuntime.Net45.csproj", "{086FF208-3CD6-40EC-9309-43F2EAA2F4D8}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AWSSDK.IntegrationTests.BedrockRuntime.NetStandard", "../../../test/Services/BedrockRuntime/IntegrationTests/AWSSDK.IntegrationTests.BedrockRuntime.NetStandard.csproj", "{9F726137-4C28-4FEA-9A6C-962DEA25951D}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AWSSDK.UnitTests.BedrockRuntime.Net35", "../../../test/Services/BedrockRuntime/UnitTests/AWSSDK.UnitTests.BedrockRuntime.Net35.csproj", "{AF460DBB-A029-440B-8C04-974E78043F09}"
EndProject
Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AWSSDK.UnitTests.BedrockRuntime.Net45", "../../../test/Services/BedrockRuntime/UnitTests/AWSSDK.UnitTests.BedrockRuntime.Net45.csproj", "{8F243F98-2E75-4C9C-B4B3-61EC51BEF835}"
Expand Down Expand Up @@ -186,6 +188,10 @@ Global
{A657D500-DDA4-45FF-9459-8351CDD96B78}.Debug|Any CPU.Build.0 = Debug|Any CPU
{A657D500-DDA4-45FF-9459-8351CDD96B78}.Release|Any CPU.ActiveCfg = Release|Any CPU
{A657D500-DDA4-45FF-9459-8351CDD96B78}.Release|Any CPU.Build.0 = Release|Any CPU
{9F726137-4C28-4FEA-9A6C-962DEA25951D}.Debug|Any CPU.ActiveCfg = Debug|Any CPU
{9F726137-4C28-4FEA-9A6C-962DEA25951D}.Debug|Any CPU.Build.0 = Debug|Any CPU
{9F726137-4C28-4FEA-9A6C-962DEA25951D}.Release|Any CPU.ActiveCfg = Release|Any CPU
{9F726137-4C28-4FEA-9A6C-962DEA25951D}.Release|Any CPU.Build.0 = Release|Any CPU
EndGlobalSection
GlobalSection(SolutionProperties) = preSolution
HideSolutionNode = FALSE
Expand Down Expand Up @@ -220,8 +226,9 @@ Global
{7BD5B7F3-2ED9-4747-9FCE-8F5622BFCC36} = {939EC5C2-8345-43E2-8F97-72EEEBEEA0AC}
{EE034587-0A31-4841-A4BB-055DB990990F} = {939EC5C2-8345-43E2-8F97-72EEEBEEA0AC}
{A657D500-DDA4-45FF-9459-8351CDD96B78} = {939EC5C2-8345-43E2-8F97-72EEEBEEA0AC}
{9F726137-4C28-4FEA-9A6C-962DEA25951D} = {12EC4E4B-7E2C-4B63-8EF9-7B959F82A89B}
EndGlobalSection
GlobalSection(ExtensibilityGlobals) = postSolution
SolutionGuid = {CE2F2305-8E72-44B6-9FAD-AA2E347C2B6A}
EndGlobalSection
EndGlobal
EndGlobal
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<RunAnalyzersDuringBuild Condition="'$(RunAnalyzersDuringBuild)'==''">true</RunAnalyzersDuringBuild>
<TargetFrameworks>netstandard2.0;netcoreapp3.1;net8.0</TargetFrameworks>
<DefineConstants>$(DefineConstants);NETSTANDARD;AWS_ASYNC_API</DefineConstants>
<DefineConstants Condition="'$(TargetFramework)' == 'netstandard2.0'">$(DefineConstants);NETSTANDARD20;AWS_ASYNC_ENUMERABLES_API</DefineConstants>
<DefineConstants Condition="'$(TargetFramework)' == 'netcoreapp3.1'">$(DefineConstants);AWS_ASYNC_ENUMERABLES_API</DefineConstants>
<DefineConstants Condition="'$(TargetFramework)' == 'net8.0'">$(DefineConstants);AWS_ASYNC_ENUMERABLES_API</DefineConstants>
<DebugType>portable</DebugType>
<GenerateDocumentationFile>true</GenerateDocumentationFile>
<AssemblyName>AWSSDK.IntegrationTests.BedrockRuntime.NetStandard</AssemblyName>
<PackageId>AWSSDK.IntegrationTests.BedrockRuntime.NetStandard</PackageId>

<GenerateAssemblyTitleAttribute>false</GenerateAssemblyTitleAttribute>
<GenerateAssemblyConfigurationAttribute>false</GenerateAssemblyConfigurationAttribute>
<GenerateAssemblyProductAttribute>false</GenerateAssemblyProductAttribute>
<GenerateAssemblyCompanyAttribute>false</GenerateAssemblyCompanyAttribute>
<GenerateAssemblyCopyrightAttribute>false</GenerateAssemblyCopyrightAttribute>
<GenerateAssemblyVersionAttribute>false</GenerateAssemblyVersionAttribute>
<GenerateAssemblyFileVersionAttribute>false</GenerateAssemblyFileVersionAttribute>
<GenerateAssemblyDescriptionAttribute>false</GenerateAssemblyDescriptionAttribute>
<SignAssembly>true</SignAssembly>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors>

<NoWarn>CA1822</NoWarn>
</PropertyGroup>

<!-- Async Enumerable Compatibility -->
<PropertyGroup Condition="'$(TargetFramework)' == 'netstandard2.0'">
<LangVersion>8.0</LangVersion>
</PropertyGroup>

<ItemGroup>
<Compile Remove="**/35/**" />
<None Remove="**/35/**" />
<Compile Remove="**/obj/**" />
<None Remove="**/obj/**" />
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.11.1" />
<PackageReference Include="MSTest.TestAdapter" Version="3.6.2" />
<PackageReference Include="MSTest.TestFramework" Version="3.6.2" />
</ItemGroup>

<ItemGroup>
<ProjectReference Include="../../../../src/Core/AWSSDK.Core.NetStandard.csproj" />
<ProjectReference Include="../../../../src/Services/BedrockRuntime/AWSSDK.BedrockRuntime.NetStandard.csproj" />
</ItemGroup>


</Project>
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Amazon.BedrockRuntime;
using Amazon.BedrockRuntime.Model;
using System.Threading.Tasks;
Expand All @@ -8,6 +8,9 @@
using System.Threading;
using System;
using System.Diagnostics.Contracts;
using System.Collections.Generic;
using Amazon;
using Amazon.Runtime;
namespace AWSSDK_DotNet.IntegrationTests.Tests
{
/// <summary>
Expand All @@ -22,7 +25,11 @@ namespace AWSSDK_DotNet.IntegrationTests.Tests
/// </summary>
[Ignore]
[TestClass]
#if NETSTANDARD
public class BedrockRuntimeEventStreamTests
#else
public class BedrockRuntimeEventStreamTests : TestBase<AmazonBedrockRuntimeClient>
#endif
{
#if BCL35
[TestMethod]
Expand Down Expand Up @@ -145,6 +152,48 @@ public async Task RequestWithInvalidBodyReturnsValidationException()
}

#endif

#if AWS_ASYNC_ENUMERABLES_API
[TestMethod]
public async Task ConverseStreamCanBeEnumeratedAsynchronously()
{
// configure with credentials and region
var client = new AmazonBedrockRuntimeClient();

var request = new ConverseStreamRequest
{
ModelId = "meta.llama3-1-8b-instruct-v1:0"
};

request.Messages.Add(new Message
{
Content = new List<ContentBlock> { new ContentBlock { Text = "Who was the first US president" } },
Role = ConversationRole.User
});

var response = await client.ConverseStreamAsync(request);

Assert.IsNotNull(response);
Assert.IsNotNull(response.Stream);

using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(20));

var contentStringBuilder = new StringBuilder();
await foreach (var item in response.Stream.WithCancellation(cts.Token))
{
if (item is ContentBlockDeltaEvent deltaEvent)
{
contentStringBuilder.Append(deltaEvent.Delta.Text);
}
}

var responseContent = contentStringBuilder.ToString();

// Since we don't know the contents of the response from Bedrock, we just assert that we received a response
Assert.IsTrue(responseContent.Length > 10);
}
#endif

static MemoryStream CreateStream(string query, bool createInvalidInput = false)
{
StringBuilder promptValueBuilder = new StringBuilder();
Expand Down