[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

vigra/random_forest/rf_decisionTree.hxx
00001 /************************************************************************/
00002 /*                                                                      */
00003 /*        Copyright 2008-2009 by  Ullrich Koethe and Rahul Nair         */
00004 /*                                                                      */
00005 /*    This file is part of the VIGRA computer vision library.           */
00006 /*    The VIGRA Website is                                              */
00007 /*        http://hci.iwr.uni-heidelberg.de/vigra/                       */
00008 /*    Please direct questions, bug reports, and contributions to        */
00009 /*        ullrich.koethe@iwr.uni-heidelberg.de    or                    */
00010 /*        vigra@informatik.uni-hamburg.de                               */
00011 /*                                                                      */
00012 /*    Permission is hereby granted, free of charge, to any person       */
00013 /*    obtaining a copy of this software and associated documentation    */
00014 /*    files (the "Software"), to deal in the Software without           */
00015 /*    restriction, including without limitation the rights to use,      */
00016 /*    copy, modify, merge, publish, distribute, sublicense, and/or      */
00017 /*    sell copies of the Software, and to permit persons to whom the    */
00018 /*    Software is furnished to do so, subject to the following          */
00019 /*    conditions:                                                       */
00020 /*                                                                      */
00021 /*    The above copyright notice and this permission notice shall be    */
00022 /*    included in all copies or substantial portions of the             */
00023 /*    Software.                                                         */
00024 /*                                                                      */
00025 /*    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND    */
00026 /*    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES   */
00027 /*    OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND          */
00028 /*    NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT       */
00029 /*    HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,      */
00030 /*    WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING      */
00031 /*    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR     */
00032 /*    OTHER DEALINGS IN THE SOFTWARE.                                   */
00033 /*                                                                      */
00034 /************************************************************************/
00035 
00036 #ifndef VIGRA_RANDOM_FOREST_DT_HXX
00037 #define VIGRA_RANDOM_FOREST_DT_HXX
00038 
00039 #include <algorithm>
00040 #include <map>
00041 #include <numeric>
00042 #include "vigra/multi_array.hxx"
00043 #include "vigra/mathutil.hxx"
00044 #include "vigra/array_vector.hxx"
00045 #include "vigra/sized_int.hxx"
00046 #include "vigra/matrix.hxx"
00047 #include "vigra/random.hxx"
00048 #include "vigra/functorexpression.hxx"
00049 #include <vector>
00050 
00051 #include "rf_common.hxx"
00052 #include "rf_nodeproxy.hxx"
00053 namespace vigra
00054 {
00055 
00056 namespace detail
00057 {
00058  // todo FINALLY DECIDE TO USE CAMEL CASE OR UNDERSCORES !!!!!!
00059 /** decisiontree classifier. 
00060  *
00061  * This class is actually meant to be used in conjunction with the 
00062  * Random Forest Classifier 
00063  * - My suggestion would be to use the RandomForest classifier with 
00064  *   following parameters instead of directly using this 
00065  *   class (Preprocessing default values etc is handled in there):
00066  *
00067  * \code
00068  *      RandomForest decisionTree(RF_Traits::Options_t()
00069  *                                  .features_per_node(RF_ALL)
00070  *                                  .tree_count(1)            );
00071  * \endcode
00072  * 
00073  * \todo remove the classCount and featurecount from the topology
00074  *       array. Pass ext_param_ to the nodes!
00075  * \todo Use relative addressing of nodes?
00076  */
00077 class DecisionTree
00078 {
00079     /**\todo make private?*/
00080   public:
00081     
00082     /** value type of container array. use whenever referencing it
00083      */
00084     typedef Int32 TreeInt;
00085 
00086     ArrayVector<TreeInt>  topology_;
00087     ArrayVector<double>   parameters_;
00088 
00089     ProblemSpec<> ext_param_;
00090     unsigned int classCount_;
00091 
00092 
00093   public:
00094     /** \brief Create tree with parameters */
00095     template<class T>
00096     DecisionTree(ProblemSpec<T> ext_param)
00097     :
00098         ext_param_(ext_param),
00099         classCount_(ext_param.class_count_)
00100     {}
00101 
00102     /**clears all memory used.
00103      */
00104     void reset(unsigned int classCount = 0)
00105     {
00106         if(classCount)
00107             classCount_ = classCount;
00108         topology_.clear();
00109         parameters_.clear();
00110     }
00111 
00112 
00113     /** learn a Tree
00114      *
00115      * \tparam  StackEntry_t The Stackentry containing Node/StackEntry_t 
00116      *          Information used during learing. Each Split functor has a 
00117      *          Stack entry associated with it (Split_t::StackEntry_t)
00118      * \sa RandomForest::learn()
00119      */
00120     template <  class U, class C,
00121                 class U2, class C2,
00122                 class StackEntry_t,
00123                 class Stop_t,
00124                 class Split_t,
00125                 class Visitor_t,
00126                 class Random_t >
00127     void learn(     MultiArrayView<2, U, C> const      & features,
00128                     MultiArrayView<2, U2, C2> const    & labels,
00129                     StackEntry_t const &                 stack_entry,
00130                     Split_t                              split,
00131                     Stop_t                               stop,
00132                     Visitor_t &                          visitor,
00133                     Random_t &                           randint);
00134     template <  class U, class C,
00135              class U2, class C2,
00136              class StackEntry_t,
00137              class Stop_t,
00138              class Split_t,
00139              class Visitor_t,
00140              class Random_t>
00141     void continueLearn(   MultiArrayView<2, U, C> const       & features,
00142                           MultiArrayView<2, U2, C2> const     & labels,
00143                           StackEntry_t const &                  stack_entry,
00144                           Split_t                               split,
00145                           Stop_t                                stop,
00146                           Visitor_t &                           visitor,
00147                           Random_t &                            randint,
00148                           //an index to which the last created exterior node will be moved (because it is not used anymore)
00149                           int                                   garbaged_child=-1);
00150 
00151     /** is a node a Leaf Node? */
00152     inline bool isLeafNode(TreeInt in) const
00153     {
00154         return (in & LeafNodeTag) == LeafNodeTag;
00155     }
00156 
00157     /** data driven traversal from root to leaf
00158      *
00159      * traverse through tree with data given in features. Use Visitors to 
00160      * collect statistics along the way. 
00161      */
00162     template<class U, class C, class Visitor_t>
00163     TreeInt getToLeaf(MultiArrayView<2, U, C> const & features, 
00164                       Visitor_t  & visitor) const
00165     {
00166         TreeInt index = 2;
00167         while(!isLeafNode(topology_[index]))
00168         {
00169             visitor.visit_internal_node(*this, index, topology_[index],features);
00170             switch(topology_[index])
00171             {
00172                 case i_ThresholdNode:
00173                 {
00174                     Node<i_ThresholdNode> 
00175                                 node(topology_, parameters_, index);
00176                     index = node.next(features);
00177                     break;
00178                 }
00179                 case i_HyperplaneNode:
00180                 {
00181                     Node<i_HyperplaneNode> 
00182                                 node(topology_, parameters_, index);
00183                     index = node.next(features);
00184                     break;
00185                 }
00186                 case i_HypersphereNode:
00187                 {
00188                     Node<i_HypersphereNode> 
00189                                 node(topology_, parameters_, index);
00190                     index = node.next(features);
00191                     break;
00192                 }
00193 #if 0 
00194                 // for quick prototyping! has to be implemented.
00195                 case i_VirtualNode:
00196                 {
00197                     Node<i_VirtualNode> 
00198                                 node(topology_, parameters, index);
00199                     index = node.next(features);
00200                 }
00201 #endif
00202                 default:
00203                     vigra_fail("DecisionTree::getToLeaf():"
00204                                "encountered unknown internal Node Type");
00205             }
00206         }
00207         visitor.visit_external_node(*this, index, topology_[index],features);
00208         return index;
00209     }
00210     /** traverse tree to get statistics
00211      *
00212      * Tree is traversed in order the Nodes are in memory (i.e. if no 
00213      * relearning//pruning scheme is utilized this will be pre order)
00214      */
00215     template<class Visitor_t>
00216     void traverse_mem_order(Visitor_t visitor) const
00217     {
00218         TreeInt index = 2;
00219         Int32 ii = 0;
00220         while(index < topology_.size())
00221         {
00222             if(isLeafNode(topology_[index]))
00223             {
00224                 visitor
00225                     .visit_external_node(*this, index, topology_[index]);
00226             }
00227             else
00228             {
00229                 visitor
00230                     ._internal_node(*this, index, topology_[index]);
00231             }
00232         }
00233     }
00234 
00235     template<class Visitor_t>
00236     void traverse_post_order(Visitor_t visitor,  TreeInt start = 2) const
00237     {
00238         typedef TinyVector<double, 2> Entry; 
00239         std::vector<Entry > stack;
00240         std::vector<double> result_stack;
00241         stack.push_back(Entry(2, 0));
00242         int addr; 
00243         while(!stack.empty())
00244         {
00245             addr = stack.back()[0];
00246             NodeBase node(topology_, parameters_, stack.back()[0]);
00247             if(stack.back()[1] == 1)
00248             {
00249                 stack.pop_back();
00250                 double leftRes = result_stack.back();
00251                 double rightRes = result_stack.back();
00252                 result_stack.pop_back();
00253                 result_stack.pop_back();
00254                 result_stack.push_back(rightRes+ leftRes);
00255                 visitor.visit_internal_node(*this, 
00256                                             addr, 
00257                                             node.typeID(), 
00258                                             rightRes+leftRes);
00259             }
00260             else
00261             {
00262                 if(isLeafNode(node.typeID()))
00263                 {
00264                     visitor.visit_external_node(*this, 
00265                                                 addr, 
00266                                                 node.typeID(), 
00267                                                 node.weights());
00268                     stack.pop_back();
00269                     result_stack.push_back(node.weights());
00270                 }
00271                 else
00272                 {
00273                     stack.back()[1] = 1; 
00274                     stack.push_back(Entry(node.child(0), 0));
00275                     stack.push_back(Entry(node.child(1), 0));
00276                 }
00277                     
00278             }
00279         }
00280     }
00281 
00282     /** same thing as above, without any visitors */
00283     template<class U, class C>
00284     TreeInt getToLeaf(MultiArrayView<2, U, C> const & features) const
00285     {
00286         RF_Traits::StopVisiting_t stop;
00287         return getToLeaf(features, stop);
00288     }
00289 
00290 
00291     template <class U, class C>
00292     ArrayVector<double>::iterator
00293     predict(MultiArrayView<2, U, C> const & features) const
00294     {
00295         TreeInt nodeindex = getToLeaf(features);
00296         switch(topology_[nodeindex])
00297         {
00298             case e_ConstProbNode:
00299                 return Node<e_ConstProbNode>(topology_, 
00300                                              parameters_,
00301                                              nodeindex).prob_begin();
00302                 break;
00303 #if 0 
00304             //first make the Logistic regression stuff...
00305             case e_LogRegProbNode:
00306                 return Node<e_LogRegProbNode>(topology_, 
00307                                               parameters_,
00308                                               nodeindex).prob_begin();
00309 #endif            
00310             default:
00311                 vigra_fail("DecisionTree::predict() :"
00312                            " encountered unknown external Node Type");
00313         }
00314         return ArrayVector<double>::iterator();
00315     }
00316 
00317 
00318 
00319     template <class U, class C>
00320     Int32 predictLabel(MultiArrayView<2, U, C> const & features) const
00321     {
00322         ArrayVector<double>::const_iterator weights = predict(features);
00323         return argMax(weights, weights+classCount_) - weights;
00324     }
00325 
00326 };
00327 
00328 
00329 template <  class U, class C,
00330             class U2, class C2,
00331             class StackEntry_t,
00332             class Stop_t,
00333             class Split_t,
00334             class Visitor_t,
00335             class Random_t>
00336 void DecisionTree::learn(   MultiArrayView<2, U, C> const       & features,
00337                             MultiArrayView<2, U2, C2> const     & labels,
00338                             StackEntry_t const &                  stack_entry,
00339                             Split_t                               split,
00340                             Stop_t                                stop,
00341                             Visitor_t &                           visitor,
00342                             Random_t &                            randint)
00343 {
00344     this->reset();
00345     topology_.reserve(256);
00346     parameters_.reserve(256);
00347     topology_.push_back(features.shape(1));
00348     topology_.push_back(classCount_);
00349     continueLearn(features,labels,stack_entry,split,stop,visitor,randint);
00350 }
00351 
00352 template <  class U, class C,
00353             class U2, class C2,
00354             class StackEntry_t,
00355             class Stop_t,
00356             class Split_t,
00357             class Visitor_t,
00358             class Random_t>
00359 void DecisionTree::continueLearn(   MultiArrayView<2, U, C> const       & features,
00360                             MultiArrayView<2, U2, C2> const     & labels,
00361                             StackEntry_t const &                  stack_entry,
00362                             Split_t                               split,
00363                             Stop_t                                stop,
00364                             Visitor_t &                           visitor,
00365                             Random_t &                            randint,
00366                             //an index to which the last created exterior node will be moved (because it is not used anymore)
00367                             int                                   garbaged_child)
00368 {
00369     std::vector<StackEntry_t> stack;
00370     stack.reserve(128);
00371     ArrayVector<StackEntry_t> child_stack_entry(2, stack_entry);
00372     stack.push_back(stack_entry);
00373     size_t last_node_pos = 0;
00374     StackEntry_t top=stack.back();
00375 
00376     while(!stack.empty())
00377     {
00378 
00379         // Take an element of the stack. Obvious ain't it?
00380         top = stack.back();
00381         stack.pop_back();
00382 
00383         // Make sure no data from the last round has remained in Pipeline;
00384         child_stack_entry[0].reset();
00385         child_stack_entry[1].reset();
00386         split.reset();
00387 
00388 
00389         //Either the Stopping criterion decides that the split should 
00390         //produce a Terminal Node or the Split itself decides what 
00391         //kind of node to make
00392         TreeInt NodeID;
00393         if(stop(top))
00394             NodeID = split.makeTerminalNode(features, 
00395                                             labels, 
00396                                             top, 
00397                                             randint);
00398         else
00399             NodeID = split.findBestSplit(features, 
00400                                          labels, 
00401                                          top, 
00402                                          child_stack_entry, 
00403                                          randint);
00404 
00405         // do some visiting yawn - just added this comment as eye candy
00406         // (looks odd otherwise with my syntax highlighting....
00407         visitor.visit_after_split(*this, split, top, 
00408                                   child_stack_entry[0], 
00409                                   child_stack_entry[1],
00410                                   features, 
00411                                   labels);
00412 
00413 
00414         // Update the Child entries of the parent
00415         // Using InteriorNodeBase because exact parameter form not needed.
00416         // look at the Node base before getting scared.
00417         last_node_pos = topology_.size();
00418         if(top.leftParent != StackEntry_t::DecisionTreeNoParent)
00419         {
00420             NodeBase(topology_, 
00421                      parameters_, 
00422                      top.leftParent).child(0) = last_node_pos;
00423         }
00424         else if(top.rightParent != StackEntry_t::DecisionTreeNoParent)
00425         {
00426             NodeBase(topology_, 
00427                      parameters_, 
00428                      top.rightParent).child(1) = last_node_pos;
00429         }
00430 
00431 
00432         // Supply the split functor with the Node type it requires.
00433         // set the address to which the children of this node should point 
00434         // to and push back children onto stack
00435         if(!isLeafNode(NodeID))
00436         {
00437             child_stack_entry[0].leftParent = topology_.size();
00438             child_stack_entry[1].rightParent = topology_.size();    
00439             child_stack_entry[0].rightParent = -1;
00440             child_stack_entry[1].leftParent = -1;
00441             stack.push_back(child_stack_entry[0]);
00442             stack.push_back(child_stack_entry[1]);
00443         }
00444 
00445         //copy the newly created node form the split functor to the
00446         //decision tree.
00447         NodeBase(split.createNode(), topology_, parameters_ );
00448     }
00449     if(garbaged_child!=-1)
00450     {
00451         Node<e_ConstProbNode>(topology_,parameters_,garbaged_child).copy(Node<e_ConstProbNode>(topology_,parameters_,last_node_pos));
00452 
00453         int last_parameter_size = Node<e_ConstProbNode>(topology_,parameters_,garbaged_child).parameters_size();
00454         topology_.resize(last_node_pos);
00455         parameters_.resize(parameters_.size() - last_parameter_size);
00456     
00457         if(top.leftParent != StackEntry_t::DecisionTreeNoParent)
00458             NodeBase(topology_, 
00459                      parameters_, 
00460                      top.leftParent).child(0) = garbaged_child;
00461         else if(top.rightParent != StackEntry_t::DecisionTreeNoParent)
00462             NodeBase(topology_, 
00463                      parameters_, 
00464                      top.rightParent).child(1) = garbaged_child;
00465     }
00466 }
00467 
00468 } //namespace detail
00469 
00470 } //namespace vigra
00471 
00472 #endif //VIGRA_RANDOM_FOREST_DT_HXX

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.7.0 (Thu Aug 25 2011)