SHOGUN v0.9.0
|
00001 // -*- C++ -*- 00002 // Main functions of the LaRank algorithm for soving Multiclass SVM 00003 // Copyright (C) 2008- Antoine Bordes 00004 // Shogun specific adjustments (w) 2009 Soeren Sonnenburg 00005 00006 // This library is free software; you can redistribute it and/or 00007 // modify it under the terms of the GNU Lesser General Public 00008 // License as published by the Free Software Foundation; either 00009 // version 2.1 of the License, or (at your option) any later version. 00010 // 00011 // This program is distributed in the hope that it will be useful, 00012 // but WITHOUT ANY WARRANTY; without even the implied warranty of 00013 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 00014 // GNU General Public License for more details. 00015 // 00016 // You should have received a copy of the GNU General Public License 00017 // along with this program; if not, write to the Free Software 00018 // Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111, USA 00019 // 00020 /*********************************************************************** 00021 * 00022 * LUSH Lisp Universal Shell 00023 * Copyright (C) 2002 Leon Bottou, Yann Le Cun, AT&T Corp, NECI. 00024 * Includes parts of TL3: 00025 * Copyright (C) 1987-1999 Leon Bottou and Neuristique. 00026 * Includes selected parts of SN3.2: 00027 * Copyright (C) 1991-2001 AT&T Corp. 00028 * 00029 * This program is free software; you can redistribute it and/or modify 00030 * it under the terms of the GNU General Public License as published by 00031 * the Free Software Foundation; either version 2 of the License, or 00032 * (at your option) any later version. 00033 * 00034 * This program is distributed in the hope that it will be useful, 00035 * but WITHOUT ANY WARRANTY; without even the implied warranty of 00036 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 00037 * GNU General Public License for more details. 00038 * 00039 * You should have received a copy of the GNU General Public License 00040 * along with this program; if not, write to the Free Software 00041 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111, USA 00042 * 00043 ***********************************************************************/ 00044 00045 /*********************************************************************** 00046 * $Id: kcache.h,v 1.8 2007/01/25 22:42:09 leonb Exp $ 00047 **********************************************************************/ 00048 00049 #ifndef LARANK_H 00050 #define LARANK_H 00051 00052 #include <ctime> 00053 #include <vector> 00054 #include <algorithm> 00055 #include <sys/time.h> 00056 #include <ext/hash_map> 00057 #include <ext/hash_set> 00058 00059 #define STDEXT_NAMESPACE __gnu_cxx 00060 #define std_hash_map STDEXT_NAMESPACE::hash_map 00061 #define std_hash_set STDEXT_NAMESPACE::hash_set 00062 00063 #include "lib/io.h" 00064 #include "kernel/Kernel.h" 00065 #include "classifier/svm/MultiClassSVM.h" 00066 00067 namespace shogun 00068 { 00069 #ifndef DOXYGEN_SHOULD_SKIP_THIS 00070 struct larank_kcache_s; 00071 typedef struct larank_kcache_s larank_kcache_t; 00072 struct larank_kcache_s 00073 { 00074 CKernel* func; 00075 larank_kcache_t *prevbuddy; 00076 larank_kcache_t *nextbuddy; 00077 int64_t maxsize; 00078 int64_t cursize; 00079 int32_t l; 00080 int32_t *i2r; 00081 int32_t *r2i; 00082 int32_t maxrowlen; 00083 /* Rows */ 00084 int32_t *rsize; 00085 float32_t *rdiag; 00086 float32_t **rdata; 00087 int32_t *rnext; 00088 int32_t *rprev; 00089 int32_t *qnext; 00090 int32_t *qprev; 00091 }; 00092 00093 /* 00094 ** OUTPUT: one per class of the raining set, keep tracks of support 00095 * vectors and their beta coefficients 00096 */ 00097 class LaRankOutput 00098 { 00099 public: 00100 LaRankOutput () : beta(NULL), g(NULL), kernel(NULL), l(0) 00101 { 00102 } 00103 virtual ~LaRankOutput () 00104 { 00105 destroy(); 00106 } 00107 00108 // Initializing an output class (basically creating a kernel cache for it) 00109 void initialize (CKernel* kfunc, int64_t cache); 00110 00111 // Destroying an output class (basically destroying the kernel cache) 00112 void destroy (); 00113 00114 // !Important! Computing the score of a given input vector for the actual output 00115 float64_t computeScore (int32_t x_id); 00116 00117 // !Important! Computing the gradient of a given input vector for the actual output 00118 float64_t computeGradient (int32_t xi_id, int32_t yi, int32_t ythis); 00119 00120 // Updating the solution in the actual output 00121 void update (int32_t x_id, float64_t lambda, float64_t gp); 00122 00123 // Linking the cache of this output to the cache of an other "buddy" output 00124 // so that if a requested value is not found in this cache, you can 00125 // ask your buddy if it has it. 00126 void set_kernel_buddy (larank_kcache_t * bud); 00127 00128 // Removing useless support vectors (for which beta=0) 00129 int32_t cleanup (); 00130 00131 // --- Below are information or "get" functions --- // 00132 00133 // 00134 inline larank_kcache_t *getKernel () const 00135 { 00136 return kernel; 00137 } 00138 // 00139 inline int32_t get_l () const 00140 { 00141 return l; 00142 } 00143 00144 // 00145 float64_t getW2 (); 00146 00147 // 00148 float64_t getKii (int32_t x_id); 00149 00150 // 00151 float64_t getBeta (int32_t x_id); 00152 00153 // 00154 inline float32_t* getBetas () const 00155 { 00156 return beta; 00157 } 00158 00159 // 00160 float64_t getGradient (int32_t x_id); 00161 00162 // 00163 bool isSupportVector (int32_t x_id) const; 00164 00165 // 00166 int32_t getSV (float32_t* &sv) const; 00167 00168 private: 00169 // the solution of LaRank relative to the actual class is stored in 00170 // this parameters 00171 float32_t* beta; // Beta coefficiens 00172 float32_t* g; // Strored gradient derivatives 00173 larank_kcache_t *kernel; // Cache for kernel values 00174 int32_t l; // Number of support vectors 00175 }; 00176 00177 /* 00178 ** LARANKPATTERN: to keep track of the support patterns 00179 */ 00180 class LaRankPattern 00181 { 00182 public: 00183 LaRankPattern (int32_t x_index, int32_t label) 00184 : x_id (x_index), y (label) {} 00185 LaRankPattern () 00186 : x_id (0) {} 00187 00188 bool exists () const 00189 { 00190 return x_id >= 0; 00191 } 00192 00193 void clear () 00194 { 00195 x_id = -1; 00196 } 00197 00198 int32_t x_id; 00199 int32_t y; 00200 }; 00201 00202 /* 00203 ** LARANKPATTERNS: the collection of support patterns 00204 */ 00205 class LaRankPatterns 00206 { 00207 public: 00208 LaRankPatterns () {} 00209 ~LaRankPatterns () {} 00210 00211 void insert (const LaRankPattern & pattern) 00212 { 00213 if (!isPattern (pattern.x_id)) 00214 { 00215 if (freeidx.size ()) 00216 { 00217 std_hash_set < uint32_t >::iterator it = freeidx.begin (); 00218 patterns[*it] = pattern; 00219 x_id2rank[pattern.x_id] = *it; 00220 freeidx.erase (it); 00221 } 00222 else 00223 { 00224 patterns.push_back (pattern); 00225 x_id2rank[pattern.x_id] = patterns.size () - 1; 00226 } 00227 } 00228 else 00229 { 00230 int32_t rank = getPatternRank (pattern.x_id); 00231 patterns[rank] = pattern; 00232 } 00233 } 00234 00235 void remove (uint32_t i) 00236 { 00237 x_id2rank[patterns[i].x_id] = 0; 00238 patterns[i].clear (); 00239 freeidx.insert (i); 00240 } 00241 00242 bool empty () const 00243 { 00244 return patterns.size () == freeidx.size (); 00245 } 00246 00247 uint32_t size () const 00248 { 00249 return patterns.size () - freeidx.size (); 00250 } 00251 00252 LaRankPattern & sample () 00253 { 00254 ASSERT (!empty ()); 00255 while (true) 00256 { 00257 uint32_t r = CMath::random(0, patterns.size ()-1); 00258 if (patterns[r].exists ()) 00259 return patterns[r]; 00260 } 00261 return patterns[0]; 00262 } 00263 00264 uint32_t getPatternRank (int32_t x_id) 00265 { 00266 return x_id2rank[x_id]; 00267 } 00268 00269 bool isPattern (int32_t x_id) 00270 { 00271 return x_id2rank[x_id] != 0; 00272 } 00273 00274 LaRankPattern & getPattern (int32_t x_id) 00275 { 00276 uint32_t rank = x_id2rank[x_id]; 00277 return patterns[rank]; 00278 } 00279 00280 uint32_t maxcount () const 00281 { 00282 return patterns.size (); 00283 } 00284 00285 LaRankPattern & operator [] (uint32_t i) 00286 { 00287 return patterns[i]; 00288 } 00289 00290 const LaRankPattern & operator [] (uint32_t i) const 00291 { 00292 return patterns[i]; 00293 } 00294 00295 private: 00296 std_hash_set < uint32_t >freeidx; 00297 std::vector < LaRankPattern > patterns; 00298 std_hash_map < int32_t, uint32_t >x_id2rank; 00299 }; 00300 00301 00302 #endif // DOXYGEN_SHOULD_SKIP_THIS 00303 00304 00305 /* 00306 ** MACHINE: the main thing, which is trained. 00307 */ 00308 class CLaRank: public CMultiClassSVM 00309 { 00310 public: 00311 CLaRank (); 00312 00319 CLaRank(float64_t C, CKernel* k, CLabels* lab); 00320 00321 virtual ~CLaRank (); 00322 00323 bool train(CFeatures* data); 00324 00325 00326 // LEARNING FUNCTION: add new patterns and run optimization steps 00327 // selected with adaptative schedule 00328 virtual int32_t add (int32_t x_id, int32_t yi); 00329 00330 // PREDICTION FUNCTION: main function in la_rank_classify 00331 virtual int32_t predict (int32_t x_id); 00332 00333 virtual void destroy (); 00334 00335 // Compute Duality gap (costly but used in stopping criteria in batch mode) 00336 virtual float64_t computeGap (); 00337 00338 // Nuber of classes so far 00339 virtual uint32_t getNumOutputs () const; 00340 00341 // Number of Support Vectors 00342 int32_t getNSV (); 00343 00344 // Norm of the parameters vector 00345 float64_t computeW2 (); 00346 00347 // Compute Dual objective value 00348 float64_t getDual (); 00349 00354 virtual inline EClassifierType get_classifier_type() { return CT_LARANK; } 00355 00357 inline virtual const char* get_name() const { return "LaRank"; } 00358 00359 void set_batch_mode(bool enable) { batch_mode=enable; }; 00360 bool get_batch_mode() { return batch_mode; }; 00361 void set_tau(float64_t t) { tau=t; }; 00362 float64_t get_tau() { return tau; }; 00363 00364 00365 private: 00366 /* 00367 ** MAIN DARK OPTIMIZATION PROCESSES 00368 */ 00369 00370 // Hash Table used to store the different outputs 00371 typedef std_hash_map < int32_t, LaRankOutput > outputhash_t; // class index -> LaRankOutput 00372 00373 00374 outputhash_t outputs; 00375 LaRankOutput *getOutput (int32_t index); 00376 00377 // 00378 LaRankPatterns patterns; 00379 00380 // Parameters 00381 int32_t nb_seen_examples; 00382 int32_t nb_removed; 00383 00384 // Numbers of each operation performed so far 00385 int32_t n_pro; 00386 int32_t n_rep; 00387 int32_t n_opt; 00388 00389 // Running estimates for each operations 00390 float64_t w_pro; 00391 float64_t w_rep; 00392 float64_t w_opt; 00393 00394 int32_t y0; 00395 float64_t dual; 00396 00397 struct outputgradient_t 00398 { 00399 outputgradient_t (int32_t result_output, float64_t result_gradient) 00400 : output (result_output), gradient (result_gradient) {} 00401 outputgradient_t () 00402 : output (0), gradient (0) {} 00403 00404 int32_t output; 00405 float64_t gradient; 00406 00407 bool operator < (const outputgradient_t & og) const 00408 { 00409 return gradient > og.gradient; 00410 } 00411 }; 00412 00413 //3 types of operations in LaRank 00414 enum process_type 00415 { 00416 processNew, 00417 processOld, 00418 processOptimize 00419 }; 00420 00421 struct process_return_t 00422 { 00423 process_return_t (float64_t dual, int32_t yprediction) 00424 : dual_increase (dual), ypred (yprediction) {} 00425 process_return_t () {} 00426 float64_t dual_increase; 00427 int32_t ypred; 00428 }; 00429 00430 // IMPORTANT Main SMO optimization step 00431 process_return_t process (const LaRankPattern & pattern, process_type ptype); 00432 00433 // ProcessOld 00434 float64_t reprocess (); 00435 00436 // Optimize 00437 float64_t optimize (); 00438 00439 // remove patterns and return the number of patterns that were removed 00440 uint32_t cleanup (); 00441 00442 protected: 00443 00444 std_hash_set < int32_t >classes; 00445 00446 inline uint32_t class_count () const 00447 { 00448 return classes.size (); 00449 } 00450 00451 float64_t tau; 00452 int32_t nb_train; 00453 int64_t cache; 00454 // whether to use online learning or batch training 00455 bool batch_mode; 00456 00457 //progess output 00458 int32_t step; 00459 }; 00460 } 00461 #endif // LARANK_H