BALL
1.4.1
|
00001 /* classificationValidation.h 00002 * 00003 * Copyright (C) 2009 Marcel Schumann 00004 * 00005 * This file is part of QuEasy -- A Toolbox for Automated QSAR Model 00006 * Construction and Validation. 00007 * QuEasy is free software; you can redistribute it and/or modify 00008 * it under the terms of the GNU General Public License as published by 00009 * the Free Software Foundation; either version 3 of the License, or (at 00010 * your option) any later version. 00011 * 00012 * QuEasy is distributed in the hope that it will be useful, but 00013 * WITHOUT ANY WARRANTY; without even the implied warranty of 00014 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 00015 * General Public License for more details. 00016 * 00017 * You should have received a copy of the GNU General Public License 00018 * along with this program; if not, see <http://www.gnu.org/licenses/>. 00019 */ 00020 00021 // -*- Mode: C++; tab-width: 2; -*- 00022 // vi: set ts=2: 00023 // 00024 // 00025 00026 #ifndef CLASVALIDATION 00027 #define CLASVALIDATION 00028 00029 #ifndef QSARDATA 00030 #include <BALL/QSAR/QSARData.h> 00031 #endif 00032 00033 #ifndef VALIDATION 00034 #include <BALL/QSAR/validation.h> 00035 #endif 00036 00037 #include <gsl/gsl_randist.h> 00038 #include <gsl/gsl_cdf.h> 00039 #include <iterator> 00040 00041 00042 namespace BALL 00043 { 00044 namespace QSAR 00045 { 00046 class ClassificationModel; 00048 class BALL_EXPORT ClassificationValidation : public Validation 00049 { 00050 public: 00056 ClassificationValidation(ClassificationModel* m); 00058 00059 00063 void crossValidation(int k, bool restore=1); 00064 00065 double getCVRes(); 00066 00067 double getFitRes(); 00068 00069 void setCVRes(double d); 00070 00071 void testInputData(bool transform=0); 00072 00074 const BALL::Matrix<double>* getConfusionMatrix(); 00075 00077 const BALL::Vector<double>* getClassResults(); 00078 00081 void bootstrap(int k, bool restore=1); 00082 00086 const BALL::Matrix<double>& yRandomizationTest(int runs, int k); 00087 00089 double getAccuracyCV(); 00090 00092 double getAccuracyInputTest(); 00093 00094 void selectStat(int s); 00095 00096 void saveToFile(string filename) const; 00097 00098 void saveToFile(string filename, const double& quality_input_test, const double& predictive_quality) const; 00099 00100 void readFromFile(string filename); 00101 00103 00104 00105 private: 00110 void testAllSubstances(bool transform); 00111 00113 void calculateAverageSensitivity(); 00114 00116 void calculateWeightedSensitivity(); 00117 00119 void calculateOverallAccuracy(); 00120 00122 void calculateAverageMCC(); 00123 00125 void calculateOverallMCC(); 00126 00128 void calculateTDR(); 00130 00131 00136 BALL::Matrix<double> confusion_matrix_; 00137 00139 Vector<double> class_results_; 00140 00141 double quality_; 00142 00143 double quality_input_test_; 00144 00145 double quality_cv_; 00146 00148 ClassificationModel* clas_model; 00149 00150 void (ClassificationValidation::* qualCalculation)(); 00151 00152 00154 00155 }; 00156 } 00157 } 00158 00159 00160 00161 #endif // REGVALIDATION