1
2
3
4
5
6
7
8
9 """Cross-validate a classifier on a dataset"""
10
11 __docformat__ = 'restructuredtext'
12
13 from mvpa.support.copy import deepcopy
14
15 from mvpa.measures.base import DatasetMeasure
16 from mvpa.datasets.splitters import NoneSplitter
17 from mvpa.base import warning
18 from mvpa.misc.state import StateVariable, Harvestable
19 from mvpa.misc.transformers import GrandMean
20
21 if __debug__:
22 from mvpa.base import debug
23
24
26 """Classifier cross-validation.
27
28 This class provides a simple interface to cross-validate a classifier
29 on datasets generated by a splitter from a single source dataset.
30
31 Arbitrary performance/error values can be computed by specifying an error
32 function (used to compute an error value for each cross-validation fold)
33 and a combiner function that aggregates all computed error values across
34 cross-validation folds.
35 """
36
37 results = StateVariable(enabled=False, doc=
38 """Store individual results in the state""")
39 splits = StateVariable(enabled=False, doc=
40 """Store the actual splits of the data. Can be memory expensive""")
41 transerrors = StateVariable(enabled=False, doc=
42 """Store copies of transerrors at each step""")
43 confusion = StateVariable(enabled=False, doc=
44 """Store total confusion matrix (if available)""")
45 training_confusion = StateVariable(enabled=False, doc=
46 """Store total training confusion matrix (if available)""")
47 samples_error = StateVariable(enabled=False,
48 doc="Per sample errors.")
49
50
51 - def __init__(self,
52 transerror,
53 splitter=None,
54 combiner='mean',
55 expose_testdataset=False,
56 harvest_attribs=None,
57 copy_attribs='copy',
58 **kwargs):
59 """
60 :Parameters:
61 transerror: TransferError instance
62 Provides the classifier used for cross-validation.
63 splitter: Splitter | None
64 Used to split the dataset for cross-validation folds. By
65 convention the first dataset in the tuple returned by the
66 splitter is used to train the provided classifier. If the
67 first element is 'None' no training is performed. The second
68 dataset is used to generate predictions with the (trained)
69 classifier. If `None` (default) an instance of
70 :class:`~mvpa.datasets.splitters.NoneSplitter` is used.
71 combiner: Functor | 'mean'
72 Used to aggregate the error values of all cross-validation
73 folds. If 'mean' (default) the grand mean of the transfer
74 errors is computed.
75 expose_testdataset: bool
76 In the proper pipeline, classifier must not know anything
77 about testing data, but in some cases it might lead only
78 to marginal harm, thus migth wanted to be enabled (provide
79 testdataset for RFE to determine stopping point).
80 harvest_attribs: list of basestr
81 What attributes of call to store and return within
82 harvested state variable
83 copy_attribs: None | basestr
84 Force copying values of attributes on harvesting
85 **kwargs:
86 All additional arguments are passed to the
87 :class:`~mvpa.measures.base.DatasetMeasure` base class.
88 """
89 DatasetMeasure.__init__(self, **kwargs)
90 Harvestable.__init__(self, harvest_attribs, copy_attribs)
91
92 if splitter is None:
93 self.__splitter = NoneSplitter()
94 else:
95 self.__splitter = splitter
96
97 if combiner == 'mean':
98 self.__combiner = GrandMean
99 else:
100 self.__combiner = combiner
101
102 self.__transerror = transerror
103 self.__expose_testdataset = expose_testdataset
104
105
106
107
108
109
110
111
112
113
114
115
116
117 - def _call(self, dataset):
118 """Perform cross-validation on a dataset.
119
120 'dataset' is passed to the splitter instance and serves as the source
121 dataset to generate split for the single cross-validation folds.
122 """
123
124 results = []
125 self.splits = []
126
127
128 states = self.states
129 clf = self.__transerror.clf
130 expose_testdataset = self.__expose_testdataset
131
132
133 terr_enable = []
134 for state_var in ['confusion', 'training_confusion', 'samples_error']:
135 if states.isEnabled(state_var):
136 terr_enable += [state_var]
137
138
139 summaryClass = clf._summaryClass
140 clf_hastestdataset = hasattr(clf, 'testdataset')
141
142 self.confusion = summaryClass()
143 self.training_confusion = summaryClass()
144 self.transerrors = []
145 self.samples_error = dict([(id, []) for id in dataset.origids])
146
147
148
149 if len(terr_enable):
150 self.__transerror.states._changeTemporarily(
151 enable_states=terr_enable)
152
153
154 for split in self.__splitter(dataset):
155
156
157 if states.isEnabled("splits"):
158 self.splits.append(split)
159
160 if states.isEnabled("transerrors"):
161
162
163 transerror = deepcopy(self.__transerror)
164 else:
165 transerror = self.__transerror
166
167
168 if clf_hastestdataset and expose_testdataset:
169 clf.testdataset = split[1]
170 pass
171
172
173 result = transerror(split[1], split[0])
174
175
176 if clf_hastestdataset and expose_testdataset:
177 clf.testdataset = None
178
179
180 self._harvest(locals())
181
182
183
184 if states.isEnabled("transerrors"):
185 self.transerrors.append(transerror)
186
187
188
189 if states.isEnabled("samples_error"):
190 for k, v in \
191 transerror.states.getvalue("samples_error").iteritems():
192 self.samples_error[k].append(v)
193
194
195 for state_var in ['confusion', 'training_confusion']:
196 if states.isEnabled(state_var):
197 states.getvalue(state_var).__iadd__(
198 transerror.states.getvalue(state_var))
199
200 if __debug__:
201 debug("CROSSC", "Split #%d: result %s" \
202 % (len(results), `result`))
203 results.append(result)
204
205
206 self.__transerror = transerror
207
208
209 if len(terr_enable):
210 self.__transerror.states._resetEnabledTemporarily()
211
212 self.results = results
213 """Store state variable if it is enabled"""
214
215
216 try:
217 if states.isEnabled("confusion"):
218 states.confusion.labels_map = dataset.labels_map
219 if states.isEnabled("training_confusion"):
220 states.training_confusion.labels_map = dataset.labels_map
221 except:
222 pass
223
224 return self.__combiner(results)
225
226
227 splitter = property(fget=lambda self:self.__splitter,
228 doc="Access to the Splitter instance.")
229 transerror = property(fget=lambda self:self.__transerror,
230 doc="Access to the TransferError instance.")
231 combiner = property(fget=lambda self:self.__combiner,
232 doc="Access to the configured combiner.")
233