SHOGUN v0.9.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2007-2009 Soeren Sonnenburg 00008 * Copyright (C) 2007-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00009 */ 00010 00011 #ifndef _HIERARCHICAL_H__ 00012 #define _HIERARCHICAL_H__ 00013 00014 #include <stdio.h> 00015 #include "lib/common.h" 00016 #include "lib/io.h" 00017 #include "distance/Distance.h" 00018 #include "classifier/DistanceMachine.h" 00019 00020 namespace shogun 00021 { 00022 class CDistanceMachine; 00023 00037 class CHierarchical : public CDistanceMachine 00038 { 00039 public: 00041 CHierarchical(); 00042 00048 CHierarchical(int32_t merges, CDistance* d); 00049 virtual ~CHierarchical(); 00050 00055 virtual inline EClassifierType get_classifier_type() { return CT_HIERARCHICAL; } 00056 00065 virtual bool train(CFeatures* data=NULL); 00066 00072 virtual bool load(FILE* srcfile); 00073 00079 virtual bool save(FILE* dstfile); 00080 00085 inline void set_merges(int32_t m) 00086 { 00087 ASSERT(m>0); 00088 merges=m; 00089 } 00090 00095 inline int32_t get_merges() 00096 { 00097 return merges; 00098 } 00099 00105 inline void get_assignment(int32_t*& assign, int32_t& num) 00106 { 00107 assign=assignment; 00108 num=table_size; 00109 } 00110 00116 inline void get_merge_distance(float64_t*& dist, int32_t& num) 00117 { 00118 dist=merge_distance; 00119 num=merges; 00120 } 00121 00127 inline void get_merge_distances(float64_t** dist, int32_t* num) 00128 { 00129 size_t sz=sizeof(*merge_distance)*merges; 00130 *dist=(float64_t*) malloc(sz); 00131 ASSERT(*dist); 00132 00133 memcpy(*dist, merge_distance, sz); 00134 *num=merges; 00135 } 00136 00143 inline void get_pairs(int32_t*& tuples, int32_t& rows, int32_t& num) 00144 { 00145 tuples=pairs; 00146 rows=2; 00147 num=merges; 00148 } 00149 00156 inline void get_cluster_pairs( 00157 int32_t** tuples, int32_t* rows, int32_t* num) 00158 { 00159 *rows=2; 00160 size_t sz=sizeof(*pairs)*(*rows)*merges; 00161 *tuples=(int32_t*) malloc(sz); 00162 ASSERT(*tuples); 00163 00164 memcpy(*tuples, pairs, sz); 00165 *num=merges; 00166 } 00167 00172 virtual CLabels* classify() 00173 { 00174 SG_NOTIMPLEMENTED; 00175 return NULL; 00176 } 00177 00183 virtual CLabels* classify(CFeatures* data) 00184 { 00185 SG_NOTIMPLEMENTED; 00186 return NULL; 00187 } 00188 00190 inline virtual const char* get_name() const { return "Hierarchical"; } 00191 00192 protected: 00194 int32_t merges; 00195 00197 int32_t dimensions; 00198 00200 int32_t assignment_size; 00201 00203 int32_t* assignment; 00204 00206 int32_t table_size; 00207 00209 int32_t* pairs; 00210 00212 float64_t* merge_distance; 00213 }; 00214 } 00215 #endif