1
2
3
4
5
6
7
8
9 """Base class for all classifiers.
10
11 At the moment, regressions are treated just as a special case of
12 classifier (or vise verse), so the same base class `Classifier` is
13 utilized for both kinds.
14 """
15
16 __docformat__ = 'restructuredtext'
17
18 import numpy as N
19
20 from mvpa.support.copy import deepcopy
21
22 import time
23
24 from mvpa.misc.support import idhash
25 from mvpa.misc.state import StateVariable, ClassWithCollections
26 from mvpa.misc.param import Parameter
27
28 from mvpa.clfs.transerror import ConfusionMatrix, RegressionStatistics
29
30 from mvpa.base import warning
31
32 if __debug__:
33 from mvpa.base import debug
37 """Abstract classifier class to be inherited by all classifiers
38 """
39
40
41
42 _DEV__doc__ = """
43 Required behavior:
44
45 For every classifier is has to be possible to be instantiated without
46 having to specify the training pattern.
47
48 Repeated calls to the train() method with different training data have to
49 result in a valid classifier, trained for the particular dataset.
50
51 It must be possible to specify all classifier parameters as keyword
52 arguments to the constructor.
53
54 Recommended behavior:
55
56 Derived classifiers should provide access to *values* -- i.e. that
57 information that is finally used to determine the predicted class label.
58
59 Michael: Maybe it works well if each classifier provides a 'values'
60 state member. This variable is a list as long as and in same order
61 as Dataset.uniquelabels (training data). Each item in the list
62 corresponds to the likelyhood of a sample to belong to the
63 respective class. However the semantics might differ between
64 classifiers, e.g. kNN would probably store distances to class-
65 neighbors, where PLR would store the raw function value of the
66 logistic function. So in the case of kNN low is predictive and for
67 PLR high is predictive. Don't know if there is the need to unify
68 that.
69
70 As the storage and/or computation of this information might be
71 demanding its collection should be switchable and off be default.
72
73 Nomenclature
74 * predictions : corresponds to the quantized labels if classifier spits
75 out labels by .predict()
76 * values : might be different from predictions if a classifier's predict()
77 makes a decision based on some internal value such as
78 probability or a distance.
79 """
80
81
82
83
84
85
86
87
88
89 trained_labels = StateVariable(enabled=True,
90 doc="Set of unique labels it has been trained on")
91
92 trained_nsamples = StateVariable(enabled=True,
93 doc="Number of samples it has been trained on")
94
95 trained_dataset = StateVariable(enabled=False,
96 doc="The dataset it has been trained on")
97
98 training_confusion = StateVariable(enabled=False,
99 doc="Confusion matrix of learning performance")
100
101 predictions = StateVariable(enabled=True,
102 doc="Most recent set of predictions")
103
104 values = StateVariable(enabled=True,
105 doc="Internal classifier values the most recent " +
106 "predictions are based on")
107
108 training_time = StateVariable(enabled=True,
109 doc="Time (in seconds) which took classifier to train")
110
111 predicting_time = StateVariable(enabled=True,
112 doc="Time (in seconds) which took classifier to predict")
113
114 feature_ids = StateVariable(enabled=False,
115 doc="Feature IDS which were used for the actual training.")
116
117 _clf_internals = []
118 """Describes some specifics about the classifier -- is that it is
119 doing regression for instance...."""
120
121 regression = Parameter(False, allowedtype='bool',
122 doc="""Either to use 'regression' as regression. By default any
123 Classifier-derived class serves as a classifier, so regression
124 does binary classification.""", index=1001)
125
126
127 retrainable = Parameter(False, allowedtype='bool',
128 doc="""Either to enable retraining for 'retrainable' classifier.""",
129 index=1002)
130
131
133 """Cheap initialization.
134 """
135 ClassWithCollections.__init__(self, **kwargs)
136
137
138 self.__trainednfeatures = None
139 """Stores number of features for which classifier was trained.
140 If None -- it wasn't trained at all"""
141
142 self._setRetrainable(self.params.retrainable, force=True)
143
144 if self.params.regression:
145 for statevar in [ "trained_labels"]:
146 if self.states.isEnabled(statevar):
147 if __debug__:
148 debug("CLF",
149 "Disabling state %s since doing regression, " %
150 statevar + "not classification")
151 self.states.disable(statevar)
152 self._summaryClass = RegressionStatistics
153 else:
154 self._summaryClass = ConfusionMatrix
155 clf_internals = self._clf_internals
156 if 'regression' in clf_internals and not ('binary' in clf_internals):
157
158
159
160 self._clf_internals = clf_internals + ['binary']
161
162
163
164
165
166
167
169 if __debug__ and 'CLF_' in debug.active:
170 return "%s / %s" % (repr(self), super(Classifier, self).__str__())
171 else:
172 return repr(self)
173
176
177
179 """Functionality prior to training
180 """
181
182
183 params = self.params
184 if not params.retrainable:
185 self.untrain()
186 else:
187
188 self.states.reset()
189 if not self.__changedData_isset:
190 self.__resetChangedData()
191 _changedData = self._changedData
192 __idhashes = self.__idhashes
193 __invalidatedChangedData = self.__invalidatedChangedData
194
195
196
197 if __debug__:
198 debug('CLF_', "IDHashes are %s" % (__idhashes))
199
200
201 for key, data_ in (('traindata', dataset.samples),
202 ('labels', dataset.labels)):
203 _changedData[key] = self.__wasDataChanged(key, data_)
204
205
206 if __invalidatedChangedData.get(key, False):
207 if __debug__ and not _changedData[key]:
208 debug('CLF_', 'Found that idhash for %s was '
209 'invalidated by retraining' % key)
210 _changedData[key] = True
211
212
213 for col in self._paramscols:
214 changedParams = self._collections[col].whichSet()
215 if len(changedParams):
216 _changedData[col] = changedParams
217
218 self.__invalidatedChangedData = {}
219
220 if __debug__:
221 debug('CLF_', "Obtained _changedData is %s"
222 % (self._changedData))
223
224 if not params.regression and 'regression' in self._clf_internals \
225 and not self.states.isEnabled('trained_labels'):
226
227
228 if __debug__:
229 debug("CLF", "Enabling trained_labels state since it is needed")
230 self.states.enable('trained_labels')
231
232
233 - def _posttrain(self, dataset):
234 """Functionality post training
235
236 For instance -- computing confusion matrix
237 :Parameters:
238 dataset : Dataset
239 Data which was used for training
240 """
241 if self.states.isEnabled('trained_labels'):
242 self.trained_labels = dataset.uniquelabels
243
244 self.trained_dataset = dataset
245 self.trained_nsamples = dataset.nsamples
246
247
248 self.__trainednfeatures = dataset.nfeatures
249
250 if __debug__ and 'CHECK_TRAINED' in debug.active:
251 self.__trainedidhash = dataset.idhash
252
253 if self.states.isEnabled('training_confusion') and \
254 not self.states.isSet('training_confusion'):
255
256
257 self.states._changeTemporarily(
258 disable_states=["predictions"])
259 if self.params.retrainable:
260
261
262
263
264
265 self.__changedData_isset = False
266 predictions = self.predict(dataset.samples)
267 self.states._resetEnabledTemporarily()
268 self.training_confusion = self._summaryClass(
269 targets=dataset.labels,
270 predictions=predictions)
271
272 try:
273 self.training_confusion.labels_map = dataset.labels_map
274 except:
275 pass
276
277 if self.states.isEnabled('feature_ids'):
278 self.feature_ids = self._getFeatureIds()
279
280
282 """Virtual method to return feature_ids used while training
283
284 Is not intended to be called anywhere but from _posttrain,
285 thus classifier is assumed to be trained at this point
286 """
287
288 return range(self.__trainednfeatures)
289
290
292 """Providing summary over the classifier"""
293
294 s = "Classifier %s" % self
295 states = self.states
296 states_enabled = states.enabled
297
298 if self.trained:
299 s += "\n trained"
300 if states.isSet('training_time'):
301 s += ' in %.3g sec' % states.training_time
302 s += ' on data with'
303 if states.isSet('trained_labels'):
304 s += ' labels:%s' % list(states.trained_labels)
305
306 nsamples, nchunks = None, None
307 if states.isSet('trained_nsamples'):
308 nsamples = states.trained_nsamples
309 if states.isSet('trained_dataset'):
310 td = states.trained_dataset
311 nsamples, nchunks = td.nsamples, len(td.uniquechunks)
312 if nsamples is not None:
313 s += ' #samples:%d' % nsamples
314 if nchunks is not None:
315 s += ' #chunks:%d' % nchunks
316
317 s += " #features:%d" % self.__trainednfeatures
318 if states.isSet('feature_ids'):
319 s += ", used #features:%d" % len(states.feature_ids)
320 if states.isSet('training_confusion'):
321 s += ", training error:%.3g" % states.training_confusion.error
322 else:
323 s += "\n not yet trained"
324
325 if len(states_enabled):
326 s += "\n enabled states:%s" % ', '.join([str(states[x])
327 for x in states_enabled])
328 return s
329
330
332 """Create full copy of the classifier.
333
334 It might require classifier to be untrained first due to
335 present SWIG bindings.
336
337 TODO: think about proper re-implementation, without enrollment of deepcopy
338 """
339 try:
340 return deepcopy(self)
341 except:
342 self.untrain()
343 return deepcopy(self)
344
345
347 """Function to be actually overridden in derived classes
348 """
349 raise NotImplementedError
350
351
352 - def train(self, dataset):
353 """Train classifier on a dataset
354
355 Shouldn't be overridden in subclasses unless explicitly needed
356 to do so
357 """
358 if __debug__:
359 debug("CLF", "Training classifier %(clf)s on dataset %(dataset)s",
360 msgargs={'clf':self, 'dataset':dataset})
361
362 self._pretrain(dataset)
363
364
365 t0 = time.time()
366
367 if dataset.nfeatures > 0:
368 result = self._train(dataset)
369 else:
370 warning("Trying to train on dataset with no features present")
371 if __debug__:
372 debug("CLF",
373 "No features present for training, no actual training " \
374 "is called")
375 result = None
376
377 self.training_time = time.time() - t0
378 self._posttrain(dataset)
379 return result
380
381
383 """Functionality prior prediction
384 """
385 if not ('notrain2predict' in self._clf_internals):
386
387 if not self.trained:
388 raise ValueError, \
389 "Classifier %s wasn't yet trained, therefore can't " \
390 "predict" % self
391 nfeatures = data.shape[1]
392
393
394 if nfeatures != self.__trainednfeatures:
395 raise ValueError, \
396 "Classifier %s was trained on data with %d features, " % \
397 (self, self.__trainednfeatures) + \
398 "thus can't predict for %d features" % nfeatures
399
400
401 if self.params.retrainable:
402 if not self.__changedData_isset:
403 self.__resetChangedData()
404 _changedData = self._changedData
405 _changedData['testdata'] = \
406 self.__wasDataChanged('testdata', data)
407 if __debug__:
408 debug('CLF_', "prepredict: Obtained _changedData is %s"
409 % (_changedData))
410
411
412 - def _postpredict(self, data, result):
413 """Functionality after prediction is computed
414 """
415 self.predictions = result
416 if self.params.retrainable:
417 self.__changedData_isset = False
418
420 """Actual prediction
421 """
422 raise NotImplementedError
423
424
426 """Predict classifier on data
427
428 Shouldn't be overridden in subclasses unless explicitly needed
429 to do so. Also subclasses trying to call super class's predict
430 should call _predict if within _predict instead of predict()
431 since otherwise it would loop
432 """
433 data = N.asarray(data)
434 if __debug__:
435 debug("CLF", "Predicting classifier %(clf)s on data %(data)s",
436 msgargs={'clf':self, 'data':data.shape})
437
438
439 t0 = time.time()
440
441 states = self.states
442
443
444 states.reset(['values', 'predictions'])
445
446 self._prepredict(data)
447
448 if self.__trainednfeatures > 0 \
449 or 'notrain2predict' in self._clf_internals:
450 result = self._predict(data)
451 else:
452 warning("Trying to predict using classifier trained on no features")
453 if __debug__:
454 debug("CLF",
455 "No features were present for training, prediction is " \
456 "bogus")
457 result = [None]*data.shape[0]
458
459 states.predicting_time = time.time() - t0
460
461 if 'regression' in self._clf_internals and not self.params.regression:
462
463
464
465
466
467
468
469
470 result_ = N.array(result)
471 if states.isEnabled('values'):
472
473
474 if not states.isSet('values'):
475 states.values = result_.copy()
476 else:
477
478
479
480 states.values = states.values.copy()
481
482 trained_labels = self.trained_labels
483 for i, value in enumerate(result):
484 dists = N.abs(value - trained_labels)
485 result[i] = trained_labels[N.argmin(dists)]
486
487 if __debug__:
488 debug("CLF_", "Converted regression result %(result_)s "
489 "into labels %(result)s for %(self_)s",
490 msgargs={'result_':result_, 'result':result,
491 'self_': self})
492
493 self._postpredict(data, result)
494 return result
495
496
498 """Either classifier was already trained.
499
500 MUST BE USED WITH CARE IF EVER"""
501 if dataset is None:
502
503 return not self.__trainednfeatures is None
504 else:
505 res = (self.__trainednfeatures == dataset.nfeatures)
506 if __debug__ and 'CHECK_TRAINED' in debug.active:
507 res2 = (self.__trainedidhash == dataset.idhash)
508 if res2 != res:
509 raise RuntimeError, \
510 "isTrained is weak and shouldn't be relied upon. " \
511 "Got result %b although comparing of idhash says %b" \
512 % (res, res2)
513 return res
514
515
517 """Some classifiers like BinaryClassifier can't be used for
518 regression"""
519
520 if self.params.regression:
521 raise ValueError, "Regression mode is meaningless for %s" % \
522 self.__class__.__name__ + " thus don't enable it"
523
524
525 @property
527 """Either classifier was already trained"""
528 return self.isTrained()
529
531 """Reset trained state"""
532 self.__trainednfeatures = None
533
534
535
536
537
538
539 super(Classifier, self).reset()
540
541
543 """Factory method to return an appropriate sensitivity analyzer for
544 the respective classifier."""
545 raise NotImplementedError
546
547
548
549
550
552 """Assign value of retrainable parameter
553
554 If retrainable flag is to be changed, classifier has to be
555 untrained. Also internal attributes such as _changedData,
556 __changedData_isset, and __idhashes should be initialized if
557 it becomes retrainable
558 """
559 pretrainable = self.params['retrainable']
560 if (force or value != pretrainable.value) \
561 and 'retrainable' in self._clf_internals:
562 if __debug__:
563 debug("CLF_", "Setting retrainable to %s" % value)
564 if 'meta' in self._clf_internals:
565 warning("Retrainability is not yet crafted/tested for "
566 "meta classifiers. Unpredictable behavior might occur")
567
568 if self.trained:
569 self.untrain()
570 states = self.states
571 if not value and states.isKnown('retrained'):
572 states.remove('retrained')
573 states.remove('repredicted')
574 if value:
575 if not 'retrainable' in self._clf_internals:
576 warning("Setting of flag retrainable for %s has no effect"
577 " since classifier has no such capability. It would"
578 " just lead to resources consumption and slowdown"
579 % self)
580 states.add(StateVariable(enabled=True,
581 name='retrained',
582 doc="Either retrainable classifier was retrained"))
583 states.add(StateVariable(enabled=True,
584 name='repredicted',
585 doc="Either retrainable classifier was repredicted"))
586
587 pretrainable.value = value
588
589
590 if value:
591 self.__idhashes = {'traindata': None, 'labels': None,
592 'testdata': None}
593 if __debug__ and 'CHECK_RETRAIN' in debug.active:
594
595
596
597
598 self.__trained = self.__idhashes.copy()
599 self.__resetChangedData()
600 self.__invalidatedChangedData = {}
601 elif 'retrainable' in self._clf_internals:
602
603 self.__changedData_isset = False
604 self._changedData = None
605 self.__idhashes = None
606 if __debug__ and 'CHECK_RETRAIN' in debug.active:
607 self.__trained = None
608
610 """For retrainable classifier we keep track of what was changed
611 This function resets that dictionary
612 """
613 if __debug__:
614 debug('CLF_',
615 'Retrainable: resetting flags on either data was changed')
616 keys = self.__idhashes.keys() + self._paramscols
617
618
619
620
621
622 self._changedData = dict(zip(keys, [False]*len(keys)))
623 self.__changedData_isset = False
624
625
627 """Check if given entry was changed from what known prior.
628
629 If so -- store only the ones needed for retrainable beastie
630 """
631 idhash_ = idhash(entry)
632 __idhashes = self.__idhashes
633
634 changed = __idhashes[key] != idhash_
635 if __debug__ and 'CHECK_RETRAIN' in debug.active:
636 __trained = self.__trained
637 changed2 = entry != __trained[key]
638 if isinstance(changed2, N.ndarray):
639 changed2 = changed2.any()
640 if changed != changed2 and not changed:
641 raise RuntimeError, \
642 'idhash found to be weak for %s. Though hashid %s!=%s %s, '\
643 'values %s!=%s %s' % \
644 (key, idhash_, __idhashes[key], changed,
645 entry, __trained[key], changed2)
646 if update:
647 __trained[key] = entry
648
649 if __debug__ and changed:
650 debug('CLF_', "Changed %s from %s to %s.%s"
651 % (key, __idhashes[key], idhash_,
652 ('','updated')[int(update)]))
653 if update:
654 __idhashes[key] = idhash_
655
656 return changed
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688 - def retrain(self, dataset, **kwargs):
689 """Helper to avoid check if data was changed actually changed
690
691 Useful if just some aspects of classifier were changed since
692 its previous training. For instance if dataset wasn't changed
693 but only classifier parameters, then kernel matrix does not
694 have to be computed.
695
696 Words of caution: classifier must be previously trained,
697 results always should first be compared to the results on not
698 'retrainable' classifier (without calling retrain). Some
699 additional checks are enabled if debug id 'CHECK_RETRAIN' is
700 enabled, to guard against obvious mistakes.
701
702 :Parameters:
703 kwargs
704 that is what _changedData gets updated with. So, smth like
705 ``(params=['C'], labels=True)`` if parameter C and labels
706 got changed
707 """
708
709
710 if __debug__:
711 if not self.params.retrainable:
712 raise RuntimeError, \
713 "Do not use re(train,predict) on non-retrainable %s" % \
714 self
715
716 if kwargs.has_key('params') or kwargs.has_key('kernel_params'):
717 raise ValueError, \
718 "Retraining for changed params not working yet"
719
720 self.__resetChangedData()
721
722
723 chd = self._changedData
724 ichd = self.__invalidatedChangedData
725
726 chd.update(kwargs)
727
728
729 for key, value in kwargs.iteritems():
730 if value:
731 ichd[key] = True
732 self.__changedData_isset = True
733
734
735 if __debug__ and 'CHECK_RETRAIN' in debug.active:
736 for key, data_ in (('traindata', dataset.samples),
737 ('labels', dataset.labels)):
738
739 if not chd[key] and not ichd.get(key, False):
740 if self.__wasDataChanged(key, data_, update=False):
741 raise RuntimeError, \
742 "Data %s found changed although wasn't " \
743 "labeled as such" % key
744
745
746
747
748
749
750 if __debug__ and 'CHECK_RETRAIN' in debug.active and self.trained \
751 and not self._changedData['traindata'] \
752 and self.__trained['traindata'].shape != dataset.samples.shape:
753 raise ValueError, "In retrain got dataset with %s size, " \
754 "whenever previousely was trained on %s size" \
755 % (dataset.samples.shape, self.__trained['traindata'].shape)
756 self.train(dataset)
757
758
760 """Helper to avoid check if data was changed actually changed
761
762 Useful if classifier was (re)trained but with the same data
763 (so just parameters were changed), so that it could be
764 repredicted easily (on the same data as before) without
765 recomputing for instance train/test kernel matrix. Should be
766 used with caution and always compared to the results on not
767 'retrainable' classifier. Some additional checks are enabled
768 if debug id 'CHECK_RETRAIN' is enabled, to guard against
769 obvious mistakes.
770
771 :Parameters:
772 data
773 data which is conventionally given to predict
774 kwargs
775 that is what _changedData gets updated with. So, smth like
776 ``(params=['C'], labels=True)`` if parameter C and labels
777 got changed
778 """
779 if len(kwargs)>0:
780 raise RuntimeError, \
781 "repredict for now should be used without params since " \
782 "it makes little sense to repredict if anything got changed"
783 if __debug__ and not self.params.retrainable:
784 raise RuntimeError, \
785 "Do not use retrain/repredict on non-retrainable classifiers"
786
787 self.__resetChangedData()
788 chd = self._changedData
789 chd.update(**kwargs)
790 self.__changedData_isset = True
791
792
793
794 if __debug__ and 'CHECK_RETRAIN' in debug.active:
795 for key, data_ in (('testdata', data),):
796
797
798 if self.__wasDataChanged(key, data_, update=False):
799 raise RuntimeError, \
800 "Data %s found changed although wasn't " \
801 "labeled as such" % key
802
803
804
805 if __debug__ and 'CHECK_RETRAIN' in debug.active \
806 and not self._changedData['testdata'] \
807 and self.__trained['testdata'].shape != data.shape:
808 raise ValueError, "In repredict got dataset with %s size, " \
809 "whenever previously was trained on %s size" \
810 % (data.shape, self.__trained['testdata'].shape)
811
812 return self.predict(data)
813
814
815
816
817
818