Skip to content

Commit

Permalink
Ensured the print buffer is flushed before program exit. Added an opt…
Browse files Browse the repository at this point in the history
…ion to print the plain URL in short format

Signed-off-by: David Pollak <feeder.of.the.bears@gmail.com>
  • Loading branch information
dpp committed Feb 28, 2024
1 parent c9253f3 commit 22af5dc
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 22 deletions.
1 change: 1 addition & 0 deletions omnibor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ build-binary = [
"tokio/macros",
"tokio/rt",
"tokio/sync",
"tokio/time",
"tokio/rt-multi-thread"
]

Expand Down
92 changes: 70 additions & 22 deletions omnibor/src/bin/omnibor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use omnibor::Sha256;
use serde_json::json;
use serde_json::Value as JsonValue;
use smart_default::SmartDefault;
use tokio::time::sleep;
use std::default::Default;
use std::fmt::Display;
use std::fmt::Formatter;
Expand All @@ -21,6 +22,9 @@ use std::path::Path;
use std::path::PathBuf;
use std::process::ExitCode;
use std::str::FromStr;
use std::sync::atomic::AtomicBool;
use std::sync::Arc;
use std::time::Duration;
use tokio::fs::File as AsyncFile;
use tokio::io::AsyncWrite;
use tokio::io::AsyncWriteExt as _;
Expand All @@ -32,31 +36,52 @@ use url::Url;
async fn main() -> ExitCode {
let args = Cli::parse();

let printing_done = Arc::new(AtomicBool::new(false));
let printing_done_2 = printing_done.clone();

// TODO(alilleybrinker): Make this channel Msg limit configurable.
let (tx, mut rx) = mpsc::channel::<Msg>(args.buffer.unwrap_or(100));
let (tx, mut rx) = mpsc::channel::<MsgOrEnd>(args.buffer.unwrap_or(100));

// Do all printing in a separate task we spawn to _just_ do printing.
// This stops printing from blocking the worker tasks.
tokio::spawn(async move {
while let Some(msg) = rx.recv().await {
// TODO(alilleybrinker): Handle this error.
let _ = msg.print().await;
match msg {
MsgOrEnd::End => break,
MsgOrEnd::Message(msg) => {
// TODO(alilleybrinker): Handle this error.
let _ = msg.print().await;
}
}
}
rx.close();
printing_done_2.store(true, std::sync::atomic::Ordering::Relaxed);
});

let result = match args.command {
Command::Id(ref args) => run_id(&tx, args).await,
Command::Find(ref args) => run_find(&tx, args).await,
};

let mut return_code = ExitCode::SUCCESS;

if let Err(e) = result {
// TODO(alilleybrinker): Handle this erroring out, probably by
// sync-printing as a last resort.
let _ = tx.send(Msg::error(e, args.format())).await;
return ExitCode::FAILURE;
let _ = tx.send(MsgOrEnd::error(e, args.format())).await;
return_code = ExitCode::FAILURE;
}


// send a message to end the printing
tx.send(MsgOrEnd::End).await.unwrap();

// wait until the printing is done
while !printing_done.load(std::sync::atomic::Ordering::Relaxed) {
sleep(Duration::from_millis(10)).await;
}

ExitCode::SUCCESS
return_code
}

/*===========================================================================
Expand Down Expand Up @@ -104,6 +129,10 @@ struct IdArgs {
/// Hash algorithm (can be "sha256")
#[arg(short = 'H', long = "hash", default_value_t)]
hash: SelectedHash,

/// Should the messages be short (just contain the gitoid)?
#[arg(short = 's', long = "short")]
short: bool,
}

#[derive(Debug, Args)]
Expand Down Expand Up @@ -172,20 +201,37 @@ impl FromStr for SelectedHash {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
enum MsgOrEnd {
End,
Message(Msg),
}

impl MsgOrEnd {
fn id(path: &Path, url: &Url, format: Format, short: bool) -> Self {
MsgOrEnd::Message(Msg::id(path, url, format, short))
}

fn error<E: Into<Error>>(error: E, format: Format) -> MsgOrEnd {
MsgOrEnd::Message(Msg::error(error, format))
}
}

#[derive(Debug, Clone)]
struct Msg {
content: Content,
status: Status,
}

impl Msg {
fn id(path: &Path, url: &Url, format: Format) -> Self {
fn id(path: &Path, url: &Url, format: Format, short: bool) -> Self {
let status = Status::Success;
let path = path.display().to_string();
let url = url.to_string();

match format {
Format::Plain => Msg::plain(status, &format!("{} => {}", path, url)),
Format::Plain if !short => Msg::plain(status, &format!("{} => {}", path, url)),
Format::Plain => Msg::plain(status, &format!("{}", url)),
Format::Json => Msg::json(status, json!({ "path": path, "id": url })),
}
}
Expand Down Expand Up @@ -235,7 +281,7 @@ impl Msg {
}
}

#[derive(Debug)]
#[derive(Debug, Clone)]
enum Content {
Json(JsonValue),
Plain(String),
Expand All @@ -250,7 +296,7 @@ impl Display for Content {
}
}

#[derive(Debug)]
#[derive(Debug, Clone, Copy)]
enum Status {
Success,
Error,
Expand All @@ -261,18 +307,18 @@ enum Status {
*-------------------------------------------------------------------------*/

