Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stable diffusion 3.5 support. #2578

Merged
merged 5 commits into from
Oct 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 48 additions & 2 deletions candle-examples/examples/stable-diffusion-3/clip.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use anyhow::{Error as E, Ok, Result};
use candle::{DType, IndexOp, Module, Tensor, D};
use candle_transformers::models::{stable_diffusion, t5};
use std::path::PathBuf;
use tokenizers::tokenizer::Tokenizer;

struct ClipWithTokenizer {
Expand Down Expand Up @@ -130,6 +131,53 @@ pub struct StableDiffusion3TripleClipWithTokenizer {
}

impl StableDiffusion3TripleClipWithTokenizer {
pub fn new_split(
clip_g_file: &PathBuf,
clip_l_file: &PathBuf,
t5xxl_file: &PathBuf,
device: &candle::Device,
) -> Result<Self> {
let vb_clip_g = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[clip_g_file], DType::F16, device)?
};
let vb_clip_l = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[clip_l_file], DType::F16, device)?
};
let vb_t5 = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[t5xxl_file], DType::F32, device)?
};
let max_position_embeddings = 77usize;
let clip_l = ClipWithTokenizer::new(
vb_clip_l,
stable_diffusion::clip::Config::sdxl(),
"openai/clip-vit-large-patch14",
max_position_embeddings,
)?;

let text_projection =
candle_nn::linear_no_bias(1280, 1280, vb_clip_g.pp("text_projection"))?;

let clip_g = ClipWithTokenizer::new(
vb_clip_g,
stable_diffusion::clip::Config::sdxl2(),
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
max_position_embeddings,
)?;

// Current T5 implementation does not support fp16, so we use fp32 VarBuilder for T5.
// This is a temporary workaround until the T5 implementation is updated to support fp16.
// Also see:
// https://github.com/huggingface/candle/issues/2480
// https://github.com/huggingface/candle/pull/2481
let t5 = T5WithTokenizer::new(vb_t5, max_position_embeddings)?;
Ok(Self {
clip_l,
clip_g,
clip_g_text_projection: text_projection,
t5,
})
}

