Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for allowing clients to limit which finish_reasons they allow. #1026

Open
wants to merge 2 commits into
base: canary
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/docs/snippets/clients/providers/anthropic.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ client<llm> MyClient {
```
</ParamField>

<Markdown src="../../../../snippets/finish-reason.mdx" />
<Markdown src="../../../../snippets/allowed-role-metadata.mdx" />

## Forwarded options
Expand Down
1 change: 1 addition & 0 deletions docs/docs/snippets/clients/providers/aws-bedrock.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ limited to:
We don't have any checks for this field, you can pass any string you wish.
</ParamField>

<Markdown src="../../../../snippets/finish-reason.mdx" />
<Markdown src="../../../../snippets/allowed-role-metadata-basic.mdx" />

## Forwarded options
Expand Down
1 change: 1 addition & 0 deletions docs/docs/snippets/clients/providers/azure.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ client<llm> MyClient {
```
</ParamField>

<Markdown src="../../../../snippets/finish-reason.mdx" />
<Markdown src="../../../../snippets/allowed-role-metadata-basic.mdx" />

## Forwarded options
Expand Down
1 change: 1 addition & 0 deletions docs/docs/snippets/clients/providers/gemini.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ client<llm> MyClient {
```
</ParamField>

<Markdown src="../../../../snippets/finish-reason.mdx" />
<Markdown src="../../../../snippets/allowed-role-metadata-basic.mdx" />

## Forwarded options
Expand Down
1 change: 1 addition & 0 deletions docs/docs/snippets/clients/providers/groq.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ client<llm> MyClient {
base_url "https://api.groq.com/openai/v1"
api_key env.GROQ_API_KEY
model "llama3-70b-8192"
default_role "user"
}
}
```
1 change: 1 addition & 0 deletions docs/docs/snippets/clients/providers/ollama.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ client<llm> MyClient {
```
</ParamField>

<Markdown src="../../../../snippets/finish-reason.mdx" />
<Markdown src="../../../../snippets/allowed-role-metadata-basic.mdx" />

## Forwarded options
Expand Down
3 changes: 3 additions & 0 deletions docs/docs/snippets/clients/providers/openai-generic.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ client<llm> MyClient {

</ParamField>

<Markdown src="../../../../snippets/finish-reason.mdx" />
<Markdown src="../../../../snippets/allowed-role-metadata-basic.mdx" />

## Forwarded options

<ParamField
Expand Down
1 change: 1 addition & 0 deletions docs/docs/snippets/clients/providers/openai.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ client<llm> MyClient {

</ParamField>

<Markdown src="../../../../snippets/finish-reason.mdx" />
<Markdown src="../../../../snippets/allowed-role-metadata-basic.mdx" />

## Forwarded options
Expand Down
1 change: 1 addition & 0 deletions docs/docs/snippets/clients/providers/vertex.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ client<llm> MyClient {
```
</ParamField>

<Markdown src="../../../../snippets/finish-reason.mdx" />
<Markdown src="../../../../snippets/allowed-role-metadata-basic.mdx" />

## Forwarded options
Expand Down
13 changes: 13 additions & 0 deletions docs/snippets/finish-reason.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
<ParamField path="finish_reason_whitelist" type="string[]">
A list of finish reasons to allow. If set, any response with a finish reason not in this list will be rejected.
Empty finish reasons are always allowed.

**Default: `[]`**
</ParamField>

<ParamField path="finish_reason_blacklist" type="string[]">
A list of finish reasons to reject. If set, any response with a finish reason in this list will be rejected.
Empty finish reasons are always allowed.

**Default: `[]`**
</ParamField>
1 change: 1 addition & 0 deletions engine/baml-runtime/src/cli/serve/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ impl BamlError {
message: format!("Something went wrong with the LLM client: {:?}", err),
},
crate::internal::llm_client::ErrorCode::Other(_)
| crate::internal::llm_client::ErrorCode::BadRequest
| crate::internal::llm_client::ErrorCode::InvalidAuthentication
| crate::internal::llm_client::ErrorCode::NotSupported
| crate::internal::llm_client::ErrorCode::RateLimited
Expand Down
14 changes: 11 additions & 3 deletions engine/baml-runtime/src/internal/llm_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pub mod llm_provider;
pub mod orchestrator;
pub mod primitive;

mod properties_hander;
pub mod retry_policy;
mod strategy;
pub mod traits;
Expand Down Expand Up @@ -169,6 +170,7 @@ pub enum ErrorCode {
RateLimited, // 429
ServerError, // 500
ServiceUnavailable, // 503
BadRequest, // 400

// We failed to parse the response
UnsupportedResponse(u16),
Expand All @@ -184,6 +186,7 @@ impl ErrorCode {
ErrorCode::NotSupported => "NotSupported (403)".into(),
ErrorCode::RateLimited => "RateLimited (429)".into(),
ErrorCode::ServerError => "ServerError (500)".into(),
ErrorCode::BadRequest => "BadRequest (400)".into(),
ErrorCode::ServiceUnavailable => "ServiceUnavailable (503)".into(),
ErrorCode::UnsupportedResponse(code) => format!("BadResponse {}", code),
ErrorCode::Other(code) => format!("Unspecified error code: {}", code),
Expand All @@ -192,6 +195,7 @@ impl ErrorCode {

pub fn from_status(status: StatusCode) -> Self {
match status.as_u16() {
400 => ErrorCode::BadRequest,
401 => ErrorCode::InvalidAuthentication,
403 => ErrorCode::NotSupported,
429 => ErrorCode::RateLimited,
Expand All @@ -203,6 +207,7 @@ impl ErrorCode {

pub fn from_u16(code: u16) -> Self {
match code {
400 => ErrorCode::BadRequest,
401 => ErrorCode::InvalidAuthentication,
403 => ErrorCode::NotSupported,
429 => ErrorCode::RateLimited,
Expand All @@ -214,6 +219,7 @@ impl ErrorCode {

pub fn to_u16(&self) -> u16 {
match self {
ErrorCode::BadRequest => 400,
ErrorCode::InvalidAuthentication => 401,
ErrorCode::NotSupported => 403,
ErrorCode::RateLimited => 429,
Expand Down Expand Up @@ -339,9 +345,9 @@ impl crate::tracing::Visualize for LLMErrorResponse {
fn resolve_properties_walker(
client: &ClientWalker,
ctx: &crate::RuntimeContext,
) -> Result<std::collections::HashMap<String, serde_json::Value>> {
) -> Result<properties_hander::PropertiesHandler> {
use anyhow::Context;
(&client.item.elem.options)
let result = (&client.item.elem.options)
.iter()
.map(|(k, v)| {
Ok((
Expand All @@ -354,5 +360,7 @@ fn resolve_properties_walker(
))?,
))
})
.collect::<Result<std::collections::HashMap<_, _>>>()
.collect::<Result<std::collections::HashMap<_, _>>>()?;

Ok(properties_hander::PropertiesHandler::new(result))
}
17 changes: 15 additions & 2 deletions engine/baml-runtime/src/internal/llm_client/orchestrator/call.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,26 @@ pub async fn orchestrate(
let prompt = match node.render_prompt(ir, prompt, ctx, params).await {
Ok(p) => p,
Err(e) => {
results.push((node.scope, LLMResponse::InternalFailure(e.to_string()), None));
results.push((
node.scope,
LLMResponse::InternalFailure(e.to_string()),
None,
));
continue;
}
};
let response = node.single_call(&ctx, &prompt).await;
let parsed_response = match &response {
LLMResponse::Success(s) => Some(parse_fn(&s.content)),
LLMResponse::Success(s) => {
if node.is_valid_finish_reason(s) {
Some(parse_fn(&s.content))
} else {
Some(Err(anyhow::anyhow!(
"Non-terminal finish reason: {}",
s.metadata.finish_reason.as_deref().unwrap_or("<empty>")
)))
}
}
_ => None,
};

Expand Down
13 changes: 12 additions & 1 deletion engine/baml-runtime/src/internal/llm_client/orchestrator/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ use crate::{
RuntimeContext,
};

use super::traits::WithRenderRawCurl;
use super::properties_hander::FinishReasonOptions;
use super::traits::{WithClientProperties, WithRenderRawCurl};
use super::LLMCompleteResponse;
use super::{
strategy::roundrobin::RoundRobinStrategy,
traits::{StreamResponse, WithPrompt, WithSingleCallable, WithStreamable},
Expand Down Expand Up @@ -81,6 +83,15 @@ impl OrchestratorNode {
_ => None,
})
}

pub fn is_valid_finish_reason(&self, response: &LLMCompleteResponse) -> bool {
let Some(finish_reason) = response.metadata.finish_reason.as_deref() else {
return true;
};
self.provider
.finish_reason_handling()
.map_or(true, |options| options.is_allowed(finish_reason))
}
}

#[derive(Default, Clone, Serialize)]
Expand Down
17 changes: 15 additions & 2 deletions engine/baml-runtime/src/internal/llm_client/orchestrator/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@ where
let prompt = match node.render_prompt(ir, prompt, ctx, params).await {
Ok(p) => p,
Err(e) => {
results.push((node.scope, LLMResponse::InternalFailure(e.to_string()), None));
results.push((
node.scope,
LLMResponse::InternalFailure(e.to_string()),
None,
));
continue;
}
};
Expand Down Expand Up @@ -89,7 +93,16 @@ where
};

let parsed_response = match &final_response {
LLMResponse::Success(s) => Some(parse_fn(&s.content)),
LLMResponse::Success(s) => {
if node.is_valid_finish_reason(s) {
Some(parse_fn(&s.content))
} else {
Some(Err(anyhow::anyhow!(
"Non-terminal finish reason: {}",
s.metadata.finish_reason.as_deref().unwrap_or("<empty>")
)))
}
}
_ => None,
};
let sleep_duration = node.error_sleep_duration().cloned();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::internal::llm_client::{
properties_hander::PropertiesHandler,
traits::{ToProviderMessage, ToProviderMessageExt, WithClientProperties},
AllowedMetadata, ResolveMediaUrls,
};
Expand Down Expand Up @@ -43,6 +44,7 @@ struct PostRequestProperities {
headers: HashMap<String, String>,
proxy_url: Option<String>,
allowed_metadata: AllowedMetadata,
finish_reason: Option<crate::internal::llm_client::properties_hander::FinishReasonOptions>,
// These are passed directly to the Anthropic API.
properties: HashMap<String, serde_json::Value>,
}
Expand All @@ -62,63 +64,40 @@ pub struct AnthropicClient {
// resolves/constructs PostRequestProperties from the client's options and runtime context, fleshing out the needed headers and parameters
// basically just reads the client's options and matches them to needed properties or defaults them
fn resolve_properties(
mut properties: HashMap<String, serde_json::Value>,
mut properties: PropertiesHandler,
ctx: &RuntimeContext,
) -> Result<PostRequestProperities> {
// this is a required field
properties
.entry("max_tokens".into())
.or_insert_with(|| 4096.into());

let default_role = properties
.remove("default_role")
.and_then(|v| v.as_str().map(|s| s.to_string()))
.unwrap_or_else(|| "system".to_string());

let default_role = properties.pull_default_role("system")?;
let base_url = properties
.remove("base_url")
.and_then(|v| v.as_str().map(|s| s.to_string()))
.unwrap_or_else(|| "https://api.anthropic.com".to_string());

.pull_base_url()?
.unwrap_or_else(|| "https://api.anthropic.com".into());
let api_key = properties
.remove("api_key")
.and_then(|v| v.as_str().map(|s| s.to_string()))
.pull_api_key()?
.or_else(|| ctx.env.get("ANTHROPIC_API_KEY").map(|s| s.to_string()));

let allowed_metadata = match properties.remove("allowed_role_metadata") {
Some(allowed_metadata) => serde_json::from_value(allowed_metadata).context(
"allowed_role_metadata must be an array of keys. For example: ['key1', 'key2']",
)?,
None => AllowedMetadata::None,
};

let mut headers = match properties.remove("headers") {
Some(headers) => headers
.as_object()
.context("headers must be a map of strings to strings")?
.iter()
.map(|(k, v)| {
Ok((
k.to_string(),
v.as_str()
.context(format!("Header '{}' must be a string", k))?
.to_string(),
))
})
.collect::<Result<HashMap<_, _>>>()?,
None => Default::default(),
};

let allowed_metadata = properties.pull_allowed_role_metadata()?;
let mut headers = properties.pull_headers()?;
headers
.entry("anthropic-version".to_string())
.or_insert("2023-06-01".to_string());
let finish_reason = properties.pull_finish_reason_options()?;

let mut properties = properties.finalize();
// Anthropic has a very low max_tokens by default, so we increase it to 4096.
properties
.entry("max_tokens".into())
.or_insert_with(|| 4096.into());
let properties = properties;

Ok(PostRequestProperities {
default_role,
base_url,
api_key,
headers,
allowed_metadata,
finish_reason,
properties,
proxy_url: ctx.env.get("BOUNDARY_PROXY_URL").map(|s| s.to_string()),
})
Expand All @@ -138,6 +117,12 @@ impl WithClientProperties for AnthropicClient {
fn client_properties(&self) -> &HashMap<String, serde_json::Value> {
&self.properties.properties
}

fn finish_reason_handling(
&self,
) -> Option<&crate::internal::llm_client::properties_hander::FinishReasonOptions> {
self.properties.finish_reason.as_ref()
}
}

impl WithClient for AnthropicClient {
Expand Down Expand Up @@ -308,14 +293,7 @@ impl WithStreamChat for AnthropicClient {
// constructs base client and resolves properties based on context
impl AnthropicClient {
pub fn dynamic_new(client: &ClientProperty, ctx: &RuntimeContext) -> Result<Self> {
let properties = resolve_properties(
client
.options
.iter()
.map(|(k, v)| Ok((k.clone(), json!(v))))
.collect::<Result<HashMap<_, _>>>()?,
ctx,
)?;
let properties = resolve_properties(client.property_handler()?, ctx)?;
let default_role = properties.default_role.clone();
Ok(Self {
name: client.name.clone(),
Expand Down
Loading
Loading