[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

vigra/random_forest/rf_preprocessing.hxx
00001 /************************************************************************/
00002 /*                                                                      */
00003 /*        Copyright 2008-2009 by  Ullrich Koethe and Rahul Nair         */
00004 /*                                                                      */
00005 /*    This file is part of the VIGRA computer vision library.           */
00006 /*    The VIGRA Website is                                              */
00007 /*        http://hci.iwr.uni-heidelberg.de/vigra/                       */
00008 /*    Please direct questions, bug reports, and contributions to        */
00009 /*        ullrich.koethe@iwr.uni-heidelberg.de    or                    */
00010 /*        vigra@informatik.uni-hamburg.de                               */
00011 /*                                                                      */
00012 /*    Permission is hereby granted, free of charge, to any person       */
00013 /*    obtaining a copy of this software and associated documentation    */
00014 /*    files (the "Software"), to deal in the Software without           */
00015 /*    restriction, including without limitation the rights to use,      */
00016 /*    copy, modify, merge, publish, distribute, sublicense, and/or      */
00017 /*    sell copies of the Software, and to permit persons to whom the    */
00018 /*    Software is furnished to do so, subject to the following          */
00019 /*    conditions:                                                       */
00020 /*                                                                      */
00021 /*    The above copyright notice and this permission notice shall be    */
00022 /*    included in all copies or substantial portions of the             */
00023 /*    Software.                                                         */
00024 /*                                                                      */
00025 /*    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND    */
00026 /*    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES   */
00027 /*    OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND          */
00028 /*    NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT       */
00029 /*    HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,      */
00030 /*    WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING      */
00031 /*    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR     */
00032 /*    OTHER DEALINGS IN THE SOFTWARE.                                   */
00033 /*                                                                      */
00034 /************************************************************************/
00035 
00036 #ifndef VIGRA_RF_PREPROCESSING_HXX
00037 #define VIGRA_RF_PREPROCESSING_HXX
00038 
00039 
00040 #include "rf_common.hxx"
00041 
00042 namespace vigra
00043 {
00044 
00045 /** Class used while preprocessing  (currently used only during learn)
00046  *
00047  * This class is internally used by the Random Forest learn function. 
00048  * Different split functors may need to process the data in different manners
00049  * (i.e., regression labels that should not be touched and classification 
00050  * labels that must be converted into a integral format)
00051  *
00052  * This Class only exists in specialized versions, where the Tag class is 
00053  * fixed. 
00054  *
00055  * The Tag class is determined by Splitfunctor::Preprocessor_t . Currently
00056  * it can either be ClassificationTag or RegressionTag.  look At the 
00057  * RegressionTag specialisation for the basic interface if you ever happen
00058  * to care.... - or need some sort of vague new preprocessor.  
00059  * new preprocessor ( Soft labels or whatever)
00060  */
00061 template<class Tag, class LabelType, class T1, class C1, class T2, class C2>
00062 class Processor;
00063 
00064 namespace detail
00065 {
00066 
00067     /** Common helper function used in all Processors. 
00068      * This function analyses the options struct and calculates the real 
00069      * values needed for the current problem (data)
00070      */
00071     template<class T>
00072     void fill_external_parameters(RF_Traits::Options_t & options,
00073                                   ProblemSpec<T> & ext_param)
00074     {
00075         // set correct value for mtry
00076         switch(options.mtry_switch_)
00077         {
00078             case RF_SQRT:
00079                 ext_param.actual_mtry_ =
00080                     int(std::floor(
00081                             std::sqrt(double(ext_param.column_count_))
00082                             + 0.5));
00083                 break;
00084             case RF_LOG:
00085                 // this is in Breimans original paper
00086                 ext_param.actual_mtry_ =
00087                     int(1+(std::log(double(ext_param.column_count_))
00088                            /std::log(2.0)));
00089                 break;
00090             case RF_FUNCTION:
00091                 ext_param.actual_mtry_ =
00092                     options.mtry_func_(ext_param.column_count_);
00093                 break;
00094             case RF_ALL:
00095                 ext_param.actual_mtry_ = ext_param.column_count_;
00096                 break;
00097             default:
00098                 ext_param.actual_mtry_ =
00099                     options.mtry_;
00100         }
00101         // set correct value for msample
00102         switch(options.training_set_calc_switch_)
00103         {
00104             case RF_CONST:
00105                 ext_param.actual_msample_ =
00106                     options.training_set_size_;
00107                 break;
00108             case RF_PROPORTIONAL:
00109                 ext_param.actual_msample_ =
00110                     (int)std::ceil(  options.training_set_proportion_ *
00111                                      ext_param.row_count_);
00112                     break;
00113             case RF_FUNCTION:
00114                 ext_param.actual_msample_ =
00115                     options.training_set_func_(ext_param.row_count_);
00116                 break;
00117             default:
00118                 vigra_precondition(1!= 1, "unexpected error");
00119 
00120         }
00121 
00122     }
00123 }
00124 
00125 
00126 
00127 /** Preprocessor used during Classification
00128  *
00129  * This class converts the labels int Integral labels which are used by the 
00130  * standard split functor to address memory in the node objects.
00131  */
00132 template<class LabelType, class T1, class C1, class T2, class C2>
00133 class Processor<ClassificationTag, LabelType, T1, C1, T2, C2>
00134 {
00135     public:
00136     typedef Int32 LabelInt;
00137     typedef MultiArrayView<2, T1, C1> Feature_t;
00138     typedef MultiArrayView<2,LabelInt> Label_t;
00139     MultiArrayView<2, T1, C1>const &    features_;
00140     MultiArray<2, LabelInt>             intLabels_;
00141     MultiArrayView<2, LabelInt>         strata_;
00142 
00143     template<class T>
00144     Processor(MultiArrayView<2, T1, C1>const & features,   
00145               MultiArrayView<2, T2, C2>const & response,
00146               RF_Traits::Options_t &options,         
00147               ProblemSpec<T> &ext_param)
00148     :
00149         features_( features) // do not touch the features. 
00150     {
00151         // set some of the problem specific parameters 
00152         ext_param.column_count_  = features.shape(1);
00153         ext_param.row_count_     = features.shape(0);
00154         ext_param.problem_type_  = CLASSIFICATION;
00155         ext_param.used_          = true;
00156         intLabels_.reshape(response.shape());
00157 
00158         //get the class labels
00159         if(ext_param.class_count_ == 0)
00160         {
00161             // fill up a map with the current labels and then create the 
00162             // integral labels.
00163             std::set<T2>                    labelToInt;
00164             for(MultiArrayIndex k = 0; k < features.shape(0); ++k)
00165                 labelToInt.insert(response(k,0));
00166             std::vector<T2> tmp_(labelToInt.begin(), labelToInt.end());
00167             ext_param.classes_(tmp_.begin(), tmp_.end());
00168         }
00169         for(MultiArrayIndex k = 0; k < features.shape(0); ++k)
00170         {
00171             if(std::find(ext_param.classes.begin(), ext_param.classes.end(), response(k,0)) == ext_param.classes.end())
00172             {
00173                 throw std::runtime_error("unknown label type");
00174             }
00175             else
00176                 intLabels_(k, 0) = std::find(ext_param.classes.begin(), ext_param.classes.end(), response(k,0))
00177                                     - ext_param.classes.begin();
00178         }
00179         // set class weights
00180         if(ext_param.class_weights_.size() == 0)
00181         {
00182             ArrayVector<T2> 
00183                 tmp((std::size_t)ext_param.class_count_, 
00184                     NumericTraits<T2>::one());
00185             ext_param.class_weights(tmp.begin(), tmp.end());
00186         }
00187 
00188         // set mtry and msample
00189         detail::fill_external_parameters(options, ext_param);
00190 
00191         // set strata
00192         strata_ = intLabels_;
00193 
00194     }
00195 
00196     /** Access the processed features
00197      */
00198     MultiArrayView<2, T1, C1>const & features()
00199     {
00200         return features_;
00201     }
00202 
00203     /** Access processed labels
00204      */
00205     MultiArrayView<2, LabelInt>& response()
00206     {
00207         return intLabels_;
00208     }
00209 
00210     /** Access processed strata
00211      */
00212     MultiArrayView<2, LabelInt>&  strata()
00213     {
00214         return intLabels_;
00215     }
00216 
00217     /** Access strata fraction sized - not used currently
00218      */
00219     ArrayVectorView< double> strata_prob()
00220     {
00221         return ArrayVectorView< double>();
00222     }
00223 };
00224 
00225 
00226 
00227 /** Regression Preprocessor - This basically does not do anything with the
00228  * data.
00229  */
00230 template<class LabelType, class T1, class C1, class T2, class C2>
00231 class Processor<RegressionTag,LabelType, T1, C1, T2, C2>
00232 {
00233 public:
00234     // only views are created - no data copied.
00235     MultiArrayView<2, T1, C1>   features_;
00236     MultiArrayView<2, T2, C2>   response_;
00237     RF_Traits::Options_t const & options_;
00238     ProblemSpec<LabelType> const &
00239                                 ext_param_;
00240     // will only be filled if needed
00241     MultiArray<2, int>      strata_;
00242     bool strata_filled;
00243 
00244     // copy the views.
00245     template<class T>
00246     Processor(  MultiArrayView<2, T1, C1>   feats,
00247                 MultiArrayView<2, T2, C2>   response,
00248                 RF_Traits::Options_t            options,
00249                 ProblemSpec<T>  ext_param)
00250     :
00251         features_(feats),
00252         response_(response),
00253         options_(options),
00254         ext_param_(ext_param)
00255     {
00256         detail::fill_external_parameters(options, ext_param);
00257         strata_ = MultiArray<2, int> (MultiArrayShape<2>::type(response_.shape(0), 1));
00258     }
00259 
00260     /** access preprocessed features
00261      */
00262     MultiArrayView<2, T1, C1> & features()
00263     {
00264         return features_;
00265     }
00266 
00267     /** access preprocessed response
00268      */
00269     MultiArrayView<2, T2, C2> & response()
00270     {
00271         return response_;
00272     }
00273 
00274     /** acess strata - this is not used currently
00275      */
00276     MultiArrayView<2, int> & strata()
00277     {
00278         return strata_;
00279     }
00280 };
00281 }
00282 #endif //VIGRA_RF_PREPROCESSING_HXX
00283 
00284 
00285 

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.7.0 (Thu Aug 25 2011)