Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix mixed precision for replicate / pure DDP #591

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,9 @@ def __init__(self):
default="bfloat16",
choices=["bfloat16", "float32"],
help="""
torch dtype to use for parameters when applying mixed precision via FSDP.
This feature only takes effect when data_parallel_degree > 1
torch dtype to use for parameters when applying mixed precision.
When data_parallel_shard_degree > 1, this changes FSDP's `param_dtype`.
When data_parallel_shard_degree == 1, this enables AMP autocast.
""",
)
self.parser.add_argument(
Expand Down
12 changes: 9 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from torchtitan import utils
from torchtitan.checkpoint import CheckpointManager, TrainState
from torchtitan.config_manager import JobConfig
from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.datasets import build_hf_data_loader, build_tokenizer
from torchtitan.float8 import Float8Handler
from torchtitan.logging import init_logger, logger
Expand Down Expand Up @@ -288,8 +288,14 @@ def loss_fn(pred, labels):
else:
# Non-PP forward / backward
with train_context():
pred = model(input_ids)
loss = loss_fn(pred, labels)
with contextlib.nullcontext() if parallel_dims.dp_shard_enabled else torch.autocast(
"cuda",
dtype=TORCH_DTYPE_MAP[
job_config.training.mixed_precision_param
],
):
Comment on lines +291 to +296
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO there are two ways:

  1. To simplify it here, we can have a separate get_foward_only_context() which encapsulates the logic.
  2. Modify get_train_context to receive a boolean variable is_backward so that we can obtain two different contexts, one for forward, the other for backward. Pro: only one method to obtain context. Con: need to enter and exit twice for context managers shared by forward and backward.
    cc: @awgu wdyt?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 sounds good to me

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will maybe_enable_compiled_autograd work correctly when it exits and re-enters in-between a fwd-bck?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ye, that should work. It's a context manager to enable the flag. So it should be okay to separately wrap fwd and bwd with maybe_enable_compiled_autograd.

pred = model(input_ids)
loss = loss_fn(pred, labels)
# pred.shape=(bs, seq_len, vocab_size)
# need to free to before bwd to avoid peaking memory
del pred
Expand Down