Skip to content

Commit

Permalink
Add support for generating Azure Blob Storage signed URLs (#22)
Browse files Browse the repository at this point in the history
* add experimental support for Azure signed URLs

* remove dead code, adjust tests

* syntax error

* Revert "syntax error"

This reverts commit 16ac210.

* rewrite signer builders

* remove stray lifetime annotation

---------

Co-authored-by: Tim Dikland <tim.dikland@databricks.com>
  • Loading branch information
tdikland and TimDikland-DB authored Dec 29, 2023
1 parent ecf71bb commit ee20de9
Show file tree
Hide file tree
Showing 8 changed files with 536 additions and 208 deletions.
295 changes: 260 additions & 35 deletions Cargo.lock

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ async-session = "3.0.0"
async-trait = "0.1.64"
axum = { version = "0.6.20", features = ["headers"] }
axum-extra = { version = "0.8", features = ["json-lines"] }
azure_core = "0.17.0"
azure_storage = "0.17.0"
azure_storage_blobs = "0.17.0"
clap = "4.1.4"
deltalake = { version = "0.15.0", features = ["s3", "azure", "gcs"] }
futures = "0.3.28"
Expand Down Expand Up @@ -47,6 +50,7 @@ sqlx = { version = "0.7", features = [
strum = { version = "0.25", features = ["derive"] }
strum_macros = "0.25"
tame-gcs = { version = "0.12.0", features = ["signing"] }
time = { version = "0.3.30", features = ["local-offset"] }
tracing = "0.1.37"
tracing-log = "0.1.3"
tracing-subscriber = { version = "0.3.16", features = ["env-filter", "json"] }
Expand Down
13 changes: 13 additions & 0 deletions src/bootstrap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub(crate) mod gcp;
mod postgres;
use anyhow::Context;
use anyhow::Result;
use azure_storage::StorageCredentials;
use rusoto_credential::ProfileProvider;
use sqlx::PgPool;
use tame_gcs::signing::ServiceAccount;
Expand Down Expand Up @@ -32,3 +33,15 @@ pub(crate) fn new_aws_profile_provider() -> Result<ProfileProvider> {
std::env::var("AWS_PROFILE").context("failed to get `AWS_PROFILE` environment variable")?;
aws::new(&aws_profile)
}

pub(crate) fn new_azure_storage_account() -> Result<StorageCredentials> {
let azure_storage_account_name = std::env::var("AZURE_STORAGE_ACCOUNT_NAME")
.context("failed to get `AZURE_STORAGE_ACCOUNT_NAME` environment variable")?;
let azure_storage_account_key = std::env::var("AZURE_STORAGE_ACCOUNT_KEY")
.context("failed to get `AZURE_STORAGE_ACCOUNT_KEY` environment variable")?;

Ok(StorageCredentials::access_key(
azure_storage_account_name,
azure_storage_account_key,
))
}
20 changes: 17 additions & 3 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub(crate) mod utilities;

use anyhow::Context;
use anyhow::Result;
use azure_storage::StorageCredentials;
use rusoto_credential::AwsCredentials;
use rusoto_credential::ProvideAwsCredentials;
use sqlx::PgPool;
Expand Down Expand Up @@ -35,6 +36,7 @@ pub struct Server {
pg_pool: PgPool,
gcp_service_account: Option<ServiceAccount>,
aws_credentials: Option<AwsCredentials>,
azure_storage_credentials: Option<StorageCredentials>,
}

impl Server {
Expand All @@ -60,16 +62,28 @@ impl Server {
if aws_credentials.is_none() {
tracing::warn!("failed to load AWS credentials");
}

let azure_storage_credentials = bootstrap::new_azure_storage_account().ok();
if azure_storage_credentials.is_none() {
tracing::warn!("failed to load Azure Storage credentials");
}

Ok(Server {
pg_pool,
gcp_service_account,
aws_credentials,
azure_storage_credentials,
})
}

pub async fn start(self) -> Result<()> {
routers::bind(self.pg_pool, self.gcp_service_account, self.aws_credentials)
.await
.context("failed to start API server")
routers::bind(
self.pg_pool,
self.gcp_service_account,
self.aws_credentials,
self.azure_storage_credentials,
)
.await
.context("failed to start API server")
}
}
16 changes: 13 additions & 3 deletions src/server/routers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use axum::middleware;
use axum::response::Response;
use axum::routing::{get, post};
use axum::Router;
use azure_storage::StorageCredentials;
use rusoto_credential::AwsCredentials;
use sqlx::PgPool;
use tame_gcs::signing::ServiceAccount;
Expand All @@ -26,6 +27,7 @@ pub struct State {
pub pg_pool: PgPool,
pub gcp_service_account: Option<ServiceAccount>,
pub aws_credentials: Option<AwsCredentials>,
pub azure_credentials: Option<StorageCredentials>,
}

pub type SharedState = Arc<State>;
Expand All @@ -38,11 +40,13 @@ async fn route(
pg_pool: PgPool,
gcp_service_account: Option<ServiceAccount>,
aws_credentials: Option<AwsCredentials>,
azure_credentials: Option<StorageCredentials>,
) -> Result<Router> {
let state = Arc::new(State {
pg_pool,
gcp_service_account,
aws_credentials,
azure_credentials,
});

let swagger = SwaggerUi::new("/swagger-ui").url("/api-doc/openapi.json", ApiDoc::openapi());
Expand Down Expand Up @@ -127,10 +131,16 @@ pub async fn bind(
pg_pool: PgPool,
gcp_service_account: Option<ServiceAccount>,
aws_credentials: Option<AwsCredentials>,
azure_credentials: Option<StorageCredentials>,
) -> Result<()> {
let app = route(pg_pool, gcp_service_account, aws_credentials)
.await
.context("failed to create axum router")?;
let app = route(
pg_pool,
gcp_service_account,
aws_credentials,
azure_credentials,
)
.await
.context("failed to create axum router")?;
let server_bind = config::fetch::<String>("server_bind");
let addr = server_bind.as_str().parse().context(format!(
r#"failed to parse "{}" to SocketAddr"#,
Expand Down
98 changes: 55 additions & 43 deletions src/server/routers/shares/schemas/tables/query.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use anyhow::anyhow;
use anyhow::Context;
use axum::extract::Extension;
use axum::extract::Json;
use axum::extract::Path;
Expand All @@ -10,6 +11,8 @@ use axum::response::IntoResponse;
use axum::response::Response;
use axum_extra::json_lines::JsonLines;
use std::str::FromStr;
use std::time::Duration;
use tame_gcs::signing::ServiceAccount;
use utoipa::IntoParams;
use utoipa::ToSchema;

Expand All @@ -26,6 +29,7 @@ use crate::server::utilities::json::PartitionFilter as JSONPartitionFilter;
use crate::server::utilities::json::PredicateJson;
use crate::server::utilities::json::Utility as JSONUtility;
use crate::server::utilities::signed_url::Platform;
use crate::server::utilities::signed_url::Signer;
use crate::server::utilities::signed_url::Utility as SignedUrlUtility;
use crate::server::utilities::sql::PartitionFilter as SQLPartitionFilter;
use crate::server::utilities::sql::Utility as SQLUtility;
Expand Down Expand Up @@ -161,46 +165,51 @@ pub async fn post(
};
metadata.to_owned()
};
let url_signer = |name: String| match &platform {
Platform::Aws { url, bucket, path } => {
if let Some(aws_credentials) = &state.aws_credentials {
let file: String = format!("{}/{}", path, name);
let Ok(signed) = SignedUrlUtility::sign_aws(
aws_credentials,
bucket,
&file,
&config::fetch::<u64>("signed_url_ttl"),
) else {
tracing::error!("failed to sign up AWS S3 url");
return url.clone();
};
return signed.into();
let url_signer: Box<dyn Signer> = match &platform {
Platform::Aws => {
if let Some(creds) = &state.aws_credentials {
Box::new(SignedUrlUtility::aws_signer(
creds.clone(),
Duration::from_secs(config::fetch::<u64>("signed_url_ttl")),
))
} else {
tracing::error!("No credentials found for AWS S3");
return Err(anyhow!("Error occurred while signing URLs").into());
}
tracing::warn!("AWS credentials were not set");
url.clone()
}
Platform::Gcp { url, bucket, path } => {
if let Some(gcp_service_account) = &state.gcp_service_account {
let file: String = format!("{}/{}", path, name);
let Ok(signed) = SignedUrlUtility::sign_gcp(
gcp_service_account,
bucket,
&file,
&config::fetch::<u64>("signed_url_ttl"),
) else {
tracing::error!("failed to sign up GCP GCS url");
return url.clone();
};
return signed.into();
Platform::Azure => {
if let Some(creds) = &state.azure_credentials {
Box::new(SignedUrlUtility::azure_signer(
creds.clone(),
Duration::from_secs(config::fetch::<u64>("signed_url_ttl")),
))
} else {
tracing::error!("No credentials found for Azure Blob Storage");
return Err(anyhow!("Error occurred while signing URLs").into());
}
tracing::warn!("GCP service account was not set");
url.clone()
}
Platform::None { url } => {
tracing::warn!("no supported platforms");
url.clone()
Platform::Gcp => {
if let Some(_) = &state.gcp_service_account {
let creds = ServiceAccount::load_json_file(
std::env::var("GOOGLE_APPLICATION_CREDENTIALS")
.context("failed to load GCP credentials")?,
)
.context("failed to load GCP credentials")?;
Box::new(SignedUrlUtility::gcp_signer(
creds,
Duration::from_secs(config::fetch::<u64>("signed_url_ttl")),
))
} else {
tracing::error!("No credentials found for GCP GCS");
return Err(anyhow!("Error occurred while signing URLs").into());
}
}
_ => {
tracing::error!("requested cloud platform is not supported");
return Err(anyhow!("Error occurred while signing URLs").into());
}
};

let mut headers = HeaderMap::new();
headers.insert(HEADER_NAME, table.version().into());
headers.insert(
Expand All @@ -211,15 +220,18 @@ pub async fn post(
Ok((
StatusCode::OK,
headers,
JsonLines::new(DeltalakeService::files_from(
table,
metadata,
predicate_hints,
json_predicate_hints,
payload.limit_hint,
is_time_traveled,
&url_signer,
)),
JsonLines::new(
DeltalakeService::files_from(
table,
metadata,
predicate_hints,
json_predicate_hints,
payload.limit_hint,
is_time_traveled,
&url_signer,
)
.await,
),
)
.into_response())
}
32 changes: 17 additions & 15 deletions src/server/services/deltalake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use utoipa::ToSchema;
use crate::server::utilities::deltalake::Utility as DeltalakeUtility;
use crate::server::utilities::json::PartitionFilter as JSONPartitionFilter;
use crate::server::utilities::json::Utility as JSONUtility;
use crate::server::utilities::signed_url::Signer;
use crate::server::utilities::sql::PartitionFilter as SQLPartitionFilter;
use crate::server::utilities::sql::Utility as SQLUtility;

Expand Down Expand Up @@ -116,12 +117,7 @@ pub struct File {
}

impl File {
fn from(
add: Add,
version: Option<i64>,
timestamp: Option<i64>,
url_signer: &dyn Fn(String) -> String,
) -> Self {
fn from(add: Add, version: Option<i64>, timestamp: Option<i64>) -> Self {
let mut partition_values: HashMap<String, String> = HashMap::new();
for (k, v) in add.partition_values.into_iter() {
if let Some(v) = v {
Expand All @@ -131,7 +127,7 @@ impl File {
Self {
file: FileDetail {
id: format!("{:x}", md5::compute(add.path.as_bytes())),
url: url_signer(add.path),
url: add.path,
partition_values,
size: add.size,
stats: add.stats,
Expand All @@ -140,6 +136,10 @@ impl File {
},
}
}

async fn sign<S: Signer>(&mut self, url_signer: &S) {
self.file.url = url_signer.sign(&self.file.url).await.unwrap();
}
}

pub struct Service;
Expand Down Expand Up @@ -220,14 +220,14 @@ impl Service {
files
}

pub fn files_from(
pub async fn files_from<S: Signer>(
table: DeltaTable,
metadata: DeltaTableMetaData,
predicate_hints: Option<Vec<SQLPartitionFilter>>,
json_predicate_hints: Option<JSONPartitionFilter>,
limit_hint: Option<i32>,
is_time_traveled: bool,
url_signer: &dyn Fn(String) -> String,
url_signer: &S,
) -> impl Stream<Item = Result<serde_json::Value, BoxError>> {
let version = if is_time_traveled {
Some(table.version())
Expand All @@ -247,14 +247,16 @@ impl Service {
let files =
Self::filter_with_json_hints(files, table.schema().cloned(), json_predicate_hints);
let files = Self::filter_with_limit_hint(files, limit_hint);
let mut files = files
let futures = files
.into_iter()
.map(|f| {
Ok::<serde_json::Value, BoxError>(json!(File::from(
f, version, timestamp, url_signer
)))
.map(|f| async {
let mut file = File::from(f, version, timestamp);
file.sign(url_signer).await;
Ok::<serde_json::Value, BoxError>(json!(file))
})
.collect::<Vec<Result<serde_json::Value, BoxError>>>();
.collect::<Vec<_>>();
let mut files = futures::future::join_all(futures).await;

let mut ret = vec![
Ok(json!(Protocol::new())),
Ok(json!(Metadata::from(metadata))),
Expand Down
Loading

0 comments on commit ee20de9

Please sign in to comment.