Is it possible to optimize this Matlab code for doing vector quantization with centroids from k-means?
I've created a codebook using k-means of size 4000x300 (4000 centroids, each with 300 features). Using the codebook, I then want to label an input vector (for purposes of binning later on). The input vector is of size Nx300, where N is the total number of input instances I receive.
To compute the labels, I calculate the closest centroid for each of the input vectors. To do so, I compare each input vector against all centroids and pick the centroid with the minimum distance. The label is then just the index of that centroid.
My current Matlab code looks like:
function labels = assign_labels(centroids, X)
labels = zeros(size(X, 1), 1);
% for each X, calculate the distance from each centroid
for i = 1:size(X, 1)
% distance of X_i from all j centroids is: sum((X_i - centroid_j)^2)
% note: we leave off the sqrt as an optimization
distances = sum(bsxfun(@minus, centroids, X(i, :)) .^ 2, 2);
[value, label] = min(distances);
labels(i) = label;
end
However, this code is still fairly slow (for my purposes), and I was hoping there might be a way to optimize the code further.
One obvious issue is that there is a for-loop, which is the bane of good performance on Matlab. I've been trying to come up with a way to get rid of it, but with no luck (I looked into using arrayfun in conjunction with bsxfun, but haven't gotten that to work). Alternatively, if someone know of any other way to speed this up, I would be greatly appreciate it.
Update
After doing some searching, I couldn't find a great solution using Matlab, so I decided to look at what is used in Python's scikits.learn package for 'euclidean_distance' (shortened):
XX = sum(X * X, axis=1)[:, newaxis]
YY = Y.copy()
YY **= 2
YY = sum(YY, axis=1)[newaxis, :]
distances = XX + YY
distances -= 2 * dot(X, Y.T)
distances = maximum(distances, 0)
which uses the binomial form of the euclidean distance ((x-y)^2 -> x^2 + y^2 - 2xy), which from what I've read usually runs faster. My completely untested Matlab translation is:
开发者_StackOverflow XX = sum(data .* data, 2);
YY = sum(center .^ 2, 2);
[val, ~] = max(XX + YY - 2*data*center');
Use the following function to calculate your distances. You should see an order of magnitude speed up
The two matrices A and B have the columns as the dimenions and the rows as each point. A is your matrix of centroids. B is your matrix of datapoints.
function D=getSim(A,B)
Qa=repmat(dot(A,A,2),1,size(B,1));
Qb=repmat(dot(B,B,2),1,size(A,1));
D=Qa+Qb'-2*A*B';
You can vectorize it by converting to cells and using cellfun
:
[nRows,nCols]=size(X);
XCell=num2cell(X,2);
dist=reshape(cell2mat(cellfun(@(x)(sum(bsxfun(@minus,centroids,x).^2,2)),XCell,'UniformOutput',false)),nRows,nRows);
[~,labels]=min(dist);
Explanation:
- We assign each row of
X
to its own cell in the second line - This piece
@(x)(sum(bsxfun(@minus,centroids,x).^2,2))
is an anonymous function which is the same as yourdistances=...
line, and usingcell2mat
, we apply it to each row ofX
. - The labels are then the indices of the minimum row along each column.
For a true matrix implementation, you may consider trying something along the lines of:
P2 = kron(centroids, ones(size(X,1),1));
Q2 = kron(ones(size(centroids,1),1), X);
distances = reshape(sum((Q2-P2).^2,2), size(X,1), size(centroids,1));
Note This assumes the data is organized as [x1 y1 ...; x2 y2 ...;...]
You can use a more efficient algorithm for nearest neighbor search than brute force. The most popular approach are Kd-Tree. O(log(n)) average query time instead of the O(n) brute force complexity. Regarding a Maltab implementation of Kd-Trees, you can have a look here
精彩评论