[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]
00001 /************************************************************************/ 00002 /* */ 00003 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */ 00004 /* */ 00005 /* This file is part of the VIGRA computer vision library. */ 00006 /* The VIGRA Website is */ 00007 /* http://hci.iwr.uni-heidelberg.de/vigra/ */ 00008 /* Please direct questions, bug reports, and contributions to */ 00009 /* ullrich.koethe@iwr.uni-heidelberg.de or */ 00010 /* vigra@informatik.uni-hamburg.de */ 00011 /* */ 00012 /* Permission is hereby granted, free of charge, to any person */ 00013 /* obtaining a copy of this software and associated documentation */ 00014 /* files (the "Software"), to deal in the Software without */ 00015 /* restriction, including without limitation the rights to use, */ 00016 /* copy, modify, merge, publish, distribute, sublicense, and/or */ 00017 /* sell copies of the Software, and to permit persons to whom the */ 00018 /* Software is furnished to do so, subject to the following */ 00019 /* conditions: */ 00020 /* */ 00021 /* The above copyright notice and this permission notice shall be */ 00022 /* included in all copies or substantial portions of the */ 00023 /* Software. */ 00024 /* */ 00025 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */ 00026 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */ 00027 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */ 00028 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */ 00029 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */ 00030 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */ 00031 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */ 00032 /* OTHER DEALINGS IN THE SOFTWARE. */ 00033 /* */ 00034 /************************************************************************/ 00035 #ifndef RF_VISITORS_HXX 00036 #define RF_VISITORS_HXX 00037 00038 #ifdef HasHDF5 00039 # include "vigra/hdf5impex.hxx" 00040 #endif // HasHDF5 00041 00042 namespace vigra 00043 { 00044 00045 00046 00047 /** Base Class from which all Visitors derive 00048 */ 00049 class VisitorBase 00050 { 00051 public: 00052 bool active_; 00053 bool is_active() 00054 { 00055 return active_; 00056 } 00057 00058 bool has_value() 00059 { 00060 return false; 00061 } 00062 00063 VisitorBase() 00064 : active_(true) 00065 {} 00066 00067 void deactivate() 00068 { 00069 active_ = false; 00070 } 00071 void activate() 00072 { 00073 active_ = true; 00074 } 00075 00076 /** do something after the the Split has decided how to process the Region 00077 * (Stack entry) 00078 * 00079 * \param tree reference to the tree that is currently being learned 00080 * \param split reference to the split object 00081 * \param parent current stack entry which was used to decide the split 00082 * \param leftChild left stack entry that will be pushed 00083 * \param rightChild 00084 * right stack entry that will be pushed. 00085 * \param features features matrix 00086 * \param labels label matrix 00087 * \sa RF_Traits::StackEntry_t 00088 */ 00089 template<class Tree, class Split, class Region, class Feature_t, class Label_t> 00090 void visit_after_split( Tree & tree, 00091 Split & split, 00092 Region & parent, 00093 Region & leftChild, 00094 Region & rightChild, 00095 Feature_t & features, 00096 Label_t & labels) 00097 {} 00098 00099 /** do something after each tree has been learned 00100 * 00101 * \param rf reference to the random forest object that called this 00102 * visitor 00103 * \param pr reference to the preprocessor that processed the input 00104 * \param sm reference to the sampler object 00105 * \param st reference to the first stack entry 00106 * \param index index of current tree 00107 */ 00108 template<class RF, class PR, class SM, class ST> 00109 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index) 00110 {} 00111 00112 /** do something after all trees have been learned 00113 * 00114 * \param rf reference to the random forest object that called this 00115 * visitor 00116 * \param pr reference to the preprocessor that processed the input 00117 */ 00118 template<class RF, class PR> 00119 void visit_at_end(RF const & rf, PR const & pr) 00120 {} 00121 00122 /** do something before learning starts 00123 * 00124 * \param rf reference to the random forest object that called this 00125 * visitor 00126 * \param pr reference to the Processor class used. 00127 */ 00128 template<class RF, class PR> 00129 void visit_at_beginning(RF const & rf, PR const & pr) 00130 {} 00131 /** do some thing while traversing tree after it has been learned 00132 * (external nodes) 00133 * 00134 * \param tr reference to the tree object that called this visitor 00135 * \param index index in the topology_ array we currently are at 00136 * \param node_t type of node we have (will be e_.... - ) 00137 * \param weight Node weight of current node. 00138 * \sa NodeTags; 00139 * 00140 * you can create the node by using a switch on node_tag and using the 00141 * corresponding Node objects. Or - if you do not care about the type 00142 * use the Nodebase class. 00143 */ 00144 template<class TR, class IntT, class TopT,class Feat> 00145 void visit_external_node(TR & tr, IntT index, TopT node_t,Feat & features) 00146 {} 00147 00148 /** do something when visiting a internal node after it has been learned 00149 * 00150 * \sa visit_external_node 00151 */ 00152 template<class TR, class IntT, class TopT,class Feat> 00153 void visit_internal_node(TR & tr, IntT index, TopT node_t,Feat & features) 00154 {} 00155 00156 /** return a double value. The value of the first 00157 * visitor encountered that has a return value is returned with the 00158 * RandomForest::learn() method - or -1.0 if no return value visitor 00159 * existed. This functionality basically only exists so that the 00160 * OOB - visitor can return the oob error rate like in the old version 00161 * of the random forest. 00162 */ 00163 double return_val() 00164 { 00165 return -1.0; 00166 } 00167 }; 00168 00169 namespace rf 00170 { 00171 00172 /** Last Visitor that should be called to stop the recursion. 00173 */ 00174 class StopVisiting: public VisitorBase 00175 { 00176 public: 00177 bool has_value() 00178 { 00179 return true; 00180 } 00181 double return_val() 00182 { 00183 return -1.0; 00184 } 00185 }; 00186 /** Container elements of the statically linked Visitor list. 00187 * 00188 * use the create_visitor() factory functions to create visitors up to size 10; 00189 * 00190 */ 00191 template <class Visitor, class Next = StopVisiting> 00192 class VisitorNode 00193 { 00194 public: 00195 00196 StopVisiting stop_; 00197 Next next_; 00198 Visitor & visitor_; 00199 VisitorNode(Visitor & visitor, Next & next) 00200 : 00201 next_(next), visitor_(visitor) 00202 {} 00203 00204 VisitorNode(Visitor & visitor) 00205 : 00206 next_(stop_), visitor_(visitor) 00207 {} 00208 00209 template<class Tree, class Split, class Region, class Feature_t, class Label_t> 00210 void visit_after_split( Tree & tree, 00211 Split & split, 00212 Region & parent, 00213 Region & leftChild, 00214 Region & rightChild, 00215 Feature_t & features, 00216 Label_t & labels) 00217 { 00218 if(visitor_.is_active()) 00219 visitor_.visit_after_split(tree, split, 00220 parent, leftChild, rightChild, 00221 features, labels); 00222 next_.visit_after_split(tree, split, parent, leftChild, rightChild, 00223 features, labels); 00224 } 00225 00226 template<class RF, class PR, class SM, class ST> 00227 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index) 00228 { 00229 if(visitor_.is_active()) 00230 visitor_.visit_after_tree(rf, pr, sm, st, index); 00231 next_.visit_after_tree(rf, pr, sm, st, index); 00232 } 00233 00234 template<class RF, class PR> 00235 void visit_at_beginning(RF & rf, PR & pr) 00236 { 00237 if(visitor_.is_active()) 00238 visitor_.visit_at_beginning(rf, pr); 00239 next_.visit_at_beginning(rf, pr); 00240 } 00241 template<class RF, class PR> 00242 void visit_at_end(RF & rf, PR & pr) 00243 { 00244 if(visitor_.is_active()) 00245 visitor_.visit_at_end(rf, pr); 00246 next_.visit_at_end(rf, pr); 00247 } 00248 00249 template<class TR, class IntT, class TopT,class Feat> 00250 void visit_external_node(TR & tr, IntT & index, TopT & node_t,Feat & features) 00251 { 00252 if(visitor_.is_active()) 00253 visitor_.visit_external_node(tr, index, node_t,features); 00254 next_.visit_external_node(tr, index, node_t,features); 00255 } 00256 template<class TR, class IntT, class TopT,class Feat> 00257 void visit_internal_node(TR & tr, IntT & index, TopT & node_t,Feat & features) 00258 { 00259 if(visitor_.is_active()) 00260 visitor_.visit_internal_node(tr, index, node_t,features); 00261 next_.visit_internal_node(tr, index, node_t,features); 00262 } 00263 00264 double return_val() 00265 { 00266 if(visitor_.is_active() && visitor_.has_value()) 00267 return visitor_.return_val(); 00268 return next_.return_val(); 00269 } 00270 }; 00271 00272 } //namespace rf 00273 00274 ////////////////////////////////////////////////////////////////////////////// 00275 // Visitor Factory function up to 10 visitors // 00276 ////////////////////////////////////////////////////////////////////////////// 00277 template<class A> 00278 rf::VisitorNode<A> 00279 create_visitor(A & a) 00280 { 00281 typedef rf::VisitorNode<A> _0_t; 00282 _0_t _0(a); 00283 return _0; 00284 } 00285 00286 00287 template<class A, class B> 00288 rf::VisitorNode<A, rf::VisitorNode<B> > 00289 create_visitor(A & a, B & b) 00290 { 00291 typedef rf::VisitorNode<B> _1_t; 00292 _1_t _1(b); 00293 typedef rf::VisitorNode<A, _1_t> _0_t; 00294 _0_t _0(a, _1); 00295 return _0; 00296 } 00297 00298 00299 template<class A, class B, class C> 00300 rf::VisitorNode<A, rf::VisitorNode<B, rf::VisitorNode<C> > > 00301 create_visitor(A & a, B & b, C & c) 00302 { 00303 typedef rf::VisitorNode<C> _2_t; 00304 _2_t _2(c); 00305 typedef rf::VisitorNode<B, _2_t> _1_t; 00306 _1_t _1(b, _2); 00307 typedef rf::VisitorNode<A, _1_t> _0_t; 00308 _0_t _0(a, _1); 00309 return _0; 00310 } 00311 00312 00313 template<class A, class B, class C, class D> 00314 rf::VisitorNode<A, rf::VisitorNode<B, rf::VisitorNode<C, 00315 rf::VisitorNode<D> > > > 00316 create_visitor(A & a, B & b, C & c, D & d) 00317 { 00318 typedef rf::VisitorNode<D> _3_t; 00319 _3_t _3(d); 00320 typedef rf::VisitorNode<C, _3_t> _2_t; 00321 _2_t _2(c, _3); 00322 typedef rf::VisitorNode<B, _2_t> _1_t; 00323 _1_t _1(b, _2); 00324 typedef rf::VisitorNode<A, _1_t> _0_t; 00325 _0_t _0(a, _1); 00326 return _0; 00327 } 00328 00329 00330 template<class A, class B, class C, class D, class E> 00331 rf::VisitorNode<A, rf::VisitorNode<B, rf::VisitorNode<C, 00332 rf::VisitorNode<D, rf::VisitorNode<E> > > > > 00333 create_visitor(A & a, B & b, C & c, 00334 D & d, E & e) 00335 { 00336 typedef rf::VisitorNode<E> _4_t; 00337 _4_t _4(e); 00338 typedef rf::VisitorNode<D, _4_t> _3_t; 00339 _3_t _3(d, _4); 00340 typedef rf::VisitorNode<C, _3_t> _2_t; 00341 _2_t _2(c, _3); 00342 typedef rf::VisitorNode<B, _2_t> _1_t; 00343 _1_t _1(b, _2); 00344 typedef rf::VisitorNode<A, _1_t> _0_t; 00345 _0_t _0(a, _1); 00346 return _0; 00347 } 00348 00349 00350 template<class A, class B, class C, class D, class E, 00351 class F> 00352 rf::VisitorNode<A, rf::VisitorNode<B, rf::VisitorNode<C, 00353 rf::VisitorNode<D, rf::VisitorNode<E, rf::VisitorNode<F> > > > > > 00354 create_visitor(A & a, B & b, C & c, 00355 D & d, E & e, F & f) 00356 { 00357 typedef rf::VisitorNode<F> _5_t; 00358 _5_t _5(f); 00359 typedef rf::VisitorNode<E, _5_t> _4_t; 00360 _4_t _4(e, _5); 00361 typedef rf::VisitorNode<D, _4_t> _3_t; 00362 _3_t _3(d, _4); 00363 typedef rf::VisitorNode<C, _3_t> _2_t; 00364 _2_t _2(c, _3); 00365 typedef rf::VisitorNode<B, _2_t> _1_t; 00366 _1_t _1(b, _2); 00367 typedef rf::VisitorNode<A, _1_t> _0_t; 00368 _0_t _0(a, _1); 00369 return _0; 00370 } 00371 00372 00373 template<class A, class B, class C, class D, class E, 00374 class F, class G> 00375 rf::VisitorNode<A, rf::VisitorNode<B, rf::VisitorNode<C, 00376 rf::VisitorNode<D, rf::VisitorNode<E, rf::VisitorNode<F, 00377 rf::VisitorNode<G> > > > > > > 00378 create_visitor(A & a, B & b, C & c, 00379 D & d, E & e, F & f, G & g) 00380 { 00381 typedef rf::VisitorNode<G> _6_t; 00382 _6_t _6(g); 00383 typedef rf::VisitorNode<F, _6_t> _5_t; 00384 _5_t _5(f, _6); 00385 typedef rf::VisitorNode<E, _5_t> _4_t; 00386 _4_t _4(e, _5); 00387 typedef rf::VisitorNode<D, _4_t> _3_t; 00388 _3_t _3(d, _4); 00389 typedef rf::VisitorNode<C, _3_t> _2_t; 00390 _2_t _2(c, _3); 00391 typedef rf::VisitorNode<B, _2_t> _1_t; 00392 _1_t _1(b, _2); 00393 typedef rf::VisitorNode<A, _1_t> _0_t; 00394 _0_t _0(a, _1); 00395 return _0; 00396 } 00397 00398 00399 template<class A, class B, class C, class D, class E, 00400 class F, class G, class H> 00401 rf::VisitorNode<A, rf::VisitorNode<B, rf::VisitorNode<C, 00402 rf::VisitorNode<D, rf::VisitorNode<E, rf::VisitorNode<F, 00403 rf::VisitorNode<G, rf::VisitorNode<H> > > > > > > > 00404 create_visitor(A & a, B & b, C & c, 00405 D & d, E & e, F & f, 00406 G & g, H & h) 00407 { 00408 typedef rf::VisitorNode<H> _7_t; 00409 _7_t _7(h); 00410 typedef rf::VisitorNode<G, _7_t> _6_t; 00411 _6_t _6(g, _7); 00412 typedef rf::VisitorNode<F, _6_t> _5_t; 00413 _5_t _5(f, _6); 00414 typedef rf::VisitorNode<E, _5_t> _4_t; 00415 _4_t _4(e, _5); 00416 typedef rf::VisitorNode<D, _4_t> _3_t; 00417 _3_t _3(d, _4); 00418 typedef rf::VisitorNode<C, _3_t> _2_t; 00419 _2_t _2(c, _3); 00420 typedef rf::VisitorNode<B, _2_t> _1_t; 00421 _1_t _1(b, _2); 00422 typedef rf::VisitorNode<A, _1_t> _0_t; 00423 _0_t _0(a, _1); 00424 return _0; 00425 } 00426 00427 00428 template<class A, class B, class C, class D, class E, 00429 class F, class G, class H, class I> 00430 rf::VisitorNode<A, rf::VisitorNode<B, rf::VisitorNode<C, 00431 rf::VisitorNode<D, rf::VisitorNode<E, rf::VisitorNode<F, 00432 rf::VisitorNode<G, rf::VisitorNode<H, rf::VisitorNode<I> > > > > > > > > 00433 create_visitor(A & a, B & b, C & c, 00434 D & d, E & e, F & f, 00435 G & g, H & h, I & i) 00436 { 00437 typedef rf::VisitorNode<I> _8_t; 00438 _8_t _8(i); 00439 typedef rf::VisitorNode<H, _8_t> _7_t; 00440 _7_t _7(h, _8); 00441 typedef rf::VisitorNode<G, _7_t> _6_t; 00442 _6_t _6(g, _7); 00443 typedef rf::VisitorNode<F, _6_t> _5_t; 00444 _5_t _5(f, _6); 00445 typedef rf::VisitorNode<E, _5_t> _4_t; 00446 _4_t _4(e, _5); 00447 typedef rf::VisitorNode<D, _4_t> _3_t; 00448 _3_t _3(d, _4); 00449 typedef rf::VisitorNode<C, _3_t> _2_t; 00450 _2_t _2(c, _3); 00451 typedef rf::VisitorNode<B, _2_t> _1_t; 00452 _1_t _1(b, _2); 00453 typedef rf::VisitorNode<A, _1_t> _0_t; 00454 _0_t _0(a, _1); 00455 return _0; 00456 } 00457 00458 template<class A, class B, class C, class D, class E, 00459 class F, class G, class H, class I, class J> 00460 rf::VisitorNode<A, rf::VisitorNode<B, rf::VisitorNode<C, 00461 rf::VisitorNode<D, rf::VisitorNode<E, rf::VisitorNode<F, 00462 rf::VisitorNode<G, rf::VisitorNode<H, rf::VisitorNode<I, 00463 rf::VisitorNode<J> > > > > > > > > > 00464 create_visitor(A & a, B & b, C & c, 00465 D & d, E & e, F & f, 00466 G & g, H & h, I & i, 00467 J & j) 00468 { 00469 typedef rf::VisitorNode<J> _9_t; 00470 _9_t _9(j); 00471 typedef rf::VisitorNode<I, _9_t> _8_t; 00472 _8_t _8(i, _9); 00473 typedef rf::VisitorNode<H, _8_t> _7_t; 00474 _7_t _7(h, _8); 00475 typedef rf::VisitorNode<G, _7_t> _6_t; 00476 _6_t _6(g, _7); 00477 typedef rf::VisitorNode<F, _6_t> _5_t; 00478 _5_t _5(f, _6); 00479 typedef rf::VisitorNode<E, _5_t> _4_t; 00480 _4_t _4(e, _5); 00481 typedef rf::VisitorNode<D, _4_t> _3_t; 00482 _3_t _3(d, _4); 00483 typedef rf::VisitorNode<C, _3_t> _2_t; 00484 _2_t _2(c, _3); 00485 typedef rf::VisitorNode<B, _2_t> _1_t; 00486 _1_t _1(b, _2); 00487 typedef rf::VisitorNode<A, _1_t> _0_t; 00488 _0_t _0(a, _1); 00489 return _0; 00490 } 00491 00492 ////////////////////////////////////////////////////////////////////////////// 00493 // Visitors of communal interest. Do not spam this file with stuff // 00494 // nobody wants. // 00495 ////////////////////////////////////////////////////////////////////////////// 00496 00497 00498 /** Vistior to gain information, later needed for online learning. 00499 */ 00500 00501 class OnlineLearnVisitor: public VisitorBase 00502 { 00503 public: 00504 //Set if we adjust thresholds 00505 bool adjust_thresholds; 00506 //Current tree id 00507 int tree_id; 00508 //Last node id for finding parent 00509 int last_node_id; 00510 //Need to now the label for interior node visiting 00511 vigra::Int32 current_label; 00512 //marginal distribution for interior nodes 00513 struct MarginalDistribution 00514 { 00515 ArrayVector<Int32> leftCounts; 00516 Int32 leftTotalCounts; 00517 ArrayVector<Int32> rightCounts; 00518 Int32 rightTotalCounts; 00519 double gap_left; 00520 double gap_right; 00521 }; 00522 typedef ArrayVector<vigra::Int32> IndexList; 00523 00524 //All information for one tree 00525 struct TreeOnlineInformation 00526 { 00527 std::vector<MarginalDistribution> mag_distributions; 00528 std::vector<IndexList> index_lists; 00529 //map for linear index of mag_distiributions 00530 std::map<int,int> interior_to_index; 00531 //map for linear index of index_lists 00532 std::map<int,int> exterior_to_index; 00533 }; 00534 00535 //All trees 00536 std::vector<TreeOnlineInformation> trees_online_information; 00537 00538 /** Initilize, set the number of trees 00539 */ 00540 template<class RF,class PR> 00541 void visit_at_beginning(RF & rf,const PR & pr) 00542 { 00543 tree_id=0; 00544 trees_online_information.resize(rf.options_.tree_count_); 00545 } 00546 00547 /** Reset a tree 00548 */ 00549 void reset_tree(int tree_id) 00550 { 00551 trees_online_information[tree_id].mag_distributions.clear(); 00552 trees_online_information[tree_id].index_lists.clear(); 00553 trees_online_information[tree_id].interior_to_index.clear(); 00554 trees_online_information[tree_id].exterior_to_index.clear(); 00555 } 00556 00557 /** simply increase the tree count 00558 */ 00559 template<class RF, class PR, class SM, class ST> 00560 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index) 00561 { 00562 tree_id++; 00563 } 00564 00565 template<class Tree, class Split, class Region, class Feature_t, class Label_t> 00566 void visit_after_split( Tree & tree, 00567 Split & split, 00568 Region & parent, 00569 Region & leftChild, 00570 Region & rightChild, 00571 Feature_t & features, 00572 Label_t & labels) 00573 { 00574 int linear_index; 00575 int addr=tree.topology_.size(); 00576 if(split.createNode().typeID() == i_ThresholdNode) 00577 { 00578 if(adjust_thresholds) 00579 { 00580 //Store marginal distribution 00581 linear_index=trees_online_information[tree_id].mag_distributions.size(); 00582 trees_online_information[tree_id].interior_to_index[addr]=linear_index; 00583 trees_online_information[tree_id].mag_distributions.push_back(MarginalDistribution()); 00584 00585 trees_online_information[tree_id].mag_distributions.back().leftCounts=leftChild.classCounts_; 00586 trees_online_information[tree_id].mag_distributions.back().rightCounts=rightChild.classCounts_; 00587 00588 trees_online_information[tree_id].mag_distributions.back().leftTotalCounts=leftChild.size_; 00589 trees_online_information[tree_id].mag_distributions.back().rightTotalCounts=rightChild.size_; 00590 //Store the gap 00591 double gap_left,gap_right; 00592 int i; 00593 gap_left=features(leftChild[0],split.bestSplitColumn()); 00594 for(i=1;i<leftChild.size();++i) 00595 if(features(leftChild[i],split.bestSplitColumn())>gap_left) 00596 gap_left=features(leftChild[i],split.bestSplitColumn()); 00597 gap_right=features(rightChild[0],split.bestSplitColumn()); 00598 for(i=1;i<rightChild.size();++i) 00599 if(features(rightChild[i],split.bestSplitColumn())<gap_right) 00600 gap_right=features(rightChild[i],split.bestSplitColumn()); 00601 trees_online_information[tree_id].mag_distributions.back().gap_left=gap_left; 00602 trees_online_information[tree_id].mag_distributions.back().gap_right=gap_right; 00603 } 00604 } 00605 else 00606 { 00607 //Store index list 00608 linear_index=trees_online_information[tree_id].index_lists.size(); 00609 trees_online_information[tree_id].exterior_to_index[addr]=linear_index; 00610 00611 trees_online_information[tree_id].index_lists.push_back(IndexList()); 00612 00613 trees_online_information[tree_id].index_lists.back().resize(parent.size_,0); 00614 std::copy(parent.begin_,parent.end_,trees_online_information[tree_id].index_lists.back().begin()); 00615 } 00616 } 00617 void add_to_index_list(int tree,int node,int index) 00618 { 00619 if(!this->active_) 00620 return; 00621 TreeOnlineInformation &ti=trees_online_information[tree]; 00622 ti.index_lists[ti.exterior_to_index[node]].push_back(index); 00623 } 00624 void move_exterior_node(int src_tree,int src_index,int dst_tree,int dst_index) 00625 { 00626 if(!this->active_) 00627 return; 00628 trees_online_information[dst_tree].exterior_to_index[dst_index]=trees_online_information[src_tree].exterior_to_index[src_index]; 00629 trees_online_information[src_tree].exterior_to_index.erase(src_index); 00630 } 00631 /** do something when visiting a internal node during getToLeaf 00632 * 00633 * remember as last node id, for finding the parent of the last external node 00634 * also: adjust class counts and borders 00635 */ 00636 template<class TR, class IntT, class TopT,class Feat> 00637 void visit_internal_node(TR & tr, IntT index, TopT node_t,Feat & features) 00638 { 00639 last_node_id=index; 00640 if(adjust_thresholds) 00641 { 00642 vigra_assert(node_t==i_ThresholdNode,"We can only visit threshold nodes"); 00643 //Check if we are in the gap 00644 double value=features(0, Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).column()); 00645 TreeOnlineInformation &ti=trees_online_information[tree_id]; 00646 MarginalDistribution &m=ti.mag_distributions[ti.interior_to_index[index]]; 00647 if(value>m.gap_left && value<m.gap_right) 00648 { 00649 //Check which site we want to go 00650 if(m.leftCounts[current_label]/double(m.leftTotalCounts)>m.rightCounts[current_label]/double(m.rightTotalCounts)) 00651 { 00652 //We want to go left 00653 m.gap_left=value; 00654 } 00655 else 00656 { 00657 //We want to go right 00658 m.gap_right=value; 00659 } 00660 Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold()=(m.gap_right+m.gap_left)/2.0; 00661 } 00662 //Adjust class counts 00663 if(value>Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold()) 00664 { 00665 ++m.rightTotalCounts; 00666 ++m.rightCounts[current_label]; 00667 } 00668 else 00669 { 00670 ++m.leftTotalCounts; 00671 ++m.rightCounts[current_label]; 00672 } 00673 } 00674 } 00675 /** do something when visiting a extern node during getToLeaf 00676 * 00677 * Store the new index! 00678 */ 00679 }; 00680 00681 00682 /** Visitor that calculates the oob error of the random forest. 00683 * this is the default visitor used. 00684 * 00685 * To bored to comment each line of this class - trust me it works. 00686 */ 00687 class OOB_Visitor:public VisitorBase 00688 { 00689 public: 00690 double oobError; 00691 int totalOobCount; 00692 ArrayVector<int> oobCount,oobErrorCount; 00693 00694 OOB_Visitor() 00695 : oobError(0.0), 00696 totalOobCount(0) 00697 {} 00698 00699 00700 bool has_value() 00701 { 00702 return true; 00703 } 00704 /** does the basic calculation per tree*/ 00705 template<class RF, class PR, class SM, class ST> 00706 void visit_after_tree( RF& rf, PR & pr, SM & sm, ST & st, int index) 00707 { 00708 //do the first time called. 00709 if(int(oobCount.size()) != rf.ext_param_.row_count_) 00710 { 00711 oobCount.resize(rf.ext_param_.row_count_, 0); 00712 oobErrorCount.resize(rf.ext_param_.row_count_, 0); 00713 } 00714 // go through the samples 00715 for(int l = 0; l < rf.ext_param_.row_count_; ++l) 00716 { 00717 // if the lth sample is oob... 00718 if(!sm.is_used()[l]) 00719 { 00720 ++oobCount[l]; 00721 if( rf.tree(index) 00722 .predictLabel(rowVector(pr.features(), l)) 00723 != pr.response()(l,0)) 00724 { 00725 ++oobErrorCount[l]; 00726 } 00727 } 00728 00729 } 00730 } 00731 00732 /** Does the normalisation 00733 */ 00734 template<class RF, class PR> 00735 void visit_at_end(RF & rf, PR & pr) 00736 { 00737 // do some normalisation 00738 for(int l=0; l < (int)rf.ext_param_.row_count_; ++l) 00739 { 00740 if(oobCount[l]) 00741 { 00742 oobError += double(oobErrorCount[l]) / oobCount[l]; 00743 ++totalOobCount; 00744 } 00745 } 00746 } 00747 00748 //returns value of the learn function. 00749 double return_val() 00750 { 00751 return oobError/totalOobCount; 00752 } 00753 }; 00754 00755 00756 /** calculate variable importance while learning. 00757 */ 00758 class VariableImportanceVisitor : public VisitorBase 00759 { 00760 public: 00761 00762 /** This Array has the same entries as the R - random forest variable 00763 * importance 00764 */ 00765 MultiArray<2, double> variable_importance_; 00766 int repetition_count_; 00767 bool in_place_; 00768 00769 #ifdef HasHDF5 00770 void save(std::string filename, std::string prefix) 00771 { 00772 prefix = "variable_importance_" + prefix; 00773 writeHDF5(filename.c_str(), 00774 prefix.c_str(), 00775 variable_importance_); 00776 } 00777 #endif 00778 00779 VariableImportanceVisitor(int rep_cnt = 10) 00780 : repetition_count_(rep_cnt) 00781 00782 {} 00783 00784 /** calculates impurity decrease based variable importance after every 00785 * split. 00786 */ 00787 template<class Tree, class Split, class Region, class Feature_t, class Label_t> 00788 void visit_after_split( Tree & tree, 00789 Split & split, 00790 Region & parent, 00791 Region & leftChild, 00792 Region & rightChild, 00793 Feature_t & features, 00794 Label_t & labels) 00795 { 00796 //resize to right size when called the first time 00797 00798 Int32 const class_count = tree.ext_param_.class_count_; 00799 Int32 const column_count = tree.ext_param_.column_count_; 00800 if(variable_importance_.size() == 0) 00801 { 00802 00803 variable_importance_ 00804 .reshape(MultiArrayShape<2>::type(column_count, 00805 class_count+2)); 00806 } 00807 00808 if(split.createNode().typeID() == i_ThresholdNode) 00809 { 00810 Node<i_ThresholdNode> node(split.createNode()); 00811 variable_importance_(node.column(),class_count+1) 00812 += split.region_gini_ - split.minGini(); 00813 } 00814 } 00815 00816 /**compute permutation based var imp. 00817 * (Only an Array of size oob_sample_count x 1 is created. 00818 * - apposed to oob_sample_count x feature_count in the other method. 00819 * 00820 * \sa FieldProxy 00821 */ 00822 template<class RF, class PR, class SM, class ST> 00823 void after_tree_ip_impl(RF& rf, PR & pr, SM & sm, ST & st, int index) 00824 { 00825 typedef MultiArrayShape<2>::type Shp_t; 00826 Int32 column_count = rf.ext_param_.column_count_; 00827 Int32 class_count = rf.ext_param_.class_count_; 00828 00829 // remove the const cast on the features (yep , I know what I am 00830 // doing here.) data is not destroyed. 00831 typename PR::Feature_t & features 00832 = const_cast<typename PR::Feature_t &>(pr.features()); 00833 00834 //find the oob indices of current tree. 00835 ArrayVector<Int32> oob_indices; 00836 ArrayVector<Int32>::iterator 00837 iter; 00838 for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii) 00839 if(!sm.is_used()[ii]) 00840 oob_indices.push_back(ii); 00841 00842 //create space to back up a column 00843 std::vector<double> backup_column; 00844 00845 // Random foo 00846 #ifdef CLASSIFIER_TEST 00847 RandomMT19937 random(1); 00848 #else 00849 RandomMT19937 random(RandomSeed); 00850 #endif 00851 UniformIntRandomFunctor<RandomMT19937> 00852 randint(random); 00853 00854 00855 //make some space for the results 00856 MultiArray<2, double> 00857 oob_right(Shp_t(1, class_count + 1)); 00858 MultiArray<2, double> 00859 perm_oob_right (Shp_t(1, class_count + 1)); 00860 00861 00862 // get the oob success rate with the original samples 00863 for(iter = oob_indices.begin(); 00864 iter != oob_indices.end(); 00865 ++iter) 00866 { 00867 if(rf.tree(index) 00868 .predictLabel(rowVector(features, *iter)) 00869 == pr.response()(*iter, 0)) 00870 { 00871 //per class 00872 ++oob_right[pr.response()(*iter,0)]; 00873 //total 00874 ++oob_right[class_count]; 00875 } 00876 } 00877 //get the oob rate after permuting the ii'th dimension. 00878 for(int ii = 0; ii < column_count; ++ii) 00879 { 00880 perm_oob_right.init(0.0); 00881 //make backup of orinal column 00882 backup_column.clear(); 00883 for(iter = oob_indices.begin(); 00884 iter != oob_indices.end(); 00885 ++iter) 00886 { 00887 backup_column.push_back(features(*iter,ii)); 00888 } 00889 00890 //get the oob rate after permuting the ii'th dimension. 00891 for(int rr = 0; rr < repetition_count_; ++rr) 00892 { 00893 //permute dimension. 00894 int n = oob_indices.size(); 00895 for(int jj = 1; jj < n; ++jj) 00896 std::swap(features(oob_indices[jj], ii), 00897 features(oob_indices[randint(jj+1)], ii)); 00898 00899 //get the oob sucess rate after permuting 00900 for(iter = oob_indices.begin(); 00901 iter != oob_indices.end(); 00902 ++iter) 00903 { 00904 if(rf.tree(index) 00905 .predictLabel(rowVector(features, *iter)) 00906 == pr.response()(*iter, 0)) 00907 { 00908 //per class 00909 ++perm_oob_right[pr.response()(*iter, 0)]; 00910 //total 00911 ++perm_oob_right[class_count]; 00912 } 00913 } 00914 } 00915 00916 00917 //normalise and add to the variable_importance array. 00918 perm_oob_right /= repetition_count_; 00919 perm_oob_right -=oob_right; 00920 perm_oob_right *= -1; 00921 perm_oob_right /= oob_indices.size(); 00922 variable_importance_ 00923 .subarray(Shp_t(ii,0), 00924 Shp_t(ii+1,class_count+1)) += perm_oob_right; 00925 //copy back permuted dimension 00926 for(int jj = 0; jj < int(oob_indices.size()); ++jj) 00927 features(oob_indices[jj], ii) = backup_column[jj]; 00928 } 00929 } 00930 00931 /** calculate permutation based impurity after every tree has been 00932 * learned default behaviour is that this happens out of place. 00933 * If you have very big data sets and want to avoid copying of data 00934 * set the in_place_ flag to true. 00935 */ 00936 template<class RF, class PR, class SM, class ST> 00937 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index) 00938 { 00939 after_tree_ip_impl(rf, pr, sm, st, index); 00940 } 00941 00942 /** Normalise variable importance after the number of trees is known. 00943 */ 00944 template<class RF, class PR> 00945 void visit_at_end(RF & rf, PR & pr) 00946 { 00947 variable_importance_ /= rf.trees_.size(); 00948 } 00949 }; 00950 00951 } // namespace vigra 00952 #endif // RF_VISITORS_HXX
© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de) |
html generated using doxygen and Python
|