SVM.h
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011 #ifndef _SVM_H___
00012 #define _SVM_H___
00013
00014 #include "lib/common.h"
00015 #include "features/Features.h"
00016 #include "kernel/Kernel.h"
00017 #include "classifier/KernelMachine.h"
00018
00019 namespace shogun
00020 {
00021
00022 class CMKL;
00023
00046 class CSVM : public CKernelMachine
00047 {
00048 public:
00052 CSVM(int32_t num_sv=0);
00053
00061 CSVM(float64_t C, CKernel* k, CLabels* lab);
00062 virtual ~CSVM();
00063
00066 void set_defaults(int32_t num_sv=0);
00067
00068
00074 virtual std::vector<float64_t> get_linear_term();
00075
00076
00082 virtual void set_linear_term(std::vector<float64_t> lin);
00083
00084
00088 bool load(FILE* svm_file);
00089
00093 bool save(FILE* svm_file);
00094
00099 inline void set_nu(float64_t nue) { nu=nue; }
00100
00109 inline void set_C(float64_t c1, float64_t c2) { C1=c1; C2=c2; }
00110
00115 inline void set_epsilon(float64_t eps) { epsilon=eps; }
00116
00121 inline void set_tube_epsilon(float64_t eps) { tube_epsilon=eps; }
00122
00127 inline void set_qpsize(int32_t qps) { qpsize=qps; }
00128
00133 inline float64_t get_epsilon() { return epsilon; }
00134
00139 inline float64_t get_nu() { return nu; }
00140
00145 inline float64_t get_C1() { return C1; }
00146
00151 inline float64_t get_C2() { return C2; }
00152
00157 inline int32_t get_qpsize() { return qpsize; }
00158
00163 inline void set_shrinking_enabled(bool enable)
00164 {
00165 use_shrinking=enable;
00166 }
00167
00172 inline bool get_shrinking_enabled()
00173 {
00174 return use_shrinking;
00175 }
00176
00181 float64_t compute_svm_dual_objective();
00182
00187 float64_t compute_svm_primal_objective();
00188
00193 inline void set_objective(float64_t v)
00194 {
00195 objective=v;
00196 }
00197
00202 inline float64_t get_objective()
00203 {
00204 return objective;
00205 }
00206
00214 void set_callback_function(CMKL* m, bool (*cb)
00215 (CMKL* mkl, const float64_t* sumw, const float64_t suma));
00216
00218 inline virtual const char* get_name() const { return "SVM"; }
00219
00220 #ifdef HAVE_BOOST_SERIALIZATION
00221 friend class ::boost::serialization::access;
00222
00223
00224
00225 template<class Archive>
00226 void serialize(Archive & ar, const unsigned int archive_version)
00227 {
00228
00229 SG_DEBUG("archiving CSVM\n");
00230
00231 ar & ::boost::serialization::base_object<CKernelMachine>(*this);
00232
00233 ar & linear_term;
00234
00235 ar & svm_loaded;
00236
00237 ar & epsilon;
00238 ar & tube_epsilon;
00239
00240 ar & nu;
00241 ar & C1;
00242 ar & C2;
00243
00244 ar & objective;
00245
00246 ar & qpsize;
00247 ar & use_shrinking;
00248
00249
00250
00251
00252 SG_DEBUG("done with CSVM\n");
00253 }
00254
00255 public:
00256 virtual void toFile(std::string filename) const
00257 {
00258
00259 std::ofstream os(filename.c_str(), std::ios::binary);
00260 ::boost::archive::binary_oarchive oa(os);
00261
00262 oa << *this;
00263
00264 }
00265
00266 virtual void fromFile(std::string filename)
00267 {
00268
00269 std::ifstream is(filename.c_str(), std::ios::binary);
00270 ::boost::archive::binary_iarchive ia(is);
00271
00272 ia >> *this;
00273
00274 }
00275
00276 #endif //HAVE_BOOST_SERIALIZATION
00277
00278 protected:
00279
00285 virtual float64_t* get_linear_term_array();
00286
00288 std::vector<float64_t> linear_term;
00289
00291 bool svm_loaded;
00293 float64_t epsilon;
00295 float64_t tube_epsilon;
00297 float64_t nu;
00299 float64_t C1;
00301 float64_t C2;
00303 float64_t objective;
00305 int32_t qpsize;
00307 bool use_shrinking;
00308
00311 bool (*callback) (CMKL* mkl, const float64_t* sumw, const float64_t suma);
00314 CMKL* mkl;
00315 };
00316 }
00317 #endif