From ebe1bcf1e43d1754fe6399db1104d955bf7813ec Mon Sep 17 00:00:00 2001 From: Olivier Delalleau <507137+odelalleau@users.noreply.github.com> Date: Thu, 25 Jan 2024 15:07:41 -0500 Subject: [PATCH] Add extra call to `clear_memory()` to fix OOM when sending critic data (#91) Signed-off-by: Olivier Delalleau <507137+odelalleau@users.noreply.github.com> --- CHANGELOG.md | 1 + nemo_aligner/algorithms/ppo.py | 1 + 2 files changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b2db7301b..a53fee7cf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) ### New features and optimizations - Added public-facing official Dockerfile for NeMo-Aligner +- Memory optimization in PPO that helps avoid OOM in the actor when sending training data to the critic ### Breaking changes diff --git a/nemo_aligner/algorithms/ppo.py b/nemo_aligner/algorithms/ppo.py index 6071f5427..87b81ebe0 100644 --- a/nemo_aligner/algorithms/ppo.py +++ b/nemo_aligner/algorithms/ppo.py @@ -385,6 +385,7 @@ def fit(self): timing_metrics["rollout_time"] = self.timer.get("rollout_time") # send critic train + clear_memory() self.rm_critic.train(ppo_rollout_data) # logging