CommWordStringKernel.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 1999-2009 Soeren Sonnenburg
00008  * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include "lib/common.h"
00012 #include "kernel/CommWordStringKernel.h"
00013 #include "kernel/SqrtDiagKernelNormalizer.h"
00014 #include "features/StringFeatures.h"
00015 #include "lib/io.h"
00016 
00017 using namespace shogun;
00018 
00019 CCommWordStringKernel::CCommWordStringKernel(int32_t size, bool s)
00020 : CStringKernel<uint16_t>(size), dictionary_size(0), dictionary_weights(NULL),
00021     use_sign(s), use_dict_diagonal_optimization(false), dict_diagonal_optimization(NULL)
00022 {
00023     properties |= KP_LINADD;
00024     init_dictionary(1<<(sizeof(uint16_t)*8));
00025     set_normalizer(new CSqrtDiagKernelNormalizer(use_dict_diagonal_optimization));
00026 }
00027 
00028 CCommWordStringKernel::CCommWordStringKernel(
00029     CStringFeatures<uint16_t>* l, CStringFeatures<uint16_t>* r, bool s,
00030     int32_t size)
00031 : CStringKernel<uint16_t>(size), dictionary_size(0), dictionary_weights(NULL),
00032     use_sign(s), use_dict_diagonal_optimization(false), dict_diagonal_optimization(NULL)
00033 {
00034     properties |= KP_LINADD;
00035 
00036     init_dictionary(1<<(sizeof(uint16_t)*8));
00037     set_normalizer(new CSqrtDiagKernelNormalizer(use_dict_diagonal_optimization));
00038     init(l,r);
00039 }
00040 
00041 
00042 bool CCommWordStringKernel::init_dictionary(int32_t size)
00043 {
00044     dictionary_size= size;
00045     delete[] dictionary_weights;
00046     dictionary_weights=new float64_t[size];
00047     SG_DEBUG( "using dictionary of %d words\n", size);
00048     clear_normal();
00049 
00050     return dictionary_weights!=NULL;
00051 }
00052 
00053 CCommWordStringKernel::~CCommWordStringKernel() 
00054 {
00055     cleanup();
00056 
00057     delete[] dictionary_weights;
00058     delete[] dict_diagonal_optimization ;
00059 }
00060   
00061 bool CCommWordStringKernel::init(CFeatures* l, CFeatures* r)
00062 {
00063     CStringKernel<uint16_t>::init(l,r);
00064 
00065     if (use_dict_diagonal_optimization)
00066     {
00067         delete[] dict_diagonal_optimization ;
00068         dict_diagonal_optimization=new int32_t[int32_t(((CStringFeatures<uint16_t>*)l)->get_num_symbols())];
00069         ASSERT(((CStringFeatures<uint16_t>*)l)->get_num_symbols() == ((CStringFeatures<uint16_t>*)r)->get_num_symbols()) ;
00070     }
00071 
00072     return init_normalizer();
00073 }
00074 
00075 void CCommWordStringKernel::cleanup()
00076 {
00077     delete_optimization();
00078     CKernel::cleanup();
00079 }
00080 
00081 float64_t CCommWordStringKernel::compute_diag(int32_t idx_a)
00082 {
00083     int32_t alen;
00084     CStringFeatures<uint16_t>* l = (CStringFeatures<uint16_t>*) lhs;
00085     CStringFeatures<uint16_t>* r = (CStringFeatures<uint16_t>*) rhs;
00086 
00087     bool free_av;
00088     uint16_t* av=l->get_feature_vector(idx_a, alen, free_av);
00089 
00090     float64_t result=0.0 ;
00091     ASSERT(l==r);
00092     ASSERT(sizeof(uint16_t)<=sizeof(float64_t));
00093     ASSERT((1<<(sizeof(uint16_t)*8)) > alen);
00094 
00095     int32_t num_symbols=(int32_t) l->get_num_symbols();
00096     ASSERT(num_symbols<=dictionary_size);
00097 
00098     int32_t* dic = dict_diagonal_optimization;
00099     memset(dic, 0, num_symbols*sizeof(int32_t));
00100 
00101     for (int32_t i=0; i<alen; i++)
00102         dic[av[i]]++;
00103 
00104     if (use_sign)
00105     {
00106         for (int32_t i=0; i<(int32_t) l->get_num_symbols(); i++)
00107         {
00108             if (dic[i]!=0)
00109                 result++;
00110         }
00111     }
00112     else
00113     {
00114         for (int32_t i=0; i<num_symbols; i++)
00115         {
00116             if (dic[i]!=0)
00117                 result+=dic[i]*dic[i];
00118         }
00119     }
00120     l->free_feature_vector(av, idx_a, free_av);
00121 
00122     return result;
00123 }
00124 
00125 float64_t CCommWordStringKernel::compute_helper(
00126     int32_t idx_a, int32_t idx_b, bool do_sort)
00127 {
00128     int32_t alen, blen;
00129     bool free_av, free_bv;
00130 
00131     CStringFeatures<uint16_t>* l = (CStringFeatures<uint16_t>*) lhs;
00132     CStringFeatures<uint16_t>* r = (CStringFeatures<uint16_t>*) rhs;
00133 
00134     uint16_t* av=l->get_feature_vector(idx_a, alen, free_av);
00135     uint16_t* bv=r->get_feature_vector(idx_b, blen, free_bv);
00136 
00137     uint16_t* avec=av;
00138     uint16_t* bvec=bv;
00139 
00140     if (do_sort)
00141     {
00142         if (alen>0)
00143         {
00144             avec=new uint16_t[alen];
00145             memcpy(avec, av, sizeof(uint16_t)*alen);
00146             CMath::radix_sort(avec, alen);
00147         }
00148         else
00149             avec=NULL;
00150 
00151         if (blen>0)
00152         {
00153             bvec=new uint16_t[blen];
00154             memcpy(bvec, bv, sizeof(uint16_t)*blen);
00155             CMath::radix_sort(bvec, blen);
00156         }
00157         else
00158             bvec=NULL;
00159     }
00160     else
00161     {
00162         if ( (l->get_num_preproc() != l->get_num_preprocessed()) ||
00163                 (r->get_num_preproc() != r->get_num_preprocessed()))
00164         {
00165             SG_ERROR("not all preprocessors have been applied to training (%d/%d)"
00166                     " or test (%d/%d) data\n", l->get_num_preprocessed(), l->get_num_preproc(),
00167                     r->get_num_preprocessed(), r->get_num_preproc());
00168         }
00169     }
00170 
00171     float64_t result=0;
00172 
00173     int32_t left_idx=0;
00174     int32_t right_idx=0;
00175 
00176     if (use_sign)
00177     {
00178         while (left_idx < alen && right_idx < blen)
00179         {
00180             if (avec[left_idx]==bvec[right_idx])
00181             {
00182                 uint16_t sym=avec[left_idx];
00183 
00184                 while (left_idx< alen && avec[left_idx]==sym)
00185                     left_idx++;
00186 
00187                 while (right_idx< blen && bvec[right_idx]==sym)
00188                     right_idx++;
00189 
00190                 result++;
00191             }
00192             else if (avec[left_idx]<bvec[right_idx])
00193                 left_idx++;
00194             else
00195                 right_idx++;
00196         }
00197     }
00198     else
00199     {
00200         while (left_idx < alen && right_idx < blen)
00201         {
00202             if (avec[left_idx]==bvec[right_idx])
00203             {
00204                 int32_t old_left_idx=left_idx;
00205                 int32_t old_right_idx=right_idx;
00206 
00207                 uint16_t sym=avec[left_idx];
00208 
00209                 while (left_idx< alen && avec[left_idx]==sym)
00210                     left_idx++;
00211 
00212                 while (right_idx< blen && bvec[right_idx]==sym)
00213                     right_idx++;
00214 
00215                 result+=((float64_t) (left_idx-old_left_idx))*
00216                     ((float64_t) (right_idx-old_right_idx));
00217             }
00218             else if (avec[left_idx]<bvec[right_idx])
00219                 left_idx++;
00220             else
00221                 right_idx++;
00222         }
00223     }
00224 
00225     if (do_sort)
00226     {
00227         delete[] avec;
00228         delete[] bvec;
00229     }
00230 
00231     l->free_feature_vector(av, idx_a, free_av);
00232     r->free_feature_vector(bv, idx_b, free_bv);
00233 
00234     return result;
00235 }
00236 
00237 void CCommWordStringKernel::add_to_normal(int32_t vec_idx, float64_t weight)
00238 {
00239     int32_t len=-1;
00240     bool free_vec;
00241     uint16_t* vec=((CStringFeatures<uint16_t>*) lhs)->
00242         get_feature_vector(vec_idx, len, free_vec);
00243 
00244     if (len>0)
00245     {
00246         int32_t j, last_j=0;
00247         if (use_sign)
00248         {
00249             for (j=1; j<len; j++)
00250             {
00251                 if (vec[j]==vec[j-1])
00252                     continue;
00253 
00254                 dictionary_weights[(int32_t) vec[j-1]]+=normalizer->
00255                     normalize_lhs(weight, vec_idx);
00256             }
00257 
00258             dictionary_weights[(int32_t) vec[len-1]]+=normalizer->
00259                 normalize_lhs(weight, vec_idx);
00260         }
00261         else
00262         {
00263             for (j=1; j<len; j++)
00264             {
00265                 if (vec[j]==vec[j-1])
00266                     continue;
00267 
00268                 dictionary_weights[(int32_t) vec[j-1]]+=normalizer->
00269                     normalize_lhs(weight*(j-last_j), vec_idx);
00270                 last_j = j;
00271             }
00272 
00273             dictionary_weights[(int32_t) vec[len-1]]+=normalizer->
00274                 normalize_lhs(weight*(len-last_j), vec_idx);
00275         }
00276         set_is_initialized(true);
00277     }
00278 
00279     ((CStringFeatures<uint16_t>*) lhs)->free_feature_vector(vec, vec_idx, free_vec);
00280 }
00281 
00282 void CCommWordStringKernel::clear_normal()
00283 {
00284     memset(dictionary_weights, 0, dictionary_size*sizeof(float64_t));
00285     set_is_initialized(false);
00286 }
00287 
00288 bool CCommWordStringKernel::init_optimization(
00289     int32_t count, int32_t* IDX, float64_t* weights)
00290 {
00291     delete_optimization();
00292 
00293     if (count<=0)
00294     {
00295         set_is_initialized(true);
00296         SG_DEBUG("empty set of SVs\n");
00297         return true;
00298     }
00299 
00300     SG_DEBUG("initializing CCommWordStringKernel optimization\n");
00301 
00302     for (int32_t i=0; i<count; i++)
00303     {
00304         if ( (i % (count/10+1)) == 0)
00305             SG_PROGRESS(i, 0, count);
00306 
00307         add_to_normal(IDX[i], weights[i]);
00308     }
00309 
00310     set_is_initialized(true);
00311     return true;
00312 }
00313 
00314 bool CCommWordStringKernel::delete_optimization() 
00315 {
00316     SG_DEBUG( "deleting CCommWordStringKernel optimization\n");
00317 
00318     clear_normal();
00319     return true;
00320 }
00321 
00322 float64_t CCommWordStringKernel::compute_optimized(int32_t i)
00323 { 
00324     if (!get_is_initialized())
00325     {
00326       SG_ERROR( "CCommWordStringKernel optimization not initialized\n");
00327         return 0 ; 
00328     }
00329 
00330     float64_t result = 0;
00331     int32_t len = -1;
00332     bool free_vec;
00333     uint16_t* vec=((CStringFeatures<uint16_t>*) rhs)->
00334         get_feature_vector(i, len, free_vec);
00335 
00336     int32_t j, last_j=0;
00337     if (vec && len>0)
00338     {
00339         if (use_sign)
00340         {
00341             for (j=1; j<len; j++)
00342             {
00343                 if (vec[j]==vec[j-1])
00344                     continue;
00345 
00346                 result += dictionary_weights[(int32_t) vec[j-1]];
00347             }
00348 
00349             result += dictionary_weights[(int32_t) vec[len-1]];
00350         }
00351         else
00352         {
00353             for (j=1; j<len; j++)
00354             {
00355                 if (vec[j]==vec[j-1])
00356                     continue;
00357 
00358                 result += dictionary_weights[(int32_t) vec[j-1]]*(j-last_j);
00359                 last_j = j;
00360             }
00361 
00362             result += dictionary_weights[(int32_t) vec[len-1]]*(len-last_j);
00363         }
00364 
00365         result=normalizer->normalize_rhs(result, i);
00366     }
00367     ((CStringFeatures<uint16_t>*) rhs)->free_feature_vector(vec, i, free_vec);
00368     return result;
00369 }
00370 
00371 float64_t* CCommWordStringKernel::compute_scoring(
00372     int32_t max_degree, int32_t& num_feat, int32_t& num_sym, float64_t* target,
00373     int32_t num_suppvec, int32_t* IDX, float64_t* alphas, bool do_init)
00374 {
00375     ASSERT(lhs);
00376     CStringFeatures<uint16_t>* str=((CStringFeatures<uint16_t>*) lhs);
00377     num_feat=1;//str->get_max_vector_length();
00378     CAlphabet* alpha=str->get_alphabet();
00379     ASSERT(alpha);
00380     int32_t num_bits=alpha->get_num_bits();
00381     int32_t order=str->get_order();
00382     ASSERT(max_degree<=order);
00383     //int32_t num_words=(int32_t) str->get_num_symbols();
00384     int32_t num_words=(int32_t) str->get_original_num_symbols();
00385     int32_t offset=0;
00386 
00387     num_sym=0;
00388     
00389     for (int32_t i=0; i<order; i++)
00390         num_sym+=CMath::pow((int32_t) num_words,i+1);
00391 
00392     SG_DEBUG("num_words:%d, order:%d, len:%d sz:%d (len*sz:%d)\n", num_words, order,
00393             num_feat, num_sym, num_feat*num_sym);
00394 
00395     if (!target)
00396         target=new float64_t[num_feat*num_sym];
00397     memset(target, 0, num_feat*num_sym*sizeof(float64_t));
00398 
00399     if (do_init)
00400         init_optimization(num_suppvec, IDX, alphas);
00401 
00402     uint32_t kmer_mask=0;
00403     uint32_t words=CMath::pow((int32_t) num_words,(int32_t) order);
00404 
00405     for (int32_t o=0; o<max_degree; o++)
00406     {
00407         float64_t* contrib=&target[offset];
00408         offset+=CMath::pow((int32_t) num_words,(int32_t) o+1);
00409 
00410         kmer_mask=(kmer_mask<<(num_bits)) | str->get_masked_symbols(0xffff, 1);
00411 
00412         for (int32_t p=-o; p<order; p++)
00413         {
00414             int32_t o_sym=0, m_sym=0, il=0,ir=0, jl=0;
00415             uint32_t imer_mask=kmer_mask;
00416             uint32_t jmer_mask=kmer_mask;
00417 
00418             if (p<0)
00419             {
00420                 il=-p;
00421                 m_sym=order-o-p-1;
00422                 o_sym=-p;
00423             }
00424             else if (p<order-o)
00425             {
00426                 ir=p;
00427                 m_sym=order-o-1;
00428             }
00429             else
00430             {
00431                 ir=p;
00432                 m_sym=p;
00433                 o_sym=p-order+o+1;
00434                 jl=order-ir;
00435                 imer_mask=(kmer_mask>>(num_bits*o_sym));
00436                 jmer_mask=(kmer_mask>>(num_bits*jl));
00437             }
00438 
00439             float64_t marginalizer=
00440                 1.0/CMath::pow((int32_t) num_words,(int32_t) m_sym);
00441             
00442             for (uint32_t i=0; i<words; i++)
00443             {
00444                 uint16_t x= ((i << (num_bits*il)) >> (num_bits*ir)) & imer_mask;
00445 
00446                 if (p>=0 && p<order-o)
00447                 {
00448 //#define DEBUG_COMMSCORING
00449 #ifdef DEBUG_COMMSCORING
00450                     SG_PRINT("o=%d/%d p=%d/%d i=0x%x x=0x%x imask=%x jmask=%x kmask=%x il=%d ir=%d marg=%g o_sym:%d m_sym:%d weight(",
00451                             o,order, p,order, i, x, imer_mask, jmer_mask, kmer_mask, il, ir, marginalizer, o_sym, m_sym);
00452 
00453                     SG_PRINT("%c%c%c%c/%c%c%c%c)+=%g/%g\n", 
00454                             alpha->remap_to_char((x>>(3*num_bits))&0x03), alpha->remap_to_char((x>>(2*num_bits))&0x03),
00455                             alpha->remap_to_char((x>>num_bits)&0x03), alpha->remap_to_char(x&0x03),
00456                             alpha->remap_to_char((i>>(3*num_bits))&0x03), alpha->remap_to_char((i>>(2*num_bits))&0x03),
00457                             alpha->remap_to_char((i>>(1*num_bits))&0x03), alpha->remap_to_char(i&0x03),
00458                             dictionary_weights[i]*marginalizer, dictionary_weights[i]);
00459 #endif
00460                     contrib[x]+=dictionary_weights[i]*marginalizer;
00461                 }
00462                 else
00463                 {
00464                     for (uint32_t j=0; j< (uint32_t) CMath::pow((int32_t) num_words, (int32_t) o_sym); j++)
00465                     {
00466                         uint32_t c=x | ((j & jmer_mask) << (num_bits*jl));
00467 #ifdef DEBUG_COMMSCORING
00468 
00469                         SG_PRINT("o=%d/%d p=%d/%d i=0x%x j=0x%x x=0x%x c=0x%x imask=%x jmask=%x kmask=%x il=%d ir=%d jl=%d marg=%g o_sym:%d m_sym:%d weight(",
00470                                 o,order, p,order, i, j, x, c, imer_mask, jmer_mask, kmer_mask, il, ir, jl, marginalizer, o_sym, m_sym);
00471                         SG_PRINT("%c%c%c%c/%c%c%c%c)+=%g/%g\n", 
00472                                 alpha->remap_to_char((c>>(3*num_bits))&0x03), alpha->remap_to_char((c>>(2*num_bits))&0x03),
00473                                 alpha->remap_to_char((c>>num_bits)&0x03), alpha->remap_to_char(c&0x03),
00474                                 alpha->remap_to_char((i>>(3*num_bits))&0x03), alpha->remap_to_char((i>>(2*num_bits))&0x03),
00475                                 alpha->remap_to_char((i>>(1*num_bits))&0x03), alpha->remap_to_char(i&0x03),
00476                                 dictionary_weights[i]*marginalizer, dictionary_weights[i]);
00477 #endif
00478                         contrib[c]+=dictionary_weights[i]*marginalizer;
00479                     }
00480                 }
00481             }
00482         }
00483     }
00484 
00485     for (int32_t i=1; i<num_feat; i++)
00486         memcpy(&target[num_sym*i], target, num_sym*sizeof(float64_t));
00487 
00488     SG_UNREF(alpha);
00489 
00490     return target;
00491 }
00492 
00493 
00494 char* CCommWordStringKernel::compute_consensus(
00495     int32_t &result_len, int32_t num_suppvec, int32_t* IDX, float64_t* alphas)
00496 {
00497     ASSERT(lhs);
00498     ASSERT(IDX);
00499     ASSERT(alphas);
00500 
00501     CStringFeatures<uint16_t>* str=((CStringFeatures<uint16_t>*) lhs);
00502     int32_t num_words=(int32_t) str->get_num_symbols();
00503     int32_t num_feat=str->get_max_vector_length();
00504     int64_t total_len=((int64_t) num_feat) * num_words;
00505     CAlphabet* alpha=((CStringFeatures<uint16_t>*) lhs)->get_alphabet();
00506     ASSERT(alpha);
00507     int32_t num_bits=alpha->get_num_bits();
00508     int32_t order=str->get_order();
00509     int32_t max_idx=-1;
00510     float64_t max_score=0; 
00511     result_len=num_feat+order-1;
00512 
00513     //init
00514     init_optimization(num_suppvec, IDX, alphas);
00515 
00516     char* result=new char[result_len];
00517     int32_t* bt=new int32_t[total_len];
00518     float64_t* score=new float64_t[total_len];
00519 
00520     for (int64_t i=0; i<total_len; i++)
00521     {
00522         bt[i]=-1;
00523         score[i]=0;
00524     }
00525 
00526     for (int32_t t=0; t<num_words; t++)
00527         score[t]=dictionary_weights[t];
00528 
00529     //dynamic program
00530     for (int32_t i=1; i<num_feat; i++)
00531     {
00532         for (int32_t t1=0; t1<num_words; t1++)
00533         {
00534             max_idx=-1;
00535             max_score=0; 
00536 
00537             /* ignore weights the svm does not care about 
00538              * (has not seen in training). note that this assumes that zero 
00539              * weights are very unlikely to appear elsewise */
00540 
00541             //if (dictionary_weights[t1]==0.0)
00542                 //continue;
00543 
00544             /* iterate over words t ending on t1 and find the highest scoring
00545              * pair */
00546             uint16_t suffix=(uint16_t) t1 >> num_bits;
00547 
00548             for (int32_t sym=0; sym<str->get_original_num_symbols(); sym++)
00549             {
00550                 uint16_t t=suffix | sym << (num_bits*(order-1));
00551 
00552                 //if (dictionary_weights[t]==0.0)
00553                 //  continue;
00554 
00555                 float64_t sc=score[num_words*(i-1) + t]+dictionary_weights[t1];
00556                 if (sc > max_score || max_idx==-1)
00557                 {
00558                     max_idx=t;
00559                     max_score=sc;
00560                 }
00561             }
00562             ASSERT(max_idx!=-1);
00563 
00564             score[num_words*i + t1]=max_score;
00565             bt[num_words*i + t1]=max_idx;
00566         }
00567     }
00568 
00569     //backtracking
00570     max_idx=0;
00571     max_score=score[num_words*(num_feat-1) + 0];
00572     for (int32_t t=1; t<num_words; t++)
00573     {
00574         float64_t sc=score[num_words*(num_feat-1) + t];
00575         if (sc>max_score)
00576         {
00577             max_idx=t;
00578             max_score=sc;
00579         }
00580     }
00581 
00582     SG_PRINT("max_idx:%i, max_score:%f\n", max_idx, max_score);
00583     
00584     for (int32_t i=result_len-1; i>=num_feat; i--)
00585         result[i]=alpha->remap_to_char( (uint8_t) str->get_masked_symbols( (uint16_t) max_idx >> (num_bits*(result_len-1-i)), 1) );
00586 
00587     for (int32_t i=num_feat-1; i>=0; i--)
00588     {
00589         result[i]=alpha->remap_to_char( (uint8_t) str->get_masked_symbols( (uint16_t) max_idx >> (num_bits*(order-1)), 1) );
00590         max_idx=bt[num_words*i + max_idx];
00591     }
00592 
00593     delete[] bt;
00594     delete[] score;
00595     SG_UNREF(alpha);
00596     return result;
00597 }

SHOGUN Machine Learning Toolbox - Documentation