-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fixes Imagen sampling example and updates container (#868)
1. Fix imagen sampling loop when prompt_ct is a multiple of `batch_size // gen_per_prompt` 2. Add comment explaining the invocation of sampling script of what exactly is expected due to implicit checkpoint dir requirements and quoting 3. add 2B base model generation gin configs 4. parametrize imagen sampling scripts 5. Updates the imagen image with the fix in (1); built as follows: ``` docker buildx build --push -t ghcr.io/nvidia/t5x:imagen-2023-10-02.v3 ./JAX-Toolbox -f - <<EOF FROM ghcr.io/nvidia/t5x:imagen-2023-10-02 COPY rosetta/rosetta/projects/diffusion/common/set_gpu_xla_flags.sh /opt/rosetta/rosetta/projects/diffusion/common/set_gpu_xla_flags.sh COPY rosetta/rosetta/projects/imagen/README.md /opt/rosetta/rosetta/projects/imagen/README.md COPY rosetta/rosetta/projects/imagen/configs/imagen_1024_sample_2b.gin /opt/rosetta/rosetta/projects/imagen/configs/imagen_1024_sample_2b.gin COPY rosetta/rosetta/projects/imagen/configs/imagen_256_sample_2b.gin /opt/rosetta/rosetta/projects/imagen/configs/imagen_256_sample_2b.gin COPY rosetta/rosetta/projects/imagen/imagen_pipe.py /opt/rosetta/rosetta/projects/imagen/imagen_pipe.py COPY rosetta/rosetta/projects/imagen/scripts/example_slurm_inf_train.sub /opt/rosetta/rosetta/projects/imagen/scripts/example_slurm_inf_train.sub COPY rosetta/rosetta/projects/imagen/scripts/sample_imagen_1024.sh /opt/rosetta/rosetta/projects/imagen/scripts/sample_imagen_1024.sh COPY rosetta/rosetta/projects/imagen/scripts/sample_imagen_256.sh /opt/rosetta/rosetta/projects/imagen/scripts/sample_imagen_256.sh EOF ``` --------- Signed-off-by: Terry Kong <terryk@nvidia.com>
- Loading branch information
Showing
9 changed files
with
347 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
3 changes: 2 additions & 1 deletion
3
rosetta/rosetta/projects/diffusion/common/set_gpu_xla_flags.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=false --xla_gpu_enable_triton_gemm=false --xla_gpu_cuda_graph_level=0 --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_disable_async_collectives=allreduce,allgather,reducescatter,collectivebroadcast,alltoall,collectivepermute ${XLA_FLAGS}" | ||
# These XLA flags are meant to be used with the JAX version in the imagen container | ||
export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=false --xla_gpu_enable_async_all_gather=false --xla_gpu_enable_async_reduce_scatter=false --xla_gpu_enable_triton_gemm=false --xla_gpu_cuda_graph_level=0 --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_async_all_reduce=false ${XLA_FLAGS}" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
78 changes: 78 additions & 0 deletions
78
rosetta/rosetta/projects/imagen/configs/imagen_1024_sample_2b.gin
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# Imagen Sampling pipeline | ||
include "rosetta/projects/imagen/configs/imagen_256_sample_2b.gin" | ||
|
||
from __gin__ import dynamic_registration | ||
import __main__ as sample_script | ||
from t5x import gin_utils | ||
from t5x import utils | ||
from t5x import partitioning | ||
|
||
from rosetta.projects.imagen import network_sr | ||
from rosetta.projects.diffusion import models | ||
from rosetta.projects.diffusion import denoisers | ||
from rosetta.projects.diffusion import samplers | ||
from rosetta.projects.diffusion import losses | ||
from rosetta.projects.diffusion import augmentations | ||
|
||
#---------------- SR1024 Model ------------------------------------------------- | ||
|
||
# ------------------- Model ---------------------------------------------------- | ||
SR1024 = @sr1024/models.DenoisingDiffusionModel() | ||
SIGMA_DATA = 0.5 | ||
sr1024/models.DenoisingDiffusionModel: | ||
denoiser= @sr1024/denoisers.EDMTextConditionedSuperResDenoiser() | ||
diffusion_loss= None | ||
diffusion_sampler= @sr1024/samplers.EDMSampler() | ||
optimizer_def = None | ||
|
||
# |--- Denoiser | ||
sr1024/denoisers.EDMTextConditionedSuperResDenoiser: | ||
raw_model= @sr1024/network_sr.ImagenEfficientUNet() | ||
|
||
sr1024/samplers.EDMSampler: | ||
dim_noise_scalar = 4. | ||
|
||
# ------------------- Network specification ------------------------------------ | ||
sr1024/network_sr.ImagenEfficientUNet.config = @sr1024/network_sr.ImagenEfficientUNetConfig() | ||
sr1024/network_sr.ImagenEfficientUNetConfig: | ||
dtype = %DTYPE | ||
model_dim = 128 | ||
cond_dim = 1024 | ||
resblocks_per_level = (2, 4, 8, 8, 8) | ||
width_multipliers = (1, 2, 4, 6, 6) | ||
attn_resolutions_divs = {16: 'cross'} | ||
mha_head_dim = 64 | ||
attn_heads = 8 | ||
resblock_activation = 'silu' | ||
resblock_zero_out = True | ||
resblock_scale_skip = True | ||
dropout_rate = %DROPOUT_RATE | ||
cond_strategy = 'shift_scale' | ||
norm_32 = True | ||
scale_attn_logits = True | ||
float32_attention_logits=False | ||
text_conditionable = True | ||
|
||
sr1024/samplers.CFGSamplingConfig: | ||
num_steps=30 | ||
cf_guidance_weight=0.0 | ||
cf_guidance_nulls={'text': None, 'text_mask': None} | ||
|
||
sr1024/partitioning.PjitPartitioner: | ||
num_partitions = 1 | ||
logical_axis_rules = @partitioning.standard_logical_axis_rules() | ||
|
||
sr1024/utils.RestoreCheckpointConfig: | ||
mode = 'specific' | ||
dtype = 'bfloat16' | ||
|
||
sr1024/sample_script.DiffusionModelSetupData: | ||
model = %SR1024 | ||
sampling_cfg = @sr1024/samplers.CFGSamplingConfig() | ||
restore_checkpoint_cfg = @sr1024/utils.RestoreCheckpointConfig() | ||
partitioner = @partitioning.PjitPartitioner() | ||
input_shapes = {'samples': (1, 1024, 1024, 3), 'text': %TXT_SHAPE, 'text_mask': %TXT_SEQLEN, 'low_res_images': (1, 256, 256, 3)} | ||
input_types = {'samples': 'float32', 'text': 'float16', 'text_mask': 'int', 'low_res_images': 'float32'} | ||
|
||
sample_script.sample: | ||
sr1024_setupdata = @sr1024/sample_script.DiffusionModelSetupData() |
220 changes: 220 additions & 0 deletions
220
rosetta/rosetta/projects/imagen/configs/imagen_256_sample_2b.gin
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,220 @@ | ||
# Imagen Sampling pipeline | ||
from __gin__ import dynamic_registration | ||
|
||
import __main__ as sample_script | ||
from t5x import gin_utils | ||
from t5x import utils | ||
from t5x import partitioning | ||
|
||
SAVE_DIR='generations' | ||
PROMPT_TEXT_FILE='custom_text.txt' | ||
GLOBAL_BATCH_SIZE=32 | ||
MAX_GENERATE=50000000 | ||
GEN_PER_PROMPT=2 | ||
NOISE_COND_AUG=0.002 | ||
|
||
TXT_SHAPE=(1, 128, 4096) #T5 xxl, seqlen x embed_dim | ||
TXT_SEQLEN=(1, 128, ) | ||
TXT_SEQLEN_SINGLE=128 | ||
DTYPE='bfloat16' | ||
DROPOUT_RATE=0 | ||
RESUME_FROM=0 #Sampling count to resume from | ||
#---------------- Base Model ------------------------------------------------- | ||
from rosetta.projects.imagen import network | ||
from rosetta.projects.imagen import network_sr | ||
from rosetta.projects.diffusion import models | ||
from rosetta.projects.diffusion import denoisers | ||
from rosetta.projects.diffusion import samplers | ||
from rosetta.projects.diffusion import losses | ||
from rosetta.projects.diffusion import augmentations | ||
|
||
# ------------------- Model ---------------------------------------------------- | ||
BASE = @base_model/models.DenoisingDiffusionModel() | ||
base_model/models.DenoisingDiffusionModel: | ||
denoiser= @base_model/denoisers.EDMTextConditionedDenoiser() | ||
diffusion_loss = None | ||
diffusion_sampler= @base_model/samplers.EDMSampler() | ||
optimizer_def = None | ||
|
||
# |--- Denoiser | ||
base_model/denoisers.EDMTextConditionedDenoiser: | ||
raw_model= @base_model/network.ImagenUNet() | ||
|
||
# ------------------- Network specification ------------------------------------ | ||
base_model/network.ImagenUNet.config = @base_model/network.DiffusionConfig() | ||
base_model/network.DiffusionConfig: | ||
dtype = %DTYPE | ||
model_dim = 512 | ||
attn_cond_dim = 2048 | ||
cond_dim = 2048 | ||
resblocks_per_level = 3 | ||
width_multipliers = (1, 2, 3, 4) | ||
attn_resolutions = (32, 16, 8) | ||
mha_head_dim = 64 | ||
attn_heads = 8 | ||
resblock_activation = 'silu' | ||
dropout_rate = %DROPOUT_RATE | ||
upsample_mode = 'shuffle' | ||
downsample_mode = 'shuffle' | ||
spatial_skip = False | ||
cond_strategy = 'shift_scale' | ||
norm_32 = True | ||
scale_attn_logits = True | ||
float32_attention_logits = False | ||
text_conditionable = True | ||
|
||
|
||
BASE_SAMPLING_CONFIG = @base_model/samplers.CFGSamplingConfig() | ||
base_model/samplers.CFGSamplingConfig: | ||
num_steps=50 | ||
cf_guidance_weight=5.00 | ||
cf_guidance_nulls=None | ||
|
||
base_model/partitioning.PjitPartitioner: | ||
num_partitions = 1 | ||
logical_axis_rules = @partitioning.standard_logical_axis_rules() | ||
|
||
base_model/utils.RestoreCheckpointConfig: | ||
mode = 'specific' | ||
dtype = 'bfloat16' | ||
|
||
base_model/sample_script.DiffusionModelSetupData: | ||
model = %BASE | ||
sampling_cfg = @base_model/samplers.CFGSamplingConfig() | ||
restore_checkpoint_cfg = @base_model/utils.RestoreCheckpointConfig() | ||
partitioner = @partitioning.PjitPartitioner() | ||
input_shapes = {'samples': (1, 64, 64, 3), 'text': %TXT_SHAPE, 'text_mask': %TXT_SEQLEN} | ||
input_types = {'samples': 'float32', 'text': 'float16', 'text_mask': 'int'} | ||
|
||
#---------------- SR256 Model ------------------------------------------------- | ||
|
||
# ------------------- Model ---------------------------------------------------- | ||
SR256 = @sr256/models.DenoisingDiffusionModel() | ||
SIGMA_DATA = 0.5 | ||
sr256/models.DenoisingDiffusionModel: | ||
denoiser= @sr256/denoisers.EDMTextConditionedSuperResDenoiser() | ||
diffusion_loss= None | ||
diffusion_sampler= @sr256/samplers.EDMSampler() | ||
optimizer_def = None | ||
|
||
# |--- Denoiser | ||
sr256/denoisers.EDMTextConditionedSuperResDenoiser: | ||
raw_model= @sr256/network_sr.ImagenEfficientUNet() | ||
|
||
sr256/samplers.EDMSampler: | ||
dim_noise_scalar = 4. | ||
|
||
# ------------------- Network specification ------------------------------------ | ||
sr256/network_sr.ImagenEfficientUNet.config = @sr256/network_sr.ImagenEfficientUNetConfig() | ||
sr256/network_sr.ImagenEfficientUNetConfig: | ||
dtype = %DTYPE | ||
model_dim = 128 | ||
cond_dim = 512 | ||
attn_cond_dim = 1024 | ||
resblocks_per_level = (2, 4, 8, 8, 2) | ||
width_multipliers = (1, 2, 4, 8, 8) | ||
attn_resolutions_divs = {8: 'fused', 16: 'fused'} | ||
mha_head_dim = 64 | ||
attn_heads = 8 | ||
resblock_activation = 'silu' | ||
resblock_zero_out = True | ||
resblock_scale_skip = True | ||
dropout_rate = %DROPOUT_RATE | ||
cond_strategy = 'shift_scale' | ||
norm_32 = True | ||
scale_attn_logits = True | ||
float32_attention_logits=False | ||
text_conditionable = True | ||
|
||
sr256/samplers.CFGSamplingConfig: | ||
num_steps=50 | ||
cf_guidance_weight=4 | ||
cf_guidance_nulls={'text': None, 'text_mask': None} | ||
|
||
sr256/partitioning.PjitPartitioner: | ||
num_partitions = 1 | ||
logical_axis_rules = @partitioning.standard_logical_axis_rules() | ||
|
||
sr256/utils.RestoreCheckpointConfig: | ||
mode = 'specific' | ||
dtype = 'bfloat16' | ||
|
||
sr256/sample_script.DiffusionModelSetupData: | ||
model = %SR256 | ||
sampling_cfg = @sr256/samplers.CFGSamplingConfig() | ||
restore_checkpoint_cfg = @sr256/utils.RestoreCheckpointConfig() | ||
partitioner = @partitioning.PjitPartitioner() | ||
input_shapes = {'samples': (1, 256, 256, 3), 'text': %TXT_SHAPE, 'text_mask': %TXT_SEQLEN, 'low_res_images': (1, 64, 64, 3)} | ||
input_types = {'samples': 'float32', 'text': 'float16', 'text_mask': 'int', 'low_res_images': 'float32'} | ||
|
||
#---------------- Text Model ------------------------------------------------- | ||
import seqio | ||
from rosetta.projects.inference_serving.t5 import network as t5x_network | ||
from rosetta.projects.inference_serving.t5 import models as t5x_models | ||
|
||
# ===================================== | ||
# === T5 Encoder only configuration === | ||
# ===================================== | ||
T5_CHECKPOINT_PATH = "/opt/rosetta/rosetta/projects/inference_serving/checkpoints/checkpoint_1000000_t5_1_1_xxl" | ||
BATCH_SIZE = 256 # Will be overridden | ||
SEQ_LEN = 128 # MAX seqlen | ||
|
||
# Vocabulary | ||
VOCABULARY = @seqio.SentencePieceVocabulary() | ||
seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model" | ||
TASK_FEATURE_LENGTHS = None # auto-computes the maximum features length to use. | ||
|
||
# --------------- Model ------------------ | ||
TEXT_ENC = @text_enc/t5x_models.EncoderOnlyModel() | ||
text_enc/t5x_models.EncoderOnlyModel: | ||
module = @t5x_network.TransformerEncoderOnly() | ||
input_vocabulary = %VOCABULARY | ||
output_vocabulary = %VOCABULARY | ||
optimizer_def = None | ||
z_loss = 0.0001 | ||
label_smoothing = 0.0 | ||
loss_normalizing_factor = None | ||
|
||
# -------- Network specification --------- | ||
t5x_network.TransformerEncoderOnly.config = @t5x_network.T5Config() | ||
t5x_network.T5Config: | ||
vocab_size = 32128 # vocab size rounded to a multiple of 128 for TPU efficiency | ||
dtype = 'bfloat16' | ||
emb_dim = 4096 | ||
num_heads = 64 | ||
num_encoder_layers = 24 | ||
num_decoder_layers = 0 | ||
head_dim = 64 | ||
mlp_dim = 10240 | ||
mlp_activations = ('gelu', 'linear') | ||
dropout_rate = 0.0 | ||
|
||
text_enc/partitioning.PjitPartitioner: | ||
num_partitions = 1 | ||
logical_axis_rules = @partitioning.standard_logical_axis_rules() | ||
|
||
text_enc/utils.RestoreCheckpointConfig: | ||
path = %T5_CHECKPOINT_PATH | ||
mode = 'specific' | ||
dtype = 'bfloat16' | ||
|
||
text_enc/sample_script.setup_text_enc: | ||
model=%TEXT_ENC | ||
restore_checkpoint_cfg=@text_enc/utils.RestoreCheckpointConfig() | ||
partitioner=@text_enc/partitioning.PjitPartitioner() | ||
batch_size=1 | ||
seq_len=%TXT_SEQLEN_SINGLE | ||
vocab = %VOCABULARY | ||
|
||
sample_script.sample: | ||
base_setupdata = @base_model/sample_script.DiffusionModelSetupData() | ||
sr256_setupdata = @sr256/sample_script.DiffusionModelSetupData() | ||
sr1024_setupdata = None | ||
out_dir = %SAVE_DIR | ||
gen_per_prompt = %GEN_PER_PROMPT | ||
prompt_file = %PROMPT_TEXT_FILE | ||
batch_size = %GLOBAL_BATCH_SIZE | ||
max_images = %MAX_GENERATE | ||
text_enc_infer = @text_enc/sample_script.setup_text_enc() | ||
noise_conditioning_aug = %NOISE_COND_AUG | ||
resume_from = %RESUME_FROM |
Oops, something went wrong.