[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]
00001 /************************************************************************/ 00002 /* */ 00003 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */ 00004 /* */ 00005 /* This file is part of the VIGRA computer vision library. */ 00006 /* The VIGRA Website is */ 00007 /* http://hci.iwr.uni-heidelberg.de/vigra/ */ 00008 /* Please direct questions, bug reports, and contributions to */ 00009 /* ullrich.koethe@iwr.uni-heidelberg.de or */ 00010 /* vigra@informatik.uni-hamburg.de */ 00011 /* */ 00012 /* Permission is hereby granted, free of charge, to any person */ 00013 /* obtaining a copy of this software and associated documentation */ 00014 /* files (the "Software"), to deal in the Software without */ 00015 /* restriction, including without limitation the rights to use, */ 00016 /* copy, modify, merge, publish, distribute, sublicense, and/or */ 00017 /* sell copies of the Software, and to permit persons to whom the */ 00018 /* Software is furnished to do so, subject to the following */ 00019 /* conditions: */ 00020 /* */ 00021 /* The above copyright notice and this permission notice shall be */ 00022 /* included in all copies or substantial portions of the */ 00023 /* Software. */ 00024 /* */ 00025 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */ 00026 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */ 00027 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */ 00028 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */ 00029 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */ 00030 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */ 00031 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */ 00032 /* OTHER DEALINGS IN THE SOFTWARE. */ 00033 /* */ 00034 /************************************************************************/ 00035 00036 #ifndef VIGRA_RF_PREPROCESSING_HXX 00037 #define VIGRA_RF_PREPROCESSING_HXX 00038 00039 00040 #include "rf_common.hxx" 00041 00042 namespace vigra 00043 { 00044 00045 /** Class used while preprocessing (currently used only during learn) 00046 * 00047 * This class is internally used by the Random Forest learn function. 00048 * Different split functors may need to process the data in different manners 00049 * (i.e., regression labels that should not be touched and classification 00050 * labels that must be converted into a integral format) 00051 * 00052 * This Class only exists in specialized versions, where the Tag class is 00053 * fixed. 00054 * 00055 * The Tag class is determined by Splitfunctor::Preprocessor_t . Currently 00056 * it can either be ClassificationTag or RegressionTag. look At the 00057 * RegressionTag specialisation for the basic interface if you ever happen 00058 * to care.... - or need some sort of vague new preprocessor. 00059 * new preprocessor ( Soft labels or whatever) 00060 */ 00061 template<class Tag, class LabelType, class T1, class C1, class T2, class C2> 00062 class Processor; 00063 00064 namespace detail 00065 { 00066 00067 /** Common helper function used in all Processors. 00068 * This function analyses the options struct and calculates the real 00069 * values needed for the current problem (data) 00070 */ 00071 template<class T> 00072 void fill_external_parameters(RF_Traits::Options_t & options, 00073 ProblemSpec<T> & ext_param) 00074 { 00075 // set correct value for mtry 00076 switch(options.mtry_switch_) 00077 { 00078 case RF_SQRT: 00079 ext_param.actual_mtry_ = 00080 int(std::floor( 00081 std::sqrt(double(ext_param.column_count_)) 00082 + 0.5)); 00083 break; 00084 case RF_LOG: 00085 // this is in Breimans original paper 00086 ext_param.actual_mtry_ = 00087 int(1+(std::log(double(ext_param.column_count_)) 00088 /std::log(2.0))); 00089 break; 00090 case RF_FUNCTION: 00091 ext_param.actual_mtry_ = 00092 options.mtry_func_(ext_param.column_count_); 00093 break; 00094 case RF_ALL: 00095 ext_param.actual_mtry_ = ext_param.column_count_; 00096 break; 00097 default: 00098 ext_param.actual_mtry_ = 00099 options.mtry_; 00100 } 00101 // set correct value for msample 00102 switch(options.training_set_calc_switch_) 00103 { 00104 case RF_CONST: 00105 ext_param.actual_msample_ = 00106 options.training_set_size_; 00107 break; 00108 case RF_PROPORTIONAL: 00109 ext_param.actual_msample_ = 00110 (int)std::ceil( options.training_set_proportion_ * 00111 ext_param.row_count_); 00112 break; 00113 case RF_FUNCTION: 00114 ext_param.actual_msample_ = 00115 options.training_set_func_(ext_param.row_count_); 00116 break; 00117 default: 00118 vigra_precondition(1!= 1, "unexpected error"); 00119 00120 } 00121 00122 } 00123 } 00124 00125 00126 00127 /** Preprocessor used during Classification 00128 * 00129 * This class converts the labels int Integral labels which are used by the 00130 * standard split functor to address memory in the node objects. 00131 */ 00132 template<class LabelType, class T1, class C1, class T2, class C2> 00133 class Processor<ClassificationTag, LabelType, T1, C1, T2, C2> 00134 { 00135 public: 00136 typedef Int32 LabelInt; 00137 typedef MultiArrayView<2, T1, C1> Feature_t; 00138 typedef MultiArrayView<2,LabelInt> Label_t; 00139 MultiArrayView<2, T1, C1>const & features_; 00140 MultiArray<2, LabelInt> intLabels_; 00141 MultiArrayView<2, LabelInt> strata_; 00142 00143 template<class T> 00144 Processor(MultiArrayView<2, T1, C1>const & features, 00145 MultiArrayView<2, T2, C2>const & response, 00146 RF_Traits::Options_t &options, 00147 ProblemSpec<T> &ext_param) 00148 : 00149 features_( features) // do not touch the features. 00150 { 00151 // set some of the problem specific parameters 00152 ext_param.column_count_ = features.shape(1); 00153 ext_param.row_count_ = features.shape(0); 00154 ext_param.problem_type_ = CLASSIFICATION; 00155 ext_param.used_ = true; 00156 intLabels_.reshape(response.shape()); 00157 00158 //get the class labels 00159 if(ext_param.class_count_ == 0) 00160 { 00161 // fill up a map with the current labels and then create the 00162 // integral labels. 00163 std::set<T2> labelToInt; 00164 for(MultiArrayIndex k = 0; k < features.shape(0); ++k) 00165 labelToInt.insert(response(k,0)); 00166 std::vector<T2> tmp_(labelToInt.begin(), labelToInt.end()); 00167 ext_param.classes_(tmp_.begin(), tmp_.end()); 00168 } 00169 for(MultiArrayIndex k = 0; k < features.shape(0); ++k) 00170 { 00171 if(std::find(ext_param.classes.begin(), ext_param.classes.end(), response(k,0)) == ext_param.classes.end()) 00172 { 00173 throw std::runtime_error("unknown label type"); 00174 } 00175 else 00176 intLabels_(k, 0) = std::find(ext_param.classes.begin(), ext_param.classes.end(), response(k,0)) 00177 - ext_param.classes.begin(); 00178 } 00179 // set class weights 00180 if(ext_param.class_weights_.size() == 0) 00181 { 00182 ArrayVector<T2> 00183 tmp((std::size_t)ext_param.class_count_, 00184 NumericTraits<T2>::one()); 00185 ext_param.class_weights(tmp.begin(), tmp.end()); 00186 } 00187 00188 // set mtry and msample 00189 detail::fill_external_parameters(options, ext_param); 00190 00191 // set strata 00192 strata_ = intLabels_; 00193 00194 } 00195 00196 /** Access the processed features 00197 */ 00198 MultiArrayView<2, T1, C1>const & features() 00199 { 00200 return features_; 00201 } 00202 00203 /** Access processed labels 00204 */ 00205 MultiArrayView<2, LabelInt>& response() 00206 { 00207 return intLabels_; 00208 } 00209 00210 /** Access processed strata 00211 */ 00212 MultiArrayView<2, LabelInt>& strata() 00213 { 00214 return intLabels_; 00215 } 00216 00217 /** Access strata fraction sized - not used currently 00218 */ 00219 ArrayVectorView< double> strata_prob() 00220 { 00221 return ArrayVectorView< double>(); 00222 } 00223 }; 00224 00225 00226 00227 /** Regression Preprocessor - This basically does not do anything with the 00228 * data. 00229 */ 00230 template<class LabelType, class T1, class C1, class T2, class C2> 00231 class Processor<RegressionTag,LabelType, T1, C1, T2, C2> 00232 { 00233 public: 00234 // only views are created - no data copied. 00235 MultiArrayView<2, T1, C1> features_; 00236 MultiArrayView<2, T2, C2> response_; 00237 RF_Traits::Options_t const & options_; 00238 ProblemSpec<LabelType> const & 00239 ext_param_; 00240 // will only be filled if needed 00241 MultiArray<2, int> strata_; 00242 bool strata_filled; 00243 00244 // copy the views. 00245 template<class T> 00246 Processor( MultiArrayView<2, T1, C1> feats, 00247 MultiArrayView<2, T2, C2> response, 00248 RF_Traits::Options_t options, 00249 ProblemSpec<T> ext_param) 00250 : 00251 features_(feats), 00252 response_(response), 00253 options_(options), 00254 ext_param_(ext_param) 00255 { 00256 detail::fill_external_parameters(options, ext_param); 00257 strata_ = MultiArray<2, int> (MultiArrayShape<2>::type(response_.shape(0), 1)); 00258 } 00259 00260 /** access preprocessed features 00261 */ 00262 MultiArrayView<2, T1, C1> & features() 00263 { 00264 return features_; 00265 } 00266 00267 /** access preprocessed response 00268 */ 00269 MultiArrayView<2, T2, C2> & response() 00270 { 00271 return response_; 00272 } 00273 00274 /** acess strata - this is not used currently 00275 */ 00276 MultiArrayView<2, int> & strata() 00277 { 00278 return strata_; 00279 } 00280 }; 00281 } 00282 #endif //VIGRA_RF_PREPROCESSING_HXX 00283 00284 00285
© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de) |
html generated using doxygen and Python
|