Skip to content

Commit

Permalink
Refactors
Browse files Browse the repository at this point in the history
  • Loading branch information
njaremko committed Jun 23, 2024
1 parent 09c4dda commit a6ac4ca
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 73 deletions.
33 changes: 7 additions & 26 deletions src/crypto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::ffi::CString;
use std::str::FromStr;
use thiserror::Error;

use crate::signature::SignatureAlgorithm;
#[cfg(feature = "xmlsec")]
use crate::xmlsec::{self, XmlSecKey, XmlSecKeyFormat, XmlSecSignatureContext};
#[cfg(feature = "xmlsec")]
Expand Down Expand Up @@ -486,24 +487,6 @@ pub fn gen_saml_assertion_id() -> String {
format!("_{}", uuid::Uuid::new_v4())
}

#[derive(Debug, PartialEq)]
enum SigAlg {
Unimplemented,
RsaSha256,
EcdsaSha256,
}

impl FromStr for SigAlg {
type Err = Box<dyn std::error::Error>;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"http://www.w3.org/2001/04/xmldsig-more#rsa-sha256" => Ok(SigAlg::RsaSha256),
"http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha256" => Ok(SigAlg::EcdsaSha256),
_ => Ok(SigAlg::Unimplemented),
}
}
}

#[derive(Debug, Error, Clone)]
pub enum UrlVerifierError {
#[error("Unimplemented SigAlg: {:?}", sigalg)]
Expand Down Expand Up @@ -621,11 +604,9 @@ impl UrlVerifier {
.collect::<HashMap<String, String>>();

// Match against implemented SigAlg
let sig_alg: SigAlg = SigAlg::from_str(&query_params["SigAlg"])?;
if sig_alg == SigAlg::Unimplemented {
return Err(Box::new(UrlVerifierError::SigAlgUnimplemented {
sigalg: query_params["SigAlg"].clone(),
}));
let sig_alg = SignatureAlgorithm::from_str(&query_params["SigAlg"])?;
if let SignatureAlgorithm::Unsupported(sigalg) = sig_alg {
return Err(Box::new(UrlVerifierError::SigAlgUnimplemented { sigalg }));
}

// Construct a Url so that percent encoded query can be easily
Expand Down Expand Up @@ -668,13 +649,13 @@ impl UrlVerifier {
fn verify_signature(
&self,
data: &[u8],
sig_alg: SigAlg,
sig_alg: SignatureAlgorithm,
signature: &[u8],
) -> Result<bool, Box<dyn std::error::Error>> {
let mut verifier = openssl::sign::Verifier::new(
match sig_alg {
SigAlg::RsaSha256 => openssl::hash::MessageDigest::sha256(),
SigAlg::EcdsaSha256 => openssl::hash::MessageDigest::sha256(),
SignatureAlgorithm::RsaSha256 => openssl::hash::MessageDigest::sha256(),
SignatureAlgorithm::EcdsaSha256 => openssl::hash::MessageDigest::sha256(),
_ => panic!("sig_alg is bad!"),
},
&self.public_key,
Expand Down
52 changes: 40 additions & 12 deletions src/idp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ pub mod verified_request;
mod tests;

use openssl::bn::{BigNum, MsbOption};
use openssl::ec::{EcGroup, EcKey};
use openssl::nid::Nid;
use openssl::pkey::Private;
use openssl::{asn1::Asn1Time, pkey, rsa::Rsa, x509};
use openssl::{asn1::Asn1Time, pkey, x509};
use std::str::FromStr;

use crate::crypto::{self};
Expand All @@ -24,22 +25,31 @@ pub struct IdentityProvider {
private_key: pkey::PKey<Private>,
}

pub enum KeyType {
pub enum Rsa {
Rsa2048,
Rsa3072,
Rsa4096,
}

impl KeyType {
impl Rsa {
fn bit_length(&self) -> u32 {
match &self {
KeyType::Rsa2048 => 2048,
KeyType::Rsa3072 => 3072,
KeyType::Rsa4096 => 4096,
Rsa::Rsa2048 => 2048,
Rsa::Rsa3072 => 3072,
Rsa::Rsa4096 => 4096,
}
}
}

pub enum Eliptic {
NISTP256,
}

pub enum KeyType {
Rsa(Rsa),
Eliptic(Eliptic),
}

pub struct CertificateParams<'a> {
pub common_name: &'a str,
pub issuer_name: &'a str,
Expand All @@ -48,22 +58,40 @@ pub struct CertificateParams<'a> {

impl IdentityProvider {
pub fn generate_new(key_type: KeyType) -> Result<Self, Error> {
let rsa = Rsa::generate(key_type.bit_length())?;
let private_key = pkey::PKey::from_rsa(rsa)?;
let private_key = match key_type {
KeyType::Rsa(rsa) => {
let bit_length = rsa.bit_length();
let rsa = openssl::rsa::Rsa::generate(bit_length)?;
pkey::PKey::from_rsa(rsa)?
}
KeyType::Eliptic(ecc) => {
let nid = match ecc {
Eliptic::NISTP256 => Nid::X9_62_PRIME256V1,
};
let group = EcGroup::from_curve_name(nid)?;
let private_key: EcKey<Private> = EcKey::generate(&group)?;
pkey::PKey::from_ec_key(private_key)?
}
};

Ok(IdentityProvider { private_key })
}

pub fn from_private_key_der(der_bytes: &[u8]) -> Result<Self, Error> {
let rsa = Rsa::private_key_from_der(der_bytes)?;
pub fn from_rsa_private_key_der(der_bytes: &[u8]) -> Result<Self, Error> {
let rsa = openssl::rsa::Rsa::private_key_from_der(der_bytes)?;
let private_key = pkey::PKey::from_rsa(rsa)?;

Ok(IdentityProvider { private_key })
}

pub fn export_private_key_der(&self) -> Result<Vec<u8>, Error> {
let rsa: Rsa<Private> = self.private_key.rsa()?;
Ok(rsa.private_key_to_der()?)
if let Ok(ec_key) = self.private_key.ec_key() {
Ok(ec_key.private_key_to_der()?)
} else if let Ok(rsa) = self.private_key.rsa() {
Ok(rsa.private_key_to_der()?)
} else {
Err(Error::UnexpectedError)?
}
}

pub fn create_certificate(&self, params: &CertificateParams) -> Result<Vec<u8>, Error> {
Expand Down
4 changes: 2 additions & 2 deletions src/idp/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ fn test_extract_sp() {
#[test]
fn test_signed_response() {
// init our IdP
let idp = IdentityProvider::from_private_key_der(include_bytes!(
let idp = IdentityProvider::from_rsa_private_key_der(include_bytes!(
"../../test_vectors/idp_private_key.der"
))
.expect("failed to create idp");
Expand Down Expand Up @@ -135,7 +135,7 @@ fn test_signed_response_threads() {

#[test]
fn test_signed_response_fingerprint() {
let idp = IdentityProvider::from_private_key_der(include_bytes!(
let idp = IdentityProvider::from_rsa_private_key_der(include_bytes!(
"../../test_vectors/idp_private_key.der"
))
.expect("failed to create idp");
Expand Down
89 changes: 56 additions & 33 deletions src/signature.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use quick_xml::events::{BytesEnd, BytesStart, BytesText, Event};
use quick_xml::Writer;
use serde::Deserialize;
use std::io::Cursor;
use std::str::FromStr;

const NAME: &str = "ds:Signature";
const SCHEMA: (&str, &str) = ("xmlns:ds", "http://www.w3.org/2000/09/xmldsig#");
Expand Down Expand Up @@ -33,32 +34,30 @@ impl Signature {
algorithm: SignatureAlgorithm::RsaSha256,
hmac_output_length: None,
},
reference: vec![
Reference {
transforms: Some(Transforms {
transforms: vec![
Transform {
algorithm: "http://www.w3.org/2000/09/xmldsig#enveloped-signature"
.to_string(),
xpath: None,
},
Transform {
algorithm: "http://www.w3.org/2001/10/xml-exc-c14n#".to_string(),
xpath: None,
},
],
}),
digest_method: DigestMethod {
algorithm: DigestAlgorithm::Sha1,
},
digest_value: Some(DigestValue {
base64_content: Some("".to_string()),
}),
uri: Some(format!("#{}", ref_id)),
reference_type: None,
id: None,
}
],
reference: vec![Reference {
transforms: Some(Transforms {
transforms: vec![
Transform {
algorithm: "http://www.w3.org/2000/09/xmldsig#enveloped-signature"
.to_string(),
xpath: None,
},
Transform {
algorithm: "http://www.w3.org/2001/10/xml-exc-c14n#".to_string(),
xpath: None,
},
],
}),
digest_method: DigestMethod {
algorithm: DigestAlgorithm::Sha1,
},
digest_value: Some(DigestValue {
base64_content: Some("".to_string()),
}),
uri: Some(format!("#{}", ref_id)),
reference_type: None,
id: None,
}],
},
signature_value: SignatureValue {
id: None,
Expand Down Expand Up @@ -294,22 +293,43 @@ impl TryFrom<&SignatureMethod> for Event<'_> {

#[derive(Clone, Debug, Deserialize, Hash, Eq, PartialEq, Ord, PartialOrd)]
pub enum SignatureAlgorithm {
#[serde(rename="http://www.w3.org/2001/04/xmldsig-more#rsa-sha256")]
#[serde(rename = "http://www.w3.org/2001/04/xmldsig-more#rsa-sha256")]
RsaSha256,
#[serde(rename="http://www.w3.org/2007/05/xmldsig-more#sha256-rsa-MGF1")]
#[serde(rename = "http://www.w3.org/2007/05/xmldsig-more#sha256-rsa-MGF1")]
Sha256RsaMGF1,
#[serde(rename = "http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha256")]
EcdsaSha256,
#[serde(untagged)]
Unsupported(String),
}

impl FromStr for SignatureAlgorithm {
type Err = Box<dyn std::error::Error>;

fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(match s {
"http://www.w3.org/2001/04/xmldsig-more#rsa-sha256" => SignatureAlgorithm::RsaSha256,
"http://www.w3.org/2007/05/xmldsig-more#sha256-rsa-MGF1" => {
SignatureAlgorithm::Sha256RsaMGF1
}
"http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha256" => {
SignatureAlgorithm::EcdsaSha256
}
i => SignatureAlgorithm::Unsupported(i.to_string()),
})
}
}

impl SignatureAlgorithm {
const RSA_SHA256: &'static str = "http://www.w3.org/2001/04/xmldsig-more#rsa-sha256";
const SHA256_RSA_MGF1: &'static str = "http://www.w3.org/2007/05/xmldsig-more#sha256-rsa-MGF1";
const SHA256_ECDSA: &'static str = "http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha256";

pub fn value(&self) -> &str {
match self {
SignatureAlgorithm::RsaSha256 => Self::RSA_SHA256,
SignatureAlgorithm::Sha256RsaMGF1 => Self::SHA256_RSA_MGF1,
SignatureAlgorithm::EcdsaSha256 => Self::SHA256_ECDSA,
SignatureAlgorithm::Unsupported(algo) => algo,
}
}
Expand Down Expand Up @@ -430,9 +450,9 @@ impl TryFrom<&DigestMethod> for Event<'_> {

#[derive(Clone, Debug, Deserialize, Hash, Eq, PartialEq, Ord, PartialOrd)]
pub enum DigestAlgorithm {
#[serde(rename="http://www.w3.org/2000/09/xmldsig#sha1")]
#[serde(rename = "http://www.w3.org/2000/09/xmldsig#sha1")]
Sha1,
#[serde(rename="http://www.w3.org/2001/04/xmlenc#sha256")]
#[serde(rename = "http://www.w3.org/2001/04/xmlenc#sha256")]
Sha256,
#[serde(untagged)]
Unsupported(String),
Expand Down Expand Up @@ -588,8 +608,10 @@ mod test {

#[test]
pub fn test_canonicalizationmethod_deserialization() -> Result<(), Box<dyn std::error::Error>> {
let canonicalization_method = r#"<ds:CanonicalizationMethod Algorithm="http://www.w3.org/2001/10/xml-exc-c14n#"/>"#;
let deserialized: CanonicalizationMethod = quick_xml::de::from_str(canonicalization_method)?;
let canonicalization_method =
r#"<ds:CanonicalizationMethod Algorithm="http://www.w3.org/2001/10/xml-exc-c14n#"/>"#;
let deserialized: CanonicalizationMethod =
quick_xml::de::from_str(canonicalization_method)?;
let serialized = deserialized.to_xml()?;
let re_deserialized: CanonicalizationMethod = quick_xml::de::from_str(&serialized)?;
assert_eq!(deserialized, re_deserialized);
Expand Down Expand Up @@ -627,7 +649,8 @@ mod test {

#[test]
pub fn test_digestmethod_deserialization() -> Result<(), Box<dyn std::error::Error>> {
let digest_method = r#"<ds:DigestMethod Algorithm="http://www.w3.org/2000/09/xmldsig#sha1" />"#;
let digest_method =
r#"<ds:DigestMethod Algorithm="http://www.w3.org/2000/09/xmldsig#sha1" />"#;
let deserialized: DigestMethod = quick_xml::de::from_str(digest_method)?;
let serialized = deserialized.to_xml()?;
let re_deserialized: DigestMethod = quick_xml::de::from_str(&serialized)?;
Expand Down

0 comments on commit a6ac4ca

Please sign in to comment.