[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

vigra/random_forest/rf_visitors.hxx
00001 /************************************************************************/
00002 /*                                                                      */
00003 /*        Copyright 2008-2009 by  Ullrich Koethe and Rahul Nair         */
00004 /*                                                                      */
00005 /*    This file is part of the VIGRA computer vision library.           */
00006 /*    The VIGRA Website is                                              */
00007 /*        http://hci.iwr.uni-heidelberg.de/vigra/                       */
00008 /*    Please direct questions, bug reports, and contributions to        */
00009 /*        ullrich.koethe@iwr.uni-heidelberg.de    or                    */
00010 /*        vigra@informatik.uni-hamburg.de                               */
00011 /*                                                                      */
00012 /*    Permission is hereby granted, free of charge, to any person       */
00013 /*    obtaining a copy of this software and associated documentation    */
00014 /*    files (the "Software"), to deal in the Software without           */
00015 /*    restriction, including without limitation the rights to use,      */
00016 /*    copy, modify, merge, publish, distribute, sublicense, and/or      */
00017 /*    sell copies of the Software, and to permit persons to whom the    */
00018 /*    Software is furnished to do so, subject to the following          */
00019 /*    conditions:                                                       */
00020 /*                                                                      */
00021 /*    The above copyright notice and this permission notice shall be    */
00022 /*    included in all copies or substantial portions of the             */
00023 /*    Software.                                                         */
00024 /*                                                                      */
00025 /*    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND    */
00026 /*    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES   */
00027 /*    OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND          */
00028 /*    NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT       */
00029 /*    HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,      */
00030 /*    WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING      */
00031 /*    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR     */
00032 /*    OTHER DEALINGS IN THE SOFTWARE.                                   */
00033 /*                                                                      */
00034 /************************************************************************/
00035 #ifndef RF_VISITORS_HXX
00036 #define RF_VISITORS_HXX
00037 
00038 #ifdef HasHDF5
00039 # include "vigra/hdf5impex.hxx"
00040 #else
00041 # include "vigra/impex.hxx"
00042 # include "vigra/multi_array.hxx"
00043 # include "vigra/multi_impex.hxx"
00044 # include "vigra/inspectimage.hxx"
00045 #endif // HasHDF5
00046 #include <vigra/windows.h>
00047 #include <iostream>
00048 #include <iomanip>
00049 #include <vigra/timing.hxx>
00050 
00051 namespace vigra
00052 {
00053 namespace rf
00054 {
00055 /** \addtogroup MachineLearning Machine Learning
00056 **/
00057 //@{
00058 
00059 /**
00060     This namespace contains all classes and methods related to extracting information during 
00061     learning of the random forest. All Visitors share the same interface defined in 
00062     visitors::VisitorBase. The member methods are invoked at certain points of the main code in 
00063     the order they were supplied.
00064     
00065     For the Random Forest the  Visitor concept is implemented as a statically linked list 
00066     (Using templates). Each Visitor object is encapsulated in a detail::VisitorNode object. The 
00067     VisitorNode object calls the Next Visitor after one of its visit() methods have terminated.
00068     
00069     To simplify usage create_visitor() factory methods are supplied.
00070     Use the create_visitor() method to supply visitor objects to the RandomForest::learn() method.
00071     It is possible to supply more than one visitor. They will then be invoked in serial order.
00072 
00073     The calculated information are stored as public data members of the class. - see documentation
00074     of the individual visitors
00075     
00076     While creating a new visitor the new class should therefore publicly inherit from this class 
00077     (i.e.: see visitors::OOB_Error).
00078 
00079     \code
00080 
00081       typedef xxx feature_t \\ replace xxx with whichever type
00082       typedef yyy label_t   \\ meme chose. 
00083       MultiArrayView<2, feature_t> f = get_some_features();
00084       MultiArrayView<2, label_t>   l = get_some_labels();
00085       RandomForest<> rf()
00086     
00087       //calculate OOB Error
00088       visitors::OOB_Error oob_v;
00089       //calculate Variable Importance
00090       visitors::VariableImportanceVisitor varimp_v;
00091 
00092       double oob_error = rf.learn(f, l, visitors::create_visitor(oob_v, varimp_v);
00093       //the data can be found in the attributes of oob_v and varimp_v now
00094       
00095     \endcode
00096 */
00097 namespace visitors
00098 {
00099     
00100     
00101 /** Base Class from which all Visitors derive. Can be used as a template to create new 
00102  * Visitors.
00103  */
00104 class VisitorBase
00105 {
00106     public:
00107     bool active_;   
00108     bool is_active()
00109     {
00110         return active_;
00111     }
00112 
00113     bool has_value()
00114     {
00115         return false;
00116     }
00117 
00118     VisitorBase()
00119         : active_(true)
00120     {}
00121 
00122     void deactivate()
00123     {
00124         active_ = false;
00125     }
00126     void activate()
00127     {
00128         active_ = true;
00129     }
00130     
00131     /** do something after the the Split has decided how to process the Region
00132      * (Stack entry)
00133      *
00134      * \param tree      reference to the tree that is currently being learned
00135      * \param split     reference to the split object
00136      * \param parent    current stack entry  which was used to decide the split
00137      * \param leftChild left stack entry that will be pushed
00138      * \param rightChild
00139      *                  right stack entry that will be pushed.
00140      * \param features  features matrix
00141      * \param labels    label matrix
00142      * \sa RF_Traits::StackEntry_t
00143      */
00144     template<class Tree, class Split, class Region, class Feature_t, class Label_t>
00145     void visit_after_split( Tree          & tree, 
00146                             Split         & split,
00147                             Region        & parent,
00148                             Region        & leftChild,
00149                             Region        & rightChild,
00150                             Feature_t     & features,
00151                             Label_t       & labels)
00152     {}
00153     
00154     /** do something after each tree has been learned
00155      *
00156      * \param rf        reference to the random forest object that called this
00157      *                  visitor
00158      * \param pr        reference to the preprocessor that processed the input
00159      * \param sm        reference to the sampler object
00160      * \param st        reference to the first stack entry
00161      * \param index     index of current tree
00162      */
00163     template<class RF, class PR, class SM, class ST>
00164     void visit_after_tree(RF& rf, PR & pr,  SM & sm, ST & st, int index)
00165     {}
00166     
00167     /** do something after all trees have been learned
00168      *
00169      * \param rf        reference to the random forest object that called this
00170      *                  visitor
00171      * \param pr        reference to the preprocessor that processed the input
00172      */
00173     template<class RF, class PR>
00174     void visit_at_end(RF const & rf, PR const & pr)
00175     {}
00176     
00177     /** do something before learning starts 
00178      *
00179      * \param rf        reference to the random forest object that called this
00180      *                  visitor
00181      * \param pr        reference to the Processor class used.
00182      */
00183     template<class RF, class PR>
00184     void visit_at_beginning(RF const & rf, PR const & pr)
00185     {}
00186     /** do some thing while traversing tree after it has been learned 
00187      *  (external nodes)
00188      *
00189      * \param tr        reference to the tree object that called this visitor
00190      * \param index     index in the topology_ array we currently are at
00191      * \param node_t    type of node we have (will be e_.... - )
00192      * \param weight    Node weight of current node. 
00193      * \sa  NodeTags;
00194      *
00195      * you can create the node by using a switch on node_tag and using the 
00196      * corresponding Node objects. Or - if you do not care about the type 
00197      * use the Nodebase class.
00198      */
00199     template<class TR, class IntT, class TopT,class Feat>
00200     void visit_external_node(TR & tr, IntT index, TopT node_t,Feat & features)
00201     {}
00202     
00203     /** do something when visiting a internal node after it has been learned
00204      *
00205      * \sa visit_external_node
00206      */
00207     template<class TR, class IntT, class TopT,class Feat>
00208     void visit_internal_node(TR & tr, IntT index, TopT node_t,Feat & features)
00209     {}
00210 
00211     /** return a double value.  The value of the first 
00212      * visitor encountered that has a return value is returned with the
00213      * RandomForest::learn() method - or -1.0 if no return value visitor
00214      * existed. This functionality basically only exists so that the 
00215      * OOB - visitor can return the oob error rate like in the old version 
00216      * of the random forest.
00217      */
00218     double return_val()
00219     {
00220         return -1.0;
00221     }
00222 };
00223 
00224 
00225 /** Last Visitor that should be called to stop the recursion.
00226  */
00227 class StopVisiting: public VisitorBase
00228 {
00229     public:
00230     bool has_value()
00231     {
00232         return true;
00233     }
00234     double return_val()
00235     {
00236         return -1.0;
00237     }
00238 };
00239 namespace detail
00240 {
00241 /** Container elements of the statically linked Visitor list.
00242  *
00243  * use the create_visitor() factory functions to create visitors up to size 10;
00244  *
00245  */
00246 template <class Visitor, class Next = StopVisiting>
00247 class VisitorNode
00248 {
00249     public:
00250     
00251     StopVisiting    stop_;
00252     Next            next_;
00253     Visitor &       visitor_;   
00254     VisitorNode(Visitor & visitor, Next & next) 
00255     : 
00256         next_(next), visitor_(visitor)
00257     {}
00258 
00259     VisitorNode(Visitor &  visitor) 
00260     : 
00261         next_(stop_), visitor_(visitor)
00262     {}
00263 
00264     template<class Tree, class Split, class Region, class Feature_t, class Label_t>
00265     void visit_after_split( Tree          & tree, 
00266                             Split         & split,
00267                             Region        & parent,
00268                             Region        & leftChild,
00269                             Region        & rightChild,
00270                             Feature_t     & features,
00271                             Label_t       & labels)
00272     {
00273         if(visitor_.is_active())
00274             visitor_.visit_after_split(tree, split, 
00275                                        parent, leftChild, rightChild,
00276                                        features, labels);
00277         next_.visit_after_split(tree, split, parent, leftChild, rightChild,
00278                                 features, labels);
00279     }
00280 
00281     template<class RF, class PR, class SM, class ST>
00282     void visit_after_tree(RF& rf, PR & pr,  SM & sm, ST & st, int index)
00283     {
00284         if(visitor_.is_active())
00285             visitor_.visit_after_tree(rf, pr, sm, st, index);
00286         next_.visit_after_tree(rf, pr, sm, st, index);
00287     }
00288 
00289     template<class RF, class PR>
00290     void visit_at_beginning(RF & rf, PR & pr)
00291     {
00292         if(visitor_.is_active())
00293             visitor_.visit_at_beginning(rf, pr);
00294         next_.visit_at_beginning(rf, pr);
00295     }
00296     template<class RF, class PR>
00297     void visit_at_end(RF & rf, PR & pr)
00298     {
00299         if(visitor_.is_active())
00300             visitor_.visit_at_end(rf, pr);
00301         next_.visit_at_end(rf, pr);
00302     }
00303     
00304     template<class TR, class IntT, class TopT,class Feat>
00305     void visit_external_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
00306     {
00307         if(visitor_.is_active())
00308             visitor_.visit_external_node(tr, index, node_t,features);
00309         next_.visit_external_node(tr, index, node_t,features);
00310     }
00311     template<class TR, class IntT, class TopT,class Feat>
00312     void visit_internal_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
00313     {
00314         if(visitor_.is_active())
00315             visitor_.visit_internal_node(tr, index, node_t,features);
00316         next_.visit_internal_node(tr, index, node_t,features);
00317     }
00318 
00319     double return_val()
00320     {
00321         if(visitor_.is_active() && visitor_.has_value())
00322             return visitor_.return_val();
00323         return next_.return_val();
00324     }
00325 };
00326 
00327 } //namespace detail
00328 
00329 //////////////////////////////////////////////////////////////////////////////
00330 //  Visitor Factory function up to 10 visitors                              //
00331 //////////////////////////////////////////////////////////////////////////////
00332 
00333 /** factory method to to be used with RandomForest::learn()
00334  */
00335 template<class A>
00336 detail::VisitorNode<A>
00337 create_visitor(A & a)
00338 {
00339    typedef detail::VisitorNode<A> _0_t;
00340    _0_t _0(a);
00341    return _0;
00342 }
00343 
00344 
00345 /** factory method to to be used with RandomForest::learn()
00346  */
00347 template<class A, class B>
00348 detail::VisitorNode<A, detail::VisitorNode<B> >
00349 create_visitor(A & a, B & b)
00350 {
00351    typedef detail::VisitorNode<B> _1_t;
00352    _1_t _1(b);
00353    typedef detail::VisitorNode<A, _1_t> _0_t;
00354    _0_t _0(a, _1);
00355    return _0;
00356 }
00357 
00358 
00359 /** factory method to to be used with RandomForest::learn()
00360  */
00361 template<class A, class B, class C>
00362 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C> > >
00363 create_visitor(A & a, B & b, C & c)
00364 {
00365    typedef detail::VisitorNode<C> _2_t;
00366    _2_t _2(c);
00367    typedef detail::VisitorNode<B, _2_t> _1_t;
00368    _1_t _1(b, _2);
00369    typedef detail::VisitorNode<A, _1_t> _0_t;
00370    _0_t _0(a, _1);
00371    return _0;
00372 }
00373 
00374 
00375 /** factory method to to be used with RandomForest::learn()
00376  */
00377 template<class A, class B, class C, class D>
00378 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 
00379     detail::VisitorNode<D> > > >
00380 create_visitor(A & a, B & b, C & c, D & d)
00381 {
00382    typedef detail::VisitorNode<D> _3_t;
00383    _3_t _3(d);
00384    typedef detail::VisitorNode<C, _3_t> _2_t;
00385    _2_t _2(c, _3);
00386    typedef detail::VisitorNode<B, _2_t> _1_t;
00387    _1_t _1(b, _2);
00388    typedef detail::VisitorNode<A, _1_t> _0_t;
00389    _0_t _0(a, _1);
00390    return _0;
00391 }
00392 
00393 
00394 /** factory method to to be used with RandomForest::learn()
00395  */
00396 template<class A, class B, class C, class D, class E>
00397 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 
00398     detail::VisitorNode<D, detail::VisitorNode<E> > > > >
00399 create_visitor(A & a, B & b, C & c, 
00400                D & d, E & e)
00401 {
00402    typedef detail::VisitorNode<E> _4_t;
00403    _4_t _4(e);
00404    typedef detail::VisitorNode<D, _4_t> _3_t;
00405    _3_t _3(d, _4);
00406    typedef detail::VisitorNode<C, _3_t> _2_t;
00407    _2_t _2(c, _3);
00408    typedef detail::VisitorNode<B, _2_t> _1_t;
00409    _1_t _1(b, _2);
00410    typedef detail::VisitorNode<A, _1_t> _0_t;
00411    _0_t _0(a, _1);
00412    return _0;
00413 }
00414 
00415 
00416 /** factory method to to be used with RandomForest::learn()
00417  */
00418 template<class A, class B, class C, class D, class E,
00419          class F>
00420 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 
00421     detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F> > > > > >
00422 create_visitor(A & a, B & b, C & c, 
00423                D & d, E & e, F & f)
00424 {
00425    typedef detail::VisitorNode<F> _5_t;
00426    _5_t _5(f);
00427    typedef detail::VisitorNode<E, _5_t> _4_t;
00428    _4_t _4(e, _5);
00429    typedef detail::VisitorNode<D, _4_t> _3_t;
00430    _3_t _3(d, _4);
00431    typedef detail::VisitorNode<C, _3_t> _2_t;
00432    _2_t _2(c, _3);
00433    typedef detail::VisitorNode<B, _2_t> _1_t;
00434    _1_t _1(b, _2);
00435    typedef detail::VisitorNode<A, _1_t> _0_t;
00436    _0_t _0(a, _1);
00437    return _0;
00438 }
00439 
00440 
00441 /** factory method to to be used with RandomForest::learn()
00442  */
00443 template<class A, class B, class C, class D, class E,
00444          class F, class G>
00445 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 
00446     detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F, 
00447     detail::VisitorNode<G> > > > > > >
00448 create_visitor(A & a, B & b, C & c, 
00449                D & d, E & e, F & f, G & g)
00450 {
00451    typedef detail::VisitorNode<G> _6_t;
00452    _6_t _6(g);
00453    typedef detail::VisitorNode<F, _6_t> _5_t;
00454    _5_t _5(f, _6);
00455    typedef detail::VisitorNode<E, _5_t> _4_t;
00456    _4_t _4(e, _5);
00457    typedef detail::VisitorNode<D, _4_t> _3_t;
00458    _3_t _3(d, _4);
00459    typedef detail::VisitorNode<C, _3_t> _2_t;
00460    _2_t _2(c, _3);
00461    typedef detail::VisitorNode<B, _2_t> _1_t;
00462    _1_t _1(b, _2);
00463    typedef detail::VisitorNode<A, _1_t> _0_t;
00464    _0_t _0(a, _1);
00465    return _0;
00466 }
00467 
00468 
00469 /** factory method to to be used with RandomForest::learn()
00470  */
00471 template<class A, class B, class C, class D, class E,
00472          class F, class G, class H>
00473 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 
00474     detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F, 
00475     detail::VisitorNode<G, detail::VisitorNode<H> > > > > > > >
00476 create_visitor(A & a, B & b, C & c, 
00477                D & d, E & e, F & f, 
00478                G & g, H & h)
00479 {
00480    typedef detail::VisitorNode<H> _7_t;
00481    _7_t _7(h);
00482    typedef detail::VisitorNode<G, _7_t> _6_t;
00483    _6_t _6(g, _7);
00484    typedef detail::VisitorNode<F, _6_t> _5_t;
00485    _5_t _5(f, _6);
00486    typedef detail::VisitorNode<E, _5_t> _4_t;
00487    _4_t _4(e, _5);
00488    typedef detail::VisitorNode<D, _4_t> _3_t;
00489    _3_t _3(d, _4);
00490    typedef detail::VisitorNode<C, _3_t> _2_t;
00491    _2_t _2(c, _3);
00492    typedef detail::VisitorNode<B, _2_t> _1_t;
00493    _1_t _1(b, _2);
00494    typedef detail::VisitorNode<A, _1_t> _0_t;
00495    _0_t _0(a, _1);
00496    return _0;
00497 }
00498 
00499 
00500 /** factory method to to be used with RandomForest::learn()
00501  */
00502 template<class A, class B, class C, class D, class E,
00503          class F, class G, class H, class I>
00504 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 
00505     detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F, 
00506     detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I> > > > > > > > >
00507 create_visitor(A & a, B & b, C & c, 
00508                D & d, E & e, F & f, 
00509                G & g, H & h, I & i)
00510 {
00511    typedef detail::VisitorNode<I> _8_t;
00512    _8_t _8(i);
00513    typedef detail::VisitorNode<H, _8_t> _7_t;
00514    _7_t _7(h, _8);
00515    typedef detail::VisitorNode<G, _7_t> _6_t;
00516    _6_t _6(g, _7);
00517    typedef detail::VisitorNode<F, _6_t> _5_t;
00518    _5_t _5(f, _6);
00519    typedef detail::VisitorNode<E, _5_t> _4_t;
00520    _4_t _4(e, _5);
00521    typedef detail::VisitorNode<D, _4_t> _3_t;
00522    _3_t _3(d, _4);
00523    typedef detail::VisitorNode<C, _3_t> _2_t;
00524    _2_t _2(c, _3);
00525    typedef detail::VisitorNode<B, _2_t> _1_t;
00526    _1_t _1(b, _2);
00527    typedef detail::VisitorNode<A, _1_t> _0_t;
00528    _0_t _0(a, _1);
00529    return _0;
00530 }
00531 
00532 /** factory method to to be used with RandomForest::learn()
00533  */
00534 template<class A, class B, class C, class D, class E,
00535          class F, class G, class H, class I, class J>
00536 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 
00537     detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F, 
00538     detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I,
00539     detail::VisitorNode<J> > > > > > > > > >
00540 create_visitor(A & a, B & b, C & c, 
00541                D & d, E & e, F & f, 
00542                G & g, H & h, I & i,
00543                J & j)
00544 {
00545    typedef detail::VisitorNode<J> _9_t;
00546    _9_t _9(j);
00547    typedef detail::VisitorNode<I, _9_t> _8_t;
00548    _8_t _8(i, _9);
00549    typedef detail::VisitorNode<H, _8_t> _7_t;
00550    _7_t _7(h, _8);
00551    typedef detail::VisitorNode<G, _7_t> _6_t;
00552    _6_t _6(g, _7);
00553    typedef detail::VisitorNode<F, _6_t> _5_t;
00554    _5_t _5(f, _6);
00555    typedef detail::VisitorNode<E, _5_t> _4_t;
00556    _4_t _4(e, _5);
00557    typedef detail::VisitorNode<D, _4_t> _3_t;
00558    _3_t _3(d, _4);
00559    typedef detail::VisitorNode<C, _3_t> _2_t;
00560    _2_t _2(c, _3);
00561    typedef detail::VisitorNode<B, _2_t> _1_t;
00562    _1_t _1(b, _2);
00563    typedef detail::VisitorNode<A, _1_t> _0_t;
00564    _0_t _0(a, _1);
00565    return _0;
00566 }
00567 
00568 //////////////////////////////////////////////////////////////////////////////
00569 // Visitors of communal interest.                                           //
00570 //////////////////////////////////////////////////////////////////////////////
00571 
00572 
00573 /** Visitor to gain information, later needed for online learning.
00574  */
00575 
00576 class OnlineLearnVisitor: public VisitorBase
00577 {
00578 public:
00579     //Set if we adjust thresholds
00580     bool adjust_thresholds;
00581     //Current tree id
00582     int tree_id;
00583     //Last node id for finding parent
00584     int last_node_id;
00585     //Need to now the label for interior node visiting
00586     vigra::Int32 current_label;
00587     //marginal distribution for interior nodes
00588     struct MarginalDistribution
00589     {
00590         ArrayVector<Int32> leftCounts;
00591         Int32 leftTotalCounts;
00592         ArrayVector<Int32> rightCounts;
00593         Int32 rightTotalCounts;
00594         double gap_left;
00595         double gap_right;
00596     };
00597     typedef ArrayVector<vigra::Int32> IndexList;
00598 
00599     //All information for one tree
00600     struct TreeOnlineInformation
00601     {
00602         std::vector<MarginalDistribution> mag_distributions;
00603         std::vector<IndexList> index_lists;
00604         //map for linear index of mag_distiributions
00605         std::map<int,int> interior_to_index;
00606         //map for linear index of index_lists
00607         std::map<int,int> exterior_to_index;
00608     };
00609 
00610     //All trees
00611     std::vector<TreeOnlineInformation> trees_online_information;
00612 
00613     /** Initilize, set the number of trees
00614      */
00615     template<class RF,class PR>
00616     void visit_at_beginning(RF & rf,const PR & pr)
00617     {
00618         tree_id=0;
00619         trees_online_information.resize(rf.options_.tree_count_);
00620     }
00621 
00622     /** Reset a tree
00623      */
00624     void reset_tree(int tree_id)
00625     {
00626         trees_online_information[tree_id].mag_distributions.clear();
00627         trees_online_information[tree_id].index_lists.clear();
00628         trees_online_information[tree_id].interior_to_index.clear();
00629         trees_online_information[tree_id].exterior_to_index.clear();
00630     }
00631 
00632     /** simply increase the tree count
00633     */
00634     template<class RF, class PR, class SM, class ST>
00635     void visit_after_tree(RF& rf, PR & pr,  SM & sm, ST & st, int index)
00636     {
00637         tree_id++;
00638     }
00639     
00640     template<class Tree, class Split, class Region, class Feature_t, class Label_t>
00641     void visit_after_split( Tree          & tree, 
00642                 Split         & split,
00643                             Region       & parent,
00644                             Region        & leftChild,
00645                             Region        & rightChild,
00646                             Feature_t     & features,
00647                             Label_t       & labels)
00648     {
00649         int linear_index;
00650         int addr=tree.topology_.size();
00651         if(split.createNode().typeID() == i_ThresholdNode)
00652         {
00653             if(adjust_thresholds)
00654             {
00655                 //Store marginal distribution
00656                 linear_index=trees_online_information[tree_id].mag_distributions.size();
00657                 trees_online_information[tree_id].interior_to_index[addr]=linear_index;
00658                 trees_online_information[tree_id].mag_distributions.push_back(MarginalDistribution());
00659 
00660                 trees_online_information[tree_id].mag_distributions.back().leftCounts=leftChild.classCounts_;
00661                 trees_online_information[tree_id].mag_distributions.back().rightCounts=rightChild.classCounts_;
00662 
00663                 trees_online_information[tree_id].mag_distributions.back().leftTotalCounts=leftChild.size_;
00664                 trees_online_information[tree_id].mag_distributions.back().rightTotalCounts=rightChild.size_;
00665                 //Store the gap
00666                 double gap_left,gap_right;
00667                 int i;
00668                 gap_left=features(leftChild[0],split.bestSplitColumn());
00669                 for(i=1;i<leftChild.size();++i)
00670                     if(features(leftChild[i],split.bestSplitColumn())>gap_left)
00671                         gap_left=features(leftChild[i],split.bestSplitColumn());
00672                 gap_right=features(rightChild[0],split.bestSplitColumn());
00673                 for(i=1;i<rightChild.size();++i)
00674                     if(features(rightChild[i],split.bestSplitColumn())<gap_right)
00675                         gap_right=features(rightChild[i],split.bestSplitColumn());
00676                 trees_online_information[tree_id].mag_distributions.back().gap_left=gap_left;
00677                 trees_online_information[tree_id].mag_distributions.back().gap_right=gap_right;
00678             }
00679         }
00680         else
00681         {
00682             //Store index list
00683             linear_index=trees_online_information[tree_id].index_lists.size();
00684             trees_online_information[tree_id].exterior_to_index[addr]=linear_index;
00685 
00686             trees_online_information[tree_id].index_lists.push_back(IndexList());
00687 
00688             trees_online_information[tree_id].index_lists.back().resize(parent.size_,0);
00689             std::copy(parent.begin_,parent.end_,trees_online_information[tree_id].index_lists.back().begin());
00690         }
00691     }
00692     void add_to_index_list(int tree,int node,int index)
00693     {
00694         if(!this->active_)
00695             return;
00696         TreeOnlineInformation &ti=trees_online_information[tree];
00697         ti.index_lists[ti.exterior_to_index[node]].push_back(index);
00698     }
00699     void move_exterior_node(int src_tree,int src_index,int dst_tree,int dst_index)
00700     {
00701         if(!this->active_)
00702             return;
00703         trees_online_information[dst_tree].exterior_to_index[dst_index]=trees_online_information[src_tree].exterior_to_index[src_index];
00704         trees_online_information[src_tree].exterior_to_index.erase(src_index);
00705     }
00706     /** do something when visiting a internal node during getToLeaf
00707      *
00708      * remember as last node id, for finding the parent of the last external node
00709      * also: adjust class counts and borders
00710      */
00711     template<class TR, class IntT, class TopT,class Feat>
00712         void visit_internal_node(TR & tr, IntT index, TopT node_t,Feat & features)
00713         {
00714             last_node_id=index;
00715             if(adjust_thresholds)
00716             {
00717                 vigra_assert(node_t==i_ThresholdNode,"We can only visit threshold nodes");
00718                 //Check if we are in the gap
00719                 double value=features(0, Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).column());
00720                 TreeOnlineInformation &ti=trees_online_information[tree_id];
00721                 MarginalDistribution &m=ti.mag_distributions[ti.interior_to_index[index]];
00722                 if(value>m.gap_left && value<m.gap_right)
00723                 {
00724                     //Check which site we want to go
00725                     if(m.leftCounts[current_label]/double(m.leftTotalCounts)>m.rightCounts[current_label]/double(m.rightTotalCounts))
00726                     {
00727                         //We want to go left
00728                         m.gap_left=value;
00729                     }
00730                     else
00731                     {
00732                         //We want to go right
00733                         m.gap_right=value;
00734                     }
00735                     Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold()=(m.gap_right+m.gap_left)/2.0;
00736                 }
00737                 //Adjust class counts
00738                 if(value>Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold())
00739                 {
00740                     ++m.rightTotalCounts;
00741                     ++m.rightCounts[current_label];
00742                 }
00743                 else
00744                 {
00745                     ++m.leftTotalCounts;
00746                     ++m.rightCounts[current_label];
00747                 }
00748             }
00749         }
00750     /** do something when visiting a extern node during getToLeaf
00751      * 
00752      * Store the new index!
00753      */
00754 };
00755 
00756 //////////////////////////////////////////////////////////////////////////////
00757 // Out of Bag Error estimates                                               //
00758 //////////////////////////////////////////////////////////////////////////////
00759 
00760 
00761 /** Visitor that calculates the oob error of each individual randomized
00762  * decision tree. 
00763  *
00764  * After training a tree, all those samples that are OOB for this particular tree
00765  * are put down the tree and the error estimated. 
00766  * the per tree oob error is the average of the individual error estimates. 
00767  * (oobError = average error of one randomized tree)
00768  * Note: This is Not the OOB - Error estimate suggested by Breiman (See OOB_Error 
00769  * visitor)
00770  */
00771 class OOB_PerTreeError:public VisitorBase
00772 {
00773 public:
00774     /** Average error of one randomized decision tree
00775      */
00776     double oobError;
00777 
00778     int totalOobCount;
00779     ArrayVector<int> oobCount,oobErrorCount;
00780 
00781     OOB_PerTreeError()
00782     : oobError(0.0),
00783       totalOobCount(0)
00784     {}
00785 
00786 
00787     bool has_value()
00788     {
00789         return true;
00790     }
00791 
00792 
00793     /** does the basic calculation per tree*/
00794     template<class RF, class PR, class SM, class ST>
00795     void visit_after_tree(    RF& rf, PR & pr,  SM & sm, ST & st, int index)
00796     {
00797         //do the first time called.
00798         if(int(oobCount.size()) != rf.ext_param_.row_count_)
00799         {
00800             oobCount.resize(rf.ext_param_.row_count_, 0);
00801             oobErrorCount.resize(rf.ext_param_.row_count_, 0);
00802         }
00803         // go through the samples
00804         for(int l = 0; l < rf.ext_param_.row_count_; ++l)
00805         {
00806             // if the lth sample is oob...
00807             if(!sm.is_used()[l])
00808             {
00809                 ++oobCount[l];
00810                 if(     rf.tree(index)
00811                             .predictLabel(rowVector(pr.features(), l)) 
00812                     !=  pr.response()(l,0))
00813                 {
00814                     ++oobErrorCount[l];
00815                 }
00816             }
00817 
00818         }
00819     }
00820 
00821     /** Does the normalisation
00822      */
00823     template<class RF, class PR>
00824     void visit_at_end(RF & rf, PR & pr)
00825     {
00826         // do some normalisation
00827         for(int l=0; l < (int)rf.ext_param_.row_count_; ++l)
00828         {
00829             if(oobCount[l])
00830             {
00831                 oobError += double(oobErrorCount[l]) / oobCount[l];
00832                 ++totalOobCount;
00833             }
00834         } 
00835         oobError/=totalOobCount;
00836     }
00837     
00838 };
00839 
00840 /** Visitor that calculates the oob error of the ensemble
00841  *  This rate should be used to estimate the crossvalidation 
00842  *  error rate.
00843  *  Here each sample is put down those trees, for which this sample
00844  *  is OOB i.e. if sample #1 is  OOB for trees 1, 3 and 5 we calculate
00845  *  the output using the ensemble consisting only of trees 1 3 and 5. 
00846  *
00847  *  Using normal bagged sampling each sample is OOB for approx. 33% of trees
00848  *  The error rate obtained as such therefore corresponds to crossvalidation
00849  *  rate obtained using a ensemble containing 33% of the trees.
00850  */
00851 class OOB_Error : public VisitorBase
00852 {
00853     typedef MultiArrayShape<2>::type Shp;
00854     int class_count;
00855     bool is_weighted;
00856     MultiArray<2,double> tmp_prob;
00857     public:
00858 
00859     MultiArray<2, double>       prob_oob; 
00860     /** Ensemble oob error rate
00861      */
00862     double                      oob_breiman;
00863 
00864     MultiArray<2, double>       oobCount;
00865     ArrayVector< int>           indices; 
00866     OOB_Error() : VisitorBase(), oob_breiman(0.0) {}
00867 
00868 #ifdef HasHDF5
00869     void save(std::string filen, std::string pathn)
00870     {
00871         if(*(pathn.end()-1) != '/')
00872             pathn += "/";
00873         const char* filename = filen.c_str();
00874         MultiArray<2, double> temp(Shp(1,1), 0.0); 
00875         temp[0] = oob_breiman;
00876         writeHDF5(filename, (pathn + "breiman_error").c_str(), temp);
00877     }
00878 #endif
00879     // negative value if sample was ib, number indicates how often.
00880     //  value >=0  if sample was oob, 0 means fail 1, corrrect
00881 
00882     template<class RF, class PR>
00883     void visit_at_beginning(RF & rf, PR & pr)
00884     {
00885         class_count = rf.class_count();
00886         tmp_prob.reshape(Shp(1, class_count), 0); 
00887         prob_oob.reshape(Shp(rf.ext_param().row_count_,class_count), 0);
00888         is_weighted = rf.options().predict_weighted_;
00889         indices.resize(rf.ext_param().row_count_);
00890         if(int(oobCount.size()) != rf.ext_param_.row_count_)
00891         {
00892             oobCount.reshape(Shp(rf.ext_param_.row_count_, 1), 0);
00893         }
00894         for(int ii = 0; ii < rf.ext_param().row_count_; ++ii)
00895         {
00896             indices[ii] = ii;
00897         }
00898     }
00899 
00900     template<class RF, class PR, class SM, class ST>
00901     void visit_after_tree(RF& rf, PR & pr,  SM & sm, ST & st, int index)
00902     {
00903         // go through the samples
00904         int total_oob =0;
00905         int wrong_oob =0;
00906         // FIXME: magic number 10000: invoke special treatment when when msample << sample_count
00907         //                            (i.e. the OOB sample ist very large)
00908         //                     40000: use at most 40000 OOB samples per class for OOB error estimate 
00909         if(rf.ext_param_.actual_msample_ < pr.features().shape(0) - 10000)
00910         {
00911             ArrayVector<int> oob_indices;
00912             ArrayVector<int> cts(class_count, 0);
00913             std::random_shuffle(indices.begin(), indices.end());
00914             for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
00915             {
00916                 if(!sm.is_used()[indices[ii]] && cts[pr.response()(indices[ii], 0)] < 40000)
00917                 {
00918                     oob_indices.push_back(indices[ii]);
00919                     ++cts[pr.response()(indices[ii], 0)];
00920                 }
00921             }
00922             for(unsigned int ll = 0; ll < oob_indices.size(); ++ll)
00923             {
00924                 // update number of trees in which current sample is oob
00925                 ++oobCount[oob_indices[ll]];
00926 
00927                 // update number of oob samples in this tree.
00928                 ++total_oob; 
00929                 // get the predicted votes ---> tmp_prob;
00930                 int pos =  rf.tree(index).getToLeaf(rowVector(pr.features(),oob_indices[ll]));
00931                 Node<e_ConstProbNode> node ( rf.tree(index).topology_, 
00932                                                     rf.tree(index).parameters_,
00933                                                     pos);
00934                 tmp_prob.init(0); 
00935                 for(int ii = 0; ii < class_count; ++ii)
00936                 {
00937                     tmp_prob[ii] = node.prob_begin()[ii];
00938                 }
00939                 if(is_weighted)
00940                 {
00941                     for(int ii = 0; ii < class_count; ++ii)
00942                         tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
00943                 }
00944                 rowVector(prob_oob, oob_indices[ll]) += tmp_prob;
00945                 int label = argMax(tmp_prob); 
00946                 
00947             }
00948         }else
00949         {
00950             for(int ll = 0; ll < rf.ext_param_.row_count_; ++ll)
00951             {
00952                 // if the lth sample is oob...
00953                 if(!sm.is_used()[ll])
00954                 {
00955                     // update number of trees in which current sample is oob
00956                     ++oobCount[ll];
00957 
00958                     // update number of oob samples in this tree.
00959                     ++total_oob; 
00960                     // get the predicted votes ---> tmp_prob;
00961                     int pos =  rf.tree(index).getToLeaf(rowVector(pr.features(),ll));
00962                     Node<e_ConstProbNode> node ( rf.tree(index).topology_, 
00963                                                         rf.tree(index).parameters_,
00964                                                         pos);
00965                     tmp_prob.init(0); 
00966                     for(int ii = 0; ii < class_count; ++ii)
00967                     {
00968                         tmp_prob[ii] = node.prob_begin()[ii];
00969                     }
00970                     if(is_weighted)
00971                     {
00972                         for(int ii = 0; ii < class_count; ++ii)
00973                             tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
00974                     }
00975                     rowVector(prob_oob, ll) += tmp_prob;
00976                     int label = argMax(tmp_prob); 
00977                     
00978                 }
00979             }
00980         }
00981         // go through the ib samples; 
00982     }
00983 
00984     /** Normalise variable importance after the number of trees is known.
00985      */
00986     template<class RF, class PR>
00987     void visit_at_end(RF & rf, PR & pr)
00988     {
00989         // ullis original metric and breiman style stuff
00990         int totalOobCount =0;
00991         int breimanstyle = 0;
00992         for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
00993         {
00994             if(oobCount[ll])
00995             {
00996                 if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
00997                    ++breimanstyle;
00998                 ++totalOobCount;
00999             }
01000         }
01001         oob_breiman = double(breimanstyle)/totalOobCount; 
01002     }
01003 };
01004 
01005 
01006 /** Visitor that calculates different OOB error statistics
01007  */
01008 class CompleteOOBInfo : public VisitorBase
01009 {
01010     typedef MultiArrayShape<2>::type Shp;
01011     int class_count;
01012     bool is_weighted;
01013     MultiArray<2,double> tmp_prob;
01014     public:
01015 
01016     /** OOB Error rate of each individual tree
01017      */
01018     MultiArray<2, double>       oob_per_tree;
01019     /** Mean of oob_per_tree
01020      */
01021     double                      oob_mean;
01022     /**Standard deviation of oob_per_tree
01023      */
01024     double                      oob_std;
01025     
01026     MultiArray<2, double>       prob_oob; 
01027     /** Ensemble OOB error
01028      *
01029      * \sa OOB_Error
01030      */
01031     double                      oob_breiman;
01032 
01033     MultiArray<2, double>       oobCount;
01034     MultiArray<2, double>       oobErrorCount;
01035     /** Per Tree OOB error calculated as in OOB_PerTreeError
01036      * (Ulli's version)
01037      */
01038     double                      oob_per_tree2;
01039 
01040     /**Column containing the development of the Ensemble
01041      * error rate with increasing number of trees
01042      */
01043     MultiArray<2, double>       breiman_per_tree;
01044     /** 4 dimensional array containing the development of confusion matrices 
01045      * with number of trees - can be used to estimate ROC curves etc.
01046      *
01047      * oobroc_per_tree(ii,jj,kk,ll) 
01048      * corresponds true label = ii 
01049      * predicted label = jj
01050      * confusion matrix after ll trees
01051      *
01052      * explaination of third index:
01053      *
01054      * Two class case:
01055      * kk = 0 - (treeCount-1)
01056      *         Threshold is on Probability for class 0  is kk/(treeCount-1);
01057      * More classes:
01058      * kk = 0. Threshold on probability set by argMax of the probability array.
01059      */
01060     MultiArray<4, double>       oobroc_per_tree;
01061     
01062     CompleteOOBInfo() : VisitorBase(), oob_mean(0), oob_std(0), oob_per_tree2(0)  {}
01063 
01064 #ifdef HasHDF5
01065     /** save to HDF5 file
01066      */
01067     void save(std::string filen, std::string pathn)
01068     {
01069         if(*(pathn.end()-1) != '/')
01070             pathn += "/";
01071         const char* filename = filen.c_str();
01072         MultiArray<2, double> temp(Shp(1,1), 0.0); 
01073         writeHDF5(filename, (pathn + "oob_per_tree").c_str(), oob_per_tree);
01074         writeHDF5(filename, (pathn + "oobroc_per_tree").c_str(), oobroc_per_tree);
01075         writeHDF5(filename, (pathn + "breiman_per_tree").c_str(), breiman_per_tree);
01076         temp[0] = oob_mean;
01077         writeHDF5(filename, (pathn + "per_tree_error").c_str(), temp);
01078         temp[0] = oob_std;
01079         writeHDF5(filename, (pathn + "per_tree_error_std").c_str(), temp);
01080         temp[0] = oob_breiman;
01081         writeHDF5(filename, (pathn + "breiman_error").c_str(), temp);
01082         temp[0] = oob_per_tree2;
01083         writeHDF5(filename, (pathn + "ulli_error").c_str(), temp);
01084     }
01085 #endif
01086     // negative value if sample was ib, number indicates how often.
01087     //  value >=0  if sample was oob, 0 means fail 1, corrrect
01088 
01089     template<class RF, class PR>
01090     void visit_at_beginning(RF & rf, PR & pr)
01091     {
01092         class_count = rf.class_count();
01093         if(class_count == 2)
01094             oobroc_per_tree.reshape(MultiArrayShape<4>::type(2,2,rf.tree_count(), rf.tree_count()));
01095         else
01096             oobroc_per_tree.reshape(MultiArrayShape<4>::type(rf.class_count(),rf.class_count(),1, rf.tree_count()));
01097         tmp_prob.reshape(Shp(1, class_count), 0); 
01098         prob_oob.reshape(Shp(rf.ext_param().row_count_,class_count), 0);
01099         is_weighted = rf.options().predict_weighted_;
01100         oob_per_tree.reshape(Shp(1, rf.tree_count()), 0);
01101         breiman_per_tree.reshape(Shp(1, rf.tree_count()), 0);
01102         //do the first time called.
01103         if(int(oobCount.size()) != rf.ext_param_.row_count_)
01104         {
01105             oobCount.reshape(Shp(rf.ext_param_.row_count_, 1), 0);
01106             oobErrorCount.reshape(Shp(rf.ext_param_.row_count_,1), 0);
01107         }
01108     }
01109 
01110     template<class RF, class PR, class SM, class ST>
01111     void visit_after_tree(RF& rf, PR & pr,  SM & sm, ST & st, int index)
01112     {
01113         // go through the samples
01114         int total_oob =0;
01115         int wrong_oob =0;
01116         for(int ll = 0; ll < rf.ext_param_.row_count_; ++ll)
01117         {
01118             // if the lth sample is oob...
01119             if(!sm.is_used()[ll])
01120             {
01121                 // update number of trees in which current sample is oob
01122                 ++oobCount[ll];
01123 
01124                 // update number of oob samples in this tree.
01125                 ++total_oob; 
01126                 // get the predicted votes ---> tmp_prob;
01127                 int pos =  rf.tree(index).getToLeaf(rowVector(pr.features(),ll));
01128                 Node<e_ConstProbNode> node ( rf.tree(index).topology_, 
01129                                                     rf.tree(index).parameters_,
01130                                                     pos);
01131                 tmp_prob.init(0); 
01132                 for(int ii = 0; ii < class_count; ++ii)
01133                 {
01134                     tmp_prob[ii] = node.prob_begin()[ii];
01135                 }
01136                 if(is_weighted)
01137                 {
01138                     for(int ii = 0; ii < class_count; ++ii)
01139                         tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
01140                 }
01141                 rowVector(prob_oob, ll) += tmp_prob;
01142                 int label = argMax(tmp_prob); 
01143                 
01144                 if(label != pr.response()(ll, 0))
01145                 {
01146                     // update number of wrong oob samples in this tree.
01147                     ++wrong_oob;
01148                     // update number of trees in which current sample is wrong oob
01149                     ++oobErrorCount[ll];
01150                 }
01151             }
01152         }
01153         int breimanstyle = 0;
01154         int totalOobCount = 0;
01155         for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
01156         {
01157             if(oobCount[ll])
01158             {
01159                 if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
01160                    ++breimanstyle;
01161                 ++totalOobCount;
01162                 if(oobroc_per_tree.shape(2) == 1)
01163                 {
01164                     oobroc_per_tree(pr.response()(ll,0), argMax(rowVector(prob_oob, ll)),0 ,index)++;
01165                 }
01166             }
01167         }
01168         if(oobroc_per_tree.shape(2) == 1)
01169             oobroc_per_tree.bindOuter(index)/=totalOobCount;
01170         if(oobroc_per_tree.shape(2) > 1)
01171         {
01172             MultiArrayView<3, double> current_roc 
01173                     = oobroc_per_tree.bindOuter(index);
01174             for(int gg = 0; gg < current_roc.shape(2); ++gg)
01175             {
01176                 for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
01177                 {
01178                     if(oobCount[ll])
01179                     {
01180                         int pred = prob_oob(ll, 1) > (double(gg)/double(current_roc.shape(2)))?
01181                                         1 : 0; 
01182                         current_roc(pr.response()(ll, 0), pred, gg)+= 1; 
01183                     }
01184                 }
01185                 current_roc.bindOuter(gg)/= totalOobCount;
01186             }
01187         }
01188         breiman_per_tree[index] = double(breimanstyle)/double(totalOobCount);
01189         oob_per_tree[index] = double(wrong_oob)/double(total_oob);
01190         // go through the ib samples; 
01191     }
01192 
01193     /** Normalise variable importance after the number of trees is known.
01194      */
01195     template<class RF, class PR>
01196     void visit_at_end(RF & rf, PR & pr)
01197     {
01198         // ullis original metric and breiman style stuff
01199         oob_per_tree2 = 0; 
01200         int totalOobCount =0;
01201         int breimanstyle = 0;
01202         for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
01203         {
01204             if(oobCount[ll])
01205             {
01206                 if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
01207                    ++breimanstyle;
01208                 oob_per_tree2 += double(oobErrorCount[ll]) / oobCount[ll];
01209                 ++totalOobCount;
01210             }
01211         }
01212         oob_per_tree2 /= totalOobCount; 
01213         oob_breiman = double(breimanstyle)/totalOobCount; 
01214         // mean error of each tree
01215         MultiArrayView<2, double> mean(Shp(1,1), &oob_mean);
01216         MultiArrayView<2, double> stdDev(Shp(1,1), &oob_std);
01217         rowStatistics(oob_per_tree, mean, stdDev);
01218     }
01219 };
01220 
01221 /** calculate variable importance while learning.
01222  */
01223 class VariableImportanceVisitor : public VisitorBase
01224 {
01225     public:
01226 
01227     /** This Array has the same entries as the R - random forest variable
01228      *  importance.
01229      *  Matrix is   featureCount by (classCount +2)
01230      *  variable_importance_(ii,jj) is the variable importance measure of 
01231      *  the ii-th variable according to:
01232      *  jj = 0 - (classCount-1)
01233      *      classwise permutation importance 
01234      *  jj = rowCount(variable_importance_) -2
01235      *      permutation importance
01236      *  jj = rowCount(variable_importance_) -1
01237      *      gini decrease importance.
01238      *
01239      *  permutation importance:
01240      *  The difference between the fraction of OOB samples classified correctly
01241      *  before and after permuting (randomizing) the ii-th column is calculated.
01242      *  The ii-th column is permuted rep_cnt times.
01243      *
01244      *  class wise permutation importance:
01245      *  same as permutation importance. We only look at those OOB samples whose 
01246      *  response corresponds to class jj.
01247      *
01248      *  gini decrease importance:
01249      *  row ii corresponds to the sum of all gini decreases induced by variable ii 
01250      *  in each node of the random forest.
01251      */
01252     MultiArray<2, double>       variable_importance_;
01253     int                         repetition_count_;
01254     bool                        in_place_;
01255 
01256 #ifdef HasHDF5
01257     void save(std::string filename, std::string prefix)
01258     {
01259         prefix = "variable_importance_" + prefix;
01260         writeHDF5(filename.c_str(), 
01261                         prefix.c_str(), 
01262                         variable_importance_);
01263     }
01264 #endif
01265     /** Constructor
01266      * \param rep_cnt (defautl: 10) how often should 
01267      * the permutation take place. Set to 1 to make calculation faster (but
01268      * possibly more instable)
01269      */
01270     VariableImportanceVisitor(int rep_cnt = 10) 
01271     :   repetition_count_(rep_cnt)
01272 
01273     {}
01274 
01275     /** calculates impurity decrease based variable importance after every
01276      * split.  
01277      */
01278     template<class Tree, class Split, class Region, class Feature_t, class Label_t>
01279     void visit_after_split( Tree          & tree, 
01280                             Split         & split,
01281                             Region        & parent,
01282                             Region        & leftChild,
01283                             Region        & rightChild,
01284                             Feature_t     & features,
01285                             Label_t       & labels)
01286     {
01287         //resize to right size when called the first time
01288         
01289         Int32 const  class_count = tree.ext_param_.class_count_;
01290         Int32 const  column_count = tree.ext_param_.column_count_;
01291         if(variable_importance_.size() == 0)
01292         {
01293             
01294             variable_importance_
01295                 .reshape(MultiArrayShape<2>::type(column_count, 
01296                                                  class_count+2));
01297         }
01298 
01299         if(split.createNode().typeID() == i_ThresholdNode)
01300         {
01301             Node<i_ThresholdNode> node(split.createNode());
01302             variable_importance_(node.column(),class_count+1) 
01303                 += split.region_gini_ - split.minGini();
01304         }
01305     }
01306 
01307     /**compute permutation based var imp. 
01308      * (Only an Array of size oob_sample_count x 1 is created.
01309      *  - apposed to oob_sample_count x feature_count in the other method.
01310      * 
01311      * \sa FieldProxy
01312      */
01313     template<class RF, class PR, class SM, class ST>
01314     void after_tree_ip_impl(RF& rf, PR & pr,  SM & sm, ST & st, int index)
01315     {
01316         typedef MultiArrayShape<2>::type Shp_t;
01317         Int32                   column_count = rf.ext_param_.column_count_;
01318         Int32                   class_count  = rf.ext_param_.class_count_;  
01319         
01320         /* This solution saves memory uptake but not multithreading
01321          * compatible
01322          */
01323         // remove the const cast on the features (yep , I know what I am 
01324         // doing here.) data is not destroyed.
01325         //typename PR::Feature_t & features 
01326         //    = const_cast<typename PR::Feature_t &>(pr.features());
01327 
01328         typename PR::FeatureWithMemory_t features = pr.features();
01329 
01330         //find the oob indices of current tree. 
01331         ArrayVector<Int32>      oob_indices;
01332         ArrayVector<Int32>::iterator
01333                                 iter;
01334         for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
01335             if(!sm.is_used()[ii])
01336                 oob_indices.push_back(ii);
01337 
01338         //create space to back up a column      
01339         std::vector<double>     backup_column;
01340 
01341         // Random foo
01342 #ifdef CLASSIFIER_TEST
01343         RandomMT19937           random(1);
01344 #else 
01345         RandomMT19937           random(RandomSeed);
01346 #endif
01347         UniformIntRandomFunctor<RandomMT19937>  
01348                                 randint(random);
01349 
01350 
01351         //make some space for the results
01352         MultiArray<2, double>
01353                     oob_right(Shp_t(1, class_count + 1)); 
01354         MultiArray<2, double>
01355                     perm_oob_right (Shp_t(1, class_count + 1)); 
01356             
01357         
01358         // get the oob success rate with the original samples
01359         for(iter = oob_indices.begin(); 
01360             iter != oob_indices.end(); 
01361             ++iter)
01362         {
01363             if(rf.tree(index)
01364                     .predictLabel(rowVector(features, *iter)) 
01365                 ==  pr.response()(*iter, 0))
01366             {
01367                 //per class
01368                 ++oob_right[pr.response()(*iter,0)];
01369                 //total
01370                 ++oob_right[class_count];
01371             }
01372         }
01373         //get the oob rate after permuting the ii'th dimension.
01374         for(int ii = 0; ii < column_count; ++ii)
01375         {
01376             perm_oob_right.init(0.0); 
01377             //make backup of orinal column
01378             backup_column.clear();
01379             for(iter = oob_indices.begin(); 
01380                 iter != oob_indices.end(); 
01381                 ++iter)
01382             {
01383                 backup_column.push_back(features(*iter,ii));
01384             }
01385             
01386             //get the oob rate after permuting the ii'th dimension.
01387             for(int rr = 0; rr < repetition_count_; ++rr)
01388             {               
01389                 //permute dimension. 
01390                 int n = oob_indices.size();
01391                 for(int jj = 1; jj < n; ++jj)
01392                     std::swap(features(oob_indices[jj], ii), 
01393                               features(oob_indices[randint(jj+1)], ii));
01394 
01395                 //get the oob sucess rate after permuting
01396                 for(iter = oob_indices.begin(); 
01397                     iter != oob_indices.end(); 
01398                     ++iter)
01399                 {
01400                     if(rf.tree(index)
01401                             .predictLabel(rowVector(features, *iter)) 
01402                         ==  pr.response()(*iter, 0))
01403                     {
01404                         //per class
01405                         ++perm_oob_right[pr.response()(*iter, 0)];
01406                         //total
01407                         ++perm_oob_right[class_count];
01408                     }
01409                 }
01410             }
01411             
01412             
01413             //normalise and add to the variable_importance array.
01414             perm_oob_right  /=  repetition_count_;
01415             perm_oob_right -=oob_right;
01416             perm_oob_right *= -1;
01417             perm_oob_right      /=  oob_indices.size();
01418             variable_importance_
01419                 .subarray(Shp_t(ii,0), 
01420                           Shp_t(ii+1,class_count+1)) += perm_oob_right;
01421             //copy back permuted dimension
01422             for(int jj = 0; jj < int(oob_indices.size()); ++jj)
01423                 features(oob_indices[jj], ii) = backup_column[jj];
01424         }
01425     }
01426 
01427     /** calculate permutation based impurity after every tree has been 
01428      * learned  default behaviour is that this happens out of place.
01429      * If you have very big data sets and want to avoid copying of data 
01430      * set the in_place_ flag to true. 
01431      */
01432     template<class RF, class PR, class SM, class ST>
01433     void visit_after_tree(RF& rf, PR & pr,  SM & sm, ST & st, int index)
01434     {
01435             after_tree_ip_impl(rf, pr, sm, st, index);
01436     }
01437 
01438     /** Normalise variable importance after the number of trees is known.
01439      */
01440     template<class RF, class PR>
01441     void visit_at_end(RF & rf, PR & pr)
01442     {
01443         variable_importance_ /= rf.trees_.size();
01444     }
01445 };
01446 
01447 /** Verbose output
01448  */
01449 class RandomForestProgressVisitor : public VisitorBase {
01450     public:
01451     RandomForestProgressVisitor() : VisitorBase() {}
01452 
01453     template<class RF, class PR, class SM, class ST>
01454     void visit_after_tree(RF& rf, PR & pr,  SM & sm, ST & st, int index){
01455         if(index != rf.options().tree_count_-1) {
01456             std::cout << "\r[" << std::setw(10) << (index+1)/static_cast<double>(rf.options().tree_count_)*100 << "%]"
01457                       << " (" << index+1 << " of " << rf.options().tree_count_ << ") done" << std::flush;
01458         }
01459         else {
01460             std::cout << "\r[" << std::setw(10) << 100.0 << "%]" << std::endl;
01461         }
01462     }
01463     
01464     template<class RF, class PR>
01465     void visit_at_end(RF const & rf, PR const & pr) {
01466         std::string a = TOCS;
01467         std::cout << "all " << rf.options().tree_count_ << " trees have been learned in " << a  << std::endl;
01468     }
01469     
01470     template<class RF, class PR>
01471     void visit_at_beginning(RF const & rf, PR const & pr) {
01472         TIC;
01473         std::cout << "growing random forest, which will have " << rf.options().tree_count_ << " trees" << std::endl;
01474     }
01475     
01476     private:
01477     USETICTOC;
01478 };
01479 
01480 
01481 /** Computes Correlation/Similarity Matrix of features while learning
01482  * random forest.
01483  */
01484 class CorrelationVisitor : public VisitorBase
01485 {
01486     public:
01487     /** gini_missc(ii, jj) describes how well variable jj can describe a partition
01488      * created on variable ii(when variable ii was chosen)
01489      */ 
01490     MultiArray<2, double>   gini_missc;
01491     MultiArray<2, int>      tmp_labels;
01492     /** additional noise features. 
01493      */
01494     MultiArray<2, double>   noise;
01495     MultiArray<2, double>   noise_l;
01496     /** how well can a noise column describe a partition created on variable ii.
01497      */
01498     MultiArray<2, double>   corr_noise;
01499     MultiArray<2, double>   corr_l;
01500 
01501     /** Similarity Matrix
01502      * 
01503      * (numberOfFeatures + 1) by (number Of Features + 1) Matrix
01504      * gini_missc 
01505      *  - row normalized by the number of times the column was chosen
01506      *  - mean of corr_noise subtracted
01507      *  - and symmetrised. 
01508      *          
01509      */
01510     MultiArray<2, double>   similarity;
01511     /** Distance Matrix 1-similarity
01512      */
01513     MultiArray<2, double>   distance;
01514     ArrayVector<int>        tmp_cc;
01515     
01516     /** How often was variable ii chosen
01517      */
01518     ArrayVector<int>        numChoices;
01519     typedef BestGiniOfColumn<GiniCriterion> ColumnDecisionFunctor;
01520     BestGiniOfColumn<GiniCriterion>         bgfunc;
01521     void save(std::string file, std::string prefix)
01522     {
01523         /*
01524         std::string tmp;
01525 #define VAR_WRITE(NAME) \
01526         tmp = #NAME;\
01527         tmp += "_";\
01528         tmp += prefix;\
01529         vigra::writeToHDF5File(file.c_str(), tmp.c_str(), NAME);
01530         VAR_WRITE(gini_missc);
01531         VAR_WRITE(corr_noise);
01532         VAR_WRITE(distance);
01533         VAR_WRITE(similarity);
01534         vigra::writeToHDF5File(file.c_str(), "nChoices", MultiArrayView<2, int>(MultiArrayShape<2>::type(numChoices.size(),1), numChoices.data()));
01535 #undef VAR_WRITE
01536 */
01537     }
01538     template<class RF, class PR>
01539     void visit_at_beginning(RF const & rf, PR  & pr)
01540     {
01541         typedef MultiArrayShape<2>::type Shp;
01542         int n = rf.ext_param_.column_count_;
01543         gini_missc.reshape(Shp(n +1,n+ 1));
01544         corr_noise.reshape(Shp(n + 1, 10));
01545         corr_l.reshape(Shp(n +1, 10));
01546 
01547         noise.reshape(Shp(pr.features().shape(0), 10));
01548         noise_l.reshape(Shp(pr.features().shape(0), 10));
01549         RandomMT19937 random(RandomSeed);
01550         for(int ii = 0; ii < noise.size(); ++ii)
01551         {
01552             noise[ii]   = random.uniform53();
01553             noise_l[ii] = random.uniform53()  > 0.5;
01554         }
01555         bgfunc = ColumnDecisionFunctor( rf.ext_param_);
01556         tmp_labels.reshape(pr.response().shape()); 
01557         tmp_cc.resize(2);
01558         numChoices.resize(n+1);
01559         // look at allaxes
01560     }
01561     template<class RF, class PR>
01562     void visit_at_end(RF const & rf, PR const & pr)
01563     {
01564         typedef MultiArrayShape<2>::type Shp;
01565         similarity.reshape(gini_missc.shape());
01566         similarity = gini_missc;;
01567         MultiArray<2, double> mean_noise(Shp(corr_noise.shape(0), 1));
01568         rowStatistics(corr_noise, mean_noise);
01569         mean_noise/= MultiArrayView<2, int>(mean_noise.shape(), numChoices.data());        
01570         int rC = similarity.shape(0);
01571         for(int jj = 0; jj < rC-1; ++jj)
01572         {
01573             rowVector(similarity, jj) /= numChoices[jj];
01574             rowVector(similarity, jj) -= mean_noise(jj, 0);
01575         }
01576         for(int jj = 0; jj < rC; ++jj)
01577         {
01578             similarity(rC -1, jj) /= numChoices[jj];
01579         }
01580         rowVector(similarity, rC -  1) -= mean_noise(rC-1, 0);
01581         similarity = abs(similarity);
01582         FindMinMax<double> minmax;
01583         inspectMultiArray(srcMultiArrayRange(similarity), minmax);
01584         
01585         for(int jj = 0; jj < rC; ++jj)
01586             similarity(jj, jj) = minmax.max;
01587         
01588         similarity.subarray(Shp(0,0), Shp(rC-1, rC-1)) 
01589             += similarity.subarray(Shp(0,0), Shp(rC-1, rC-1)).transpose();
01590         similarity.subarray(Shp(0,0), Shp(rC-1, rC-1))/= 2;  
01591         columnVector(similarity, rC-1) = rowVector(similarity, rC-1).transpose();
01592         for(int jj = 0; jj < rC; ++jj)
01593             similarity(jj, jj) = 0;
01594         
01595         FindMinMax<double> minmax2;
01596         inspectMultiArray(srcMultiArrayRange(similarity), minmax2);
01597         for(int jj = 0; jj < rC; ++jj)
01598             similarity(jj, jj) = minmax2.max;
01599         distance.reshape(gini_missc.shape(), minmax2.max);
01600         distance -= similarity; 
01601     }
01602 
01603     template<class Tree, class Split, class Region, class Feature_t, class Label_t>
01604     void visit_after_split( Tree          & tree, 
01605                             Split         & split,
01606                             Region        & parent,
01607                             Region        & leftChild,
01608                             Region        & rightChild,
01609                             Feature_t     & features,
01610                             Label_t       & labels)
01611     {
01612         if(split.createNode().typeID() == i_ThresholdNode)
01613         {
01614             double wgini;
01615             tmp_cc.init(0); 
01616             for(int ii = 0; ii < parent.size(); ++ii)
01617             {
01618                 tmp_labels[parent[ii]] 
01619                     = (features(parent[ii], split.bestSplitColumn()) < split.bestSplitThreshold());
01620                 ++tmp_cc[tmp_labels[parent[ii]]];
01621             }
01622             double region_gini = bgfunc.loss_of_region(tmp_labels, 
01623                                                        parent.begin(),
01624                                                        parent.end(),
01625                                                        tmp_cc);
01626 
01627             int n = split.bestSplitColumn(); 
01628             ++numChoices[n];
01629             ++(*(numChoices.end()-1));
01630             //this functor does all the work
01631             for(int k = 0; k < features.shape(1); ++k)
01632             {
01633                 bgfunc(columnVector(features, k),
01634                        0,
01635                        tmp_labels, 
01636                        parent.begin(), parent.end(), 
01637                        tmp_cc);
01638                 wgini = (region_gini - bgfunc.min_gini_);
01639                 gini_missc(n, k) 
01640                     += wgini;
01641             }
01642             for(int k = 0; k < 10; ++k)
01643             {
01644                 bgfunc(columnVector(noise, k),
01645                        0,
01646                        tmp_labels, 
01647                        parent.begin(), parent.end(), 
01648                        tmp_cc);
01649                 wgini = (region_gini - bgfunc.min_gini_);
01650                 corr_noise(n, k) 
01651                     += wgini;
01652             }
01653             
01654             for(int k = 0; k < 10; ++k)
01655             {
01656                 bgfunc(columnVector(noise_l, k),
01657                        0,
01658                        tmp_labels, 
01659                        parent.begin(), parent.end(), 
01660                        tmp_cc);
01661                 wgini = (region_gini - bgfunc.min_gini_);
01662                 corr_l(n, k) 
01663                     += wgini;
01664             }
01665             bgfunc(labels,0,  tmp_labels, parent.begin(), parent.end(),tmp_cc);
01666             wgini = (region_gini - bgfunc.min_gini_);
01667             gini_missc(n, columnCount(gini_missc)-1) 
01668                 += wgini;
01669             
01670             region_gini = split.region_gini_;
01671 #if 1 
01672             Node<i_ThresholdNode> node(split.createNode());
01673             gini_missc(rowCount(gini_missc)-1, 
01674                                   node.column()) 
01675                  +=split.region_gini_ - split.minGini();
01676 #endif
01677             for(int k = 0; k < 10; ++k)
01678             {
01679                 split.bgfunc(columnVector(noise, k),
01680                              0,
01681                              labels, 
01682                              parent.begin(), parent.end(), 
01683                              parent.classCounts());
01684                 corr_noise(rowCount(gini_missc)-1, 
01685                            k) 
01686                      += wgini;
01687             }
01688 #if 0
01689             for(int k = 0; k < tree.ext_param_.actual_mtry_; ++k)
01690             {
01691                 wgini = region_gini - split.min_gini_[k];
01692                 
01693                 gini_missc(rowCount(gini_missc)-1, 
01694                                       split.splitColumns[k]) 
01695                      += wgini;
01696             }
01697             
01698             for(int k=tree.ext_param_.actual_mtry_; k<features.shape(1); ++k)
01699             {
01700                 split.bgfunc(columnVector(features, split.splitColumns[k]),
01701                              labels, 
01702                              parent.begin(), parent.end(), 
01703                              parent.classCounts());
01704                 wgini = region_gini - split.bgfunc.min_gini_;
01705                 gini_missc(rowCount(gini_missc)-1, 
01706                                       split.splitColumns[k]) += wgini;
01707             }
01708 #endif
01709             // remember to partition the data according to the best.
01710                 gini_missc(rowCount(gini_missc)-1, 
01711                            columnCount(gini_missc)-1) 
01712                      += region_gini;
01713                 SortSamplesByDimensions<Feature_t> 
01714                 sorter(features, split.bestSplitColumn(), split.bestSplitThreshold());
01715             std::partition(parent.begin(), parent.end(), sorter);
01716         }
01717     }
01718 };
01719 
01720 
01721 } // namespace visitors
01722 } // namespace rf
01723 } // namespace vigra
01724 
01725 //@}
01726 #endif // RF_VISITORS_HXX

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.7.1 (Sun Feb 19 2012)