Skip to content

Commit

Permalink
Implement Microsoft.Extensions.AI's IChatClient on AmazonBedrockRunti…
Browse files Browse the repository at this point in the history
…meClient

This enables AmazonBedrockRuntimeClient to be used as a Microsoft.Extensions.AI.IChatClient, such that it can implicitly be used by any consumer that operates on an IChatClient, and with any middleware written in terms of IChatClient, such as those components in the Microsoft.Extensions.AI package that provide support for automatic function invocation, OpenTelemetry, logging, distributed caching, and more.
  • Loading branch information
stephentoub committed Nov 6, 2024
1 parent 15400fa commit 6308716
Show file tree
Hide file tree
Showing 5 changed files with 706 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Threading;

#if AWS_ASYNC_API
using System.Threading.Tasks;
#endif
Expand All @@ -48,9 +50,9 @@ namespace Amazon.Runtime.EventStreams.Internal
[SuppressMessage("Microsoft.Naming", "CA1710", Justification = "EventStreamCollection is not descriptive.")]
[SuppressMessage("Microsoft.Design", "CA1063", Justification = "IDisposable is a transient interface from IEventStream. Users need to be able to call Dispose.")]
#if NET8_0_OR_GREATER
public abstract class EnumerableEventStream<T, [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] TE> : EventStream<T, TE>, IEnumerableEventStream<T, TE> where T : IEventStreamEvent where TE : EventStreamException, new()
public abstract class EnumerableEventStream<T, [DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] TE> : EventStream<T, TE>, IEnumerableEventStream<T, TE>, IAsyncEnumerable<T> where T : IEventStreamEvent where TE : EventStreamException, new()
#else
public abstract class EnumerableEventStream<T, TE> : EventStream<T, TE>, IEnumerableEventStream<T, TE> where T : IEventStreamEvent where TE : EventStreamException, new()
public abstract class EnumerableEventStream<T, TE> : EventStream<T, TE>, IEnumerableEventStream<T, TE>, IAsyncEnumerable<T> where T : IEventStreamEvent where TE : EventStreamException, new()
#endif
{
private const string MutuallyExclusiveExceptionMessage = "Stream has already begun processing. Event-driven and Enumerable traversals of the stream are mutually exclusive. " +
Expand Down Expand Up @@ -145,6 +147,67 @@ public IEnumerator<T> GetEnumerator()
}
}

/// <summary>
/// Returns an async enumerator that asynchronously iterates through the collection.
/// </summary>
/// <returns>An async enumerator that can be used to iterate through the collection.</returns>
public async IAsyncEnumerator<T> GetAsyncEnumerator(CancellationToken cancellationToken)
{
// This implementation of this method is identical to that of GetEnumerator, except that
// instead of using ReadFromStream, it uses ReadFromStreamAsync. The two implementations
// should be kept in sync.

if (IsProcessing)
{
// If the queue has already begun processing, refuse to enumerate.
throw new InvalidOperationException(MutuallyExclusiveExceptionMessage);
}

// There could be more than 1 message created per decoder cycle.
var events = new Queue<T>();

// Opting out of events - letting the enumeration handle everything.
IsEnumerated = true;
IsProcessing = true;

// Enumeration is just magic over the event driven mechanism.
EventReceived += (sender, args) => events.Enqueue(args.EventStreamEvent);

var buffer = new byte[BufferSize];

while (IsProcessing)
{
// If there are already events ready to be served, do not ask for more.
if (events.Count > 0)
{
var ev = events.Dequeue();
// Enumeration handles terminal events on behalf of the user.
if (ev is IEventStreamTerminalEvent)
{
IsProcessing = false;
Dispose();
}

yield return ev;
}
else
{
try
{
await ReadFromStreamAsync(buffer, cancellationToken).ConfigureAwait(false);
}
catch (Exception ex)
{
IsProcessing = false;
Dispose();

// Wrap exceptions as needed to match event-driven behavior.
throw WrapException(ex);
}
}
}
}

/// <summary>
/// Returns an enumerator that iterates through a collection.
/// </summary>
Expand Down
18 changes: 16 additions & 2 deletions sdk/src/Core/Amazon.Runtime/EventStreams/Internal/EventStream.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
using System.Collections.Generic;
using System.Diagnostics.CodeAnalysis;
using System.IO;
using System.Threading;

#if AWS_ASYNC_API
using System.Threading.Tasks;
#else
Expand Down Expand Up @@ -351,9 +353,21 @@ protected void ReadFromStream(byte[] buffer)
/// each message it decodes.
/// </summary>
/// <param name="buffer">The buffer to store the read bytes from the stream.</param>
protected async Task ReadFromStreamAsync(byte[] buffer)
protected Task ReadFromStreamAsync(byte[] buffer) => ReadFromStreamAsync(buffer, CancellationToken.None);

/// <summary>
/// Reads from the stream into the buffer. It then passes the buffer to the decoder, which raises an event for
/// each message it decodes.
/// </summary>
/// <param name="buffer">The buffer to store the read bytes from the stream.</param>
/// <param name="cancellationToken">The token to monitor for cancellation requests.</param>
protected async Task ReadFromStreamAsync(byte[] buffer, CancellationToken cancellationToken)
{
var bytesRead = await NetworkStream.ReadAsync(buffer, 0, buffer.Length).ConfigureAwait(false);
#if NETCOREAPP
var bytesRead = await NetworkStream.ReadAsync(buffer, cancellationToken).ConfigureAwait(false);
#else
var bytesRead = await NetworkStream.ReadAsync(buffer, 0, buffer.Length, cancellationToken).ConfigureAwait(false);
#endif
if (bytesRead > 0)
{
// Decoder raises MessageReceived for every message it encounters.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
<Project Sdk="Microsoft.NET.Sdk">
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<RunAnalyzersDuringBuild Condition="'$(RunAnalyzersDuringBuild)'==''">true</RunAnalyzersDuringBuild>
<TargetFramework>net472</TargetFramework>
Expand Down Expand Up @@ -64,6 +64,10 @@
<ProjectReference Include="../../Core/AWSSDK.Core.NetFramework.csproj"/>
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.0.0-preview.9.24525.1" />
</ItemGroup>

<ItemGroup Condition="$(RunAnalyzersDuringBuild)">
<PackageReference Include="Microsoft.CodeAnalysis.FxCopAnalyzers" Version="2.9.3">
<PrivateAssets>all</PrivateAssets>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@
<ProjectReference Include="../../Core/AWSSDK.Core.NetStandard.csproj"/>
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.0.0-preview.9.24525.1" />
</ItemGroup>

<ItemGroup Condition="$(RunAnalyzersDuringBuild)">
<PackageReference Include="Microsoft.CodeAnalysis.FxCopAnalyzers" Version="2.9.3">
<PrivateAssets>all</PrivateAssets>
Expand Down
Loading

0 comments on commit 6308716

Please sign in to comment.