From 373785b9455d32325dcf452fc50f0688f36f1a38 Mon Sep 17 00:00:00 2001 From: Dylan Martin Date: Fri, 9 Aug 2024 15:48:04 -0400 Subject: [PATCH] feat(flags): add validation for database reads for rust feature flag service (#24089) Co-authored-by: Neil Kakkar Co-authored-by: James Greenhill --- rust/feature-flags/src/api.rs | 117 +- rust/feature-flags/src/flag_definitions.rs | 1222 +++++++++++++++++++- rust/feature-flags/src/team.rs | 26 +- rust/feature-flags/src/test_utils.rs | 35 +- rust/feature-flags/src/v0_endpoint.rs | 10 +- rust/feature-flags/src/v0_request.rs | 29 +- rust/feature-flags/tests/test_flags.rs | 130 ++- 7 files changed, 1504 insertions(+), 65 deletions(-) diff --git a/rust/feature-flags/src/api.rs b/rust/feature-flags/src/api.rs index 2caae80bf9af6..da2b00fbfdef5 100644 --- a/rust/feature-flags/src/api.rs +++ b/rust/feature-flags/src/api.rs @@ -13,66 +13,131 @@ pub enum FlagsResponseCode { Ok = 1, } +#[derive(Debug, PartialEq, Eq, Deserialize, Serialize)] +#[serde(untagged)] +pub enum FlagValue { + Boolean(bool), + String(String), +} + #[derive(Debug, PartialEq, Eq, Deserialize, Serialize)] #[serde(rename_all = "camelCase")] pub struct FlagsResponse { pub error_while_computing_flags: bool, - // TODO: better typing here, support bool responses - pub feature_flags: HashMap, + pub feature_flags: HashMap, +} + +#[derive(Error, Debug)] +pub enum ClientFacingError { + #[error("Invalid request: {0}")] + BadRequest(String), + #[error("Unauthorized: {0}")] + Unauthorized(String), + #[error("Rate limited")] + RateLimited, + #[error("Service unavailable")] + ServiceUnavailable, } #[derive(Error, Debug)] pub enum FlagError { + #[error(transparent)] + ClientFacing(#[from] ClientFacingError), + #[error("Internal error: {0}")] + Internal(String), #[error("failed to decode request: {0}")] RequestDecodingError(String), #[error("failed to parse request: {0}")] RequestParsingError(#[from] serde_json::Error), - #[error("Empty distinct_id in request")] EmptyDistinctId, #[error("No distinct_id in request")] MissingDistinctId, - #[error("No api_key in request")] NoTokenError, #[error("API key is not valid")] TokenValidationError, - - #[error("rate limited")] - RateLimited, - #[error("failed to parse redis cache data")] DataParsingError, + #[error("failed to update redis cache")] + CacheUpdateError, #[error("redis unavailable")] RedisUnavailable, #[error("database unavailable")] DatabaseUnavailable, #[error("Timed out while fetching data")] TimeoutError, - // TODO: Consider splitting top-level errors (that are returned to the client) - // and FlagMatchingError, like timeouterror which we can gracefully handle. - // This will make the `into_response` a lot clearer as well, since it wouldn't - // have arbitrary errors that actually never make it to the client. } impl IntoResponse for FlagError { fn into_response(self) -> Response { match self { - FlagError::RequestDecodingError(_) - | FlagError::RequestParsingError(_) - | FlagError::EmptyDistinctId - | FlagError::MissingDistinctId => (StatusCode::BAD_REQUEST, self.to_string()), - - FlagError::NoTokenError | FlagError::TokenValidationError => { - (StatusCode::UNAUTHORIZED, self.to_string()) + FlagError::ClientFacing(err) => match err { + ClientFacingError::BadRequest(msg) => (StatusCode::BAD_REQUEST, msg), + ClientFacingError::Unauthorized(msg) => (StatusCode::UNAUTHORIZED, msg), + ClientFacingError::RateLimited => (StatusCode::TOO_MANY_REQUESTS, "Rate limit exceeded. Please reduce your request frequency and try again later.".to_string()), + ClientFacingError::ServiceUnavailable => (StatusCode::SERVICE_UNAVAILABLE, "Service is currently unavailable. Please try again later.".to_string()), + }, + FlagError::Internal(msg) => { + tracing::error!("Internal server error: {}", msg); + ( + StatusCode::INTERNAL_SERVER_ERROR, + "An internal server error occurred. Please try again later or contact support if the problem persists.".to_string(), + ) + } + FlagError::RequestDecodingError(msg) => { + (StatusCode::BAD_REQUEST, format!("Failed to decode request: {}. Please check your request format and try again.", msg)) + } + FlagError::RequestParsingError(err) => { + (StatusCode::BAD_REQUEST, format!("Failed to parse request: {}. Please ensure your request is properly formatted and all required fields are present.", err)) + } + FlagError::EmptyDistinctId => { + (StatusCode::BAD_REQUEST, "The distinct_id field cannot be empty. Please provide a valid identifier.".to_string()) + } + FlagError::MissingDistinctId => { + (StatusCode::BAD_REQUEST, "The distinct_id field is missing from the request. Please include a valid identifier.".to_string()) + } + FlagError::NoTokenError => { + (StatusCode::UNAUTHORIZED, "No API key provided. Please include a valid API key in your request.".to_string()) + } + FlagError::TokenValidationError => { + (StatusCode::UNAUTHORIZED, "The provided API key is invalid or has expired. Please check your API key and try again.".to_string()) + } + FlagError::DataParsingError => { + tracing::error!("Data parsing error: {:?}", self); + ( + StatusCode::SERVICE_UNAVAILABLE, + "Failed to parse internal data. This is likely a temporary issue. Please try again later.".to_string(), + ) + } + FlagError::CacheUpdateError => { + tracing::error!("Cache update error: {:?}", self); + ( + StatusCode::INTERNAL_SERVER_ERROR, + "Failed to update internal cache. This is likely a temporary issue. Please try again later.".to_string(), + ) + } + FlagError::RedisUnavailable => { + tracing::error!("Redis unavailable: {:?}", self); + ( + StatusCode::SERVICE_UNAVAILABLE, + "Our cache service is currently unavailable. This is likely a temporary issue. Please try again later.".to_string(), + ) + } + FlagError::DatabaseUnavailable => { + tracing::error!("Database unavailable: {:?}", self); + ( + StatusCode::SERVICE_UNAVAILABLE, + "Our database service is currently unavailable. This is likely a temporary issue. Please try again later.".to_string(), + ) + } + FlagError::TimeoutError => { + tracing::error!("Timeout error: {:?}", self); + ( + StatusCode::SERVICE_UNAVAILABLE, + "The request timed out. This could be due to high load or network issues. Please try again later.".to_string(), + ) } - - FlagError::RateLimited => (StatusCode::TOO_MANY_REQUESTS, self.to_string()), - - FlagError::DataParsingError - | FlagError::RedisUnavailable - | FlagError::DatabaseUnavailable - | FlagError::TimeoutError => (StatusCode::SERVICE_UNAVAILABLE, self.to_string()), } .into_response() } diff --git a/rust/feature-flags/src/flag_definitions.rs b/rust/feature-flags/src/flag_definitions.rs index cc208ae8b073f..ef1db6762a5ce 100644 --- a/rust/feature-flags/src/flag_definitions.rs +++ b/rust/feature-flags/src/flag_definitions.rs @@ -1,9 +1,8 @@ +use crate::{api::FlagError, database::Client as DatabaseClient, redis::Client as RedisClient}; use serde::{Deserialize, Serialize}; use std::sync::Arc; use tracing::instrument; -use crate::{api::FlagError, database::Client as DatabaseClient, redis::Client as RedisClient}; - // TRICKY: This cache data is coming from django-redis. If it ever goes out of sync, we'll bork. // TODO: Add integration tests across repos to ensure this doesn't happen. pub const TEAM_FLAGS_CACHE_PREFIX: &str = "posthog:1:team_feature_flags_"; @@ -65,8 +64,6 @@ pub struct MultivariateFlagOptions { pub variants: Vec, } -// TODO: test name with https://www.fileformat.info/info/charset/UTF-16/list.htm values, like '𝖕𝖗𝖔𝖕𝖊𝖗𝖙𝖞': `𝓿𝓪𝓵𝓾𝓮` - #[derive(Debug, Clone, Deserialize)] pub struct FlagFilters { pub groups: Vec, @@ -121,7 +118,6 @@ impl FeatureFlag { } #[derive(Debug, Deserialize)] - pub struct FeatureFlagList { pub flags: Vec, } @@ -155,38 +151,58 @@ impl FeatureFlagList { client: Arc, team_id: i32, ) -> Result { - let mut conn = client.get_connection().await?; - // TODO: Clean up error handling here + let mut conn = client.get_connection().await.map_err(|e| { + tracing::error!("Failed to get database connection: {}", e); + FlagError::DatabaseUnavailable + })?; let query = "SELECT id, team_id, name, key, filters, deleted, active, ensure_experience_continuity FROM posthog_featureflag WHERE team_id = $1"; let flags_row = sqlx::query_as::<_, FeatureFlagRow>(query) .bind(team_id) .fetch_all(&mut *conn) - .await?; + .await + .map_err(|e| { + tracing::error!("Failed to fetch feature flags from database: {}", e); + FlagError::Internal(format!("Database query error: {}", e)) + })?; - let serialized_flags = serde_json::to_string(&flags_row).map_err(|e| { - tracing::error!("failed to serialize flags: {}", e); - println!("failed to serialize flags: {}", e); - FlagError::DataParsingError - })?; + let flags_list = flags_row + .into_iter() + .map(|row| { + let filters = serde_json::from_value(row.filters).map_err(|e| { + tracing::error!("Failed to deserialize filters for flag {}: {}", row.key, e); + FlagError::DataParsingError + })?; - let flags_list: Vec = - serde_json::from_str(&serialized_flags).map_err(|e| { - tracing::error!("failed to parse data to flags list: {}", e); - println!("failed to parse data: {}", e); + Ok(FeatureFlag { + id: row.id, + team_id: row.team_id, + name: row.name, + key: row.key, + filters, + deleted: row.deleted, + active: row.active, + ensure_experience_continuity: row.ensure_experience_continuity, + }) + }) + .collect::, FlagError>>()?; - FlagError::DataParsingError - })?; Ok(FeatureFlagList { flags: flags_list }) } } #[cfg(test)] mod tests { + use crate::flag_definitions; + use rand::Rng; + use serde_json::json; + use std::time::Instant; + use tokio::task; + use super::*; use crate::test_utils::{ - insert_flags_for_team_in_pg, insert_flags_for_team_in_redis, insert_new_team_in_pg, - insert_new_team_in_redis, setup_pg_client, setup_redis_client, + insert_flag_for_team_in_pg, insert_flags_for_team_in_redis, insert_new_team_in_pg, + insert_new_team_in_redis, setup_invalid_pg_client, setup_pg_client, setup_redis_client, }; #[tokio::test] @@ -247,7 +263,7 @@ mod tests { .await .expect("Failed to insert team in pg"); - insert_flags_for_team_in_pg(client.clone(), team.id, None) + insert_flag_for_team_in_pg(client.clone(), team.id, None) .await .expect("Failed to insert flags"); @@ -282,6 +298,99 @@ mod tests { assert_eq!(flag.filters.groups[0].rollout_percentage, Some(50.0)); } + #[test] + fn test_utf16_property_names_and_values() { + let json_str = r#"{ + "id": 1, + "team_id": 2, + "name": "𝖚𝖙𝖋16_𝖙𝖊𝖘𝖙_𝖋𝖑𝖆𝖌", + "key": "𝖚𝖙𝖋16_𝖙𝖊𝖘𝖙_𝖋𝖑𝖆𝖌", + "filters": { + "groups": [ + { + "properties": [ + { + "key": "𝖕𝖗𝖔𝖕𝖊𝖗𝖙𝖞", + "value": "𝓿𝓪𝓵𝓾𝓮", + "type": "person" + } + ] + } + ] + } + }"#; + + let flag: FeatureFlag = serde_json::from_str(json_str).expect("Failed to deserialize"); + + assert_eq!(flag.key, "𝖚𝖙𝖋16_𝖙𝖊𝖘𝖙_𝖋𝖑𝖆𝖌"); + let property = &flag.filters.groups[0].properties.as_ref().unwrap()[0]; + assert_eq!(property.key, "𝖕𝖗𝖔𝖕𝖊𝖗𝖙𝖞"); + assert_eq!(property.value, json!("𝓿𝓪𝓵𝓾𝓮")); + } + + #[test] + fn test_deserialize_complex_flag() { + let json_str = r#"{ + "id": 1, + "team_id": 2, + "name": "Complex Flag", + "key": "complex_flag", + "filters": { + "groups": [ + { + "properties": [ + { + "key": "email", + "value": "test@example.com", + "operator": "exact", + "type": "person" + } + ], + "rollout_percentage": 50 + } + ], + "multivariate": { + "variants": [ + { + "key": "control", + "name": "Control Group", + "rollout_percentage": 33.33 + }, + { + "key": "test", + "name": "Test Group", + "rollout_percentage": 66.67 + } + ] + }, + "aggregation_group_type_index": 0, + "payloads": {"test": {"type": "json", "value": {"key": "value"}}} + }, + "deleted": false, + "active": true, + "ensure_experience_continuity": false + }"#; + + let flag: FeatureFlag = serde_json::from_str(json_str).expect("Failed to deserialize"); + + assert_eq!(flag.id, 1); + assert_eq!(flag.team_id, 2); + assert_eq!(flag.name, Some("Complex Flag".to_string())); + assert_eq!(flag.key, "complex_flag"); + assert_eq!(flag.filters.groups.len(), 1); + assert_eq!(flag.filters.groups[0].properties.as_ref().unwrap().len(), 1); + assert_eq!(flag.filters.groups[0].rollout_percentage, Some(50.0)); + assert_eq!( + flag.filters.multivariate.as_ref().unwrap().variants.len(), + 2 + ); + assert_eq!(flag.filters.aggregation_group_type_index, Some(0)); + assert!(flag.filters.payloads.is_some()); + assert!(!flag.deleted); + assert!(flag.active); + assert!(!flag.ensure_experience_continuity); + } + // TODO: Add more tests to validate deserialization of flags. // TODO: Also make sure old flag data is handled, or everything is migrated to new style in production @@ -298,4 +407,1073 @@ mod tests { } } } + + #[tokio::test] + async fn test_fetch_nonexistent_team_from_pg() { + let client = setup_pg_client(None).await; + + match FeatureFlagList::from_pg(client.clone(), -1).await { + Ok(flags) => assert_eq!(flags.flags.len(), 0), + Err(err) => panic!("Expected empty result, got error: {:?}", err), + } + } + + #[tokio::test] + async fn test_fetch_flags_db_connection_failure() { + // Simulate a database connection failure by using an invalid client setup + let client = setup_invalid_pg_client().await; + + match FeatureFlagList::from_pg(client, 1).await { + Err(FlagError::DatabaseUnavailable) => (), + other => panic!("Expected DatabaseUnavailable error, got: {:?}", other), + } + } + + #[tokio::test] + async fn test_fetch_multiple_flags_from_pg() { + let client = setup_pg_client(None).await; + + let team = insert_new_team_in_pg(client.clone()) + .await + .expect("Failed to insert team in pg"); + + let random_id_1 = rand::thread_rng().gen_range(0..10_000_000); + let random_id_2 = rand::thread_rng().gen_range(0..10_000_000); + + let flag1 = FeatureFlagRow { + id: random_id_1, + team_id: team.id, + name: Some("Test Flag".to_string()), + key: "test_flag".to_string(), + filters: serde_json::json!({"groups": [{"properties": [], "rollout_percentage": 100}]}), + deleted: false, + active: true, + ensure_experience_continuity: false, + }; + + let flag2 = FeatureFlagRow { + id: random_id_2, + team_id: team.id, + name: Some("Test Flag 2".to_string()), + key: "test_flag_2".to_string(), + filters: serde_json::json!({"groups": [{"properties": [], "rollout_percentage": 100}]}), + deleted: false, + active: true, + ensure_experience_continuity: false, + }; + + // Insert multiple flags for the team + insert_flag_for_team_in_pg(client.clone(), team.id, Some(flag1)) + .await + .expect("Failed to insert flags"); + + insert_flag_for_team_in_pg(client.clone(), team.id, Some(flag2)) + .await + .expect("Failed to insert flags"); + + let flags_from_pg = FeatureFlagList::from_pg(client.clone(), team.id) + .await + .expect("Failed to fetch flags from pg"); + + assert_eq!(flags_from_pg.flags.len(), 2); + for flag in &flags_from_pg.flags { + assert_eq!(flag.team_id, team.id); + } + } + + #[test] + fn test_operator_type_deserialization() { + let operators = vec![ + ("exact", OperatorType::Exact), + ("is_not", OperatorType::IsNot), + ("icontains", OperatorType::Icontains), + ("not_icontains", OperatorType::NotIcontains), + ("regex", OperatorType::Regex), + ("not_regex", OperatorType::NotRegex), + ("gt", OperatorType::Gt), + ("lt", OperatorType::Lt), + ("gte", OperatorType::Gte), + ("lte", OperatorType::Lte), + ("is_set", OperatorType::IsSet), + ("is_not_set", OperatorType::IsNotSet), + ("is_date_exact", OperatorType::IsDateExact), + ("is_date_after", OperatorType::IsDateAfter), + ("is_date_before", OperatorType::IsDateBefore), + ]; + + for (op_str, op_type) in operators { + let json = format!( + r#"{{ + "key": "test_key", + "value": "test_value", + "operator": "{}", + "type": "person" + }}"#, + op_str + ); + let deserialized: PropertyFilter = serde_json::from_str(&json).unwrap(); + assert_eq!(deserialized.operator, Some(op_type)); + } + } + + #[tokio::test] + async fn test_multivariate_flag_parsing() { + let redis_client = setup_redis_client(None); + let pg_client = setup_pg_client(None).await; + + let team = insert_new_team_in_pg(pg_client.clone()) + .await + .expect("Failed to insert team in pg"); + + let multivariate_flag = json!({ + "id": 1, + "team_id": team.id, + "name": "Multivariate Flag", + "key": "multivariate_flag", + "filters": { + "groups": [ + { + "properties": [], + "rollout_percentage": 100 + } + ], + "multivariate": { + "variants": [ + { + "key": "control", + "name": "Control Group", + "rollout_percentage": 33.33 + }, + { + "key": "test_a", + "name": "Test Group A", + "rollout_percentage": 33.33 + }, + { + "key": "test_b", + "name": "Test Group B", + "rollout_percentage": 33.34 + } + ] + } + }, + "active": true, + "deleted": false + }); + + // Insert into Redis + insert_flags_for_team_in_redis( + redis_client.clone(), + team.id, + Some(json!([multivariate_flag]).to_string()), + ) + .await + .expect("Failed to insert flag in Redis"); + + // Insert into Postgres + insert_flag_for_team_in_pg( + pg_client.clone(), + team.id, + Some(FeatureFlagRow { + id: 1, + team_id: team.id, + name: Some("Multivariate Flag".to_string()), + key: "multivariate_flag".to_string(), + filters: multivariate_flag["filters"].clone(), + deleted: false, + active: true, + ensure_experience_continuity: false, + }), + ) + .await + .expect("Failed to insert flag in Postgres"); + + // Fetch and verify from Redis + let redis_flags = FeatureFlagList::from_redis(redis_client, team.id) + .await + .expect("Failed to fetch flags from Redis"); + + assert_eq!(redis_flags.flags.len(), 1); + let redis_flag = &redis_flags.flags[0]; + assert_eq!(redis_flag.key, "multivariate_flag"); + assert_eq!(redis_flag.get_variants().len(), 3); + + // Fetch and verify from Postgres + let pg_flags = FeatureFlagList::from_pg(pg_client, team.id) + .await + .expect("Failed to fetch flags from Postgres"); + + assert_eq!(pg_flags.flags.len(), 1); + let pg_flag = &pg_flags.flags[0]; + assert_eq!(pg_flag.key, "multivariate_flag"); + assert_eq!(pg_flag.get_variants().len(), 3); + } + + #[tokio::test] + async fn test_multivariate_flag_with_payloads() { + let redis_client = setup_redis_client(None); + let pg_client = setup_pg_client(None).await; + + let team = insert_new_team_in_pg(pg_client.clone()) + .await + .expect("Failed to insert team in pg"); + + let multivariate_flag_with_payloads = json!({ + "id": 1, + "team_id": team.id, + "name": "Multivariate Flag with Payloads", + "key": "multivariate_flag_with_payloads", + "filters": { + "groups": [ + { + "properties": [], + "rollout_percentage": 100 + } + ], + "multivariate": { + "variants": [ + { + "key": "control", + "name": "Control Group", + "rollout_percentage": 33.33 + }, + { + "key": "test_a", + "name": "Test Group A", + "rollout_percentage": 33.33 + }, + { + "key": "test_b", + "name": "Test Group B", + "rollout_percentage": 33.34 + } + ] + }, + "payloads": { + "control": {"type": "json", "value": {"feature": "old"}}, + "test_a": {"type": "json", "value": {"feature": "new_a"}}, + "test_b": {"type": "json", "value": {"feature": "new_b"}} + } + }, + "active": true, + "deleted": false + }); + + // Insert into Redis + insert_flags_for_team_in_redis( + redis_client.clone(), + team.id, + Some(json!([multivariate_flag_with_payloads]).to_string()), + ) + .await + .expect("Failed to insert flag in Redis"); + + // Insert into Postgres + insert_flag_for_team_in_pg( + pg_client.clone(), + team.id, + Some(FeatureFlagRow { + id: 1, + team_id: team.id, + name: Some("Multivariate Flag with Payloads".to_string()), + key: "multivariate_flag_with_payloads".to_string(), + filters: multivariate_flag_with_payloads["filters"].clone(), + deleted: false, + active: true, + ensure_experience_continuity: false, + }), + ) + .await + .expect("Failed to insert flag in Postgres"); + + // Fetch and verify from Redis + let redis_flags = FeatureFlagList::from_redis(redis_client, team.id) + .await + .expect("Failed to fetch flags from Redis"); + + assert_eq!(redis_flags.flags.len(), 1); + let redis_flag = &redis_flags.flags[0]; + assert_eq!(redis_flag.key, "multivariate_flag_with_payloads"); + + // Fetch and verify from Postgres + let pg_flags = FeatureFlagList::from_pg(pg_client, team.id) + .await + .expect("Failed to fetch flags from Postgres"); + + assert_eq!(pg_flags.flags.len(), 1); + let pg_flag = &pg_flags.flags[0]; + assert_eq!(pg_flag.key, "multivariate_flag_with_payloads"); + + // Verify flag contents for both Redis and Postgres + for (source, flag) in [("Redis", redis_flag), ("Postgres", pg_flag)].iter() { + // Check multivariate options + assert!(flag.filters.multivariate.is_some()); + let multivariate = flag.filters.multivariate.as_ref().unwrap(); + assert_eq!(multivariate.variants.len(), 3); + + // Check variant details + let variant_keys = ["control", "test_a", "test_b"]; + let expected_names = ["Control Group", "Test Group A", "Test Group B"]; + for (i, (key, expected_name)) in + variant_keys.iter().zip(expected_names.iter()).enumerate() + { + let variant = &multivariate.variants[i]; + assert_eq!(variant.key, *key); + assert_eq!( + variant.name, + Some(expected_name.to_string()), + "Incorrect variant name for {} in {}", + key, + source + ); + } + + // Check payloads + assert!(flag.filters.payloads.is_some()); + let payloads = flag.filters.payloads.as_ref().unwrap(); + + for key in variant_keys.iter() { + let payload = payloads[key].as_object().unwrap(); + assert_eq!(payload["type"], "json"); + + let value = payload["value"].as_object().unwrap(); + let expected_feature = match *key { + "control" => "old", + "test_a" => "new_a", + "test_b" => "new_b", + _ => panic!("Unexpected variant key"), + }; + assert_eq!( + value["feature"], expected_feature, + "Incorrect payload value for {} in {}", + key, source + ); + } + } + } + #[tokio::test] + async fn test_flag_with_super_groups() { + let redis_client = setup_redis_client(None); + let pg_client = setup_pg_client(None).await; + + let team = insert_new_team_in_pg(pg_client.clone()) + .await + .expect("Failed to insert team in pg"); + + let flag_with_super_groups = json!({ + "id": 1, + "team_id": team.id, + "name": "Flag with Super Groups", + "key": "flag_with_super_groups", + "filters": { + "groups": [ + { + "properties": [], + "rollout_percentage": 50 + } + ], + "super_groups": [ + { + "properties": [ + { + "key": "country", + "value": "US", + "type": "person", + "operator": "exact" + } + ], + "rollout_percentage": 100 + } + ] + }, + "active": true, + "deleted": false + }); + + // Insert into Redis + insert_flags_for_team_in_redis( + redis_client.clone(), + team.id, + Some(json!([flag_with_super_groups]).to_string()), + ) + .await + .expect("Failed to insert flag in Redis"); + + // Insert into Postgres + insert_flag_for_team_in_pg( + pg_client.clone(), + team.id, + Some(FeatureFlagRow { + id: 1, + team_id: team.id, + name: Some("Flag with Super Groups".to_string()), + key: "flag_with_super_groups".to_string(), + filters: flag_with_super_groups["filters"].clone(), + deleted: false, + active: true, + ensure_experience_continuity: false, + }), + ) + .await + .expect("Failed to insert flag in Postgres"); + + // Fetch and verify from Redis + let redis_flags = FeatureFlagList::from_redis(redis_client, team.id) + .await + .expect("Failed to fetch flags from Redis"); + + assert_eq!(redis_flags.flags.len(), 1); + let redis_flag = &redis_flags.flags[0]; + assert_eq!(redis_flag.key, "flag_with_super_groups"); + assert!(redis_flag.filters.super_groups.is_some()); + assert_eq!(redis_flag.filters.super_groups.as_ref().unwrap().len(), 1); + + // Fetch and verify from Postgres + let pg_flags = FeatureFlagList::from_pg(pg_client, team.id) + .await + .expect("Failed to fetch flags from Postgres"); + + assert_eq!(pg_flags.flags.len(), 1); + let pg_flag = &pg_flags.flags[0]; + assert_eq!(pg_flag.key, "flag_with_super_groups"); + assert!(pg_flag.filters.super_groups.is_some()); + assert_eq!(pg_flag.filters.super_groups.as_ref().unwrap().len(), 1); + } + + #[tokio::test] + async fn test_flags_with_different_property_types() { + let redis_client = setup_redis_client(None); + let pg_client = setup_pg_client(None).await; + + let team = insert_new_team_in_pg(pg_client.clone()) + .await + .expect("Failed to insert team in pg"); + + let flag_with_different_properties = json!({ + "id": 1, + "team_id": team.id, + "name": "Flag with Different Properties", + "key": "flag_with_different_properties", + "filters": { + "groups": [ + { + "properties": [ + { + "key": "email", + "value": "test@example.com", + "type": "person", + "operator": "exact" + }, + { + "key": "country", + "value": "US", + "type": "group", + "operator": "exact" + }, + { + "key": "purchase", + "value": "completed", + "type": "event", + "operator": "exact" + } + ], + "rollout_percentage": 100 + } + ] + }, + "active": true, + "deleted": false + }); + + // Insert into Redis + insert_flags_for_team_in_redis( + redis_client.clone(), + team.id, + Some(json!([flag_with_different_properties]).to_string()), + ) + .await + .expect("Failed to insert flag in Redis"); + + // Insert into Postgres + insert_flag_for_team_in_pg( + pg_client.clone(), + team.id, + Some(FeatureFlagRow { + id: 1, + team_id: team.id, + name: Some("Flag with Different Properties".to_string()), + key: "flag_with_different_properties".to_string(), + filters: flag_with_different_properties["filters"].clone(), + deleted: false, + active: true, + ensure_experience_continuity: false, + }), + ) + .await + .expect("Failed to insert flag in Postgres"); + + // Fetch and verify from Redis + let redis_flags = FeatureFlagList::from_redis(redis_client, team.id) + .await + .expect("Failed to fetch flags from Redis"); + + assert_eq!(redis_flags.flags.len(), 1); + let redis_flag = &redis_flags.flags[0]; + assert_eq!(redis_flag.key, "flag_with_different_properties"); + let redis_properties = &redis_flag.filters.groups[0].properties.as_ref().unwrap(); + assert_eq!(redis_properties.len(), 3); + assert_eq!(redis_properties[0].prop_type, "person"); + assert_eq!(redis_properties[1].prop_type, "group"); + assert_eq!(redis_properties[2].prop_type, "event"); + + // Fetch and verify from Postgres + let pg_flags = FeatureFlagList::from_pg(pg_client, team.id) + .await + .expect("Failed to fetch flags from Postgres"); + + assert_eq!(pg_flags.flags.len(), 1); + let pg_flag = &pg_flags.flags[0]; + assert_eq!(pg_flag.key, "flag_with_different_properties"); + let pg_properties = &pg_flag.filters.groups[0].properties.as_ref().unwrap(); + assert_eq!(pg_properties.len(), 3); + assert_eq!(pg_properties[0].prop_type, "person"); + assert_eq!(pg_properties[1].prop_type, "group"); + assert_eq!(pg_properties[2].prop_type, "event"); + } + + #[tokio::test] + async fn test_deleted_and_inactive_flags() { + let redis_client = setup_redis_client(None); + let pg_client = setup_pg_client(None).await; + + let team = insert_new_team_in_pg(pg_client.clone()) + .await + .expect("Failed to insert team in pg"); + + let deleted_flag = json!({ + "id": 1, + "team_id": team.id, + "name": "Deleted Flag", + "key": "deleted_flag", + "filters": {"groups": []}, + "active": true, + "deleted": true + }); + + let inactive_flag = json!({ + "id": 2, + "team_id": team.id, + "name": "Inactive Flag", + "key": "inactive_flag", + "filters": {"groups": []}, + "active": false, + "deleted": false + }); + + // Insert into Redis + insert_flags_for_team_in_redis( + redis_client.clone(), + team.id, + Some(json!([deleted_flag, inactive_flag]).to_string()), + ) + .await + .expect("Failed to insert flags in Redis"); + + // Insert into Postgres + insert_flag_for_team_in_pg( + pg_client.clone(), + team.id, + Some(FeatureFlagRow { + id: 0, + team_id: team.id, + name: Some("Deleted Flag".to_string()), + key: "deleted_flag".to_string(), + filters: deleted_flag["filters"].clone(), + deleted: true, + active: true, + ensure_experience_continuity: false, + }), + ) + .await + .expect("Failed to insert deleted flag in Postgres"); + + insert_flag_for_team_in_pg( + pg_client.clone(), + team.id, + Some(FeatureFlagRow { + id: 0, + team_id: team.id, + name: Some("Inactive Flag".to_string()), + key: "inactive_flag".to_string(), + filters: inactive_flag["filters"].clone(), + deleted: false, + active: false, + ensure_experience_continuity: false, + }), + ) + .await + .expect("Failed to insert inactive flag in Postgres"); + + // Fetch and verify from Redis + let redis_flags = FeatureFlagList::from_redis(redis_client, team.id) + .await + .expect("Failed to fetch flags from Redis"); + + assert_eq!(redis_flags.flags.len(), 2); + assert!(redis_flags + .flags + .iter() + .any(|f| f.key == "deleted_flag" && f.deleted)); + assert!(redis_flags + .flags + .iter() + .any(|f| f.key == "inactive_flag" && !f.active)); + + // Fetch and verify from Postgres + let pg_flags = FeatureFlagList::from_pg(pg_client, team.id) + .await + .expect("Failed to fetch flags from Postgres"); + + assert_eq!(pg_flags.flags.len(), 2); + assert!(pg_flags + .flags + .iter() + .any(|f| f.key == "deleted_flag" && f.deleted)); + assert!(pg_flags + .flags + .iter() + .any(|f| f.key == "inactive_flag" && !f.active)); + } + + #[tokio::test] + async fn test_error_handling() { + let redis_client = setup_redis_client(Some("redis://localhost:6379/".to_string())); + let pg_client = setup_pg_client(None).await; + + // Test Redis connection error + let bad_redis_client = setup_redis_client(Some("redis://localhost:1111/".to_string())); + let result = FeatureFlagList::from_redis(bad_redis_client, 1).await; + assert!(matches!(result, Err(FlagError::RedisUnavailable))); + + // Test malformed JSON in Redis + let team = insert_new_team_in_pg(pg_client.clone()) + .await + .expect("Failed to insert team in pg"); + + redis_client + .set( + format!("{}{}", flag_definitions::TEAM_FLAGS_CACHE_PREFIX, team.id), + "not a json".to_string(), + ) + .await + .expect("Failed to set malformed JSON in Redis"); + + let result = FeatureFlagList::from_redis(redis_client, team.id).await; + assert!(matches!(result, Err(FlagError::DataParsingError))); + + // Test database query error (using a non-existent table) + let result = sqlx::query("SELECT * FROM non_existent_table") + .fetch_all(&mut *pg_client.get_connection().await.unwrap()) + .await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_concurrent_access() { + let redis_client = setup_redis_client(None); + let pg_client = setup_pg_client(None).await; + + let team = insert_new_team_in_pg(pg_client.clone()) + .await + .expect("Failed to insert team in pg"); + + let flag = json!({ + "id": 1, + "team_id": team.id, + "name": "Concurrent Flag", + "key": "concurrent_flag", + "filters": {"groups": []}, + "active": true, + "deleted": false + }); + + insert_flags_for_team_in_redis( + redis_client.clone(), + team.id, + Some(json!([flag]).to_string()), + ) + .await + .expect("Failed to insert flag in Redis"); + + insert_flag_for_team_in_pg( + pg_client.clone(), + team.id, + Some(FeatureFlagRow { + id: 0, + team_id: team.id, + name: Some("Concurrent Flag".to_string()), + key: "concurrent_flag".to_string(), + filters: flag["filters"].clone(), + deleted: false, + active: true, + ensure_experience_continuity: false, + }), + ) + .await + .expect("Failed to insert flag in Postgres"); + + let mut handles = vec![]; + for _ in 0..10 { + let redis_client = redis_client.clone(); + let pg_client = pg_client.clone(); + let team_id = team.id; + + let handle = task::spawn(async move { + let redis_flags = FeatureFlagList::from_redis(redis_client, team_id) + .await + .unwrap(); + let pg_flags = FeatureFlagList::from_pg(pg_client, team_id).await.unwrap(); + (redis_flags, pg_flags) + }); + + handles.push(handle); + } + + for handle in handles { + let (redis_flags, pg_flags) = handle.await.unwrap(); + assert_eq!(redis_flags.flags.len(), 1); + assert_eq!(pg_flags.flags.len(), 1); + assert_eq!(redis_flags.flags[0].key, "concurrent_flag"); + assert_eq!(pg_flags.flags[0].key, "concurrent_flag"); + } + } + + #[tokio::test] + #[ignore] + async fn test_performance() { + let redis_client = setup_redis_client(None); + let pg_client = setup_pg_client(None).await; + + let team = insert_new_team_in_pg(pg_client.clone()) + .await + .expect("Failed to insert team in pg"); + + let num_flags = 1000; + let mut flags = Vec::with_capacity(num_flags); + + for i in 0..num_flags { + let flag = json!({ + "id": i, + "team_id": team.id, + "name": format!("Flag {}", i), + "key": format!("flag_{}", i), + "filters": {"groups": []}, + "active": true, + "deleted": false + }); + flags.push(flag); + } + + insert_flags_for_team_in_redis( + redis_client.clone(), + team.id, + Some(json!(flags).to_string()), + ) + .await + .expect("Failed to insert flags in Redis"); + + for flag in flags { + insert_flag_for_team_in_pg( + pg_client.clone(), + team.id, + Some(FeatureFlagRow { + id: 0, + team_id: team.id, + name: Some(flag["name"].as_str().unwrap().to_string()), + key: flag["key"].as_str().unwrap().to_string(), + filters: flag["filters"].clone(), + deleted: false, + active: true, + ensure_experience_continuity: false, + }), + ) + .await + .expect("Failed to insert flag in Postgres"); + } + + let start = Instant::now(); + let redis_flags = FeatureFlagList::from_redis(redis_client, team.id) + .await + .expect("Failed to fetch flags from Redis"); + let redis_duration = start.elapsed(); + + let start = Instant::now(); + let pg_flags = FeatureFlagList::from_pg(pg_client, team.id) + .await + .expect("Failed to fetch flags from Postgres"); + let pg_duration = start.elapsed(); + + println!("Redis fetch time: {:?}", redis_duration); + println!("Postgres fetch time: {:?}", pg_duration); + + assert_eq!(redis_flags.flags.len(), num_flags); + assert_eq!(pg_flags.flags.len(), num_flags); + + assert!(redis_duration < std::time::Duration::from_millis(100)); + assert!(pg_duration < std::time::Duration::from_millis(1000)); + } + + #[tokio::test] + async fn test_edge_cases() { + let redis_client = setup_redis_client(None); + let pg_client = setup_pg_client(None).await; + + let team = insert_new_team_in_pg(pg_client.clone()) + .await + .expect("Failed to insert team in pg"); + + let edge_case_flags = json!([ + { + "id": 1, + "team_id": team.id, + "name": "Empty Properties Flag", + "key": "empty_properties", + "filters": {"groups": [{"properties": [], "rollout_percentage": 100}]}, + "active": true, + "deleted": false + }, + { + "id": 2, + "team_id": team.id, + "name": "Very Long Key Flag", + "key": "a".repeat(400), // max key length is 400 + "filters": {"groups": [{"properties": [], "rollout_percentage": 100}]}, + "active": true, + "deleted": false + }, + { + "id": 3, + "team_id": team.id, + "name": "Unicode Flag", + "key": "unicode_flag_🚀", + "filters": {"groups": [{"properties": [{"key": "country", "value": "🇯🇵", "type": "person"}], "rollout_percentage": 100}]}, + "active": true, + "deleted": false + } + ]); + + // Insert edge case flags + insert_flags_for_team_in_redis( + redis_client.clone(), + team.id, + Some(edge_case_flags.to_string()), + ) + .await + .expect("Failed to insert edge case flags in Redis"); + + for flag in edge_case_flags.as_array().unwrap() { + insert_flag_for_team_in_pg( + pg_client.clone(), + team.id, + Some(FeatureFlagRow { + id: 0, + team_id: team.id, + name: flag["name"].as_str().map(|s| s.to_string()), + key: flag["key"].as_str().unwrap().to_string(), + filters: flag["filters"].clone(), + deleted: false, + active: true, + ensure_experience_continuity: false, + }), + ) + .await + .expect("Failed to insert edge case flag in Postgres"); + } + + // Fetch and verify edge case flags + let redis_flags = FeatureFlagList::from_redis(redis_client, team.id) + .await + .expect("Failed to fetch flags from Redis"); + let pg_flags = FeatureFlagList::from_pg(pg_client, team.id) + .await + .expect("Failed to fetch flags from Postgres"); + + assert_eq!(redis_flags.flags.len(), 3); + assert_eq!(pg_flags.flags.len(), 3); + + // Verify empty properties flag + assert!(redis_flags.flags.iter().any(|f| f.key == "empty_properties" + && f.filters.groups[0].properties.as_ref().unwrap().is_empty())); + assert!(pg_flags.flags.iter().any(|f| f.key == "empty_properties" + && f.filters.groups[0].properties.as_ref().unwrap().is_empty())); + + // Verify very long key flag + assert!(redis_flags.flags.iter().any(|f| f.key.len() == 400)); + assert!(pg_flags.flags.iter().any(|f| f.key.len() == 400)); + + // Verify unicode flag + assert!(redis_flags.flags.iter().any(|f| f.key == "unicode_flag_🚀")); + assert!(pg_flags.flags.iter().any(|f| f.key == "unicode_flag_🚀")); + } + + #[tokio::test] + async fn test_consistent_behavior_from_both_clients() { + let redis_client = setup_redis_client(None); + let pg_client = setup_pg_client(None).await; + + let team = insert_new_team_in_pg(pg_client.clone()) + .await + .expect("Failed to insert team in pg"); + + let flags = json!([ + { + "id": 1, + "team_id": team.id, + "name": "Flag 1", + "key": "flag_1", + "filters": {"groups": [{"properties": [], "rollout_percentage": 50}]}, + "active": true, + "deleted": false + }, + { + "id": 2, + "team_id": team.id, + "name": "Flag 2", + "key": "flag_2", + "filters": {"groups": [{"properties": [], "rollout_percentage": 75}]}, + "active": true, + "deleted": false + } + ]); + + // Insert flags in both Redis and Postgres + insert_flags_for_team_in_redis(redis_client.clone(), team.id, Some(flags.to_string())) + .await + .expect("Failed to insert flags in Redis"); + + for flag in flags.as_array().unwrap() { + insert_flag_for_team_in_pg( + pg_client.clone(), + team.id, + Some(FeatureFlagRow { + id: 0, + team_id: team.id, + name: flag["name"].as_str().map(|s| s.to_string()), + key: flag["key"].as_str().unwrap().to_string(), + filters: flag["filters"].clone(), + deleted: false, + active: true, + ensure_experience_continuity: false, + }), + ) + .await + .expect("Failed to insert flag in Postgres"); + } + + // Fetch flags from both sources + let redis_flags = FeatureFlagList::from_redis(redis_client, team.id) + .await + .expect("Failed to fetch flags from Redis"); + let pg_flags = FeatureFlagList::from_pg(pg_client, team.id) + .await + .expect("Failed to fetch flags from Postgres"); + + // Compare results + assert_eq!(redis_flags.flags.len(), pg_flags.flags.len()); + for (redis_flag, pg_flag) in redis_flags.flags.iter().zip(pg_flags.flags.iter()) { + assert_eq!(redis_flag.key, pg_flag.key); + assert_eq!(redis_flag.name, pg_flag.name); + assert_eq!(redis_flag.active, pg_flag.active); + assert_eq!(redis_flag.deleted, pg_flag.deleted); + assert_eq!( + redis_flag.filters.groups[0].rollout_percentage, + pg_flag.filters.groups[0].rollout_percentage + ); + } + } + + #[tokio::test] + async fn test_rollout_percentage_edge_cases() { + let redis_client = setup_redis_client(None); + let pg_client = setup_pg_client(None).await; + + let team = insert_new_team_in_pg(pg_client.clone()) + .await + .expect("Failed to insert team in pg"); + + let flags = json!([ + { + "id": 1, + "team_id": team.id, + "name": "0% Rollout", + "key": "zero_percent", + "filters": {"groups": [{"properties": [], "rollout_percentage": 0}]}, + "active": true, + "deleted": false + }, + { + "id": 2, + "team_id": team.id, + "name": "100% Rollout", + "key": "hundred_percent", + "filters": {"groups": [{"properties": [], "rollout_percentage": 100}]}, + "active": true, + "deleted": false + }, + { + "id": 3, + "team_id": team.id, + "name": "Fractional Rollout", + "key": "fractional_percent", + "filters": {"groups": [{"properties": [], "rollout_percentage": 33.33}]}, + "active": true, + "deleted": false + } + ]); + + // Insert flags in both Redis and Postgres + insert_flags_for_team_in_redis(redis_client.clone(), team.id, Some(flags.to_string())) + .await + .expect("Failed to insert flags in Redis"); + + for flag in flags.as_array().unwrap() { + insert_flag_for_team_in_pg( + pg_client.clone(), + team.id, + Some(FeatureFlagRow { + id: 0, + team_id: team.id, + name: flag["name"].as_str().map(|s| s.to_string()), + key: flag["key"].as_str().unwrap().to_string(), + filters: flag["filters"].clone(), + deleted: false, + active: true, + ensure_experience_continuity: false, + }), + ) + .await + .expect("Failed to insert flag in Postgres"); + } + + // Fetch flags from both sources + let redis_flags = FeatureFlagList::from_redis(redis_client, team.id) + .await + .expect("Failed to fetch flags from Redis"); + let pg_flags = FeatureFlagList::from_pg(pg_client, team.id) + .await + .expect("Failed to fetch flags from Postgres"); + + // Verify rollout percentages + for flags in &[redis_flags, pg_flags] { + assert!(flags + .flags + .iter() + .any(|f| f.key == "zero_percent" + && f.filters.groups[0].rollout_percentage == Some(0.0))); + assert!(flags.flags.iter().any(|f| f.key == "hundred_percent" + && f.filters.groups[0].rollout_percentage == Some(100.0))); + assert!(flags.flags.iter().any(|f| f.key == "fractional_percent" + && (f.filters.groups[0].rollout_percentage.unwrap() - 33.33).abs() < f64::EPSILON)); + } + } } diff --git a/rust/feature-flags/src/team.rs b/rust/feature-flags/src/team.rs index 7c7cfd9547bbf..678668490485d 100644 --- a/rust/feature-flags/src/team.rs +++ b/rust/feature-flags/src/team.rs @@ -23,7 +23,7 @@ impl Team { client: Arc, token: String, ) -> Result { - // TODO: Instead of failing here, i.e. if not in redis, fallback to pg + // NB: if this lookup fails, we fall back to the database before returning an error let serialized_team = client .get(format!("{TEAM_TOKEN_CACHE_PREFIX}{}", token)) .await?; @@ -37,6 +37,30 @@ impl Team { Ok(team) } + #[instrument(skip_all)] + pub async fn update_redis_cache( + client: Arc, + team: Team, + ) -> Result<(), FlagError> { + let serialized_team = serde_json::to_string(&team).map_err(|e| { + tracing::error!("Failed to serialize team: {}", e); + FlagError::DataParsingError + })?; + + client + .set( + format!("{TEAM_TOKEN_CACHE_PREFIX}{}", team.api_token), + serialized_team, + ) + .await + .map_err(|e| { + tracing::error!("Failed to update Redis cache: {}", e); + FlagError::CacheUpdateError + })?; + + Ok(()) + } + pub async fn from_pg( client: Arc, token: String, diff --git a/rust/feature-flags/src/test_utils.rs b/rust/feature-flags/src/test_utils.rs index 9d1f5970d46b6..20b33ba5c3543 100644 --- a/rust/feature-flags/src/test_utils.rs +++ b/rust/feature-flags/src/test_utils.rs @@ -1,11 +1,13 @@ use anyhow::Error; +use axum::async_trait; use serde_json::{json, Value}; +use sqlx::{pool::PoolConnection, postgres::PgRow, Error as SqlxError, Postgres}; use std::sync::Arc; use uuid::Uuid; use crate::{ config::{Config, DEFAULT_TEST_CONFIG}, - database::{Client as DatabaseClientTrait, PgClient}, + database::{Client, CustomDatabaseError, PgClient}, flag_definitions::{self, FeatureFlag, FeatureFlagRow}, redis::{Client as RedisClientTrait, RedisClient}, team::{self, Team}, @@ -137,6 +139,30 @@ pub async fn setup_pg_client(config: Option<&Config>) -> Arc { ) } +pub struct MockPgClient; + +#[async_trait] +impl Client for MockPgClient { + async fn run_query( + &self, + _query: String, + _parameters: Vec, + _timeout_ms: Option, + ) -> Result, CustomDatabaseError> { + // Simulate a database connection failure + Err(CustomDatabaseError::Other(SqlxError::PoolTimedOut)) + } + + async fn get_connection(&self) -> Result, CustomDatabaseError> { + // Simulate a database connection failure + Err(CustomDatabaseError::Other(SqlxError::PoolTimedOut)) + } +} + +pub async fn setup_invalid_pg_client() -> Arc { + Arc::new(MockPgClient) +} + pub async fn insert_new_team_in_pg(client: Arc) -> Result { const ORG_ID: &str = "019026a4be8000005bf3171d00629163"; @@ -184,7 +210,7 @@ pub async fn insert_new_team_in_pg(client: Arc) -> Result Ok(team) } -pub async fn insert_flags_for_team_in_pg( +pub async fn insert_flag_for_team_in_pg( client: Arc, team_id: i32, flag: Option, @@ -192,7 +218,10 @@ pub async fn insert_flags_for_team_in_pg( let id = rand::thread_rng().gen_range(0..10_000_000); let payload_flag = match flag { - Some(value) => value, + Some(mut value) => { + value.id = id; + value + } None => FeatureFlagRow { id, key: "flag1".to_string(), diff --git a/rust/feature-flags/src/v0_endpoint.rs b/rust/feature-flags/src/v0_endpoint.rs index ba4bcef8fec47..d32f976d94447 100644 --- a/rust/feature-flags/src/v0_endpoint.rs +++ b/rust/feature-flags/src/v0_endpoint.rs @@ -8,6 +8,7 @@ use axum::http::{HeaderMap, Method}; use axum_client_ip::InsecureClientIp; use tracing::instrument; +use crate::api::FlagValue; use crate::{ api::{FlagError, FlagsResponse}, router, @@ -72,7 +73,7 @@ pub async fn flags( }?; let token = request - .extract_and_verify_token(state.redis.clone()) + .extract_and_verify_token(state.redis.clone(), state.postgres.clone()) .await?; let distinct_id = request.extract_distinct_id()?; @@ -87,8 +88,11 @@ pub async fn flags( Ok(Json(FlagsResponse { error_while_computing_flags: false, feature_flags: HashMap::from([ - ("beta-feature".to_string(), "variant-1".to_string()), - ("rollout-flag".to_string(), true.to_string()), + ( + "beta-feature".to_string(), + FlagValue::String("variant-1".to_string()), + ), + ("rollout-flag".to_string(), FlagValue::Boolean(true)), ]), })) } diff --git a/rust/feature-flags/src/v0_request.rs b/rust/feature-flags/src/v0_request.rs index 63b26b455f6f4..4447cb64d1d68 100644 --- a/rust/feature-flags/src/v0_request.rs +++ b/rust/feature-flags/src/v0_request.rs @@ -5,7 +5,9 @@ use serde::{Deserialize, Serialize}; use serde_json::Value; use tracing::instrument; -use crate::{api::FlagError, redis::Client, team::Team}; +use crate::{ + api::FlagError, database::Client as DatabaseClient, redis::Client as RedisClient, team::Team, +}; #[derive(Deserialize, Default)] pub struct FlagsQueryParams { @@ -53,7 +55,8 @@ impl FlagRequest { pub async fn extract_and_verify_token( &self, - redis_client: Arc, + redis_client: Arc, + pg_client: Arc, ) -> Result { let token = match self { FlagRequest { @@ -62,12 +65,22 @@ impl FlagRequest { _ => return Err(FlagError::NoTokenError), }; - // validate token - Team::from_redis(redis_client, token.clone()).await?; - - // TODO: fallback when token not found in redis - - Ok(token) + match Team::from_redis(redis_client.clone(), token.clone()).await { + Ok(_) => Ok(token), + Err(_) => { + // Fallback: Check PostgreSQL if not found in Redis + match Team::from_pg(pg_client, token.clone()).await { + Ok(team) => { + // Token found in PostgreSQL, update Redis cache + if let Err(e) = Team::update_redis_cache(redis_client, team).await { + tracing::warn!("Failed to update Redis cache: {}", e); + } + Ok(token) + } + Err(_) => Err(FlagError::TokenValidationError), + } + } + } } pub fn extract_distinct_id(&self) -> Result { diff --git a/rust/feature-flags/tests/test_flags.rs b/rust/feature-flags/tests/test_flags.rs index f9a46e1c543af..7f50064daddb6 100644 --- a/rust/feature-flags/tests/test_flags.rs +++ b/rust/feature-flags/tests/test_flags.rs @@ -41,7 +41,7 @@ async fn it_sends_flag_request() -> Result<()> { "errorWhileComputingFlags": false, "featureFlags": { "beta-feature": "variant-1", - "rollout-flag": "true", + "rollout-flag": true, } }) ); @@ -77,8 +77,134 @@ async fn it_rejects_invalid_headers_flag_request() -> Result<()> { assert_eq!( response_text, - "failed to decode request: unsupported content type: xyz" + "Failed to decode request: unsupported content type: xyz. Please check your request format and try again." ); Ok(()) } + +#[tokio::test] +async fn it_rejects_empty_distinct_id() -> Result<()> { + let config = DEFAULT_TEST_CONFIG.clone(); + let client = setup_redis_client(Some(config.redis_url.clone())); + let team = insert_new_team_in_redis(client.clone()).await.unwrap(); + let token = team.api_token; + let server = ServerHandle::for_config(config).await; + + let payload = json!({ + "token": token, + "distinct_id": "", + "groups": {"group1": "group1"} + }); + let res = server.send_flags_request(payload.to_string()).await; + assert_eq!(StatusCode::BAD_REQUEST, res.status()); + assert_eq!( + res.text().await?, + "The distinct_id field cannot be empty. Please provide a valid identifier." + ); + Ok(()) +} + +#[tokio::test] +async fn it_rejects_missing_distinct_id() -> Result<()> { + let config = DEFAULT_TEST_CONFIG.clone(); + let client = setup_redis_client(Some(config.redis_url.clone())); + let team = insert_new_team_in_redis(client.clone()).await.unwrap(); + let token = team.api_token; + let server = ServerHandle::for_config(config).await; + + let payload = json!({ + "token": token, + "groups": {"group1": "group1"} + }); + let res = server.send_flags_request(payload.to_string()).await; + assert_eq!(StatusCode::BAD_REQUEST, res.status()); + assert_eq!( + res.text().await?, + "The distinct_id field is missing from the request. Please include a valid identifier." + ); + Ok(()) +} + +#[tokio::test] +async fn it_rejects_missing_token() -> Result<()> { + let config = DEFAULT_TEST_CONFIG.clone(); + let server = ServerHandle::for_config(config).await; + + let payload = json!({ + "distinct_id": "user1", + "groups": {"group1": "group1"} + }); + let res = server.send_flags_request(payload.to_string()).await; + assert_eq!(StatusCode::UNAUTHORIZED, res.status()); + assert_eq!( + res.text().await?, + "No API key provided. Please include a valid API key in your request." + ); + Ok(()) +} + +#[tokio::test] +async fn it_rejects_invalid_token() -> Result<()> { + let config = DEFAULT_TEST_CONFIG.clone(); + let server = ServerHandle::for_config(config).await; + + let payload = json!({ + "token": "invalid_token", + "distinct_id": "user1", + "groups": {"group1": "group1"} + }); + let res = server.send_flags_request(payload.to_string()).await; + assert_eq!(StatusCode::UNAUTHORIZED, res.status()); + assert_eq!( + res.text().await?, + "The provided API key is invalid or has expired. Please check your API key and try again." + ); + Ok(()) +} + +#[tokio::test] +async fn it_handles_malformed_json() -> Result<()> { + let config = DEFAULT_TEST_CONFIG.clone(); + let server = ServerHandle::for_config(config).await; + + let payload = "{invalid_json}"; + let res = server.send_flags_request(payload.to_string()).await; + assert_eq!(StatusCode::BAD_REQUEST, res.status()); + assert!(res.text().await?.starts_with("Failed to parse request:")); + Ok(()) +} + +// TODO: we haven't implemented rate limiting in the new endpoint yet +// #[tokio::test] +// async fn it_handles_rate_limiting() -> Result<()> { +// let config = DEFAULT_TEST_CONFIG.clone(); +// let client = setup_redis_client(Some(config.redis_url.clone())); +// let team = insert_new_team_in_redis(client.clone()).await.unwrap(); +// let token = team.api_token; +// let server = ServerHandle::for_config(config).await; + +// // Simulate multiple requests to trigger rate limiting +// for _ in 0..100 { +// let payload = json!({ +// "token": token, +// "distinct_id": "user1", +// "groups": {"group1": "group1"} +// }); +// server.send_flags_request(payload.to_string()).await; +// } + +// // The next request should be rate limited +// let payload = json!({ +// "token": token, +// "distinct_id": "user1", +// "groups": {"group1": "group1"} +// }); +// let res = server.send_flags_request(payload.to_string()).await; +// assert_eq!(StatusCode::TOO_MANY_REQUESTS, res.status()); +// assert_eq!( +// res.text().await?, +// "Rate limit exceeded. Please reduce your request frequency and try again later." +// ); +// Ok(()) +// }