diff --git a/colabs/diffusers/sdxl-text-to-image.ipynb b/colabs/diffusers/sdxl-text-to-image.ipynb index 444717ee..65663461 100644 --- a/colabs/diffusers/sdxl-text-to-image.ipynb +++ b/colabs/diffusers/sdxl-text-to-image.ipynb @@ -44,7 +44,11 @@ "source": [ "import torch\n", "import wandb\n", - "from diffusers import DiffusionPipeline, EulerDiscreteScheduler" + "from diffusers import (\n", + " StableDiffusionXLPipeline,\n", + " StableDiffusionXLImg2ImgPipeline,\n", + " EulerDiscreteScheduler\n", + ")" ] }, { @@ -62,36 +66,58 @@ "metadata": {}, "outputs": [], "source": [ + "project_name = \"stable-diffusion-xl\" # @param {type:\"string\"}\n", + "\n", "# initialize a wandb run\n", - "wandb.init(project=\"stable-diffusion-xl\", job_type=\"text-to-image\")\n", + "wandb.init(project=project_name, job_type=\"text-to-image\")\n", "\n", "# define experiment configs\n", "config = wandb.config\n", - "config.stable_diffusion_checkpoint = \"stabilityai/stable-diffusion-xl-base-1.0\"\n", - "config.refiner_checkpoint = \"stabilityai/stable-diffusion-xl-refiner-1.0\"\n", - "config.offload_to_cpu = False\n", + "config.stable_diffusion_checkpoint = \"stabilityai/stable-diffusion-xl-base-1.0\" # @param [\"stabilityai/stable-diffusion-xl-base-1.0\", \"stabilityai/stable-diffusion-xl-base-0.9\"] {allow-input: true}\n", + "config.refiner_checkpoint = \"stabilityai/stable-diffusion-xl-refiner-1.0\" # @param [\"stabilityai/stable-diffusion-xl-refiner-1.0\", \"stabilityai/stable-diffusion-xl-refiner-0.9\"] {allow-input: true}\n", "config.compile_model = False\n", - "config.prompt_1 = \"a photograph of an evil and vile looking demon in Bengali attire eating fish. The demon has large and bloody teeth. The demon is sitting on the branches of a giant Banyan tree, dimly lit, bluish and dark color palette, realistic, 8k\"\n", - "config.prompt_2 = \"\" # Leave blank if you want both text encoders to use the same prompt\n", - "config.negative_prompt_1 = \"static, painting, illustration, sd character, low quality, low resolution, greyscale, monochrome, nose, cropped, lowres, jpeg artifacts, deformed iris, deformed pupils, bad eyes, semi-realistic worst quality, bad lips, deformed mouth, deformed face, deformed fingers, deformed toes standing still, posing\"\n", - "config.negative_prompt_2 = \"static, painting, illustration, sd character, low quality, low resolution, greyscale, monochrome, nose, cropped, lowres, jpeg artifacts, deformed iris, deformed pupils, bad eyes, semi-realistic worst quality, bad lips, deformed mouth, deformed face, deformed fingers, deformed toes standing still, posing\"\n", - "config.seed = None\n", - "config.use_ensemble_of_experts = True\n", - "config.num_inference_steps = 100\n", - "config.num_refinement_steps = 150\n", - "config.high_noise_fraction = 0.8 # Set explicitly only if config.use_ensemble_of_experts is True\n", - "config.scheduler_kwargs = {\n", + "config.prompt_1 = \"a photograph of an evil and vile looking demon in Bengali attire eating fish. The demon has large and bloody teeth. The demon is sitting on the branches of a giant Banyan tree, dimly lit, bluish and dark color palette, realistic, 8k\" # @param {type:\"string\"}\n", + "config.prompt_2 = \"\" # @param {type:\"string\"}\n", + "config.negative_prompt_1 = \"static, frame, painting, illustration, sd character, low quality, low resolution, greyscale, monochrome, nose, cropped, lowres, jpeg artifacts, deformed iris, deformed pupils, bad eyes, semi-realistic worst quality, bad lips, deformed mouth, deformed face, deformed fingers, deformed toes standing still, posing\" # @param {type:\"string\"}\n", + "config.negative_prompt_2 = \"static, frame, painting, illustration, sd character, low quality, low resolution, greyscale, monochrome, nose, cropped, lowres, jpeg artifacts, deformed iris, deformed pupils, bad eyes, semi-realistic worst quality, bad lips, deformed mouth, deformed face, deformed fingers, deformed toes standing still, posing\" # @param {type:\"string\"}\n", + "config.base_guidance_scale = 5.0 # @param {type:\"slider\", min:1, max:10, step:0.1}\n", + "config.seed = 0 # @param {type:\"raw\"}\n", + "config.num_inference_steps = 100 # @param {type:\"slider\", min:1, max:500, step:1}\n", + "\n", + "config.enable_cpu_offload_base = True # @param {type:\"boolean\"}\n", + "config.enable_cpu_offload_refiner = True # @param {type:\"boolean\"}\n", + "\n", + "config.compile_base_model = False # @param {type:\"boolean\"}\n", + "\n", + "# Enable refinement only if high-ram instance\n", + "config.enable_refinement = False # @param {type:\"boolean\"}\n", + "config.compile_refinement_model = False # @param {type:\"boolean\"}\n", + "config.refiner_guidance_scale = 5.0 # @param {type:\"slider\", min:1, max:10, step:0.1}\n", + "config.num_refinement_steps = 150 # @param {type:\"slider\", min:1, max:500, step:1}\n", + "\n", + "# Set explicitly only if config.use_ensemble_of_experts is True\n", + "config.high_noise_fraction = 0.8 # @param {type:\"slider\", min:0, max:1, step:0.1}\n", + "\n", + "beta_schedule = \"scaled_linear\" # @param [\"linear\", \"scaled_linear\"]\n", + "interpolation_type = \"linear\" # @param [\"linear\", \"log_linear\"] {allow-input: true}\n", + "prediction_type = \"epsilon\" # @param [\"epsilon\", \"sample\", \"v_prediction\"]\n", + "timestep_spacing = \"leading\" # @param [\"linspace\", \"leading\"] {allow-input: true}\n", + "\n", + "# configs for diffusers.EulerDiscreteScheduler\n", + "scheduler_kwargs = {\n", " \"beta_end\": 0.012,\n", - " \"beta_schedule\": \"scaled_linear\", # one of [\"linear\", \"scaled_linear\"]\n", + " \"beta_schedule\": beta_schedule,\n", " \"beta_start\": 0.00085,\n", - " \"interpolation_type\": \"linear\", # one of [\"linear\", \"log_linear\"]\n", + " \"interpolation_type\": interpolation_type,\n", " \"num_train_timesteps\": 1000,\n", - " \"prediction_type\": \"epsilon\", # one of [\"epsilon\", \"sample\", \"v_prediction\"]\n", + " \"prediction_type\": prediction_type,\n", " \"steps_offset\": 1,\n", - " \"timestep_spacing\": \"leading\", # one of [\"linspace\", \"leading\"]\n", + " \"timestep_spacing\": timestep_spacing,\n", " \"trained_betas\": None,\n", " \"use_karras_sigmas\": False,\n", - "}" + "}\n", + "\n", + "config.scheduler_kwargs = scheduler_kwargs" ] }, { @@ -107,19 +133,26 @@ "metadata": {}, "outputs": [], "source": [ - "if config.seed is not None:\n", - " generator = [torch.Generator(device=\"cuda\").manual_seed(config.seed)]\n", - "else:\n", - " generator = [torch.Generator(device=\"cuda\")]" + "generator = [torch.Generator(device=\"cuda\")]\n", + "if config.seed:\n", + " generator = [g.manual_seed(config.seed) for g in generator]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "## Creating the Diffusion Pipelines\n", + "## The Base Diffusion Pipelines\n", + "\n", + "For performing text-conditional image generation, we use the `diffusers` library to define the diffusion pipelines corresponding to the base SDXL model and the SDXL refinement model.\n", + "\n", + "1. We define the base diffusion pipeline using `diffusers.DiffusionPipeline` and load the pre-trained weights for SDXL 1.0 by calling the `from_pretrained` function on it. We also pass the scheduler as `diffusers.EulerDiscreteScheduler` in this step.\n", + "\n", + "2. In case we don't have a GPU with large enough GPU, it's recommended to enable CPU offloading. Otherwise, we load the model on the GPU. In case you're curious how HiggingFace manages CPU offloading in the most optimized manner, we recommend you read this port by [Sylvain Gugger](https://huggingface.co/sgugger): [How 🤗 Accelerate runs very large models thanks to PyTorch](https://huggingface.co/blog/accelerate-large-models).\n", "\n", - "For performing text-conditional image generation, we use the `diffusers` library to define the diffusion pipelines corresponding to the base SDXL model and the SDXL refinement model." + "3. We can compile model using `torch.compile`, this might give a significant speedup.\n", + "\n", + "4. We generate the image from the prompts and negative prompts using the base pipeline." ] }, { @@ -128,8 +161,8 @@ "metadata": {}, "outputs": [], "source": [ - "# Define base model\n", - "pipe = DiffusionPipeline.from_pretrained(\n", + "# Define the Base Pipeline\n", + "pipe = StableDiffusionXLPipeline.from_pretrained(\n", " config.stable_diffusion_checkpoint,\n", " torch_dtype=torch.float16,\n", " variant=\"fp16\",\n", @@ -137,72 +170,46 @@ " scheduler=EulerDiscreteScheduler(**config.scheduler_kwargs),\n", ")\n", "\n", - "# Offload to CPU in case of OOM\n", - "if config.offload_to_cpu:\n", + "if config.enable_cpu_offload_base:\n", + " # Offload base pipeline to CPU\n", " pipe.enable_model_cpu_offload()\n", "else:\n", + " # Load base pipeline to GPU\n", " pipe.to(\"cuda\")\n", "\n", "# Compile model using `torch.compile`, this might give a significant speedup\n", - "if config.compile_model:\n", - " pipe.unet = torch.compile(pipe.unet, mode=\"reduce-overhead\", fullgraph=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Define base model\n", - "refiner = DiffusionPipeline.from_pretrained(\n", - " config.refiner_checkpoint,\n", - " text_encoder_2=pipe.text_encoder_2,\n", - " vae=pipe.vae,\n", - " torch_dtype=torch.float16,\n", - " use_safetensors=True,\n", - " variant=\"fp16\",\n", - " scheduler=EulerDiscreteScheduler(**config.scheduler_kwargs),\n", - ")\n", - "\n", - "# Offload to CPU in case of OOM\n", - "if config.offload_to_cpu:\n", - " refiner.enable_model_cpu_offload()\n", - "else:\n", - " refiner.to(\"cuda\")\n", + "if config.compile_base_model:\n", + " pipe.unet = torch.compile(pipe.unet, mode=\"reduce-overhead\", fullgraph=True)\n", "\n", - "# Compile model using `torch.compile`, this might give a significant speedup\n", - "if config.compile_model:\n", - " refiner.unet = torch.compile(refiner.unet, mode=\"reduce-overhead\", fullgraph=True)" + "# Generate image from the prompts and negative prompts using the base pipeline\n", + "generated_image = pipe(\n", + " prompt=config.prompt_1,\n", + " prompt_2=config.prompt_2,\n", + " negative_prompt=config.negative_prompt_1,\n", + " negative_prompt_2=config.negative_prompt_2,\n", + " guidance_scale=config.base_guidance_scale,\n", + " output_type=\"latent\" if config.enable_refinement else \"pil\",\n", + " num_inference_steps=config.num_inference_steps,\n", + " denoising_end=config.high_noise_fraction if config.enable_refinement else None,\n", + " generator=generator,\n", + ").images" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "We now define a utility function to postprocess the latents obtained from the base model." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def postprocess_latent(latent):\n", - " vae_output = pipe.vae.decode(\n", - " latent.images / pipe.vae.config.scaling_factor, return_dict=False\n", - " )[0].detach()\n", - " return pipe.image_processor.postprocess(vae_output, output_type=\"pil\")[0]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Text-to-Image Generation\n", + "## Refining the Generated Image\n", + "\n", + "For refining the image generated by the base pipeline, we using the SDXL Refiner pipeline using the base and refiner model as an ensemble of expert of denoisers. In this case, the base model should serve as the expert for the high-noise diffusion stage and the refiner serves as the expert for the low-noise diffusion stage.\n", "\n", - "Now, we pass the prompts and the negative prompts to the base model and then pass the output to the refiner for firther refinement. In order to know more about the different refinement techniques that can be used with SDXL, you can check [`diffusers` docs](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)." + "1. We define the diffusion pipeline for the refiner using `diffusers.DiffusionPipeline` and load the pre-trained weights for SDXL 1.0 refiner by calling the `from_pretrained` function on it. We also pass the scheduler as `diffusers.EulerDiscreteScheduler` in this step.\n", + "\n", + "2. In case we don't have a GPU with large enough GPU, it's recommended to enable CPU offloading. Otherwise, we load the model on the GPU. In case you're curious how HiggingFace manages CPU offloading in the most optimized manner, we recommend you read this port by [Sylvain Gugger](https://huggingface.co/sgugger): [How 🤗 Accelerate runs very large models thanks to PyTorch](https://huggingface.co/blog/accelerate-large-models).\n", + "\n", + "3. We can compile model using `torch.compile`, this might give a significant speedup.\n", + "\n", + "4. We refine the latents generated by the base model from the same set of prompts and negative prompts using the refiner pipeline." ] }, { @@ -211,56 +218,37 @@ "metadata": {}, "outputs": [], "source": [ - "if config.use_ensemble_of_experts:\n", - " latent = pipe(\n", - " prompt=config.prompt_1 if config.prompt_1 != \"\" else None,\n", - " prompt_2=config.prompt_2 if config.prompt_2 != \"\" else None,\n", - " negative_prompt=config.negative_prompt_1 if config.negative_prompt_1 != \"\" else None,\n", - " negative_prompt_2=config.negative_prompt_2 if config.negative_prompt_2 != \"\" else None,\n", - " output_type=\"latent\",\n", - " num_inference_steps=config.num_inference_steps,\n", - " denoising_end=config.high_noise_fraction,\n", - " generator=generator,\n", + "if config.enable_refinement:\n", + " refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(\n", + " config.refiner_checkpoint,\n", + " text_encoder_2=pipe.text_encoder_2,\n", + " vae=pipe.vae,\n", + " torch_dtype=torch.float16,\n", + " use_safetensors=True,\n", + " variant=\"fp16\",\n", + " scheduler=EulerDiscreteScheduler(**config.scheduler_kwargs),\n", " )\n", - "else:\n", - " latent = pipe(\n", - " prompt=config.prompt_1 if config.prompt_1 != \"\" else None,\n", - " prompt_2=config.prompt_2 if config.prompt_2 != \"\" else None,\n", - " negative_prompt=config.negative_prompt_1 if config.negative_prompt_1 != \"\" else None,\n", - " negative_prompt_2=config.negative_prompt_2 if config.negative_prompt_2 != \"\" else None,\n", - " output_type=\"latent\",\n", - " num_inference_steps=config.num_inference_steps,\n", - " generator=generator,\n", - " )\n", - "unrefined_image = postprocess_latent(latent)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "if config.use_ensemble_of_experts:\n", - " refined_image = refiner(\n", - " prompt=config.prompt_1 if config.prompt_1 != \"\" else None,\n", - " prompt_2=config.prompt_2 if config.prompt_2 != \"\" else None,\n", - " negative_prompt=config.negative_prompt_1 if config.negative_prompt_1 != \"\" else None,\n", - " negative_prompt_2=config.negative_prompt_2 if config.negative_prompt_2 != \"\" else None,\n", - " image=latent.images,\n", + "\n", + " if config.enable_cpu_offload_refiner:\n", + " refiner.enable_model_cpu_offload()\n", + " else:\n", + " refiner.to(\"cuda\")\n", + " \n", + " # Compile model using `torch.compile`, this might give a significant speedup\n", + " if config.compile_refinement_model:\n", + " refiner.unet = torch.compile(pipe.unet, mode=\"reduce-overhead\", fullgraph=True)\n", + "\n", + " generated_image = refiner(\n", + " prompt=config.prompt_1,\n", + " prompt_2=config.prompt_2,\n", + " negative_prompt=config.negative_prompt_1,\n", + " negative_prompt_2=config.negative_prompt_2,\n", + " guidance_scale=config.refiner_guidance_scale,\n", + " image=generated_image,\n", " num_inference_steps=config.num_refinement_steps,\n", " denoising_start=config.high_noise_fraction,\n", " generator=generator,\n", - " ).images[0]\n", - "else:\n", - " refined_image = refiner(\n", - " prompt=config.prompt_1 if config.prompt_1 != \"\" else None,\n", - " prompt_2=config.prompt_2 if config.prompt_2 != \"\" else None,\n", - " negative_prompt=config.negative_prompt_1 if config.negative_prompt_1 != \"\" else None,\n", - " negative_prompt_2=config.negative_prompt_2 if config.negative_prompt_2 != \"\" else None,\n", - " image=latent.images[0][None, :],\n", - " generator=generator,\n", - " ).images[0]" + " ).images" ] }, { @@ -288,13 +276,10 @@ " \"Prompt-2\",\n", " \"Negative-Prompt-1\",\n", " \"Negative-Prompt-2\",\n", - " \"Unrefined-Image\",\n", - " \"Refined-Image\",\n", - " \"Use-Ensemble-of-Experts\",\n", + " \"Generated-Image\",\n", "])\n", "\n", - "unrefined_image = wandb.Image(unrefined_image)\n", - "refined_image = wandb.Image(refined_image)\n", + "generated_image = wandb.Image(generated_image[0])\n", "\n", "# Add the images to the table\n", "table.add_data(\n", @@ -302,15 +287,12 @@ " config.prompt_2,\n", " config.negative_prompt_1,\n", " config.negative_prompt_2,\n", - " unrefined_image,\n", - " refined_image,\n", - " config.use_ensemble_of_experts,\n", + " generated_image,\n", ")\n", "\n", "# Log the images and table to wandb\n", "wandb.log({\n", - " \"Unrefined-Image\": unrefined_image,\n", - " \"Refined-Image\": refined_image,\n", + " \"Generated-Image\": generated_image,\n", " \"Text-to-Image\": table\n", "})\n", "\n",