diff --git a/src/colvar_arithmeticpath.h b/src/colvar_arithmeticpath.h index 6a5cf1534..4da249977 100644 --- a/src/colvar_arithmeticpath.h +++ b/src/colvar_arithmeticpath.h @@ -23,6 +23,20 @@ T logsumexp(const vector& a, const vector& b) { return max_a + cvm::logn(sum); } +template +vector softmax(const vector& a) { + const auto max_a = *std::max_element(a.begin(), a.end()); + T sum = T(); + vector out(a.size()); + for (size_t i = 0; i < a.size(); ++i) { + sum += cvm::exp(a[i] - max_a); + } + for (size_t i = 0; i < a.size(); ++i) { + out[i] = cvm::exp(a[i] - max_a) / sum; + } + return out; +} + template T logsumexp(const vector& a) { const auto max_a = *std::max_element(a.begin(), a.end());