[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

vigra/random_forest_hdf5_impex.hxx
00001 /************************************************************************/
00002 /*                                                                      */
00003 /*       Copyright 2009 by Rahul Nair and  Ullrich Koethe               */
00004 /*                                                                      */
00005 /*    This file is part of the VIGRA computer vision library.           */
00006 /*    The VIGRA Website is                                              */
00007 /*        http://hci.iwr.uni-heidelberg.de/vigra/                       */
00008 /*    Please direct questions, bug reports, and contributions to        */
00009 /*        ullrich.koethe@iwr.uni-heidelberg.de    or                    */
00010 /*        vigra@informatik.uni-hamburg.de                               */
00011 /*                                                                      */
00012 /*    Permission is hereby granted, free of charge, to any person       */
00013 /*    obtaining a copy of this software and associated documentation    */
00014 /*    files (the "Software"), to deal in the Software without           */
00015 /*    restriction, including without limitation the rights to use,      */
00016 /*    copy, modify, merge, publish, distribute, sublicense, and/or      */
00017 /*    sell copies of the Software, and to permit persons to whom the    */
00018 /*    Software is furnished to do so, subject to the following          */
00019 /*    conditions:                                                       */
00020 /*                                                                      */
00021 /*    The above copyright notice and this permission notice shall be    */
00022 /*    included in all copies or substantial portions of the             */
00023 /*    Software.                                                         */
00024 /*                                                                      */
00025 /*    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND    */
00026 /*    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES   */
00027 /*    OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND          */
00028 /*    NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT       */
00029 /*    HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,      */
00030 /*    WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING      */
00031 /*    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR     */
00032 /*    OTHER DEALINGS IN THE SOFTWARE.                                   */
00033 /*                                                                      */
00034 /************************************************************************/
00035 
00036 
00037 #ifndef VIGRA_RANDOM_FOREST_IMPEX_HDF5_HXX
00038 #define VIGRA_RANDOM_FOREST_IMPEX_HDF5_HXX
00039 
00040 #include "random_forest.hxx"
00041 #include "hdf5impex.hxx"
00042 #include <cstdio>
00043 #include <string>
00044 
00045 #ifdef HasHDF5
00046 
00047 namespace vigra 
00048 {
00049 
00050 namespace detail
00051 {
00052 
00053 
00054 /** shallow search the hdf5 group for containing elements
00055  * returns negative value if unsuccessful
00056  * \param grp_id    hid_t containing path to group.
00057  * \param cont      reference to container that supports
00058  *                  insert(). valuetype of cont must be
00059  *                  std::string
00060  */
00061 template<class Container>
00062 bool find_groups_hdf5(hid_t grp_id, Container &cont)
00063 {
00064     
00065     //get group info
00066 #if (H5_VERS_MAJOR == 1 && H5_VERS_MINOR <= 6)
00067     hsize_t size;
00068     H5Gget_num_objs(grp_id, &size);
00069 #else
00070     hsize_t size;
00071     H5G_info_t ginfo;
00072     herr_t      status; 
00073     status = H5Gget_info (grp_id , &ginfo);
00074     if(status < 0)
00075         std::runtime_error("find_groups_hdf5():"
00076                            "problem while getting group info");
00077     size = ginfo.nlinks;
00078 #endif
00079     for(hsize_t ii = 0; ii < size; ++ii)
00080     {
00081 #if (H5_VERS_MAJOR == 1 && H5_VERS_MINOR <= 6)
00082         ssize_t buffer_size = 
00083                 H5Gget_objname_by_idx(grp_id, 
00084                                       ii, NULL, 0 ) + 1;
00085 #else
00086         std::ptrdiff_t buffer_size =
00087                 H5Lget_name_by_idx(grp_id, ".",
00088                                    H5_INDEX_NAME,
00089                                    H5_ITER_INC,
00090                                    ii, 0, 0, H5P_DEFAULT)+1;
00091 #endif
00092         ArrayVector<char> buffer(buffer_size);
00093 #if (H5_VERS_MAJOR == 1 && H5_VERS_MINOR <= 6)
00094         buffer_size = 
00095                 H5Gget_objname_by_idx(grp_id, 
00096                                       ii, buffer.data(), 
00097                                       (size_t)buffer_size );
00098 #else
00099         buffer_size =
00100                 H5Lget_name_by_idx(grp_id, ".",
00101                                    H5_INDEX_NAME,
00102                                    H5_ITER_INC,
00103                                    ii, buffer.data(),
00104                                    (size_t)buffer_size,
00105                                    H5P_DEFAULT);
00106 #endif
00107         cont.insert(cont.end(), std::string(buffer.data()));
00108     }
00109     return true;
00110 }
00111 
00112 
00113 /** shallow search the hdf5 group for containing elements
00114  * returns negative value if unsuccessful
00115  * \param filename name of hdf5 file
00116  * \param groupname path in hdf5 file
00117  * \param cont      reference to container that supports
00118  *                  insert(). valuetype of cont must be
00119  *                  std::string
00120  */
00121 template<class Container>
00122 bool find_groups_hdf5(std::string filename, 
00123                               std::string groupname, 
00124                               Container &cont)
00125 {
00126     //check if file exists
00127     FILE* pFile;
00128     pFile = fopen ( filename.c_str(), "r" );
00129     if ( pFile == NULL)
00130     {   
00131         return 0;
00132     }
00133     //open the file
00134     HDF5Handle file_id(H5Fopen(filename.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT),
00135                        &H5Fclose, "Unable to open HDF5 file");
00136     HDF5Handle grp_id;
00137     if(groupname == "")
00138     {
00139         grp_id = HDF5Handle(file_id, 0, "");
00140     }
00141     else
00142     {
00143         grp_id = HDF5Handle(H5Gopen(file_id, groupname.c_str(), H5P_DEFAULT),
00144                             &H5Gclose, "Unable to open group");
00145 
00146     }
00147     bool res =  find_groups_hdf5(grp_id, cont); 
00148     return res; 
00149 }
00150 
00151 inline int get_number_of_digits(int in)
00152 {
00153     int num = 0;
00154     int i = 1; 
00155     while(double(in) / double(i) >= 1)
00156     {
00157         i *= 10;
00158         num += 1; 
00159     }
00160     if(num == 0)
00161         num = 1;
00162     return num; 
00163 }
00164 
00165 inline std::string make_padded_number(int number, int max_number)
00166 {
00167     int max_digit_ct = get_number_of_digits(max_number);
00168     char buffer [50];
00169     std::sprintf(buffer, "%d", number);
00170     std::string padding = "";
00171     std::string numeral = buffer;
00172     int digit_ct = get_number_of_digits(number); 
00173     for(int gg = 0; gg < max_digit_ct - digit_ct; ++ gg)
00174         padding = padding + "0";
00175     return padding + numeral;
00176 }
00177 
00178 /** write a ArrayVector to a hdf5 dataset.
00179  */
00180 template<class U, class T>
00181 void write_array_2_hdf5(hid_t & id, 
00182                         ArrayVector<U> const & arr, 
00183                         std::string    const & name, 
00184                         T  type) 
00185 {
00186     hsize_t size = arr.size(); 
00187     vigra_postcondition(H5LTmake_dataset (id, 
00188                                           name.c_str(), 
00189                                           1, 
00190                                           &size, 
00191                                           type, 
00192                                           arr.begin()) 
00193                         >= 0,
00194                         "write_array_2_hdf5():"
00195                         "unable to write dataset");
00196 }
00197 
00198 
00199 template<class U, class T>
00200 void write_hdf5_2_array(hid_t & id, 
00201                         ArrayVector<U>       & arr, 
00202                         std::string    const & name, 
00203                         T  type) 
00204 {   
00205     // The last three values of get_dataset_info can be NULL
00206     // my EFFING FOOT! that is valid for HDF5 1.8 but not for
00207     // 1.6 - but documented the other way around AAARRHGHGHH
00208     hsize_t size; 
00209     H5T_class_t a; 
00210     size_t b;
00211     vigra_postcondition(H5LTget_dataset_info(id, 
00212                                              name.c_str(), 
00213                                              &size, 
00214                                              &a, 
00215                                              &b) >= 0,
00216                         "write_hdf5_2_array(): "
00217                         "Unable to locate dataset");
00218     arr.resize((typename ArrayVector<U>::size_type)size);
00219     vigra_postcondition(H5LTread_dataset (id, 
00220                                           name.c_str(),
00221                                           type, 
00222                                           arr.data()) >= 0,
00223                         "write_array_2_hdf5():"
00224                         "unable to read dataset");
00225 }
00226 
00227 
00228 inline void options_import_HDF5(hid_t & group_id,
00229                          RandomForestOptions & opt, 
00230                          std::string name)
00231 {
00232     ArrayVector<double> serialized_options;
00233     write_hdf5_2_array(group_id, serialized_options,
00234                         name, H5T_NATIVE_DOUBLE); 
00235     opt.unserialize(serialized_options.begin(),
00236                     serialized_options.end());
00237 }
00238 
00239 inline void options_export_HDF5(hid_t & group_id,
00240                          RandomForestOptions const & opt, 
00241                          std::string name)
00242 {
00243     ArrayVector<double> serialized_options(opt.serialized_size());
00244     opt.serialize(serialized_options.begin(),
00245                   serialized_options.end());
00246     write_array_2_hdf5(group_id, serialized_options,
00247                       name, H5T_NATIVE_DOUBLE); 
00248 }
00249 
00250 struct MyT
00251 {
00252     enum type { INT8 = 1,  INT16 = 2,  INT32 =3,  INT64=4, 
00253                   UINT8 = 5, UINT16 = 6, UINT32= 7, UINT64= 8,
00254                   FLOAT = 9, DOUBLE = 10, OTHER = 3294};
00255 };
00256 
00257 
00258 
00259 #define create_type_of(TYPE, ENUM) \
00260 inline MyT::type type_of(TYPE)\
00261 {\
00262     return MyT::ENUM; \
00263 }
00264 create_type_of(Int8, INT8)
00265 create_type_of(Int16, INT16)
00266 create_type_of(Int32, INT32)
00267 create_type_of(Int64, INT64)
00268 create_type_of(UInt8, UINT8)
00269 create_type_of(UInt16, UINT16)
00270 create_type_of(UInt32, UINT32)
00271 create_type_of(UInt64, UINT64)
00272 create_type_of(float, FLOAT)
00273 create_type_of(double, DOUBLE)
00274 #undef create_type_of
00275 
00276 inline MyT::type type_of_hid_t(hid_t group_id, std::string name)
00277 {
00278     hid_t m_dataset_handle = 
00279     H5Dopen(group_id, name.c_str(), H5P_DEFAULT);
00280     hid_t datatype = H5Dget_type(m_dataset_handle);
00281     H5T_class_t dataclass = H5Tget_class(datatype);
00282     size_t datasize  = H5Tget_size(datatype);
00283     H5T_sign_t datasign  = H5Tget_sign(datatype);
00284     MyT::type result = MyT::OTHER; 
00285     if(dataclass == H5T_FLOAT)
00286     {
00287         if(datasize == 4)
00288             result = MyT::FLOAT;
00289         else if(datasize == 8)
00290             result = MyT::DOUBLE;
00291     }
00292     else if(dataclass == H5T_INTEGER)   
00293     {
00294         if(datasign == H5T_SGN_NONE)
00295         {
00296             if(datasize ==  1)
00297                 result = MyT::UINT8;
00298             else if(datasize == 2)
00299                 result = MyT::UINT16;
00300             else if(datasize == 4)
00301                 result = MyT::UINT32;
00302             else if(datasize == 8)
00303                 result = MyT::UINT64;
00304         }
00305         else
00306         {
00307             if(datasize ==  1)
00308                 result = MyT::INT8;
00309             else if(datasize == 2)
00310                 result = MyT::INT16;
00311             else if(datasize == 4)
00312                 result = MyT::INT32;
00313             else if(datasize == 8)
00314                 result = MyT::INT64;
00315         }
00316     }
00317     H5Tclose(datatype);
00318     H5Dclose(m_dataset_handle);
00319     return result;
00320 }
00321 
00322 template<class T>
00323 void problemspec_import_HDF5(hid_t & group_id, 
00324                              ProblemSpec<T>  & param, 
00325                              std::string name)
00326 {
00327     hid_t param_id = H5Gopen (group_id, 
00328                               name.c_str(), 
00329                               H5P_DEFAULT);
00330 
00331     vigra_postcondition(param_id >= 0, 
00332                         "problemspec_import_HDF5():"
00333                         " Unable to open external parameters");
00334 
00335     //get a map containing all the double fields
00336     std::set<std::string> ext_set;
00337     find_groups_hdf5(param_id, ext_set);
00338     std::map<std::string, ArrayVector <double> > ext_map;
00339     std::set<std::string>::iterator iter;
00340     if(ext_set.find(std::string("labels")) == ext_set.end())
00341         std::runtime_error("labels are missing");
00342     for(iter = ext_set.begin(); iter != ext_set.end(); ++ iter)
00343     {
00344         if(*iter != std::string("labels"))
00345         {
00346             ext_map[*iter] = ArrayVector<double>();
00347             write_hdf5_2_array(param_id, ext_map[*iter], 
00348                                *iter, H5T_NATIVE_DOUBLE);
00349         }
00350     }
00351     param.make_from_map(ext_map);
00352     //load_class_labels
00353     switch(type_of_hid_t(param_id,"labels" ))
00354     {
00355         #define SOME_CASE(type_, enum_) \
00356       case MyT::enum_ :\
00357         {\
00358             ArrayVector<type_> tmp;\
00359             write_hdf5_2_array(param_id, tmp, "labels", H5T_NATIVE_##enum_);\
00360             param.classes_(tmp.begin(), tmp.end());\
00361         }\
00362             break;
00363         SOME_CASE(UInt8,    UINT8);
00364         SOME_CASE(UInt16,   UINT16);
00365         SOME_CASE(UInt32,   UINT32);
00366         SOME_CASE(UInt64,   UINT64);
00367         SOME_CASE(Int8,     INT8);
00368         SOME_CASE(Int16,    INT16);
00369         SOME_CASE(Int32,    INT32);
00370         SOME_CASE(Int64,    INT64);
00371         SOME_CASE(double,   DOUBLE);
00372         SOME_CASE(float,    FLOAT);
00373         default:
00374             std::runtime_error("exportRF_HDF5(): unknown class type"); 
00375         #undef SOME_CASE
00376     }
00377     H5Gclose(param_id);
00378 }
00379 
00380 template<class T>
00381 void problemspec_export_HDF5(hid_t & group_id, 
00382                              ProblemSpec<T> const & param, 
00383                              std::string name)
00384 {
00385     hid_t param_id = H5Gcreate(group_id, name.c_str(), 
00386                                         H5P_DEFAULT, 
00387                                         H5P_DEFAULT, 
00388                                         H5P_DEFAULT);
00389     vigra_postcondition(param_id >= 0, 
00390                         "problemspec_export_HDF5():"
00391                         " Unable to create external parameters");
00392 
00393     //get a map containing all the double fields
00394     std::map<std::string, ArrayVector<double> > serialized_param;
00395     param.make_map(serialized_param);
00396     std::map<std::string, ArrayVector<double> >::iterator iter;
00397     for(iter = serialized_param.begin(); iter != serialized_param.end(); ++iter)
00398         write_array_2_hdf5(param_id, iter->second, iter->first, H5T_NATIVE_DOUBLE);
00399     
00400     //save class_labels
00401     switch(type_of(param.classes[0]))
00402     {
00403         #define SOME_CASE(type) \
00404         case MyT::type:\
00405             write_array_2_hdf5(param_id, param.classes, "labels", H5T_NATIVE_##type);\
00406             break;
00407         SOME_CASE(UINT8);
00408         SOME_CASE(UINT16);
00409         SOME_CASE(UINT32);
00410         SOME_CASE(UINT64);
00411         SOME_CASE(INT8);
00412         SOME_CASE(INT16);
00413         SOME_CASE(INT32);
00414         SOME_CASE(INT64);
00415         SOME_CASE(DOUBLE);
00416         SOME_CASE(FLOAT);
00417         default:
00418             std::runtime_error("exportRF_HDF5(): unknown class type"); 
00419         #undef SOME_CASE
00420     }
00421     H5Gclose(param_id);
00422 }
00423 
00424 inline void dt_import_HDF5( hid_t & group_id,
00425                             RF_Traits::DecisionTree_t & tree,
00426                             std::string name)
00427 {
00428     //check if ext_param was written and write it if not
00429     if(tree.ext_param_.actual_msample_ == 0)
00430     {
00431         problemspec_import_HDF5(group_id, tree.ext_param_, "_ext_param");
00432         tree.classCount_ = tree.ext_param_.class_count_;
00433     }
00434     
00435     hid_t tree_id =H5Gopen (group_id, name.c_str(), H5P_DEFAULT);
00436     //write down topology
00437     write_hdf5_2_array(tree_id, 
00438                        tree.topology_, 
00439                        "topology", 
00440                        H5T_NATIVE_INT);
00441     //write down parameters
00442     write_hdf5_2_array(tree_id, 
00443                        tree.parameters_, 
00444                        "parameters", 
00445                        H5T_NATIVE_DOUBLE);
00446     H5Gclose(tree_id);
00447 }
00448 
00449 
00450 inline void dt_export_HDF5( hid_t & group_id,
00451                             RF_Traits::DecisionTree_t const & tree,
00452                             std::string name)
00453 {
00454     //check if ext_param was written and write it if not
00455     hid_t e_id = H5Gopen (group_id, 
00456                           "_ext_param", 
00457                           H5P_DEFAULT);
00458     if(e_id < 0)
00459     {
00460         problemspec_export_HDF5(group_id,
00461                                 tree.ext_param_, 
00462                                 "_ext_param"); 
00463     }
00464     else H5Gclose(e_id);
00465     
00466     //make the folder for the tree.
00467     hid_t tree_id = H5Gcreate(group_id, name.c_str(), 
00468                                         H5P_DEFAULT, 
00469                                         H5P_DEFAULT, 
00470                                         H5P_DEFAULT);
00471     //write down topology
00472     write_array_2_hdf5(tree_id, 
00473                        tree.topology_, 
00474                        "topology", 
00475                        H5T_NATIVE_INT);
00476     //write down parameters
00477     write_array_2_hdf5(tree_id, 
00478                        tree.parameters_, 
00479                        "parameters", 
00480                        H5T_NATIVE_DOUBLE);
00481     H5Gclose(tree_id);
00482 }
00483 } //namespace detail
00484 
00485 template<class T>
00486 bool rf_export_HDF5(RandomForest<T> const &rf, 
00487                     std::string filename, 
00488                     std::string pathname = "",
00489                     bool overwriteflag = false)
00490 { 
00491     using detail::make_padded_number;
00492     using detail::options_export_HDF5;
00493     using detail::problemspec_export_HDF5;
00494     using detail::dt_export_HDF5;
00495     //if file exists delete it.
00496     FILE* pFile = fopen ( filename.c_str(), "r" );
00497     if ( pFile != NULL && !overwriteflag)
00498         return 0;
00499     else if(pFile != 0 &&std::remove(filename.c_str()) != 0)
00500         return 0;
00501     
00502     //create a new file and group.
00503     hid_t file_id = H5Fcreate(filename.c_str(), H5F_ACC_TRUNC, 
00504                                                 H5P_DEFAULT, 
00505                                                 H5P_DEFAULT);
00506     vigra_postcondition(file_id >= 0, 
00507                         "rf_export_HDF5(): Unable to open file.");
00508     std::cerr << pathname.c_str();
00509     hid_t group_id = pathname== "" ?
00510                         file_id
00511                     :   H5Gcreate(file_id, pathname.c_str(), 
00512                                            H5P_DEFAULT, 
00513                                            H5P_DEFAULT, 
00514                                            H5P_DEFAULT);
00515 
00516     vigra_postcondition(group_id >= 0, 
00517                         "rf_export_HDF5(): Unable to create group");
00518 
00519     //save serialized options
00520         options_export_HDF5(group_id, rf.options(), "_options"); 
00521     //save external parameters
00522         problemspec_export_HDF5(group_id, rf.ext_param(), "_ext_param");
00523     //save trees
00524     
00525     int tree_count = rf.options_.tree_count_;
00526     for(int ii = 0; ii < tree_count; ++ii)
00527     {
00528         std::string treename =  "Tree_"  + 
00529                                 make_padded_number(ii, tree_count -1);
00530         dt_export_HDF5(group_id, rf.tree(ii), treename); 
00531     }
00532     
00533     //clean up the mess
00534     if(pathname != "")
00535         H5Gclose(group_id);
00536     H5Fclose(file_id);
00537 
00538     return 1;
00539 }
00540 
00541 
00542 template<class T>
00543 bool rf_import_HDF5(RandomForest<T> &rf, 
00544                     std::string filename, 
00545                     std::string pathname = "")
00546 { 
00547     using detail::find_groups_hdf5;
00548     using detail::options_import_HDF5;
00549     using detail::problemspec_import_HDF5;
00550     using detail::dt_export_HDF5;
00551     //if file exists delete it.
00552     FILE* pFile = fopen ( filename.c_str(), "r" );
00553     if ( pFile == NULL)
00554         return 0;
00555     
00556     //open file
00557     hid_t file_id = H5Fopen (filename.c_str(), 
00558                              H5F_ACC_RDONLY, 
00559                              H5P_DEFAULT);
00560     
00561     vigra_postcondition(file_id >= 0, 
00562                         "rf_import_HDF5(): Unable to open file.");
00563     hid_t group_id = pathname== "" ?
00564                         file_id
00565                     :   H5Gopen (file_id, 
00566                                  pathname.c_str(), 
00567                                  H5P_DEFAULT);
00568     
00569     vigra_postcondition(group_id >= 0, 
00570                         "rf_export_HDF5(): Unable to create group");
00571 
00572     //get serialized options
00573         options_import_HDF5(group_id, rf.options_, "_options"); 
00574     //save external parameters
00575         problemspec_import_HDF5(group_id, rf.ext_param_, "_ext_param");
00576     // TREE SAVING TIME
00577     // get all groups in base path
00578     
00579     std::set<std::string> tree_set;
00580     std::set<std::string>::iterator iter; 
00581     find_groups_hdf5(filename, pathname, tree_set);
00582     
00583     for(iter = tree_set.begin(); iter != tree_set.end(); ++iter)
00584     {
00585         if((*iter)[0] != '_')
00586         {
00587             rf.trees_.push_back(RF_Traits::DecisionTree_t(rf.ext_param_));
00588             dt_import_HDF5(group_id, rf.trees_.back(), *iter); 
00589         }
00590     }
00591     
00592     //clean up the mess
00593     if(pathname != "")
00594         H5Gclose(group_id);
00595     H5Fclose(file_id);
00596     rf.tree_indices_.resize(rf.tree_count());
00597     for(int ii = 0; ii < rf.tree_count(); ++ii)
00598         rf.tree_indices_[ii] = ii; 
00599     return 1;
00600 }
00601 } // namespace vigra
00602 
00603 #endif // HasHDF5
00604 
00605 #endif // VIGRA_RANDOM_FOREST_HDF5_IMPEX_HXX
00606 

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.7.0 (Thu Aug 25 2011)