LaRank.h
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049 #ifndef LARANK_H
00050 #define LARANK_H
00051
00052 #include <ctime>
00053 #include <vector>
00054 #include <algorithm>
00055 #include <sys/time.h>
00056 #include <ext/hash_map>
00057 #include <ext/hash_set>
00058
00059 #define STDEXT_NAMESPACE __gnu_cxx
00060 #define std_hash_map STDEXT_NAMESPACE::hash_map
00061 #define std_hash_set STDEXT_NAMESPACE::hash_set
00062
00063 #include "lib/io.h"
00064 #include "kernel/Kernel.h"
00065 #include "classifier/svm/MultiClassSVM.h"
00066
00067 namespace shogun
00068 {
00069 struct larank_kcache_s;
00070 typedef struct larank_kcache_s larank_kcache_t;
00071 struct larank_kcache_s
00072 {
00073 CKernel* func;
00074 larank_kcache_t *prevbuddy;
00075 larank_kcache_t *nextbuddy;
00076 int64_t maxsize;
00077 int64_t cursize;
00078 int32_t l;
00079 int32_t *i2r;
00080 int32_t *r2i;
00081 int32_t maxrowlen;
00082
00083 int32_t *rsize;
00084 float32_t *rdiag;
00085 float32_t **rdata;
00086 int32_t *rnext;
00087 int32_t *rprev;
00088 int32_t *qnext;
00089 int32_t *qprev;
00090 };
00091
00092
00093
00094
00095
00096 class LaRankOutput
00097 {
00098 public:
00099 LaRankOutput () : beta(NULL), g(NULL), kernel(NULL), l(0)
00100 {
00101 }
00102 virtual ~LaRankOutput ()
00103 {
00104 destroy();
00105 }
00106
00107
00108 void initialize (CKernel* kfunc, int64_t cache);
00109
00110
00111 void destroy ();
00112
00113
00114 float64_t computeScore (int32_t x_id);
00115
00116
00117 float64_t computeGradient (int32_t xi_id, int32_t yi, int32_t ythis);
00118
00119
00120 void update (int32_t x_id, float64_t lambda, float64_t gp);
00121
00122
00123
00124
00125 void set_kernel_buddy (larank_kcache_t * bud);
00126
00127
00128 int32_t cleanup ();
00129
00130
00131
00132
00133 inline larank_kcache_t *getKernel () const
00134 {
00135 return kernel;
00136 }
00137
00138 inline int32_t get_l () const
00139 {
00140 return l;
00141 }
00142
00143
00144 float64_t getW2 ();
00145
00146
00147 float64_t getKii (int32_t x_id);
00148
00149
00150 float64_t getBeta (int32_t x_id);
00151
00152
00153 inline float32_t* getBetas () const
00154 {
00155 return beta;
00156 }
00157
00158
00159 float64_t getGradient (int32_t x_id);
00160
00161
00162 bool isSupportVector (int32_t x_id) const;
00163
00164
00165 int32_t getSV (float32_t* &sv) const;
00166
00167 private:
00168
00169
00170 float32_t* beta;
00171 float32_t* g;
00172 larank_kcache_t *kernel;
00173 int32_t l;
00174 };
00175
00176
00177
00178
00179 class LaRankPattern
00180 {
00181 public:
00182 LaRankPattern (int32_t x_index, int32_t label)
00183 : x_id (x_index), y (label) {}
00184 LaRankPattern ()
00185 : x_id (0) {}
00186
00187 bool exists () const
00188 {
00189 return x_id >= 0;
00190 }
00191
00192 void clear ()
00193 {
00194 x_id = -1;
00195 }
00196
00197 int32_t x_id;
00198 int32_t y;
00199 };
00200
00201
00202
00203
00204 class LaRankPatterns
00205 {
00206 public:
00207 LaRankPatterns () {}
00208 ~LaRankPatterns () {}
00209
00210 void insert (const LaRankPattern & pattern)
00211 {
00212 if (!isPattern (pattern.x_id))
00213 {
00214 if (freeidx.size ())
00215 {
00216 std_hash_set < uint32_t >::iterator it = freeidx.begin ();
00217 patterns[*it] = pattern;
00218 x_id2rank[pattern.x_id] = *it;
00219 freeidx.erase (it);
00220 }
00221 else
00222 {
00223 patterns.push_back (pattern);
00224 x_id2rank[pattern.x_id] = patterns.size () - 1;
00225 }
00226 }
00227 else
00228 {
00229 int32_t rank = getPatternRank (pattern.x_id);
00230 patterns[rank] = pattern;
00231 }
00232 }
00233
00234 void remove (uint32_t i)
00235 {
00236 x_id2rank[patterns[i].x_id] = 0;
00237 patterns[i].clear ();
00238 freeidx.insert (i);
00239 }
00240
00241 bool empty () const
00242 {
00243 return patterns.size () == freeidx.size ();
00244 }
00245
00246 uint32_t size () const
00247 {
00248 return patterns.size () - freeidx.size ();
00249 }
00250
00251 LaRankPattern & sample ()
00252 {
00253 ASSERT (!empty ());
00254 while (true)
00255 {
00256 uint32_t r = CMath::random(0, patterns.size ());
00257 if (patterns[r].exists ())
00258 return patterns[r];
00259 }
00260 return patterns[0];
00261 }
00262
00263 uint32_t getPatternRank (int32_t x_id)
00264 {
00265 return x_id2rank[x_id];
00266 }
00267
00268 bool isPattern (int32_t x_id)
00269 {
00270 return x_id2rank[x_id] != 0;
00271 }
00272
00273 LaRankPattern & getPattern (int32_t x_id)
00274 {
00275 uint32_t rank = x_id2rank[x_id];
00276 return patterns[rank];
00277 }
00278
00279 uint32_t maxcount () const
00280 {
00281 return patterns.size ();
00282 }
00283
00284 LaRankPattern & operator [] (uint32_t i)
00285 {
00286 return patterns[i];
00287 }
00288
00289 const LaRankPattern & operator [] (uint32_t i) const
00290 {
00291 return patterns[i];
00292 }
00293
00294 private:
00295 std_hash_set < uint32_t >freeidx;
00296 std::vector < LaRankPattern > patterns;
00297 std_hash_map < int32_t, uint32_t >x_id2rank;
00298 };
00299
00300
00301
00302
00303
00304
00305 class CLaRank: public CMultiClassSVM
00306 {
00307 public:
00308 CLaRank ();
00309
00316 CLaRank(float64_t C, CKernel* k, CLabels* lab);
00317
00318 virtual ~CLaRank ();
00319
00320 bool train(CFeatures* data);
00321
00322
00323
00324
00325 virtual int32_t add (int32_t x_id, int32_t yi);
00326
00327
00328 virtual int32_t predict (int32_t x_id);
00329
00330 virtual void destroy ();
00331
00332
00333 virtual float64_t computeGap ();
00334
00335
00336 virtual uint32_t getNumOutputs () const;
00337
00338
00339 int32_t getNSV ();
00340
00341
00342 float64_t computeW2 ();
00343
00344
00345 float64_t getDual ();
00346
00351 virtual inline EClassifierType get_classifier_type() { return CT_LARANK; }
00352
00354 inline virtual const char* get_name() const { return "LaRank"; }
00355
00356 void set_batch_mode(bool enable) { batch_mode=enable; };
00357 bool get_batch_mode() { return batch_mode; };
00358 void set_tau(float64_t t) { tau=t; };
00359 float64_t get_tau() { return tau; };
00360
00361
00362 private:
00363
00364
00365
00366
00367
00368 typedef std_hash_map < int32_t, LaRankOutput > outputhash_t;
00369
00370
00371 outputhash_t outputs;
00372 LaRankOutput *getOutput (int32_t index);
00373
00374
00375 LaRankPatterns patterns;
00376
00377
00378 int32_t nb_seen_examples;
00379 int32_t nb_removed;
00380
00381
00382 int32_t n_pro;
00383 int32_t n_rep;
00384 int32_t n_opt;
00385
00386
00387 float64_t w_pro;
00388 float64_t w_rep;
00389 float64_t w_opt;
00390
00391 int32_t y0;
00392 float64_t dual;
00393
00394 struct outputgradient_t
00395 {
00396 outputgradient_t (int32_t result_output, float64_t result_gradient)
00397 : output (result_output), gradient (result_gradient) {}
00398 outputgradient_t ()
00399 : output (0), gradient (0) {}
00400
00401 int32_t output;
00402 float64_t gradient;
00403
00404 bool operator < (const outputgradient_t & og) const
00405 {
00406 return gradient > og.gradient;
00407 }
00408 };
00409
00410
00411 enum process_type
00412 {
00413 processNew,
00414 processOld,
00415 processOptimize
00416 };
00417
00418 struct process_return_t
00419 {
00420 process_return_t (float64_t dual, int32_t yprediction)
00421 : dual_increase (dual), ypred (yprediction) {}
00422 process_return_t () {}
00423 float64_t dual_increase;
00424 int32_t ypred;
00425 };
00426
00427
00428 process_return_t process (const LaRankPattern & pattern, process_type ptype);
00429
00430
00431 float64_t reprocess ();
00432
00433
00434 float64_t optimize ();
00435
00436
00437 uint32_t cleanup ();
00438
00439 protected:
00440
00441 std_hash_set < int32_t >classes;
00442
00443 inline uint32_t class_count () const
00444 {
00445 return classes.size ();
00446 }
00447
00448 float64_t tau;
00449 int32_t nb_train;
00450 int64_t cache;
00451
00452 bool batch_mode;
00453
00454
00455 int32_t step;
00456 };
00457 }
00458 #endif