[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

vigra/random_forest/rf_visitors.hxx
00001 /************************************************************************/
00002 /*                                                                      */
00003 /*        Copyright 2008-2009 by  Ullrich Koethe and Rahul Nair         */
00004 /*                                                                      */
00005 /*    This file is part of the VIGRA computer vision library.           */
00006 /*    The VIGRA Website is                                              */
00007 /*        http://hci.iwr.uni-heidelberg.de/vigra/                       */
00008 /*    Please direct questions, bug reports, and contributions to        */
00009 /*        ullrich.koethe@iwr.uni-heidelberg.de    or                    */
00010 /*        vigra@informatik.uni-hamburg.de                               */
00011 /*                                                                      */
00012 /*    Permission is hereby granted, free of charge, to any person       */
00013 /*    obtaining a copy of this software and associated documentation    */
00014 /*    files (the "Software"), to deal in the Software without           */
00015 /*    restriction, including without limitation the rights to use,      */
00016 /*    copy, modify, merge, publish, distribute, sublicense, and/or      */
00017 /*    sell copies of the Software, and to permit persons to whom the    */
00018 /*    Software is furnished to do so, subject to the following          */
00019 /*    conditions:                                                       */
00020 /*                                                                      */
00021 /*    The above copyright notice and this permission notice shall be    */
00022 /*    included in all copies or substantial portions of the             */
00023 /*    Software.                                                         */
00024 /*                                                                      */
00025 /*    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND    */
00026 /*    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES   */
00027 /*    OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND          */
00028 /*    NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT       */
00029 /*    HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,      */
00030 /*    WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING      */
00031 /*    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR     */
00032 /*    OTHER DEALINGS IN THE SOFTWARE.                                   */
00033 /*                                                                      */
00034 /************************************************************************/
00035 #ifndef RF_VISITORS_HXX
00036 #define RF_VISITORS_HXX
00037 
00038 #ifdef HasHDF5
00039 # include "vigra/hdf5impex.hxx"
00040 #endif // HasHDF5
00041 
00042 namespace vigra
00043 {
00044 
00045     
00046     
00047 /** Base Class from which all Visitors derive
00048  */
00049 class VisitorBase
00050 {
00051     public:
00052     bool active_;   
00053     bool is_active()
00054     {
00055         return active_;
00056     }
00057 
00058     bool has_value()
00059     {
00060         return false;
00061     }
00062 
00063     VisitorBase()
00064         : active_(true)
00065     {}
00066 
00067     void deactivate()
00068     {
00069         active_ = false;
00070     }
00071     void activate()
00072     {
00073         active_ = true;
00074     }
00075     
00076     /** do something after the the Split has decided how to process the Region
00077      * (Stack entry)
00078      *
00079      * \param tree      reference to the tree that is currently being learned
00080      * \param split     reference to the split object
00081      * \param parent    current stack entry  which was used to decide the split
00082      * \param leftChild left stack entry that will be pushed
00083      * \param rightChild
00084      *                  right stack entry that will be pushed.
00085      * \param features  features matrix
00086      * \param labels    label matrix
00087      * \sa RF_Traits::StackEntry_t
00088      */
00089     template<class Tree, class Split, class Region, class Feature_t, class Label_t>
00090     void visit_after_split( Tree          & tree, 
00091                             Split         & split,
00092                             Region        & parent,
00093                             Region        & leftChild,
00094                             Region        & rightChild,
00095                             Feature_t     & features,
00096                             Label_t       & labels)
00097     {}
00098     
00099     /** do something after each tree has been learned
00100      *
00101      * \param rf        reference to the random forest object that called this
00102      *                  visitor
00103      * \param pr        reference to the preprocessor that processed the input
00104      * \param sm        reference to the sampler object
00105      * \param st        reference to the first stack entry
00106      * \param index     index of current tree
00107      */
00108     template<class RF, class PR, class SM, class ST>
00109     void visit_after_tree(RF& rf, PR & pr,  SM & sm, ST & st, int index)
00110     {}
00111     
00112     /** do something after all trees have been learned
00113      *
00114      * \param rf        reference to the random forest object that called this
00115      *                  visitor
00116      * \param pr        reference to the preprocessor that processed the input
00117      */
00118     template<class RF, class PR>
00119     void visit_at_end(RF const & rf, PR const & pr)
00120     {}
00121     
00122     /** do something before learning starts 
00123      *
00124      * \param rf        reference to the random forest object that called this
00125      *                  visitor
00126      * \param pr        reference to the Processor class used.
00127      */
00128     template<class RF, class PR>
00129     void visit_at_beginning(RF const & rf, PR const & pr)
00130     {}
00131     /** do some thing while traversing tree after it has been learned 
00132      *  (external nodes)
00133      *
00134      * \param tr        reference to the tree object that called this visitor
00135      * \param index     index in the topology_ array we currently are at
00136      * \param node_t    type of node we have (will be e_.... - )
00137      * \param weight    Node weight of current node. 
00138      * \sa  NodeTags;
00139      *
00140      * you can create the node by using a switch on node_tag and using the 
00141      * corresponding Node objects. Or - if you do not care about the type 
00142      * use the Nodebase class.
00143      */
00144     template<class TR, class IntT, class TopT,class Feat>
00145     void visit_external_node(TR & tr, IntT index, TopT node_t,Feat & features)
00146     {}
00147     
00148     /** do something when visiting a internal node after it has been learned
00149      *
00150      * \sa visit_external_node
00151      */
00152     template<class TR, class IntT, class TopT,class Feat>
00153     void visit_internal_node(TR & tr, IntT index, TopT node_t,Feat & features)
00154     {}
00155 
00156     /** return a double value.  The value of the first 
00157      * visitor encountered that has a return value is returned with the
00158      * RandomForest::learn() method - or -1.0 if no return value visitor
00159      * existed. This functionality basically only exists so that the 
00160      * OOB - visitor can return the oob error rate like in the old version 
00161      * of the random forest.
00162      */
00163     double return_val()
00164     {
00165         return -1.0;
00166     }
00167 };
00168 
00169 namespace rf
00170 {
00171 
00172 /** Last Visitor that should be called to stop the recursion.
00173  */
00174 class StopVisiting: public VisitorBase
00175 {
00176     public:
00177     bool has_value()
00178     {
00179         return true;
00180     }
00181     double return_val()
00182     {
00183         return -1.0;
00184     }
00185 };
00186 /** Container elements of the statically linked Visitor list.
00187  *
00188  * use the create_visitor() factory functions to create visitors up to size 10;
00189  *
00190  */
00191 template <class Visitor, class Next = StopVisiting>
00192 class VisitorNode
00193 {
00194     public:
00195     
00196     StopVisiting    stop_;
00197     Next            next_;
00198     Visitor &       visitor_;   
00199     VisitorNode(Visitor & visitor, Next & next) 
00200     : 
00201         next_(next), visitor_(visitor)
00202     {}
00203 
00204     VisitorNode(Visitor &  visitor) 
00205     : 
00206         next_(stop_), visitor_(visitor)
00207     {}
00208 
00209     template<class Tree, class Split, class Region, class Feature_t, class Label_t>
00210     void visit_after_split( Tree          & tree, 
00211                             Split         & split,
00212                             Region        & parent,
00213                             Region        & leftChild,
00214                             Region        & rightChild,
00215                             Feature_t     & features,
00216                             Label_t       & labels)
00217     {
00218         if(visitor_.is_active())
00219             visitor_.visit_after_split(tree, split, 
00220                                        parent, leftChild, rightChild,
00221                                        features, labels);
00222         next_.visit_after_split(tree, split, parent, leftChild, rightChild,
00223                                 features, labels);
00224     }
00225 
00226     template<class RF, class PR, class SM, class ST>
00227     void visit_after_tree(RF& rf, PR & pr,  SM & sm, ST & st, int index)
00228     {
00229         if(visitor_.is_active())
00230             visitor_.visit_after_tree(rf, pr, sm, st, index);
00231         next_.visit_after_tree(rf, pr, sm, st, index);
00232     }
00233 
00234     template<class RF, class PR>
00235     void visit_at_beginning(RF & rf, PR & pr)
00236     {
00237         if(visitor_.is_active())
00238             visitor_.visit_at_beginning(rf, pr);
00239         next_.visit_at_beginning(rf, pr);
00240     }
00241     template<class RF, class PR>
00242     void visit_at_end(RF & rf, PR & pr)
00243     {
00244         if(visitor_.is_active())
00245             visitor_.visit_at_end(rf, pr);
00246         next_.visit_at_end(rf, pr);
00247     }
00248     
00249     template<class TR, class IntT, class TopT,class Feat>
00250     void visit_external_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
00251     {
00252         if(visitor_.is_active())
00253             visitor_.visit_external_node(tr, index, node_t,features);
00254         next_.visit_external_node(tr, index, node_t,features);
00255     }
00256     template<class TR, class IntT, class TopT,class Feat>
00257     void visit_internal_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
00258     {
00259         if(visitor_.is_active())
00260             visitor_.visit_internal_node(tr, index, node_t,features);
00261         next_.visit_internal_node(tr, index, node_t,features);
00262     }
00263 
00264     double return_val()
00265     {
00266         if(visitor_.is_active() && visitor_.has_value())
00267             return visitor_.return_val();
00268         return next_.return_val();
00269     }
00270 };
00271 
00272 } //namespace rf
00273 
00274 //////////////////////////////////////////////////////////////////////////////
00275 //  Visitor Factory function up to 10 visitors                              //
00276 //////////////////////////////////////////////////////////////////////////////
00277 template<class A>
00278 rf::VisitorNode<A>
00279 create_visitor(A & a)
00280 {
00281    typedef rf::VisitorNode<A> _0_t;
00282    _0_t _0(a);
00283    return _0;
00284 }
00285 
00286 
00287 template<class A, class B>
00288 rf::VisitorNode<A, rf::VisitorNode<B> >
00289 create_visitor(A & a, B & b)
00290 {
00291    typedef rf::VisitorNode<B> _1_t;
00292    _1_t _1(b);
00293    typedef rf::VisitorNode<A, _1_t> _0_t;
00294    _0_t _0(a, _1);
00295    return _0;
00296 }
00297 
00298 
00299 template<class A, class B, class C>
00300 rf::VisitorNode<A, rf::VisitorNode<B, rf::VisitorNode<C> > >
00301 create_visitor(A & a, B & b, C & c)
00302 {
00303    typedef rf::VisitorNode<C> _2_t;
00304    _2_t _2(c);
00305    typedef rf::VisitorNode<B, _2_t> _1_t;
00306    _1_t _1(b, _2);
00307    typedef rf::VisitorNode<A, _1_t> _0_t;
00308    _0_t _0(a, _1);
00309    return _0;
00310 }
00311 
00312 
00313 template<class A, class B, class C, class D>
00314 rf::VisitorNode<A, rf::VisitorNode<B, rf::VisitorNode<C, 
00315     rf::VisitorNode<D> > > >
00316 create_visitor(A & a, B & b, C & c, D & d)
00317 {
00318    typedef rf::VisitorNode<D> _3_t;
00319    _3_t _3(d);
00320    typedef rf::VisitorNode<C, _3_t> _2_t;
00321    _2_t _2(c, _3);
00322    typedef rf::VisitorNode<B, _2_t> _1_t;
00323    _1_t _1(b, _2);
00324    typedef rf::VisitorNode<A, _1_t> _0_t;
00325    _0_t _0(a, _1);
00326    return _0;
00327 }
00328 
00329 
00330 template<class A, class B, class C, class D, class E>
00331 rf::VisitorNode<A, rf::VisitorNode<B, rf::VisitorNode<C, 
00332     rf::VisitorNode<D, rf::VisitorNode<E> > > > >
00333 create_visitor(A & a, B & b, C & c, 
00334                D & d, E & e)
00335 {
00336    typedef rf::VisitorNode<E> _4_t;
00337    _4_t _4(e);
00338    typedef rf::VisitorNode<D, _4_t> _3_t;
00339    _3_t _3(d, _4);
00340    typedef rf::VisitorNode<C, _3_t> _2_t;
00341    _2_t _2(c, _3);
00342    typedef rf::VisitorNode<B, _2_t> _1_t;
00343    _1_t _1(b, _2);
00344    typedef rf::VisitorNode<A, _1_t> _0_t;
00345    _0_t _0(a, _1);
00346    return _0;
00347 }
00348 
00349 
00350 template<class A, class B, class C, class D, class E,
00351          class F>
00352 rf::VisitorNode<A, rf::VisitorNode<B, rf::VisitorNode<C, 
00353     rf::VisitorNode<D, rf::VisitorNode<E, rf::VisitorNode<F> > > > > >
00354 create_visitor(A & a, B & b, C & c, 
00355                D & d, E & e, F & f)
00356 {
00357    typedef rf::VisitorNode<F> _5_t;
00358    _5_t _5(f);
00359    typedef rf::VisitorNode<E, _5_t> _4_t;
00360    _4_t _4(e, _5);
00361    typedef rf::VisitorNode<D, _4_t> _3_t;
00362    _3_t _3(d, _4);
00363    typedef rf::VisitorNode<C, _3_t> _2_t;
00364    _2_t _2(c, _3);
00365    typedef rf::VisitorNode<B, _2_t> _1_t;
00366    _1_t _1(b, _2);
00367    typedef rf::VisitorNode<A, _1_t> _0_t;
00368    _0_t _0(a, _1);
00369    return _0;
00370 }
00371 
00372 
00373 template<class A, class B, class C, class D, class E,
00374          class F, class G>
00375 rf::VisitorNode<A, rf::VisitorNode<B, rf::VisitorNode<C, 
00376     rf::VisitorNode<D, rf::VisitorNode<E, rf::VisitorNode<F, 
00377     rf::VisitorNode<G> > > > > > >
00378 create_visitor(A & a, B & b, C & c, 
00379                D & d, E & e, F & f, G & g)
00380 {
00381    typedef rf::VisitorNode<G> _6_t;
00382    _6_t _6(g);
00383    typedef rf::VisitorNode<F, _6_t> _5_t;
00384    _5_t _5(f, _6);
00385    typedef rf::VisitorNode<E, _5_t> _4_t;
00386    _4_t _4(e, _5);
00387    typedef rf::VisitorNode<D, _4_t> _3_t;
00388    _3_t _3(d, _4);
00389    typedef rf::VisitorNode<C, _3_t> _2_t;
00390    _2_t _2(c, _3);
00391    typedef rf::VisitorNode<B, _2_t> _1_t;
00392    _1_t _1(b, _2);
00393    typedef rf::VisitorNode<A, _1_t> _0_t;
00394    _0_t _0(a, _1);
00395    return _0;
00396 }
00397 
00398 
00399 template<class A, class B, class C, class D, class E,
00400          class F, class G, class H>
00401 rf::VisitorNode<A, rf::VisitorNode<B, rf::VisitorNode<C, 
00402     rf::VisitorNode<D, rf::VisitorNode<E, rf::VisitorNode<F, 
00403     rf::VisitorNode<G, rf::VisitorNode<H> > > > > > > >
00404 create_visitor(A & a, B & b, C & c, 
00405                D & d, E & e, F & f, 
00406                G & g, H & h)
00407 {
00408    typedef rf::VisitorNode<H> _7_t;
00409    _7_t _7(h);
00410    typedef rf::VisitorNode<G, _7_t> _6_t;
00411    _6_t _6(g, _7);
00412    typedef rf::VisitorNode<F, _6_t> _5_t;
00413    _5_t _5(f, _6);
00414    typedef rf::VisitorNode<E, _5_t> _4_t;
00415    _4_t _4(e, _5);
00416    typedef rf::VisitorNode<D, _4_t> _3_t;
00417    _3_t _3(d, _4);
00418    typedef rf::VisitorNode<C, _3_t> _2_t;
00419    _2_t _2(c, _3);
00420    typedef rf::VisitorNode<B, _2_t> _1_t;
00421    _1_t _1(b, _2);
00422    typedef rf::VisitorNode<A, _1_t> _0_t;
00423    _0_t _0(a, _1);
00424    return _0;
00425 }
00426 
00427 
00428 template<class A, class B, class C, class D, class E,
00429          class F, class G, class H, class I>
00430 rf::VisitorNode<A, rf::VisitorNode<B, rf::VisitorNode<C, 
00431     rf::VisitorNode<D, rf::VisitorNode<E, rf::VisitorNode<F, 
00432     rf::VisitorNode<G, rf::VisitorNode<H, rf::VisitorNode<I> > > > > > > > >
00433 create_visitor(A & a, B & b, C & c, 
00434                D & d, E & e, F & f, 
00435                G & g, H & h, I & i)
00436 {
00437    typedef rf::VisitorNode<I> _8_t;
00438    _8_t _8(i);
00439    typedef rf::VisitorNode<H, _8_t> _7_t;
00440    _7_t _7(h, _8);
00441    typedef rf::VisitorNode<G, _7_t> _6_t;
00442    _6_t _6(g, _7);
00443    typedef rf::VisitorNode<F, _6_t> _5_t;
00444    _5_t _5(f, _6);
00445    typedef rf::VisitorNode<E, _5_t> _4_t;
00446    _4_t _4(e, _5);
00447    typedef rf::VisitorNode<D, _4_t> _3_t;
00448    _3_t _3(d, _4);
00449    typedef rf::VisitorNode<C, _3_t> _2_t;
00450    _2_t _2(c, _3);
00451    typedef rf::VisitorNode<B, _2_t> _1_t;
00452    _1_t _1(b, _2);
00453    typedef rf::VisitorNode<A, _1_t> _0_t;
00454    _0_t _0(a, _1);
00455    return _0;
00456 }
00457 
00458 template<class A, class B, class C, class D, class E,
00459          class F, class G, class H, class I, class J>
00460 rf::VisitorNode<A, rf::VisitorNode<B, rf::VisitorNode<C, 
00461     rf::VisitorNode<D, rf::VisitorNode<E, rf::VisitorNode<F, 
00462     rf::VisitorNode<G, rf::VisitorNode<H, rf::VisitorNode<I,
00463     rf::VisitorNode<J> > > > > > > > > >
00464 create_visitor(A & a, B & b, C & c, 
00465                D & d, E & e, F & f, 
00466                G & g, H & h, I & i,
00467                J & j)
00468 {
00469    typedef rf::VisitorNode<J> _9_t;
00470    _9_t _9(j);
00471    typedef rf::VisitorNode<I, _9_t> _8_t;
00472    _8_t _8(i, _9);
00473    typedef rf::VisitorNode<H, _8_t> _7_t;
00474    _7_t _7(h, _8);
00475    typedef rf::VisitorNode<G, _7_t> _6_t;
00476    _6_t _6(g, _7);
00477    typedef rf::VisitorNode<F, _6_t> _5_t;
00478    _5_t _5(f, _6);
00479    typedef rf::VisitorNode<E, _5_t> _4_t;
00480    _4_t _4(e, _5);
00481    typedef rf::VisitorNode<D, _4_t> _3_t;
00482    _3_t _3(d, _4);
00483    typedef rf::VisitorNode<C, _3_t> _2_t;
00484    _2_t _2(c, _3);
00485    typedef rf::VisitorNode<B, _2_t> _1_t;
00486    _1_t _1(b, _2);
00487    typedef rf::VisitorNode<A, _1_t> _0_t;
00488    _0_t _0(a, _1);
00489    return _0;
00490 }
00491 
00492 //////////////////////////////////////////////////////////////////////////////
00493 // Visitors of communal interest. Do not spam this file with stuff          //
00494 // nobody wants.                                                            //
00495 //////////////////////////////////////////////////////////////////////////////
00496 
00497 
00498 /** Vistior to gain information, later needed for online learning.
00499  */
00500 
00501 class OnlineLearnVisitor: public VisitorBase
00502 {
00503 public:
00504     //Set if we adjust thresholds
00505     bool adjust_thresholds;
00506     //Current tree id
00507     int tree_id;
00508     //Last node id for finding parent
00509     int last_node_id;
00510     //Need to now the label for interior node visiting
00511     vigra::Int32 current_label;
00512     //marginal distribution for interior nodes
00513     struct MarginalDistribution
00514     {
00515         ArrayVector<Int32> leftCounts;
00516         Int32 leftTotalCounts;
00517         ArrayVector<Int32> rightCounts;
00518         Int32 rightTotalCounts;
00519         double gap_left;
00520         double gap_right;
00521     };
00522     typedef ArrayVector<vigra::Int32> IndexList;
00523 
00524     //All information for one tree
00525     struct TreeOnlineInformation
00526     {
00527         std::vector<MarginalDistribution> mag_distributions;
00528         std::vector<IndexList> index_lists;
00529         //map for linear index of mag_distiributions
00530         std::map<int,int> interior_to_index;
00531         //map for linear index of index_lists
00532         std::map<int,int> exterior_to_index;
00533     };
00534 
00535     //All trees
00536     std::vector<TreeOnlineInformation> trees_online_information;
00537 
00538     /** Initilize, set the number of trees
00539      */
00540     template<class RF,class PR>
00541     void visit_at_beginning(RF & rf,const PR & pr)
00542     {
00543         tree_id=0;
00544         trees_online_information.resize(rf.options_.tree_count_);
00545     }
00546 
00547     /** Reset a tree
00548      */
00549     void reset_tree(int tree_id)
00550     {
00551         trees_online_information[tree_id].mag_distributions.clear();
00552         trees_online_information[tree_id].index_lists.clear();
00553         trees_online_information[tree_id].interior_to_index.clear();
00554         trees_online_information[tree_id].exterior_to_index.clear();
00555     }
00556 
00557     /** simply increase the tree count
00558     */
00559     template<class RF, class PR, class SM, class ST>
00560     void visit_after_tree(RF& rf, PR & pr,  SM & sm, ST & st, int index)
00561     {
00562         tree_id++;
00563     }
00564     
00565     template<class Tree, class Split, class Region, class Feature_t, class Label_t>
00566     void visit_after_split( Tree          & tree, 
00567                 Split         & split,
00568                             Region       & parent,
00569                             Region        & leftChild,
00570                             Region        & rightChild,
00571                             Feature_t     & features,
00572                             Label_t       & labels)
00573     {
00574         int linear_index;
00575         int addr=tree.topology_.size();
00576         if(split.createNode().typeID() == i_ThresholdNode)
00577         {
00578             if(adjust_thresholds)
00579             {
00580                 //Store marginal distribution
00581                 linear_index=trees_online_information[tree_id].mag_distributions.size();
00582                 trees_online_information[tree_id].interior_to_index[addr]=linear_index;
00583                 trees_online_information[tree_id].mag_distributions.push_back(MarginalDistribution());
00584 
00585                 trees_online_information[tree_id].mag_distributions.back().leftCounts=leftChild.classCounts_;
00586                 trees_online_information[tree_id].mag_distributions.back().rightCounts=rightChild.classCounts_;
00587 
00588                 trees_online_information[tree_id].mag_distributions.back().leftTotalCounts=leftChild.size_;
00589                 trees_online_information[tree_id].mag_distributions.back().rightTotalCounts=rightChild.size_;
00590                 //Store the gap
00591                 double gap_left,gap_right;
00592                 int i;
00593                 gap_left=features(leftChild[0],split.bestSplitColumn());
00594                 for(i=1;i<leftChild.size();++i)
00595                     if(features(leftChild[i],split.bestSplitColumn())>gap_left)
00596                         gap_left=features(leftChild[i],split.bestSplitColumn());
00597                 gap_right=features(rightChild[0],split.bestSplitColumn());
00598                 for(i=1;i<rightChild.size();++i)
00599                     if(features(rightChild[i],split.bestSplitColumn())<gap_right)
00600                         gap_right=features(rightChild[i],split.bestSplitColumn());
00601                 trees_online_information[tree_id].mag_distributions.back().gap_left=gap_left;
00602                 trees_online_information[tree_id].mag_distributions.back().gap_right=gap_right;
00603             }
00604         }
00605         else
00606         {
00607             //Store index list
00608             linear_index=trees_online_information[tree_id].index_lists.size();
00609             trees_online_information[tree_id].exterior_to_index[addr]=linear_index;
00610 
00611             trees_online_information[tree_id].index_lists.push_back(IndexList());
00612 
00613             trees_online_information[tree_id].index_lists.back().resize(parent.size_,0);
00614             std::copy(parent.begin_,parent.end_,trees_online_information[tree_id].index_lists.back().begin());
00615         }
00616     }
00617     void add_to_index_list(int tree,int node,int index)
00618     {
00619         if(!this->active_)
00620             return;
00621         TreeOnlineInformation &ti=trees_online_information[tree];
00622         ti.index_lists[ti.exterior_to_index[node]].push_back(index);
00623     }
00624     void move_exterior_node(int src_tree,int src_index,int dst_tree,int dst_index)
00625     {
00626         if(!this->active_)
00627             return;
00628         trees_online_information[dst_tree].exterior_to_index[dst_index]=trees_online_information[src_tree].exterior_to_index[src_index];
00629         trees_online_information[src_tree].exterior_to_index.erase(src_index);
00630     }
00631     /** do something when visiting a internal node during getToLeaf
00632      *
00633      * remember as last node id, for finding the parent of the last external node
00634      * also: adjust class counts and borders
00635      */
00636     template<class TR, class IntT, class TopT,class Feat>
00637         void visit_internal_node(TR & tr, IntT index, TopT node_t,Feat & features)
00638         {
00639             last_node_id=index;
00640             if(adjust_thresholds)
00641             {
00642                 vigra_assert(node_t==i_ThresholdNode,"We can only visit threshold nodes");
00643                 //Check if we are in the gap
00644                 double value=features(0, Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).column());
00645                 TreeOnlineInformation &ti=trees_online_information[tree_id];
00646                 MarginalDistribution &m=ti.mag_distributions[ti.interior_to_index[index]];
00647                 if(value>m.gap_left && value<m.gap_right)
00648                 {
00649                     //Check which site we want to go
00650                     if(m.leftCounts[current_label]/double(m.leftTotalCounts)>m.rightCounts[current_label]/double(m.rightTotalCounts))
00651                     {
00652                         //We want to go left
00653                         m.gap_left=value;
00654                     }
00655                     else
00656                     {
00657                         //We want to go right
00658                         m.gap_right=value;
00659                     }
00660                     Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold()=(m.gap_right+m.gap_left)/2.0;
00661                 }
00662                 //Adjust class counts
00663                 if(value>Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold())
00664                 {
00665                     ++m.rightTotalCounts;
00666                     ++m.rightCounts[current_label];
00667                 }
00668                 else
00669                 {
00670                     ++m.leftTotalCounts;
00671                     ++m.rightCounts[current_label];
00672                 }
00673             }
00674         }
00675     /** do something when visiting a extern node during getToLeaf
00676      * 
00677      * Store the new index!
00678      */
00679 };
00680 
00681 
00682 /** Visitor that calculates the oob error of the random forest. 
00683  * this is the default visitor used. 
00684  *
00685  * To bored to comment each line of this class - trust me it works.
00686  */
00687 class OOB_Visitor:public VisitorBase
00688 {
00689 public:
00690     double oobError;
00691     int totalOobCount;
00692     ArrayVector<int> oobCount,oobErrorCount;
00693 
00694     OOB_Visitor()
00695     : oobError(0.0),
00696       totalOobCount(0)
00697     {}
00698 
00699 
00700     bool has_value()
00701     {
00702         return true;
00703     }
00704     /** does the basic calculation per tree*/
00705     template<class RF, class PR, class SM, class ST>
00706     void visit_after_tree(    RF& rf, PR & pr,  SM & sm, ST & st, int index)
00707     {
00708         //do the first time called.
00709         if(int(oobCount.size()) != rf.ext_param_.row_count_)
00710         {
00711             oobCount.resize(rf.ext_param_.row_count_, 0);
00712             oobErrorCount.resize(rf.ext_param_.row_count_, 0);
00713         }
00714         // go through the samples
00715         for(int l = 0; l < rf.ext_param_.row_count_; ++l)
00716         {
00717             // if the lth sample is oob...
00718             if(!sm.is_used()[l])
00719             {
00720                 ++oobCount[l];
00721                 if(     rf.tree(index)
00722                             .predictLabel(rowVector(pr.features(), l)) 
00723                     !=  pr.response()(l,0))
00724                 {
00725                     ++oobErrorCount[l];
00726                 }
00727             }
00728 
00729         }
00730     }
00731 
00732     /** Does the normalisation
00733      */
00734     template<class RF, class PR>
00735     void visit_at_end(RF & rf, PR & pr)
00736     {
00737         // do some normalisation
00738         for(int l=0; l < (int)rf.ext_param_.row_count_; ++l)
00739         {
00740             if(oobCount[l])
00741             {
00742                 oobError += double(oobErrorCount[l]) / oobCount[l];
00743                 ++totalOobCount;
00744             }
00745         } 
00746     }
00747     
00748     //returns value of the learn function. 
00749     double return_val()
00750     {
00751         return oobError/totalOobCount;
00752     }
00753 };
00754 
00755 
00756 /** calculate variable importance while learning.
00757  */
00758 class VariableImportanceVisitor : public VisitorBase
00759 {
00760     public:
00761 
00762     /** This Array has the same entries as the R - random forest variable
00763      *  importance
00764      */
00765     MultiArray<2, double>       variable_importance_;
00766     int                         repetition_count_;
00767     bool                        in_place_;
00768 
00769 #ifdef HasHDF5
00770     void save(std::string filename, std::string prefix)
00771     {
00772         prefix = "variable_importance_" + prefix;
00773         writeHDF5(filename.c_str(), 
00774                         prefix.c_str(), 
00775                         variable_importance_);
00776     }
00777 #endif
00778 
00779     VariableImportanceVisitor(int rep_cnt = 10) 
00780     :   repetition_count_(rep_cnt)
00781 
00782     {}
00783 
00784     /** calculates impurity decrease based variable importance after every
00785      * split.  
00786      */
00787     template<class Tree, class Split, class Region, class Feature_t, class Label_t>
00788     void visit_after_split( Tree          & tree, 
00789                             Split         & split,
00790                             Region        & parent,
00791                             Region        & leftChild,
00792                             Region        & rightChild,
00793                             Feature_t     & features,
00794                             Label_t       & labels)
00795     {
00796         //resize to right size when called the first time
00797         
00798         Int32 const  class_count = tree.ext_param_.class_count_;
00799         Int32 const  column_count = tree.ext_param_.column_count_;
00800         if(variable_importance_.size() == 0)
00801         {
00802             
00803             variable_importance_
00804                 .reshape(MultiArrayShape<2>::type(column_count, 
00805                                                  class_count+2));
00806         }
00807 
00808         if(split.createNode().typeID() == i_ThresholdNode)
00809         {
00810             Node<i_ThresholdNode> node(split.createNode());
00811             variable_importance_(node.column(),class_count+1) 
00812                 += split.region_gini_ - split.minGini();
00813         }
00814     }
00815 
00816     /**compute permutation based var imp. 
00817      * (Only an Array of size oob_sample_count x 1 is created.
00818      *  - apposed to oob_sample_count x feature_count in the other method.
00819      * 
00820      * \sa FieldProxy
00821      */
00822     template<class RF, class PR, class SM, class ST>
00823     void after_tree_ip_impl(RF& rf, PR & pr,  SM & sm, ST & st, int index)
00824     {
00825         typedef MultiArrayShape<2>::type Shp_t;
00826         Int32                   column_count = rf.ext_param_.column_count_;
00827         Int32                   class_count  = rf.ext_param_.class_count_;  
00828         
00829         // remove the const cast on the features (yep , I know what I am 
00830         // doing here.) data is not destroyed.
00831         typename PR::Feature_t & features 
00832             = const_cast<typename PR::Feature_t &>(pr.features());
00833 
00834         //find the oob indices of current tree. 
00835         ArrayVector<Int32>      oob_indices;
00836         ArrayVector<Int32>::iterator
00837                                 iter;
00838         for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
00839             if(!sm.is_used()[ii])
00840                 oob_indices.push_back(ii);
00841 
00842         //create space to back up a column      
00843         std::vector<double>     backup_column;
00844 
00845         // Random foo
00846 #ifdef CLASSIFIER_TEST
00847         RandomMT19937           random(1);
00848 #else 
00849         RandomMT19937           random(RandomSeed);
00850 #endif
00851         UniformIntRandomFunctor<RandomMT19937>  
00852                                 randint(random);
00853 
00854 
00855         //make some space for the results
00856         MultiArray<2, double>
00857                     oob_right(Shp_t(1, class_count + 1)); 
00858         MultiArray<2, double>
00859                     perm_oob_right (Shp_t(1, class_count + 1)); 
00860             
00861         
00862         // get the oob success rate with the original samples
00863         for(iter = oob_indices.begin(); 
00864             iter != oob_indices.end(); 
00865             ++iter)
00866         {
00867             if(rf.tree(index)
00868                     .predictLabel(rowVector(features, *iter)) 
00869                 ==  pr.response()(*iter, 0))
00870             {
00871                 //per class
00872                 ++oob_right[pr.response()(*iter,0)];
00873                 //total
00874                 ++oob_right[class_count];
00875             }
00876         }
00877         //get the oob rate after permuting the ii'th dimension.
00878         for(int ii = 0; ii < column_count; ++ii)
00879         {
00880             perm_oob_right.init(0.0); 
00881             //make backup of orinal column
00882             backup_column.clear();
00883             for(iter = oob_indices.begin(); 
00884                 iter != oob_indices.end(); 
00885                 ++iter)
00886             {
00887                 backup_column.push_back(features(*iter,ii));
00888             }
00889             
00890             //get the oob rate after permuting the ii'th dimension.
00891             for(int rr = 0; rr < repetition_count_; ++rr)
00892             {               
00893                 //permute dimension. 
00894                 int n = oob_indices.size();
00895                 for(int jj = 1; jj < n; ++jj)
00896                     std::swap(features(oob_indices[jj], ii), 
00897                               features(oob_indices[randint(jj+1)], ii));
00898 
00899                 //get the oob sucess rate after permuting
00900                 for(iter = oob_indices.begin(); 
00901                     iter != oob_indices.end(); 
00902                     ++iter)
00903                 {
00904                     if(rf.tree(index)
00905                             .predictLabel(rowVector(features, *iter)) 
00906                         ==  pr.response()(*iter, 0))
00907                     {
00908                         //per class
00909                         ++perm_oob_right[pr.response()(*iter, 0)];
00910                         //total
00911                         ++perm_oob_right[class_count];
00912                     }
00913                 }
00914             }
00915             
00916             
00917             //normalise and add to the variable_importance array.
00918             perm_oob_right  /=  repetition_count_;
00919             perm_oob_right -=oob_right;
00920             perm_oob_right *= -1;
00921             perm_oob_right      /=  oob_indices.size();
00922             variable_importance_
00923                 .subarray(Shp_t(ii,0), 
00924                           Shp_t(ii+1,class_count+1)) += perm_oob_right;
00925             //copy back permuted dimension
00926             for(int jj = 0; jj < int(oob_indices.size()); ++jj)
00927                 features(oob_indices[jj], ii) = backup_column[jj];
00928         }
00929     }
00930 
00931     /** calculate permutation based impurity after every tree has been 
00932      * learned  default behaviour is that this happens out of place.
00933      * If you have very big data sets and want to avoid copying of data 
00934      * set the in_place_ flag to true. 
00935      */
00936     template<class RF, class PR, class SM, class ST>
00937     void visit_after_tree(RF& rf, PR & pr,  SM & sm, ST & st, int index)
00938     {
00939             after_tree_ip_impl(rf, pr, sm, st, index);
00940     }
00941 
00942     /** Normalise variable importance after the number of trees is known.
00943      */
00944     template<class RF, class PR>
00945     void visit_at_end(RF & rf, PR & pr)
00946     {
00947         variable_importance_ /= rf.trees_.size();
00948     }
00949 };
00950 
00951 } // namespace vigra
00952 #endif // RF_VISITORS_HXX

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.7.0 (Thu Aug 25 2011)