PerformanceMeasures.h
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #ifndef __PERFORMANCEMEASURES_H_
00012 #define __PERFORMANCEMEASURES_H_
00013
00014 #include "base/SGObject.h"
00015 #include "features/Labels.h"
00016 #include "lib/DynamicArray.h"
00017
00018 namespace shogun
00019 {
00020 class CLabels;
00021
00045 class CPerformanceMeasures : public CSGObject
00046 {
00047 public:
00049 CPerformanceMeasures();
00050
00056 CPerformanceMeasures(CLabels* true_labels, CLabels* output);
00057
00058 virtual ~CPerformanceMeasures();
00059
00065 void init(CLabels* true_labels, CLabels* output);
00066
00072 inline bool set_true_labels(CLabels* true_labels)
00073 {
00074 m_true_labels=true_labels;
00075 SG_REF(true_labels);
00076 return true;
00077 }
00078
00083 inline CLabels* get_true_labels() const { return m_true_labels; }
00084
00090 inline bool set_output(CLabels* output)
00091 {
00092 m_output=output;
00093 SG_REF(output);
00094 return true;
00095 }
00096
00101 inline CLabels* get_output() const { return m_output; }
00102
00107 inline int32_t get_num_labels() const { return m_num_labels; }
00108
00121 void get_ROC(float64_t** result, int32_t* num, int32_t* dim);
00122
00129 inline float64_t get_auROC()
00130 {
00131 if (m_auROC==CMath::ALMOST_NEG_INFTY) {
00132 float64_t** roc=(float64_t**) malloc(sizeof(float64_t**));
00133 compute_ROC(roc);
00134 free(*roc);
00135 free(roc);
00136 }
00137 return m_auROC;
00138 }
00139
00146 inline float64_t get_aoROC()
00147 {
00148 return 1.0-get_auROC();
00149 }
00150
00163 void get_PRC(float64_t** result, int32_t* num, int32_t* dim);
00164
00171 inline float64_t get_auPRC()
00172 {
00173 if (m_auPRC==CMath::ALMOST_NEG_INFTY) {
00174 float64_t** prc=(float64_t**) malloc(sizeof(float64_t**));
00175 compute_PRC(prc);
00176 free(*prc);
00177 free(prc);
00178 }
00179 return m_auPRC;
00180 }
00181
00188 inline float64_t get_aoPRC()
00189 {
00190 return 1-get_auPRC();
00191 }
00192
00205 void get_DET(float64_t** result, int32_t* num, int32_t* dim);
00206
00213 inline float64_t get_auDET()
00214 {
00215 if (m_auDET==CMath::ALMOST_NEG_INFTY) {
00216 float64_t** det=(float64_t**) malloc(sizeof(float64_t**));
00217 compute_DET(det);
00218 free(*det);
00219 free(det);
00220 }
00221 return m_auDET;
00222 }
00223
00230 inline float64_t get_aoDET()
00231 {
00232 return 1-get_auDET();
00233 }
00234
00246 void get_all_accuracy(float64_t** result, int32_t* num, int32_t* dim);
00247
00254 float64_t get_accuracy(float64_t threshold=0);
00255
00267 void get_all_error(float64_t** result, int32_t* num, int32_t* dim);
00268
00277 inline float64_t get_error(float64_t threshold=0)
00278 {
00279 return 1.0-get_accuracy(threshold);
00280 }
00281
00293 void get_all_fmeasure(float64_t** result, int32_t* num, int32_t* dim);
00294
00299 float64_t get_fmeasure(float64_t threshold=0);
00300
00328 void get_all_CC(float64_t** result, int32_t* num, int32_t* dim);
00329
00334 float64_t get_CC(float64_t threshold=0);
00335
00353 void get_all_WRAcc(float64_t** result, int32_t* num, int32_t* dim);
00354
00359 float64_t get_WRAcc(float64_t threshold=0);
00360
00378 void get_all_BAL(float64_t** result, int32_t* num, int32_t* dim);
00379
00384 float64_t get_BAL(float64_t threshold=0);
00385
00390 inline virtual const char* get_name() const { return "PerformanceMeasures"; }
00391
00392 protected:
00394 void init_nolabels();
00395
00404 float64_t trapezoid_area(float64_t x1, float64_t x2, float64_t y1, float64_t y2);
00405
00409 void create_sortedROC();
00410
00414 void compute_ROC(float64_t** result);
00415
00423 void compute_accuracy(
00424 float64_t** result, int32_t* num, int32_t* dim, bool do_error=false);
00425
00430 void compute_PRC(float64_t** result);
00431
00436 void compute_DET(float64_t** result);
00437
00448 void compute_confusion_matrix(
00449 float64_t threshold,
00450 int32_t* tp, int32_t* fp, int32_t* fn, int32_t* tn);
00451
00452 protected:
00454 CLabels* m_true_labels;
00456 CLabels* m_output;
00458 int32_t m_num_labels;
00459
00461 int32_t m_all_true;
00463 int32_t m_all_false;
00464
00467 int32_t* m_sortedROC;
00469 float64_t m_auROC;
00471 float64_t m_auPRC;
00473 float64_t m_auDET;
00474 };
00475 }
00476 #endif