Find "best" complete subgraphs
While optimizing performance of an app of mine, I ran across a huge performance bottleneck in few lines of (Python) code.
I have N tokens. each token has a value assigned to it. Some of the tokens contradict (e.g. tokens 8 and 12 cannot "live together"). My job is to find the k-best token-groups. The value of a group of tokens is simply the sum of the values of the tokens in it.
Naïve algorithm (which I have implemented...):
- find all 2^N token-group permutations of the tokens
- Eliminate the token-groups that have contradictions in them
- Calculate the value of all remaining token-groups
- Sort token-groups by value
- Choose top K token-groups
Real world numbers - I need top 10 token groups from a group of 20 tokens (for which I calculated the 1,000,000 permutations (!)), narrowed down to 3500 non-contradicting token groups. This took 5 seconds on my laptop...
I'm sure I can optimize steps 1+2 somehow by generating just the non-contradicting token-groups.
I'm also pretty sure I can somehow magically find the best token-group in a single开发者_如何学C search and find a way to traverse the token-groups by diminishing value, thus finding just the 10-best I am looking for....
my actual code:
all_possibilities = sum((list(itertools.combinations(token_list, i)) for i in xrange(len(token_list)+1)), [])
all_possibilities = [list(option) for option in all_possibilities if self._no_contradiction(option)]
all_possibilities = [(option, self._probability(option)) for option in all_possibilities]
all_possibilities.sort(key = lambda result: -result[1]) # sort by descending probability
Please help?
Tal.
A simple approach at steps 1+2 could look like this: first, define a list of tokens and a dictionary of contradictions (each key is a token and each value is a set of tokens). Then, for each token take two actions:
- add it to the
result
if it is not already contradicting, and increase theconflicting
set with tokens that contradict the currently added token - don't add it to the
result
(choose to ignore it) and move to the next token.
So here's a sample code:
token_list = ['a', 'b', 'c']
contradictions = {
'a': set(['b']),
'b': set(['a']),
'c': set()
}
class Generator(object):
def __init__(self, token_list, contradictions):
self.list = token_list
self.contradictions = contradictions
self.max_start = len(self.list) - 1
def add_no(self, start, result, conflicting):
if start < self.max_start:
for g in self.gen(start + 1, result, conflicting):
yield g
else:
yield result[:]
def add_yes(self, token, start, result, conflicting):
result.append(token)
new_conflicting = conflicting | self.contradictions[token]
for g in self.add_no(start, result, new_conflicting):
yield g
result.pop()
def gen(self, start, result, conflicting):
token = self.list[start]
if token not in conflicting:
for g in self.add_yes(token, start, result, conflicting):
yield g
for g in self.add_no(start, result, conflicting):
yield g
def go(self):
return self.gen(0, [], set())
Sample usage:
g = Generator(token_list, contradictions)
for x in g.go():
print x
This is a recursive algorithm, so it won't work for more than a few thousand tokens (because of Python's stack limit), but you could easily create a non-recursive one.
An O(n (log n))
or O(n + m)
solution for n
tokens and string-length m
What differentiates your problem from the NP-complete clique problem is the fact that your "conflict" graph has structure - namely that it can be projected onto 1 dimension (it can be sorted).
That means you can divide and conquer; after all, non-overlapping ranges have no effect on each other, so there is no need to explore the complete state-space. In particular, a dynamic programming solution will work.
The outline of an algorithm
- Assume a token's position is represented as
[start, end)
(i.e. inclusive start, exclusive end). Sort the token-list by token end, we'll be iterating over them. - You will be extending subsets of these tokens. These sets of tokens will have an end (no token can be added to the subset if it starts before the subset's end), and a cumulative value. The end of a subset of tokens is the maximum of the ends of all tokens in the subset.
- You're going to maintain a mapping (e.g. via a hashtable or array) from the index into the sorted array of tokens up to which everything's been processed to the resultant best-yet subset of non-conflicting tokens. That means that the best-yet subset stored in the mapping for index J must can only include tokens of index less than or equal to J
- At each step, you'll be computing the best subset for some position J, and then one of three things can occur: you may have already cached this computation in the mapping (easy), or the best subset includes the item J, or the best subset exludes item J. If you haven't cached it, you can only find out it the best subset includes or excludes J by trying both options.
Now, the trick is in the cache - you need to try both options, and that looks like a recursive (exponential) search, but it needn't be.
- If the best subset for index
J
includestoken[J]
then it can't include any tokens that overlap that token - and in particular, since we sorted bytoken.end
, there is a last tokenK
in that list such thatK < J
andtoken[K].end <= token[J].start
: and for that tokenK
we can compute the best subset too (or maybe we already have it cached). - On the other hand, it may exclude
token[J]
, but then the best subset is simplytoken[J-1]
. - In either case, a special case
token[-1]
withtoken[-1].end = 0
and subset value0
can form the base case.
Since you only need to do this computation once for each token index, this part is actually linear in the number of tokens. However, sorting the tokens naively (which I'd recommend) is O(n log(n)) and finding a last token index given an string position is O(log(n)) - repeated n times; so the overall running time is O(n log(n)). You can reduce this to O(n) by observing that you don't need to sort an arbitrary list - the maximal string position is limited and small so you can do the sorting by indexing in the string, but it's almost certainly not worth it. Similarly, although finding one token by binary search is log n
you can do this by aligning two lists instead - one sorted on token end, the other on token start - thus permitting an O(n + m)
implementation. Unless n
can really get huge, it's not worth it.
If you iterate from the front of the string to the end, since all lookups look "back", you can remove the recursion entirely and simply directly lookup the result for a given index since it must have been calculated already anyhow.
Does this rather vague explanation help? It's a basic application of dynamic programming, which is just a fancy word for caching; so if you're confused, that's what you should read up on.
Extending this to the top k-best solutions
If you want to find the top-K best solutions, you'll need a messy but doable extension that maps index-of token not to the single best subset, but to the best-K subsets so far - obviously at increased computational cost and a bit of extra code. Essentially, rather than picking to either include or not include token[J]
, you'll take the set union and trim down to the k-best options at each token-index. That's O(n log(n) + n k log(k))
if implemented straightforwardly.
A really simple way to get all the non-contradicting token-groups:
#!/usr/bin/env python
token_list = ['a', 'b', 'c']
contradictions = {
'a': set(['b']),
'b': set(['a']),
'c': set()
}
result = []
while token_list:
token = token_list.pop()
new = [set([token])]
for r in result:
if token not in contradictions or not r & contradictions[token]:
new.append(r | set([token]))
result.extend(new)
print result
Here's a possible "heuristically optimized" approach and a small sample:
import itertools
# tokens in decreasing order of value (must all be > 0)
toks = 12, 11, 8, 7, 6, 2, 1
# contradictions (dict highestvaltok -> set of incompatible ones)
cont = {12: set([11, 8, 7, 2]),
11: set([8, 7, 6]),
7: set([2]),
2: set([1]),
}
rec_calls = 0
def bestgroup(toks, contdict, arein=(), contset=()):
"""Recursively compute the highest-valued non-contradictory subset of toks."""
global rec_calls
toks = list(toks)
while toks:
# find the top token compatible w/the ones in `arein`
toptok = toks.pop(0)
if toptok in contset:
continue
# try to extend with and without this toptok
without_top = bestgroup(toks, contdict, arein, contset)
contset = set(contset).union(c for c in contdict.get(toptok, ()))
newarein = arein + (toptok,)
with_top = bestgroup(toks, contdict, newarein, contset)
rec_calls += 1
if sum(with_top) > sum(without_top):
return with_top
else:
return without_top
return arein
def noncongroups(toks, contdict):
"""Count possible, non-contradictory subsets of toks."""
tot = 0
for l in range(1, len(toks) + 1):
for c in itertools.combinations(toks, l):
if any(cont[k].intersection(c) for k in c if k in contdict): continue
tot += 1
return tot
print bestgroup(toks, cont)
print 'calls: %d (vs %d of %d)' % (rec_calls, noncongroups(toks, cont), 2**len(toks))
I believe this always makes as many recursive calls as feasible (non-contradictory) subsets exist, but haven't proven it (so I'm just counting both -- the noncongroups
of course has nothing to do with the solution, it's there just to check that behavioral property;-).
If this produces an acceptable speedup on your "actual use cases" benchmarks, then further optimization may introduce alpha-pruning (so you can stop recursion along paths that you know to be unproductive -- that's the motivation for the descending order in the tokens;-) and recursion elimination (using an explicit stack within the function instead). But I wanted to keep this first version simple, so it can easily be understood and verified (also, the further optimizations I have in mind are only going to help marginally, I suspect -- say, at best, halving the typical runtime, if even that much).
The following solution generates all maximal non-contradicting subsets, taking advantage of the fact that there's no point omitting an element from the solution unless it contradicts another element in the solution.
The simple optimization to avoid the second recursion in the case that the element t doesn't contradict any of the remaining elements should help make this solution efficient if the number of contradictions is small.
def solve(tokens, contradictions):
if not tokens:
yield set()
else:
tokens = set(tokens)
t = tokens.pop()
for solution in solve(tokens - contradictions[t], contradictions):
yield solution | set([t])
if contradictions[t] & tokens:
for solution in solve(tokens, contradictions):
if contradictions[t] & solution:
yield solution
This solution also demonstrates that dynamic programming (aka memoization) may be helpful to improve the performance of the solution further for some types of inputs.
精彩评论