1
2
3
4
5
6
7
8
9 """k-Nearest-Neighbour classifier."""
10
11 __docformat__ = 'restructuredtext'
12
13
14 import numpy as N
15
16 from mvpa.base import warning
17 from mvpa.misc.support import indentDoc
18 from mvpa.clfs.base import Classifier
19 from mvpa.base.dochelpers import enhancedDocString
20 from mvpa.clfs.distance import squared_euclidean_distance
21
22 if __debug__:
23 from mvpa.base import debug
24
25
26 -class kNN(Classifier):
27 """
28 k-Nearest-Neighbour classifier.
29
30 This is a simple classifier that bases its decision on the distances
31 between the training dataset samples and the test sample(s). Distances
32 are computed using a customizable distance function. A certain number
33 (`k`)of nearest neighbors is selected based on the smallest distances
34 and the labels of this neighboring samples are fed into a voting
35 function to determine the labels of the test sample.
36
37 Training a kNN classifier is extremely quick, as no actuall training
38 is performed as the training dataset is simply stored in the
39 classifier. All computations are done during classifier prediction.
40
41 .. note::
42 If enabled, kNN stores the votes per class in the 'values' state after
43 calling predict().
44
45 """
46
47 _clf_internals = ['knn', 'non-linear', 'binary', 'multiclass',
48 'notrain2predict' ]
49
52 """
53 :Parameters:
54 k: unsigned integer
55 Number of nearest neighbours to be used for voting.
56 dfx: functor
57 Function to compute the distances between training and test samples.
58 Default: squared euclidean distance
59 voting: str
60 Voting method used to derive predictions from the nearest neighbors.
61 Possible values are 'majority' (simple majority of classes
62 determines vote) and 'weighted' (votes are weighted according to the
63 relative frequencies of each class in the training data).
64 **kwargs:
65 Additonal arguments are passed to the base class.
66 """
67
68
69 Classifier.__init__(self, **kwargs)
70
71 self.__k = k
72 self.__dfx = dfx
73 self.__voting = voting
74 self.__data = None
75
76
78 """Representation of the object
79 """
80 return super(kNN, self).__repr__(
81 ["k=%d" % self.__k, "dfx=%s" % self.__dfx,
82 "voting=%s" % repr(self.__voting)]
83 + prefixes)
84
85
89
90
92 """Train the classifier.
93
94 For kNN it is degenerate -- just stores the data.
95 """
96 self.__data = data
97 if __debug__:
98 if str(data.samples.dtype).startswith('uint') \
99 or str(data.samples.dtype).startswith('int'):
100 warning("kNN: input data is in integers. " + \
101 "Overflow on arithmetic operations might result in"+\
102 " errors. Please convert dataset's samples into" +\
103 " floating datatype if any error is reported.")
104 self.__weights = None
105
106
107 uniquelabels = data.uniquelabels
108 self.__votes_init = dict(zip(uniquelabels,
109 [0] * len(uniquelabels)))
110
111
113 """Predict the class labels for the provided data.
114
115 Returns a list of class labels (one for each data sample).
116 """
117
118 data = N.asarray(data)
119
120
121 if __debug__:
122 if not data.ndim == 2:
123 raise ValueError, "Data array must be two-dimensional."
124
125 if not data.shape[1] == self.__data.nfeatures:
126 raise ValueError, "Length of data samples (features) does " \
127 "not match the classifier."
128
129
130
131
132 dists = self.__dfx(self.__data.samples, data).T
133
134
135 knns = dists.argsort(axis=1)[:, :self.__k]
136
137
138 predicted = []
139
140 if self.__voting == 'majority':
141 vfx = self.getMajorityVote
142 elif self.__voting == 'weighted':
143 vfx = self.getWeightedVote
144 else:
145 raise ValueError, "kNN told to perform unknown voting '%s'." \
146 % self.__voting
147
148
149 results = [vfx(knn) for knn in knns]
150
151
152 predicted = [r[0] for r in results]
153
154
155
156 self.predictions = predicted
157 self.values = [r[1] for r in results]
158
159 return predicted
160
161
163 """Simple voting by choosing the majority of class neighbors.
164 """
165
166 _data = self.__data
167 labels = _data.labels
168
169
170 votes = self.__votes_init.copy()
171 for nn in knn_ids:
172 votes[labels[nn]] += 1
173
174
175
176 return max(votes.iteritems(), key=lambda x:x[1])[0], \
177 [votes[ul] for ul in _data.uniquelabels]
178
179
181 """Vote with classes weighted by the number of samples per class.
182 """
183
184 _data = self.__data
185 uniquelabels = _data.uniquelabels
186
187
188 if self.__weights is None:
189
190
191
192
193 self.__labels = labels = self.__data.labels
194 Nlabels = len(labels)
195 Nuniquelabels = len(uniquelabels)
196
197
198
199
200
201
202
203 self.__weights = \
204 [ 1.0 - ((labels == label).sum() / Nlabels) \
205 for label in uniquelabels ]
206 self.__weights = dict(zip(uniquelabels, self.__weights))
207
208 labels = self.__labels
209
210 votes = self.__votes_init.copy()
211 for nn in knn_ids:
212 votes[labels[nn]] += 1
213
214
215 votes = [ self.__weights[ul] * votes[ul] for ul in uniquelabels]
216
217
218
219 return uniquelabels[N.asarray(votes).argmax()], \
220 votes
221
222
224 """Reset trained state"""
225 self.__data = None
226 super(kNN, self).untrain()
227