SHOGUN v0.9.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 2010 Soeren Sonnenburg 00008 * Copyright (C) 2010 Berlin Institute of Technology 00009 */ 00010 00011 #ifndef _SCATTERKERNELNORMALIZER_H___ 00012 #define _SCATTERKERNELNORMALIZER_H___ 00013 00014 #include "kernel/KernelNormalizer.h" 00015 #include "kernel/IdentityKernelNormalizer.h" 00016 #include "kernel/Kernel.h" 00017 #include "features/Labels.h" 00018 #include "lib/io.h" 00019 00020 namespace shogun 00021 { 00022 class CScatterKernelNormalizer: public CKernelNormalizer 00023 { 00024 00025 public: 00027 CScatterKernelNormalizer() : CKernelNormalizer() 00028 { 00029 init(); 00030 } 00031 00034 CScatterKernelNormalizer(float64_t const_diag, float64_t const_offdiag, 00035 CLabels* labels,CKernelNormalizer* normalizer=NULL) 00036 : CKernelNormalizer() 00037 { 00038 init(); 00039 00040 m_testing_class=-1; 00041 m_const_diag=const_diag; 00042 m_const_offdiag=const_offdiag; 00043 00044 ASSERT(labels) 00045 SG_REF(labels); 00046 m_labels=labels; 00047 00048 if (normalizer==NULL) 00049 normalizer=new CIdentityKernelNormalizer(); 00050 SG_REF(normalizer); 00051 m_normalizer=normalizer; 00052 00053 SG_DEBUG("Constructing ScatterKernelNormalizer with const_diag=%g" 00054 " const_offdiag=%g num_labels=%d and normalizer='%s'\n", 00055 const_diag, const_offdiag, labels->get_num_labels(), 00056 normalizer->get_name()); 00057 } 00058 00060 virtual ~CScatterKernelNormalizer() 00061 { 00062 SG_UNREF(m_labels); 00063 SG_UNREF(m_normalizer); 00064 } 00065 00068 virtual bool init(CKernel* k) 00069 { 00070 m_normalizer->init(k); 00071 return true; 00072 } 00073 00078 int32_t get_testing_class() 00079 { 00080 return m_testing_class; 00081 } 00082 00087 void set_testing_class(int32_t c) 00088 { 00089 m_testing_class=c; 00090 } 00091 00097 inline virtual float64_t normalize(float64_t value, int32_t idx_lhs, 00098 int32_t idx_rhs) 00099 { 00100 value=m_normalizer->normalize(value, idx_lhs, idx_rhs); 00101 float64_t c=m_const_offdiag; 00102 00103 if (m_testing_class>=0) 00104 { 00105 if (m_labels->get_label(idx_lhs) == m_testing_class) 00106 c=m_const_diag; 00107 } 00108 else 00109 { 00110 if (m_labels->get_label(idx_lhs) == m_labels->get_label(idx_rhs)) 00111 c=m_const_diag; 00112 00113 } 00114 return value*c; 00115 } 00116 00121 inline virtual float64_t normalize_lhs(float64_t value, int32_t idx_lhs) 00122 { 00123 SG_ERROR("normalize_lhs not implemented"); 00124 return 0; 00125 } 00126 00131 inline virtual float64_t normalize_rhs(float64_t value, int32_t idx_rhs) 00132 { 00133 SG_ERROR("normalize_rhs not implemented"); 00134 return 0; 00135 } 00136 00138 inline virtual const char* get_name() const 00139 { 00140 return "ScatterKernelNormalizer"; 00141 } 00142 00143 private: 00144 void init() 00145 { 00146 m_const_diag = 1.0; 00147 m_const_offdiag = 1.0; 00148 00149 m_labels = NULL; 00150 m_normalizer = NULL; 00151 00152 m_testing_class = -1; 00153 00154 00155 m_parameters->add(&m_testing_class, "m_testing_class" 00156 "Testing Class."); 00157 m_parameters->add(&m_const_diag, "m_const_diag" 00158 "Factor to multiply to diagonal elements."); 00159 m_parameters->add(&m_const_offdiag, "m_const_offdiag" 00160 "Factor to multiply to off-diagonal elements."); 00161 00162 m_parameters->add((CSGObject**) &m_labels, "m_labels", "Labels"); 00163 m_parameters->add((CSGObject**) &m_normalizer, "m_normalizer", "Kernel normalizer."); 00164 } 00165 00166 protected: 00167 00169 float64_t m_const_diag; 00171 float64_t m_const_offdiag; 00172 00174 CLabels* m_labels; 00175 00177 CKernelNormalizer* m_normalizer; 00178 00180 int32_t m_testing_class; 00181 }; 00182 } 00183 #endif 00184