diff --git a/examples/benchmark/multi.py b/examples/benchmark/multi.py index 94457502..b09f93de 100644 --- a/examples/benchmark/multi.py +++ b/examples/benchmark/multi.py @@ -127,15 +127,16 @@ def run( results = [] - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) + timer_event = getattr(torch, "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") + start = timer_event.Event(enable_timing=True) + end = timer_event.Event(enable_timing=True) for _ in tqdm(range(iterations)): start.record() out_tensor = stream.stream(image_tensor).cpu() queue.put(out_tensor) end.record() - torch.cuda.synchronize() + timer_event.synchronize() results.append(start.elapsed_time(end)) print(f"Average time: {sum(results) / len(results)}ms") diff --git a/examples/benchmark/single.py b/examples/benchmark/single.py index 6a183f9a..e65ddd66 100644 --- a/examples/benchmark/single.py +++ b/examples/benchmark/single.py @@ -112,8 +112,9 @@ def run( results = [] - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) + timer_event = getattr(torch, "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu") + start = timer_event.Event(enable_timing=True) + end = timer_event.Event(enable_timing=True) for _ in tqdm(range(iterations)): start.record() @@ -121,7 +122,7 @@ def run( stream(image=image_tensor) end.record() - torch.cuda.synchronize() + timer_event.synchronize() results.append(start.elapsed_time(end)) print(f"Average time: {sum(results) / len(results)}ms") diff --git a/src/streamdiffusion/pipeline.py b/src/streamdiffusion/pipeline.py index 66c08c8f..ca0b6e06 100644 --- a/src/streamdiffusion/pipeline.py +++ b/src/streamdiffusion/pipeline.py @@ -1,4 +1,5 @@ import time + from typing import List, Optional, Union, Any, Dict, Tuple, Literal import numpy as np @@ -30,6 +31,8 @@ def __init__( self.dtype = torch_dtype self.generator = None + self.timer_event = getattr(torch, str(self.device).split(':', 1)[0]) + self.height = height self.width = width @@ -440,8 +443,8 @@ def predict_x0_batch(self, x_t_latent: torch.Tensor) -> torch.Tensor: def __call__( self, x: Union[torch.Tensor, PIL.Image.Image, np.ndarray] = None ) -> torch.Tensor: - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) + start = self.timer_event.Event(enable_timing=True) + end = self.timer_event.Event(enable_timing=True) start.record() if x is not None: x = self.image_processor.preprocess(x, self.height, self.width).to( @@ -463,7 +466,7 @@ def __call__( self.prev_image_result = x_output end.record() - torch.cuda.synchronize() + self.timer_event.synchronize() inference_time = start.elapsed_time(end) / 1000 self.inference_time_ema = 0.9 * self.inference_time_ema + 0.1 * inference_time return x_output diff --git a/utils/wrapper.py b/utils/wrapper.py index cb49bcc2..365e0a8c 100644 --- a/utils/wrapper.py +++ b/utils/wrapper.py @@ -28,7 +28,7 @@ def __init__( output_type: Literal["pil", "pt", "np", "latent"] = "pil", lcm_lora_id: Optional[str] = None, vae_id: Optional[str] = None, - device: Literal["cpu", "cuda"] = "cuda", + device: Literal["cpu", "cuda", "mps"] = "cuda", dtype: torch.dtype = torch.float16, frame_buffer_size: int = 1, width: int = 512, @@ -463,174 +463,176 @@ def _load_model( stream.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd").to( device=pipe.device, dtype=pipe.dtype ) - - try: - if acceleration == "xformers": - stream.pipe.enable_xformers_memory_efficient_attention() - if acceleration == "tensorrt": - from polygraphy import cuda - from streamdiffusion.acceleration.tensorrt import ( - TorchVAEEncoder, - compile_unet, - compile_vae_decoder, - compile_vae_encoder, - ) - from streamdiffusion.acceleration.tensorrt.engine import ( - AutoencoderKLEngine, - UNet2DConditionModelEngine, - ) - from streamdiffusion.acceleration.tensorrt.models import ( - VAE, - UNet, - VAEEncoder, - ) - - def create_prefix( - model_id_or_path: str, - max_batch_size: int, - min_batch_size: int, - ): - maybe_path = Path(model_id_or_path) - if maybe_path.exists(): - return f"{maybe_path.stem}--lcm_lora-{use_lcm_lora}--tiny_vae-{use_tiny_vae}--max_batch-{max_batch_size}--min_batch-{min_batch_size}--mode-{self.mode}" - else: - return f"{model_id_or_path}--lcm_lora-{use_lcm_lora}--tiny_vae-{use_tiny_vae}--max_batch-{max_batch_size}--min_batch-{min_batch_size}--mode-{self.mode}" - - engine_dir = Path(engine_dir) - unet_path = os.path.join( - engine_dir, - create_prefix( - model_id_or_path=model_id_or_path, - max_batch_size=stream.trt_unet_batch_size, - min_batch_size=stream.trt_unet_batch_size, - ), - "unet.engine", - ) - vae_encoder_path = os.path.join( - engine_dir, - create_prefix( - model_id_or_path=model_id_or_path, - max_batch_size=self.batch_size - if self.mode == "txt2img" - else stream.frame_bff_size, - min_batch_size=self.batch_size - if self.mode == "txt2img" - else stream.frame_bff_size, - ), - "vae_encoder.engine", - ) - vae_decoder_path = os.path.join( - engine_dir, - create_prefix( - model_id_or_path=model_id_or_path, - max_batch_size=self.batch_size - if self.mode == "txt2img" - else stream.frame_bff_size, - min_batch_size=self.batch_size - if self.mode == "txt2img" - else stream.frame_bff_size, - ), - "vae_decoder.engine", - ) - - if not os.path.exists(unet_path): - os.makedirs(os.path.dirname(unet_path), exist_ok=True) - unet_model = UNet( - fp16=True, - device=stream.device, - max_batch_size=stream.trt_unet_batch_size, - min_batch_size=stream.trt_unet_batch_size, - embedding_dim=stream.text_encoder.config.hidden_size, - unet_dim=stream.unet.config.in_channels, + if self.device == "mps": + print("Currently acceleration is not avaiable on mps device. Using normal mode instead.") + else: + try: + if acceleration == "xformers": + stream.pipe.enable_xformers_memory_efficient_attention() + if acceleration == "tensorrt": + from polygraphy import cuda + from streamdiffusion.acceleration.tensorrt import ( + TorchVAEEncoder, + compile_unet, + compile_vae_decoder, + compile_vae_encoder, + ) + from streamdiffusion.acceleration.tensorrt.engine import ( + AutoencoderKLEngine, + UNet2DConditionModelEngine, ) - compile_unet( - stream.unet, - unet_model, - unet_path + ".onnx", - unet_path + ".opt.onnx", - unet_path, - opt_batch_size=stream.trt_unet_batch_size, + from streamdiffusion.acceleration.tensorrt.models import ( + VAE, + UNet, + VAEEncoder, ) - if not os.path.exists(vae_decoder_path): - os.makedirs(os.path.dirname(vae_decoder_path), exist_ok=True) - stream.vae.forward = stream.vae.decode - vae_decoder_model = VAE( - device=stream.device, - max_batch_size=self.batch_size - if self.mode == "txt2img" - else stream.frame_bff_size, - min_batch_size=self.batch_size - if self.mode == "txt2img" - else stream.frame_bff_size, + def create_prefix( + model_id_or_path: str, + max_batch_size: int, + min_batch_size: int, + ): + maybe_path = Path(model_id_or_path) + if maybe_path.exists(): + return f"{maybe_path.stem}--lcm_lora-{use_lcm_lora}--tiny_vae-{use_tiny_vae}--max_batch-{max_batch_size}--min_batch-{min_batch_size}--mode-{self.mode}" + else: + return f"{model_id_or_path}--lcm_lora-{use_lcm_lora}--tiny_vae-{use_tiny_vae}--max_batch-{max_batch_size}--min_batch-{min_batch_size}--mode-{self.mode}" + + engine_dir = Path(engine_dir) + unet_path = os.path.join( + engine_dir, + create_prefix( + model_id_or_path=model_id_or_path, + max_batch_size=stream.trt_unet_batch_size, + min_batch_size=stream.trt_unet_batch_size, + ), + "unet.engine", ) - compile_vae_decoder( - stream.vae, - vae_decoder_model, - vae_decoder_path + ".onnx", - vae_decoder_path + ".opt.onnx", - vae_decoder_path, - opt_batch_size=self.batch_size - if self.mode == "txt2img" - else stream.frame_bff_size, + vae_encoder_path = os.path.join( + engine_dir, + create_prefix( + model_id_or_path=model_id_or_path, + max_batch_size=self.batch_size + if self.mode == "txt2img" + else stream.frame_bff_size, + min_batch_size=self.batch_size + if self.mode == "txt2img" + else stream.frame_bff_size, + ), + "vae_encoder.engine", ) - delattr(stream.vae, "forward") - - if not os.path.exists(vae_encoder_path): - os.makedirs(os.path.dirname(vae_encoder_path), exist_ok=True) - vae_encoder = TorchVAEEncoder(stream.vae).to(torch.device("cuda")) - vae_encoder_model = VAEEncoder( - device=stream.device, - max_batch_size=self.batch_size - if self.mode == "txt2img" - else stream.frame_bff_size, - min_batch_size=self.batch_size - if self.mode == "txt2img" - else stream.frame_bff_size, + vae_decoder_path = os.path.join( + engine_dir, + create_prefix( + model_id_or_path=model_id_or_path, + max_batch_size=self.batch_size + if self.mode == "txt2img" + else stream.frame_bff_size, + min_batch_size=self.batch_size + if self.mode == "txt2img" + else stream.frame_bff_size, + ), + "vae_decoder.engine", ) - compile_vae_encoder( - vae_encoder, - vae_encoder_model, - vae_encoder_path + ".onnx", - vae_encoder_path + ".opt.onnx", + + if not os.path.exists(unet_path): + os.makedirs(os.path.dirname(unet_path), exist_ok=True) + unet_model = UNet( + fp16=True, + device=stream.device, + max_batch_size=stream.trt_unet_batch_size, + min_batch_size=stream.trt_unet_batch_size, + embedding_dim=stream.text_encoder.config.hidden_size, + unet_dim=stream.unet.config.in_channels, + ) + compile_unet( + stream.unet, + unet_model, + unet_path + ".onnx", + unet_path + ".opt.onnx", + unet_path, + opt_batch_size=stream.trt_unet_batch_size, + ) + + if not os.path.exists(vae_decoder_path): + os.makedirs(os.path.dirname(vae_decoder_path), exist_ok=True) + stream.vae.forward = stream.vae.decode + vae_decoder_model = VAE( + device=stream.device, + max_batch_size=self.batch_size + if self.mode == "txt2img" + else stream.frame_bff_size, + min_batch_size=self.batch_size + if self.mode == "txt2img" + else stream.frame_bff_size, + ) + compile_vae_decoder( + stream.vae, + vae_decoder_model, + vae_decoder_path + ".onnx", + vae_decoder_path + ".opt.onnx", + vae_decoder_path, + opt_batch_size=self.batch_size + if self.mode == "txt2img" + else stream.frame_bff_size, + ) + delattr(stream.vae, "forward") + + if not os.path.exists(vae_encoder_path): + os.makedirs(os.path.dirname(vae_encoder_path), exist_ok=True) + vae_encoder = TorchVAEEncoder(stream.vae).to(torch.device("cuda")) + vae_encoder_model = VAEEncoder( + device=stream.device, + max_batch_size=self.batch_size + if self.mode == "txt2img" + else stream.frame_bff_size, + min_batch_size=self.batch_size + if self.mode == "txt2img" + else stream.frame_bff_size, + ) + compile_vae_encoder( + vae_encoder, + vae_encoder_model, + vae_encoder_path + ".onnx", + vae_encoder_path + ".opt.onnx", + vae_encoder_path, + opt_batch_size=self.batch_size + if self.mode == "txt2img" + else stream.frame_bff_size, + ) + + cuda_steram = cuda.Stream() + + vae_config = stream.vae.config + vae_dtype = stream.vae.dtype + + stream.unet = UNet2DConditionModelEngine( + unet_path, cuda_steram, use_cuda_graph=False + ) + stream.vae = AutoencoderKLEngine( vae_encoder_path, - opt_batch_size=self.batch_size - if self.mode == "txt2img" - else stream.frame_bff_size, + vae_decoder_path, + cuda_steram, + stream.pipe.vae_scale_factor, + use_cuda_graph=False, ) + setattr(stream.vae, "config", vae_config) + setattr(stream.vae, "dtype", vae_dtype) - cuda_steram = cuda.Stream() + gc.collect() + torch.cuda.empty_cache() - vae_config = stream.vae.config - vae_dtype = stream.vae.dtype - - stream.unet = UNet2DConditionModelEngine( - unet_path, cuda_steram, use_cuda_graph=False - ) - stream.vae = AutoencoderKLEngine( - vae_encoder_path, - vae_decoder_path, - cuda_steram, - stream.pipe.vae_scale_factor, - use_cuda_graph=False, - ) - setattr(stream.vae, "config", vae_config) - setattr(stream.vae, "dtype", vae_dtype) - - gc.collect() - torch.cuda.empty_cache() - - print("TensorRT acceleration enabled.") - if acceleration == "sfast": - from streamdiffusion.acceleration.sfast import ( - accelerate_with_stable_fast, - ) + print("TensorRT acceleration enabled.") + if acceleration == "sfast": + from streamdiffusion.acceleration.sfast import ( + accelerate_with_stable_fast, + ) - stream = accelerate_with_stable_fast(stream) - print("StableFast acceleration enabled.") - except Exception: - traceback.print_exc() - print("Acceleration has failed. Falling back to normal mode.") + stream = accelerate_with_stable_fast(stream) + print("StableFast acceleration enabled.") + except Exception: + traceback.print_exc() + print("Acceleration has failed. Falling back to normal mode.") if seed < 0: # Random seed seed = np.random.randint(0, 1000000)