Skip to content

Commit

Permalink
updated the code
Browse files Browse the repository at this point in the history
  • Loading branch information
Deepak Yadav committed Dec 25, 2023
1 parent f434f2a commit 2c1f3ef
Show file tree
Hide file tree
Showing 10 changed files with 107 additions and 80 deletions.
24 changes: 18 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

This repository contains a PyTorch implementation of a stable diffusion model from scratch. The diffusion model is a powerful probabilistic generative model that has applications in image synthesis, denoising, and other tasks.

![Stable Diffusion](assets/diffusion.jpg)

## Features

- **PyTorch Implementation:** The diffusion model is implemented using PyTorch, providing flexibility and ease of use for both training and inference.
Expand Down Expand Up @@ -37,29 +39,39 @@ This repository contains a PyTorch implementation of a stable diffusion model fr

3. Download the pre-trained weights
```bash
wget -O saved_models https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt
wget -O saved_models/v1-5-pruned-emaonly.ckpt https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt
```

### Usage

1. Training:
1. Inference (Text-to-image):

```bash
python train.py --dataset your_dataset --epochs 1000 --lr 0.001
python src/inference.py --prompt "photograph of an astronaut riding a horse"
```
![Text-to-image](assets/txt-to-img.jpg)

2. Inference:
2. Inference (Image-to-image):

```bash
python inference.py --prompt="a dog is flying in the sky" --model_path saved_models/model.pth --num_samples 10
python src/inference.py --prompt "Put hat on the cat head, ultra-sharp, cinematic, 100mm lens, 8k resolution." --image-path="cat.jpg"
```
![Image-to-image](assets/img-to-img.jpg)

## Structure

- **`src/`**: Contains the source code for the diffusion model implementation.
- **`data/`**: Placeholder for your dataset or data loading scripts.
- **`saved_models/`**: Directory to store trained model checkpoints.
- **`experiments/`**: Logs and other experiment-related files.

## Future Work

- **Implement Training Strategies:** To train diffusion model on custom dataset.

- **Parallelization and Optimization:** Investigate opportunities for parallelizing training and optimizing the code for faster convergence on different hardware configurations.

- **Quantization:** Qauntize the model for faster inference time


## Contributing

Expand Down
Binary file added assets/diffusion.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/txt-to-img.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 0 additions & 1 deletion experiments/exp.txt

This file was deleted.

11 changes: 6 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
torch
numpy
tqdm
transformers
pytorch_lightning
numpy==1.26.2
Pillow==10.1.0
pytorch-lightning==2.1.3
torch==2.1.2
tqdm==4.66.1
transformers==4.36.2
Empty file added src/__init__.py
Empty file.
65 changes: 0 additions & 65 deletions src/demo.py

This file was deleted.

81 changes: 81 additions & 0 deletions src/inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import os
import argparse
import model_loader
import pipeline
from PIL import Image
from transformers import CLIPTokenizer
import torch
import subprocess

def parse_args():
parser = argparse.ArgumentParser(description="Run the inference script.")
parser.add_argument("--prompt", type=str, help="Text prompt for text-to-image generation")
parser.add_argument("--uncond-prompt", type=str, default="", help="Unconditional prompt for text-to-image generation (negative prompt)")
parser.add_argument("--image-path", type=str, default=None, help="Path to the input image for image-to-image generation")
parser.add_argument("--output-path", type=str, default="output_image.jpg", help="Path to save the output image")
parser.add_argument("--strength", type=float, default=0.9, help="Strength parameter for image-to-image generation")
parser.add_argument("--do-cfg", default=True, help="Enable conditional configuration for text-to-image generation")
parser.add_argument("--cfg-scale", type=float, default=8, help="Scale parameter for conditional configuration")
parser.add_argument("--sampler", type=str, default="ddpm", help="Sampler name for image-to-image generation")
parser.add_argument("--num-inference-steps", type=int, default=25, help="Number of inference steps")
parser.add_argument("--seed", type=int, default=42, help="Seed for random number generation")
return parser.parse_args()

def download_model(model_url, save_path):
# Download the model using wget
subprocess.run(["wget", "-O", save_path, model_url])

def main():
args = parse_args()

DEVICE = "cpu"

ALLOW_CUDA = False
ALLOW_MPS = False

if torch.cuda.is_available() and ALLOW_CUDA:
DEVICE = "cuda"
elif (torch.torch.backends.mps.is_built() or torch.backends.mps.is_available()) and ALLOW_MPS:
DEVICE = "mps"
print(f"Using device: {DEVICE}")

# Get the absolute path to the vocabulary and merges files
vocab_file = os.path.join(os.getcwd(), "data/vocab.json")
merges_file = os.path.join(os.getcwd(), "data/merges.txt")
tokenizer = CLIPTokenizer(vocab_file, merges_file=merges_file)

# Check if the model file exists
model_file_path = os.path.join(os.getcwd(), "saved_models/v1-5-pruned-emaonly.ckpt")
if not os.path.exists(model_file_path):
print(f"Model file '{model_file_path}' not found. Downloading...")
model_url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt"
download_model(model_url, model_file_path)
print("Download complete.")

models = model_loader.preload_models_from_standard_weights(model_file_path, DEVICE)

if args.image_path:
input_image = Image.open(args.image_path)
else:
input_image = None

output_image = pipeline.generate(
prompt=args.prompt,
uncond_prompt=args.uncond_prompt,
input_image=input_image,
strength=args.strength,
do_cfg=args.do_cfg,
cfg_scale=args.cfg_scale,
sampler_name=args.sampler,
n_inference_steps=args.num_inference_steps,
seed=args.seed,
models=models,
device=DEVICE,
idle_device="cpu",
tokenizer=tokenizer,
)

Image.fromarray(output_image).save(args.output_path)

if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion src/model_loader.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from clip import CLIP
from encoder import VAE_Encoder
from decoder import VAE_Decoder
from stable_diffusion_pytorch.diffusion import Diffusion
from diffusion import Diffusion

import model_converter

Expand Down
3 changes: 1 addition & 2 deletions src/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,10 @@ def generate(
latents_shape = (1, 4, LATENTS_HEIGHT, LATENTS_WIDTH)

if input_image:
input_image = input_image.to(device)
encoder = models["encoder"]
encoder.to(device)

input_image_tensor = input_image.resize((WIDTH, HEIGHT)).to(device)
input_image_tensor = input_image.resize((WIDTH, HEIGHT))
# (Height, Width, Channel)
input_image_tensor = np.array(input_image_tensor)
# (Height, Width, Channel) -> (Height, Width, Channel)
Expand Down

0 comments on commit 2c1f3ef

Please sign in to comment.