diff --git a/sdk/src/Core/Amazon.Runtime/EventStreams/Internal/EnumerableEventStream.cs b/sdk/src/Core/Amazon.Runtime/EventStreams/Internal/EnumerableEventStream.cs index 30da7cffc453..fab9ac9821d5 100644 --- a/sdk/src/Core/Amazon.Runtime/EventStreams/Internal/EnumerableEventStream.cs +++ b/sdk/src/Core/Amazon.Runtime/EventStreams/Internal/EnumerableEventStream.cs @@ -1,4 +1,4 @@ -/******************************************************************************* +/******************************************************************************* * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. * Licensed under the Apache License, Version 2.0 (the "License"). You may not use * this file except in compliance with the License. A copy of the License is located at @@ -25,6 +25,7 @@ using System.Diagnostics.CodeAnalysis; using System.IO; #if AWS_ASYNC_API +using System.Threading; using System.Threading.Tasks; #endif @@ -36,7 +37,12 @@ namespace Amazon.Runtime.EventStreams.Internal /// An implementation of IEventStreamEvent (e.g. IS3Event). /// An implementation of EventStreamException (e.g. S3EventStreamException). [SuppressMessage("Microsoft.Naming", "CA1710", Justification = "IEventStreamCollection is not descriptive.")] +#if AWS_ASYNC_ENUMERABLES_API + + public interface IEnumerableEventStream : IEventStream, IEnumerable, IAsyncEnumerable where T : IEventStreamEvent where TE : EventStreamException, new() +#else public interface IEnumerableEventStream : IEventStream, IEnumerable where T : IEventStreamEvent where TE : EventStreamException, new() +#endif { } @@ -171,13 +177,72 @@ public override void StartProcessing() /// /// The Task will be completed when all of the events from the stream have been processed. /// - public override async Task StartProcessingAsync() + public override async Task StartProcessingAsync(CancellationToken cancellationToken = default) { // If they are/have enumerated, the event-driven mode should be disabled if (IsEnumerated) throw new InvalidOperationException(MutuallyExclusiveExceptionMessage); - await base.StartProcessingAsync().ConfigureAwait(false); + await base.StartProcessingAsync(cancellationToken).ConfigureAwait(false); + } +#endif + +#if AWS_ASYNC_ENUMERABLES_API + /// + /// Returns an async enumerator that iterates through the collection. + /// + /// An async enumerator that can be used to iterate through the collection. + public async IAsyncEnumerator GetAsyncEnumerator(CancellationToken cancellationToken) + { + if (IsProcessing) + { + // If the queue has already begun processing, refuse to enumerate. + throw new InvalidOperationException(MutuallyExclusiveExceptionMessage); + } + + // There could be more than 1 message created per decoder cycle. + var events = new Queue(); + + // Opting out of events - letting the enumeration handle everything. + IsEnumerated = true; + IsProcessing = true; + + // Enumeration is just magic over the event driven mechanism. + EventReceived += (sender, args) => events.Enqueue(args.EventStreamEvent); + + var buffer = new byte[BufferSize]; + + while (IsProcessing) + { + // If there are already events ready to be served, do not ask for more. + if (events.Count > 0) + { + var ev = events.Dequeue(); + // Enumeration handles terminal events on behalf of the user. + if (ev is IEventStreamTerminalEvent) + { + IsProcessing = false; + Dispose(); + } + + yield return ev; + } + else + { + try + { + await ReadFromStreamAsync(buffer, cancellationToken).ConfigureAwait(false); + } + catch (Exception ex) + { + IsProcessing = false; + Dispose(); + + // Wrap exceptions as needed to match event-driven behavior. + throw WrapException(ex); + } + } + } } #endif } -} \ No newline at end of file +} diff --git a/sdk/src/Core/Amazon.Runtime/EventStreams/Internal/EventStream.cs b/sdk/src/Core/Amazon.Runtime/EventStreams/Internal/EventStream.cs index f7f715f7adb6..72d3ed890ad5 100644 --- a/sdk/src/Core/Amazon.Runtime/EventStreams/Internal/EventStream.cs +++ b/sdk/src/Core/Amazon.Runtime/EventStreams/Internal/EventStream.cs @@ -1,4 +1,4 @@ -/* +/* * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). @@ -17,10 +17,9 @@ using System.Collections.Generic; using System.Diagnostics.CodeAnalysis; using System.IO; +using System.Threading; #if AWS_ASYNC_API using System.Threading.Tasks; -#else -using System.Threading; #endif namespace Amazon.Runtime.EventStreams.Internal @@ -55,7 +54,7 @@ namespace Amazon.Runtime.EventStreams.Internal /// /// The Task will be completed when all of the events from the stream have been processed. /// - Task StartProcessingAsync(); + Task StartProcessingAsync(CancellationToken cancellationToken = default); #endif } @@ -262,7 +261,7 @@ protected void Process() { #if AWS_ASYNC_API // Task only exists in framework 4.5 and up, and Standard. - Task.Run(() => ProcessLoopAsync()); + Task.Run(() => ProcessLoopAsync(CancellationToken.None)); #else // ThreadPool only exists in 3.5 and below. These implementations do not have the Task library. ThreadPool.QueueUserWorkItem(ProcessLoop); @@ -270,7 +269,7 @@ protected void Process() } #if AWS_ASYNC_API - private async Task ProcessLoopAsync() + private async Task ProcessLoopAsync(CancellationToken cancellationToken) { var buffer = new byte[BufferSize]; @@ -278,7 +277,7 @@ private async Task ProcessLoopAsync() { while (IsProcessing) { - await ReadFromStreamAsync(buffer).ConfigureAwait(false); + await ReadFromStreamAsync(buffer, cancellationToken).ConfigureAwait(false); } } // These exceptions are raised on the background thread. They are fired as events for visibility. @@ -351,9 +350,10 @@ protected void ReadFromStream(byte[] buffer) /// each message it decodes. /// /// The buffer to store the read bytes from the stream. - protected async Task ReadFromStreamAsync(byte[] buffer) + /// A cancellation token. + protected async Task ReadFromStreamAsync(byte[] buffer, CancellationToken cancellationToken) { - var bytesRead = await NetworkStream.ReadAsync(buffer, 0, buffer.Length).ConfigureAwait(false); + var bytesRead = await NetworkStream.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false); if (bytesRead > 0) { // Decoder raises MessageReceived for every message it encounters. @@ -408,13 +408,13 @@ public virtual void StartProcessing() /// /// The Task will be completed when all of the events from the stream have been processed. /// - public virtual async Task StartProcessingAsync() + public virtual async Task StartProcessingAsync(CancellationToken cancellationToken = default) { if (IsProcessing) return; IsProcessing = true; - await ProcessLoopAsync().ConfigureAwait(false); + await ProcessLoopAsync(cancellationToken).ConfigureAwait(false); } #endif diff --git a/sdk/src/Services/BedrockRuntime/BedrockRuntime.sln b/sdk/src/Services/BedrockRuntime/BedrockRuntime.sln index 5363ba0af196..9c4da50b495a 100644 --- a/sdk/src/Services/BedrockRuntime/BedrockRuntime.sln +++ b/sdk/src/Services/BedrockRuntime/BedrockRuntime.sln @@ -26,6 +26,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AWSSDK.IntegrationTests.Bed EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AWSSDK.IntegrationTests.BedrockRuntime.Net45", "../../../test/Services/BedrockRuntime/IntegrationTests/AWSSDK.IntegrationTests.BedrockRuntime.Net45.csproj", "{086FF208-3CD6-40EC-9309-43F2EAA2F4D8}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AWSSDK.IntegrationTests.BedrockRuntime.NetStandard", "../../../test/Services/BedrockRuntime/IntegrationTests/AWSSDK.IntegrationTests.BedrockRuntime.NetStandard.csproj", "{9F726137-4C28-4FEA-9A6C-962DEA25951D}" +EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AWSSDK.UnitTests.BedrockRuntime.Net35", "../../../test/Services/BedrockRuntime/UnitTests/AWSSDK.UnitTests.BedrockRuntime.Net35.csproj", "{AF460DBB-A029-440B-8C04-974E78043F09}" EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "AWSSDK.UnitTests.BedrockRuntime.Net45", "../../../test/Services/BedrockRuntime/UnitTests/AWSSDK.UnitTests.BedrockRuntime.Net45.csproj", "{8F243F98-2E75-4C9C-B4B3-61EC51BEF835}" @@ -186,6 +188,10 @@ Global {A657D500-DDA4-45FF-9459-8351CDD96B78}.Debug|Any CPU.Build.0 = Debug|Any CPU {A657D500-DDA4-45FF-9459-8351CDD96B78}.Release|Any CPU.ActiveCfg = Release|Any CPU {A657D500-DDA4-45FF-9459-8351CDD96B78}.Release|Any CPU.Build.0 = Release|Any CPU + {9F726137-4C28-4FEA-9A6C-962DEA25951D}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {9F726137-4C28-4FEA-9A6C-962DEA25951D}.Debug|Any CPU.Build.0 = Debug|Any CPU + {9F726137-4C28-4FEA-9A6C-962DEA25951D}.Release|Any CPU.ActiveCfg = Release|Any CPU + {9F726137-4C28-4FEA-9A6C-962DEA25951D}.Release|Any CPU.Build.0 = Release|Any CPU EndGlobalSection GlobalSection(SolutionProperties) = preSolution HideSolutionNode = FALSE @@ -220,8 +226,9 @@ Global {7BD5B7F3-2ED9-4747-9FCE-8F5622BFCC36} = {939EC5C2-8345-43E2-8F97-72EEEBEEA0AC} {EE034587-0A31-4841-A4BB-055DB990990F} = {939EC5C2-8345-43E2-8F97-72EEEBEEA0AC} {A657D500-DDA4-45FF-9459-8351CDD96B78} = {939EC5C2-8345-43E2-8F97-72EEEBEEA0AC} + {9F726137-4C28-4FEA-9A6C-962DEA25951D} = {12EC4E4B-7E2C-4B63-8EF9-7B959F82A89B} EndGlobalSection GlobalSection(ExtensibilityGlobals) = postSolution SolutionGuid = {CE2F2305-8E72-44B6-9FAD-AA2E347C2B6A} EndGlobalSection -EndGlobal \ No newline at end of file +EndGlobal diff --git a/sdk/test/Services/BedrockRuntime/IntegrationTests/AWSSDK.IntegrationTests.BedrockRuntime.NetStandard.csproj b/sdk/test/Services/BedrockRuntime/IntegrationTests/AWSSDK.IntegrationTests.BedrockRuntime.NetStandard.csproj new file mode 100644 index 000000000000..8e774b3089b7 --- /dev/null +++ b/sdk/test/Services/BedrockRuntime/IntegrationTests/AWSSDK.IntegrationTests.BedrockRuntime.NetStandard.csproj @@ -0,0 +1,52 @@ + + + true + netstandard2.0;netcoreapp3.1;net8.0 + $(DefineConstants);NETSTANDARD;AWS_ASYNC_API + $(DefineConstants);NETSTANDARD20;AWS_ASYNC_ENUMERABLES_API + $(DefineConstants);AWS_ASYNC_ENUMERABLES_API + $(DefineConstants);AWS_ASYNC_ENUMERABLES_API + portable + true + AWSSDK.IntegrationTests.BedrockRuntime.NetStandard + AWSSDK.IntegrationTests.BedrockRuntime.NetStandard + + false + false + false + false + false + false + false + false + true + true + + CA1822 + + + + + 8.0 + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/sdk/test/Services/BedrockRuntime/IntegrationTests/BedrockRuntimeEventStreamTests.cs b/sdk/test/Services/BedrockRuntime/IntegrationTests/BedrockRuntimeEventStreamTests.cs index 6a2e636f2d51..9eaabe01d20f 100644 --- a/sdk/test/Services/BedrockRuntime/IntegrationTests/BedrockRuntimeEventStreamTests.cs +++ b/sdk/test/Services/BedrockRuntime/IntegrationTests/BedrockRuntimeEventStreamTests.cs @@ -1,4 +1,4 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; +using Microsoft.VisualStudio.TestTools.UnitTesting; using Amazon.BedrockRuntime; using Amazon.BedrockRuntime.Model; using System.Threading.Tasks; @@ -8,6 +8,9 @@ using System.Threading; using System; using System.Diagnostics.Contracts; +using System.Collections.Generic; +using Amazon; +using Amazon.Runtime; namespace AWSSDK_DotNet.IntegrationTests.Tests { /// @@ -22,7 +25,11 @@ namespace AWSSDK_DotNet.IntegrationTests.Tests /// [Ignore] [TestClass] +#if NETSTANDARD + public class BedrockRuntimeEventStreamTests +#else public class BedrockRuntimeEventStreamTests : TestBase +#endif { #if BCL35 [TestMethod] @@ -145,6 +152,48 @@ public async Task RequestWithInvalidBodyReturnsValidationException() } #endif + +#if AWS_ASYNC_ENUMERABLES_API + [TestMethod] + public async Task ConverseStreamCanBeEnumeratedAsynchronously() + { + // configure with credentials and region + var client = new AmazonBedrockRuntimeClient(); + + var request = new ConverseStreamRequest + { + ModelId = "meta.llama3-1-8b-instruct-v1:0" + }; + + request.Messages.Add(new Message + { + Content = new List { new ContentBlock { Text = "Who was the first US president" } }, + Role = ConversationRole.User + }); + + var response = await client.ConverseStreamAsync(request); + + Assert.IsNotNull(response); + Assert.IsNotNull(response.Stream); + + using var cts = new CancellationTokenSource(TimeSpan.FromSeconds(20)); + + var contentStringBuilder = new StringBuilder(); + await foreach (var item in response.Stream.WithCancellation(cts.Token)) + { + if (item is ContentBlockDeltaEvent deltaEvent) + { + contentStringBuilder.Append(deltaEvent.Delta.Text); + } + } + + var responseContent = contentStringBuilder.ToString(); + + // Since we don't know the contents of the response from Bedrock, we just assert that we received a response + Assert.IsTrue(responseContent.Length > 10); + } +#endif + static MemoryStream CreateStream(string query, bool createInvalidInput = false) { StringBuilder promptValueBuilder = new StringBuilder();