pub fn new(vb_fp16: candle_nn::VarBuilder, vb_fp32: candle_nn::VarBuilder) -> Result<Self> {
let max_position_embeddings = 77usize;
let clip_l = ClipWithTokenizer::new(
Expand Down Expand Up @@ -158,7 +206,6 @@ impl StableDiffusion3TripleClipWithTokenizer {
// https://github.com/huggingface/candle/issues/2480
// https://github.com/huggingface/candle/pull/2481
let t5 = T5WithTokenizer::new(vb_fp32.pp("t5xxl.transformer"), max_position_embeddings)?;

Ok(Self {
clip_l,
clip_g,
Expand Down Expand Up @@ -195,7 +242,6 @@ impl StableDiffusion3TripleClipWithTokenizer {
.encode_text_to_embedding(prompt, device)?
.to_dtype(DType::F16)?;
let context = Tensor::cat(&[&clip_embeddings_concat, &t5_embeddings], D::Minus2)?;

Ok((context, y))
}
}
198 changes: 117 additions & 81 deletions candle-examples/examples/stable-diffusion-3/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,25 @@ use crate::vae::{build_sd3_vae_autoencoder, sd3_vae_vb_rename};
use anyhow::{Ok, Result};
use clap::Parser;

#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
enum Which {
#[value(name = "3-medium")]
V3Medium,
#[value(name = "3.5-large")]
V3_5Large,
#[value(name = "3.5-large-turbo")]
V3_5LargeTurbo,
}

impl Which {
fn is_3_5(&self) -> bool {
match self {
Self::V3Medium => false,
Self::V3_5Large | Self::V3_5LargeTurbo => true,
}
}
}

#[derive(Parser)]
#[command(author, version, about, long_about = None)]
struct Args {
Expand All @@ -30,10 +49,6 @@ struct Args {
#[arg(long)]
cpu: bool,

/// The GPU device ID to use.
#[arg(long, default_value_t = 0)]
gpu_device_id: usize,

/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
Expand All @@ -50,13 +65,17 @@ struct Args {
#[arg(long, default_value_t = 1024)]
width: usize,

/// The model to use.
#[arg(long, default_value = "3-medium")]
which: Which,

/// The seed to use when generating random samples.
#[arg(long, default_value_t = 28)]
num_inference_steps: usize,
#[arg(long)]
num_inference_steps: Option<usize>,

// CFG scale.
#[arg(long, default_value_t = 4.0)]
cfg_scale: f64,
#[arg(long)]
cfg_scale: Option<f64>,

// Time shift factor (alpha).
#[arg(long, default_value_t = 3.0)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed SAI's reference implementation of sd3.5 seems to change the recommended time_shift for sd3.0-medium from 3.0 to 1.0, but it's for dpmpp_2m sampler (code here). This wasn't the case in their original reference impl for sd3.0-medium, which is always 3.0 but for euler sampling (code here).

Not sure if it's worth to change the default behavior here. The difference should be small, though.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, as both the scheduler and the parameter changed I preferred keeping everything as they were before and not modify the existing behavior.

Expand All @@ -68,20 +87,13 @@ struct Args {
}

fn main() -> Result<()> {
let args = Args::parse();
// Your main code here
run(args)
}

fn run(args: Args) -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;

let Args {
prompt,
uncond_prompt,
cpu,
gpu_device_id,
tracing,
use_flash_attn,
height,
Expand All @@ -90,7 +102,8 @@ fn run(args: Args) -> Result<()> {
cfg_scale,
time_shift,
seed,
} = args;
which,
} = Args::parse();

let _guard = if tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
Expand All @@ -100,87 +113,110 @@ fn run(args: Args) -> Result<()> {
None
};

let device = if cpu {
candle::Device::Cpu
} else if candle::utils::cuda_is_available() {
candle::Device::new_cuda(gpu_device_id)?
} else if candle::utils::metal_is_available() {
candle::Device::new_metal(gpu_device_id)?
} else {
candle::Device::Cpu
let device = candle_examples::device(cpu)?;
let default_inference_steps = match which {
Which::V3_5Large => 28,
Which::V3_5LargeTurbo => 4,
Which::V3Medium => 28,
};
let num_inference_steps = num_inference_steps.unwrap_or(default_inference_steps);
let default_cfg_scale = match which {
Which::V3_5Large => 4.0,
Which::V3_5LargeTurbo => 1.0,
Which::V3Medium => 4.0,
};
let cfg_scale = cfg_scale.unwrap_or(default_cfg_scale);

let api = hf_hub::api::sync::Api::new()?;
let sai_repo = {
let name = "stabilityai/stable-diffusion-3-medium";
api.repo(hf_hub::Repo::model(name.to_string()))
};
let model_file = sai_repo.get("sd3_medium_incl_clips_t5xxlfp16.safetensors")?;
let vb_fp16 = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file.clone()], DType::F16, &device)?
};
let (mmdit_config, mut triple, vb) = if which.is_3_5() {
let sai_repo = {
let name = match which {
Which::V3_5Large => "stabilityai/stable-diffusion-3.5-large",
Which::V3_5LargeTurbo => "stabilityai/stable-diffusion-3.5-large-turbo",
Which::V3Medium => unreachable!(),
};
api.repo(hf_hub::Repo::model(name.to_string()))
};
let clip_g_file = sai_repo.get("text_encoders/clip_g.safetensors")?;
let clip_l_file = sai_repo.get("text_encoders/clip_l.safetensors")?;
let t5xxl_file = sai_repo.get("text_encoders/t5xxl_fp16.safetensors")?;
let model_file = {
let model_file = match which {
Which::V3_5Large => "sd3.5_large.safetensors",
Which::V3_5LargeTurbo => "sd3.5_large_turbo.safetensors",
Which::V3Medium => unreachable!(),
};
sai_repo.get(model_file)?
};
let triple = StableDiffusion3TripleClipWithTokenizer::new_split(
&clip_g_file,
&clip_l_file,
&t5xxl_file,
&device,
)?;
let vb = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F16, &device)?
};
(MMDiTConfig::sd3_5_large(), triple, vb)
} else {
let sai_repo = {
let name = "stabilityai/stable-diffusion-3-medium";
api.repo(hf_hub::Repo::model(name.to_string()))
};
let model_file = sai_repo.get("sd3_medium_incl_clips_t5xxlfp16.safetensors")?;
let vb_fp16 = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(&[&model_file], DType::F16, &device)?
};

let (context, y) = {
let vb_fp32 = unsafe {
candle_nn::VarBuilder::from_mmaped_safetensors(
&[model_file.clone()],
DType::F32,
&device,
)?
candle_nn::VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)?
};
let mut triple = StableDiffusion3TripleClipWithTokenizer::new(
let triple = StableDiffusion3TripleClipWithTokenizer::new(
vb_fp16.pp("text_encoders"),
vb_fp32.pp("text_encoders"),
)?;
let (context, y) = triple.encode_text_to_embedding(prompt.as_str(), &device)?;
let (context_uncond, y_uncond) =
triple.encode_text_to_embedding(uncond_prompt.as_str(), &device)?;
(
Tensor::cat(&[context, context_uncond], 0)?,
Tensor::cat(&[y, y_uncond], 0)?,
)
};

let x = {
let mmdit = MMDiT::new(
&MMDiTConfig::sd3_medium(),
use_flash_attn,
vb_fp16.pp("model.diffusion_model"),
)?;

if let Some(seed) = seed {
device.set_seed(seed)?;
}
let start_time = std::time::Instant::now();
let x = sampling::euler_sample(
&mmdit,
&y,
&context,
num_inference_steps,
cfg_scale,
time_shift,
height,
width,
)?;
let dt = start_time.elapsed().as_secs_f32();
println!(
"Sampling done. {num_inference_steps} steps. {:.2}s. Average rate: {:.2} iter/s",
dt,
num_inference_steps as f32 / dt
);
x
(MMDiTConfig::sd3_medium(), triple, vb_fp16)
};
let (context, y) = triple.encode_text_to_embedding(prompt.as_str(), &device)?;
let (context_uncond, y_uncond) =
triple.encode_text_to_embedding(uncond_prompt.as_str(), &device)?;
let context = Tensor::cat(&[context, context_uncond], 0)?;
let y = Tensor::cat(&[y, y_uncond], 0)?;

let mmdit = MMDiT::new(
&mmdit_config,
use_flash_attn,
vb.pp("model.diffusion_model"),
)?;

if let Some(seed) = seed {
device.set_seed(seed)?;
}
let start_time = std::time::Instant::now();
let x = sampling::euler_sample(
&mmdit,
&y,
&context,
num_inference_steps,
cfg_scale,
time_shift,
height,
width,
)?;
let dt = start_time.elapsed().as_secs_f32();
println!(
"Sampling done. {num_inference_steps} steps. {:.2}s. Average rate: {:.2} iter/s",
dt,
num_inference_steps as f32 / dt
);

let img = {
let vb_vae = vb_fp16
.clone()
.rename_f(sd3_vae_vb_rename)
.pp("first_stage_model");
let vb_vae = vb.rename_f(sd3_vae_vb_rename).pp("first_stage_model");
let autoencoder = build_sd3_vae_autoencoder(vb_vae)?;

// Apply TAESD3 scale factor. Seems to be significantly improving the quality of the image.
// https://github.com/comfyanonymous/ComfyUI/blob/3c60ecd7a83da43d694e26a77ca6b93106891251/nodes.py#L721-L723
autoencoder.decode(&((x.clone() / 1.5305)? + 0.0609)?)?
autoencoder.decode(&((x / 1.5305)? + 0.0609)?)?
};
let img = ((img.clamp(-1f32, 1f32)? + 1.0)? * 127.5)?.to_dtype(candle::DType::U8)?;
candle_examples::save_image(&img.i(0)?, "out.jpg")?;
Expand Down
2 changes: 1 addition & 1 deletion candle-examples/examples/stable-diffusion-3/sampling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ pub fn euler_sample(

let timestep = (*s_curr) * 1000.0;
let noise_pred = mmdit.forward(
&Tensor::cat(&[x.clone(), x.clone()], 0)?,
&Tensor::cat(&[&x, &x], 0)?,
&Tensor::full(timestep as f32, (2,), x.device())?.contiguous()?,
y,
context,
Expand Down
14 changes: 14 additions & 0 deletions candle-transformers/src/models/mmdit/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@ impl Config {
frequency_embedding_size: 256,
}
}

pub fn sd3_5_large() -> Self {
Self {
patch_size: 2,
in_channels: 16,
out_channels: 16,
depth: 38,
head_size: 64,
adm_in_channels: 2048,
pos_embed_max_size: 192,
context_embed_size: 4096,
frequency_embedding_size: 256,
}
}
}

pub struct MMDiT {
Expand Down
Loading
Loading