Package mvpa :: Package tests :: Module test_transerror
[hide private]
[frames] | no frames]

Source Code for Module mvpa.tests.test_transerror

  1  # emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- 
  2  # vi: set ft=python sts=4 ts=4 sw=4 et: 
  3  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  4  # 
  5  #   See COPYING file distributed along with the PyMVPA package for the 
  6  #   copyright and license terms. 
  7  # 
  8  ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## 
  9  """Unit tests for PyMVPA classifier cross-validation""" 
 10   
 11  import unittest 
 12  from mvpa.support.copy import copy 
 13   
 14  from mvpa.base import externals 
 15  from mvpa.datasets import Dataset 
 16  from mvpa.datasets.splitters import OddEvenSplitter 
 17   
 18  from mvpa.clfs.meta import MulticlassClassifier 
 19  from mvpa.clfs.transerror import \ 
 20       TransferError, ConfusionMatrix, ConfusionBasedError 
 21  from mvpa.algorithms.cvtranserror import CrossValidatedTransferError 
 22   
 23  from mvpa.clfs.stats import MCNullDist 
 24   
 25  from mvpa.misc.exceptions import UnknownStateError 
 26   
 27  from tests_warehouse import datasets, sweepargs 
 28  from tests_warehouse_clfs import * 
29 30 -class ErrorsTests(unittest.TestCase):
31
32 - def testConfusionMatrix(self):
33 data = N.array([1,2,1,2,2,2,3,2,1], ndmin=2).T 34 reg = [1,1,1,2,2,2,3,3,3] 35 regl = [1,2,1,2,2,2,3,2,1] 36 correct_cm = [[2,0,1],[1,3,1],[0,0,1]] 37 # Check if we are ok with any input type - either list, or N.array, or tuple 38 for t in [reg, tuple(reg), list(reg), N.array(reg)]: 39 for p in [regl, tuple(regl), list(regl), N.array(regl)]: 40 cm = ConfusionMatrix(targets=t, predictions=p) 41 # check table content 42 self.failUnless((cm.matrix == correct_cm).all()) 43 44 45 # Do a bit more thorough checking 46 cm = ConfusionMatrix() 47 self.failUnlessRaises(ZeroDivisionError, lambda x:x.percentCorrect, cm) 48 """No samples -- raise exception""" 49 50 cm.add(reg, regl) 51 52 self.failUnlessEqual(len(cm.sets), 1, 53 msg="Should have a single set so far") 54 self.failUnlessEqual(cm.matrix.shape, (3,3), 55 msg="should be square matrix (len(reglabels) x len(reglabels)") 56 57 self.failUnlessRaises(ValueError, cm.add, reg, N.array([1])) 58 """ConfusionMatrix must complaint if number of samples different""" 59 60 # check table content 61 self.failUnless((cm.matrix == correct_cm).all()) 62 63 # lets add with new labels (not yet known) 64 cm.add(reg, N.array([1,4,1,2,2,2,4,2,1])) 65 66 self.failUnlessEqual(cm.labels, [1,2,3,4], 67 msg="We should have gotten 4th label") 68 69 matrices = cm.matrices # separate CM per each given set 70 self.failUnlessEqual(len(matrices), 2, 71 msg="Have gotten two splits") 72 73 self.failUnless((matrices[0].matrix + matrices[1].matrix == cm.matrix).all(), 74 msg="Total votes should match the sum across split CMs") 75 76 # check pretty print 77 # just a silly test to make sure that printing works 78 self.failUnless(len(cm.asstring( 79 header=True, summary=True, 80 description=True))>100) 81 self.failUnless(len(str(cm))>100) 82 # and that it knows some parameters for printing 83 self.failUnless(len(cm.asstring(summary=True, 84 header=False))>100) 85 86 # lets check iadd -- just itself to itself 87 cm += cm 88 self.failUnlessEqual(len(cm.matrices), 4, msg="Must be 4 sets now") 89 90 # lets check add -- just itself to itself 91 cm2 = cm + cm 92 self.failUnlessEqual(len(cm2.matrices), 8, msg="Must be 8 sets now") 93 self.failUnlessEqual(cm2.percentCorrect, cm.percentCorrect, 94 msg="Percent of corrrect should remain the same ;-)") 95 96 self.failUnlessEqual(cm2.error, 1.0-cm.percentCorrect/100.0, 97 msg="Test if we get proper error value")
98 99
100 - def testDegenerateConfusion(self):
101 # We must not just puke -- some testing splits might 102 # have just a single target label 103 104 for orig in ([1], [1, 1], [0], [0, 0]): 105 cm = ConfusionMatrix(targets=orig, predictions=orig, values=orig) 106 107 scm = str(cm) 108 self.failUnless(cm.stats['ACC%'] == 100)
109 110
111 - def testConfusionMatrixACC(self):
112 reg = [0,0,1,1] 113 regl = [1,0,1,0] 114 cm = ConfusionMatrix(targets=reg, predictions=regl) 115 self.failUnless('ACC% 50' in str(cm))
116 117
119 data = N.array([1,2,1,2,2,2,3,2,1], ndmin=2).T 120 reg = [1,1,1,2,2,2,3,3,3] 121 regl = [1,2,1,2,2,2,3,2,1] 122 correct_cm = [[2,0,1], [1,3,1], [0,0,1]] 123 lm = {'apple':1, 'orange':2, 'shitty apple':1, 'candy':3} 124 cm = ConfusionMatrix(targets=reg, predictions=regl, 125 labels_map=lm) 126 # check table content 127 self.failUnless((cm.matrix == correct_cm).all()) 128 # assure that all labels are somewhere listed ;-) 129 s = str(cm) 130 for l in lm.keys(): 131 self.failUnless(l in s)
132 133 134 135 @sweepargs(l_clf=clfswh['linear', 'svm'])
136 - def testConfusionBasedError(self, l_clf):
137 train = datasets['uni2medium_train'] 138 # to check if we fail to classify for 3 labels 139 test3 = datasets['uni3medium_train'] 140 err = ConfusionBasedError(clf=l_clf) 141 terr = TransferError(clf=l_clf) 142 143 self.failUnlessRaises(UnknownStateError, err, None) 144 """Shouldn't be able to access the state yet""" 145 146 l_clf.train(train) 147 self.failUnlessEqual(err(None), terr(train), 148 msg="ConfusionBasedError should be equal to TransferError on" + 149 " traindataset") 150 151 # this will print nasty WARNING but it is ok -- it is just checking code 152 # NB warnings are not printed while doing whole testing 153 self.failIf(terr(test3) is None) 154 155 # try copying the beast 156 terr_copy = copy(terr)
157 158 159 @sweepargs(l_clf=clfswh['linear', 'svm'])
160 - def testNullDistProb(self, l_clf):
161 train = datasets['uni2medium'] 162 163 # define class to estimate NULL distribution of errors 164 # use left tail of the distribution since we use MeanMatchFx as error 165 # function and lower is better 166 terr = TransferError(clf=l_clf, 167 null_dist=MCNullDist(permutations=10, 168 tail='left')) 169 170 # check reasonable error range 171 err = terr(train, train) 172 self.failUnless(err < 0.4) 173 174 # check that the result is highly significant since we know that the 175 # data has signal 176 null_prob = terr.null_prob 177 self.failUnless(null_prob < 0.01, 178 msg="Failed to check that the result is highly significant " 179 "(got %f) since we know that the data has signal" 180 % null_prob)
181 182 183 @sweepargs(l_clf=clfswh['linear', 'svm'])
184 - def testPerSampleError(self, l_clf):
185 train = datasets['uni2medium'] 186 terr = TransferError(clf=l_clf, enable_states=['samples_error']) 187 err = terr(train, train) 188 se = terr.samples_error 189 190 # one error per sample 191 self.failUnless(len(se) == train.nsamples) 192 # for this simple test it can only be correct or misclassified 193 # (boolean) 194 self.failUnless( 195 N.sum(N.array(se.values(), dtype='float') \ 196 - N.array(se.values(), dtype='b')) == 0)
197 198 199 @sweepargs(clf=clfswh['multiclass'])
200 - def testAUC(self, clf):
201 """Test AUC computation 202 """ 203 if isinstance(clf, MulticlassClassifier): 204 # TODO: handle those values correctly 205 return 206 clf.states._changeTemporarily(enable_states = ['values']) 207 # uni2 dataset with reordered labels 208 ds2 = datasets['uni2small'].copy() 209 ds2.labels = 1 - ds2.labels # revert labels 210 # same with uni3 211 ds3 = datasets['uni3small'].copy() 212 ul = ds3.uniquelabels 213 nl = ds3.labels.copy() 214 for l in xrange(3): 215 nl[ds3.labels == ul[l]] = ul[(l+1)%3] 216 ds3.labels = nl 217 for ds in [datasets['uni2small'], ds2, 218 datasets['uni3small'], ds3]: 219 cv = CrossValidatedTransferError( 220 TransferError(clf), 221 OddEvenSplitter(), 222 enable_states=['confusion', 'training_confusion']) 223 cverror = cv(ds) 224 stats = cv.confusion.stats 225 Nlabels = len(ds.uniquelabels) 226 # so we at least do slightly above chance 227 self.failUnless(stats['ACC'] > 1.2 / Nlabels) 228 auc = stats['AUC'] 229 if (Nlabels == 2) or (Nlabels > 2 and auc[0] is not N.nan): 230 mauc = N.min(stats['AUC']) 231 if cfg.getboolean('tests', 'labile', default='yes'): 232 self.failUnless(mauc > 0.55, 233 msg='All AUCs must be above chance. Got minimal ' 234 'AUC=%.2g among %s' % (mauc, stats['AUC'])) 235 clf.states._resetEnabledTemporarily()
236 237 238 239
240 - def testConfusionPlot(self):
241 """Based on existing cell dataset results. 242 243 Let in for possible future testing, but is not a part of the 244 unittests suite 245 """ 246 #from matplotlib import rc as rcmpl 247 #rcmpl('font',**{'family':'sans-serif','sans-serif':['DejaVu Sans']}) 248 ##rcmpl('text', usetex=True) 249 ##rcmpl('font', family='sans', style='normal', variant='normal', 250 ## weight='bold', stretch='normal', size='large') 251 #import numpy as N 252 #from mvpa.clfs.transerror import \ 253 # TransferError, ConfusionMatrix, ConfusionBasedError 254 255 array = N.array 256 uint8 = N.uint8 257 sets = [ 258 (array([47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 259 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 260 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 261 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 262 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 263 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 264 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 265 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 266 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 267 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 268 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 269 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 270 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44], dtype=uint8), 271 array([40, 39, 47, 43, 45, 41, 44, 41, 46, 42, 47, 39, 38, 43, 45, 41, 44, 272 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 46, 273 45, 38, 44, 39, 46, 38, 39, 39, 38, 43, 45, 41, 44, 40, 46, 42, 38, 274 40, 47, 43, 45, 41, 44, 40, 46, 42, 38, 39, 40, 43, 45, 41, 44, 39, 275 46, 42, 47, 38, 38, 43, 45, 41, 44, 38, 46, 42, 47, 38, 39, 43, 45, 276 41, 44, 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 40, 46, 42, 47, 38, 277 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 47, 43, 45, 41, 44, 40, 46, 278 42, 47, 38, 38, 43, 45, 41, 44, 40, 46, 42, 39, 39, 38, 43, 45, 41, 279 44, 47, 46, 42, 47, 38, 39, 43, 45, 40, 44, 40, 46, 42, 47, 39, 40, 280 43, 45, 41, 44, 38, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 41, 281 47, 39, 38, 46, 45, 41, 44, 40, 46, 42, 40, 38, 38, 43, 45, 41, 44, 282 40, 45, 42, 47, 39, 39, 43, 45, 41, 44, 38, 46, 42, 47, 38, 42, 43, 283 45, 41, 44, 39, 46, 42, 39, 39, 39, 47, 45, 41, 44], dtype=uint8)), 284 (array([40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 285 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 286 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 287 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 288 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 289 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 290 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 291 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 292 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 293 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 294 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 295 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 296 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43], dtype=uint8), 297 array([40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 47, 46, 42, 47, 39, 40, 43, 298 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 299 39, 38, 43, 45, 41, 44, 39, 46, 42, 47, 47, 47, 43, 45, 41, 44, 40, 300 46, 42, 43, 39, 38, 43, 45, 41, 44, 38, 38, 42, 38, 39, 38, 43, 45, 301 41, 44, 40, 46, 42, 47, 40, 38, 43, 45, 41, 44, 40, 40, 42, 47, 40, 302 40, 43, 45, 41, 44, 38, 38, 42, 47, 38, 38, 47, 45, 41, 44, 40, 46, 303 42, 47, 39, 40, 43, 45, 41, 44, 40, 46, 42, 47, 47, 39, 43, 45, 41, 304 44, 40, 46, 42, 39, 39, 42, 43, 45, 41, 44, 40, 46, 42, 47, 39, 39, 305 43, 45, 41, 44, 47, 46, 42, 40, 39, 39, 43, 45, 41, 44, 40, 46, 42, 306 47, 39, 38, 43, 45, 40, 44, 40, 46, 42, 47, 39, 39, 43, 45, 41, 44, 307 38, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40, 46, 46, 47, 38, 39, 43, 308 45, 41, 44, 40, 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 39, 309 39, 38, 47, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43], dtype=uint8)), 310 (array([45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 311 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 312 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 313 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 314 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 315 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 316 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 317 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 318 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 319 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 320 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 321 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 322 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47], dtype=uint8), 323 array([45, 41, 44, 40, 46, 42, 47, 39, 46, 43, 45, 41, 44, 40, 46, 42, 47, 324 39, 39, 43, 45, 41, 44, 38, 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 325 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 47, 39, 43, 43, 45, 326 40, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 47, 327 40, 43, 45, 41, 44, 40, 47, 42, 38, 47, 38, 43, 45, 41, 44, 40, 40, 328 42, 47, 39, 39, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 329 44, 38, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40, 46, 42, 47, 40, 38, 330 43, 45, 41, 44, 40, 46, 38, 38, 39, 38, 43, 45, 41, 44, 39, 46, 42, 331 47, 40, 39, 43, 45, 38, 44, 38, 46, 42, 47, 47, 40, 43, 45, 41, 44, 332 40, 40, 42, 47, 40, 38, 43, 39, 41, 44, 41, 46, 42, 39, 39, 38, 38, 333 45, 41, 44, 38, 46, 40, 46, 46, 46, 43, 45, 38, 44, 40, 46, 42, 39, 334 39, 45, 43, 45, 41, 44, 38, 46, 42, 38, 39, 39, 43, 45, 41, 38, 40, 335 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 40], dtype=uint8)), 336 (array([39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 337 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 338 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 339 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 340 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 341 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 342 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 343 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 344 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 345 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 346 39, 38, 43, 45, 41, 44, 40, 46, 42, 39, 38, 43, 45, 41, 44, 40, 46, 347 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 348 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40], dtype=uint8), 349 array([39, 38, 43, 45, 41, 44, 40, 46, 38, 47, 39, 38, 43, 45, 41, 44, 40, 350 46, 42, 47, 39, 38, 43, 45, 41, 44, 41, 46, 42, 47, 39, 38, 43, 45, 351 41, 44, 40, 38, 43, 47, 38, 38, 43, 45, 41, 44, 39, 46, 42, 39, 39, 352 38, 43, 45, 41, 44, 43, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40, 46, 353 42, 47, 39, 40, 43, 45, 41, 44, 40, 46, 42, 39, 38, 38, 43, 45, 40, 354 44, 47, 46, 38, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 38, 39, 38, 355 43, 45, 41, 44, 40, 46, 42, 38, 39, 38, 43, 45, 47, 44, 45, 46, 42, 356 38, 39, 41, 43, 45, 41, 44, 38, 38, 42, 39, 40, 40, 43, 45, 41, 39, 357 40, 46, 42, 47, 39, 40, 43, 45, 41, 44, 40, 47, 42, 47, 38, 38, 43, 358 45, 41, 44, 47, 46, 42, 47, 40, 47, 43, 45, 41, 44, 40, 46, 42, 47, 359 38, 39, 43, 45, 41, 44, 40, 46, 42, 39, 38, 43, 45, 46, 44, 38, 46, 360 42, 47, 38, 44, 43, 45, 42, 44, 41, 46, 42, 47, 47, 38, 43, 45, 41, 361 44, 38, 46, 42, 39, 39, 38, 43, 45, 41, 44, 40], dtype=uint8)), 362 (array([46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 363 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 364 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 365 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 366 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 367 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 368 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 369 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 370 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 371 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 372 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 373 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 374 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45], dtype=uint8), 375 array([46, 42, 39, 38, 38, 43, 45, 41, 44, 40, 46, 42, 47, 47, 42, 43, 45, 376 42, 44, 40, 46, 42, 38, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 47, 377 40, 43, 45, 41, 44, 41, 46, 42, 38, 39, 38, 43, 45, 41, 44, 38, 46, 378 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 46, 38, 38, 43, 45, 41, 379 44, 39, 46, 42, 47, 39, 40, 43, 45, 41, 44, 40, 46, 42, 47, 39, 39, 380 43, 45, 41, 44, 40, 47, 42, 47, 38, 39, 43, 45, 41, 44, 39, 46, 42, 381 47, 39, 46, 43, 45, 41, 44, 39, 46, 42, 39, 39, 38, 43, 45, 41, 44, 382 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 40, 46, 42, 39, 39, 38, 43, 383 45, 41, 44, 40, 38, 42, 46, 39, 38, 43, 45, 41, 44, 38, 46, 42, 46, 384 46, 38, 43, 45, 41, 44, 40, 46, 42, 47, 47, 38, 38, 45, 41, 44, 38, 385 38, 42, 43, 39, 40, 43, 45, 41, 44, 38, 46, 42, 47, 38, 39, 47, 45, 386 46, 44, 40, 46, 42, 47, 40, 38, 43, 45, 41, 44, 40, 46, 42, 47, 40, 387 38, 43, 45, 41, 44, 38, 46, 42, 38, 39, 38, 47, 45], dtype=uint8)), 388 (array([41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 389 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 390 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 391 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 392 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 393 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 394 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 395 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 396 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 397 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 398 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 399 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 400 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39], dtype=uint8), 401 array([41, 44, 38, 46, 42, 47, 39, 47, 40, 45, 41, 44, 40, 46, 42, 38, 40, 402 38, 43, 45, 41, 44, 40, 46, 42, 38, 38, 38, 43, 45, 41, 44, 46, 38, 403 42, 40, 38, 39, 43, 45, 41, 44, 41, 46, 42, 47, 47, 38, 43, 45, 41, 404 44, 40, 46, 42, 38, 39, 39, 43, 45, 41, 44, 38, 46, 42, 47, 43, 39, 405 43, 45, 41, 44, 40, 46, 42, 38, 39, 38, 43, 45, 41, 44, 40, 46, 42, 406 40, 39, 38, 43, 45, 41, 44, 38, 46, 42, 39, 39, 39, 43, 45, 41, 44, 407 40, 46, 42, 39, 38, 47, 43, 45, 38, 44, 40, 38, 42, 47, 38, 38, 43, 408 45, 41, 44, 40, 38, 46, 47, 38, 38, 43, 45, 41, 44, 41, 46, 42, 40, 409 38, 38, 40, 45, 41, 44, 40, 40, 42, 43, 38, 40, 43, 39, 41, 44, 40, 410 40, 42, 47, 38, 46, 43, 45, 41, 44, 47, 41, 42, 43, 40, 47, 43, 45, 411 41, 44, 41, 38, 42, 40, 39, 40, 43, 45, 41, 44, 39, 43, 42, 47, 39, 412 40, 43, 45, 41, 44, 42, 46, 42, 47, 40, 46, 43, 45, 41, 44, 38, 46, 413 42, 47, 47, 38, 43, 45, 41, 44, 40, 38, 39, 47, 38], dtype=uint8)), 414 (array([38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 415 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 416 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 417 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 418 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 419 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 420 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 421 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 422 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 423 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 424 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 425 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 426 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46], dtype=uint8), 427 array([39, 43, 45, 41, 44, 40, 46, 42, 47, 38, 38, 43, 45, 41, 44, 41, 46, 428 42, 47, 47, 39, 43, 45, 41, 44, 40, 46, 42, 47, 38, 39, 43, 45, 41, 429 44, 40, 46, 42, 47, 39, 40, 43, 45, 41, 44, 40, 46, 42, 47, 45, 38, 430 43, 45, 41, 44, 38, 46, 42, 47, 38, 39, 43, 45, 41, 44, 40, 46, 42, 431 39, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 432 40, 46, 42, 47, 40, 39, 43, 45, 41, 44, 40, 39, 42, 40, 39, 38, 43, 433 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 38, 46, 42, 39, 434 39, 47, 43, 45, 41, 44, 40, 46, 42, 47, 39, 39, 43, 45, 41, 44, 40, 435 46, 42, 46, 47, 39, 47, 45, 41, 44, 40, 46, 42, 47, 39, 39, 43, 45, 436 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 38, 46, 42, 47, 39, 437 38, 43, 45, 42, 44, 39, 47, 42, 39, 39, 47, 43, 47, 40, 44, 40, 46, 438 42, 39, 39, 38, 39, 45, 41, 44, 40, 46, 42, 47, 38, 38, 43, 45, 41, 439 44, 46, 38, 42, 47, 39, 43, 43, 45, 41, 44, 40, 46], dtype=uint8)), 440 (array([42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 441 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 442 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 443 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 444 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 445 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 446 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 447 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 448 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 449 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 450 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 451 44, 40, 46, 42, 47, 39, 38, 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 452 43, 45, 41, 44, 40, 46, 42, 47, 39, 38, 43, 45], dtype=uint8), 453 array([42, 38, 38, 40, 43, 45, 41, 44, 39, 46, 42, 47, 39, 38, 43, 45, 41, 454 44, 39, 38, 42, 47, 41, 40, 43, 45, 41, 44, 40, 41, 42, 47, 38, 46, 455 43, 45, 41, 44, 41, 41, 42, 40, 39, 39, 43, 45, 41, 44, 46, 45, 42, 456 39, 39, 40, 43, 45, 41, 44, 40, 46, 42, 40, 44, 38, 43, 41, 41, 44, 457 39, 46, 42, 39, 39, 39, 43, 45, 41, 44, 40, 43, 42, 47, 39, 39, 43, 458 45, 41, 44, 40, 47, 42, 38, 46, 39, 47, 45, 41, 44, 39, 46, 42, 47, 459 41, 38, 43, 45, 41, 44, 42, 46, 42, 46, 39, 38, 43, 45, 41, 44, 41, 460 46, 42, 46, 39, 38, 43, 45, 41, 44, 40, 46, 42, 38, 38, 38, 43, 45, 461 41, 44, 38, 46, 42, 39, 40, 43, 43, 45, 41, 44, 39, 38, 40, 40, 38, 462 38, 43, 45, 41, 44, 41, 40, 42, 39, 39, 39, 43, 45, 41, 44, 40, 46, 463 42, 47, 40, 40, 43, 45, 41, 44, 40, 46, 42, 41, 39, 39, 43, 45, 41, 464 44, 40, 38, 42, 40, 39, 46, 43, 45, 41, 44, 47, 46, 42, 47, 39, 38, 465 43, 45, 41, 44, 41, 46, 42, 43, 39, 39, 43, 45], dtype=uint8))] 466 labels_map = {'12kHz': 40, 467 '20kHz': 41, 468 '30kHz': 42, 469 '3kHz': 38, 470 '7kHz': 39, 471 'song1': 43, 472 'song2': 44, 473 'song3': 45, 474 'song4': 46, 475 'song5': 47} 476 try: 477 cm = ConfusionMatrix(sets=sets, labels_map=labels_map) 478 except: 479 self.fail() 480 self.failUnless('3kHz / 38' in cm.asstring()) 481 482 if externals.exists("pylab plottable"): 483 import pylab as P 484 P.figure() 485 labels_order = ("3kHz", "7kHz", "12kHz", "20kHz","30kHz", None, 486 "song1","song2","song3","song4","song5") 487 #print cm 488 #fig, im, cb = cm.plot(origin='lower', labels=labels_order) 489 fig, im, cb = cm.plot(labels=labels_order[1:2] + labels_order[:1] 490 + labels_order[2:], numbers=True) 491 self.failUnless(cm._plotted_confusionmatrix[0,0] == cm.matrix[1,1]) 492 self.failUnless(cm._plotted_confusionmatrix[0,1] == cm.matrix[1,0]) 493 self.failUnless(cm._plotted_confusionmatrix[1,1] == cm.matrix[0,0]) 494 self.failUnless(cm._plotted_confusionmatrix[1,0] == cm.matrix[0,1]) 495 P.close(fig) 496 fig, im, cb = cm.plot(labels=labels_order, numbers=True) 497 P.close(fig)
498 # P.show() 499
500 - def testConfusionPlot2(self):
501 """Based on a sample confusion which plots incorrectly 502 503 """ 504 505 array = N.array 506 uint8 = N.uint8 507 sets = [(array([1, 2]), array([1, 1]), 508 array([[ 0.54343765, 0.45656235], 509 [ 0.92395853, 0.07604147]])), 510 (array([1, 2]), array([1, 1]), 511 array([[ 0.98030832, 0.01969168], 512 [ 0.78998763, 0.21001237]])), 513 (array([1, 2]), array([1, 1]), 514 array([[ 0.86125263, 0.13874737], 515 [ 0.83674113, 0.16325887]])), 516 (array([1, 2]), array([1, 1]), 517 array([[ 0.57870383, 0.42129617], 518 [ 0.59702509, 0.40297491]])), 519 (array([1, 2]), array([1, 1]), 520 array([[ 0.89530255, 0.10469745], 521 [ 0.69373919, 0.30626081]])), 522 (array([1, 2]), array([1, 1]), 523 array([[ 0.75015218, 0.24984782], 524 [ 0.9339767 , 0.0660233 ]])), 525 (array([1, 2]), array([1, 2]), 526 array([[ 0.97826616, 0.02173384], 527 [ 0.38620638, 0.61379362]])), 528 (array([2]), array([2]), 529 array([[ 0.46893776, 0.53106224]]))] 530 try: 531 cm = ConfusionMatrix(sets=sets) 532 except: 533 self.fail() 534 if externals.exists("pylab plottable"): 535 import pylab as P 536 #P.figure() 537 #print cm 538 fig, im, cb = cm.plot(origin='lower', numbers=True) 539 #P.plot() 540 self.failUnless((cm._plotted_confusionmatrix == cm.matrix).all()) 541 P.close(fig)
542 #fig, im, cb = cm.plot(labels=labels_order, numbers=True)
543 #P.close(fig) 544 #P.show() 545 546 547 -def suite():
548 return unittest.makeSuite(ErrorsTests)
549 550 551 if __name__ == '__main__': 552 import runner 553