SVM.h

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #ifndef _SVM_H___
00012 #define _SVM_H___
00013 
00014 #include "lib/common.h"
00015 #include "features/Features.h"
00016 #include "kernel/Kernel.h"
00017 #include "classifier/KernelMachine.h"
00018 
00019 namespace shogun
00020 {
00021 
00022 class CMKL;
00023 
00046 class CSVM : public CKernelMachine
00047 {
00048     public:
00052         CSVM(int32_t num_sv=0);
00053 
00061         CSVM(float64_t C, CKernel* k, CLabels* lab);
00062         virtual ~CSVM();
00063 
00066         void set_defaults(int32_t num_sv=0);
00067 
00068 
00074         virtual std::vector<float64_t> get_linear_term();
00075 
00076 
00082         virtual void set_linear_term(std::vector<float64_t> lin);
00083 
00084 
00088         bool load(FILE* svm_file);
00089 
00093         bool save(FILE* svm_file);
00094 
00099         inline void set_nu(float64_t nue) { nu=nue; }
00100 
00109         inline void set_C(float64_t c1, float64_t c2) { C1=c1; C2=c2; }
00110 
00115         inline void set_epsilon(float64_t eps) { epsilon=eps; }
00116 
00121         inline void set_tube_epsilon(float64_t eps) { tube_epsilon=eps; }
00122 
00127         inline void set_qpsize(int32_t qps) { qpsize=qps; }
00128 
00133         inline float64_t get_epsilon() { return epsilon; }
00134 
00139         inline float64_t get_nu() { return nu; }
00140 
00145         inline float64_t get_C1() { return C1; }
00146 
00151         inline float64_t get_C2() { return C2; }
00152 
00157         inline int32_t get_qpsize() { return qpsize; }
00158 
00163         inline void set_shrinking_enabled(bool enable)
00164         {
00165             use_shrinking=enable;
00166         }
00167 
00172         inline bool get_shrinking_enabled()
00173         {
00174             return use_shrinking;
00175         }
00176 
00181         float64_t compute_svm_dual_objective();
00182 
00187         float64_t compute_svm_primal_objective();
00188 
00193         inline void set_objective(float64_t v)
00194         {
00195             objective=v;
00196         }
00197 
00202         inline float64_t get_objective()
00203         {
00204             return objective;
00205         }
00206 
00214         void set_callback_function(CMKL* m, bool (*cb)
00215                 (CMKL* mkl, const float64_t* sumw, const float64_t suma));
00216 
00218         inline virtual const char* get_name() const { return "SVM"; }
00219 
00220 #ifdef HAVE_BOOST_SERIALIZATION
00221         friend class ::boost::serialization::access;
00222         // When the class Archive corresponds to an output archive, the
00223         // & operator is defined similar to <<.  Likewise, when the class Archive
00224         // is a type of input archive the & operator is defined similar to >>.
00225         template<class Archive>
00226             void serialize(Archive & ar, const unsigned int archive_version)
00227             {
00228 
00229                 SG_DEBUG("archiving CSVM\n");
00230 
00231                 ar & ::boost::serialization::base_object<CKernelMachine>(*this);
00232 
00233                 ar & linear_term;
00234 
00235                 ar & svm_loaded;
00236 
00237                 ar & epsilon;
00238                 ar & tube_epsilon;
00239 
00240                 ar & nu;
00241                 ar & C1;
00242                 ar & C2;
00243 
00244                 ar & objective;
00245 
00246                 ar & qpsize;
00247                 ar & use_shrinking;
00248 
00249                 //TODO serialize mkl object
00250                 //CMKL* mkl;
00251 
00252                 SG_DEBUG("done with CSVM\n");
00253             }
00254 
00255     public:
00256         virtual void toFile(std::string filename) const
00257         {
00258 
00259             std::ofstream os(filename.c_str(), std::ios::binary);
00260             ::boost::archive::binary_oarchive oa(os);
00261 
00262             oa << *this;
00263 
00264         }
00265 
00266         virtual void fromFile(std::string filename)
00267         {
00268 
00269             std::ifstream is(filename.c_str(), std::ios::binary);
00270             ::boost::archive::binary_iarchive ia(is);
00271 
00272             ia >> *this;
00273 
00274         }
00275 
00276 #endif //HAVE_BOOST_SERIALIZATION
00277 
00278     protected:
00279 
00285         virtual float64_t* get_linear_term_array();
00286 
00288         std::vector<float64_t> linear_term;
00289 
00291         bool svm_loaded;
00293         float64_t epsilon;
00295         float64_t tube_epsilon;
00297         float64_t nu;
00299         float64_t C1;
00301         float64_t C2;
00303         float64_t objective;
00305         int32_t qpsize;
00307         bool use_shrinking;
00308 
00311         bool (*callback) (CMKL* mkl, const float64_t* sumw, const float64_t suma);
00314         CMKL* mkl;
00315 };
00316 }
00317 #endif

SHOGUN Machine Learning Toolbox - Documentation