Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Robust stop criteria for fitting Platt transform #277

Merged
merged 1 commit into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions pecos/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2050,23 +2050,46 @@ def link_calibrator_methods(self):
corelib.fillprototype(
self.clib_float32.c_fit_platt_transform_f32,
c_uint32,
[c_uint64, POINTER(c_float), POINTER(c_float), POINTER(c_double)],
[
c_uint64,
POINTER(c_float),
POINTER(c_float),
POINTER(c_double),
c_uint64, # max_iter
c_double, # eps
],
)
corelib.fillprototype(
self.clib_float32.c_fit_platt_transform_f64,
c_uint32,
[c_uint64, POINTER(c_double), POINTER(c_double), POINTER(c_double)],
[
c_uint64,
POINTER(c_double),
POINTER(c_double),
POINTER(c_double),
c_uint64, # max_iter
c_double, # eps
],
)

def fit_platt_transform(self, logits, targets, clip_tgt_prob=True):
def fit_platt_transform(
self,
logits,
targets,
max_iter=100,
eps=1e-5,
clip_tgt_prob=True,
):
"""Python to C/C++ interface for platt transfrom fit.

Ref: https://www.csie.ntu.edu.tw/~cjlin/papers/plattprob.pdf

Args:
logits (ndarray): 1-d array of logit with length N.
targets (ndarray): 1-d array of target probability scores within [0, 1] with length N.
clip_tgt_prob (bool): whether to clip the target probability to
max_iter (int, optional): max number of iterations to train. Default 100
eps (float, optional): epsilon. Defaults to 1e-5
clip_tgt_prob (bool, optional): whether to clip the target probability to
[1/(prior0 + 2), 1 - 1/(prior1 + 2)]
where prior1 = sum(targets), prior0 = N - prior1
Returns:
Expand Down Expand Up @@ -2097,13 +2120,17 @@ def fit_platt_transform(self, logits, targets, clip_tgt_prob=True):
logits.ctypes.data_as(POINTER(c_float)),
tgt_prob.ctypes.data_as(POINTER(c_float)),
AB.ctypes.data_as(POINTER(c_double)),
max_iter,
eps,
)
elif tgt_prob.dtype == np.float64:
return_code = clib.clib_float32.c_fit_platt_transform_f64(
len(logits),
logits.ctypes.data_as(POINTER(c_double)),
tgt_prob.ctypes.data_as(POINTER(c_double)),
AB.ctypes.data_as(POINTER(c_double)),
max_iter,
eps,
)
else:
raise ValueError(f"Unsupported dtype: {tgt_prob.dtype}")
Expand Down
6 changes: 4 additions & 2 deletions pecos/core/libpecos.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -756,9 +756,11 @@ extern "C" {
size_t num_samples, \
const VAL_TYPE* logits, \
const VAL_TYPE* tgt_probs, \
double* AB \
double* AB, \
size_t max_iter, \
double eps \
) { \
return pecos::fit_platt_transform(num_samples, logits, tgt_probs, AB[0], AB[1]); \
return pecos::fit_platt_transform(num_samples, logits, tgt_probs, AB[0], AB[1], max_iter, eps); \
}
C_FIT_PLATT_TRANSFORM(_f32, float32_t)
C_FIT_PLATT_TRANSFORM(_f64, float64_t)
Expand Down
22 changes: 12 additions & 10 deletions pecos/core/utils/newton.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ namespace pecos {
// https://github.com/cjlin1/libsvm/blob/master/svm.cpp

template <typename value_type>
uint32_t fit_platt_transform(size_t num_samples, const value_type *logits, const value_type *tgt_probs, double& A, double& B) {
uint32_t fit_platt_transform(size_t num_samples, const value_type *logits, const value_type *tgt_probs, double& A, double& B, size_t max_iter, double eps) {
// define the return code
enum {
SUCCESS=0,
Expand All @@ -288,10 +288,8 @@ namespace pecos {
};

// hyper parameters
int max_iter = 100; // Maximal number of iterations
double min_step = 1e-10; // Minimal step taken in line search
double sigma = 1e-12; // For numerically strict PD of Hessian
double eps = 1e-5;

// calculate prior of B
double prior1 = 0;
Expand All @@ -300,7 +298,6 @@ namespace pecos {
}
double prior0 = double(num_samples) - prior1;


// Initial Point and Initial Fun Value
A = 0.0; B = log((prior0 + 1.0) / (prior1 + 1.0));
double fval = 0.0;
Expand All @@ -313,7 +310,7 @@ namespace pecos {
fval += (tgt_probs[i] - 1) * fApB + log(1 + exp(fApB));
}
}
int iter;
size_t iter = 0;
for (iter = 0; iter < max_iter; iter++) {
// Update Gradient and Hessian (use H' = H + sigma I)
double h11 = sigma;
Expand Down Expand Up @@ -342,16 +339,22 @@ namespace pecos {
g2 += d1;
}

// Stopping Criteria
if (fabs(g1) < eps && fabs(g2) < eps)
break;

// Finding Newton direction: -inv(H') * g
double det = h11 * h22 - h21 * h21;
double dA = -(h22 * g1 - h21 * g2) / det;
double dB = -(-h21 * g1 + h11 * g2) / det;
double gd = g1 * dA + g2 * dB;

// Stopping Criteria
if (fabs(g1) < eps && fabs(g2) < eps) {
break;
}
// additional stop criteria to handle the case when det is large
if (fabs(dA) < eps && fabs(dB) < eps) {
break;
}

// Line Search
double stepsize = 1.0;

Expand All @@ -370,8 +373,7 @@ namespace pecos {
}
}
// Check sufficient decrease
if (newf < fval + 0.0001 * stepsize * gd)
{
if (newf < fval + 0.0001 * stepsize * gd) {
A = newA;
B = newB;
fval = newf;
Expand Down
Loading