Skip to content

Commit

Permalink
Refactors
Browse files Browse the repository at this point in the history
  • Loading branch information
njaremko committed Jun 23, 2024
1 parent 8ccbcd2 commit 91c9461
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 95 deletions.
34 changes: 8 additions & 26 deletions src/crypto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use std::ffi::CString;
use std::str::FromStr;
use thiserror::Error;

use crate::signature::SignatureAlgorithm;
#[cfg(feature = "xmlsec")]
use crate::xmlsec::{self, XmlSecKey, XmlSecKeyFormat, XmlSecSignatureContext};
#[cfg(feature = "xmlsec")]
Expand Down Expand Up @@ -223,6 +224,7 @@ fn get_elements_by_predicate<F: FnMut(&libxml::tree::Node) -> bool>(
/// Searches for and returns the element with the given value of the `ID` attribute from the subtree
/// rooted at the given node.
#[cfg(feature = "xmlsec")]
#[allow(unused)]
fn get_element_by_id(elem: &libxml::tree::Node, id: &str) -> Option<libxml::tree::Node> {
let mut elems = get_elements_by_predicate(elem, |node| {
node.get_attribute("ID")
Expand Down Expand Up @@ -486,24 +488,6 @@ pub fn gen_saml_assertion_id() -> String {
format!("_{}", uuid::Uuid::new_v4())
}

#[derive(Debug, PartialEq)]
enum SigAlg {
Unimplemented,
RsaSha256,
EcdsaSha256,
}

impl FromStr for SigAlg {
type Err = Box<dyn std::error::Error>;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"http://www.w3.org/2001/04/xmldsig-more#rsa-sha256" => Ok(SigAlg::RsaSha256),
"http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha256" => Ok(SigAlg::EcdsaSha256),
_ => Ok(SigAlg::Unimplemented),
}
}
}

#[derive(Debug, Error, Clone)]
pub enum UrlVerifierError {
#[error("Unimplemented SigAlg: {:?}", sigalg)]
Expand Down Expand Up @@ -621,11 +605,9 @@ impl UrlVerifier {
.collect::<HashMap<String, String>>();

// Match against implemented SigAlg
let sig_alg: SigAlg = SigAlg::from_str(&query_params["SigAlg"])?;
if sig_alg == SigAlg::Unimplemented {
return Err(Box::new(UrlVerifierError::SigAlgUnimplemented {
sigalg: query_params["SigAlg"].clone(),
}));
let sig_alg = SignatureAlgorithm::from_str(&query_params["SigAlg"])?;
if let SignatureAlgorithm::Unsupported(sigalg) = sig_alg {
return Err(Box::new(UrlVerifierError::SigAlgUnimplemented { sigalg }));
}

// Construct a Url so that percent encoded query can be easily
Expand Down Expand Up @@ -668,13 +650,13 @@ impl UrlVerifier {
fn verify_signature(
&self,
data: &[u8],
sig_alg: SigAlg,
sig_alg: SignatureAlgorithm,
signature: &[u8],
) -> Result<bool, Box<dyn std::error::Error>> {
let mut verifier = openssl::sign::Verifier::new(
match sig_alg {
SigAlg::RsaSha256 => openssl::hash::MessageDigest::sha256(),
SigAlg::EcdsaSha256 => openssl::hash::MessageDigest::sha256(),
SignatureAlgorithm::RsaSha256 => openssl::hash::MessageDigest::sha256(),
SignatureAlgorithm::EcdsaSha256 => openssl::hash::MessageDigest::sha256(),
_ => panic!("sig_alg is bad!"),
},
&self.public_key,
Expand Down
52 changes: 40 additions & 12 deletions src/idp/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ pub mod verified_request;
mod tests;

use openssl::bn::{BigNum, MsbOption};
use openssl::ec::{EcGroup, EcKey};
use openssl::nid::Nid;
use openssl::pkey::Private;
use openssl::{asn1::Asn1Time, pkey, rsa::Rsa, x509};
use openssl::{asn1::Asn1Time, pkey, x509};
use std::str::FromStr;

use crate::crypto::{self};
Expand All @@ -24,22 +25,31 @@ pub struct IdentityProvider {
private_key: pkey::PKey<Private>,
}

pub enum KeyType {
pub enum Rsa {
Rsa2048,
Rsa3072,
Rsa4096,
}

impl KeyType {
impl Rsa {
fn bit_length(&self) -> u32 {
match &self {
KeyType::Rsa2048 => 2048,
KeyType::Rsa3072 => 3072,
KeyType::Rsa4096 => 4096,
Rsa::Rsa2048 => 2048,
Rsa::Rsa3072 => 3072,
Rsa::Rsa4096 => 4096,
}
}
}

pub enum Elliptic {
NISTP256,
}

pub enum KeyType {
Rsa(Rsa),
Elliptic(Elliptic),
}

pub struct CertificateParams<'a> {
pub common_name: &'a str,
pub issuer_name: &'a str,
Expand All @@ -48,22 +58,40 @@ pub struct CertificateParams<'a> {

impl IdentityProvider {
pub fn generate_new(key_type: KeyType) -> Result<Self, Error> {
let rsa = Rsa::generate(key_type.bit_length())?;
let private_key = pkey::PKey::from_rsa(rsa)?;
let private_key = match key_type {
KeyType::Rsa(rsa) => {
let bit_length = rsa.bit_length();
let rsa = openssl::rsa::Rsa::generate(bit_length)?;
pkey::PKey::from_rsa(rsa)?
}
KeyType::Elliptic(ecc) => {
let nid = match ecc {
Elliptic::NISTP256 => Nid::X9_62_PRIME256V1,
};
let group = EcGroup::from_curve_name(nid)?;
let private_key: EcKey<Private> = EcKey::generate(&group)?;
pkey::PKey::from_ec_key(private_key)?
}
};

Ok(IdentityProvider { private_key })
}

pub fn from_private_key_der(der_bytes: &[u8]) -> Result<Self, Error> {
let rsa = Rsa::private_key_from_der(der_bytes)?;
pub fn from_rsa_private_key_der(der_bytes: &[u8]) -> Result<Self, Error> {
let rsa = openssl::rsa::Rsa::private_key_from_der(der_bytes)?;
let private_key = pkey::PKey::from_rsa(rsa)?;

Ok(IdentityProvider { private_key })
}

pub fn export_private_key_der(&self) -> Result<Vec<u8>, Error> {
let rsa: Rsa<Private> = self.private_key.rsa()?;
Ok(rsa.private_key_to_der()?)
if let Ok(ec_key) = self.private_key.ec_key() {
Ok(ec_key.private_key_to_der()?)
} else if let Ok(rsa) = self.private_key.rsa() {
Ok(rsa.private_key_to_der()?)
} else {
Err(Error::UnexpectedError)?
}
}

pub fn create_certificate(&self, params: &CertificateParams) -> Result<Vec<u8>, Error> {
Expand Down
4 changes: 2 additions & 2 deletions src/idp/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ fn test_extract_sp() {
#[test]
fn test_signed_response() {
// init our IdP
let idp = IdentityProvider::from_private_key_der(include_bytes!(
let idp = IdentityProvider::from_rsa_private_key_der(include_bytes!(
"../../test_vectors/idp_private_key.der"
))
.expect("failed to create idp");
Expand Down Expand Up @@ -135,7 +135,7 @@ fn test_signed_response_threads() {

#[test]
fn test_signed_response_fingerprint() {
let idp = IdentityProvider::from_private_key_der(include_bytes!(
let idp = IdentityProvider::from_rsa_private_key_der(include_bytes!(
"../../test_vectors/idp_private_key.der"
))
.expect("failed to create idp");
Expand Down
63 changes: 45 additions & 18 deletions src/metadata/entity_descriptor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use chrono::prelude::*;
use quick_xml::events::{BytesDecl, BytesEnd, BytesStart, BytesText, Event};
use quick_xml::Writer;
use serde::Deserialize;
use std::collections::VecDeque;
use std::io::Cursor;
use std::str::FromStr;
use thiserror::Error;
Expand All @@ -29,18 +30,8 @@ pub enum EntityDescriptorType {
}

impl EntityDescriptorType {
pub fn take_first(self) -> Option<EntityDescriptor> {
match self {
EntityDescriptorType::EntitiesDescriptor(descriptor) => descriptor
.descriptors
.into_iter()
.next()
.and_then(|descriptor_type| match descriptor_type {
EntityDescriptorType::EntitiesDescriptor(_) => None,
EntityDescriptorType::EntityDescriptor(descriptor) => Some(descriptor),
}),
EntityDescriptorType::EntityDescriptor(descriptor) => Some(descriptor),
}
pub fn iter(&self) -> EntityDescriptorIterator {
EntityDescriptorIterator::new(self)
}
}

Expand Down Expand Up @@ -284,6 +275,39 @@ impl TryFrom<&EntityDescriptor> for Event<'_> {
}
}

#[derive(Clone)]
pub struct EntityDescriptorIterator<'a> {
queue: VecDeque<&'a EntityDescriptorType>,
}

impl<'a> EntityDescriptorIterator<'a> {
pub fn new(root: &'a EntityDescriptorType) -> Self {
let mut queue = VecDeque::new();
queue.push_back(root);
EntityDescriptorIterator { queue }
}
}

impl<'a> Iterator for EntityDescriptorIterator<'a> {
type Item = &'a EntityDescriptor;

fn next(&mut self) -> Option<Self::Item> {
while let Some(current) = self.queue.pop_front() {
match current {
EntityDescriptorType::EntitiesDescriptor(entities_descriptor) => {
for descriptor in &entities_descriptor.descriptors {
self.queue.push_back(descriptor);
}
}
EntityDescriptorType::EntityDescriptor(entity_descriptor) => {
return Some(entity_descriptor);
}
}
}
None
}
}

#[cfg(test)]
mod test {
use crate::traits::ToXml;
Expand Down Expand Up @@ -345,6 +369,7 @@ mod test {
.parse()
.expect("Failed to parse EntitiesDescriptor");

assert_eq!(2, reparsed_entities_descriptor.descriptors.len());
assert_eq!(reparsed_entities_descriptor, entities_descriptor);
}

Expand All @@ -369,11 +394,12 @@ mod test {
let expected_entity_descriptor: EntityDescriptor = input_xml
.parse()
.expect("Failed to parse idp_metadata.xml into an EntityDescriptor");
let entity_descriptor: EntityDescriptor = entity_descriptor_type
.take_first()
let entity_descriptor = entity_descriptor_type
.iter()
.next()
.expect("Failed to take first EntityDescriptor from EntityDescriptorType");

assert_eq!(expected_entity_descriptor, entity_descriptor);
assert_eq!(&expected_entity_descriptor, entity_descriptor);
}

#[test]
Expand Down Expand Up @@ -401,11 +427,12 @@ mod test {
let expected_entity_descriptor: EntityDescriptor = input_xml
.parse()
.expect("Failed to parse idp_metadata.xml into an EntityDescriptor");
let entity_descriptor: EntityDescriptor = entity_descriptor_type
.take_first()
let entity_descriptor = entity_descriptor_type
.iter()
.next()
.expect("Failed to take first EntityDescriptor from EntityDescriptorType");
println!("{entity_descriptor:#?}");

assert_eq!(expected_entity_descriptor, entity_descriptor);
assert_eq!(&expected_entity_descriptor, entity_descriptor);
}
}
Loading

0 comments on commit 91c9461

Please sign in to comment.