/// Run the `id` subcommand.
async fn run_id(tx: &Sender<Msg>, args: &IdArgs) -> Result<()> {
async fn run_id(tx: &Sender<MsgOrEnd>, args: &IdArgs) -> Result<()> {
let mut file = open_async_file(&args.path).await?;

if file_is_dir(&file).await? {
id_directory(tx, &args.path, args.format, args.hash).await
id_directory(tx, &args.path, args.format, args.hash, args.short).await
} else {
id_file(tx, &mut file, &args.path, args.format, args.hash).await
id_file(tx, &mut file, &args.path, args.format, args.hash, args.short).await
}
}

/// Run the `find` subcommand.
async fn run_find(tx: &Sender<Msg>, args: &FindArgs) -> Result<()> {
async fn run_find(tx: &Sender<MsgOrEnd>, args: &FindArgs) -> Result<()> {
let FindArgs { url, path, format } = args;

let id = ArtifactId::<Sha256>::id_url(url.clone())?;
Expand All @@ -283,7 +329,7 @@ async fn run_find(tx: &Sender<Msg>, args: &FindArgs) -> Result<()> {
loop {
match entries.next().await {
None => break,
Some(Err(e)) => tx.send(Msg::error(e, *format)).await?,
Some(Err(e)) => tx.send(MsgOrEnd::Message(Msg::error(e, *format))).await?,
Some(Ok(entry)) => {
let path = &entry.path();

Expand All @@ -295,7 +341,7 @@ async fn run_find(tx: &Sender<Msg>, args: &FindArgs) -> Result<()> {
let file_url = hash_file(SelectedHash::Sha256, &mut file, &path).await?;

if url == file_url {
tx.send(Msg::id(&path, &url, *format)).await?;
tx.send(MsgOrEnd::id(&path, &url, *format, false)).await?;
return Ok(());
}
}
Expand All @@ -311,17 +357,18 @@ async fn run_find(tx: &Sender<Msg>, args: &FindArgs) -> Result<()> {

// Identify, recursively, all the files under a directory.
async fn id_directory(
tx: &Sender<Msg>,
tx: &Sender<MsgOrEnd>,
path: &Path,
format: Format,
hash: SelectedHash,
short: bool
) -> Result<()> {
let mut entries = WalkDir::new(path);

loop {
match entries.next().await {
None => break,
Some(Err(e)) => tx.send(Msg::error(e, format)).await?,
Some(Err(e)) => tx.send(MsgOrEnd::error(e, format)).await?,
Some(Ok(entry)) => {
let path = &entry.path();

Expand All @@ -330,7 +377,7 @@ async fn id_directory(
}

let mut file = open_async_file(&path).await?;
id_file(tx, &mut file, &path, format, hash).await?;
id_file(tx, &mut file, &path, format, hash, short).await?;
}
}
}
Expand All @@ -340,14 +387,15 @@ async fn id_directory(

/// Identify a single file.
async fn id_file(
tx: &Sender<Msg>,
tx: &Sender<MsgOrEnd>,
file: &mut AsyncFile,
path: &Path,
format: Format,
hash: SelectedHash,
short: bool
) -> Result<()> {
let url = hash_file(hash, file, &path).await?;
tx.send(Msg::id(path, &url, format)).await?;
tx.send(MsgOrEnd::id(path, &url, format, short)).await?;
Ok(())
}

Expand Down

0 comments on commit 22af5dc

Please sign in to comment.