diff --git a/crates/protocol/src/errors.rs b/crates/protocol/src/errors.rs index b76ec1d59..5e70cc4c0 100644 --- a/crates/protocol/src/errors.rs +++ b/crates/protocol/src/errors.rs @@ -31,6 +31,8 @@ pub enum GenericProtocolError { Broadcast(#[from] Box>), #[error("Mpsc send error: {0}")] Mpsc(#[from] tokio::sync::mpsc::error::SendError), + #[error("Could not get session out of Arc")] + ArcUnwrapError, } impl From for GenericProtocolError { @@ -61,6 +63,7 @@ impl From>> GenericProtocolError::IncomingStream(err) => ProtocolExecutionErr::IncomingStream(err), GenericProtocolError::Broadcast(err) => ProtocolExecutionErr::Broadcast(err), GenericProtocolError::Mpsc(err) => ProtocolExecutionErr::Mpsc(err), + GenericProtocolError::ArcUnwrapError => ProtocolExecutionErr::ArcUnwrapError, } } } @@ -73,6 +76,7 @@ impl From>> for ProtocolE GenericProtocolError::IncomingStream(err) => ProtocolExecutionErr::IncomingStream(err), GenericProtocolError::Broadcast(err) => ProtocolExecutionErr::Broadcast(err), GenericProtocolError::Mpsc(err) => ProtocolExecutionErr::Mpsc(err), + GenericProtocolError::ArcUnwrapError => ProtocolExecutionErr::ArcUnwrapError, } } } @@ -85,6 +89,7 @@ impl From>> for Prot GenericProtocolError::IncomingStream(err) => ProtocolExecutionErr::IncomingStream(err), GenericProtocolError::Broadcast(err) => ProtocolExecutionErr::Broadcast(err), GenericProtocolError::Mpsc(err) => ProtocolExecutionErr::Mpsc(err), + GenericProtocolError::ArcUnwrapError => ProtocolExecutionErr::ArcUnwrapError, } } } @@ -97,6 +102,7 @@ impl From>> for ProtocolEx GenericProtocolError::IncomingStream(err) => ProtocolExecutionErr::IncomingStream(err), GenericProtocolError::Broadcast(err) => ProtocolExecutionErr::Broadcast(err), GenericProtocolError::Mpsc(err) => ProtocolExecutionErr::Mpsc(err), + GenericProtocolError::ArcUnwrapError => ProtocolExecutionErr::ArcUnwrapError, } } } @@ -136,6 +142,8 @@ pub enum ProtocolExecutionErr { BadVerifyingKey(String), #[error("Expected verifying key but got a protocol message")] UnexpectedMessage, + #[error("Could not get session out of Arc")] + ArcUnwrapError, } #[derive(Debug, Error)] diff --git a/crates/protocol/src/execute_protocol.rs b/crates/protocol/src/execute_protocol.rs index b5db81501..5c25ab652 100644 --- a/crates/protocol/src/execute_protocol.rs +++ b/crates/protocol/src/execute_protocol.rs @@ -139,7 +139,9 @@ where let tx_clone = process_tx.clone(); tokio::spawn(async move { let result = session_clone.process_message(&mut OsRng, preprocessed); - tx_clone.send(result).await.unwrap(); + if tx_clone.send(result).await.is_err() { + tracing::error!("Protocol finished before message processing result sent"); + } }); } } else { @@ -166,22 +168,18 @@ where // tx.incoming_sender.send(message).await?; // } - // Get session back out of Arc and Mutex - if let Ok(session_inner) = Arc::try_unwrap(session_arc) { - // let session_inner = session_inner.into_inner().unwrap(); - match session_inner.finalize_round(&mut OsRng, accum)? { - // match session_arc.finalize_round(&mut OsRng, accum)? { - FinalizeOutcome::Success(res) => break Ok((res, chans)), - FinalizeOutcome::AnotherRound { - session: new_session, - cached_messages: new_cached_messages, - } => { - session = new_session; - cached_messages = new_cached_messages; - }, - } - } else { - panic!("Cannot get session out of Arc"); + // Get session back out of Arc + let session_inner = + Arc::try_unwrap(session_arc).map_err(|_| GenericProtocolError::ArcUnwrapError)?; + match session_inner.finalize_round(&mut OsRng, accum)? { + FinalizeOutcome::Success(res) => break Ok((res, chans)), + FinalizeOutcome::AnotherRound { + session: new_session, + cached_messages: new_cached_messages, + } => { + session = new_session; + cached_messages = new_cached_messages; + }, } } }