Skip to content

Commit

Permalink
feat: add --context-type CLI option
Browse files Browse the repository at this point in the history
Signed-off-by: Xin Liu <sam@secondstate.io>
  • Loading branch information
apepkuss committed Sep 27, 2024
1 parent 9a66231 commit 46a83cb
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 11 deletions.
9 changes: 3 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ edition = "2021"
[dependencies]
anyhow = "1"
clap = { version = "4.4.6", features = ["cargo", "derive"] }
endpoints = { version = "=0.14.2" }
endpoints = { version = "=0.14.2", git = "https://github.com/LlamaEdge/LlamaEdge.git", branch = "dev" }
hyper = { version = "0.14", features = ["full"] }
llama-core = { version = "=0.18.1", features = ["logging"] }
llama-core = { version = "=0.18.1", features = [
"logging",
], git = "https://github.com/LlamaEdge/LlamaEdge.git", branch = "dev" }
log = { version = "0.4.21", features = ["std", "kv", "kv_serde"] }
multipart-2021 = "0.19.0"
serde = { version = "1.0", features = ["derive"] }
Expand Down
39 changes: 36 additions & 3 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ mod error;
mod utils;

use anyhow::Result;
use clap::{ArgGroup, Parser};
use clap::{ArgGroup, Parser, ValueEnum};
use error::ServerError;
use hyper::{
body::HttpBody,
Expand Down Expand Up @@ -52,6 +52,9 @@ struct Cli {
/// Number of threads to use during computation. Default is -1, which means to use all available threads.
#[arg(long, default_value = "-1")]
threads: i32,
/// Context to create for the model.
#[arg(long, default_value = "full")]
context_type: ContextType,
/// Socket address of LlamaEdge API Server instance. For example, `0.0.0.0:8080`.
#[arg(long, default_value = None, value_parser = clap::value_parser!(SocketAddr), group = "socket_address_group")]
socket_addr: Option<SocketAddr>,
Expand Down Expand Up @@ -87,13 +90,19 @@ async fn main() -> Result<(), ServerError> {
// log model name
info!(target: "stdout", "model_name: {}", cli.model_name);

// log context type
info!(target: "stdout", "context_type: {:?}", cli.context_type);

// Determine which model option is set
if !cli.model.is_empty() {
info!(target: "stdout", "model: {}", &cli.model);

// initialize the stable diffusion context
llama_core::init_sd_context_with_full_model(&cli.model)
.map_err(|e| ServerError::Operation(format!("{}", e)))?;
llama_core::init_sd_context_with_full_model(
&cli.model,
cli.context_type.to_sd_context_type(),
)
.map_err(|e| ServerError::Operation(format!("{}", e)))?;
} else if !cli.diffusion_model.is_empty() {
// if diffusion_model is not empty, check if diffusion_model is a valid path
if !PathBuf::from(&cli.diffusion_model).exists() {
Expand Down Expand Up @@ -144,6 +153,7 @@ async fn main() -> Result<(), ServerError> {
&cli.t5xxl,
&cli.lora_model_dir,
cli.threads,
cli.context_type.to_sd_context_type(),
)
.map_err(|e| ServerError::Operation(format!("{}", e)))?;
} else {
Expand Down Expand Up @@ -240,3 +250,26 @@ async fn handle_request(req: Request<Body>) -> Result<Response<Body>, hyper::Err

Ok(response)
}

/// The context to use for the model.
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
enum ContextType {
/// `text_to_image` context.
#[value(name = "text-to-image")]
TextToImage,
/// `image_to_image` context.
#[value(name = "image-to-image")]
ImageToText,
/// Both `text_to_image` and `image_to_image` contexts.
#[value(name = "full")]
Full,
}
impl ContextType {
fn to_sd_context_type(&self) -> llama_core::SDContextType {

Check failure on line 268 in src/main.rs

View workflow job for this annotation

GitHub Actions / build-wasm (ubuntu-22.04)

methods with the following characteristics: (`to_*` and `self` type is `Copy`) usually take `self` by value
match self {
ContextType::TextToImage => llama_core::SDContextType::TextToImage,
ContextType::ImageToText => llama_core::SDContextType::ImageToImage,
ContextType::Full => llama_core::SDContextType::Full,
}
}
}

0 comments on commit 46a83cb

Please sign in to comment.