You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The function make_graphed_callables always overrides the forward function of the module with wrapped version in fp8_autocast. One issue with this approach is that once we wrap a given module, we cannot use it without FP8, even if we are not using the graphed version. A quick pseudocode example:
module = te.Linear(1028, 1028)
b = te.make_graphed_callables(module, sample_args, fp8_enabled=True, _order=[1,-1]) # _order argument makes it so that we can still use module for non-graphed callable
# At this point b(arg) will execute our module as a graph in FP8, which is fine and expected. We can also call module(arg) to use non-graphed module in FP8. However, this still executes module in FP8:
with te.fp8_autocast(enabled=False):
module(arg) # Still executed in FP8!
It's because of this line, which always executes the module in FP8 status we've given at graph creation.
I'd like to see an option to still use non-FP8 non-graphed module even after creating a CUDA Graph with FP8.
The text was updated successfully, but these errors were encountered:
The function
make_graphed_callables
always overrides the forward function of the module with wrapped version in fp8_autocast. One issue with this approach is that once we wrap a given module, we cannot use it without FP8, even if we are not using the graphed version. A quick pseudocode example:It's because of this line, which always executes the module in FP8 status we've given at graph creation.
I'd like to see an option to still use non-FP8 non-graphed module even after creating a CUDA Graph with FP8.
The text was updated successfully, but these errors were encountered: