Skip to content

Commit

Permalink
Do not backpropagate through layers without gradient sources (#2423)
Browse files Browse the repository at this point in the history
  • Loading branch information
tbennun authored Feb 12, 2024
1 parent 3492fbd commit 32a761b
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 42 deletions.
2 changes: 2 additions & 0 deletions include/lbann/layers/layer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,8 @@ class Layer
template <hydrogen::Device Device>
friend class kfac_block_gru;

friend class model;

public:
/** @name Lifecycle */
///@{
Expand Down
18 changes: 14 additions & 4 deletions include/lbann/models/model.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,11 @@ class model
*/
void setup_weights();

/** @brief Tests whether a layer would be needed to compute through during
* backpropagation
*/
bool is_layer_needed_for_backprop(const Layer* l) const;

///@}
/** @name Subgraph parallelism implementation */
///@{
Expand Down Expand Up @@ -561,6 +566,14 @@ class model
/** @brief Current callbacks to process. */
std::vector<std::shared_ptr<callback_base>> m_callbacks;

/** @brief A set of layers needed for backpropagation.
* @details This set is populated by model::forward_prop and controls
* which layers will be computed during backpropagation. If the
* `NO_BACKPROP_DISABLE` option is enabled, this set will not change the
* behavior of backpropagation.
*/
std::unordered_set<const Layer*> m_needed_for_backprop;

/** @brief Is the model setup
* @details Flag to indicate if the setup function has been called
*/
Expand Down Expand Up @@ -793,10 +806,7 @@ model::set_current_mini_batch_size(uint64_t mini_batch_size) noexcept
return;
}

inline bool model::is_amp_enabled() const noexcept
{
return m_amp_enabled;
}
inline bool model::is_amp_enabled() const noexcept { return m_amp_enabled; }

inline EvalType model::get_amp_scale_factor() const noexcept
{
Expand Down
4 changes: 2 additions & 2 deletions src/callbacks/check_gradients.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,8 @@ void check_gradients::do_check_gradients(model& m) const
m.get_objective_function()->differentiate();
m.get_objective_function()->compute_weight_regularization();

// Compute analytical gradients through model
m.backward_prop(false, /*skip_callbacks=*/true);
// Compute all analytical gradients through model
m.backward_prop(/*compute_weight_grads_only=*/false, /*skip_callbacks=*/true);

// Choose finite difference step
// Note: Consider a central difference scheme:
Expand Down
123 changes: 87 additions & 36 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,34 @@ void model::serialize_to_onnx(onnx::ModelProto& mp)
}
#endif // LBANN_HAS_ONNX

bool model::is_layer_needed_for_backprop(const Layer* l) const
{
// First, check parents. If one of the parent layers require gradients,
// this layer does too.
for (int i = 0; i < l->get_num_parents(); ++i) {
if (m_needed_for_backprop.find(&l->get_parent_layer(i)) !=
m_needed_for_backprop.end()) {
return true;
}
}

// Second, check the layer itself. If frozen, backprop is not necessary.
if (l->is_frozen()) {
return false;
}

// Otherwise, check weight optimizers. If one of the associated optimizers
// is not nullptr, then backprop will be necessary.
for (size_t i = 0; i < l->num_weights(); ++i) {
if (l->get_weights(i).get_optimizer() != nullptr) {
return true;
}
}

// Not needed for backprop
return false;
}

// =============================================
// Model specification
// =============================================
Expand Down Expand Up @@ -929,24 +957,20 @@ void model::setup_subcommunicators(const std::vector<El::Grid*>& fngrids)
for (El::Int node = 0; node < num_layers; ++node) {
Layer* const layer = layers[node];
std::string const& layer_type = layer->get_type();
if ((layer_type == "slice" ||
layer_type == "split" ||
layer_type == "concatenate" ||
layer_type == "sum") &&
if ((layer_type == "slice" || layer_type == "split" ||
layer_type == "concatenate" || layer_type == "sum") &&
layer->subgraph_parallelism_execution()) {
if (subCommunicatorsSubgrids.find(one_index) !=
subCommunicatorsSubgrids.end()) {
layer->reset_inter_subgrid_vc_comm(
subCommunicatorsSubgrids[one_index]);
layer->reset_inter_subgrid_vc_comm(subCommunicatorsSubgrids[one_index]);
}
else {
subCommunicatorsSubgrids[one_index] = std::make_shared<El::mpi::Comm>();
const auto& childs = layer->get_child_layers();

int indexSubgrid = -1;
for (int child = 0; child < layer->get_num_children(); ++child) {
if (fngrids.at(childs[child]->get_grid_tag())->InGrid())
{
if (fngrids.at(childs[child]->get_grid_tag())->InGrid()) {
indexSubgrid = child;
}
}
Expand All @@ -955,9 +979,17 @@ void model::setup_subcommunicators(const std::vector<El::Grid*>& fngrids)
const int layer_tag = layer->get_grid_tag();

if (child_tag < 0)
LBANN_ERROR("child_tag=", child_tag, " (child=", childs[indexSubgrid]->get_name(), ")");
LBANN_ERROR("child_tag=",
child_tag,
" (child=",
childs[indexSubgrid]->get_name(),
")");
if (layer_tag < 0)
LBANN_ERROR("layer_tag=", layer_tag, " (layer=", layer->get_name(), ")");
LBANN_ERROR("layer_tag=",
layer_tag,
" (layer=",
layer->get_name(),
")");

const int posInSubGrid = fngrids[child_tag]->VCRank();
const int posInGrid = fngrids[layer_tag]->ViewingRank();
Expand All @@ -966,15 +998,13 @@ void model::setup_subcommunicators(const std::vector<El::Grid*>& fngrids)
posInGrid,
*subCommunicatorsSubgrids[one_index]);

layer->reset_inter_subgrid_vc_comm(
subCommunicatorsSubgrids[one_index]);
layer->reset_inter_subgrid_vc_comm(subCommunicatorsSubgrids[one_index]);
}
}

if (layer_type == "cross_grid_sum" ||
layer_type == "cross_grid_sum_slice") {
layer->reset_inter_subgrid_vc_comm(
subCommunicatorsSubgrids[one_index]);
layer->reset_inter_subgrid_vc_comm(subCommunicatorsSubgrids[one_index]);
}
}
}
Expand Down Expand Up @@ -1321,14 +1351,15 @@ void model::add_split_layers(std::unordered_set<std::string>& layer_names)
l.get_data_layout(),
l.get_device_allocation());

#define PROTO_DEVICE_LAYOUT(T_datatype, T_layout, T_device) \
if (args == args_tuple(std::type_index(typeid(T_datatype)), T_layout, T_device)) { \
split.reset(new split_layer<T_datatype, T_layout, T_device>(m_comm)); \
}
#define PROTO_DEVICE_LAYOUT(T_datatype, T_layout, T_device) \
if (args == \
args_tuple(std::type_index(typeid(T_datatype)), T_layout, T_device)) { \
split.reset(new split_layer<T_datatype, T_layout, T_device>(m_comm)); \
}

#define PROTO_DEVICE(T_datatype, T_device) \
PROTO_DEVICE_LAYOUT(T_datatype, data_layout::DATA_PARALLEL, T_device); \
PROTO_DEVICE_LAYOUT(T_datatype, data_layout::MODEL_PARALLEL, T_device);
#define PROTO_DEVICE(T_datatype, T_device) \
PROTO_DEVICE_LAYOUT(T_datatype, data_layout::DATA_PARALLEL, T_device); \
PROTO_DEVICE_LAYOUT(T_datatype, data_layout::MODEL_PARALLEL, T_device);

#include "lbann/macros/instantiate_device.hpp"
#undef PROTO_DEVICE_LAYOUT
Expand Down Expand Up @@ -1449,7 +1480,7 @@ void model::remove_layer(std::string const& removable_layer_name)
auto& parent =
const_cast<Layer&>(l.get_parent_layer(0)); // assuming only one parent
auto& child =
const_cast<Layer&>(l.get_child_layer(0)); // assuming only one child
const_cast<Layer&>(l.get_child_layer(0)); // assuming only one child

// Setup relationship between parent layer and child layer
child.replace_parent_layer(l.get_parent_layer_pointer(0),
Expand Down Expand Up @@ -1501,7 +1532,7 @@ void model::replace_layer(OwningLayerPtr&& new_layer,
auto& parent =
const_cast<Layer&>(l.get_parent_layer(0)); // assuming only one parent
auto& child =
const_cast<Layer&>(l.get_child_layer(0)); // assuming only one child
const_cast<Layer&>(l.get_child_layer(0)); // assuming only one child

// Setup relationship between the new layer and child of old layer (which
// becomes child of new layer)
Expand Down Expand Up @@ -1582,6 +1613,9 @@ void model::forward_prop(execution_mode mode, bool skip_callbacks)
// Clear activations in reference counter
m_activation_refcnt.clear();

// Clear layers that will be required in backpropagation
m_needed_for_backprop.clear();

for (El::Int i = 0; i < get_num_layers(); ++i) {
auto& l = get_layer(i);

Expand All @@ -1605,6 +1639,9 @@ void model::forward_prop(execution_mode mode, bool skip_callbacks)
if (!skip_callbacks)
do_layer_forward_prop_end_cbs(mode, &l);
}

if (is_layer_needed_for_backprop(&l))
m_needed_for_backprop.insert(&l);
}
if (!skip_callbacks)
do_model_forward_prop_end_cbs(mode);
Expand All @@ -1627,8 +1664,19 @@ void model::backward_prop(bool compute_weight_grads_only, bool skip_callbacks)

// Perform backward prop step on current layer
auto& l = get_layer(i);
bool enable_layer = (!envvar_disable_layers ||
disabled_layers.find(&l) == disabled_layers.end());

// Check if layer should be skipped
bool enable_layer = true;
if (envvar_disable_layers) {
// Based on backpropagation requirements
if (disabled_layers.find(&l) != disabled_layers.end())
enable_layer = false;

// Based on gradient/optimizer requirements
if (compute_weight_grads_only && m_needed_for_backprop.size() > 0 &&
m_needed_for_backprop.find(&l) == m_needed_for_backprop.end())
enable_layer = false;
}

// Check if all children skip gradient backpropagation
if (enable_layer && envvar_disable_layers) {
Expand Down Expand Up @@ -1746,25 +1794,27 @@ void model::update_weights()
++m_amp_cur_skipped_steps;
// Keep scale factor to the smallest positive normalized value for
// floats. Even when EvalType is double, we may cast to float.
m_amp_scale_factor = std::max(
static_cast<EvalType>(std::numeric_limits<float>::min()),
m_amp_scale_factor * m_amp_backoff_factor);
m_amp_scale_factor =
std::max(static_cast<EvalType>(std::numeric_limits<float>::min()),
m_amp_scale_factor * m_amp_backoff_factor);
// Warn if we've been skipping too many steps.
// Check exact number to avoid printing repeatedly.
if (m_amp_cur_skipped_steps == 10) {
LBANN_WARNING(
"AMP skipped ten steps in a row, your model may have issues with AMP");
LBANN_WARNING("AMP skipped ten steps in a row, your model may have "
"issues with AMP");
}
} else {
}
else {
if (m_amp_cur_steps + 1 == m_amp_growth_interval) {
m_amp_cur_steps = 0;
m_amp_cur_skipped_steps = 0;
// Prevent scale factor from overflowing to inf when cast to
// float.
m_amp_scale_factor = std::min(
static_cast<EvalType>(std::numeric_limits<float>::max()),
m_amp_scale_factor * m_amp_growth_factor);
} else {
m_amp_scale_factor =
std::min(static_cast<EvalType>(std::numeric_limits<float>::max()),
m_amp_scale_factor * m_amp_growth_factor);
}
else {
++m_amp_cur_steps;
}
}
Expand Down Expand Up @@ -1809,7 +1859,8 @@ void model::reconcile_weight_values()
void model::enable_amp(EvalType init_scale_factor,
EvalType growth_factor,
EvalType backoff_factor,
size_t growth_interval) {
size_t growth_interval)
{
m_amp_enabled = true;
m_amp_scale_factor = init_scale_factor;
m_amp_growth_factor = growth_factor;
Expand Down

0 comments on commit 32a761b

Please sign in to comment.