Skip to content

Commit

Permalink
move platt-transform expcetion to python
Browse files Browse the repository at this point in the history
  • Loading branch information
jiong-zhang committed Dec 6, 2023
1 parent d259cbe commit 3fdbdd9
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 18 deletions.
23 changes: 18 additions & 5 deletions pecos/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2049,12 +2049,12 @@ def link_calibrator_methods(self):
"""
corelib.fillprototype(
self.clib_float32.c_fit_platt_transform_f32,
None,
c_uint32,
[c_uint64, POINTER(c_float), POINTER(c_float), POINTER(c_double)],
)
corelib.fillprototype(
self.clib_float32.c_fit_platt_transform_f64,
None,
c_uint32,
[c_uint64, POINTER(c_double), POINTER(c_double), POINTER(c_double)],
)

Expand All @@ -2080,14 +2080,14 @@ def fit_platt_transform(self, logits, tgt_prob):
AB = np.array([0, 0], dtype=np.float64)

if tgt_prob.dtype == np.float32:
clib.clib_float32.c_fit_platt_transform_f32(
return_code = clib.clib_float32.c_fit_platt_transform_f32(
len(logits),
logits.ctypes.data_as(POINTER(c_float)),
tgt_prob.ctypes.data_as(POINTER(c_float)),
AB.ctypes.data_as(POINTER(c_double)),
)
elif tgt_prob.dtype == np.float64:
clib.clib_float32.c_fit_platt_transform_f64(
return_code = clib.clib_float32.c_fit_platt_transform_f64(
len(logits),
logits.ctypes.data_as(POINTER(c_double)),
tgt_prob.ctypes.data_as(POINTER(c_double)),
Expand All @@ -2096,7 +2096,20 @@ def fit_platt_transform(self, logits, tgt_prob):
else:
raise ValueError(f"Unsupported dtype: {tgt_prob.dtype}")

return AB[0], AB[1]
PLATT_RETURN_CODE = {
"SUCCESS": 0,
"LINE_SEARCH_FAIL": 1,
"MAX_ITER_REACHED": 2,
}

if return_code == PLATT_RETURN_CODE["SUCCESS"]:
return AB[0], AB[1]
elif return_code == PLATT_RETURN_CODE["LINE_SEARCH_FAIL"]:
raise RuntimeError("fit_platt_transform: Line search fails")
elif return_code == PLATT_RETURN_CODE["MAX_ITER_REACHED"]:
raise RuntimeError("fit_platt_transform: Reaching maximal iterations")
else:
raise ValueError(f"Unknown return code {return_code}")


clib = corelib(os.path.join(os.path.dirname(os.path.abspath(pecos.__file__)), "core"), "libpecos")
4 changes: 2 additions & 2 deletions pecos/core/libpecos.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -752,13 +752,13 @@ extern "C" {
// ==== C Interface of Score Calibrator ====

#define C_FIT_PLATT_TRANSFORM(SUFFIX, VAL_TYPE) \
void c_fit_platt_transform ## SUFFIX( \
uint32_t c_fit_platt_transform ## SUFFIX( \
size_t num_samples, \
const VAL_TYPE* logits, \
const VAL_TYPE* tgt_probs, \
double* AB \
) { \
pecos::fit_platt_transform(num_samples, logits, tgt_probs, AB[0], AB[1]); \
return pecos::fit_platt_transform(num_samples, logits, tgt_probs, AB[0], AB[1]); \
}
C_FIT_PLATT_TRANSFORM(_f32, float32_t)
C_FIT_PLATT_TRANSFORM(_f64, float64_t)
Expand Down
24 changes: 13 additions & 11 deletions pecos/core/utils/newton.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,14 @@ namespace pecos {
// https://github.com/cjlin1/libsvm/blob/master/svm.cpp

template <typename value_type>
static void fit_platt_transform(size_t num_samples, const value_type *logits, const value_type *tgt_probs, double& A, double& B) {
uint32_t fit_platt_transform(size_t num_samples, const value_type *logits, const value_type *tgt_probs, double& A, double& B) {
// define the return code
enum {
SUCCESS=0,
LINE_SEARCH_FAIL=1,
MAX_ITER_REACHED=2,
};

// hyper parameters
int max_iter = 100; // Maximal number of iterations
double min_step = 1e-10; // Minimal step taken in line search
Expand All @@ -292,14 +299,6 @@ namespace pecos {
A = 0.0; B = 1.0;
double fval = 0.0;

// check for out of bound in tgt_probs
for (size_t i = 0; i < num_samples; i++) {
if (tgt_probs[i] > 1.0 || tgt_probs[i] < 0) {
throw std::runtime_error("fit_platt_transform: target probability out of bound\n");
}
}


for (size_t i = 0; i < num_samples; i++) {
double fApB = logits[i] * A + B;
if (fApB >= 0) {
Expand Down Expand Up @@ -376,13 +375,16 @@ namespace pecos {
}

if (stepsize < min_step) {
throw std::runtime_error("fit_platt_transform: Line search fails\n");
printf("WARNING: fit_platt_transform: Line search fails\n");
return LINE_SEARCH_FAIL;
}
}

if (iter >= max_iter) {
throw std::runtime_error("fit_platt_transform: Reaching maximal iterations\n");
printf("WARNING: fit_platt_transform: Reaching maximal iterations\n");
return MAX_ITER_REACHED;
}
return SUCCESS;
}
} // namespace pecos
#endif

0 comments on commit 3fdbdd9

Please sign in to comment.