[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]
00001 /************************************************************************/ 00002 /* */ 00003 /* Copyright 2009 by Rahul Nair and Ullrich Koethe */ 00004 /* */ 00005 /* This file is part of the VIGRA computer vision library. */ 00006 /* The VIGRA Website is */ 00007 /* http://hci.iwr.uni-heidelberg.de/vigra/ */ 00008 /* Please direct questions, bug reports, and contributions to */ 00009 /* ullrich.koethe@iwr.uni-heidelberg.de or */ 00010 /* vigra@informatik.uni-hamburg.de */ 00011 /* */ 00012 /* Permission is hereby granted, free of charge, to any person */ 00013 /* obtaining a copy of this software and associated documentation */ 00014 /* files (the "Software"), to deal in the Software without */ 00015 /* restriction, including without limitation the rights to use, */ 00016 /* copy, modify, merge, publish, distribute, sublicense, and/or */ 00017 /* sell copies of the Software, and to permit persons to whom the */ 00018 /* Software is furnished to do so, subject to the following */ 00019 /* conditions: */ 00020 /* */ 00021 /* The above copyright notice and this permission notice shall be */ 00022 /* included in all copies or substantial portions of the */ 00023 /* Software. */ 00024 /* */ 00025 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */ 00026 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */ 00027 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */ 00028 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */ 00029 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */ 00030 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */ 00031 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */ 00032 /* OTHER DEALINGS IN THE SOFTWARE. */ 00033 /* */ 00034 /************************************************************************/ 00035 00036 00037 #ifndef VIGRA_RANDOM_FOREST_IMPEX_HDF5_HXX 00038 #define VIGRA_RANDOM_FOREST_IMPEX_HDF5_HXX 00039 00040 #include "random_forest.hxx" 00041 #include "hdf5impex.hxx" 00042 #include <cstdio> 00043 #include <string> 00044 00045 #ifdef HasHDF5 00046 00047 namespace vigra 00048 { 00049 00050 namespace detail 00051 { 00052 00053 00054 /** shallow search the hdf5 group for containing elements 00055 * returns negative value if unsuccessful 00056 * \param grp_id hid_t containing path to group. 00057 * \param cont reference to container that supports 00058 * insert(). valuetype of cont must be 00059 * std::string 00060 */ 00061 template<class Container> 00062 bool find_groups_hdf5(hid_t grp_id, Container &cont) 00063 { 00064 00065 //get group info 00066 #if (H5_VERS_MAJOR == 1 && H5_VERS_MINOR <= 6) 00067 hsize_t size; 00068 H5Gget_num_objs(grp_id, &size); 00069 #else 00070 hsize_t size; 00071 H5G_info_t ginfo; 00072 herr_t status; 00073 status = H5Gget_info (grp_id , &ginfo); 00074 if(status < 0) 00075 std::runtime_error("find_groups_hdf5():" 00076 "problem while getting group info"); 00077 size = ginfo.nlinks; 00078 #endif 00079 for(hsize_t ii = 0; ii < size; ++ii) 00080 { 00081 #if (H5_VERS_MAJOR == 1 && H5_VERS_MINOR <= 6) 00082 ssize_t buffer_size = 00083 H5Gget_objname_by_idx(grp_id, 00084 ii, NULL, 0 ) + 1; 00085 #else 00086 std::ptrdiff_t buffer_size = 00087 H5Lget_name_by_idx(grp_id, ".", 00088 H5_INDEX_NAME, 00089 H5_ITER_INC, 00090 ii, 0, 0, H5P_DEFAULT)+1; 00091 #endif 00092 ArrayVector<char> buffer(buffer_size); 00093 #if (H5_VERS_MAJOR == 1 && H5_VERS_MINOR <= 6) 00094 buffer_size = 00095 H5Gget_objname_by_idx(grp_id, 00096 ii, buffer.data(), 00097 (size_t)buffer_size ); 00098 #else 00099 buffer_size = 00100 H5Lget_name_by_idx(grp_id, ".", 00101 H5_INDEX_NAME, 00102 H5_ITER_INC, 00103 ii, buffer.data(), 00104 (size_t)buffer_size, 00105 H5P_DEFAULT); 00106 #endif 00107 cont.insert(cont.end(), std::string(buffer.data())); 00108 } 00109 return true; 00110 } 00111 00112 00113 /** shallow search the hdf5 group for containing elements 00114 * returns negative value if unsuccessful 00115 * \param filename name of hdf5 file 00116 * \param groupname path in hdf5 file 00117 * \param cont reference to container that supports 00118 * insert(). valuetype of cont must be 00119 * std::string 00120 */ 00121 template<class Container> 00122 bool find_groups_hdf5(std::string filename, 00123 std::string groupname, 00124 Container &cont) 00125 { 00126 //check if file exists 00127 FILE* pFile; 00128 pFile = fopen ( filename.c_str(), "r" ); 00129 if ( pFile == NULL) 00130 { 00131 return 0; 00132 } 00133 //open the file 00134 HDF5Handle file_id(H5Fopen(filename.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT), 00135 &H5Fclose, "Unable to open HDF5 file"); 00136 HDF5Handle grp_id; 00137 if(groupname == "") 00138 { 00139 grp_id = HDF5Handle(file_id, 0, ""); 00140 } 00141 else 00142 { 00143 grp_id = HDF5Handle(H5Gopen(file_id, groupname.c_str(), H5P_DEFAULT), 00144 &H5Gclose, "Unable to open group"); 00145 00146 } 00147 bool res = find_groups_hdf5(grp_id, cont); 00148 return res; 00149 } 00150 00151 inline int get_number_of_digits(int in) 00152 { 00153 int num = 0; 00154 int i = 1; 00155 while(double(in) / double(i) >= 1) 00156 { 00157 i *= 10; 00158 num += 1; 00159 } 00160 if(num == 0) 00161 num = 1; 00162 return num; 00163 } 00164 00165 inline std::string make_padded_number(int number, int max_number) 00166 { 00167 int max_digit_ct = get_number_of_digits(max_number); 00168 char buffer [50]; 00169 std::sprintf(buffer, "%d", number); 00170 std::string padding = ""; 00171 std::string numeral = buffer; 00172 int digit_ct = get_number_of_digits(number); 00173 for(int gg = 0; gg < max_digit_ct - digit_ct; ++ gg) 00174 padding = padding + "0"; 00175 return padding + numeral; 00176 } 00177 00178 /** write a ArrayVector to a hdf5 dataset. 00179 */ 00180 template<class U, class T> 00181 void write_array_2_hdf5(hid_t & id, 00182 ArrayVector<U> const & arr, 00183 std::string const & name, 00184 T type) 00185 { 00186 hsize_t size = arr.size(); 00187 vigra_postcondition(H5LTmake_dataset (id, 00188 name.c_str(), 00189 1, 00190 &size, 00191 type, 00192 arr.begin()) 00193 >= 0, 00194 "write_array_2_hdf5():" 00195 "unable to write dataset"); 00196 } 00197 00198 00199 template<class U, class T> 00200 void write_hdf5_2_array(hid_t & id, 00201 ArrayVector<U> & arr, 00202 std::string const & name, 00203 T type) 00204 { 00205 // The last three values of get_dataset_info can be NULL 00206 // my EFFING FOOT! that is valid for HDF5 1.8 but not for 00207 // 1.6 - but documented the other way around AAARRHGHGHH 00208 hsize_t size; 00209 H5T_class_t a; 00210 size_t b; 00211 vigra_postcondition(H5LTget_dataset_info(id, 00212 name.c_str(), 00213 &size, 00214 &a, 00215 &b) >= 0, 00216 "write_hdf5_2_array(): " 00217 "Unable to locate dataset"); 00218 arr.resize((typename ArrayVector<U>::size_type)size); 00219 vigra_postcondition(H5LTread_dataset (id, 00220 name.c_str(), 00221 type, 00222 arr.data()) >= 0, 00223 "write_array_2_hdf5():" 00224 "unable to read dataset"); 00225 } 00226 00227 00228 inline void options_import_HDF5(hid_t & group_id, 00229 RandomForestOptions & opt, 00230 std::string name) 00231 { 00232 ArrayVector<double> serialized_options; 00233 write_hdf5_2_array(group_id, serialized_options, 00234 name, H5T_NATIVE_DOUBLE); 00235 opt.unserialize(serialized_options.begin(), 00236 serialized_options.end()); 00237 } 00238 00239 inline void options_export_HDF5(hid_t & group_id, 00240 RandomForestOptions const & opt, 00241 std::string name) 00242 { 00243 ArrayVector<double> serialized_options(opt.serialized_size()); 00244 opt.serialize(serialized_options.begin(), 00245 serialized_options.end()); 00246 write_array_2_hdf5(group_id, serialized_options, 00247 name, H5T_NATIVE_DOUBLE); 00248 } 00249 00250 struct MyT 00251 { 00252 enum type { INT8 = 1, INT16 = 2, INT32 =3, INT64=4, 00253 UINT8 = 5, UINT16 = 6, UINT32= 7, UINT64= 8, 00254 FLOAT = 9, DOUBLE = 10, OTHER = 3294}; 00255 }; 00256 00257 00258 00259 #define create_type_of(TYPE, ENUM) \ 00260 inline MyT::type type_of(TYPE)\ 00261 {\ 00262 return MyT::ENUM; \ 00263 } 00264 create_type_of(Int8, INT8) 00265 create_type_of(Int16, INT16) 00266 create_type_of(Int32, INT32) 00267 create_type_of(Int64, INT64) 00268 create_type_of(UInt8, UINT8) 00269 create_type_of(UInt16, UINT16) 00270 create_type_of(UInt32, UINT32) 00271 create_type_of(UInt64, UINT64) 00272 create_type_of(float, FLOAT) 00273 create_type_of(double, DOUBLE) 00274 #undef create_type_of 00275 00276 inline MyT::type type_of_hid_t(hid_t group_id, std::string name) 00277 { 00278 hid_t m_dataset_handle = 00279 H5Dopen(group_id, name.c_str(), H5P_DEFAULT); 00280 hid_t datatype = H5Dget_type(m_dataset_handle); 00281 H5T_class_t dataclass = H5Tget_class(datatype); 00282 size_t datasize = H5Tget_size(datatype); 00283 H5T_sign_t datasign = H5Tget_sign(datatype); 00284 MyT::type result = MyT::OTHER; 00285 if(dataclass == H5T_FLOAT) 00286 { 00287 if(datasize == 4) 00288 result = MyT::FLOAT; 00289 else if(datasize == 8) 00290 result = MyT::DOUBLE; 00291 } 00292 else if(dataclass == H5T_INTEGER) 00293 { 00294 if(datasign == H5T_SGN_NONE) 00295 { 00296 if(datasize == 1) 00297 result = MyT::UINT8; 00298 else if(datasize == 2) 00299 result = MyT::UINT16; 00300 else if(datasize == 4) 00301 result = MyT::UINT32; 00302 else if(datasize == 8) 00303 result = MyT::UINT64; 00304 } 00305 else 00306 { 00307 if(datasize == 1) 00308 result = MyT::INT8; 00309 else if(datasize == 2) 00310 result = MyT::INT16; 00311 else if(datasize == 4) 00312 result = MyT::INT32; 00313 else if(datasize == 8) 00314 result = MyT::INT64; 00315 } 00316 } 00317 H5Tclose(datatype); 00318 H5Dclose(m_dataset_handle); 00319 return result; 00320 } 00321 00322 template<class T> 00323 void problemspec_import_HDF5(hid_t & group_id, 00324 ProblemSpec<T> & param, 00325 std::string name) 00326 { 00327 hid_t param_id = H5Gopen (group_id, 00328 name.c_str(), 00329 H5P_DEFAULT); 00330 00331 vigra_postcondition(param_id >= 0, 00332 "problemspec_import_HDF5():" 00333 " Unable to open external parameters"); 00334 00335 //get a map containing all the double fields 00336 std::set<std::string> ext_set; 00337 find_groups_hdf5(param_id, ext_set); 00338 std::map<std::string, ArrayVector <double> > ext_map; 00339 std::set<std::string>::iterator iter; 00340 if(ext_set.find(std::string("labels")) == ext_set.end()) 00341 std::runtime_error("labels are missing"); 00342 for(iter = ext_set.begin(); iter != ext_set.end(); ++ iter) 00343 { 00344 if(*iter != std::string("labels")) 00345 { 00346 ext_map[*iter] = ArrayVector<double>(); 00347 write_hdf5_2_array(param_id, ext_map[*iter], 00348 *iter, H5T_NATIVE_DOUBLE); 00349 } 00350 } 00351 param.make_from_map(ext_map); 00352 //load_class_labels 00353 switch(type_of_hid_t(param_id,"labels" )) 00354 { 00355 #define SOME_CASE(type_, enum_) \ 00356 case MyT::enum_ :\ 00357 {\ 00358 ArrayVector<type_> tmp;\ 00359 write_hdf5_2_array(param_id, tmp, "labels", H5T_NATIVE_##enum_);\ 00360 param.classes_(tmp.begin(), tmp.end());\ 00361 }\ 00362 break; 00363 SOME_CASE(UInt8, UINT8); 00364 SOME_CASE(UInt16, UINT16); 00365 SOME_CASE(UInt32, UINT32); 00366 SOME_CASE(UInt64, UINT64); 00367 SOME_CASE(Int8, INT8); 00368 SOME_CASE(Int16, INT16); 00369 SOME_CASE(Int32, INT32); 00370 SOME_CASE(Int64, INT64); 00371 SOME_CASE(double, DOUBLE); 00372 SOME_CASE(float, FLOAT); 00373 default: 00374 std::runtime_error("exportRF_HDF5(): unknown class type"); 00375 #undef SOME_CASE 00376 } 00377 H5Gclose(param_id); 00378 } 00379 00380 template<class T> 00381 void problemspec_export_HDF5(hid_t & group_id, 00382 ProblemSpec<T> const & param, 00383 std::string name) 00384 { 00385 hid_t param_id = H5Gcreate(group_id, name.c_str(), 00386 H5P_DEFAULT, 00387 H5P_DEFAULT, 00388 H5P_DEFAULT); 00389 vigra_postcondition(param_id >= 0, 00390 "problemspec_export_HDF5():" 00391 " Unable to create external parameters"); 00392 00393 //get a map containing all the double fields 00394 std::map<std::string, ArrayVector<double> > serialized_param; 00395 param.make_map(serialized_param); 00396 std::map<std::string, ArrayVector<double> >::iterator iter; 00397 for(iter = serialized_param.begin(); iter != serialized_param.end(); ++iter) 00398 write_array_2_hdf5(param_id, iter->second, iter->first, H5T_NATIVE_DOUBLE); 00399 00400 //save class_labels 00401 switch(type_of(param.classes[0])) 00402 { 00403 #define SOME_CASE(type) \ 00404 case MyT::type:\ 00405 write_array_2_hdf5(param_id, param.classes, "labels", H5T_NATIVE_##type);\ 00406 break; 00407 SOME_CASE(UINT8); 00408 SOME_CASE(UINT16); 00409 SOME_CASE(UINT32); 00410 SOME_CASE(UINT64); 00411 SOME_CASE(INT8); 00412 SOME_CASE(INT16); 00413 SOME_CASE(INT32); 00414 SOME_CASE(INT64); 00415 SOME_CASE(DOUBLE); 00416 SOME_CASE(FLOAT); 00417 default: 00418 std::runtime_error("exportRF_HDF5(): unknown class type"); 00419 #undef SOME_CASE 00420 } 00421 H5Gclose(param_id); 00422 } 00423 00424 inline void dt_import_HDF5( hid_t & group_id, 00425 RF_Traits::DecisionTree_t & tree, 00426 std::string name) 00427 { 00428 //check if ext_param was written and write it if not 00429 if(tree.ext_param_.actual_msample_ == 0) 00430 { 00431 problemspec_import_HDF5(group_id, tree.ext_param_, "_ext_param"); 00432 tree.classCount_ = tree.ext_param_.class_count_; 00433 } 00434 00435 hid_t tree_id =H5Gopen (group_id, name.c_str(), H5P_DEFAULT); 00436 //write down topology 00437 write_hdf5_2_array(tree_id, 00438 tree.topology_, 00439 "topology", 00440 H5T_NATIVE_INT); 00441 //write down parameters 00442 write_hdf5_2_array(tree_id, 00443 tree.parameters_, 00444 "parameters", 00445 H5T_NATIVE_DOUBLE); 00446 H5Gclose(tree_id); 00447 } 00448 00449 00450 inline void dt_export_HDF5( hid_t & group_id, 00451 RF_Traits::DecisionTree_t const & tree, 00452 std::string name) 00453 { 00454 //check if ext_param was written and write it if not 00455 hid_t e_id = H5Gopen (group_id, 00456 "_ext_param", 00457 H5P_DEFAULT); 00458 if(e_id < 0) 00459 { 00460 problemspec_export_HDF5(group_id, 00461 tree.ext_param_, 00462 "_ext_param"); 00463 } 00464 else H5Gclose(e_id); 00465 00466 //make the folder for the tree. 00467 hid_t tree_id = H5Gcreate(group_id, name.c_str(), 00468 H5P_DEFAULT, 00469 H5P_DEFAULT, 00470 H5P_DEFAULT); 00471 //write down topology 00472 write_array_2_hdf5(tree_id, 00473 tree.topology_, 00474 "topology", 00475 H5T_NATIVE_INT); 00476 //write down parameters 00477 write_array_2_hdf5(tree_id, 00478 tree.parameters_, 00479 "parameters", 00480 H5T_NATIVE_DOUBLE); 00481 H5Gclose(tree_id); 00482 } 00483 } //namespace detail 00484 00485 template<class T> 00486 bool rf_export_HDF5(RandomForest<T> const &rf, 00487 std::string filename, 00488 std::string pathname = "", 00489 bool overwriteflag = false) 00490 { 00491 using detail::make_padded_number; 00492 using detail::options_export_HDF5; 00493 using detail::problemspec_export_HDF5; 00494 using detail::dt_export_HDF5; 00495 //if file exists delete it. 00496 FILE* pFile = fopen ( filename.c_str(), "r" ); 00497 if ( pFile != NULL && !overwriteflag) 00498 return 0; 00499 else if(pFile != 0 &&std::remove(filename.c_str()) != 0) 00500 return 0; 00501 00502 //create a new file and group. 00503 hid_t file_id = H5Fcreate(filename.c_str(), H5F_ACC_TRUNC, 00504 H5P_DEFAULT, 00505 H5P_DEFAULT); 00506 vigra_postcondition(file_id >= 0, 00507 "rf_export_HDF5(): Unable to open file."); 00508 std::cerr << pathname.c_str(); 00509 hid_t group_id = pathname== "" ? 00510 file_id 00511 : H5Gcreate(file_id, pathname.c_str(), 00512 H5P_DEFAULT, 00513 H5P_DEFAULT, 00514 H5P_DEFAULT); 00515 00516 vigra_postcondition(group_id >= 0, 00517 "rf_export_HDF5(): Unable to create group"); 00518 00519 //save serialized options 00520 options_export_HDF5(group_id, rf.options(), "_options"); 00521 //save external parameters 00522 problemspec_export_HDF5(group_id, rf.ext_param(), "_ext_param"); 00523 //save trees 00524 00525 int tree_count = rf.options_.tree_count_; 00526 for(int ii = 0; ii < tree_count; ++ii) 00527 { 00528 std::string treename = "Tree_" + 00529 make_padded_number(ii, tree_count -1); 00530 dt_export_HDF5(group_id, rf.tree(ii), treename); 00531 } 00532 00533 //clean up the mess 00534 if(pathname != "") 00535 H5Gclose(group_id); 00536 H5Fclose(file_id); 00537 00538 return 1; 00539 } 00540 00541 00542 template<class T> 00543 bool rf_import_HDF5(RandomForest<T> &rf, 00544 std::string filename, 00545 std::string pathname = "") 00546 { 00547 using detail::find_groups_hdf5; 00548 using detail::options_import_HDF5; 00549 using detail::problemspec_import_HDF5; 00550 using detail::dt_export_HDF5; 00551 //if file exists delete it. 00552 FILE* pFile = fopen ( filename.c_str(), "r" ); 00553 if ( pFile == NULL) 00554 return 0; 00555 00556 //open file 00557 hid_t file_id = H5Fopen (filename.c_str(), 00558 H5F_ACC_RDONLY, 00559 H5P_DEFAULT); 00560 00561 vigra_postcondition(file_id >= 0, 00562 "rf_import_HDF5(): Unable to open file."); 00563 hid_t group_id = pathname== "" ? 00564 file_id 00565 : H5Gopen (file_id, 00566 pathname.c_str(), 00567 H5P_DEFAULT); 00568 00569 vigra_postcondition(group_id >= 0, 00570 "rf_export_HDF5(): Unable to create group"); 00571 00572 //get serialized options 00573 options_import_HDF5(group_id, rf.options_, "_options"); 00574 //save external parameters 00575 problemspec_import_HDF5(group_id, rf.ext_param_, "_ext_param"); 00576 // TREE SAVING TIME 00577 // get all groups in base path 00578 00579 std::set<std::string> tree_set; 00580 std::set<std::string>::iterator iter; 00581 find_groups_hdf5(filename, pathname, tree_set); 00582 00583 for(iter = tree_set.begin(); iter != tree_set.end(); ++iter) 00584 { 00585 if((*iter)[0] != '_') 00586 { 00587 rf.trees_.push_back(RF_Traits::DecisionTree_t(rf.ext_param_)); 00588 dt_import_HDF5(group_id, rf.trees_.back(), *iter); 00589 } 00590 } 00591 00592 //clean up the mess 00593 if(pathname != "") 00594 H5Gclose(group_id); 00595 H5Fclose(file_id); 00596 rf.tree_indices_.resize(rf.tree_count()); 00597 for(int ii = 0; ii < rf.tree_count(); ++ii) 00598 rf.tree_indices_[ii] = ii; 00599 return 1; 00600 } 00601 } // namespace vigra 00602 00603 #endif // HasHDF5 00604 00605 #endif // VIGRA_RANDOM_FOREST_HDF5_IMPEX_HXX 00606
© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de) |
html generated using doxygen and Python
|