Looping over components in matlab

Alec Jacobson

July 8, 2024

weblog/

A common pattern I find myself writing is something like the following which divides $n$ elements into $m$ components and then processes each component:


% Some big array of data
X = rand(n,3);
% labels of each entry into a component [1,m] (use random here)
C = randi(m,n,1);
% Loop over each component
for c = 1:max(C)
  % Get the indices of the c-th component
  Ic = find(C == c);
  % Do something with just the c-th component
  Xc = X(Ic,:);
  ...
end

This code is $O(n\cdot m)$ for $n$ entries in my big array and $m$ components because the find(C == c) line is $O(n)$. Since for loops can be slow in matlab and the % Do something might be pretty heavy, often the cost of this find(C == c) isn't felt. Especially if $m$ is small.

When $m$ is big this can be slow. One way to deal with it is to sort C and process the sorted chunks. That'd be $O(n\cdot \log(n))$ (ever element is processed only once during the for loop):


% Sort the indices of C
[~,I] = sort(C);
% Cumulative lengths of each component
cumlens = [0;find(C(I(1:end-1))~=C(I(2:end)));numel(C)];
% Loop over each component
for c = 1:max(C)
  Ic = I(cumlens(c)+1:cumlens(c+1));
  ...
end

That's not too difficult once it's written out, but it's not something that I can brainlessly write without bugs 100% of the time.

An alternative is to utilize sparse matrices to do the sorting for us and continue using find to get the indices:


% Sparse matrix with 1's at the indices of C
S = sparse(1:n,C,1,n,m);
% Loop over each component
for c = 1:max(C)
  % Get the indices of the c-th component
  Ic = find(S(:,c));
  % Do something with just the c-th component
  Xc = X(Ic,:);
  ...
end

This works because matlab sparse matrices are stored in compressed sparse column format. This costs us $O(n\cdot \log(n))$ when we create the sparse matrix, but then the find is linear cost in the size of the component (i.e., it has output bound complexity; just what we would like).

The pitfall of this method is remember that the columns are compressed not the rows. If you try to do S = sparse(C,1:n,1,m,n); and Ic = find(S(c,:)); you'll get burned and be back to $O(n)$ inside each loop.

A super common place that I'm running into this is when I'm using gptoolbox to process submeshes. I might have code like:


% Some mesh with O(n) vertices and faces
[V,F] = ...
% Connected component labels per face
[~,C] = connected_components(F);
% process each component
for c = 1:max(C)
  Ic = find(C == c);
  Fc = F(Ic,:);
  % Extract submesh. (warning this next line is O(n))
  [Vc,~,~,Fc] = remove_unreferenced(V,Fc);
  % Now Vc and Fc are both O(|Ic|). Do something with them.
  ...
end

There's another pitfall here, this time it's gptoolbox's fault. For legacy reasons I've forgotten, the second output of remove_unreferenced(V,Fc) is #V-long regardless of how few unique elements are passed in for Fc. This causes the whole loop to be $O(n \cdot m)$ again.

I've finally added a 3rd, boolean argument to remove_unreferenced so that you can indicate that you don't carry about the second output argument and you'd like the complexity to be $O(|Fc|)$. So the adjusted loop looks like:


S = sparse(1:size(F,1),C,1,size(F,1),max(C));
for c = 1:max(C)
  Ic = find(S(:,c));
  Fc = F(Ic,:);
  % Extract submesh. This next line is now O(|Ic|)
  [Vc,~,~,Fc] = remove_unreferenced(V,Fc,true);
  % Now Vc and Fc are both O(|Ic|). Do something with them.
  ...
end