diff --git a/rosetta/rosetta/projects/pax/xla_flags/grok-proxy.env b/rosetta/rosetta/projects/pax/xla_flags/grok-proxy.env index 1da816f1b..e48b76dcf 100644 --- a/rosetta/rosetta/projects/pax/xla_flags/grok-proxy.env +++ b/rosetta/rosetta/projects/pax/xla_flags/grok-proxy.env @@ -1,13 +1,16 @@ set -x +ALL_REDUCE_THRESHOLD_BYTES=3221225472 +ALL_GATHER_THRESHOLD_BYTES=3221225472 +REDUCE_SCATTER_THRESHOLD_BYTES=402653184 export XLA_FLAGS="\ --xla_gpu_enable_latency_hiding_scheduler=true \ --xla_allow_excess_precision \ --xla_gpu_enable_highest_priority_async_stream=true \ --xla_gpu_enable_triton_softmax_fusion=false \ - --xla_gpu_all_reduce_combine_threshold_bytes=3221225472 \ + --xla_gpu_all_reduce_combine_threshold_bytes=${ALL_REDUCE_THRESHOLD_BYTES} \ --xla_gpu_graph_level=0 \ - --xla_gpu_all_gather_combine_threshold_bytes=3221225472 \ - --xla_gpu_reduce_scatter_combine_threshold_bytes=402653184 \ + --xla_gpu_all_gather_combine_threshold_bytes=${ALL_GATHER_THRESHOLD_BYTES} \ + --xla_gpu_reduce_scatter_combine_threshold_bytes=${REDUCE_SCATTER_THRESHOLD_BYTES} \ --xla_gpu_enable_pipelined_all_gather=true \ --xla_gpu_enable_pipelined_reduce_scatter=true \ --xla_gpu_enable_pipelined_all_reduce=true \ @@ -18,4 +21,5 @@ export XLA_FLAGS="\ --xla_gpu_enable_custom_fusions=true " export XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 +unset ALL_REDUCE_THRESHOLD_BYTES ALL_GATHER_THRESHOLD_BYTES REDUCE_SCATTER_THRESHOLD_BYTES set +x