KMeans.h
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012 #ifndef _KMEANS_H__
00013 #define _KMEANS_H__
00014
00015 #include <stdio.h>
00016 #include "lib/common.h"
00017 #include "lib/io.h"
00018 #include "features/SimpleFeatures.h"
00019 #include "distance/Distance.h"
00020 #include "classifier/DistanceMachine.h"
00021
00022 class CDistanceMachine;
00023
00024 namespace shogun
00025 {
00039 class CKMeans : public CDistanceMachine
00040 {
00041 public:
00043 CKMeans();
00044
00050 CKMeans(int32_t k, CDistance* d);
00051 virtual ~CKMeans();
00052
00057 virtual inline EClassifierType get_classifier_type() { return CT_KMEANS; }
00058
00067 virtual bool train(CFeatures* data=NULL);
00068
00074 virtual bool load(FILE* srcfile);
00075
00081 virtual bool save(FILE* dstfile);
00082
00087 inline void set_k(int32_t p_k)
00088 {
00089 ASSERT(p_k>0);
00090 this->k=p_k;
00091 }
00092
00097 inline int32_t get_k()
00098 {
00099 return k;
00100 }
00101
00106 inline void set_max_iter(int32_t iter)
00107 {
00108 ASSERT(iter>0);
00109 max_iter=iter;
00110 }
00111
00116 inline float64_t get_max_iter()
00117 {
00118 return max_iter;
00119 }
00120
00126 inline void get_radi(float64_t*& radi, int32_t& num)
00127 {
00128 radi=R;
00129 num=k;
00130 }
00131
00138 inline void get_centers(float64_t*& centers, int32_t& dim, int32_t& num)
00139 {
00140 centers=mus;
00141 dim=dimensions;
00142 num=k;
00143 }
00144
00150 inline void get_radiuses(float64_t** radii, int32_t* num)
00151 {
00152 size_t sz=sizeof(*R)*k;
00153 *radii=(float64_t*) malloc(sz);
00154 ASSERT(*radii);
00155
00156 memcpy(*radii, R, sz);
00157 *num=k;
00158 }
00159
00166 inline void get_cluster_centers(
00167 float64_t** centers, int32_t* dim, int32_t* num)
00168 {
00169 size_t sz=sizeof(*mus)*dimensions*k;
00170 *centers=(float64_t*) malloc(sz);
00171 ASSERT(*centers);
00172
00173 memcpy(*centers, mus, sz);
00174 *dim=dimensions;
00175 *num=k;
00176 }
00177
00182 inline int32_t get_dimensions()
00183 {
00184 return dimensions;
00185 }
00186
00187 protected:
00198 void sqdist(
00199 float64_t* x, CSimpleFeatures<float64_t>* y, float64_t *z, int32_t n1,
00200 int32_t offs, int32_t n2, int32_t m);
00201
00207 void clustknb(bool use_old_mus, float64_t *mus_start);
00208
00213 virtual CLabels* classify()
00214 {
00215 SG_NOTIMPLEMENTED;
00216 return NULL;
00217 }
00218
00224 virtual CLabels* classify(CFeatures* data)
00225 {
00226 SG_NOTIMPLEMENTED;
00227 return NULL;
00228 }
00229
00230
00231
00233 inline virtual const char* get_name() const { return "KMeans"; }
00234
00235 protected:
00237 int32_t max_iter;
00238
00240 int32_t k;
00241
00243 int32_t dimensions;
00244
00246 float64_t* R;
00247
00249 float64_t* mus;
00250
00251 private:
00253 float64_t* Weights;
00254 };
00255 }
00256 #endif
00257