[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

vigra/random_forest/rf_split.hxx
00001 /************************************************************************/
00002 /*                                                                      */
00003 /*        Copyright 2008-2009 by  Ullrich Koethe and Rahul Nair         */
00004 /*                                                                      */
00005 /*    This file is part of the VIGRA computer vision library.           */
00006 /*    The VIGRA Website is                                              */
00007 /*        http://hci.iwr.uni-heidelberg.de/vigra/                       */
00008 /*    Please direct questions, bug reports, and contributions to        */
00009 /*        ullrich.koethe@iwr.uni-heidelberg.de    or                    */
00010 /*        vigra@informatik.uni-hamburg.de                               */
00011 /*                                                                      */
00012 /*    Permission is hereby granted, free of charge, to any person       */
00013 /*    obtaining a copy of this software and associated documentation    */
00014 /*    files (the "Software"), to deal in the Software without           */
00015 /*    restriction, including without limitation the rights to use,      */
00016 /*    copy, modify, merge, publish, distribute, sublicense, and/or      */
00017 /*    sell copies of the Software, and to permit persons to whom the    */
00018 /*    Software is furnished to do so, subject to the following          */
00019 /*    conditions:                                                       */
00020 /*                                                                      */
00021 /*    The above copyright notice and this permission notice shall be    */
00022 /*    included in all copies or substantial portions of the             */
00023 /*    Software.                                                         */
00024 /*                                                                      */
00025 /*    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND    */
00026 /*    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES   */
00027 /*    OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND          */
00028 /*    NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT       */
00029 /*    HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,      */
00030 /*    WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING      */
00031 /*    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR     */
00032 /*    OTHER DEALINGS IN THE SOFTWARE.                                   */
00033 /*                                                                      */
00034 /************************************************************************/
00035 #ifndef VIGRA_RANDOM_FOREST_SPLIT_HXX
00036 #define VIGRA_RANDOM_FOREST_SPLIT_HXX
00037 #include <algorithm>
00038 #include <map>
00039 #include <numeric>
00040 #include "../mathutil.hxx"
00041 #include "../array_vector.hxx"
00042 #include "../sized_int.hxx"
00043 #include "../matrix.hxx"
00044 #include "../random.hxx"
00045 #include "../functorexpression.hxx"
00046 #include "rf_nodeproxy.hxx"
00047 #include "rf_sampling.hxx"
00048 //#include "../hokashyap.hxx"
00049 //#include "vigra/rf_helpers.hxx"
00050 
00051 namespace vigra
00052 {
00053 
00054 // Incomplete Class to ensure that findBestSplit is always implemented in
00055 // the derived classes of SplitBase
00056 class CompileTimeError;
00057 
00058 
00059 namespace detail
00060 {
00061     template<class Tag>
00062     class Normalise
00063     {
00064     public:
00065         template<class Iter>
00066         static void exec(Iter begin, Iter  end)
00067         {}
00068     };
00069 
00070     template<>
00071     class Normalise<ClassificationTag>
00072     {
00073     public:
00074         template<class Iter>
00075         static void exec (Iter begin, Iter end)
00076         {
00077             double bla = std::accumulate(begin, end, 0.0);
00078             for(int ii = 0; ii < end - begin; ++ii)
00079                 begin[ii] = begin[ii]/bla ;
00080         }
00081     };
00082 }
00083 
00084 
00085 /** Base Class for all SplitFunctors used with the \ref RandomForest class
00086     defines the interface used while learning a tree.
00087 **/
00088 template<class Tag>
00089 class SplitBase
00090 {
00091   public:
00092 
00093     typedef Tag           RF_Tag;
00094     typedef DT_StackEntry<ArrayVectorView<Int32>::iterator>
00095                                         StackEntry_t;
00096 
00097     ProblemSpec<>           ext_param_;
00098 
00099     NodeBase::T_Container_type          t_data;
00100     NodeBase::P_Container_type          p_data;
00101 
00102     NodeBase                            node_;
00103 
00104     /** returns the DecisionTree Node created by
00105         \ref findBestSplit or \ref makeTerminalNode.
00106     **/
00107 
00108     template<class T>
00109     void set_external_parameters(ProblemSpec<T> const & in)
00110     {
00111         ext_param_ = in;
00112         t_data.push_back(in.column_count_);
00113         t_data.push_back(in.class_count_);
00114     }
00115 
00116     NodeBase & createNode() 
00117     {
00118         return node_;
00119     }
00120 
00121     int classCount() const
00122     {
00123         return int(t_data[1]);
00124     }
00125 
00126     int featureCount() const
00127     {
00128         return int(t_data[0]);
00129     }
00130 
00131     /** resets internal data. Should always be called before
00132         calling findBestSplit or makeTerminalNode
00133     **/
00134     void reset()
00135     {
00136         t_data.resize(2);
00137         p_data.resize(0);
00138     }
00139 
00140 
00141     /** findBestSplit has to be implemented in derived split functor.
00142         these functions only insures That a CompileTime error is issued
00143         if no such method was defined.
00144     **/
00145 
00146     template<class T, class C, class T2, class C2, class Region, class Random>
00147     int findBestSplit(MultiArrayView<2, T, C> features,
00148                       MultiArrayView<2, T2, C2> labels,
00149                       Region region,
00150                       ArrayVector<Region> childs,
00151                       Random randint)
00152     {
00153         CompileTimeError SplitFunctor__findBestSplit_member_was_not_defined;
00154         return 0;
00155     }
00156 
00157     /** default action for creating a terminal Node.
00158         sets the Class probability of the remaining region according to
00159         the class histogram
00160     **/
00161     template<class T, class C, class T2,class C2, class Region, class Random>
00162     int makeTerminalNode(MultiArrayView<2, T, C> features,
00163                       MultiArrayView<2, T2, C2>  labels,
00164                       Region &                   region,
00165                       Random                     randint)
00166     {
00167         Node<e_ConstProbNode> ret(t_data, p_data);
00168         node_ = ret;
00169         if(ext_param_.class_weights_.size() != region.classCounts().size())
00170         {
00171         std::copy(          region.classCounts().begin(),
00172                             region.classCounts().end(),
00173                             ret.prob_begin());
00174         }
00175         else
00176         {
00177         std::transform(     region.classCounts().begin(),
00178                             region.classCounts().end(),
00179                             ext_param_.class_weights_.begin(),
00180                             ret.prob_begin(), std::multiplies<double>());
00181         }
00182         detail::Normalise<RF_Tag>::exec(ret.prob_begin(), ret.prob_end());
00183         ret.weights() = region.size();  
00184         return e_ConstProbNode;
00185     }
00186 
00187 
00188 };
00189 
00190 /** Functor to sort the indices of a feature Matrix by a certain dimension
00191 **/
00192 template<class DataMatrix>
00193 class SortSamplesByDimensions
00194 {
00195     DataMatrix const & data_;
00196     MultiArrayIndex sortColumn_;
00197     double thresVal_;
00198   public:
00199 
00200     SortSamplesByDimensions(DataMatrix const & data, 
00201                             MultiArrayIndex sortColumn,
00202                             double thresVal = 0.0)
00203     : data_(data),
00204       sortColumn_(sortColumn),
00205       thresVal_(thresVal)
00206     {}
00207 
00208     void setColumn(MultiArrayIndex sortColumn)
00209     {
00210         sortColumn_ = sortColumn;
00211     }
00212     void setThreshold(double value)
00213     {
00214         thresVal_ = value; 
00215     }
00216 
00217     bool operator()(MultiArrayIndex l, MultiArrayIndex r) const
00218     {
00219         return data_(l, sortColumn_) < data_(r, sortColumn_);
00220     }
00221     bool operator()(MultiArrayIndex l) const
00222     {
00223         return data_(l, sortColumn_) < thresVal_;
00224     }
00225 };
00226 
00227 template<class DataMatrix>
00228 class DimensionNotEqual
00229 {
00230     DataMatrix const & data_;
00231     MultiArrayIndex sortColumn_;
00232 
00233   public:
00234 
00235     DimensionNotEqual(DataMatrix const & data, 
00236                             MultiArrayIndex sortColumn)
00237     : data_(data),
00238       sortColumn_(sortColumn)
00239     {}
00240 
00241     void setColumn(MultiArrayIndex sortColumn)
00242     {
00243         sortColumn_ = sortColumn;
00244     }
00245 
00246     bool operator()(MultiArrayIndex l, MultiArrayIndex r) const
00247     {
00248         return data_(l, sortColumn_) != data_(r, sortColumn_);
00249     }
00250 };
00251 
00252 template<class DataMatrix>
00253 class SortSamplesByHyperplane
00254 {
00255     DataMatrix const & data_;
00256     Node<i_HyperplaneNode> const & node_;
00257 
00258   public:
00259 
00260     SortSamplesByHyperplane(DataMatrix              const & data, 
00261                             Node<i_HyperplaneNode>  const & node)
00262     :       
00263             data_(data), 
00264             node_()
00265     {}
00266 
00267     /** calculate the distance of a sample point to a hyperplane
00268      */
00269     double operator[](MultiArrayIndex l) const
00270     {
00271         double result_l = -1 * node_.intercept();
00272         for(int ii = 0; ii < node_.columns_size(); ++ii)
00273         {
00274             result_l +=     rowVector(data_, l)[node_.columns_begin()[ii]] 
00275                         *   node_.weights()[ii];
00276         }
00277         return result_l;
00278     }
00279 
00280     bool operator()(MultiArrayIndex l, MultiArrayIndex r) const
00281     {
00282         return (*this)[l]  < (*this)[r];
00283     }
00284 
00285 };
00286 
00287 /** makes a Class Histogram given indices in a labels_ array
00288  *  usage: 
00289  *      MultiArrayView<2, T2, C2> labels = makeSomeLabels()
00290  *      ArrayVector<int> hist(numberOfLabels(labels), 0);
00291  *      RandomForestClassCounter<T2, C2, ArrayVector> counter(labels, hist);
00292  *
00293  *      Container<int> indices = getSomeIndices()
00294  *      std::for_each(indices, counter);
00295  */
00296 template <class DataSource, class CountArray>
00297 class RandomForestClassCounter
00298 {
00299     DataSource  const &     labels_;
00300     CountArray        &     counts_;
00301 
00302   public:
00303 
00304     RandomForestClassCounter(DataSource  const & labels, 
00305                              CountArray & counts)
00306     : labels_(labels),
00307       counts_(counts)
00308     {
00309         reset();
00310     }
00311 
00312     void reset()
00313     {
00314         counts_.init(0);
00315     }
00316 
00317     void operator()(MultiArrayIndex l) const
00318     {
00319         counts_[labels_[l]] +=1;
00320     }
00321 };
00322 
00323 
00324 /** Functor To Calculate the Best possible Split Based on the Gini Index
00325     given Labels and Features along a given Axis
00326 */
00327 
00328 namespace detail
00329 {
00330     template<int N>
00331     class ConstArr
00332     {
00333     public:
00334         double operator[](size_t) const
00335         {
00336             return (double)N;
00337         }
00338     };
00339 
00340 
00341 }
00342 
00343 
00344 
00345 
00346 
00347 /** Functor to calculate the gini impurity
00348  */
00349 class GiniCriterion
00350 {
00351 public:
00352     /**caculate the weighted gini impurity based on class histogram
00353      * and class weights
00354      */
00355     template<class Array, class Array2>
00356     double operator()        (Array     const & hist, 
00357                               Array2    const & weights, 
00358                               double            total = 1.0) const
00359     {
00360         return impurity(hist, weights, total);
00361     }
00362     
00363     /** calculate the gini based impurity based on class histogram
00364      */
00365     template<class Array>
00366     double operator()(Array const & hist, double total = 1.0) const
00367     {
00368         return impurity(hist, total);
00369     }
00370     
00371     /** static version of operator(hist total)
00372      */
00373     template<class Array>
00374     static double impurity(Array const & hist, double total)
00375     {
00376         return impurity(hist, detail::ConstArr<1>(), total);
00377     }
00378 
00379     /** static version of operator(hist, weights, total)
00380      */
00381     template<class Array, class Array2>
00382     static double impurity   (Array     const & hist, 
00383                               Array2    const & weights, 
00384                               double            total)
00385     {
00386 
00387         int     class_count     = hist.size();
00388         double  gini            = 0;
00389         if(class_count == 2)
00390         {
00391             double w            = weights[0] * weights[1];
00392             gini                = w * (hist[0] * hist[1] / total);
00393         }
00394         else
00395         {
00396             for(int ii = 0; ii < class_count; ++ii)
00397             {
00398                 double w        = weights[ii];
00399                 gini           += w*( hist[ii]*( 1.0 - w * hist[ii]/total ) );
00400             }
00401         }
00402         return gini; 
00403     }
00404 };
00405 
00406 
00407 template <class DataSource, class Impurity= GiniCriterion>
00408 class ImpurityLoss
00409 {
00410 
00411     DataSource  const &         labels_;
00412     ArrayVector<double>        counts_;
00413     ArrayVector<double> const & class_weights_;
00414     double                      total_counts_;
00415     Impurity                    impurity_;
00416 
00417   public:
00418 
00419     template<class T>
00420     ImpurityLoss(DataSource  const & labels, 
00421                                 ProblemSpec<T> const & ext_)
00422     : labels_(labels),
00423       counts_(ext_.class_count_, 0.0),
00424       class_weights_(ext_.class_weights_),
00425       total_counts_(0.0)
00426     {}
00427 
00428     void reset()
00429     {
00430         counts_.init(0);
00431         total_counts_ = 0.0;
00432     }
00433 
00434     template<class Counts>
00435     double increment_histogram(Counts const & counts)
00436     {
00437         std::transform(counts.begin(), counts.end(),
00438                        counts_.begin(), counts_.begin(),
00439                        std::plus<double>());
00440         total_counts_ = std::accumulate( counts_.begin(), 
00441                                          counts_.end(),
00442                                          0.0);
00443         return impurity_(counts_, class_weights_, total_counts_);
00444     }
00445 
00446     template<class Counts>
00447     double decrement_histogram(Counts const & counts)
00448     {
00449         std::transform(counts.begin(), counts.end(),
00450                        counts_.begin(), counts_.begin(),
00451                        std::minus<double>());
00452         total_counts_ = std::accumulate( counts_.begin(), 
00453                                          counts_.end(),
00454                                          0.0);
00455         return impurity_(counts_, class_weights_, total_counts_);
00456     }
00457 
00458     template<class Iter>
00459     double increment(Iter begin, Iter end)
00460     {
00461         for(Iter iter = begin; iter != end; ++iter)
00462         {
00463             counts_[labels_[*iter]] +=1;
00464             total_counts_ +=1;
00465         }
00466         return impurity_(counts_, class_weights_, total_counts_);
00467     }
00468 
00469     template<class Iter>
00470     double decrement(Iter begin, Iter end)
00471     {
00472         for(Iter iter = begin; iter != end; ++iter)
00473         {
00474             counts_[labels_[*iter]] -=1;
00475             total_counts_ -=1;
00476         }
00477         return impurity_(counts_, class_weights_, total_counts_);
00478     }
00479 
00480     template<class Iter, class Resp_t>
00481     double init (Iter begin, Iter end, Resp_t resp)
00482     {
00483         reset();
00484         std::copy(resp.begin(), resp.end(), counts_.begin());
00485         total_counts_ = std::accumulate(counts_.begin(), counts_.end(), 0.0); 
00486         return impurity_(counts_,class_weights_, total_counts_);
00487     }
00488     
00489     ArrayVector<double> const & response()
00490     {
00491         return counts_;
00492     }
00493 };
00494 
00495 template <class DataSource>
00496 class RegressionForestCounter
00497 {
00498     typedef MultiArrayShape<2>::type Shp;
00499     DataSource const &      labels_;
00500     ArrayVector <double>    mean_;
00501     ArrayVector <double>    variance_;
00502     ArrayVector <double>    tmp_;
00503     size_t                  count_;
00504 
00505     template<class T>
00506     RegressionForestCounter(DataSource const & labels, 
00507                             ProblemSpec<T> const & ext_)
00508     :
00509         labels_(labels),
00510         mean_(ext_.response_size, 0.0),
00511         variance_(ext_.response_size, 0.0),
00512         tmp_(ext_.response_size),
00513         count_(0)
00514     {}
00515     
00516     //  west's alorithm for incremental variance
00517     // calculation
00518     template<class Iter>
00519     double increment (Iter begin, Iter end)
00520     {
00521         for(Iter iter = begin; iter != end; ++iter)
00522         {
00523             ++count_;
00524             for(int ii = 0; ii < mean_.size(); ++ii)
00525                 tmp_[ii] = labels_(*iter, ii) - mean_[ii]; 
00526             double f  = 1.0 / count_,
00527                    f1 = 1.0 - f;
00528             for(int ii = 0; ii < mean_.size(); ++ii)
00529                 mean_[ii] += f*tmp_[ii]; 
00530             for(int ii = 0; ii < mean_.size(); ++ii)
00531                 variance_[ii] += f1*sq(tmp_[ii]);
00532         }
00533         return std::accumulate(variance_.begin(), 
00534                                variance_.end(),
00535                                0.0,
00536                                std::plus<double>())
00537                 /(count_ -1);
00538     }
00539 
00540     template<class Iter>
00541     double decrement (Iter begin, Iter end)
00542     {
00543         for(Iter iter = begin; iter != end; ++iter)
00544         {
00545             --count_;
00546             for(int ii = 0; ii < mean_.size(); ++ii)
00547                 tmp_[ii] = labels_(*iter, ii) - mean_[ii]; 
00548             double f  = 1.0 / count_,
00549                    f1 = 1.0 + f;
00550             for(int ii = 0; ii < mean_.size(); ++ii)
00551                 mean_[ii] -= f*tmp_[ii]; 
00552             for(int ii = 0; ii < mean_.size(); ++ii)
00553                 variance_[ii] -= f1*sq(tmp_[ii]);
00554         }
00555         return std::accumulate(variance_.begin(), 
00556                                variance_.end(),
00557                                0.0,
00558                                std::plus<double>())
00559                 /(count_ -1);
00560     }
00561 
00562     template<class Iter, class Resp_t>
00563     double init (Iter begin, Iter end, Resp_t resp)
00564     {
00565         reset();
00566         return increment(begin, end);
00567     }
00568     
00569 
00570     ArrayVector<double> const & response()
00571     {
00572         return mean_;
00573     }
00574 
00575     void reset()
00576     {
00577         mean_.init(0.0);
00578         variance_.init(0.0);
00579         count_ = 0; 
00580     }
00581 };
00582 
00583 template<class Tag, class Datatyp>
00584 struct LossTraits;
00585 
00586 struct LSQLoss
00587 {};
00588 
00589 template<class Datatype>
00590 struct LossTraits<GiniCriterion, Datatype>
00591 {
00592     typedef ImpurityLoss<Datatype, GiniCriterion> type;
00593 };
00594 
00595 template<class Datatype>
00596 struct LossTraits<LSQLoss, Datatype>
00597 {
00598     typedef RegressionForestCounter<Datatype> type;
00599 };
00600 
00601 template<class LineSearchLossTag>
00602 class BestGiniOfColumn
00603 {
00604 public:
00605     ArrayVector<double>     class_weights_;
00606     ArrayVector<double>     bestCurrentCounts[2];
00607     double                  min_gini_;
00608     ptrdiff_t               min_index_;
00609     double                  min_threshold_;
00610     ProblemSpec<>           ext_param_;
00611 
00612     BestGiniOfColumn()
00613     {}
00614 
00615     template<class T> 
00616     BestGiniOfColumn(ProblemSpec<T> const & ext)
00617     :
00618         class_weights_(ext.class_weights_),
00619         ext_param_(ext)
00620     {
00621         bestCurrentCounts[0].resize(ext.class_count_);
00622         bestCurrentCounts[1].resize(ext.class_count_);
00623     }
00624     template<class T> 
00625     void set_external_parameters(ProblemSpec<T> const & ext)
00626     {
00627         class_weights_ = ext.class_weights_; 
00628         ext_param_ = ext;
00629         bestCurrentCounts[0].resize(ext.class_count_);
00630         bestCurrentCounts[1].resize(ext.class_count_);
00631     }
00632     /** calculate the best gini split along a Feature Column
00633      * \param column, the feature vector - has to support the [] operator
00634      * \param labels, the label vector 
00635      * \param begin 
00636      * \param end     (in and out)
00637      *                begin and end iterators to the indices of the
00638      *                samples in the current region. 
00639      *                the range begin - end is sorted by the column supplied
00640      *                during function execution.
00641      * \param class_counts
00642      *                class histogram of the range. 
00643      *
00644      *  precondition: begin, end valid range, 
00645      *                class_counts positive integer valued array with the 
00646      *                class counts in the current range.
00647      *                labels.size() >= max(begin, end); 
00648      *  postcondition:
00649      *                begin, end sorted by column given. 
00650      *                min_gini_ contains the minimum gini found or 
00651      *                NumericTraits<double>::max if no split was found.
00652      *                min_index_ countains the splitting index in the range
00653      *                or invalid data if no split was found.
00654      *                BestCirremtcounts[0] and [1] contain the 
00655      *                class histogram of the left and right region of 
00656      *                the left and right regions. 
00657      */
00658     template<   class DataSourceF_t,
00659                 class DataSource_t, 
00660                 class I_Iter, 
00661                 class Array>
00662     void operator()(DataSourceF_t   const & column,
00663                     DataSource_t    const & labels,
00664                     I_Iter                & begin, 
00665                     I_Iter                & end,
00666                     Array           const & region_response)
00667     {
00668         std::sort(begin, end, 
00669                   SortSamplesByDimensions<DataSourceF_t>(column, 0));
00670         typedef typename 
00671             LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
00672         LineSearchLoss left(labels, ext_param_);
00673         LineSearchLoss right(labels, ext_param_);
00674 
00675         
00676 
00677         min_gini_ = right.init(begin, end, region_response);
00678         min_threshold_ = *begin;
00679         min_index_     = 0;
00680         DimensionNotEqual<DataSourceF_t> comp(column, 0); 
00681         
00682         I_Iter iter = begin;
00683         I_Iter next = std::adjacent_find(iter, end, comp);
00684         while( next  != end)
00685         {
00686 
00687             double loss = right.decrement(iter, next + 1) 
00688                     +     left.increment(iter , next + 1);
00689             if(loss < min_gini_ && !closeAtTolerance(loss, min_gini_))
00690             {
00691                 bestCurrentCounts[0] = left.response();
00692                 bestCurrentCounts[1] = right.response();
00693                 min_gini_       = loss < min_gini_? loss : min_gini_;
00694                 min_index_      = next - begin +1 ;
00695                 min_threshold_  = (column[*next] + column[*(next +1)])/2;
00696             }
00697             iter = next +1 ;
00698             next = std::adjacent_find(iter, end, comp);
00699         }
00700 
00701     }
00702 
00703     template<class DataSource_t, class Iter, class Array>
00704     double loss_of_region(DataSource_t const & labels,
00705                           Iter & begin, 
00706                           Iter & end, 
00707                           Array const & region_response) const
00708     {
00709         typedef typename 
00710             LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
00711         LineSearchLoss region_loss(labels, ext_param_);
00712         return 
00713             region_loss.init(begin, end, region_response);
00714     }
00715 
00716 };
00717 
00718 template<class ColumnDecisionFunctor, class Tag>
00719 class ThresholdSplit: public SplitBase<Tag>
00720 {
00721   public:
00722 
00723 
00724     typedef SplitBase<Tag> SB;
00725     
00726     ArrayVector<Int32>          splitColumns;
00727     ColumnDecisionFunctor       bgfunc;
00728 
00729     double                      region_gini_;
00730     ArrayVector<double>         min_gini_;
00731     ArrayVector<ptrdiff_t>      min_indices_;
00732     ArrayVector<double>         min_thresholds_;
00733 
00734     int                         bestSplitIndex;
00735 
00736     double minGini() const
00737     {
00738         return min_gini_[bestSplitIndex];
00739     }
00740     int bestSplitColumn() const
00741     {
00742         return splitColumns[bestSplitIndex];
00743     }
00744     double bestSplitThreshold() const
00745     {
00746         return min_thresholds_[bestSplitIndex];
00747     }
00748 
00749     template<class T>
00750     void set_external_parameters(ProblemSpec<T> const & in)
00751     {
00752         SB::set_external_parameters(in);        
00753         bgfunc.set_external_parameters( SB::ext_param_);
00754         int featureCount_ = SB::ext_param_.column_count_;
00755         splitColumns.resize(featureCount_);
00756         for(int k=0; k<featureCount_; ++k)
00757             splitColumns[k] = k;
00758         min_gini_.resize(featureCount_);
00759         min_indices_.resize(featureCount_);
00760         min_thresholds_.resize(featureCount_);
00761     }
00762 
00763 
00764     template<class T, class C, class T2, class C2, class Region, class Random>
00765     int findBestSplit(MultiArrayView<2, T, C> features,
00766                       MultiArrayView<2, T2, C2>  labels,
00767                       Region & region,
00768                       ArrayVector<Region>& childRegions,
00769                       Random & randint)
00770     {
00771 
00772         typedef typename Region::IndexIterator IndexIterator;
00773         if(region.size() == 0)
00774         {
00775            std::cerr << "SplitFunctor::findBestSplit(): stackentry with 0 examples encountered\n"
00776                         "continuing learning process...."; 
00777         }
00778         // calculate things that haven't been calculated yet. 
00779         
00780         if(std::accumulate(region.classCounts().begin(),
00781                            region.classCounts().end(), 0) != region.size())
00782         {
00783             RandomForestClassCounter<   MultiArrayView<2,T2, C2>, 
00784                                         ArrayVector<Int32> >
00785                 counter(labels, region.classCounts());
00786             std::for_each(  region.begin(), region.end(), counter);
00787             region.classCountsIsValid = true;
00788         }
00789 
00790         // Is the region pure already?
00791         region_gini_ = bgfunc.loss_of_region(labels,
00792                                              region.begin(), 
00793                                              region.end(),
00794                                              region.classCounts());
00795         if(region_gini_ <= SB::ext_param_.precision_)
00796             return  makeTerminalNode(features, labels, region, randint);
00797 
00798         // select columns  to be tried.
00799         for(int ii = 0; ii < SB::ext_param_.actual_mtry_; ++ii)
00800             std::swap(splitColumns[ii], 
00801                       splitColumns[ii+ randint(features.shape(1) - ii)]);
00802 
00803         // find the best gini index
00804         bestSplitIndex              = 0;
00805         double  current_min_gini    = region_gini_;
00806         int     num2try             = features.shape(1);
00807         for(int k=0; k<num2try; ++k)
00808         {
00809             //this functor does all the work
00810             bgfunc(columnVector(features, splitColumns[k]),
00811                    labels, 
00812                    region.begin(), region.end(), 
00813                    region.classCounts());
00814             min_gini_[k]            = bgfunc.min_gini_; 
00815             min_indices_[k]         = bgfunc.min_index_;
00816             min_thresholds_[k]      = bgfunc.min_threshold_;
00817 
00818             if(     bgfunc.min_gini_ < current_min_gini
00819                &&  !closeAtTolerance(bgfunc.min_gini_, current_min_gini))
00820             {
00821                 current_min_gini = bgfunc.min_gini_;
00822                 childRegions[0].classCounts() = bgfunc.bestCurrentCounts[0];
00823                 childRegions[1].classCounts() = bgfunc.bestCurrentCounts[1];
00824                 childRegions[0].classCountsIsValid = true;
00825                 childRegions[1].classCountsIsValid = true;
00826 
00827                 bestSplitIndex   = k;
00828                 num2try = SB::ext_param_.actual_mtry_;
00829             }
00830         }
00831 
00832         // did not find any suitable split
00833         if(closeAtTolerance(current_min_gini, region_gini_))
00834             return  makeTerminalNode(features, labels, region, randint);
00835         
00836         //create a Node for output
00837         Node<i_ThresholdNode>   node(SB::t_data, SB::p_data);
00838         SB::node_ = node;
00839         node.threshold()    = min_thresholds_[bestSplitIndex];
00840         node.column()       = splitColumns[bestSplitIndex];
00841         
00842         // partition the range according to the best dimension 
00843         SortSamplesByDimensions<MultiArrayView<2, T, C> > 
00844             sorter(features, node.column(), node.threshold());
00845         IndexIterator bestSplit =
00846             std::partition(region.begin(), region.end(), sorter);
00847         // Save the ranges of the child stack entries.
00848         childRegions[0].setRange(   region.begin()  , bestSplit       );
00849         childRegions[0].rule = region.rule;
00850         childRegions[0].rule.push_back(std::make_pair(1, 1.0));
00851         childRegions[1].setRange(   bestSplit       , region.end()    );
00852         childRegions[1].rule = region.rule;
00853         childRegions[1].rule.push_back(std::make_pair(1, 1.0));
00854 
00855         return i_ThresholdNode;
00856     }
00857 };
00858 
00859 typedef  ThresholdSplit<BestGiniOfColumn<GiniCriterion> >                 GiniSplit;
00860 typedef  ThresholdSplit<BestGiniOfColumn<LSQLoss>, RegressionTag>         RegressionSplit;
00861 
00862 } //namespace vigra
00863 #endif // VIGRA_RANDOM_FOREST_SPLIT_HXX

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.7.0 (Thu Aug 25 2011)