SHOGUN v0.9.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2007-2008 Vojtech Franc 00008 * Written (W) 2007-2009 Soeren Sonnenburg 00009 * Copyright (C) 2007-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00010 */ 00011 00012 #ifndef _WDSVMOCAS_H___ 00013 #define _WDSVMOCAS_H___ 00014 00015 #include "lib/common.h" 00016 #include "classifier/Classifier.h" 00017 #include "classifier/svm/SVMOcas.h" 00018 #include "features/StringFeatures.h" 00019 #include "features/Labels.h" 00020 00021 namespace shogun 00022 { 00023 template <class ST> class CStringFeatures; 00024 00026 class CWDSVMOcas : public CClassifier 00027 { 00028 public: 00030 CWDSVMOcas(void); 00031 00036 CWDSVMOcas(E_SVM_TYPE type); 00037 00046 CWDSVMOcas( 00047 float64_t C, int32_t d, int32_t from_d, 00048 CStringFeatures<uint8_t>* traindat, CLabels* trainlab); 00049 virtual ~CWDSVMOcas(); 00050 00055 virtual inline EClassifierType get_classifier_type() { return CT_WDSVMOCAS; } 00056 00065 virtual bool train(CFeatures* data=NULL); 00066 00073 inline void set_C(float64_t c_neg, float64_t c_pos) { C1=c_neg; C2=c_pos; } 00074 00079 inline float64_t get_C1() { return C1; } 00080 00085 inline float64_t get_C2() { return C2; } 00086 00091 inline void set_epsilon(float64_t eps) { epsilon=eps; } 00092 00097 inline float64_t get_epsilon() { return epsilon; } 00098 00103 inline void set_features(CStringFeatures<uint8_t>* feat) 00104 { 00105 SG_UNREF(features); 00106 SG_REF(feat); 00107 features=feat; 00108 } 00109 00114 inline CStringFeatures<uint8_t>* get_features() 00115 { 00116 SG_REF(features); 00117 return features; 00118 } 00119 00124 inline void set_bias_enabled(bool enable_bias) { use_bias=enable_bias; } 00125 00130 inline bool get_bias_enabled() { return use_bias; } 00131 00136 inline void set_bufsize(int32_t sz) { bufsize=sz; } 00137 00142 inline int32_t get_bufsize() { return bufsize; } 00143 00149 inline void set_degree(int32_t d, int32_t from_d) 00150 { 00151 degree=d; 00152 from_degree=from_d; 00153 } 00154 00159 inline int32_t get_degree() { return degree; } 00160 00165 CLabels* classify(); 00166 00172 virtual CLabels* classify(CFeatures* data); 00173 00179 inline virtual float64_t classify_example(int32_t num) 00180 { 00181 ASSERT(features); 00182 if (!wd_weights) 00183 set_wd_weights(); 00184 00185 int32_t len=0; 00186 float64_t sum=0; 00187 bool free_vec; 00188 uint8_t* vec=features->get_feature_vector(num, len, free_vec); 00189 //SG_INFO("len %d, string_length %d\n", len, string_length); 00190 ASSERT(len==string_length); 00191 00192 for (int32_t j=0; j<string_length; j++) 00193 { 00194 int32_t offs=w_dim_single_char*j; 00195 int32_t val=0; 00196 for (int32_t k=0; (j+k<string_length) && (k<degree); k++) 00197 { 00198 val=val*alphabet_size + vec[j+k]; 00199 sum+=wd_weights[k] * w[offs+val]; 00200 offs+=w_offsets[k]; 00201 } 00202 } 00203 features->free_feature_vector(vec, len, free_vec); 00204 return sum/normalization_const; 00205 } 00206 00208 inline void set_normalization_const() 00209 { 00210 ASSERT(features); 00211 normalization_const=0; 00212 for (int32_t i=0; i<degree; i++) 00213 normalization_const+=(string_length-i)*wd_weights[i]*wd_weights[i]; 00214 00215 normalization_const=CMath::sqrt(normalization_const); 00216 SG_DEBUG("normalization_const:%f\n", normalization_const); 00217 } 00218 00223 inline float64_t get_normalization_const() { return normalization_const; } 00224 00225 00226 protected: 00231 int32_t set_wd_weights(); 00232 00241 static void compute_W( 00242 float64_t *sq_norm_W, float64_t *dp_WoldW, float64_t *alpha, 00243 uint32_t nSel, void* ptr ); 00244 00251 static float64_t update_W(float64_t t, void* ptr ); 00252 00258 static void* add_new_cut_helper(void* ptr); 00259 00268 static int add_new_cut( 00269 float64_t *new_col_H, uint32_t *new_cut, uint32_t cut_length, 00270 uint32_t nSel, void* ptr ); 00271 00277 static void* compute_output_helper(void* ptr); 00278 00284 static int compute_output( float64_t *output, void* ptr ); 00285 00292 static int sort( float64_t* vals, float64_t* data, uint32_t size); 00293 00295 static inline void print(ocas_return_value_T value) 00296 { 00297 return; 00298 } 00299 00300 00302 inline virtual const char* get_name() const { return "WDSVMOcas"; } 00303 00304 protected: 00306 CStringFeatures<uint8_t>* features; 00308 bool use_bias; 00310 int32_t bufsize; 00312 float64_t C1; 00314 float64_t C2; 00316 float64_t epsilon; 00318 E_SVM_TYPE method; 00319 00321 int32_t degree; 00323 int32_t from_degree; 00325 float32_t* wd_weights; 00327 int32_t num_vec; 00329 int32_t string_length; 00331 int32_t alphabet_size; 00332 00334 float64_t normalization_const; 00335 00337 float64_t bias; 00339 float64_t old_bias; 00341 int32_t* w_offsets; 00343 int32_t w_dim; 00345 int32_t w_dim_single_char; 00347 float32_t* w; 00349 float32_t* old_w; 00351 float64_t* lab; 00352 00354 float32_t** cuts; 00356 float64_t* cp_bias; 00357 }; 00358 } 00359 #endif