SHOGUN v0.9.0
|
00001 /* 00002 * This program is free software; you can redistribute it and/or modify 00003 * it under the terms of the GNU General Public License as published by 00004 * the Free Software Foundation; either version 3 of the License, or 00005 * (at your option) any later version. 00006 * 00007 * Written (W) 1999-2009 Soeren Sonnenburg 00008 * Written (W) 1999-2008 Gunnar Raetsch 00009 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society 00010 */ 00011 00012 #include <vector> 00013 00014 #include "lib/common.h" 00015 #include "lib/io.h" 00016 #include "lib/Signal.h" 00017 #include "lib/Trie.h" 00018 #include "base/Parallel.h" 00019 00020 #include "kernel/SpectrumMismatchRBFKernel.h" 00021 #include "features/Features.h" 00022 #include "features/StringFeatures.h" 00023 00024 00025 #include <vector> 00026 #include <string> 00027 00028 #include <assert.h> 00029 00030 #ifndef WIN32 00031 #include <pthread.h> 00032 #endif 00033 00034 using namespace shogun; 00035 00036 CSpectrumMismatchRBFKernel::CSpectrumMismatchRBFKernel(void) 00037 :CStringKernel<char>(0) 00038 { 00039 SG_UNSTABLE("CSpectrumMismatchRBFKernel::" 00040 "CSpectrumMismatchRBFKernel(void)", "\n"); 00041 00042 alphabet = NULL; 00043 degree = 0; 00044 max_mismatch = 0; 00045 AA_matrix = NULL; 00046 width = 0.0; 00047 00048 initialized = false; 00049 target_letter_0 = 0; 00050 } 00051 00052 CSpectrumMismatchRBFKernel::CSpectrumMismatchRBFKernel (int32_t size, 00053 float64_t* AA_matrix_, int32_t nr, int32_t nc, 00054 int32_t degree_, int32_t max_mismatch_, float64_t width_) : CStringKernel<char>(size), 00055 alphabet(NULL), degree(degree_), max_mismatch(max_mismatch_), width(width_) 00056 { 00057 lhs=NULL; 00058 rhs=NULL; 00059 00060 target_letter_0=-1 ; 00061 00062 AA_matrix=NULL; 00063 set_AA_matrix(AA_matrix_, nr, nc); 00064 } 00065 00066 CSpectrumMismatchRBFKernel::CSpectrumMismatchRBFKernel( 00067 CStringFeatures<char>* l, CStringFeatures<char>* r, int32_t size, float64_t* AA_matrix_, int32_t nr, int32_t nc, int32_t degree_, int32_t max_mismatch_, float64_t width_) 00068 : CStringKernel<char>(size), alphabet(NULL), degree(degree_), max_mismatch(max_mismatch_), width(width_) 00069 { 00070 target_letter_0=-1 ; 00071 00072 AA_matrix=NULL; 00073 set_AA_matrix(AA_matrix_, nr, nc); 00074 init(l, r); 00075 } 00076 00077 CSpectrumMismatchRBFKernel::~CSpectrumMismatchRBFKernel() 00078 { 00079 cleanup(); 00080 delete[] AA_matrix ; 00081 } 00082 00083 00084 void CSpectrumMismatchRBFKernel::remove_lhs() 00085 { 00086 00087 CKernel::remove_lhs(); 00088 } 00089 00090 bool CSpectrumMismatchRBFKernel::init(CFeatures* l, CFeatures* r) 00091 { 00092 int32_t lhs_changed=(lhs!=l); 00093 int32_t rhs_changed=(rhs!=r); 00094 00095 CStringKernel<char>::init(l,r); 00096 00097 SG_DEBUG("lhs_changed: %i\n", lhs_changed); 00098 SG_DEBUG("rhs_changed: %i\n", rhs_changed); 00099 00100 CStringFeatures<char>* sf_l=(CStringFeatures<char>*) l; 00101 CStringFeatures<char>* sf_r=(CStringFeatures<char>*) r; 00102 00103 SG_UNREF(alphabet); 00104 alphabet=sf_l->get_alphabet(); 00105 CAlphabet* ralphabet=sf_r->get_alphabet(); 00106 00107 if (!((alphabet->get_alphabet()==DNA) || (alphabet->get_alphabet()==RNA))) 00108 properties &= ((uint64_t) (-1)) ^ (KP_LINADD | KP_BATCHEVALUATION); 00109 00110 ASSERT(ralphabet->get_alphabet()==alphabet->get_alphabet()); 00111 SG_UNREF(ralphabet); 00112 00113 compute_all() ; 00114 00115 return init_normalizer(); 00116 } 00117 00118 void CSpectrumMismatchRBFKernel::cleanup() 00119 { 00120 00121 SG_UNREF(alphabet); 00122 alphabet=NULL; 00123 00124 CKernel::cleanup(); 00125 } 00126 00127 float64_t CSpectrumMismatchRBFKernel::AA_helper(std::string &path, const char* joint_seq, unsigned int index) 00128 { 00129 float64_t diff=0.0 ; 00130 00131 for (unsigned int i=0; i<path.size(); i++) 00132 { 00133 if (path[i]!=joint_seq[index+i]) 00134 { 00135 diff += AA_matrix[ (path[i]-1)*128 + path[i] - 1] ; 00136 diff -= 2*AA_matrix[ (path[i]-1)*128 + joint_seq[index+i] - 1] ; 00137 diff += AA_matrix[ (joint_seq[index+i]-1)*128 + joint_seq[index+i] - 1] ; 00138 } 00139 } 00140 00141 return exp( - diff/width) ; 00142 } 00143 00144 /* 00145 float64_t CSpectrumMismatchRBFKernel::compute_helper(const char* joint_seq, 00146 std::vector<unsigned int> joint_index, std::vector<unsigned int> joint_mismatch, 00147 std::string path, unsigned int d, 00148 const int & alen) 00149 { 00150 const char* AA = "ACDEFGHIKLMNPQRSTVWY" ; 00151 const unsigned int num_AA = strlen(AA) ; 00152 00153 assert(path.size()==d) ; 00154 assert(joint_mismatch.size()==joint_index.size()) ; 00155 00156 float64_t res = 0.0 ; 00157 00158 for (unsigned int i=0; i<num_AA; i++) 00159 { 00160 std::vector<unsigned int> joint_mismatch_ ; 00161 std::vector<unsigned int> joint_index_ ; 00162 00163 for (unsigned int j=0; j<joint_index.size(); j++) 00164 { 00165 if (joint_seq[joint_index[j]+d] != AA[i]) 00166 { 00167 if (joint_mismatch[j]+1 <= (unsigned int) max_mismatch) 00168 { 00169 joint_mismatch_.push_back(joint_mismatch[j]+1) ; 00170 joint_index_.push_back(joint_index[j]) ; 00171 } 00172 } 00173 else 00174 { 00175 joint_mismatch_.push_back(joint_mismatch[j]) ; 00176 joint_index_.push_back(joint_index[j]) ; 00177 } 00178 } 00179 if (joint_mismatch_.size()>0) 00180 { 00181 std::string path_ = path + AA[i] ; 00182 00183 if (d+1 < (unsigned int) degree) 00184 { 00185 res += compute_helper(joint_seq, joint_index_, joint_mismatch_, path_, d+1, alen) ; 00186 } 00187 else 00188 { 00189 int anum=0, bnum=0; 00190 for (unsigned int j=0; j<joint_index_.size(); j++) 00191 if (joint_index_[j] < (unsigned int)alen) 00192 { 00193 if (1) 00194 { 00195 anum++ ; 00196 if (joint_mismatch_[j]==0) 00197 anum+=3 ; 00198 } 00199 else 00200 { 00201 if (joint_mismatch_[j]!=0) 00202 anum += AA_helper(path_, joint_seq, joint_index_[j]) ; 00203 else 00204 anum++ ; 00205 } 00206 } 00207 else 00208 { 00209 if (1) 00210 { 00211 bnum++ ; 00212 if (joint_mismatch_[j]==0) 00213 bnum+=3 ; 00214 } 00215 else 00216 { 00217 if (joint_mismatch_[j]!=0) 00218 bnum += AA_helper(path_, joint_seq, joint_index_[j]) ; 00219 else 00220 bnum++ ; 00221 } 00222 } 00223 00224 //fprintf(stdout, "%s: %i x %i\n", path_.c_str(), anum, bnum) ; 00225 00226 res+= anum*bnum ; 00227 } 00228 } 00229 } 00230 return res ; 00231 } 00232 */ 00233 00234 void CSpectrumMismatchRBFKernel::compute_helper_all(const char *joint_seq, 00235 std::vector<struct joint_list_struct> &joint_list, 00236 std::string path, unsigned int d) 00237 { 00238 const char* AA = "ACDEFGHIKLMNPQRSTVWY" ; 00239 const unsigned int num_AA = strlen(AA) ; 00240 00241 assert(path.size()==d) ; 00242 00243 for (unsigned int i=0; i<num_AA; i++) 00244 { 00245 std::vector<struct joint_list_struct> joint_list_ ; 00246 00247 if (d==0) 00248 fprintf(stderr, "i=%i: ", i) ; 00249 if (d==0 && target_letter_0!=-1 && (int)i != target_letter_0 ) 00250 continue ; 00251 00252 if (d==1) 00253 { 00254 fprintf(stdout, "*") ; 00255 fflush(stdout) ; 00256 } 00257 if (d==2) 00258 { 00259 fprintf(stdout, "+") ; 00260 fflush(stdout) ; 00261 } 00262 00263 for (unsigned int j=0; j<joint_list.size(); j++) 00264 { 00265 if (joint_seq[joint_list[j].index+d] != AA[i]) 00266 { 00267 if (joint_list[j].mismatch+1 <= (unsigned int) max_mismatch) 00268 { 00269 struct joint_list_struct list_item ; 00270 list_item = joint_list[j] ; 00271 list_item.mismatch = joint_list[j].mismatch+1 ; 00272 joint_list_.push_back(list_item) ; 00273 } 00274 } 00275 else 00276 joint_list_.push_back(joint_list[j]) ; 00277 } 00278 00279 if (joint_list_.size()>0) 00280 { 00281 std::string path_ = path + AA[i] ; 00282 00283 if (d+1 < (unsigned int) degree) 00284 { 00285 compute_helper_all(joint_seq, joint_list_, path_, d+1) ; 00286 } 00287 else 00288 { 00289 CArray<float64_t> feats ; 00290 feats.resize_array(kernel_matrix.get_dim1()) ; 00291 feats.zero() ; 00292 00293 for (unsigned int j=0; j<joint_list_.size(); j++) 00294 { 00295 if (width==0.0) 00296 { 00297 feats[joint_list_[j].ex_index]++ ; 00298 //if (joint_mismatch_[j]==0) 00299 // feats[joint_ex_index_[j]]+=3 ; 00300 } 00301 else 00302 { 00303 if (joint_list_[j].mismatch!=0) 00304 feats[joint_list_[j].ex_index] += AA_helper(path_, joint_seq, joint_list_[j].index) ; 00305 else 00306 feats[joint_list_[j].ex_index] ++ ; 00307 } 00308 } 00309 00310 std::vector<int> idx ; 00311 for (int r=0; r<feats.get_array_size(); r++) 00312 if (feats[r]!=0.0) 00313 idx.push_back(r) ; 00314 00315 for (unsigned int r=0; r<idx.size(); r++) 00316 for (unsigned int s=r; s<idx.size(); s++) 00317 if (s==r) 00318 kernel_matrix.set_element(feats[idx[r]]*feats[idx[s]] + kernel_matrix.get_element(idx[r],idx[s]), idx[r], idx[s]) ; 00319 else 00320 { 00321 kernel_matrix.set_element(feats[idx[r]]*feats[idx[s]] + kernel_matrix.get_element(idx[r],idx[s]), idx[r], idx[s]) ; 00322 kernel_matrix.set_element(feats[idx[r]]*feats[idx[s]] + kernel_matrix.get_element(idx[s],idx[r]), idx[s], idx[r]) ; 00323 } 00324 } 00325 } 00326 if (d==0) 00327 fprintf(stdout, "\n") ; 00328 } 00329 } 00330 00331 void CSpectrumMismatchRBFKernel::compute_all() 00332 { 00333 std::string joint_seq ; 00334 std::vector<struct joint_list_struct> joint_list ; 00335 00336 assert(lhs->get_num_vectors()==rhs->get_num_vectors()) ; 00337 kernel_matrix.resize_array(lhs->get_num_vectors(), lhs->get_num_vectors()) ; 00338 for (int i=0; i<lhs->get_num_vectors(); i++) 00339 for (int j=0; j<lhs->get_num_vectors(); j++) 00340 kernel_matrix.set_element(0, i, j) ; 00341 00342 for (int i=0; i<lhs->get_num_vectors(); i++) 00343 { 00344 int32_t alen ; 00345 bool free_avec ; 00346 char* avec = ((CStringFeatures<char>*) lhs)->get_feature_vector(i, alen, free_avec); 00347 00348 for (int apos=0; apos+degree-1<alen; apos++) 00349 { 00350 struct joint_list_struct list_item ; 00351 list_item.ex_index = i ; 00352 list_item.index = apos+joint_seq.size() ; 00353 list_item.mismatch = 0 ; 00354 00355 joint_list.push_back(list_item) ; 00356 } 00357 joint_seq += std::string(avec, alen) ; 00358 00359 ((CStringFeatures<char>*) lhs)->free_feature_vector(avec, i, free_avec); 00360 } 00361 00362 compute_helper_all(joint_seq.c_str(), joint_list, "", 0) ; 00363 } 00364 00365 00366 float64_t CSpectrumMismatchRBFKernel::compute(int32_t idx_a, int32_t idx_b) 00367 { 00368 return kernel_matrix.element(idx_a, idx_b) ; 00369 } 00370 /* 00371 bool CSpectrumMismatchRBFKernel::set_weights( 00372 float64_t* ws, int32_t d, int32_t len) 00373 { 00374 if (d==128 && len==128) 00375 { 00376 SG_DEBUG("Setting AA_matrix\n") ; 00377 memcpy(AA_matrix, ws, 128*128*sizeof(float64_t)) ; 00378 return true ; 00379 } 00380 00381 if (d==1 && len==1) 00382 { 00383 sigma=ws[0] ; 00384 SG_DEBUG("Setting sigma to %e\n", sigma) ; 00385 return true ; 00386 } 00387 00388 if (d==2 && len==2) 00389 { 00390 target_letter_0=ws[0] ; 00391 SG_DEBUG("Setting target letter to %c\n", target_letter_0) ; 00392 return true ; 00393 } 00394 00395 if (d!=degree || len<1) 00396 SG_ERROR("Dimension mismatch (should be de(seq_length | 1) x degree)\n"); 00397 00398 length=len; 00399 00400 if (length==0) 00401 length=1; 00402 00403 int32_t num_weights=degree*(length+max_mismatch); 00404 delete[] weights; 00405 weights=new float64_t[num_weights]; 00406 00407 if (weights) 00408 { 00409 for (int32_t i=0; i<num_weights; i++) { 00410 if (ws[i]) // len(ws) might be != num_weights? 00411 weights[i]=ws[i]; 00412 } 00413 return true; 00414 } 00415 else 00416 return false; 00417 } 00418 */ 00419 00420 bool CSpectrumMismatchRBFKernel::set_AA_matrix(float64_t* AA_matrix_, int32_t nr, int32_t nc) 00421 { 00422 if (AA_matrix_) 00423 { 00424 if (nr!=128 || nc!=128) 00425 SG_ERROR("AA_matrix should be of shape 128x128\n"); 00426 delete[] AA_matrix; 00427 AA_matrix=new float64_t[nc*nr]; 00428 memcpy(AA_matrix, AA_matrix_, nc*nr*sizeof(float64_t)) ; 00429 SG_DEBUG("Setting AA_matrix\n") ; 00430 memcpy(AA_matrix, AA_matrix_, 128*128*sizeof(float64_t)) ; 00431 return true ; 00432 } 00433 00434 return false; 00435 } 00436 00437 bool CSpectrumMismatchRBFKernel::set_max_mismatch(int32_t max) 00438 { 00439 max_mismatch=max; 00440 00441 if (lhs!=NULL && rhs!=NULL) 00442 return init(lhs, rhs); 00443 else 00444 return true; 00445 }