[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

vigra/random_forest/rf_split.hxx
00001 /************************************************************************/
00002 /*                                                                      */
00003 /*        Copyright 2008-2009 by  Ullrich Koethe and Rahul Nair         */
00004 /*                                                                      */
00005 /*    This file is part of the VIGRA computer vision library.           */
00006 /*    The VIGRA Website is                                              */
00007 /*        http://hci.iwr.uni-heidelberg.de/vigra/                       */
00008 /*    Please direct questions, bug reports, and contributions to        */
00009 /*        ullrich.koethe@iwr.uni-heidelberg.de    or                    */
00010 /*        vigra@informatik.uni-hamburg.de                               */
00011 /*                                                                      */
00012 /*    Permission is hereby granted, free of charge, to any person       */
00013 /*    obtaining a copy of this software and associated documentation    */
00014 /*    files (the "Software"), to deal in the Software without           */
00015 /*    restriction, including without limitation the rights to use,      */
00016 /*    copy, modify, merge, publish, distribute, sublicense, and/or      */
00017 /*    sell copies of the Software, and to permit persons to whom the    */
00018 /*    Software is furnished to do so, subject to the following          */
00019 /*    conditions:                                                       */
00020 /*                                                                      */
00021 /*    The above copyright notice and this permission notice shall be    */
00022 /*    included in all copies or substantial portions of the             */
00023 /*    Software.                                                         */
00024 /*                                                                      */
00025 /*    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND    */
00026 /*    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES   */
00027 /*    OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND          */
00028 /*    NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT       */
00029 /*    HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,      */
00030 /*    WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING      */
00031 /*    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR     */
00032 /*    OTHER DEALINGS IN THE SOFTWARE.                                   */
00033 /*                                                                      */
00034 /************************************************************************/
00035 #ifndef VIGRA_RANDOM_FOREST_SPLIT_HXX
00036 #define VIGRA_RANDOM_FOREST_SPLIT_HXX
00037 #include <algorithm>
00038 #include <cstddef>
00039 #include <map>
00040 #include <numeric>
00041 #include <math.h>
00042 #include "../mathutil.hxx"
00043 #include "../array_vector.hxx"
00044 #include "../sized_int.hxx"
00045 #include "../matrix.hxx"
00046 #include "../random.hxx"
00047 #include "../functorexpression.hxx"
00048 #include "rf_nodeproxy.hxx"
00049 //#include "rf_sampling.hxx"
00050 #include "rf_region.hxx"
00051 //#include "../hokashyap.hxx"
00052 //#include "vigra/rf_helpers.hxx"
00053 
00054 namespace vigra
00055 {
00056 
00057 // Incomplete Class to ensure that findBestSplit is always implemented in
00058 // the derived classes of SplitBase
00059 class CompileTimeError;
00060 
00061 
00062 namespace detail
00063 {
00064     template<class Tag>
00065     class Normalise
00066     {
00067     public:
00068         template<class Iter>
00069         static void exec(Iter begin, Iter  end)
00070         {}
00071     };
00072 
00073     template<>
00074     class Normalise<ClassificationTag>
00075     {
00076     public:
00077         template<class Iter>
00078         static void exec (Iter begin, Iter end)
00079         {
00080             double bla = std::accumulate(begin, end, 0.0);
00081             for(int ii = 0; ii < end - begin; ++ii)
00082                 begin[ii] = begin[ii]/bla ;
00083         }
00084     };
00085 }
00086 
00087 
00088 /** Base Class for all SplitFunctors used with the \ref RandomForest class
00089     defines the interface used while learning a tree.
00090 **/
00091 template<class Tag>
00092 class SplitBase
00093 {
00094   public:
00095 
00096     typedef Tag           RF_Tag;
00097     typedef DT_StackEntry<ArrayVectorView<Int32>::iterator>
00098                                         StackEntry_t;
00099 
00100     ProblemSpec<>           ext_param_;
00101 
00102     NodeBase::T_Container_type          t_data;
00103     NodeBase::P_Container_type          p_data;
00104 
00105     NodeBase                            node_;
00106 
00107     /** returns the DecisionTree Node created by
00108         \ref findBestSplit or \ref makeTerminalNode.
00109     **/
00110 
00111     template<class T>
00112     void set_external_parameters(ProblemSpec<T> const & in)
00113     {
00114         ext_param_ = in;
00115         t_data.push_back(in.column_count_);
00116         t_data.push_back(in.class_count_);
00117     }
00118 
00119     NodeBase & createNode() 
00120     {
00121         return node_;
00122     }
00123 
00124     int classCount() const
00125     {
00126         return int(t_data[1]);
00127     }
00128 
00129     int featureCount() const
00130     {
00131         return int(t_data[0]);
00132     }
00133 
00134     /** resets internal data. Should always be called before
00135         calling findBestSplit or makeTerminalNode
00136     **/
00137     void reset()
00138     {
00139         t_data.resize(2);
00140         p_data.resize(0);
00141     }
00142 
00143 
00144     /** findBestSplit has to be implemented in derived split functor.
00145         these functions only insures That a CompileTime error is issued
00146         if no such method was defined.
00147     **/
00148 
00149     template<class T, class C, class T2, class C2, class Region, class Random>
00150     int findBestSplit(MultiArrayView<2, T, C> features,
00151                       MultiArrayView<2, T2, C2> labels,
00152                       Region region,
00153                       ArrayVector<Region> childs,
00154                       Random randint)
00155     {
00156         CompileTimeError SplitFunctor__findBestSplit_member_was_not_defined;
00157         return 0;
00158     }
00159 
00160     /** default action for creating a terminal Node.
00161         sets the Class probability of the remaining region according to
00162         the class histogram
00163     **/
00164     template<class T, class C, class T2,class C2, class Region, class Random>
00165     int makeTerminalNode(MultiArrayView<2, T, C> features,
00166                       MultiArrayView<2, T2, C2>  labels,
00167                       Region &                   region,
00168                       Random                     randint)
00169     {
00170         Node<e_ConstProbNode> ret(t_data, p_data);
00171         node_ = ret;
00172         if(ext_param_.class_weights_.size() != region.classCounts().size())
00173         {
00174         std::copy(          region.classCounts().begin(),
00175                             region.classCounts().end(),
00176                             ret.prob_begin());
00177         }
00178         else
00179         {
00180         std::transform(     region.classCounts().begin(),
00181                             region.classCounts().end(),
00182                             ext_param_.class_weights_.begin(),
00183                             ret.prob_begin(), std::multiplies<double>());
00184         }
00185         detail::Normalise<RF_Tag>::exec(ret.prob_begin(), ret.prob_end());
00186         ret.weights() = region.size();  
00187         return e_ConstProbNode;
00188     }
00189 
00190 
00191 };
00192 
00193 /** Functor to sort the indices of a feature Matrix by a certain dimension
00194 **/
00195 template<class DataMatrix>
00196 class SortSamplesByDimensions
00197 {
00198     DataMatrix const & data_;
00199     MultiArrayIndex sortColumn_;
00200     double thresVal_;
00201   public:
00202 
00203     SortSamplesByDimensions(DataMatrix const & data, 
00204                             MultiArrayIndex sortColumn,
00205                             double thresVal = 0.0)
00206     : data_(data),
00207       sortColumn_(sortColumn),
00208       thresVal_(thresVal)
00209     {}
00210 
00211     void setColumn(MultiArrayIndex sortColumn)
00212     {
00213         sortColumn_ = sortColumn;
00214     }
00215     void setThreshold(double value)
00216     {
00217         thresVal_ = value; 
00218     }
00219 
00220     bool operator()(MultiArrayIndex l, MultiArrayIndex r) const
00221     {
00222         return data_(l, sortColumn_) < data_(r, sortColumn_);
00223     }
00224     bool operator()(MultiArrayIndex l) const
00225     {
00226         return data_(l, sortColumn_) < thresVal_;
00227     }
00228 };
00229 
00230 template<class DataMatrix>
00231 class DimensionNotEqual
00232 {
00233     DataMatrix const & data_;
00234     MultiArrayIndex sortColumn_;
00235 
00236   public:
00237 
00238     DimensionNotEqual(DataMatrix const & data, 
00239                             MultiArrayIndex sortColumn)
00240     : data_(data),
00241       sortColumn_(sortColumn)
00242     {}
00243 
00244     void setColumn(MultiArrayIndex sortColumn)
00245     {
00246         sortColumn_ = sortColumn;
00247     }
00248 
00249     bool operator()(MultiArrayIndex l, MultiArrayIndex r) const
00250     {
00251         return data_(l, sortColumn_) != data_(r, sortColumn_);
00252     }
00253 };
00254 
00255 template<class DataMatrix>
00256 class SortSamplesByHyperplane
00257 {
00258     DataMatrix const & data_;
00259     Node<i_HyperplaneNode> const & node_;
00260 
00261   public:
00262 
00263     SortSamplesByHyperplane(DataMatrix              const & data, 
00264                             Node<i_HyperplaneNode>  const & node)
00265     :       
00266             data_(data), 
00267             node_()
00268     {}
00269 
00270     /** calculate the distance of a sample point to a hyperplane
00271      */
00272     double operator[](MultiArrayIndex l) const
00273     {
00274         double result_l = -1 * node_.intercept();
00275         for(int ii = 0; ii < node_.columns_size(); ++ii)
00276         {
00277             result_l +=     rowVector(data_, l)[node_.columns_begin()[ii]] 
00278                         *   node_.weights()[ii];
00279         }
00280         return result_l;
00281     }
00282 
00283     bool operator()(MultiArrayIndex l, MultiArrayIndex r) const
00284     {
00285         return (*this)[l]  < (*this)[r];
00286     }
00287 
00288 };
00289 
00290 /** makes a Class Histogram given indices in a labels_ array
00291  *  usage: 
00292  *      MultiArrayView<2, T2, C2> labels = makeSomeLabels()
00293  *      ArrayVector<int> hist(numberOfLabels(labels), 0);
00294  *      RandomForestClassCounter<T2, C2, ArrayVector> counter(labels, hist);
00295  *
00296  *      Container<int> indices = getSomeIndices()
00297  *      std::for_each(indices, counter);
00298  */
00299 template <class DataSource, class CountArray>
00300 class RandomForestClassCounter
00301 {
00302     DataSource  const &     labels_;
00303     CountArray        &     counts_;
00304 
00305   public:
00306 
00307     RandomForestClassCounter(DataSource  const & labels, 
00308                              CountArray & counts)
00309     : labels_(labels),
00310       counts_(counts)
00311     {
00312         reset();
00313     }
00314 
00315     void reset()
00316     {
00317         counts_.init(0);
00318     }
00319 
00320     void operator()(MultiArrayIndex l) const
00321     {
00322         counts_[labels_[l]] +=1;
00323     }
00324 };
00325 
00326 
00327 /** Functor To Calculate the Best possible Split Based on the Gini Index
00328     given Labels and Features along a given Axis
00329 */
00330 
00331 namespace detail
00332 {
00333     template<int N>
00334     class ConstArr
00335     {
00336     public:
00337         double operator[](size_t) const
00338         {
00339             return (double)N;
00340         }
00341     };
00342 
00343 
00344 }
00345 
00346 
00347 
00348 
00349 /** Functor to calculate the entropy based impurity
00350  */
00351 class EntropyCriterion
00352 {
00353 public:
00354     /**caculate the weighted gini impurity based on class histogram
00355      * and class weights
00356      */
00357     template<class Array, class Array2>
00358     double operator()        (Array     const & hist, 
00359                               Array2    const & weights, 
00360                               double            total = 1.0) const
00361     {
00362         return impurity(hist, weights, total);
00363     }
00364     
00365     /** calculate the gini based impurity based on class histogram
00366      */
00367     template<class Array>
00368     double operator()(Array const & hist, double total = 1.0) const
00369     {
00370         return impurity(hist, total);
00371     }
00372     
00373     /** static version of operator(hist total)
00374      */
00375     template<class Array>
00376     static double impurity(Array const & hist, double total)
00377     {
00378         return impurity(hist, detail::ConstArr<1>(), total);
00379     }
00380 
00381     /** static version of operator(hist, weights, total)
00382      */
00383     template<class Array, class Array2>
00384     static double impurity   (Array     const & hist, 
00385                               Array2    const & weights, 
00386                               double            total)
00387     {
00388 
00389         int     class_count     = hist.size();
00390         double  entropy            = 0.0;
00391         if(class_count == 2)
00392         {
00393             double p0           = (hist[0]/total);
00394             double p1           = (hist[1]/total);
00395             entropy             = 0 - weights[0]*p0*std::log(p0) - weights[1]*p1*std::log(p1);
00396         }
00397         else
00398         {
00399             for(int ii = 0; ii < class_count; ++ii)
00400             {
00401                 double w        = weights[ii];
00402                 double pii      = hist[ii]/total;
00403                 entropy         -= w*( pii*std::log(pii));
00404             }
00405         }
00406         entropy             = total * entropy;
00407         return entropy; 
00408     }
00409 };
00410 
00411 /** Functor to calculate the gini impurity
00412  */
00413 class GiniCriterion
00414 {
00415 public:
00416     /**caculate the weighted gini impurity based on class histogram
00417      * and class weights
00418      */
00419     template<class Array, class Array2>
00420     double operator()        (Array     const & hist, 
00421                               Array2    const & weights, 
00422                               double            total = 1.0) const
00423     {
00424         return impurity(hist, weights, total);
00425     }
00426     
00427     /** calculate the gini based impurity based on class histogram
00428      */
00429     template<class Array>
00430     double operator()(Array const & hist, double total = 1.0) const
00431     {
00432         return impurity(hist, total);
00433     }
00434     
00435     /** static version of operator(hist total)
00436      */
00437     template<class Array>
00438     static double impurity(Array const & hist, double total)
00439     {
00440         return impurity(hist, detail::ConstArr<1>(), total);
00441     }
00442 
00443     /** static version of operator(hist, weights, total)
00444      */
00445     template<class Array, class Array2>
00446     static double impurity   (Array     const & hist, 
00447                               Array2    const & weights, 
00448                               double            total)
00449     {
00450 
00451         int     class_count     = hist.size();
00452         double  gini            = 0.0;
00453         if(class_count == 2)
00454         {
00455             double w            = weights[0] * weights[1];
00456             gini                = w * (hist[0] * hist[1] / total);
00457         }
00458         else
00459         {
00460             for(int ii = 0; ii < class_count; ++ii)
00461             {
00462                 double w        = weights[ii];
00463                 gini           += w*( hist[ii]*( 1.0 - w * hist[ii]/total ) );
00464             }
00465         }
00466         return gini; 
00467     }
00468 };
00469 
00470 
00471 template <class DataSource, class Impurity= GiniCriterion>
00472 class ImpurityLoss
00473 {
00474 
00475     DataSource  const &         labels_;
00476     ArrayVector<double>        counts_;
00477     ArrayVector<double> const  class_weights_;
00478     double                      total_counts_;
00479     Impurity                    impurity_;
00480 
00481   public:
00482 
00483     template<class T>
00484     ImpurityLoss(DataSource  const & labels, 
00485                                 ProblemSpec<T> const & ext_)
00486     : labels_(labels),
00487       counts_(ext_.class_count_, 0.0),
00488       class_weights_(ext_.class_weights_),
00489       total_counts_(0.0)
00490     {}
00491 
00492     void reset()
00493     {
00494         counts_.init(0);
00495         total_counts_ = 0.0;
00496     }
00497 
00498     template<class Counts>
00499     double increment_histogram(Counts const & counts)
00500     {
00501         std::transform(counts.begin(), counts.end(),
00502                        counts_.begin(), counts_.begin(),
00503                        std::plus<double>());
00504         total_counts_ = std::accumulate( counts_.begin(), 
00505                                          counts_.end(),
00506                                          0.0);
00507         return impurity_(counts_, class_weights_, total_counts_);
00508     }
00509 
00510     template<class Counts>
00511     double decrement_histogram(Counts const & counts)
00512     {
00513         std::transform(counts.begin(), counts.end(),
00514                        counts_.begin(), counts_.begin(),
00515                        std::minus<double>());
00516         total_counts_ = std::accumulate( counts_.begin(), 
00517                                          counts_.end(),
00518                                          0.0);
00519         return impurity_(counts_, class_weights_, total_counts_);
00520     }
00521 
00522     template<class Iter>
00523     double increment(Iter begin, Iter end)
00524     {
00525         for(Iter iter = begin; iter != end; ++iter)
00526         {
00527             counts_[labels_(*iter, 0)] +=1.0;
00528             total_counts_ +=1.0;
00529         }
00530         return impurity_(counts_, class_weights_, total_counts_);
00531     }
00532 
00533     template<class Iter>
00534     double decrement(Iter const &  begin, Iter const & end)
00535     {
00536         for(Iter iter = begin; iter != end; ++iter)
00537         {
00538             counts_[labels_(*iter,0)] -=1.0;
00539             total_counts_ -=1.0;
00540         }
00541         return impurity_(counts_, class_weights_, total_counts_);
00542     }
00543 
00544     template<class Iter, class Resp_t>
00545     double init (Iter begin, Iter end, Resp_t resp)
00546     {
00547         reset();
00548         std::copy(resp.begin(), resp.end(), counts_.begin());
00549         total_counts_ = std::accumulate(counts_.begin(), counts_.end(), 0.0); 
00550         return impurity_(counts_,class_weights_, total_counts_);
00551     }
00552     
00553     ArrayVector<double> const & response()
00554     {
00555         return counts_;
00556     }
00557 };
00558 
00559 template <class DataSource>
00560 class RegressionForestCounter
00561 {
00562     typedef MultiArrayShape<2>::type Shp;
00563     DataSource const &      labels_;
00564     ArrayVector <double>    mean_;
00565     ArrayVector <double>    variance_;
00566     ArrayVector <double>    tmp_;
00567     size_t                  count_;
00568 
00569     template<class T>
00570     RegressionForestCounter(DataSource const & labels, 
00571                             ProblemSpec<T> const & ext_)
00572     :
00573         labels_(labels),
00574         mean_(ext_.response_size, 0.0),
00575         variance_(ext_.response_size, 0.0),
00576         tmp_(ext_.response_size),
00577         count_(0)
00578     {}
00579     
00580     //  west's alorithm for incremental variance
00581     // calculation
00582     template<class Iter>
00583     double increment (Iter begin, Iter end)
00584     {
00585         for(Iter iter = begin; iter != end; ++iter)
00586         {
00587             ++count_;
00588             for(int ii = 0; ii < mean_.size(); ++ii)
00589                 tmp_[ii] = labels_(*iter, ii) - mean_[ii]; 
00590             double f  = 1.0 / count_,
00591                    f1 = 1.0 - f;
00592             for(int ii = 0; ii < mean_.size(); ++ii)
00593                 mean_[ii] += f*tmp_[ii]; 
00594             for(int ii = 0; ii < mean_.size(); ++ii)
00595                 variance_[ii] += f1*sq(tmp_[ii]);
00596         }
00597         return std::accumulate(variance_.begin(), 
00598                                variance_.end(),
00599                                0.0,
00600                                std::plus<double>())
00601                 /(count_ -1);
00602     }
00603 
00604     template<class Iter>
00605     double decrement (Iter begin, Iter end)
00606     {
00607         for(Iter iter = begin; iter != end; ++iter)
00608         {
00609             --count_;
00610             for(int ii = 0; ii < mean_.size(); ++ii)
00611                 tmp_[ii] = labels_(*iter, ii) - mean_[ii]; 
00612             double f  = 1.0 / count_,
00613                    f1 = 1.0 + f;
00614             for(int ii = 0; ii < mean_.size(); ++ii)
00615                 mean_[ii] -= f*tmp_[ii]; 
00616             for(int ii = 0; ii < mean_.size(); ++ii)
00617                 variance_[ii] -= f1*sq(tmp_[ii]);
00618         }
00619         return std::accumulate(variance_.begin(), 
00620                                variance_.end(),
00621                                0.0,
00622                                std::plus<double>())
00623                 /(count_ -1);
00624     }
00625 
00626     template<class Iter, class Resp_t>
00627     double init (Iter begin, Iter end, Resp_t resp)
00628     {
00629         reset();
00630         return increment(begin, end);
00631     }
00632     
00633 
00634     ArrayVector<double> const & response()
00635     {
00636         return mean_;
00637     }
00638 
00639     void reset()
00640     {
00641         mean_.init(0.0);
00642         variance_.init(0.0);
00643         count_ = 0; 
00644     }
00645 };
00646 
00647 template<class Tag, class Datatyp>
00648 struct LossTraits;
00649 
00650 struct LSQLoss
00651 {};
00652 
00653 template<class Datatype>
00654 struct LossTraits<GiniCriterion, Datatype>
00655 {
00656     typedef ImpurityLoss<Datatype, GiniCriterion> type;
00657 };
00658 
00659 template<class Datatype>
00660 struct LossTraits<EntropyCriterion, Datatype>
00661 {
00662     typedef ImpurityLoss<Datatype, EntropyCriterion> type;
00663 };
00664 
00665 template<class Datatype>
00666 struct LossTraits<LSQLoss, Datatype>
00667 {
00668     typedef RegressionForestCounter<Datatype> type;
00669 };
00670 
00671 /** Given a column, choose a split that minimizes some loss
00672  */
00673 template<class LineSearchLossTag>
00674 class BestGiniOfColumn
00675 {
00676 public:
00677     ArrayVector<double>     class_weights_;
00678     ArrayVector<double>     bestCurrentCounts[2];
00679     double                  min_gini_;
00680     ptrdiff_t               min_index_;
00681     double                  min_threshold_;
00682     ProblemSpec<>           ext_param_;
00683 
00684     BestGiniOfColumn()
00685     {}
00686 
00687     template<class T> 
00688     BestGiniOfColumn(ProblemSpec<T> const & ext)
00689     :
00690         class_weights_(ext.class_weights_),
00691         ext_param_(ext)
00692     {
00693         bestCurrentCounts[0].resize(ext.class_count_);
00694         bestCurrentCounts[1].resize(ext.class_count_);
00695     }
00696     template<class T> 
00697     void set_external_parameters(ProblemSpec<T> const & ext)
00698     {
00699         class_weights_ = ext.class_weights_; 
00700         ext_param_ = ext;
00701         bestCurrentCounts[0].resize(ext.class_count_);
00702         bestCurrentCounts[1].resize(ext.class_count_);
00703     }
00704     /** calculate the best gini split along a Feature Column
00705      * \param column, the feature vector - has to support the [] operator
00706      * \param labels, the label vector 
00707      * \param begin 
00708      * \param end     (in and out)
00709      *                begin and end iterators to the indices of the
00710      *                samples in the current region. 
00711      *                the range begin - end is sorted by the column supplied
00712      *                during function execution.
00713      * \param class_counts
00714      *                class histogram of the range. 
00715      *
00716      *  precondition: begin, end valid range, 
00717      *                class_counts positive integer valued array with the 
00718      *                class counts in the current range.
00719      *                labels.size() >= max(begin, end); 
00720      *  postcondition:
00721      *                begin, end sorted by column given. 
00722      *                min_gini_ contains the minimum gini found or 
00723      *                NumericTraits<double>::max if no split was found.
00724      *                min_index_ countains the splitting index in the range
00725      *                or invalid data if no split was found.
00726      *                BestCirremtcounts[0] and [1] contain the 
00727      *                class histogram of the left and right region of 
00728      *                the left and right regions. 
00729      */
00730     template<   class DataSourceF_t,
00731                 class DataSource_t, 
00732                 class I_Iter, 
00733                 class Array>
00734     void operator()(DataSourceF_t   const & column,
00735                     int                     g,
00736                     DataSource_t    const & labels,
00737                     I_Iter                & begin, 
00738                     I_Iter                & end,
00739                     Array           const & region_response)
00740     {
00741         std::sort(begin, end, 
00742                   SortSamplesByDimensions<DataSourceF_t>(column, g));
00743         typedef typename 
00744             LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
00745         LineSearchLoss left(labels, ext_param_);
00746         LineSearchLoss right(labels, ext_param_);
00747 
00748         
00749 
00750         min_gini_ = right.init(begin, end, region_response);
00751         min_threshold_ = *begin;
00752         min_index_     = 0;
00753         DimensionNotEqual<DataSourceF_t> comp(column, g); 
00754         
00755         I_Iter iter = begin;
00756         I_Iter next = std::adjacent_find(iter, end, comp);
00757         while( next  != end)
00758         {
00759 
00760             double loss = right.decrement(iter, next + 1) 
00761                      +     left.increment(iter , next + 1);
00762 #ifdef CLASSIFIER_TEST
00763             if(loss < min_gini_ && !closeAtTolerance(loss, min_gini_))
00764 #else
00765             if(loss < min_gini_ )
00766 #endif 
00767             {
00768                 bestCurrentCounts[0] = left.response();
00769                 bestCurrentCounts[1] = right.response();
00770 #ifdef CLASSIFIER_TEST
00771                 min_gini_       = loss < min_gini_? loss : min_gini_;
00772 #else
00773                 min_gini_       = loss; 
00774 #endif
00775                 min_index_      = next - begin +1 ;
00776                 min_threshold_  = (double(column(*next,g)) + double(column(*(next +1), g)))/2.0;
00777             }
00778             iter = next +1 ;
00779             next = std::adjacent_find(iter, end, comp);
00780         }
00781     }
00782 
00783     template<class DataSource_t, class Iter, class Array>
00784     double loss_of_region(DataSource_t const & labels,
00785                           Iter & begin, 
00786                           Iter & end, 
00787                           Array const & region_response) const
00788     {
00789         typedef typename 
00790             LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
00791         LineSearchLoss region_loss(labels, ext_param_);
00792         return 
00793             region_loss.init(begin, end, region_response);
00794     }
00795 
00796 };
00797 
00798 
00799 /** Chooses mtry columns ad applys ColumnDecisionFunctor to each of the
00800  * columns. Then Chooses the column that is best
00801  */
00802 template<class ColumnDecisionFunctor, class Tag = ClassificationTag>
00803 class ThresholdSplit: public SplitBase<Tag>
00804 {
00805   public:
00806 
00807 
00808     typedef SplitBase<Tag> SB;
00809     
00810     ArrayVector<Int32>          splitColumns;
00811     ColumnDecisionFunctor       bgfunc;
00812 
00813     double                      region_gini_;
00814     ArrayVector<double>         min_gini_;
00815     ArrayVector<ptrdiff_t>      min_indices_;
00816     ArrayVector<double>         min_thresholds_;
00817 
00818     int                         bestSplitIndex;
00819 
00820     double minGini() const
00821     {
00822         return min_gini_[bestSplitIndex];
00823     }
00824     int bestSplitColumn() const
00825     {
00826         return splitColumns[bestSplitIndex];
00827     }
00828     double bestSplitThreshold() const
00829     {
00830         return min_thresholds_[bestSplitIndex];
00831     }
00832 
00833     template<class T>
00834     void set_external_parameters(ProblemSpec<T> const & in)
00835     {
00836         SB::set_external_parameters(in);        
00837         bgfunc.set_external_parameters( SB::ext_param_);
00838         int featureCount_ = SB::ext_param_.column_count_;
00839         splitColumns.resize(featureCount_);
00840         for(int k=0; k<featureCount_; ++k)
00841             splitColumns[k] = k;
00842         min_gini_.resize(featureCount_);
00843         min_indices_.resize(featureCount_);
00844         min_thresholds_.resize(featureCount_);
00845     }
00846 
00847 
00848     template<class T, class C, class T2, class C2, class Region, class Random>
00849     int findBestSplit(MultiArrayView<2, T, C> features,
00850                       MultiArrayView<2, T2, C2>  labels,
00851                       Region & region,
00852                       ArrayVector<Region>& childRegions,
00853                       Random & randint)
00854     {
00855 
00856         typedef typename Region::IndexIterator IndexIterator;
00857         if(region.size() == 0)
00858         {
00859            std::cerr << "SplitFunctor::findBestSplit(): stackentry with 0 examples encountered\n"
00860                         "continuing learning process...."; 
00861         }
00862         // calculate things that haven't been calculated yet. 
00863         
00864         if(std::accumulate(region.classCounts().begin(),
00865                            region.classCounts().end(), 0) != region.size())
00866         {
00867             RandomForestClassCounter<   MultiArrayView<2,T2, C2>, 
00868                                         ArrayVector<double> >
00869                 counter(labels, region.classCounts());
00870             std::for_each(  region.begin(), region.end(), counter);
00871             region.classCountsIsValid = true;
00872         }
00873 
00874         // Is the region pure already?
00875         region_gini_ = bgfunc.loss_of_region(labels,
00876                                              region.begin(), 
00877                                              region.end(),
00878                                              region.classCounts());
00879         if(region_gini_ <= SB::ext_param_.precision_)
00880             return  makeTerminalNode(features, labels, region, randint);
00881 
00882         // select columns  to be tried.
00883         for(int ii = 0; ii < SB::ext_param_.actual_mtry_; ++ii)
00884             std::swap(splitColumns[ii], 
00885                       splitColumns[ii+ randint(features.shape(1) - ii)]);
00886 
00887         // find the best gini index
00888         bestSplitIndex              = 0;
00889         double  current_min_gini    = region_gini_;
00890         int     num2try             = features.shape(1);
00891         for(int k=0; k<num2try; ++k)
00892         {
00893             //this functor does all the work
00894             bgfunc(features,
00895                    splitColumns[k],
00896                    labels, 
00897                    region.begin(), region.end(), 
00898                    region.classCounts());
00899             min_gini_[k]            = bgfunc.min_gini_; 
00900             min_indices_[k]         = bgfunc.min_index_;
00901             min_thresholds_[k]      = bgfunc.min_threshold_;
00902 #ifdef CLASSIFIER_TEST
00903             if(     bgfunc.min_gini_ < current_min_gini
00904                &&  !closeAtTolerance(bgfunc.min_gini_, current_min_gini))
00905 #else
00906             if(bgfunc.min_gini_ < current_min_gini)
00907 #endif
00908             {
00909                 current_min_gini = bgfunc.min_gini_;
00910                 childRegions[0].classCounts() = bgfunc.bestCurrentCounts[0];
00911                 childRegions[1].classCounts() = bgfunc.bestCurrentCounts[1];
00912                 childRegions[0].classCountsIsValid = true;
00913                 childRegions[1].classCountsIsValid = true;
00914 
00915                 bestSplitIndex   = k;
00916                 num2try = SB::ext_param_.actual_mtry_;
00917             }
00918         }
00919 
00920         // did not find any suitable split
00921         if(closeAtTolerance(current_min_gini, region_gini_))
00922             return  makeTerminalNode(features, labels, region, randint);
00923         
00924         //create a Node for output
00925         Node<i_ThresholdNode>   node(SB::t_data, SB::p_data);
00926         SB::node_ = node;
00927         node.threshold()    = min_thresholds_[bestSplitIndex];
00928         node.column()       = splitColumns[bestSplitIndex];
00929         
00930         // partition the range according to the best dimension 
00931         SortSamplesByDimensions<MultiArrayView<2, T, C> > 
00932             sorter(features, node.column(), node.threshold());
00933         IndexIterator bestSplit =
00934             std::partition(region.begin(), region.end(), sorter);
00935         // Save the ranges of the child stack entries.
00936         childRegions[0].setRange(   region.begin()  , bestSplit       );
00937         childRegions[0].rule = region.rule;
00938         childRegions[0].rule.push_back(std::make_pair(1, 1.0));
00939         childRegions[1].setRange(   bestSplit       , region.end()    );
00940         childRegions[1].rule = region.rule;
00941         childRegions[1].rule.push_back(std::make_pair(1, 1.0));
00942 
00943         return i_ThresholdNode;
00944     }
00945 };
00946 
00947 typedef  ThresholdSplit<BestGiniOfColumn<GiniCriterion> >                    GiniSplit;
00948 typedef  ThresholdSplit<BestGiniOfColumn<EntropyCriterion> >                 EntropySplit;
00949 typedef  ThresholdSplit<BestGiniOfColumn<LSQLoss>, RegressionTag>            RegressionSplit;
00950 
00951 namespace rf
00952 {
00953 
00954 /** This namespace contains additional Splitfunctors.
00955  *
00956  * The Split functor classes are designed in a modular fashion because new split functors may 
00957  * share a lot of code with existing ones. 
00958  * 
00959  * ThresholdSplit implements the functionality needed for any split functor, that makes its 
00960  * decision via one dimensional axis-parallel cuts. The Template parameter defines how the split
00961  * along one dimension is chosen. 
00962  *
00963  * The BestGiniOfColumn class chooses a split that minimizes one of the Loss functions supplied
00964  * (GiniCriterion for classification and LSQLoss for regression). Median chooses the Split in a 
00965  * kD tree fashion. 
00966  *
00967  *
00968  * Currently defined typedefs: 
00969  * \code
00970  * typedef  ThresholdSplit<BestGiniOfColumn<GiniCriterion> >                 GiniSplit;
00971  * typedef  ThresholdSplit<BestGiniOfColumn<LSQLoss>, RegressionTag>         RegressionSplit;
00972  * typedef  ThresholdSplit<Median> MedianSplit;
00973  * \endcode
00974  */
00975 namespace split
00976 {
00977 
00978 /** This Functor chooses the median value of a column
00979  */
00980 class Median
00981 {
00982 public:
00983 
00984     typedef GiniCriterion   LineSearchLossTag;
00985     ArrayVector<double>     class_weights_;
00986     ArrayVector<double>     bestCurrentCounts[2];
00987     double                  min_gini_;
00988     ptrdiff_t               min_index_;
00989     double                  min_threshold_;
00990     ProblemSpec<>           ext_param_;
00991 
00992     Median()
00993     {}
00994 
00995     template<class T> 
00996     Median(ProblemSpec<T> const & ext)
00997     :
00998         class_weights_(ext.class_weights_),
00999         ext_param_(ext)
01000     {
01001         bestCurrentCounts[0].resize(ext.class_count_);
01002         bestCurrentCounts[1].resize(ext.class_count_);
01003     }
01004   
01005     template<class T> 
01006     void set_external_parameters(ProblemSpec<T> const & ext)
01007     {
01008         class_weights_ = ext.class_weights_; 
01009         ext_param_ = ext;
01010         bestCurrentCounts[0].resize(ext.class_count_);
01011         bestCurrentCounts[1].resize(ext.class_count_);
01012     }
01013      
01014     template<   class DataSourceF_t,
01015                 class DataSource_t, 
01016                 class I_Iter, 
01017                 class Array>
01018     void operator()(DataSourceF_t   const & column,
01019                     DataSource_t    const & labels,
01020                     I_Iter                & begin, 
01021                     I_Iter                & end,
01022                     Array           const & region_response)
01023     {
01024         std::sort(begin, end, 
01025                   SortSamplesByDimensions<DataSourceF_t>(column, 0));
01026         typedef typename 
01027             LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
01028         LineSearchLoss left(labels, ext_param_);
01029         LineSearchLoss right(labels, ext_param_);
01030         right.init(begin, end, region_response);
01031 
01032         min_gini_ = NumericTraits<double>::max();
01033         min_index_ = floor(double(end - begin)/2.0); 
01034         min_threshold_ =  column[*(begin + min_index_)];
01035         SortSamplesByDimensions<DataSourceF_t> 
01036             sorter(column, 0, min_threshold_);
01037         I_Iter part = std::partition(begin, end, sorter);
01038         DimensionNotEqual<DataSourceF_t> comp(column, 0); 
01039         if(part == begin)
01040         {
01041             part= std::adjacent_find(part, end, comp)+1;
01042             
01043         }
01044         if(part >= end)
01045         {
01046             return; 
01047         }
01048         else
01049         {
01050             min_threshold_ = column[*part];
01051         }
01052         min_gini_ = right.decrement(begin, part) 
01053               +     left.increment(begin , part);
01054 
01055         bestCurrentCounts[0] = left.response();
01056         bestCurrentCounts[1] = right.response();
01057         
01058         min_index_      = part - begin;
01059     }
01060 
01061     template<class DataSource_t, class Iter, class Array>
01062     double loss_of_region(DataSource_t const & labels,
01063                           Iter & begin, 
01064                           Iter & end, 
01065                           Array const & region_response) const
01066     {
01067         typedef typename 
01068             LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
01069         LineSearchLoss region_loss(labels, ext_param_);
01070         return 
01071             region_loss.init(begin, end, region_response);
01072     }
01073 
01074 };
01075 
01076 typedef  ThresholdSplit<Median> MedianSplit;
01077 
01078 
01079 /** This Functor chooses a random value of a column
01080  */
01081 class RandomSplitOfColumn
01082 {
01083 public:
01084 
01085     typedef GiniCriterion   LineSearchLossTag;
01086     ArrayVector<double>     class_weights_;
01087     ArrayVector<double>     bestCurrentCounts[2];
01088     double                  min_gini_;
01089     ptrdiff_t               min_index_;
01090     double                  min_threshold_;
01091     ProblemSpec<>           ext_param_;
01092     typedef RandomMT19937   Random_t;
01093     Random_t                random;
01094 
01095     RandomSplitOfColumn()
01096     {}
01097 
01098     template<class T> 
01099     RandomSplitOfColumn(ProblemSpec<T> const & ext)
01100     :
01101         class_weights_(ext.class_weights_),
01102         ext_param_(ext),
01103         random(RandomSeed)
01104     {
01105         bestCurrentCounts[0].resize(ext.class_count_);
01106         bestCurrentCounts[1].resize(ext.class_count_);
01107     }
01108     
01109     template<class T> 
01110     RandomSplitOfColumn(ProblemSpec<T> const & ext, Random_t & random_)
01111     :
01112         class_weights_(ext.class_weights_),
01113         ext_param_(ext),
01114         random(random_)
01115     {
01116         bestCurrentCounts[0].resize(ext.class_count_);
01117         bestCurrentCounts[1].resize(ext.class_count_);
01118     }
01119   
01120     template<class T> 
01121     void set_external_parameters(ProblemSpec<T> const & ext)
01122     {
01123         class_weights_ = ext.class_weights_; 
01124         ext_param_ = ext;
01125         bestCurrentCounts[0].resize(ext.class_count_);
01126         bestCurrentCounts[1].resize(ext.class_count_);
01127     }
01128      
01129     template<   class DataSourceF_t,
01130                 class DataSource_t, 
01131                 class I_Iter, 
01132                 class Array>
01133     void operator()(DataSourceF_t   const & column,
01134                     DataSource_t    const & labels,
01135                     I_Iter                & begin, 
01136                     I_Iter                & end,
01137                     Array           const & region_response)
01138     {
01139         std::sort(begin, end, 
01140                   SortSamplesByDimensions<DataSourceF_t>(column, 0));
01141         typedef typename 
01142             LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
01143         LineSearchLoss left(labels, ext_param_);
01144         LineSearchLoss right(labels, ext_param_);
01145         right.init(begin, end, region_response);
01146 
01147         
01148         min_gini_ = NumericTraits<double>::max();
01149         
01150         min_index_ = begin + random.uniformInt(end -begin);
01151         min_threshold_ =  column[*(begin + min_index_)];
01152         SortSamplesByDimensions<DataSourceF_t> 
01153             sorter(column, 0, min_threshold_);
01154         I_Iter part = std::partition(begin, end, sorter);
01155         DimensionNotEqual<DataSourceF_t> comp(column, 0); 
01156         if(part == begin)
01157         {
01158             part= std::adjacent_find(part, end, comp)+1;
01159             
01160         }
01161         if(part >= end)
01162         {
01163             return; 
01164         }
01165         else
01166         {
01167             min_threshold_ = column[*part];
01168         }
01169         min_gini_ = right.decrement(begin, part) 
01170               +     left.increment(begin , part);
01171 
01172         bestCurrentCounts[0] = left.response();
01173         bestCurrentCounts[1] = right.response();
01174         
01175         min_index_      = part - begin;
01176     }
01177 
01178     template<class DataSource_t, class Iter, class Array>
01179     double loss_of_region(DataSource_t const & labels,
01180                           Iter & begin, 
01181                           Iter & end, 
01182                           Array const & region_response) const
01183     {
01184         typedef typename 
01185             LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
01186         LineSearchLoss region_loss(labels, ext_param_);
01187         return 
01188             region_loss.init(begin, end, region_response);
01189     }
01190 
01191 };
01192 
01193 typedef  ThresholdSplit<RandomSplitOfColumn> RandomSplit;
01194 }
01195 }
01196 
01197 
01198 } //namespace vigra
01199 #endif // VIGRA_RANDOM_FOREST_SPLIT_HXX

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.7.1 (Sun Feb 19 2012)