Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix mixed precision for replicate / pure DDP #591

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

152334H
Copy link

@152334H 152334H commented Sep 29, 2024

Hi. I noticed the following:

  1. the keyword autocast does not exist in the repository
  2. MixedPrecisionConfig is only used in the fully_shard codepath
  3. the duration of a dummy 1000 step run is a lot longer with DDP than with FSDP

All of the above indicates that, when dp_shard_enabled is false, training runs with pure fp32, regardless of the mixed precision config.

This pull request changes the code to use torch.autocast in the training forward pass, specifically only when dp_shard_enabled is false, to the dtype of mixed_precision_param.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 29, 2024
@152334H
Copy link
Author

152334H commented Sep 29, 2024

If mixed_precision_reduce ever has arguments beyond float32, it may be worth adding an appropriate compression hook after replicate as well, e.g.

diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py
index fc26703..bdbcbb2 100644
--- a/torchtitan/parallelisms/parallelize_llama.py
+++ b/torchtitan/parallelisms/parallelize_llama.py
@@ -344,5 +344,7 @@ def apply_ddp(
             torch._dynamo.config.optimize_ddp = "ddp_optimizer"

     replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100)
+    from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import bf16_compress_hook
+    model.register_comm_hook(dp_mesh.get_group(), bf16_compress_hook)

     logger.info("Applied DDP to the model")

@awgu
Copy link
Contributor

awgu commented Sep 29, 2024

might be nice to fold this as part of the train_context

@152334H
Copy link
Author

152334H commented Sep 29, 2024

@awgu it can't be folded into the train context unless you're certain autocasting the backward will lead to no problems. According to torch documentation,

autocast should wrap only the forward pass(es) of your network, including the loss computation(s). Backward passes under autocast are not recommended. Backward ops run in the same type that autocast used for corresponding forward ops.

Personally, in other projects, I have experienced issues where torch.compiling a network with autocast over the backwards inexplicably causes NaNs to appear. I don't know whether that applies in this case but will take any overruling if people are certain it will work without problems.

@awgu awgu requested a review from tianyu-l September 30, 2024 12:00
@awgu
Copy link
Contributor

awgu commented Sep 30, 2024

cc: @tianyu-l on thoughts on how to handle this
perhaps separate forward and backward contexts

@tianyu-l
Copy link
Contributor

Hmm I didn't have much context on this.

@fegin How is DDP supposed to handle mixed precision? Is AMP autocasting suggested?

@awgu
Copy link
Contributor

awgu commented Oct 1, 2024

I think Rohan added a tentative mixed precision API for DDP, but it never made it to public feature. I think using AMP is probably the way to go.

@fegin
Copy link
Contributor

fegin commented Oct 1, 2024

autocast is the right way for DDP.

@tianyu-l tianyu-l requested a review from fegin October 3, 2024 18:23
This feature only takes effect when data_parallel_degree > 1
torch dtype to use for parameters when applying mixed precision.
When data_parallel_degree > 1, this changes FSDP's `param_dtype`.
When data_parallel_degree == 1, this enables AMP autocast.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused by the description here. Other than DDP, does it work with TP, PP, CP, etc.?
I wonder if there is a document/tutorial on this.
cc: @fegin

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the implementation of mixed precision only needs to change based on whether fully_shard (which has its own internal mixed prec mechanism) or replicate (which internally wraps a module with DDP) is used.

therefore, the qn of what happens (and what should happen) w/ TP/PP/DP is solely dependent on the sharded data parallel degree.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens if neither FSDP or DDP is used? btw DDP is not composable with TP/PP today as far as I know.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If FSDP is not used, AMP should be the default answer for all other parallelisms. But we don't actually test AMP + TP.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mean the behavior of dtensor TP sharded operations under AMP in general pytorch is not tested? or that specifically it is not tested in torchtitan

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DTensor TP+ AMP

Comment on lines +291 to +296
with contextlib.nullcontext() if parallel_dims.dp_shard_enabled else torch.autocast(
"cuda",
dtype=TORCH_DTYPE_MAP[
job_config.training.mixed_precision_param
],
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO there are two ways:

  1. To simplify it here, we can have a separate get_foward_only_context() which encapsulates the logic.
  2. Modify get_train_context to receive a boolean variable is_backward so that we can obtain two different contexts, one for forward, the other for backward. Pro: only one method to obtain context. Con: need to enter and exit twice for context managers shared by forward and backward.
    cc: @awgu wdyt?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 sounds good to me

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will maybe_enable_compiled_autograd work correctly when it exits and re-enters in-between a fwd-bck?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ye, that should work. It's a context manager to enable the flag. So it should be okay to separately wrap fwd and bwd with maybe_enable_compiled_autograd.

torchtitan/config_manager.py Outdated Show resolved Hide resolved
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants