From 247a710db9c5a4963894cfbfda3e4321a973d054 Mon Sep 17 00:00:00 2001 From: HanatoK Date: Mon, 31 Jul 2023 10:28:32 -0500 Subject: [PATCH] Fix compilation and derivatives --- src/colvar_arithmeticpath.h | 33 ++++++++++++++------------------- src/colvarcomp_apath.cpp | 4 ++-- 2 files changed, 16 insertions(+), 21 deletions(-) diff --git a/src/colvar_arithmeticpath.h b/src/colvar_arithmeticpath.h index 92e18914c..2eb615140 100644 --- a/src/colvar_arithmeticpath.h +++ b/src/colvar_arithmeticpath.h @@ -8,6 +8,7 @@ #include #include #include +#include namespace ArithmeticPathCV { @@ -65,16 +66,17 @@ class ArithmeticPathBase { public: ArithmeticPathBase() {} virtual ~ArithmeticPathBase() {} - virtual void initialize(size_t p_num_elements, size_t p_total_frames, double p_lambda, const vector& p_element, const vector& p_weights); + void initialize(size_t p_num_elements, size_t p_total_frames, double p_lambda, const vector& p_element, const vector& p_weights); virtual void updateDistanceToReferenceFrames() = 0; - virtual void computeValue(); - virtual void computeDerivatives(); - virtual void compute(); - virtual void reComputeLambda(const vector& rmsd_between_refs); + void computeValue(); + // can only be called after computeValue() + template ::value, bool>::type = true> + void computeDerivatives(); + void reComputeLambda(const vector& rmsd_between_refs); // for element-wise derivatives - virtual void computeDerivatives(const size_t frame_index); + void computeElementWiseDerivatives(const size_t frame_index); // must be called before vector computeDerivatives(const size_t frame_index) - virtual void updateSoftMaxOut(); + void updateSoftMaxOut(); protected: scalar_type lambda; vector weights; @@ -141,12 +143,7 @@ void ArithmeticPathBase::computeValue() { } template -void ArithmeticPathBase::compute() { - computeValue(); - computeDerivatives(); -} - -template +template ::value, bool>::type> void ArithmeticPathBase::computeDerivatives() { for (size_t j_elem = 0; j_elem < num_elements; ++j_elem) { for (size_t i_frame = 0; i_frame < frame_element_distances.size(); ++i_frame) { @@ -181,15 +178,13 @@ void ArithmeticPathBase::updateSoftMaxOut( // frame-wise derivatives for frames using optimal rotation template -void ArithmeticPathBase::computeDerivatives(size_t frame_index) { - const auto b = static_cast(total_frames - 1) * s; +void ArithmeticPathBase::computeElementWiseDerivatives(size_t frame_index) { + const auto tmp = (static_cast(frame_index) - static_cast(total_frames - 1) * s) * normalization_factor; for (size_t j_elem = 0; j_elem < num_elements; ++j_elem) { const auto a_i = softmax_out[frame_index]; - dzdx[j_elem] = 2.0 * frame_element_distances[frame_index][j_elem] * a_i; + dzdx[j_elem] = 2.0 * weights[j_elem] * weights[j_elem] * frame_element_distances[frame_index][j_elem] * a_i; dsdx[j_elem] = -2.0 * weights[j_elem] * weights[j_elem] * lambda * - frame_element_distances[frame_index][j_elem] * a_i * - (static_cast(frame_index) - b) * - normalization_factor; + frame_element_distances[frame_index][j_elem] * a_i * tmp; } } } diff --git a/src/colvarcomp_apath.cpp b/src/colvarcomp_apath.cpp index 26b96f6c4..c2b259a0d 100644 --- a/src/colvarcomp_apath.cpp +++ b/src/colvarcomp_apath.cpp @@ -57,7 +57,7 @@ void colvar::aspath::calc_value() { void colvar::aspath::calc_gradients() { updateSoftMaxOut(); for (size_t i_frame = 0; i_frame < reference_frames.size(); ++i_frame) { - computeDerivatives(i_frame); + computeElementWiseDerivatives(i_frame); for (size_t i_atom = 0; i_atom < atoms->size(); ++i_atom) { (*(comp_atoms[i_frame]))[i_atom].grad += dsdx[i_atom]; } @@ -109,7 +109,7 @@ void colvar::azpath::calc_value() { void colvar::azpath::calc_gradients() { updateSoftMaxOut(); for (size_t i_frame = 0; i_frame < reference_frames.size(); ++i_frame) { - computeDerivatives(i_frame); + computeElementWiseDerivatives(i_frame); for (size_t i_atom = 0; i_atom < atoms->size(); ++i_atom) { (*(comp_atoms[i_frame]))[i_atom].grad += dzdx[i_atom]; }