Skip to content

Commit

Permalink
Fix compilation and derivatives
Browse files Browse the repository at this point in the history
  • Loading branch information
HanatoK committed Jul 31, 2023
1 parent 735f0c8 commit 247a710
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 21 deletions.
33 changes: 14 additions & 19 deletions src/colvar_arithmeticpath.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <limits>
#include <string>
#include <algorithm>
#include <type_traits>

namespace ArithmeticPathCV {

Expand Down Expand Up @@ -65,16 +66,17 @@ class ArithmeticPathBase {
public:
ArithmeticPathBase() {}
virtual ~ArithmeticPathBase() {}
virtual void initialize(size_t p_num_elements, size_t p_total_frames, double p_lambda, const vector<element_type>& p_element, const vector<double>& p_weights);
void initialize(size_t p_num_elements, size_t p_total_frames, double p_lambda, const vector<element_type>& p_element, const vector<double>& p_weights);
virtual void updateDistanceToReferenceFrames() = 0;
virtual void computeValue();
virtual void computeDerivatives();
virtual void compute();
virtual void reComputeLambda(const vector<scalar_type>& rmsd_between_refs);
void computeValue();
// can only be called after computeValue()
template <typename U = element_type, typename std::enable_if<std::is_convertible<U, scalar_type>::value, bool>::type = true>
void computeDerivatives();
void reComputeLambda(const vector<scalar_type>& rmsd_between_refs);
// for element-wise derivatives
virtual void computeDerivatives(const size_t frame_index);
void computeElementWiseDerivatives(const size_t frame_index);
// must be called before vector<element_type> computeDerivatives(const size_t frame_index)
virtual void updateSoftMaxOut();
void updateSoftMaxOut();
protected:
scalar_type lambda;
vector<scalar_type> weights;
Expand Down Expand Up @@ -141,12 +143,7 @@ void ArithmeticPathBase<element_type, scalar_type, path_type>::computeValue() {
}

template <typename element_type, typename scalar_type, path_sz path_type>
void ArithmeticPathBase<element_type, scalar_type, path_type>::compute() {
computeValue();
computeDerivatives();
}

template <typename element_type, typename scalar_type, path_sz path_type>
template <typename U, typename std::enable_if<std::is_convertible<U, scalar_type>::value, bool>::type>
void ArithmeticPathBase<element_type, scalar_type, path_type>::computeDerivatives() {
for (size_t j_elem = 0; j_elem < num_elements; ++j_elem) {
for (size_t i_frame = 0; i_frame < frame_element_distances.size(); ++i_frame) {
Expand Down Expand Up @@ -181,15 +178,13 @@ void ArithmeticPathBase<element_type, scalar_type, path_type>::updateSoftMaxOut(

// frame-wise derivatives for frames using optimal rotation
template <typename element_type, typename scalar_type, path_sz path_type>
void ArithmeticPathBase<element_type, scalar_type, path_type>::computeDerivatives(size_t frame_index) {
const auto b = static_cast<scalar_type>(total_frames - 1) * s;
void ArithmeticPathBase<element_type, scalar_type, path_type>::computeElementWiseDerivatives(size_t frame_index) {
const auto tmp = (static_cast<scalar_type>(frame_index) - static_cast<scalar_type>(total_frames - 1) * s) * normalization_factor;
for (size_t j_elem = 0; j_elem < num_elements; ++j_elem) {
const auto a_i = softmax_out[frame_index];
dzdx[j_elem] = 2.0 * frame_element_distances[frame_index][j_elem] * a_i;
dzdx[j_elem] = 2.0 * weights[j_elem] * weights[j_elem] * frame_element_distances[frame_index][j_elem] * a_i;
dsdx[j_elem] = -2.0 * weights[j_elem] * weights[j_elem] * lambda *
frame_element_distances[frame_index][j_elem] * a_i *
(static_cast<scalar_type>(frame_index) - b) *
normalization_factor;
frame_element_distances[frame_index][j_elem] * a_i * tmp;
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/colvarcomp_apath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ void colvar::aspath::calc_value() {
void colvar::aspath::calc_gradients() {
updateSoftMaxOut();
for (size_t i_frame = 0; i_frame < reference_frames.size(); ++i_frame) {
computeDerivatives(i_frame);
computeElementWiseDerivatives(i_frame);
for (size_t i_atom = 0; i_atom < atoms->size(); ++i_atom) {
(*(comp_atoms[i_frame]))[i_atom].grad += dsdx[i_atom];
}
Expand Down Expand Up @@ -109,7 +109,7 @@ void colvar::azpath::calc_value() {
void colvar::azpath::calc_gradients() {
updateSoftMaxOut();
for (size_t i_frame = 0; i_frame < reference_frames.size(); ++i_frame) {
computeDerivatives(i_frame);
computeElementWiseDerivatives(i_frame);
for (size_t i_atom = 0; i_atom < atoms->size(); ++i_atom) {
(*(comp_atoms[i_frame]))[i_atom].grad += dzdx[i_atom];
}
Expand Down

0 comments on commit 247a710

Please sign in to comment.