[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]
00001 /************************************************************************/ 00002 /* */ 00003 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */ 00004 /* */ 00005 /* This file is part of the VIGRA computer vision library. */ 00006 /* The VIGRA Website is */ 00007 /* http://hci.iwr.uni-heidelberg.de/vigra/ */ 00008 /* Please direct questions, bug reports, and contributions to */ 00009 /* ullrich.koethe@iwr.uni-heidelberg.de or */ 00010 /* vigra@informatik.uni-hamburg.de */ 00011 /* */ 00012 /* Permission is hereby granted, free of charge, to any person */ 00013 /* obtaining a copy of this software and associated documentation */ 00014 /* files (the "Software"), to deal in the Software without */ 00015 /* restriction, including without limitation the rights to use, */ 00016 /* copy, modify, merge, publish, distribute, sublicense, and/or */ 00017 /* sell copies of the Software, and to permit persons to whom the */ 00018 /* Software is furnished to do so, subject to the following */ 00019 /* conditions: */ 00020 /* */ 00021 /* The above copyright notice and this permission notice shall be */ 00022 /* included in all copies or substantial portions of the */ 00023 /* Software. */ 00024 /* */ 00025 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */ 00026 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */ 00027 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */ 00028 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */ 00029 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */ 00030 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */ 00031 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */ 00032 /* OTHER DEALINGS IN THE SOFTWARE. */ 00033 /* */ 00034 /************************************************************************/ 00035 #ifndef RF_VISITORS_HXX 00036 #define RF_VISITORS_HXX 00037 00038 #ifdef HasHDF5 00039 # include "vigra/hdf5impex.hxx" 00040 #else 00041 # include "vigra/impex.hxx" 00042 # include "vigra/multi_array.hxx" 00043 # include "vigra/multi_impex.hxx" 00044 # include "vigra/inspectimage.hxx" 00045 #endif // HasHDF5 00046 #include <vigra/windows.h> 00047 #include <iostream> 00048 #include <iomanip> 00049 #include <vigra/timing.hxx> 00050 00051 namespace vigra 00052 { 00053 namespace rf 00054 { 00055 /** \addtogroup MachineLearning Machine Learning 00056 **/ 00057 //@{ 00058 00059 /** 00060 This namespace contains all classes and methods related to extracting information during 00061 learning of the random forest. All Visitors share the same interface defined in 00062 visitors::VisitorBase. The member methods are invoked at certain points of the main code in 00063 the order they were supplied. 00064 00065 For the Random Forest the Visitor concept is implemented as a statically linked list 00066 (Using templates). Each Visitor object is encapsulated in a detail::VisitorNode object. The 00067 VisitorNode object calls the Next Visitor after one of its visit() methods have terminated. 00068 00069 To simplify usage create_visitor() factory methods are supplied. 00070 Use the create_visitor() method to supply visitor objects to the RandomForest::learn() method. 00071 It is possible to supply more than one visitor. They will then be invoked in serial order. 00072 00073 The calculated information are stored as public data members of the class. - see documentation 00074 of the individual visitors 00075 00076 While creating a new visitor the new class should therefore publicly inherit from this class 00077 (i.e.: see visitors::OOB_Error). 00078 00079 \code 00080 00081 typedef xxx feature_t \\ replace xxx with whichever type 00082 typedef yyy label_t \\ meme chose. 00083 MultiArrayView<2, feature_t> f = get_some_features(); 00084 MultiArrayView<2, label_t> l = get_some_labels(); 00085 RandomForest<> rf() 00086 00087 //calculate OOB Error 00088 visitors::OOB_Error oob_v; 00089 //calculate Variable Importance 00090 visitors::VariableImportanceVisitor varimp_v; 00091 00092 double oob_error = rf.learn(f, l, visitors::create_visitor(oob_v, varimp_v); 00093 //the data can be found in the attributes of oob_v and varimp_v now 00094 00095 \endcode 00096 */ 00097 namespace visitors 00098 { 00099 00100 00101 /** Base Class from which all Visitors derive. Can be used as a template to create new 00102 * Visitors. 00103 */ 00104 class VisitorBase 00105 { 00106 public: 00107 bool active_; 00108 bool is_active() 00109 { 00110 return active_; 00111 } 00112 00113 bool has_value() 00114 { 00115 return false; 00116 } 00117 00118 VisitorBase() 00119 : active_(true) 00120 {} 00121 00122 void deactivate() 00123 { 00124 active_ = false; 00125 } 00126 void activate() 00127 { 00128 active_ = true; 00129 } 00130 00131 /** do something after the the Split has decided how to process the Region 00132 * (Stack entry) 00133 * 00134 * \param tree reference to the tree that is currently being learned 00135 * \param split reference to the split object 00136 * \param parent current stack entry which was used to decide the split 00137 * \param leftChild left stack entry that will be pushed 00138 * \param rightChild 00139 * right stack entry that will be pushed. 00140 * \param features features matrix 00141 * \param labels label matrix 00142 * \sa RF_Traits::StackEntry_t 00143 */ 00144 template<class Tree, class Split, class Region, class Feature_t, class Label_t> 00145 void visit_after_split( Tree & tree, 00146 Split & split, 00147 Region & parent, 00148 Region & leftChild, 00149 Region & rightChild, 00150 Feature_t & features, 00151 Label_t & labels) 00152 {} 00153 00154 /** do something after each tree has been learned 00155 * 00156 * \param rf reference to the random forest object that called this 00157 * visitor 00158 * \param pr reference to the preprocessor that processed the input 00159 * \param sm reference to the sampler object 00160 * \param st reference to the first stack entry 00161 * \param index index of current tree 00162 */ 00163 template<class RF, class PR, class SM, class ST> 00164 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index) 00165 {} 00166 00167 /** do something after all trees have been learned 00168 * 00169 * \param rf reference to the random forest object that called this 00170 * visitor 00171 * \param pr reference to the preprocessor that processed the input 00172 */ 00173 template<class RF, class PR> 00174 void visit_at_end(RF const & rf, PR const & pr) 00175 {} 00176 00177 /** do something before learning starts 00178 * 00179 * \param rf reference to the random forest object that called this 00180 * visitor 00181 * \param pr reference to the Processor class used. 00182 */ 00183 template<class RF, class PR> 00184 void visit_at_beginning(RF const & rf, PR const & pr) 00185 {} 00186 /** do some thing while traversing tree after it has been learned 00187 * (external nodes) 00188 * 00189 * \param tr reference to the tree object that called this visitor 00190 * \param index index in the topology_ array we currently are at 00191 * \param node_t type of node we have (will be e_.... - ) 00192 * \param weight Node weight of current node. 00193 * \sa NodeTags; 00194 * 00195 * you can create the node by using a switch on node_tag and using the 00196 * corresponding Node objects. Or - if you do not care about the type 00197 * use the Nodebase class. 00198 */ 00199 template<class TR, class IntT, class TopT,class Feat> 00200 void visit_external_node(TR & tr, IntT index, TopT node_t,Feat & features) 00201 {} 00202 00203 /** do something when visiting a internal node after it has been learned 00204 * 00205 * \sa visit_external_node 00206 */ 00207 template<class TR, class IntT, class TopT,class Feat> 00208 void visit_internal_node(TR & tr, IntT index, TopT node_t,Feat & features) 00209 {} 00210 00211 /** return a double value. The value of the first 00212 * visitor encountered that has a return value is returned with the 00213 * RandomForest::learn() method - or -1.0 if no return value visitor 00214 * existed. This functionality basically only exists so that the 00215 * OOB - visitor can return the oob error rate like in the old version 00216 * of the random forest. 00217 */ 00218 double return_val() 00219 { 00220 return -1.0; 00221 } 00222 }; 00223 00224 00225 /** Last Visitor that should be called to stop the recursion. 00226 */ 00227 class StopVisiting: public VisitorBase 00228 { 00229 public: 00230 bool has_value() 00231 { 00232 return true; 00233 } 00234 double return_val() 00235 { 00236 return -1.0; 00237 } 00238 }; 00239 namespace detail 00240 { 00241 /** Container elements of the statically linked Visitor list. 00242 * 00243 * use the create_visitor() factory functions to create visitors up to size 10; 00244 * 00245 */ 00246 template <class Visitor, class Next = StopVisiting> 00247 class VisitorNode 00248 { 00249 public: 00250 00251 StopVisiting stop_; 00252 Next next_; 00253 Visitor & visitor_; 00254 VisitorNode(Visitor & visitor, Next & next) 00255 : 00256 next_(next), visitor_(visitor) 00257 {} 00258 00259 VisitorNode(Visitor & visitor) 00260 : 00261 next_(stop_), visitor_(visitor) 00262 {} 00263 00264 template<class Tree, class Split, class Region, class Feature_t, class Label_t> 00265 void visit_after_split( Tree & tree, 00266 Split & split, 00267 Region & parent, 00268 Region & leftChild, 00269 Region & rightChild, 00270 Feature_t & features, 00271 Label_t & labels) 00272 { 00273 if(visitor_.is_active()) 00274 visitor_.visit_after_split(tree, split, 00275 parent, leftChild, rightChild, 00276 features, labels); 00277 next_.visit_after_split(tree, split, parent, leftChild, rightChild, 00278 features, labels); 00279 } 00280 00281 template<class RF, class PR, class SM, class ST> 00282 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index) 00283 { 00284 if(visitor_.is_active()) 00285 visitor_.visit_after_tree(rf, pr, sm, st, index); 00286 next_.visit_after_tree(rf, pr, sm, st, index); 00287 } 00288 00289 template<class RF, class PR> 00290 void visit_at_beginning(RF & rf, PR & pr) 00291 { 00292 if(visitor_.is_active()) 00293 visitor_.visit_at_beginning(rf, pr); 00294 next_.visit_at_beginning(rf, pr); 00295 } 00296 template<class RF, class PR> 00297 void visit_at_end(RF & rf, PR & pr) 00298 { 00299 if(visitor_.is_active()) 00300 visitor_.visit_at_end(rf, pr); 00301 next_.visit_at_end(rf, pr); 00302 } 00303 00304 template<class TR, class IntT, class TopT,class Feat> 00305 void visit_external_node(TR & tr, IntT & index, TopT & node_t,Feat & features) 00306 { 00307 if(visitor_.is_active()) 00308 visitor_.visit_external_node(tr, index, node_t,features); 00309 next_.visit_external_node(tr, index, node_t,features); 00310 } 00311 template<class TR, class IntT, class TopT,class Feat> 00312 void visit_internal_node(TR & tr, IntT & index, TopT & node_t,Feat & features) 00313 { 00314 if(visitor_.is_active()) 00315 visitor_.visit_internal_node(tr, index, node_t,features); 00316 next_.visit_internal_node(tr, index, node_t,features); 00317 } 00318 00319 double return_val() 00320 { 00321 if(visitor_.is_active() && visitor_.has_value()) 00322 return visitor_.return_val(); 00323 return next_.return_val(); 00324 } 00325 }; 00326 00327 } //namespace detail 00328 00329 ////////////////////////////////////////////////////////////////////////////// 00330 // Visitor Factory function up to 10 visitors // 00331 ////////////////////////////////////////////////////////////////////////////// 00332 00333 /** factory method to to be used with RandomForest::learn() 00334 */ 00335 template<class A> 00336 detail::VisitorNode<A> 00337 create_visitor(A & a) 00338 { 00339 typedef detail::VisitorNode<A> _0_t; 00340 _0_t _0(a); 00341 return _0; 00342 } 00343 00344 00345 /** factory method to to be used with RandomForest::learn() 00346 */ 00347 template<class A, class B> 00348 detail::VisitorNode<A, detail::VisitorNode<B> > 00349 create_visitor(A & a, B & b) 00350 { 00351 typedef detail::VisitorNode<B> _1_t; 00352 _1_t _1(b); 00353 typedef detail::VisitorNode<A, _1_t> _0_t; 00354 _0_t _0(a, _1); 00355 return _0; 00356 } 00357 00358 00359 /** factory method to to be used with RandomForest::learn() 00360 */ 00361 template<class A, class B, class C> 00362 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C> > > 00363 create_visitor(A & a, B & b, C & c) 00364 { 00365 typedef detail::VisitorNode<C> _2_t; 00366 _2_t _2(c); 00367 typedef detail::VisitorNode<B, _2_t> _1_t; 00368 _1_t _1(b, _2); 00369 typedef detail::VisitorNode<A, _1_t> _0_t; 00370 _0_t _0(a, _1); 00371 return _0; 00372 } 00373 00374 00375 /** factory method to to be used with RandomForest::learn() 00376 */ 00377 template<class A, class B, class C, class D> 00378 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 00379 detail::VisitorNode<D> > > > 00380 create_visitor(A & a, B & b, C & c, D & d) 00381 { 00382 typedef detail::VisitorNode<D> _3_t; 00383 _3_t _3(d); 00384 typedef detail::VisitorNode<C, _3_t> _2_t; 00385 _2_t _2(c, _3); 00386 typedef detail::VisitorNode<B, _2_t> _1_t; 00387 _1_t _1(b, _2); 00388 typedef detail::VisitorNode<A, _1_t> _0_t; 00389 _0_t _0(a, _1); 00390 return _0; 00391 } 00392 00393 00394 /** factory method to to be used with RandomForest::learn() 00395 */ 00396 template<class A, class B, class C, class D, class E> 00397 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 00398 detail::VisitorNode<D, detail::VisitorNode<E> > > > > 00399 create_visitor(A & a, B & b, C & c, 00400 D & d, E & e) 00401 { 00402 typedef detail::VisitorNode<E> _4_t; 00403 _4_t _4(e); 00404 typedef detail::VisitorNode<D, _4_t> _3_t; 00405 _3_t _3(d, _4); 00406 typedef detail::VisitorNode<C, _3_t> _2_t; 00407 _2_t _2(c, _3); 00408 typedef detail::VisitorNode<B, _2_t> _1_t; 00409 _1_t _1(b, _2); 00410 typedef detail::VisitorNode<A, _1_t> _0_t; 00411 _0_t _0(a, _1); 00412 return _0; 00413 } 00414 00415 00416 /** factory method to to be used with RandomForest::learn() 00417 */ 00418 template<class A, class B, class C, class D, class E, 00419 class F> 00420 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 00421 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F> > > > > > 00422 create_visitor(A & a, B & b, C & c, 00423 D & d, E & e, F & f) 00424 { 00425 typedef detail::VisitorNode<F> _5_t; 00426 _5_t _5(f); 00427 typedef detail::VisitorNode<E, _5_t> _4_t; 00428 _4_t _4(e, _5); 00429 typedef detail::VisitorNode<D, _4_t> _3_t; 00430 _3_t _3(d, _4); 00431 typedef detail::VisitorNode<C, _3_t> _2_t; 00432 _2_t _2(c, _3); 00433 typedef detail::VisitorNode<B, _2_t> _1_t; 00434 _1_t _1(b, _2); 00435 typedef detail::VisitorNode<A, _1_t> _0_t; 00436 _0_t _0(a, _1); 00437 return _0; 00438 } 00439 00440 00441 /** factory method to to be used with RandomForest::learn() 00442 */ 00443 template<class A, class B, class C, class D, class E, 00444 class F, class G> 00445 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 00446 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F, 00447 detail::VisitorNode<G> > > > > > > 00448 create_visitor(A & a, B & b, C & c, 00449 D & d, E & e, F & f, G & g) 00450 { 00451 typedef detail::VisitorNode<G> _6_t; 00452 _6_t _6(g); 00453 typedef detail::VisitorNode<F, _6_t> _5_t; 00454 _5_t _5(f, _6); 00455 typedef detail::VisitorNode<E, _5_t> _4_t; 00456 _4_t _4(e, _5); 00457 typedef detail::VisitorNode<D, _4_t> _3_t; 00458 _3_t _3(d, _4); 00459 typedef detail::VisitorNode<C, _3_t> _2_t; 00460 _2_t _2(c, _3); 00461 typedef detail::VisitorNode<B, _2_t> _1_t; 00462 _1_t _1(b, _2); 00463 typedef detail::VisitorNode<A, _1_t> _0_t; 00464 _0_t _0(a, _1); 00465 return _0; 00466 } 00467 00468 00469 /** factory method to to be used with RandomForest::learn() 00470 */ 00471 template<class A, class B, class C, class D, class E, 00472 class F, class G, class H> 00473 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 00474 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F, 00475 detail::VisitorNode<G, detail::VisitorNode<H> > > > > > > > 00476 create_visitor(A & a, B & b, C & c, 00477 D & d, E & e, F & f, 00478 G & g, H & h) 00479 { 00480 typedef detail::VisitorNode<H> _7_t; 00481 _7_t _7(h); 00482 typedef detail::VisitorNode<G, _7_t> _6_t; 00483 _6_t _6(g, _7); 00484 typedef detail::VisitorNode<F, _6_t> _5_t; 00485 _5_t _5(f, _6); 00486 typedef detail::VisitorNode<E, _5_t> _4_t; 00487 _4_t _4(e, _5); 00488 typedef detail::VisitorNode<D, _4_t> _3_t; 00489 _3_t _3(d, _4); 00490 typedef detail::VisitorNode<C, _3_t> _2_t; 00491 _2_t _2(c, _3); 00492 typedef detail::VisitorNode<B, _2_t> _1_t; 00493 _1_t _1(b, _2); 00494 typedef detail::VisitorNode<A, _1_t> _0_t; 00495 _0_t _0(a, _1); 00496 return _0; 00497 } 00498 00499 00500 /** factory method to to be used with RandomForest::learn() 00501 */ 00502 template<class A, class B, class C, class D, class E, 00503 class F, class G, class H, class I> 00504 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 00505 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F, 00506 detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I> > > > > > > > > 00507 create_visitor(A & a, B & b, C & c, 00508 D & d, E & e, F & f, 00509 G & g, H & h, I & i) 00510 { 00511 typedef detail::VisitorNode<I> _8_t; 00512 _8_t _8(i); 00513 typedef detail::VisitorNode<H, _8_t> _7_t; 00514 _7_t _7(h, _8); 00515 typedef detail::VisitorNode<G, _7_t> _6_t; 00516 _6_t _6(g, _7); 00517 typedef detail::VisitorNode<F, _6_t> _5_t; 00518 _5_t _5(f, _6); 00519 typedef detail::VisitorNode<E, _5_t> _4_t; 00520 _4_t _4(e, _5); 00521 typedef detail::VisitorNode<D, _4_t> _3_t; 00522 _3_t _3(d, _4); 00523 typedef detail::VisitorNode<C, _3_t> _2_t; 00524 _2_t _2(c, _3); 00525 typedef detail::VisitorNode<B, _2_t> _1_t; 00526 _1_t _1(b, _2); 00527 typedef detail::VisitorNode<A, _1_t> _0_t; 00528 _0_t _0(a, _1); 00529 return _0; 00530 } 00531 00532 /** factory method to to be used with RandomForest::learn() 00533 */ 00534 template<class A, class B, class C, class D, class E, 00535 class F, class G, class H, class I, class J> 00536 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 00537 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F, 00538 detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I, 00539 detail::VisitorNode<J> > > > > > > > > > 00540 create_visitor(A & a, B & b, C & c, 00541 D & d, E & e, F & f, 00542 G & g, H & h, I & i, 00543 J & j) 00544 { 00545 typedef detail::VisitorNode<J> _9_t; 00546 _9_t _9(j); 00547 typedef detail::VisitorNode<I, _9_t> _8_t; 00548 _8_t _8(i, _9); 00549 typedef detail::VisitorNode<H, _8_t> _7_t; 00550 _7_t _7(h, _8); 00551 typedef detail::VisitorNode<G, _7_t> _6_t; 00552 _6_t _6(g, _7); 00553 typedef detail::VisitorNode<F, _6_t> _5_t; 00554 _5_t _5(f, _6); 00555 typedef detail::VisitorNode<E, _5_t> _4_t; 00556 _4_t _4(e, _5); 00557 typedef detail::VisitorNode<D, _4_t> _3_t; 00558 _3_t _3(d, _4); 00559 typedef detail::VisitorNode<C, _3_t> _2_t; 00560 _2_t _2(c, _3); 00561 typedef detail::VisitorNode<B, _2_t> _1_t; 00562 _1_t _1(b, _2); 00563 typedef detail::VisitorNode<A, _1_t> _0_t; 00564 _0_t _0(a, _1); 00565 return _0; 00566 } 00567 00568 ////////////////////////////////////////////////////////////////////////////// 00569 // Visitors of communal interest. // 00570 ////////////////////////////////////////////////////////////////////////////// 00571 00572 00573 /** Visitor to gain information, later needed for online learning. 00574 */ 00575 00576 class OnlineLearnVisitor: public VisitorBase 00577 { 00578 public: 00579 //Set if we adjust thresholds 00580 bool adjust_thresholds; 00581 //Current tree id 00582 int tree_id; 00583 //Last node id for finding parent 00584 int last_node_id; 00585 //Need to now the label for interior node visiting 00586 vigra::Int32 current_label; 00587 //marginal distribution for interior nodes 00588 struct MarginalDistribution 00589 { 00590 ArrayVector<Int32> leftCounts; 00591 Int32 leftTotalCounts; 00592 ArrayVector<Int32> rightCounts; 00593 Int32 rightTotalCounts; 00594 double gap_left; 00595 double gap_right; 00596 }; 00597 typedef ArrayVector<vigra::Int32> IndexList; 00598 00599 //All information for one tree 00600 struct TreeOnlineInformation 00601 { 00602 std::vector<MarginalDistribution> mag_distributions; 00603 std::vector<IndexList> index_lists; 00604 //map for linear index of mag_distiributions 00605 std::map<int,int> interior_to_index; 00606 //map for linear index of index_lists 00607 std::map<int,int> exterior_to_index; 00608 }; 00609 00610 //All trees 00611 std::vector<TreeOnlineInformation> trees_online_information; 00612 00613 /** Initilize, set the number of trees 00614 */ 00615 template<class RF,class PR> 00616 void visit_at_beginning(RF & rf,const PR & pr) 00617 { 00618 tree_id=0; 00619 trees_online_information.resize(rf.options_.tree_count_); 00620 } 00621 00622 /** Reset a tree 00623 */ 00624 void reset_tree(int tree_id) 00625 { 00626 trees_online_information[tree_id].mag_distributions.clear(); 00627 trees_online_information[tree_id].index_lists.clear(); 00628 trees_online_information[tree_id].interior_to_index.clear(); 00629 trees_online_information[tree_id].exterior_to_index.clear(); 00630 } 00631 00632 /** simply increase the tree count 00633 */ 00634 template<class RF, class PR, class SM, class ST> 00635 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index) 00636 { 00637 tree_id++; 00638 } 00639 00640 template<class Tree, class Split, class Region, class Feature_t, class Label_t> 00641 void visit_after_split( Tree & tree, 00642 Split & split, 00643 Region & parent, 00644 Region & leftChild, 00645 Region & rightChild, 00646 Feature_t & features, 00647 Label_t & labels) 00648 { 00649 int linear_index; 00650 int addr=tree.topology_.size(); 00651 if(split.createNode().typeID() == i_ThresholdNode) 00652 { 00653 if(adjust_thresholds) 00654 { 00655 //Store marginal distribution 00656 linear_index=trees_online_information[tree_id].mag_distributions.size(); 00657 trees_online_information[tree_id].interior_to_index[addr]=linear_index; 00658 trees_online_information[tree_id].mag_distributions.push_back(MarginalDistribution()); 00659 00660 trees_online_information[tree_id].mag_distributions.back().leftCounts=leftChild.classCounts_; 00661 trees_online_information[tree_id].mag_distributions.back().rightCounts=rightChild.classCounts_; 00662 00663 trees_online_information[tree_id].mag_distributions.back().leftTotalCounts=leftChild.size_; 00664 trees_online_information[tree_id].mag_distributions.back().rightTotalCounts=rightChild.size_; 00665 //Store the gap 00666 double gap_left,gap_right; 00667 int i; 00668 gap_left=features(leftChild[0],split.bestSplitColumn()); 00669 for(i=1;i<leftChild.size();++i) 00670 if(features(leftChild[i],split.bestSplitColumn())>gap_left) 00671 gap_left=features(leftChild[i],split.bestSplitColumn()); 00672 gap_right=features(rightChild[0],split.bestSplitColumn()); 00673 for(i=1;i<rightChild.size();++i) 00674 if(features(rightChild[i],split.bestSplitColumn())<gap_right) 00675 gap_right=features(rightChild[i],split.bestSplitColumn()); 00676 trees_online_information[tree_id].mag_distributions.back().gap_left=gap_left; 00677 trees_online_information[tree_id].mag_distributions.back().gap_right=gap_right; 00678 } 00679 } 00680 else 00681 { 00682 //Store index list 00683 linear_index=trees_online_information[tree_id].index_lists.size(); 00684 trees_online_information[tree_id].exterior_to_index[addr]=linear_index; 00685 00686 trees_online_information[tree_id].index_lists.push_back(IndexList()); 00687 00688 trees_online_information[tree_id].index_lists.back().resize(parent.size_,0); 00689 std::copy(parent.begin_,parent.end_,trees_online_information[tree_id].index_lists.back().begin()); 00690 } 00691 } 00692 void add_to_index_list(int tree,int node,int index) 00693 { 00694 if(!this->active_) 00695 return; 00696 TreeOnlineInformation &ti=trees_online_information[tree]; 00697 ti.index_lists[ti.exterior_to_index[node]].push_back(index); 00698 } 00699 void move_exterior_node(int src_tree,int src_index,int dst_tree,int dst_index) 00700 { 00701 if(!this->active_) 00702 return; 00703 trees_online_information[dst_tree].exterior_to_index[dst_index]=trees_online_information[src_tree].exterior_to_index[src_index]; 00704 trees_online_information[src_tree].exterior_to_index.erase(src_index); 00705 } 00706 /** do something when visiting a internal node during getToLeaf 00707 * 00708 * remember as last node id, for finding the parent of the last external node 00709 * also: adjust class counts and borders 00710 */ 00711 template<class TR, class IntT, class TopT,class Feat> 00712 void visit_internal_node(TR & tr, IntT index, TopT node_t,Feat & features) 00713 { 00714 last_node_id=index; 00715 if(adjust_thresholds) 00716 { 00717 vigra_assert(node_t==i_ThresholdNode,"We can only visit threshold nodes"); 00718 //Check if we are in the gap 00719 double value=features(0, Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).column()); 00720 TreeOnlineInformation &ti=trees_online_information[tree_id]; 00721 MarginalDistribution &m=ti.mag_distributions[ti.interior_to_index[index]]; 00722 if(value>m.gap_left && value<m.gap_right) 00723 { 00724 //Check which site we want to go 00725 if(m.leftCounts[current_label]/double(m.leftTotalCounts)>m.rightCounts[current_label]/double(m.rightTotalCounts)) 00726 { 00727 //We want to go left 00728 m.gap_left=value; 00729 } 00730 else 00731 { 00732 //We want to go right 00733 m.gap_right=value; 00734 } 00735 Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold()=(m.gap_right+m.gap_left)/2.0; 00736 } 00737 //Adjust class counts 00738 if(value>Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold()) 00739 { 00740 ++m.rightTotalCounts; 00741 ++m.rightCounts[current_label]; 00742 } 00743 else 00744 { 00745 ++m.leftTotalCounts; 00746 ++m.rightCounts[current_label]; 00747 } 00748 } 00749 } 00750 /** do something when visiting a extern node during getToLeaf 00751 * 00752 * Store the new index! 00753 */ 00754 }; 00755 00756 ////////////////////////////////////////////////////////////////////////////// 00757 // Out of Bag Error estimates // 00758 ////////////////////////////////////////////////////////////////////////////// 00759 00760 00761 /** Visitor that calculates the oob error of each individual randomized 00762 * decision tree. 00763 * 00764 * After training a tree, all those samples that are OOB for this particular tree 00765 * are put down the tree and the error estimated. 00766 * the per tree oob error is the average of the individual error estimates. 00767 * (oobError = average error of one randomized tree) 00768 * Note: This is Not the OOB - Error estimate suggested by Breiman (See OOB_Error 00769 * visitor) 00770 */ 00771 class OOB_PerTreeError:public VisitorBase 00772 { 00773 public: 00774 /** Average error of one randomized decision tree 00775 */ 00776 double oobError; 00777 00778 int totalOobCount; 00779 ArrayVector<int> oobCount,oobErrorCount; 00780 00781 OOB_PerTreeError() 00782 : oobError(0.0), 00783 totalOobCount(0) 00784 {} 00785 00786 00787 bool has_value() 00788 { 00789 return true; 00790 } 00791 00792 00793 /** does the basic calculation per tree*/ 00794 template<class RF, class PR, class SM, class ST> 00795 void visit_after_tree( RF& rf, PR & pr, SM & sm, ST & st, int index) 00796 { 00797 //do the first time called. 00798 if(int(oobCount.size()) != rf.ext_param_.row_count_) 00799 { 00800 oobCount.resize(rf.ext_param_.row_count_, 0); 00801 oobErrorCount.resize(rf.ext_param_.row_count_, 0); 00802 } 00803 // go through the samples 00804 for(int l = 0; l < rf.ext_param_.row_count_; ++l) 00805 { 00806 // if the lth sample is oob... 00807 if(!sm.is_used()[l]) 00808 { 00809 ++oobCount[l]; 00810 if( rf.tree(index) 00811 .predictLabel(rowVector(pr.features(), l)) 00812 != pr.response()(l,0)) 00813 { 00814 ++oobErrorCount[l]; 00815 } 00816 } 00817 00818 } 00819 } 00820 00821 /** Does the normalisation 00822 */ 00823 template<class RF, class PR> 00824 void visit_at_end(RF & rf, PR & pr) 00825 { 00826 // do some normalisation 00827 for(int l=0; l < (int)rf.ext_param_.row_count_; ++l) 00828 { 00829 if(oobCount[l]) 00830 { 00831 oobError += double(oobErrorCount[l]) / oobCount[l]; 00832 ++totalOobCount; 00833 } 00834 } 00835 oobError/=totalOobCount; 00836 } 00837 00838 }; 00839 00840 /** Visitor that calculates the oob error of the ensemble 00841 * This rate should be used to estimate the crossvalidation 00842 * error rate. 00843 * Here each sample is put down those trees, for which this sample 00844 * is OOB i.e. if sample #1 is OOB for trees 1, 3 and 5 we calculate 00845 * the output using the ensemble consisting only of trees 1 3 and 5. 00846 * 00847 * Using normal bagged sampling each sample is OOB for approx. 33% of trees 00848 * The error rate obtained as such therefore corresponds to crossvalidation 00849 * rate obtained using a ensemble containing 33% of the trees. 00850 */ 00851 class OOB_Error : public VisitorBase 00852 { 00853 typedef MultiArrayShape<2>::type Shp; 00854 int class_count; 00855 bool is_weighted; 00856 MultiArray<2,double> tmp_prob; 00857 public: 00858 00859 MultiArray<2, double> prob_oob; 00860 /** Ensemble oob error rate 00861 */ 00862 double oob_breiman; 00863 00864 MultiArray<2, double> oobCount; 00865 ArrayVector< int> indices; 00866 OOB_Error() : VisitorBase(), oob_breiman(0.0) {} 00867 00868 #ifdef HasHDF5 00869 void save(std::string filen, std::string pathn) 00870 { 00871 if(*(pathn.end()-1) != '/') 00872 pathn += "/"; 00873 const char* filename = filen.c_str(); 00874 MultiArray<2, double> temp(Shp(1,1), 0.0); 00875 temp[0] = oob_breiman; 00876 writeHDF5(filename, (pathn + "breiman_error").c_str(), temp); 00877 } 00878 #endif 00879 // negative value if sample was ib, number indicates how often. 00880 // value >=0 if sample was oob, 0 means fail 1, corrrect 00881 00882 template<class RF, class PR> 00883 void visit_at_beginning(RF & rf, PR & pr) 00884 { 00885 class_count = rf.class_count(); 00886 tmp_prob.reshape(Shp(1, class_count), 0); 00887 prob_oob.reshape(Shp(rf.ext_param().row_count_,class_count), 0); 00888 is_weighted = rf.options().predict_weighted_; 00889 indices.resize(rf.ext_param().row_count_); 00890 if(int(oobCount.size()) != rf.ext_param_.row_count_) 00891 { 00892 oobCount.reshape(Shp(rf.ext_param_.row_count_, 1), 0); 00893 } 00894 for(int ii = 0; ii < rf.ext_param().row_count_; ++ii) 00895 { 00896 indices[ii] = ii; 00897 } 00898 } 00899 00900 template<class RF, class PR, class SM, class ST> 00901 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index) 00902 { 00903 // go through the samples 00904 int total_oob =0; 00905 int wrong_oob =0; 00906 // FIXME: magic number 10000: invoke special treatment when when msample << sample_count 00907 // (i.e. the OOB sample ist very large) 00908 // 40000: use at most 40000 OOB samples per class for OOB error estimate 00909 if(rf.ext_param_.actual_msample_ < pr.features().shape(0) - 10000) 00910 { 00911 ArrayVector<int> oob_indices; 00912 ArrayVector<int> cts(class_count, 0); 00913 std::random_shuffle(indices.begin(), indices.end()); 00914 for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii) 00915 { 00916 if(!sm.is_used()[indices[ii]] && cts[pr.response()(indices[ii], 0)] < 40000) 00917 { 00918 oob_indices.push_back(indices[ii]); 00919 ++cts[pr.response()(indices[ii], 0)]; 00920 } 00921 } 00922 for(unsigned int ll = 0; ll < oob_indices.size(); ++ll) 00923 { 00924 // update number of trees in which current sample is oob 00925 ++oobCount[oob_indices[ll]]; 00926 00927 // update number of oob samples in this tree. 00928 ++total_oob; 00929 // get the predicted votes ---> tmp_prob; 00930 int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),oob_indices[ll])); 00931 Node<e_ConstProbNode> node ( rf.tree(index).topology_, 00932 rf.tree(index).parameters_, 00933 pos); 00934 tmp_prob.init(0); 00935 for(int ii = 0; ii < class_count; ++ii) 00936 { 00937 tmp_prob[ii] = node.prob_begin()[ii]; 00938 } 00939 if(is_weighted) 00940 { 00941 for(int ii = 0; ii < class_count; ++ii) 00942 tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1)); 00943 } 00944 rowVector(prob_oob, oob_indices[ll]) += tmp_prob; 00945 int label = argMax(tmp_prob); 00946 00947 } 00948 }else 00949 { 00950 for(int ll = 0; ll < rf.ext_param_.row_count_; ++ll) 00951 { 00952 // if the lth sample is oob... 00953 if(!sm.is_used()[ll]) 00954 { 00955 // update number of trees in which current sample is oob 00956 ++oobCount[ll]; 00957 00958 // update number of oob samples in this tree. 00959 ++total_oob; 00960 // get the predicted votes ---> tmp_prob; 00961 int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),ll)); 00962 Node<e_ConstProbNode> node ( rf.tree(index).topology_, 00963 rf.tree(index).parameters_, 00964 pos); 00965 tmp_prob.init(0); 00966 for(int ii = 0; ii < class_count; ++ii) 00967 { 00968 tmp_prob[ii] = node.prob_begin()[ii]; 00969 } 00970 if(is_weighted) 00971 { 00972 for(int ii = 0; ii < class_count; ++ii) 00973 tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1)); 00974 } 00975 rowVector(prob_oob, ll) += tmp_prob; 00976 int label = argMax(tmp_prob); 00977 00978 } 00979 } 00980 } 00981 // go through the ib samples; 00982 } 00983 00984 /** Normalise variable importance after the number of trees is known. 00985 */ 00986 template<class RF, class PR> 00987 void visit_at_end(RF & rf, PR & pr) 00988 { 00989 // ullis original metric and breiman style stuff 00990 int totalOobCount =0; 00991 int breimanstyle = 0; 00992 for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll) 00993 { 00994 if(oobCount[ll]) 00995 { 00996 if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0)) 00997 ++breimanstyle; 00998 ++totalOobCount; 00999 } 01000 } 01001 oob_breiman = double(breimanstyle)/totalOobCount; 01002 } 01003 }; 01004 01005 01006 /** Visitor that calculates different OOB error statistics 01007 */ 01008 class CompleteOOBInfo : public VisitorBase 01009 { 01010 typedef MultiArrayShape<2>::type Shp; 01011 int class_count; 01012 bool is_weighted; 01013 MultiArray<2,double> tmp_prob; 01014 public: 01015 01016 /** OOB Error rate of each individual tree 01017 */ 01018 MultiArray<2, double> oob_per_tree; 01019 /** Mean of oob_per_tree 01020 */ 01021 double oob_mean; 01022 /**Standard deviation of oob_per_tree 01023 */ 01024 double oob_std; 01025 01026 MultiArray<2, double> prob_oob; 01027 /** Ensemble OOB error 01028 * 01029 * \sa OOB_Error 01030 */ 01031 double oob_breiman; 01032 01033 MultiArray<2, double> oobCount; 01034 MultiArray<2, double> oobErrorCount; 01035 /** Per Tree OOB error calculated as in OOB_PerTreeError 01036 * (Ulli's version) 01037 */ 01038 double oob_per_tree2; 01039 01040 /**Column containing the development of the Ensemble 01041 * error rate with increasing number of trees 01042 */ 01043 MultiArray<2, double> breiman_per_tree; 01044 /** 4 dimensional array containing the development of confusion matrices 01045 * with number of trees - can be used to estimate ROC curves etc. 01046 * 01047 * oobroc_per_tree(ii,jj,kk,ll) 01048 * corresponds true label = ii 01049 * predicted label = jj 01050 * confusion matrix after ll trees 01051 * 01052 * explaination of third index: 01053 * 01054 * Two class case: 01055 * kk = 0 - (treeCount-1) 01056 * Threshold is on Probability for class 0 is kk/(treeCount-1); 01057 * More classes: 01058 * kk = 0. Threshold on probability set by argMax of the probability array. 01059 */ 01060 MultiArray<4, double> oobroc_per_tree; 01061 01062 CompleteOOBInfo() : VisitorBase(), oob_mean(0), oob_std(0), oob_per_tree2(0) {} 01063 01064 #ifdef HasHDF5 01065 /** save to HDF5 file 01066 */ 01067 void save(std::string filen, std::string pathn) 01068 { 01069 if(*(pathn.end()-1) != '/') 01070 pathn += "/"; 01071 const char* filename = filen.c_str(); 01072 MultiArray<2, double> temp(Shp(1,1), 0.0); 01073 writeHDF5(filename, (pathn + "oob_per_tree").c_str(), oob_per_tree); 01074 writeHDF5(filename, (pathn + "oobroc_per_tree").c_str(), oobroc_per_tree); 01075 writeHDF5(filename, (pathn + "breiman_per_tree").c_str(), breiman_per_tree); 01076 temp[0] = oob_mean; 01077 writeHDF5(filename, (pathn + "per_tree_error").c_str(), temp); 01078 temp[0] = oob_std; 01079 writeHDF5(filename, (pathn + "per_tree_error_std").c_str(), temp); 01080 temp[0] = oob_breiman; 01081 writeHDF5(filename, (pathn + "breiman_error").c_str(), temp); 01082 temp[0] = oob_per_tree2; 01083 writeHDF5(filename, (pathn + "ulli_error").c_str(), temp); 01084 } 01085 #endif 01086 // negative value if sample was ib, number indicates how often. 01087 // value >=0 if sample was oob, 0 means fail 1, corrrect 01088 01089 template<class RF, class PR> 01090 void visit_at_beginning(RF & rf, PR & pr) 01091 { 01092 class_count = rf.class_count(); 01093 if(class_count == 2) 01094 oobroc_per_tree.reshape(MultiArrayShape<4>::type(2,2,rf.tree_count(), rf.tree_count())); 01095 else 01096 oobroc_per_tree.reshape(MultiArrayShape<4>::type(rf.class_count(),rf.class_count(),1, rf.tree_count())); 01097 tmp_prob.reshape(Shp(1, class_count), 0); 01098 prob_oob.reshape(Shp(rf.ext_param().row_count_,class_count), 0); 01099 is_weighted = rf.options().predict_weighted_; 01100 oob_per_tree.reshape(Shp(1, rf.tree_count()), 0); 01101 breiman_per_tree.reshape(Shp(1, rf.tree_count()), 0); 01102 //do the first time called. 01103 if(int(oobCount.size()) != rf.ext_param_.row_count_) 01104 { 01105 oobCount.reshape(Shp(rf.ext_param_.row_count_, 1), 0); 01106 oobErrorCount.reshape(Shp(rf.ext_param_.row_count_,1), 0); 01107 } 01108 } 01109 01110 template<class RF, class PR, class SM, class ST> 01111 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index) 01112 { 01113 // go through the samples 01114 int total_oob =0; 01115 int wrong_oob =0; 01116 for(int ll = 0; ll < rf.ext_param_.row_count_; ++ll) 01117 { 01118 // if the lth sample is oob... 01119 if(!sm.is_used()[ll]) 01120 { 01121 // update number of trees in which current sample is oob 01122 ++oobCount[ll]; 01123 01124 // update number of oob samples in this tree. 01125 ++total_oob; 01126 // get the predicted votes ---> tmp_prob; 01127 int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),ll)); 01128 Node<e_ConstProbNode> node ( rf.tree(index).topology_, 01129 rf.tree(index).parameters_, 01130 pos); 01131 tmp_prob.init(0); 01132 for(int ii = 0; ii < class_count; ++ii) 01133 { 01134 tmp_prob[ii] = node.prob_begin()[ii]; 01135 } 01136 if(is_weighted) 01137 { 01138 for(int ii = 0; ii < class_count; ++ii) 01139 tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1)); 01140 } 01141 rowVector(prob_oob, ll) += tmp_prob; 01142 int label = argMax(tmp_prob); 01143 01144 if(label != pr.response()(ll, 0)) 01145 { 01146 // update number of wrong oob samples in this tree. 01147 ++wrong_oob; 01148 // update number of trees in which current sample is wrong oob 01149 ++oobErrorCount[ll]; 01150 } 01151 } 01152 } 01153 int breimanstyle = 0; 01154 int totalOobCount = 0; 01155 for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll) 01156 { 01157 if(oobCount[ll]) 01158 { 01159 if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0)) 01160 ++breimanstyle; 01161 ++totalOobCount; 01162 if(oobroc_per_tree.shape(2) == 1) 01163 { 01164 oobroc_per_tree(pr.response()(ll,0), argMax(rowVector(prob_oob, ll)),0 ,index)++; 01165 } 01166 } 01167 } 01168 if(oobroc_per_tree.shape(2) == 1) 01169 oobroc_per_tree.bindOuter(index)/=totalOobCount; 01170 if(oobroc_per_tree.shape(2) > 1) 01171 { 01172 MultiArrayView<3, double> current_roc 01173 = oobroc_per_tree.bindOuter(index); 01174 for(int gg = 0; gg < current_roc.shape(2); ++gg) 01175 { 01176 for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll) 01177 { 01178 if(oobCount[ll]) 01179 { 01180 int pred = prob_oob(ll, 1) > (double(gg)/double(current_roc.shape(2)))? 01181 1 : 0; 01182 current_roc(pr.response()(ll, 0), pred, gg)+= 1; 01183 } 01184 } 01185 current_roc.bindOuter(gg)/= totalOobCount; 01186 } 01187 } 01188 breiman_per_tree[index] = double(breimanstyle)/double(totalOobCount); 01189 oob_per_tree[index] = double(wrong_oob)/double(total_oob); 01190 // go through the ib samples; 01191 } 01192 01193 /** Normalise variable importance after the number of trees is known. 01194 */ 01195 template<class RF, class PR> 01196 void visit_at_end(RF & rf, PR & pr) 01197 { 01198 // ullis original metric and breiman style stuff 01199 oob_per_tree2 = 0; 01200 int totalOobCount =0; 01201 int breimanstyle = 0; 01202 for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll) 01203 { 01204 if(oobCount[ll]) 01205 { 01206 if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0)) 01207 ++breimanstyle; 01208 oob_per_tree2 += double(oobErrorCount[ll]) / oobCount[ll]; 01209 ++totalOobCount; 01210 } 01211 } 01212 oob_per_tree2 /= totalOobCount; 01213 oob_breiman = double(breimanstyle)/totalOobCount; 01214 // mean error of each tree 01215 MultiArrayView<2, double> mean(Shp(1,1), &oob_mean); 01216 MultiArrayView<2, double> stdDev(Shp(1,1), &oob_std); 01217 rowStatistics(oob_per_tree, mean, stdDev); 01218 } 01219 }; 01220 01221 /** calculate variable importance while learning. 01222 */ 01223 class VariableImportanceVisitor : public VisitorBase 01224 { 01225 public: 01226 01227 /** This Array has the same entries as the R - random forest variable 01228 * importance. 01229 * Matrix is featureCount by (classCount +2) 01230 * variable_importance_(ii,jj) is the variable importance measure of 01231 * the ii-th variable according to: 01232 * jj = 0 - (classCount-1) 01233 * classwise permutation importance 01234 * jj = rowCount(variable_importance_) -2 01235 * permutation importance 01236 * jj = rowCount(variable_importance_) -1 01237 * gini decrease importance. 01238 * 01239 * permutation importance: 01240 * The difference between the fraction of OOB samples classified correctly 01241 * before and after permuting (randomizing) the ii-th column is calculated. 01242 * The ii-th column is permuted rep_cnt times. 01243 * 01244 * class wise permutation importance: 01245 * same as permutation importance. We only look at those OOB samples whose 01246 * response corresponds to class jj. 01247 * 01248 * gini decrease importance: 01249 * row ii corresponds to the sum of all gini decreases induced by variable ii 01250 * in each node of the random forest. 01251 */ 01252 MultiArray<2, double> variable_importance_; 01253 int repetition_count_; 01254 bool in_place_; 01255 01256 #ifdef HasHDF5 01257 void save(std::string filename, std::string prefix) 01258 { 01259 prefix = "variable_importance_" + prefix; 01260 writeHDF5(filename.c_str(), 01261 prefix.c_str(), 01262 variable_importance_); 01263 } 01264 #endif 01265 /** Constructor 01266 * \param rep_cnt (defautl: 10) how often should 01267 * the permutation take place. Set to 1 to make calculation faster (but 01268 * possibly more instable) 01269 */ 01270 VariableImportanceVisitor(int rep_cnt = 10) 01271 : repetition_count_(rep_cnt) 01272 01273 {} 01274 01275 /** calculates impurity decrease based variable importance after every 01276 * split. 01277 */ 01278 template<class Tree, class Split, class Region, class Feature_t, class Label_t> 01279 void visit_after_split( Tree & tree, 01280 Split & split, 01281 Region & parent, 01282 Region & leftChild, 01283 Region & rightChild, 01284 Feature_t & features, 01285 Label_t & labels) 01286 { 01287 //resize to right size when called the first time 01288 01289 Int32 const class_count = tree.ext_param_.class_count_; 01290 Int32 const column_count = tree.ext_param_.column_count_; 01291 if(variable_importance_.size() == 0) 01292 { 01293 01294 variable_importance_ 01295 .reshape(MultiArrayShape<2>::type(column_count, 01296 class_count+2)); 01297 } 01298 01299 if(split.createNode().typeID() == i_ThresholdNode) 01300 { 01301 Node<i_ThresholdNode> node(split.createNode()); 01302 variable_importance_(node.column(),class_count+1) 01303 += split.region_gini_ - split.minGini(); 01304 } 01305 } 01306 01307 /**compute permutation based var imp. 01308 * (Only an Array of size oob_sample_count x 1 is created. 01309 * - apposed to oob_sample_count x feature_count in the other method. 01310 * 01311 * \sa FieldProxy 01312 */ 01313 template<class RF, class PR, class SM, class ST> 01314 void after_tree_ip_impl(RF& rf, PR & pr, SM & sm, ST & st, int index) 01315 { 01316 typedef MultiArrayShape<2>::type Shp_t; 01317 Int32 column_count = rf.ext_param_.column_count_; 01318 Int32 class_count = rf.ext_param_.class_count_; 01319 01320 /* This solution saves memory uptake but not multithreading 01321 * compatible 01322 */ 01323 // remove the const cast on the features (yep , I know what I am 01324 // doing here.) data is not destroyed. 01325 //typename PR::Feature_t & features 01326 // = const_cast<typename PR::Feature_t &>(pr.features()); 01327 01328 typename PR::FeatureWithMemory_t features = pr.features(); 01329 01330 //find the oob indices of current tree. 01331 ArrayVector<Int32> oob_indices; 01332 ArrayVector<Int32>::iterator 01333 iter; 01334 for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii) 01335 if(!sm.is_used()[ii]) 01336 oob_indices.push_back(ii); 01337 01338 //create space to back up a column 01339 std::vector<double> backup_column; 01340 01341 // Random foo 01342 #ifdef CLASSIFIER_TEST 01343 RandomMT19937 random(1); 01344 #else 01345 RandomMT19937 random(RandomSeed); 01346 #endif 01347 UniformIntRandomFunctor<RandomMT19937> 01348 randint(random); 01349 01350 01351 //make some space for the results 01352 MultiArray<2, double> 01353 oob_right(Shp_t(1, class_count + 1)); 01354 MultiArray<2, double> 01355 perm_oob_right (Shp_t(1, class_count + 1)); 01356 01357 01358 // get the oob success rate with the original samples 01359 for(iter = oob_indices.begin(); 01360 iter != oob_indices.end(); 01361 ++iter) 01362 { 01363 if(rf.tree(index) 01364 .predictLabel(rowVector(features, *iter)) 01365 == pr.response()(*iter, 0)) 01366 { 01367 //per class 01368 ++oob_right[pr.response()(*iter,0)]; 01369 //total 01370 ++oob_right[class_count]; 01371 } 01372 } 01373 //get the oob rate after permuting the ii'th dimension. 01374 for(int ii = 0; ii < column_count; ++ii) 01375 { 01376 perm_oob_right.init(0.0); 01377 //make backup of orinal column 01378 backup_column.clear(); 01379 for(iter = oob_indices.begin(); 01380 iter != oob_indices.end(); 01381 ++iter) 01382 { 01383 backup_column.push_back(features(*iter,ii)); 01384 } 01385 01386 //get the oob rate after permuting the ii'th dimension. 01387 for(int rr = 0; rr < repetition_count_; ++rr) 01388 { 01389 //permute dimension. 01390 int n = oob_indices.size(); 01391 for(int jj = 1; jj < n; ++jj) 01392 std::swap(features(oob_indices[jj], ii), 01393 features(oob_indices[randint(jj+1)], ii)); 01394 01395 //get the oob sucess rate after permuting 01396 for(iter = oob_indices.begin(); 01397 iter != oob_indices.end(); 01398 ++iter) 01399 { 01400 if(rf.tree(index) 01401 .predictLabel(rowVector(features, *iter)) 01402 == pr.response()(*iter, 0)) 01403 { 01404 //per class 01405 ++perm_oob_right[pr.response()(*iter, 0)]; 01406 //total 01407 ++perm_oob_right[class_count]; 01408 } 01409 } 01410 } 01411 01412 01413 //normalise and add to the variable_importance array. 01414 perm_oob_right /= repetition_count_; 01415 perm_oob_right -=oob_right; 01416 perm_oob_right *= -1; 01417 perm_oob_right /= oob_indices.size(); 01418 variable_importance_ 01419 .subarray(Shp_t(ii,0), 01420 Shp_t(ii+1,class_count+1)) += perm_oob_right; 01421 //copy back permuted dimension 01422 for(int jj = 0; jj < int(oob_indices.size()); ++jj) 01423 features(oob_indices[jj], ii) = backup_column[jj]; 01424 } 01425 } 01426 01427 /** calculate permutation based impurity after every tree has been 01428 * learned default behaviour is that this happens out of place. 01429 * If you have very big data sets and want to avoid copying of data 01430 * set the in_place_ flag to true. 01431 */ 01432 template<class RF, class PR, class SM, class ST> 01433 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index) 01434 { 01435 after_tree_ip_impl(rf, pr, sm, st, index); 01436 } 01437 01438 /** Normalise variable importance after the number of trees is known. 01439 */ 01440 template<class RF, class PR> 01441 void visit_at_end(RF & rf, PR & pr) 01442 { 01443 variable_importance_ /= rf.trees_.size(); 01444 } 01445 }; 01446 01447 /** Verbose output 01448 */ 01449 class RandomForestProgressVisitor : public VisitorBase { 01450 public: 01451 RandomForestProgressVisitor() : VisitorBase() {} 01452 01453 template<class RF, class PR, class SM, class ST> 01454 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index){ 01455 if(index != rf.options().tree_count_-1) { 01456 std::cout << "\r[" << std::setw(10) << (index+1)/static_cast<double>(rf.options().tree_count_)*100 << "%]" 01457 << " (" << index+1 << " of " << rf.options().tree_count_ << ") done" << std::flush; 01458 } 01459 else { 01460 std::cout << "\r[" << std::setw(10) << 100.0 << "%]" << std::endl; 01461 } 01462 } 01463 01464 template<class RF, class PR> 01465 void visit_at_end(RF const & rf, PR const & pr) { 01466 std::string a = TOCS; 01467 std::cout << "all " << rf.options().tree_count_ << " trees have been learned in " << a << std::endl; 01468 } 01469 01470 template<class RF, class PR> 01471 void visit_at_beginning(RF const & rf, PR const & pr) { 01472 TIC; 01473 std::cout << "growing random forest, which will have " << rf.options().tree_count_ << " trees" << std::endl; 01474 } 01475 01476 private: 01477 USETICTOC; 01478 }; 01479 01480 01481 /** Computes Correlation/Similarity Matrix of features while learning 01482 * random forest. 01483 */ 01484 class CorrelationVisitor : public VisitorBase 01485 { 01486 public: 01487 /** gini_missc(ii, jj) describes how well variable jj can describe a partition 01488 * created on variable ii(when variable ii was chosen) 01489 */ 01490 MultiArray<2, double> gini_missc; 01491 MultiArray<2, int> tmp_labels; 01492 /** additional noise features. 01493 */ 01494 MultiArray<2, double> noise; 01495 MultiArray<2, double> noise_l; 01496 /** how well can a noise column describe a partition created on variable ii. 01497 */ 01498 MultiArray<2, double> corr_noise; 01499 MultiArray<2, double> corr_l; 01500 01501 /** Similarity Matrix 01502 * 01503 * (numberOfFeatures + 1) by (number Of Features + 1) Matrix 01504 * gini_missc 01505 * - row normalized by the number of times the column was chosen 01506 * - mean of corr_noise subtracted 01507 * - and symmetrised. 01508 * 01509 */ 01510 MultiArray<2, double> similarity; 01511 /** Distance Matrix 1-similarity 01512 */ 01513 MultiArray<2, double> distance; 01514 ArrayVector<int> tmp_cc; 01515 01516 /** How often was variable ii chosen 01517 */ 01518 ArrayVector<int> numChoices; 01519 typedef BestGiniOfColumn<GiniCriterion> ColumnDecisionFunctor; 01520 BestGiniOfColumn<GiniCriterion> bgfunc; 01521 void save(std::string file, std::string prefix) 01522 { 01523 /* 01524 std::string tmp; 01525 #define VAR_WRITE(NAME) \ 01526 tmp = #NAME;\ 01527 tmp += "_";\ 01528 tmp += prefix;\ 01529 vigra::writeToHDF5File(file.c_str(), tmp.c_str(), NAME); 01530 VAR_WRITE(gini_missc); 01531 VAR_WRITE(corr_noise); 01532 VAR_WRITE(distance); 01533 VAR_WRITE(similarity); 01534 vigra::writeToHDF5File(file.c_str(), "nChoices", MultiArrayView<2, int>(MultiArrayShape<2>::type(numChoices.size(),1), numChoices.data())); 01535 #undef VAR_WRITE 01536 */ 01537 } 01538 template<class RF, class PR> 01539 void visit_at_beginning(RF const & rf, PR & pr) 01540 { 01541 typedef MultiArrayShape<2>::type Shp; 01542 int n = rf.ext_param_.column_count_; 01543 gini_missc.reshape(Shp(n +1,n+ 1)); 01544 corr_noise.reshape(Shp(n + 1, 10)); 01545 corr_l.reshape(Shp(n +1, 10)); 01546 01547 noise.reshape(Shp(pr.features().shape(0), 10)); 01548 noise_l.reshape(Shp(pr.features().shape(0), 10)); 01549 RandomMT19937 random(RandomSeed); 01550 for(int ii = 0; ii < noise.size(); ++ii) 01551 { 01552 noise[ii] = random.uniform53(); 01553 noise_l[ii] = random.uniform53() > 0.5; 01554 } 01555 bgfunc = ColumnDecisionFunctor( rf.ext_param_); 01556 tmp_labels.reshape(pr.response().shape()); 01557 tmp_cc.resize(2); 01558 numChoices.resize(n+1); 01559 // look at allaxes 01560 } 01561 template<class RF, class PR> 01562 void visit_at_end(RF const & rf, PR const & pr) 01563 { 01564 typedef MultiArrayShape<2>::type Shp; 01565 similarity.reshape(gini_missc.shape()); 01566 similarity = gini_missc;; 01567 MultiArray<2, double> mean_noise(Shp(corr_noise.shape(0), 1)); 01568 rowStatistics(corr_noise, mean_noise); 01569 mean_noise/= MultiArrayView<2, int>(mean_noise.shape(), numChoices.data()); 01570 int rC = similarity.shape(0); 01571 for(int jj = 0; jj < rC-1; ++jj) 01572 { 01573 rowVector(similarity, jj) /= numChoices[jj]; 01574 rowVector(similarity, jj) -= mean_noise(jj, 0); 01575 } 01576 for(int jj = 0; jj < rC; ++jj) 01577 { 01578 similarity(rC -1, jj) /= numChoices[jj]; 01579 } 01580 rowVector(similarity, rC - 1) -= mean_noise(rC-1, 0); 01581 similarity = abs(similarity); 01582 FindMinMax<double> minmax; 01583 inspectMultiArray(srcMultiArrayRange(similarity), minmax); 01584 01585 for(int jj = 0; jj < rC; ++jj) 01586 similarity(jj, jj) = minmax.max; 01587 01588 similarity.subarray(Shp(0,0), Shp(rC-1, rC-1)) 01589 += similarity.subarray(Shp(0,0), Shp(rC-1, rC-1)).transpose(); 01590 similarity.subarray(Shp(0,0), Shp(rC-1, rC-1))/= 2; 01591 columnVector(similarity, rC-1) = rowVector(similarity, rC-1).transpose(); 01592 for(int jj = 0; jj < rC; ++jj) 01593 similarity(jj, jj) = 0; 01594 01595 FindMinMax<double> minmax2; 01596 inspectMultiArray(srcMultiArrayRange(similarity), minmax2); 01597 for(int jj = 0; jj < rC; ++jj) 01598 similarity(jj, jj) = minmax2.max; 01599 distance.reshape(gini_missc.shape(), minmax2.max); 01600 distance -= similarity; 01601 } 01602 01603 template<class Tree, class Split, class Region, class Feature_t, class Label_t> 01604 void visit_after_split( Tree & tree, 01605 Split & split, 01606 Region & parent, 01607 Region & leftChild, 01608 Region & rightChild, 01609 Feature_t & features, 01610 Label_t & labels) 01611 { 01612 if(split.createNode().typeID() == i_ThresholdNode) 01613 { 01614 double wgini; 01615 tmp_cc.init(0); 01616 for(int ii = 0; ii < parent.size(); ++ii) 01617 { 01618 tmp_labels[parent[ii]] 01619 = (features(parent[ii], split.bestSplitColumn()) < split.bestSplitThreshold()); 01620 ++tmp_cc[tmp_labels[parent[ii]]]; 01621 } 01622 double region_gini = bgfunc.loss_of_region(tmp_labels, 01623 parent.begin(), 01624 parent.end(), 01625 tmp_cc); 01626 01627 int n = split.bestSplitColumn(); 01628 ++numChoices[n]; 01629 ++(*(numChoices.end()-1)); 01630 //this functor does all the work 01631 for(int k = 0; k < features.shape(1); ++k) 01632 { 01633 bgfunc(columnVector(features, k), 01634 0, 01635 tmp_labels, 01636 parent.begin(), parent.end(), 01637 tmp_cc); 01638 wgini = (region_gini - bgfunc.min_gini_); 01639 gini_missc(n, k) 01640 += wgini; 01641 } 01642 for(int k = 0; k < 10; ++k) 01643 { 01644 bgfunc(columnVector(noise, k), 01645 0, 01646 tmp_labels, 01647 parent.begin(), parent.end(), 01648 tmp_cc); 01649 wgini = (region_gini - bgfunc.min_gini_); 01650 corr_noise(n, k) 01651 += wgini; 01652 } 01653 01654 for(int k = 0; k < 10; ++k) 01655 { 01656 bgfunc(columnVector(noise_l, k), 01657 0, 01658 tmp_labels, 01659 parent.begin(), parent.end(), 01660 tmp_cc); 01661 wgini = (region_gini - bgfunc.min_gini_); 01662 corr_l(n, k) 01663 += wgini; 01664 } 01665 bgfunc(labels,0, tmp_labels, parent.begin(), parent.end(),tmp_cc); 01666 wgini = (region_gini - bgfunc.min_gini_); 01667 gini_missc(n, columnCount(gini_missc)-1) 01668 += wgini; 01669 01670 region_gini = split.region_gini_; 01671 #if 1 01672 Node<i_ThresholdNode> node(split.createNode()); 01673 gini_missc(rowCount(gini_missc)-1, 01674 node.column()) 01675 +=split.region_gini_ - split.minGini(); 01676 #endif 01677 for(int k = 0; k < 10; ++k) 01678 { 01679 split.bgfunc(columnVector(noise, k), 01680 0, 01681 labels, 01682 parent.begin(), parent.end(), 01683 parent.classCounts()); 01684 corr_noise(rowCount(gini_missc)-1, 01685 k) 01686 += wgini; 01687 } 01688 #if 0 01689 for(int k = 0; k < tree.ext_param_.actual_mtry_; ++k) 01690 { 01691 wgini = region_gini - split.min_gini_[k]; 01692 01693 gini_missc(rowCount(gini_missc)-1, 01694 split.splitColumns[k]) 01695 += wgini; 01696 } 01697 01698 for(int k=tree.ext_param_.actual_mtry_; k<features.shape(1); ++k) 01699 { 01700 split.bgfunc(columnVector(features, split.splitColumns[k]), 01701 labels, 01702 parent.begin(), parent.end(), 01703 parent.classCounts()); 01704 wgini = region_gini - split.bgfunc.min_gini_; 01705 gini_missc(rowCount(gini_missc)-1, 01706 split.splitColumns[k]) += wgini; 01707 } 01708 #endif 01709 // remember to partition the data according to the best. 01710 gini_missc(rowCount(gini_missc)-1, 01711 columnCount(gini_missc)-1) 01712 += region_gini; 01713 SortSamplesByDimensions<Feature_t> 01714 sorter(features, split.bestSplitColumn(), split.bestSplitThreshold()); 01715 std::partition(parent.begin(), parent.end(), sorter); 01716 } 01717 } 01718 }; 01719 01720 01721 } // namespace visitors 01722 } // namespace rf 01723 } // namespace vigra 01724 01725 //@} 01726 #endif // RF_VISITORS_HXX
© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de) |
html generated using doxygen and Python
|