[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

vigra/random_forest/rf_algorithm.hxx
00001 /************************************************************************/
00002 /*                                                                      */
00003 /*        Copyright 2008-2009 by Rahul Nair                             */
00004 /*                                                                      */
00005 /*    This file is part of the VIGRA computer vision library.           */
00006 /*    The VIGRA Website is                                              */
00007 /*        http://hci.iwr.uni-heidelberg.de/vigra/                       */
00008 /*    Please direct questions, bug reports, and contributions to        */
00009 /*        ullrich.koethe@iwr.uni-heidelberg.de    or                    */
00010 /*        vigra@informatik.uni-hamburg.de                               */
00011 /*                                                                      */
00012 /*    Permission is hereby granted, free of charge, to any person       */
00013 /*    obtaining a copy of this software and associated documentation    */
00014 /*    files (the "Software"), to deal in the Software without           */
00015 /*    restriction, including without limitation the rights to use,      */
00016 /*    copy, modify, merge, publish, distribute, sublicense, and/or      */
00017 /*    sell copies of the Software, and to permit persons to whom the    */
00018 /*    Software is furnished to do so, subject to the following          */
00019 /*    conditions:                                                       */
00020 /*                                                                      */
00021 /*    The above copyright notice and this permission notice shall be    */
00022 /*    included in all copies or substantial portions of the             */
00023 /*    Software.                                                         */
00024 /*                                                                      */
00025 /*    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND    */
00026 /*    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES   */
00027 /*    OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND          */
00028 /*    NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT       */
00029 /*    HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,      */
00030 /*    WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING      */
00031 /*    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR     */
00032 /*    OTHER DEALINGS IN THE SOFTWARE.                                   */
00033 /*                                                                      */
00034 /************************************************************************/
00035 #define VIGRA_RF_ALGORTIHM_HXX
00036 
00037 #include <vector>
00038 #include "splices.hxx"
00039 #include <queue>
00040 namespace vigra
00041 {
00042  
00043 namespace rf
00044 {
00045 /** This namespace contains all algorithms developed for feature 
00046  * selection
00047  *
00048  */
00049 namespace algorithms
00050 {
00051 
00052 namespace detail
00053 {
00054     /** create a MultiArray containing only columns supplied between iterators
00055         b and e
00056     */
00057     template<class OrigMultiArray,
00058              class Iter,
00059              class DestMultiArray>
00060     void choose(OrigMultiArray     const & in,
00061                 Iter               const & b,
00062                 Iter               const & e,
00063                 DestMultiArray        & out)
00064     {
00065         int columnCount = std::distance(b, e);
00066         int rowCount     = in.shape(0);
00067         out.reshape(MultiArrayShape<2>::type(rowCount, columnCount));
00068         int ii = 0;
00069         for(Iter iter = b; iter != e; ++iter, ++ii)
00070         {
00071             columnVector(out, ii) = columnVector(in, *iter);
00072         }
00073     }
00074 }
00075 
00076 
00077 
00078 /** Standard random forest Errorrate callback functor
00079  *
00080  * returns the random forest error estimate when invoked. 
00081  */
00082 class RFErrorCallback
00083 {
00084     RandomForestOptions options;
00085     
00086     public:
00087     /** Default constructor
00088      *
00089      * optionally supply options to the random forest classifier
00090      * \sa RandomForestOptions
00091      */
00092     RFErrorCallback(RandomForestOptions opt = RandomForestOptions())
00093     : options(opt)
00094     {}
00095 
00096     /** returns the RF OOB error estimate given features and 
00097      * labels
00098      */
00099     template<class Feature_t, class Response_t>
00100     double operator() (Feature_t const & features,
00101                        Response_t const & response)
00102     {
00103         RandomForest<>             rf(options);
00104         visitors::OOB_Error        oob;
00105         rf.learn(features, 
00106                  response, 
00107                  visitors::create_visitor(oob ));
00108         return oob.oob_breiman;
00109     }
00110 };
00111 
00112 
00113 /** Structure to hold Variable Selection results
00114  */
00115 class VariableSelectionResult
00116 {
00117     bool initialized;
00118 
00119   public:
00120     VariableSelectionResult()
00121     : initialized(false)
00122     {}
00123 
00124     typedef std::vector<int> FeatureList_t;
00125     typedef std::vector<double> ErrorList_t;
00126     typedef FeatureList_t::iterator Pivot_t;
00127 
00128     Pivot_t pivot;
00129 
00130     /** list of features. 
00131      */
00132     FeatureList_t selected;
00133     
00134     /** vector of size (number of features)
00135      *
00136      * the i-th entry encodes the error rate obtained
00137      * while using features [0 - i](including i) 
00138      *
00139      * if the i-th entry is -1 then no error rate was obtained
00140      * this may happen if more than one feature is added to the
00141      * selected list in one step of the algorithm.
00142      *
00143      * during initialisation error[m+n-1] is always filled
00144      */
00145     ErrorList_t errors;
00146     
00147 
00148     /** errorrate using no features
00149      */
00150     double no_features;
00151 
00152     template<class FeatureT, 
00153              class ResponseT, 
00154              class Iter,
00155              class ErrorRateCallBack>
00156     bool init(FeatureT const & all_features,
00157               ResponseT const & response,
00158               Iter b,
00159               Iter e,
00160               ErrorRateCallBack errorcallback)
00161     {
00162         bool ret_ = init(all_features, response, errorcallback); 
00163         if(!ret_)
00164             return false;
00165         vigra_precondition(std::distance(b, e) == selected.size(),
00166                            "Number of features in ranking != number of features matrix");
00167         std::copy(b, e, selected.begin());
00168         return true;
00169     }
00170     
00171     template<class FeatureT, 
00172              class ResponseT, 
00173              class Iter>
00174     bool init(FeatureT const & all_features,
00175               ResponseT const & response,
00176               Iter b,
00177               Iter e)
00178     {
00179         RFErrorCallback ecallback;
00180         return init(all_features, response, b, e, ecallback);
00181     }
00182 
00183 
00184     template<class FeatureT, 
00185              class ResponseT>
00186     bool init(FeatureT const & all_features,
00187               ResponseT const & response)
00188     {
00189         return init(all_features, response, RFErrorCallback());
00190     }
00191     /**initialization routine. Will be called only once in the lifetime
00192      * of a VariableSelectionResult. Subsequent calls will not reinitialize
00193      * member variables.
00194      *
00195      * This is intended, to allow continuing variable selection at a point 
00196      * stopped in an earlier iteration. 
00197      *
00198      * returns true if initialization was successful and false if 
00199      * the object was already initialized before.
00200      */
00201     template<class FeatureT, 
00202              class ResponseT,
00203              class ErrorRateCallBack>
00204     bool init(FeatureT const & all_features,
00205               ResponseT const & response,
00206               ErrorRateCallBack errorcallback)
00207     {
00208         if(initialized)
00209         {
00210             return false;
00211         }
00212         // calculate error with all features
00213         selected.resize(all_features.shape(1), 0);
00214         for(unsigned int ii = 0; ii < selected.size(); ++ii)
00215             selected[ii] = ii;
00216         errors.resize(all_features.shape(1), -1);
00217         errors.back() = errorcallback(all_features, response);
00218 
00219         // calculate error rate if no features are chosen 
00220         // corresponds to max(prior probability) of the classes
00221         std::map<typename ResponseT::value_type, int>     res_map;
00222         std::vector<int>                                 cts;
00223         int                                             counter = 0;
00224         for(int ii = 0; ii < response.shape(0); ++ii)
00225         {
00226             if(res_map.find(response(ii, 0)) == res_map.end())
00227             {
00228                 res_map[response(ii, 0)] = counter;
00229                 ++counter;
00230                 cts.push_back(0);
00231             }
00232             cts[res_map[response(ii,0)]] +=1;
00233         }
00234         no_features = double(*(std::max_element(cts.begin(),
00235                                                  cts.end())))
00236                     / double(response.shape(0));
00237 
00238         /*init not_selected vector;
00239         not_selected.resize(all_features.shape(1), 0);
00240         for(int ii = 0; ii < not_selected.size(); ++ii)
00241         {
00242             not_selected[ii] = ii;
00243         }
00244         initialized = true;
00245         */
00246         pivot = selected.begin();
00247         return true;
00248     }
00249 };
00250 
00251 
00252     
00253 /** Perform forward selection
00254  *
00255  * \param features    IN:     n x p matrix containing n instances with p attributes/features
00256  *                             used in the variable selection algorithm
00257  * \param response  IN:     n x 1 matrix containing the corresponding response
00258  * \param result    IN/OUT: VariableSelectionResult struct which will contain the results
00259  *                             of the algorithm. 
00260  *                             Features between result.selected.begin() and result.pivot will
00261  *                             be left untouched.
00262  *                             \sa VariableSelectionResult
00263  * \param errorcallback
00264  *                     IN, OPTIONAL: 
00265  *                             Functor that returns the error rate given a set of 
00266  *                             features and labels. Default is the RandomForest OOB Error.
00267  *
00268  * Forward selection subsequently chooses the next feature that decreases the Error rate most.
00269  *
00270  * usage:
00271  * \code
00272  *         MultiArray<2, double>     features = createSomeFeatures();
00273  *         MultiArray<2, int>        labels   = createCorrespondingLabels();
00274  *         VariableSelectionResult  result;
00275  *         forward_selection(features, labels, result);
00276  * \endcode
00277  * To use forward selection but ensure that a specific feature e.g. feature 5 is always 
00278  * included one would do the following
00279  *
00280  * \code
00281  *         VariableSelectionResult result;
00282  *         result.init(features, labels);
00283  *         std::swap(result.selected[0], result.selected[5]);
00284  *         result.setPivot(1);
00285  *         forward_selection(features, labels, result);
00286  * \endcode
00287  *
00288  * \sa VariableSelectionResult
00289  *
00290  */                    
00291 template<class FeatureT, class ResponseT, class ErrorRateCallBack>
00292 void forward_selection(FeatureT          const & features,
00293                        ResponseT          const & response,
00294                        VariableSelectionResult & result,
00295                        ErrorRateCallBack          errorcallback)
00296 {
00297     VariableSelectionResult::FeatureList_t & selected         = result.selected;
00298     VariableSelectionResult::ErrorList_t &     errors            = result.errors;
00299     VariableSelectionResult::Pivot_t       & pivot            = result.pivot;    
00300     int featureCount = features.shape(1);
00301     // initialize result struct if in use for the first time
00302     if(!result.init(features, response, errorcallback))
00303     {
00304         //result is being reused just ensure that the number of features is
00305         //the same.
00306         vigra_precondition(selected.size() == featureCount,
00307                            "forward_selection(): Number of features in Feature "
00308                            "matrix and number of features in previously used "
00309                            "result struct mismatch!");
00310     }
00311     
00312 
00313     int not_selected_size = std::distance(pivot, selected.end());
00314     while(not_selected_size > 1)
00315     {
00316         std::vector<int> current_errors;
00317         VariableSelectionResult::Pivot_t next = pivot;
00318         for(int ii = 0; ii < not_selected_size; ++ii, ++next)
00319         {
00320             std::swap(*pivot, *next);
00321             MultiArray<2, double> cur_feats;
00322             detail::choose( features, 
00323                             selected.begin(), 
00324                             pivot+1, 
00325                             cur_feats);
00326             double error = errorcallback(cur_feats, response);
00327             current_errors.push_back(error);
00328             std::swap(*pivot, *next);
00329         }
00330         int pos = std::distance(current_errors.begin(),
00331                                 std::min_element(current_errors.begin(),
00332                                                    current_errors.end()));
00333         next = pivot;
00334         std::advance(next, pos);
00335         std::swap(*pivot, *next);
00336         errors[std::distance(selected.begin(), pivot)] = current_errors[pos];
00337         ++pivot;
00338         not_selected_size = std::distance(pivot, selected.end());
00339     }
00340 }
00341 template<class FeatureT, class ResponseT>
00342 void forward_selection(FeatureT          const & features,
00343                        ResponseT          const & response,
00344                        VariableSelectionResult & result)
00345 {
00346     forward_selection(features, response, result, RFErrorCallback());
00347 }
00348 
00349 
00350 /** Perform backward elimination
00351  *
00352  * \param features    IN:     n x p matrix containing n instances with p attributes/features
00353  *                             used in the variable selection algorithm
00354  * \param response  IN:     n x 1 matrix containing the corresponding response
00355  * \param result    IN/OUT: VariableSelectionResult struct which will contain the results
00356  *                             of the algorithm. 
00357  *                             Features between result.pivot and result.selected.end() will
00358  *                             be left untouched.
00359  *                             \sa VariableSelectionResult
00360  * \param errorcallback
00361  *                     IN, OPTIONAL: 
00362  *                             Functor that returns the error rate given a set of 
00363  *                             features and labels. Default is the RandomForest OOB Error.
00364  *
00365  * Backward elimination subsequently eliminates features that have the least influence
00366  * on the error rate
00367  *
00368  * usage:
00369  * \code
00370  *         MultiArray<2, double>     features = createSomeFeatures();
00371  *         MultiArray<2, int>        labels   = createCorrespondingLabels();
00372  *         VariableSelectionResult  result;
00373  *         backward_elimination(features, labels, result);
00374  * \endcode
00375  * To use backward elimination but ensure that a specific feature e.g. feature 5 is always 
00376  * excluded one would do the following:
00377  *
00378  * \code
00379  *         VariableSelectionResult result;
00380  *         result.init(features, labels);
00381  *         std::swap(result.selected[result.selected.size()-1], result.selected[5]);
00382  *         result.setPivot(result.selected.size()-1);
00383  *         backward_elimination(features, labels, result);
00384  * \endcode
00385  *
00386  * \sa VariableSelectionResult
00387  *
00388  */                    
00389 template<class FeatureT, class ResponseT, class ErrorRateCallBack>
00390 void backward_elimination(FeatureT              const & features,
00391                              ResponseT         const & response,
00392                           VariableSelectionResult & result,
00393                           ErrorRateCallBack         errorcallback)
00394 {
00395     int featureCount = features.shape(1);
00396     VariableSelectionResult::FeatureList_t & selected         = result.selected;
00397     VariableSelectionResult::ErrorList_t &     errors            = result.errors;
00398     VariableSelectionResult::Pivot_t       & pivot            = result.pivot;    
00399     
00400     // initialize result struct if in use for the first time
00401     if(!result.init(features, response, errorcallback))
00402     {
00403         //result is being reused just ensure that the number of features is
00404         //the same.
00405         vigra_precondition(selected.size() == featureCount,
00406                            "backward_elimination(): Number of features in Feature "
00407                            "matrix and number of features in previously used "
00408                            "result struct mismatch!");
00409     }
00410     pivot = selected.end() - 1;    
00411 
00412     int selected_size = std::distance(selected.begin(), pivot);
00413     while(selected_size > 1)
00414     {
00415         VariableSelectionResult::Pivot_t next = selected.begin();
00416         std::vector<int> current_errors;
00417         for(int ii = 0; ii < selected_size; ++ii, ++next)
00418         {
00419             std::swap(*pivot, *next);
00420             MultiArray<2, double> cur_feats;
00421             detail::choose( features, 
00422                             selected.begin(), 
00423                             pivot, 
00424                             cur_feats);
00425             double error = errorcallback(cur_feats, response);
00426             current_errors.push_back(error);
00427             std::swap(*pivot, *next);
00428         }
00429         int pos = std::distance(current_errors.begin(),
00430                                 std::max_element(current_errors.begin(),
00431                                                    current_errors.end()));
00432         next = selected.begin();
00433         std::advance(next, pos);
00434         std::swap(*pivot, *next);
00435 //        std::cerr << std::distance(selected.begin(), pivot) << " " << pos << " " << current_errors.size() << " " << errors.size() << std::endl;
00436         errors[std::distance(selected.begin(), pivot)] = current_errors[pos];
00437         selected_size = std::distance(selected.begin(), pivot);
00438         --pivot;
00439     }
00440 }
00441 
00442 template<class FeatureT, class ResponseT>
00443 void backward_elimination(FeatureT              const & features,
00444                              ResponseT         const & response,
00445                           VariableSelectionResult & result)
00446 {
00447     backward_elimination(features, response, result, RFErrorCallback());
00448 }
00449 
00450 /** Perform rank selection using a predefined ranking
00451  *
00452  * \param features    IN:     n x p matrix containing n instances with p attributes/features
00453  *                             used in the variable selection algorithm
00454  * \param response  IN:     n x 1 matrix containing the corresponding response
00455  * \param result    IN/OUT: VariableSelectionResult struct which will contain the results
00456  *                             of the algorithm. The struct should be initialized with the
00457  *                             predefined ranking.
00458  *                         
00459  *                             \sa VariableSelectionResult
00460  * \param errorcallback
00461  *                     IN, OPTIONAL: 
00462  *                             Functor that returns the error rate given a set of 
00463  *                             features and labels. Default is the RandomForest OOB Error.
00464  *
00465  * Often some variable importance, score measure is used to create the ordering in which
00466  * variables have to be selected. This method takes such a ranking and calculates the 
00467  * corresponding error rates. 
00468  *
00469  * usage:
00470  * \code
00471  *         MultiArray<2, double>     features = createSomeFeatures();
00472  *         MultiArray<2, int>        labels   = createCorrespondingLabels();
00473  *         std::vector<int>        ranking  = createRanking(features);
00474  *         VariableSelectionResult  result;
00475  *         result.init(features, labels, ranking.begin(), ranking.end());
00476  *         backward_elimination(features, labels, result);
00477  * \endcode
00478  *
00479  * \sa VariableSelectionResult
00480  *
00481  */                    
00482 template<class FeatureT, class ResponseT, class ErrorRateCallBack>
00483 void rank_selection      (FeatureT              const & features,
00484                              ResponseT         const & response,
00485                           VariableSelectionResult & result,
00486                           ErrorRateCallBack         errorcallback)
00487 {
00488     VariableSelectionResult::FeatureList_t & selected         = result.selected;
00489     VariableSelectionResult::ErrorList_t &     errors            = result.errors;
00490     VariableSelectionResult::Pivot_t       & iter            = result.pivot;
00491     int featureCount = features.shape(1);
00492     // initialize result struct if in use for the first time
00493     if(!result.init(features, response, errorcallback))
00494     {
00495         //result is being reused just ensure that the number of features is
00496         //the same.
00497         vigra_precondition(selected.size() == featureCount,
00498                            "forward_selection(): Number of features in Feature "
00499                            "matrix and number of features in previously used "
00500                            "result struct mismatch!");
00501     }
00502     
00503     int ii = 0;
00504     for(; iter != selected.end(); ++iter)
00505     {
00506 //        std::cerr << ii<< std::endl;
00507         ++ii;
00508         MultiArray<2, double> cur_feats;
00509         detail::choose( features, 
00510                         selected.begin(), 
00511                         iter, 
00512                         cur_feats);
00513         double error = errorcallback(cur_feats, response);
00514         errors[std::distance(selected.begin(), iter)] = error;
00515 
00516     }
00517 }
00518 
00519 template<class FeatureT, class ResponseT>
00520 void rank_selection      (FeatureT              const & features,
00521                              ResponseT         const & response,
00522                           VariableSelectionResult & result)
00523 {
00524     rank_selection(features, response, result, RFErrorCallback());
00525 }
00526 
00527 
00528 
00529 enum ClusterLeafTypes{c_Leaf = 95, c_Node = 99};
00530 
00531 /* View of a Node in the hierarchical clustering 
00532  * class 
00533  * For internal use only - 
00534  * \sa NodeBase
00535  */
00536 class ClusterNode
00537 : public NodeBase
00538 {
00539     public:
00540 
00541     typedef NodeBase BT;
00542 
00543         /**constructors **/
00544     ClusterNode():NodeBase(){}
00545     ClusterNode(    int                      nCol,
00546                     BT::T_Container_type    &   topology,
00547                     BT::P_Container_type    &   split_param)
00548                 :   BT(nCol + 5, 5,topology, split_param)
00549     {
00550         status() = 0; 
00551         BT::column_data()[0] = nCol;
00552         if(nCol == 1)
00553             BT::typeID() = c_Leaf;
00554         else
00555             BT::typeID() = c_Node;
00556     }
00557 
00558     ClusterNode(           BT::T_Container_type  const  &   topology,
00559                     BT::P_Container_type  const  &   split_param,
00560                     int                  n             )
00561                 :   NodeBase(5 , 5,topology, split_param, n)
00562     {
00563         //TODO : is there a more elegant way to do this?
00564         BT::topology_size_ += BT::column_data()[0];
00565     }
00566 
00567     ClusterNode( BT & node_)
00568         :   BT(5, 5, node_) 
00569     {
00570         //TODO : is there a more elegant way to do this?
00571         BT::topology_size_ += BT::column_data()[0];
00572         BT::parameter_size_ += 0;
00573     }
00574     int index()
00575     {
00576         return static_cast<int>(BT::parameters_begin()[1]);
00577     }
00578     void set_index(int in)
00579     {
00580         BT::parameters_begin()[1] = in;
00581     }
00582     double& mean()
00583     {
00584         return BT::parameters_begin()[2];
00585     }
00586     double& stdev()
00587     {
00588         return BT::parameters_begin()[3];
00589     }
00590     double& status()
00591     {
00592         return BT::parameters_begin()[4];
00593     }
00594 };
00595 
00596 /** Stackentry class for HClustering class
00597  */
00598 struct HC_Entry
00599 {
00600     int parent;
00601     int level;
00602     int addr; 
00603     bool infm;
00604     HC_Entry(int p, int l, int a, bool in)
00605         : parent(p), level(l), addr(a), infm(in)
00606     {}
00607 };
00608 
00609 
00610 /** Hierarchical Clustering class. 
00611  * Performs single linkage clustering
00612  * \code
00613  *         Matrix<double> distance = get_distance_matrix();
00614  *      linkage.cluster(distance);
00615  *      // Draw clustering tree.
00616  *      Draw<double, int> draw(features, labels, "linkagetree.graph");
00617  *      linkage.breadth_first_traversal(draw);
00618  * \endcode
00619  * \sa ClusterImportanceVisitor
00620  *
00621  * once the clustering has taken place. Information queries can be made
00622  * using the breadth_first_traversal() method and iterate() method
00623  *
00624  */
00625 class HClustering
00626 {
00627 public:
00628     typedef MultiArrayShape<2>::type Shp;
00629     ArrayVector<int>         topology_;
00630     ArrayVector<double>     parameters_;
00631     int                     begin_addr;
00632 
00633     // Calculates the distance between two 
00634     double dist_func(double a, double b)
00635     {
00636         return std::min(a, b); 
00637     }
00638 
00639     /** Visit each node with a Functor 
00640      * in creation order (should be depth first)
00641      */
00642     template<class Functor>
00643     void iterate(Functor & tester)
00644     {
00645 
00646         std::vector<int> stack; 
00647         stack.push_back(begin_addr); 
00648         while(!stack.empty())
00649         {
00650             ClusterNode node(topology_, parameters_, stack.back());
00651             stack.pop_back();
00652             if(!tester(node))
00653             {
00654                 if(node.columns_size() != 1)
00655                 {
00656                     stack.push_back(node.child(0));
00657                     stack.push_back(node.child(1));
00658                 }
00659             }
00660         }
00661     }
00662 
00663     /** Perform breadth first traversal of hierarchical cluster tree
00664      */
00665     template<class Functor>
00666     void breadth_first_traversal(Functor & tester)
00667     {
00668 
00669         std::queue<HC_Entry> queue; 
00670         int level = 0;
00671         int parent = -1;
00672         int addr   = -1;
00673         bool infm  = false;
00674         queue.push(HC_Entry(parent,level,begin_addr, infm)); 
00675         while(!queue.empty())
00676         {
00677             level  = queue.front().level;
00678             parent = queue.front().parent;
00679             addr   = queue.front().addr;
00680             infm   = queue.front().infm;
00681             ClusterNode node(topology_, parameters_, queue.front().addr);
00682             ClusterNode parnt;
00683             if(parent != -1)
00684             {
00685                 parnt = ClusterNode(topology_, parameters_, parent); 
00686             }
00687             queue.pop();
00688             bool istrue = tester(node, level, parnt, infm);
00689             if(node.columns_size() != 1)
00690             {
00691                 queue.push(HC_Entry(addr, level +1,node.child(0),istrue));
00692                 queue.push(HC_Entry(addr, level +1,node.child(1),istrue));
00693             }
00694         }
00695     }
00696 #ifdef HasHDF5
00697     /**save to HDF5 - defunct - has to be updated to new HDF5 interface
00698      */
00699     void save(std::string file, std::string prefix)
00700     {
00701         
00702         vigra::writeHDF5(file.c_str(), (prefix + "topology").c_str(), 
00703                                MultiArrayView<2, int>(
00704                                     Shp(topology_.size(),1),
00705                                     topology_.data()));
00706         vigra::writeHDF5(file.c_str(), (prefix + "parameters").c_str(), 
00707                                MultiArrayView<2, double>(
00708                                     Shp(parameters_.size(), 1),
00709                                     parameters_.data()));
00710         vigra::writeHDF5(file.c_str(), (prefix + "begin_addr").c_str(), 
00711                                MultiArrayView<2, int>(Shp(1,1), &begin_addr));
00712                                
00713     }
00714 #endif
00715 
00716     /**Perform single linkage clustering
00717      * \param distance distance matrix used. \sa CorrelationVisitor
00718      */
00719     template<class T, class C>
00720     void cluster(MultiArrayView<2, T, C> distance)
00721     {
00722         MultiArray<2, T> dist(distance); 
00723         std::vector<std::pair<int, int> > addr; 
00724         typedef std::pair<int, int>  Entry;
00725         int index = 0;
00726         for(int ii = 0; ii < distance.shape(0); ++ii)
00727         {
00728             addr.push_back(std::make_pair(topology_.size(), ii));
00729             ClusterNode leaf(1, topology_, parameters_);
00730             leaf.set_index(index);
00731             ++index;
00732             leaf.columns_begin()[0] = ii;
00733         }
00734 
00735         while(addr.size() != 1)
00736         {
00737             //find the two nodes with the smallest distance
00738             int ii_min = 0;
00739             int jj_min = 1;
00740             double min_dist = dist((addr.begin()+ii_min)->second, 
00741                               (addr.begin()+jj_min)->second);
00742             for(unsigned int ii = 0; ii < addr.size(); ++ii)
00743             {
00744                 for(unsigned int jj = ii+1; jj < addr.size(); ++jj)
00745                 {
00746                     if(  dist((addr.begin()+ii_min)->second, 
00747                               (addr.begin()+jj_min)->second)
00748                        > dist((addr.begin()+ii)->second, 
00749                               (addr.begin()+jj)->second))
00750                     {
00751                         min_dist = dist((addr.begin()+ii)->second, 
00752                               (addr.begin()+jj)->second);
00753                         ii_min = ii; 
00754                         jj_min = jj;
00755                     }
00756                 }
00757             }
00758 
00759             //merge two nodes
00760             int col_size = 0;
00761             // The problem is that creating a new node invalidates the iterators stored
00762             // in firstChild and secondChild.
00763             {
00764                 ClusterNode firstChild(topology_, 
00765                                        parameters_, 
00766                                        (addr.begin() +ii_min)->first);
00767                 ClusterNode secondChild(topology_, 
00768                                        parameters_, 
00769                                        (addr.begin() +jj_min)->first);
00770                 col_size = firstChild.columns_size() + secondChild.columns_size();
00771             }
00772             int cur_addr = topology_.size();
00773             begin_addr = cur_addr;
00774 //            std::cerr << col_size << std::endl;
00775             ClusterNode parent(col_size,
00776                                topology_,
00777                                parameters_); 
00778             ClusterNode firstChild(topology_, 
00779                                    parameters_, 
00780                                    (addr.begin() +ii_min)->first);
00781             ClusterNode secondChild(topology_, 
00782                                    parameters_, 
00783                                    (addr.begin() +jj_min)->first);
00784             parent.parameters_begin()[0] = min_dist;
00785             parent.set_index(index);
00786             ++index;
00787             std::merge(firstChild.columns_begin(), firstChild.columns_end(),
00788                        secondChild.columns_begin(),secondChild.columns_end(),
00789                        parent.columns_begin());
00790             //merge nodes in addr
00791             int to_keep;
00792             int to_desc;
00793             int ii_keep;
00794             if(*parent.columns_begin() ==  *firstChild.columns_begin())
00795             {
00796                 parent.child(0) = (addr.begin()+ii_min)->first;
00797                 parent.child(1) = (addr.begin()+jj_min)->first;
00798                 (addr.begin()+ii_min)->first = cur_addr;
00799                 ii_keep = ii_min;
00800                 to_keep = (addr.begin()+ii_min)->second;
00801                 to_desc = (addr.begin()+jj_min)->second;
00802                 addr.erase(addr.begin()+jj_min);
00803             }
00804             else
00805             {
00806                 parent.child(1) = (addr.begin()+ii_min)->first;
00807                 parent.child(0) = (addr.begin()+jj_min)->first;
00808                 (addr.begin()+jj_min)->first = cur_addr;
00809                 ii_keep = jj_min;
00810                 to_keep = (addr.begin()+jj_min)->second;
00811                 to_desc = (addr.begin()+ii_min)->second;
00812                 addr.erase(addr.begin()+ii_min);
00813             }
00814             //update distances;
00815             
00816             for(unsigned int jj = 0 ; jj < addr.size(); ++jj)
00817             {
00818                 if(jj == ii_keep)
00819                     continue;
00820                 double bla = dist_func(
00821                                   dist(to_desc, (addr.begin()+jj)->second),
00822                                   dist((addr.begin()+ii_keep)->second,
00823                                         (addr.begin()+jj)->second));
00824 
00825                 dist((addr.begin()+ii_keep)->second,
00826                      (addr.begin()+jj)->second) = bla;
00827                 dist((addr.begin()+jj)->second,
00828                      (addr.begin()+ii_keep)->second) = bla;
00829             }
00830         }
00831     }
00832 
00833 };
00834 
00835 
00836 /** Normalize the status value in the HClustering tree (HClustering Visitor)
00837  */
00838 class NormalizeStatus
00839 {
00840 public:
00841     double n;
00842     /** Constructor
00843      * \param m normalize status() by m
00844      */
00845     NormalizeStatus(double m)
00846         :n(m)
00847     {}
00848     template<class Node>
00849     bool operator()(Node& node)
00850     {
00851         node.status()/=n;
00852         return false;
00853     }
00854 };
00855 
00856 
00857 /** Perform Permutation importance on HClustering clusters
00858  * (See visit_after_tree() method of visitors::VariableImportance to 
00859  * see the basic idea. (Just that we apply the permutation not only to
00860  * variables but also to clusters))
00861  */
00862 template<class Iter, class DT>
00863 class PermuteCluster
00864 {
00865 public:
00866     typedef MultiArrayShape<2>::type Shp;
00867     Matrix<double> tmp_mem_;
00868     MultiArrayView<2, double> perm_imp;
00869     MultiArrayView<2, double> orig_imp;
00870     Matrix<double> feats_;
00871     Matrix<int>    labels_;
00872     const int      nPerm;
00873     DT const &           dt;
00874     int index;
00875     int oob_size;
00876 
00877     template<class Feat_T, class Label_T>
00878     PermuteCluster(Iter  a, 
00879                    Iter  b,
00880                    Feat_T const & feats,
00881                    Label_T const & labls, 
00882                    MultiArrayView<2, double> p_imp, 
00883                    MultiArrayView<2, double> o_imp, 
00884                    int np,
00885                    DT const  & dt_)
00886         :tmp_mem_(_spl(a, b).size(), feats.shape(1)),
00887          perm_imp(p_imp),
00888          orig_imp(o_imp),
00889          feats_(_spl(a,b).size(), feats.shape(1)),
00890          labels_(_spl(a,b).size(),1),
00891          nPerm(np),
00892          dt(dt_),
00893          index(0),
00894          oob_size(b-a)
00895     {
00896         copy_splice(_spl(a,b),
00897                     _spl(feats.shape(1)),
00898                     feats,
00899                     feats_);
00900         copy_splice(_spl(a,b),
00901                     _spl(labls.shape(1)),
00902                     labls,
00903                     labels_);
00904     }
00905 
00906     template<class Node>
00907     bool operator()(Node& node)
00908     {
00909         tmp_mem_ = feats_;
00910         RandomMT19937 random;
00911         int class_count = perm_imp.shape(1) - 1;
00912         //permute columns together
00913         for(int kk = 0; kk < nPerm; ++kk)
00914         {
00915             tmp_mem_ = feats_;
00916             for(int ii = 0; ii < rowCount(feats_); ++ii)
00917             {
00918                 int index = random.uniformInt(rowCount(feats_) - ii) +ii;
00919                 for(int jj = 0; jj < node.columns_size(); ++jj)
00920                 {
00921                     if(node.columns_begin()[jj] != feats_.shape(1))
00922                         tmp_mem_(ii, node.columns_begin()[jj]) 
00923                             = tmp_mem_(index, node.columns_begin()[jj]);
00924                 }
00925             }
00926             
00927             for(int ii = 0; ii < rowCount(tmp_mem_); ++ii)
00928             {
00929                 if(dt
00930                         .predictLabel(rowVector(tmp_mem_, ii)) 
00931                     ==  labels_(ii, 0))
00932                 {
00933                     //per class
00934                     ++perm_imp(index,labels_(ii, 0));
00935                     //total
00936                     ++perm_imp(index, class_count);
00937                 }
00938             }
00939         }
00940         double node_status  = perm_imp(index, class_count);
00941         node_status /= nPerm;
00942         node_status -= orig_imp(0, class_count);
00943         node_status *= -1;
00944         node_status /= oob_size;
00945         node.status() += node_status;
00946         ++index;
00947          
00948         return false;
00949     }
00950 };
00951 
00952 /** Convert ClusteringTree into a list (HClustering visitor)
00953  */
00954 class GetClusterVariables
00955 {
00956 public:
00957     /** NumberOfClusters x NumberOfVariables MultiArrayView containing
00958      * in each row the variable belonging to a cluster
00959      */
00960     MultiArrayView<2, int>    variables;
00961     int index;
00962     GetClusterVariables(MultiArrayView<2, int> vars)
00963         :variables(vars), index(0)
00964     {}
00965 #ifdef HasHDF5
00966     void save(std::string file, std::string prefix)
00967     {
00968         vigra::writeHDF5(file.c_str(), (prefix + "_variables").c_str(), 
00969                                variables);
00970     }
00971 #endif
00972 
00973     template<class Node>
00974     bool operator()(Node& node)
00975     {
00976         for(int ii = 0; ii < node.columns_size(); ++ii)
00977             variables(index, ii) = node.columns_begin()[ii];
00978         ++index;
00979         return false;
00980     }
00981 };
00982 /** corrects the status fields of a linkage Clustering (HClustering Visitor)
00983  *  
00984  *  such that status(currentNode) = min(status(parent), status(currentNode))
00985  *  \sa cluster_permutation_importance()
00986  */
00987 class CorrectStatus
00988 {
00989 public:
00990     template<class Nde>
00991     bool operator()(Nde & cur, int level, Nde parent, bool infm)
00992     {
00993         if(parent.hasData_)
00994             cur.status() = std::min(parent.status(), cur.status());
00995         return true;
00996     }
00997 };
00998 
00999 
01000 /** draw current linkage Clustering (HClustering Visitor)
01001  *
01002  * create a graphviz .dot file
01003  * usage:
01004  * \code
01005  *         Matrix<double> distance = get_distance_matrix();
01006  *      linkage.cluster(distance);
01007  *      Draw<double, int> draw(features, labels, "linkagetree.graph");
01008  *      linkage.breadth_first_traversal(draw);
01009  * \endcode 
01010  */
01011 template<class T1,
01012          class T2, 
01013          class C1 = UnstridedArrayTag,
01014          class C2 = UnstridedArrayTag> 
01015 class Draw
01016 {
01017 public:
01018     typedef MultiArrayShape<2>::type Shp;
01019     MultiArrayView<2, T1, C1> const &   features_;
01020     MultiArrayView<2, T2, C2> const &   labels_;
01021     std::ofstream graphviz;
01022 
01023 
01024     Draw(MultiArrayView<2, T1, C1> const & features, 
01025          MultiArrayView<2, T2, C2> const& labels,
01026          std::string const  gz)
01027         :features_(features), labels_(labels), 
01028         graphviz(gz.c_str(), std::ios::out)
01029     {
01030         graphviz << "digraph G\n{\n node [shape=\"record\"]";
01031     }
01032     ~Draw()
01033     {
01034         graphviz << "\n}\n";
01035         graphviz.close();
01036     }
01037 
01038     template<class Nde>
01039     bool operator()(Nde & cur, int level, Nde parent, bool infm)
01040     {
01041         graphviz << "node" << cur.index() << " [style=\"filled\"][label = \" #Feats: "<< cur.columns_size() << "\\n";
01042         graphviz << " status: " << cur.status() << "\\n";
01043         for(int kk = 0; kk < cur.columns_size(); ++kk)
01044         {
01045                 graphviz  << cur.columns_begin()[kk] << " ";
01046                 if(kk % 15 == 14)
01047                     graphviz << "\\n";
01048         }
01049         graphviz << "\"] [color = \"" <<cur.status() << " 1.000 1.000\"];\n";
01050         if(parent.hasData_)
01051         graphviz << "\"node" << parent.index() << "\" -> \"node" << cur.index() <<"\";\n";
01052         return true;
01053     }
01054 };
01055 
01056 /** calculate Cluster based permutation importance while learning. (RandomForestVisitor)
01057  */
01058 class ClusterImportanceVisitor : public visitors::VisitorBase
01059 {
01060     public:
01061 
01062     /** List of variables as produced by GetClusterVariables
01063      */
01064     MultiArray<2, int>          variables;
01065     /** Corresponding importance measures
01066      */
01067     MultiArray<2, double>       cluster_importance_;
01068     /** Corresponding error
01069      */
01070     MultiArray<2, double>       cluster_stdev_;
01071     int                         repetition_count_;
01072     bool                        in_place_;
01073     HClustering            &    clustering;
01074 
01075 
01076 #ifdef HasHDF5
01077     void save(std::string filename, std::string prefix)
01078     {
01079         std::string prefix1 = "cluster_importance_" + prefix;
01080         writeHDF5(filename.c_str(), 
01081                         prefix1.c_str(), 
01082                         cluster_importance_);
01083         prefix1 = "vars_" + prefix;
01084         writeHDF5(filename.c_str(), 
01085                         prefix1.c_str(), 
01086                         variables);
01087     }
01088 #endif
01089 
01090     ClusterImportanceVisitor(HClustering & clst, int rep_cnt = 10) 
01091     :   repetition_count_(rep_cnt), clustering(clst)
01092 
01093     {}
01094 
01095     /** Allocate enough memory 
01096      */
01097     template<class RF, class PR>
01098     void visit_at_beginning(RF const & rf, PR const & pr)
01099     {
01100         Int32 const  class_count = rf.ext_param_.class_count_;
01101         Int32 const  column_count = rf.ext_param_.column_count_+1;
01102         cluster_importance_
01103             .reshape(MultiArrayShape<2>::type(2*column_count-1, 
01104                                                 class_count+1));
01105         cluster_stdev_
01106             .reshape(MultiArrayShape<2>::type(2*column_count-1, 
01107                                                 class_count+1));
01108         variables
01109             .reshape(MultiArrayShape<2>::type(2*column_count-1, 
01110                                                 column_count), -1);
01111         GetClusterVariables gcv(variables);
01112         clustering.iterate(gcv);
01113         
01114     }
01115 
01116     /**compute permutation based var imp. 
01117      * (Only an Array of size oob_sample_count x 1 is created.
01118      *  - apposed to oob_sample_count x feature_count in the other method.
01119      * 
01120      * \sa FieldProxy
01121      */
01122     template<class RF, class PR, class SM, class ST>
01123     void after_tree_ip_impl(RF& rf, PR & pr,  SM & sm, ST & st, int index)
01124     {
01125         typedef MultiArrayShape<2>::type Shp_t;
01126         Int32                   column_count = rf.ext_param_.column_count_ +1;
01127         Int32                   class_count  = rf.ext_param_.class_count_;  
01128         
01129         // remove the const cast on the features (yep , I know what I am 
01130         // doing here.) data is not destroyed.
01131         typename PR::Feature_t & features 
01132             = const_cast<typename PR::Feature_t &>(pr.features());
01133 
01134         //find the oob indices of current tree. 
01135         ArrayVector<Int32>      oob_indices;
01136         ArrayVector<Int32>::iterator
01137                                 iter;
01138         
01139         if(rf.ext_param_.actual_msample_ < pr.features().shape(0)- 10000)
01140         {
01141             ArrayVector<int> cts(2, 0);
01142             ArrayVector<Int32> indices(pr.features().shape(0));
01143             for(int ii = 0; ii < pr.features().shape(0); ++ii)
01144                indices.push_back(ii); 
01145             std::random_shuffle(indices.begin(), indices.end());
01146             for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
01147             {
01148                 if(!sm.is_used()[indices[ii]] && cts[pr.response()(indices[ii], 0)] < 3000)
01149                 {
01150                     oob_indices.push_back(indices[ii]);
01151                     ++cts[pr.response()(indices[ii], 0)];
01152                 }
01153             }
01154         }
01155         else
01156         {
01157             for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
01158                 if(!sm.is_used()[ii])
01159                     oob_indices.push_back(ii);
01160         }
01161 
01162         // Random foo
01163         RandomMT19937           random(RandomSeed);
01164         UniformIntRandomFunctor<RandomMT19937>  
01165                                 randint(random);
01166 
01167         //make some space for the results
01168         MultiArray<2, double>
01169                     oob_right(Shp_t(1, class_count + 1)); 
01170         
01171         // get the oob success rate with the original samples
01172         for(iter = oob_indices.begin(); 
01173             iter != oob_indices.end(); 
01174             ++iter)
01175         {
01176             if(rf.tree(index)
01177                     .predictLabel(rowVector(features, *iter)) 
01178                 ==  pr.response()(*iter, 0))
01179             {
01180                 //per class
01181                 ++oob_right[pr.response()(*iter,0)];
01182                 //total
01183                 ++oob_right[class_count];
01184             }
01185         }
01186         
01187         MultiArray<2, double>
01188                     perm_oob_right (Shp_t(2* column_count-1, class_count + 1)); 
01189         
01190         PermuteCluster<ArrayVector<Int32>::iterator,typename RF::DecisionTree_t>
01191             pc(oob_indices.begin(), oob_indices.end(), 
01192                             pr.features(),
01193                             pr.response(),
01194                             perm_oob_right,
01195                             oob_right,
01196                             repetition_count_,
01197                             rf.tree(index));
01198         clustering.iterate(pc);
01199 
01200         perm_oob_right  /=  repetition_count_;
01201         for(int ii = 0; ii < rowCount(perm_oob_right); ++ii)
01202             rowVector(perm_oob_right, ii) -= oob_right;
01203 
01204         perm_oob_right       *= -1;
01205         perm_oob_right       /= oob_indices.size();
01206         cluster_importance_  += perm_oob_right;
01207     }
01208 
01209     /** calculate permutation based impurity after every tree has been 
01210      * learned  default behaviour is that this happens out of place.
01211      * If you have very big data sets and want to avoid copying of data 
01212      * set the in_place_ flag to true. 
01213      */
01214     template<class RF, class PR, class SM, class ST>
01215     void visit_after_tree(RF& rf, PR & pr,  SM & sm, ST & st, int index)
01216     {    
01217             after_tree_ip_impl(rf, pr, sm, st, index);
01218     }
01219 
01220     /** Normalise variable importance after the number of trees is known.
01221      */
01222     template<class RF, class PR>
01223     void visit_at_end(RF & rf, PR & pr)
01224     {
01225         NormalizeStatus nrm(rf.tree_count());
01226         clustering.iterate(nrm);
01227         cluster_importance_ /= rf.trees_.size();
01228     }
01229 };
01230 
01231 /** Perform hierarchical clustering of variables and assess importance of clusters
01232  *
01233  * \param features    IN:     n x p matrix containing n instances with p attributes/features
01234  *                             used in the variable selection algorithm
01235  * \param response  IN:     n x 1 matrix containing the corresponding response
01236  * \param linkage    OUT:    Hierarchical grouping of variables.
01237  * \param distance  OUT:    distance matrix used for creating the linkage
01238  *
01239  * Performs Hierarchical clustering of variables. And calculates the permutation importance 
01240  * measures of each of the clusters. Use the Draw functor to create human readable output
01241  * The cluster-permutation importance measure corresponds to the normal permutation importance
01242  * measure with all columns corresponding to a cluster permuted. 
01243  * The importance measure for each cluster is stored as the status() field of each clusternode
01244  * \sa HClustering
01245  *
01246  * usage:
01247  * \code
01248  *         MultiArray<2, double>     features = createSomeFeatures();
01249  *         MultiArray<2, int>        labels   = createCorrespondingLabels();
01250  *         HClustering                linkage;
01251  *         MultiArray<2, double>    distance;
01252  *         cluster_permutation_importance(features, labels, linkage, distance)
01253  *        // create graphviz output
01254  *
01255  *      Draw<double, int> draw(features, labels, "linkagetree.graph");
01256  *      linkage.breadth_first_traversal(draw);
01257  *
01258  * \endcode
01259  *
01260  *
01261  */                    
01262 template<class FeatureT, class ResponseT>
01263 void cluster_permutation_importance(FeatureT              const & features,
01264                                          ResponseT         const &     response,
01265                                     HClustering               & linkage,
01266                                     MultiArray<2, double>      & distance)
01267 {
01268 
01269         RandomForestOptions opt;
01270         opt.tree_count(100);
01271         if(features.shape(0) > 40000)
01272             opt.samples_per_tree(20000).use_stratification(RF_EQUAL);
01273 
01274 
01275         vigra::RandomForest<int> RF(opt); 
01276         visitors::RandomForestProgressVisitor             progress;
01277         visitors::CorrelationVisitor                     missc;
01278         RF.learn(features, response,
01279                  create_visitor(missc, progress));
01280         distance = missc.distance;
01281         /*
01282            missc.save(exp_dir + dset.name() + "_result.h5", dset.name()+"MACH");
01283            */
01284 
01285 
01286         // Produce linkage
01287         linkage.cluster(distance);
01288         
01289         //linkage.save(exp_dir + dset.name() + "_result.h5", "_linkage_CC/");
01290         vigra::RandomForest<int> RF2(opt); 
01291         ClusterImportanceVisitor          ci(linkage);
01292         RF2.learn(features, 
01293                   response,
01294                   create_visitor(progress, ci));
01295         
01296         
01297         CorrectStatus cs;
01298         linkage.breadth_first_traversal(cs);
01299 
01300         //ci.save(exp_dir + dset.name() + "_result.h5", dset.name());
01301         //Draw<double, int> draw(dset.features(), dset.response(), exp_dir+ dset.name() + ".graph");
01302         //linkage.breadth_first_traversal(draw);
01303 
01304 }
01305 
01306     
01307 template<class FeatureT, class ResponseT>
01308 void cluster_permutation_importance(FeatureT              const & features,
01309                                          ResponseT         const &     response,
01310                                     HClustering               & linkage)
01311 {
01312     MultiArray<2, double> distance;
01313     cluster_permutation_importance(features, response, linkage, distance);
01314 }
01315 }//namespace algorithms
01316 }//namespace rf
01317 }//namespace vigra

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.7.1 (Sun Feb 19 2012)