All minimum spanning trees implementation
I've been looking for an implementation (I'm using networkx library.) that will find all the minimum spanning trees (MST) of an undirected weighted graph.
I can only find implementations for Kruskal's Algorithm and Prim's Algorithm both of which will only return a single MST.
I've seen papers that address this problem (such as Representing all minimum spanning trees with applications to counting and generation) but my head te开发者_如何学运维nds to explode someway through trying to think how to translate it to code.
In fact i've not been able to find an implementation in any language!
I don't know if this is the solution, but it's a solution (it's the graph version of a brute force, I would say):
- Find the MST of the graph using kruskal's or prim's algorithm. This should be O(E log V).
- Generate all spanning trees. This can be done in
O(Elog(V) + V + n) for n = number of spanning trees
, as I understand from 2 minutes's worth of google, can possibly be improved. - Filter the list generated in step #2 by the tree's weight being equal to the MST's weight. This should be O(n) for n as the number of trees generated in step #2.
Note: Do this lazily! Generating all possible trees and then filtering the results will take O(V^2) memory, and polynomial space requirements are evil - Generate a tree, examine it's weight, if it's an MST add it to a result list, if not - discard it.
Overall time complexity: O(Elog(V) + V + n) for G(V,E) with n spanning trees
Ronald Rivest has a nice implementation in Python, mst.py
You can find an idea in the work of Sorensen and Janssens (2005).
The idea is to generate the STs in the increasing order, and as soon as you get the bigger value of ST stop the enumeration.
Here's a short python implementation, basically a recursive variation of Kruskal's. Uses weight of the the first MST found to limit the size of the search space thereafter. Definitely still exponential complexity but better than generating every spanning tree. Some test code is also included.
[Note: this was just my own experimentation for fun and possible inspiration of further thoughts on the problem from others, it's not an attempt to specifically implement any of the solutions suggested in other supplied answers here]
# Disjoint set find (and collapse)
def find(nd, djset):
uv = nd
while djset[uv] >= 0: uv = djset[uv]
if djset[nd] >= 0: djset[nd] = uv
return uv
# Disjoint set union (does not modify djset)
def union(nd1, nd2, djset):
unionset = djset.copy()
if unionset[nd2] < unionset[nd1]:
nd1, nd2 = nd2, nd1
unionset[nd1] += unionset[nd2]
unionset[nd2] = nd1
return unionset
# Bitmask convenience methods; uses bitmasks
# internally to represent MST edge combinations
def setbit(j, mask): return mask | (1 << j)
def isbitset(j, mask): return (mask >> j) & 1
def masktoedges(mask, sedges):
return [sedges[i] for i in range(len(sedges))
if isbitset(i, mask)]
# Upper-bound count of viable MST edge combination, i.e.
# count of edge subsequences of length: NEDGES, w/sum: WEIGHT
def count_subsequences(sedges, weight, nedges):
#{
def count(i, target, length, cache):
tkey = (i, target, length)
if tkey in cache: return cache[tkey]
if i == len(sedges) or target < sedges[i][2]: return 0
cache[tkey] = (count(i+1, target, length, cache) +
count(i+1, target - sedges[i][2], length - 1, cache) +
(1 if sedges[i][2] == target and length == 1 else 0))
return cache[tkey]
return count(0, weight, nedges, {})
#}
# Arg: n is number of nodes in graph [0, n-1]
# Arg: sedges is list of graph edges sorted by weight
# Return: list of MSTs, where each MST is a list of edges
def find_all_msts(n, sedges):
#{
# Recursive variant of kruskal to find all MSTs
def buildmsts(i, weight, mask, nedges, djset):
#{
nonlocal maxweight, msts
if nedges == (n-1):
msts.append(mask)
if maxweight == float('inf'):
print(f"MST weight: {weight}, MST edges: {n-1}, Total graph edges: {len(sedges)}")
print(f"Upper bound numb viable MST edge combinations: {count_subsequences(sedges, weight, n-1)}\n")
maxweight = weight
return
if i < len(sedges):
#{
u,v,wt = sedges[i]
if weight + wt*((n-1) - nedges) <= maxweight:
#{
# Left recursive branch - include edge if valid
nd1, nd2 = find(u, djset), find(v, djset)
if nd1 != nd2: buildmsts(i+1, weight + wt,
setbit(i, mask), nedges+1, union(nd1, nd2, djset))
# Right recursive branch - always skips edge
buildmsts(i+1, weight, mask, nedges, djset)
#}
#}
#}
maxweight, msts = float('inf'), []
djset = {i: -1 for i in range(n)}
buildmsts(0, 0, 0, 0, djset)
return [masktoedges(mask, sedges) for mask in msts]
#}
import time, numpy
def run_test_case(low=10, high=21):
rng = numpy.random.default_rng()
n = rng.integers(low, high)
nedges = rng.integers(n-1, n*(n-1)//2)
edges = set()
while len(edges) < nedges:
u,v = sorted(rng.choice(range(n), size=2, replace=False))
edges.add((u,v))
weights = sorted(rng.integers(1, 2*n, size=nedges))
sedges = [[u,v,wt] for (u,v), wt in zip(edges, weights)]
print(f"Numb nodes: {n}\nSorted edges: {sedges}\n")
for i, mst in enumerate(find_all_msts(n, sedges)):
if i == 0: print("MSTs:")
print((i+1), ":", mst)
if __name__ == "__main__":
initial = time.time()
run_test_case(20, 35)
print(f"\nRun time: {time.time() - initial}s")
精彩评论