开发者

Find the top k sums of two sorted arrays

You are given two sorted arrays, of sizes n and m respectively. Your task (should you choose to accept it), is to output the largest 开发者_如何转开发k sums of the form a[i]+b[j].

A O(k log k) solution can be found here. There are rumors of a O(k) or O(n) solution. Does one exist?


I found the responses at your link mostly vague and poorly structured. Here's a start with a O(k * log(min(m, n))) O(k * log(m + n)) O(k * log(k)) algorithm.

Suppose they are sorted decreasing. Imagine you computed the m*n matrix of the sums as follows:

for i from 0 to m
    for j from 0 to n
        sums[i][j] = a[i] + b[j]

In this matrix, values monotonically decrease down and to the right. With that in mind, here is an algorithm which performs a graph search through this matrix in order of decreasing sums.

q : priority queue (decreasing) := empty priority queue
add (0, 0) to q with priority a[0] + b[0]
while k > 0:
    k--
    x := pop q
    output x
    (i, j) : tuple of int,int := position of x
    if i < m:
        add (i + 1, j) to q with priority a[i + 1] + b[j]
    if j < n:
        add (i, j + 1) to q with priority a[i] + b[j + 1]

Analysis:

  1. The loop is executed k times.
    1. There is one pop operation per iteration.
    2. There are up to two insert operations per iteration.
  2. The maximum size of the priority queue is O(min(m, n)) O(m + n) O(k).
  3. The priority queue can be implemented with a binary heap giving log(size) pop and insert.
  4. Therefore this algorithm is O(k * log(min(m, n))) O(k * log(m + n)) O(k * log(k)).

Note that the general priority queue abstract data type needs to be modified to ignore duplicate entries. Alternately, you could maintain a separate set structure that first checks for membership in the set before adding to the queue, and removes from the set after popping from the queue. Neither of these ideas would worsen the time or space complexity.

I could write this up in Java if there's any interest.

Edit: fixed complexity. There is an algorithm which has the complexity I described, but it is slightly different from this one. You would have to take care to avoid adding certain nodes. My simple solution adds many nodes to the queue prematurely.


private static class FrontierElem implements Comparable<FrontierElem> {
    int value;
    int aIdx;
    int bIdx;

    public FrontierElem(int value, int aIdx, int bIdx) {
        this.value = value;
        this.aIdx = aIdx;
        this.bIdx = bIdx;
    }

    @Override
    public int compareTo(FrontierElem o) {
        return o.value - value;
    }

}

public static void findMaxSum( int [] a, int [] b, int k ) {
    Integer [] frontierA = new Integer[ a.length ];
    Integer [] frontierB = new Integer[ b.length ];
    PriorityQueue<FrontierElem> q = new PriorityQueue<MaxSum.FrontierElem>();
    frontierA[0] = frontierB[0]=0;
    q.add( new FrontierElem( a[0]+b[0], 0, 0));
    while( k > 0 ) {
        FrontierElem f = q.poll();
        System.out.println( f.value+"    "+q.size() );
        k--;
        frontierA[ f.aIdx ] = frontierB[ f.bIdx ] = null;
        int fRight = f.aIdx+1;
        int fDown = f.bIdx+1;
        if( fRight < a.length && frontierA[ fRight ] == null ) {
            q.add( new FrontierElem( a[fRight]+b[f.bIdx], fRight, f.bIdx));
            frontierA[ fRight ] = f.bIdx;
            frontierB[ f.bIdx ] = fRight;
        }
        if( fDown < b.length && frontierB[ fDown ] == null ) {
            q.add( new FrontierElem( a[f.aIdx]+b[fDown], f.aIdx, fDown));
            frontierA[ f.aIdx ] = fDown;
            frontierB[ fDown ] = f.aIdx;
        }
    }
}

The idea is similar to the other solution, but with the observation that as you add to your result set from the matrix, at every step the next element in our set can only come from where the current set is concave. I called these elements frontier elements and I keep track of their position in two arrays and their values in a priority queue. This helps keep the queue size down, but by how much I've yet to figure out. It seems to be about sqrt( k ) but I'm not entirely sure about that.

(Of course the frontierA/B arrays could be simple boolean arrays, but this way they fully define my result set, This isn't used anywhere in this example but might be useful otherwise.)


As the pre-condition is the Array are sorted hence lets consider the following for N= 5;

A[]={ 1,2,3,4,5}

B[]={ 496,497,498,499,500}

Now since we know Summation of N-1 of A&B would be highest hence just insert this in to heap along with the indexes of A & B element ( why, indexes? we'll come to know in a short while )

H.insert(A[N-1]+B[N-1],N-1,N-1);

now

 while(!H.empty()) { // the time heap is not empty 

 H.pop(); // this will give you the sum you are looking for 

 The indexes which we got at the time of pop, we shall use them for selecting the next sum element.

 Consider the following :
 if we have i & j as the indexes in A & B , then the next element would be  max ( A[i]+B[j-1], A[i-1]+B[j], A[i+1]+B[j+1] ) , 
 So, insert the same if that has not been inserted in the heap
 hence
 (i,j)= max ( A[i]+B[j-1], A[i-1]+B[j], A[i+1]+B[j+1] ) ;
 if(Hash[i,j]){ // not inserted 
    H.insert (i,j);
 }else{
    get the next max from max ( A[i]+B[j-1], A[i-1]+B[j], A[i+1]+B[j+1] ) ; and insert.                      
 }

 K pop-ing them will give you max elements required.

Hope this helps


Many thanks to @rlibby and @xuhdev with such an original idea to solve this kind of problem. I had a similar coding exercise interview require to find N largest sums formed by K elements in K descending sorted arrays - means we must pick 1 element from each sorted arrays to build the largest sum.

Example: List findHighestSums(int[][] lists, int n) {}

[5,4,3,2,1]
[4,1]
[5,0,0]
[6,4,2]
[1]

and a value of 5 for n, your procedure should return a List of size 5:

[21,20,19,19,18]

Below is my code, please take a look carefully for those block comments :D

private class Pair implements Comparable<Pair>{
    String state;

    int sum;

    public Pair(String state, int sum) {
        this.state = state;
        this.sum = sum;
    }

    @Override
    public int compareTo(Pair o) {
        // Max heap
        return o.sum - this.sum;
    }
}

List<Integer> findHighestSums(int[][] lists, int n) {

    int numOfLists = lists.length;
    int totalCharacterInState = 0;

    /*
     * To represent State of combination of largest sum as String
     * The number of characters for each list should be Math.ceil(log(list[i].length))
     * For example: 
     *      If list1 length contains from 11 to 100 elements
     *      Then the State represents for list1 will require 2 characters
     */
    int[] positionStartingCharacterOfListState = new int[numOfLists + 1];
    positionStartingCharacterOfListState[0] = 0;

    // the reason to set less or equal here is to get the position starting character of the last list
    for(int i = 1; i <= numOfLists; i++) {  
        int previousListNumOfCharacters = 1;
        if(lists[i-1].length > 10) {
            previousListNumOfCharacters = (int)Math.ceil(Math.log10(lists[i-1].length));
        }
        positionStartingCharacterOfListState[i] = positionStartingCharacterOfListState[i-1] + previousListNumOfCharacters;
        totalCharacterInState += previousListNumOfCharacters;
    }

    // Check the state <---> make sure that combination of a sum is new
    Set<String> states = new HashSet<>();
    List<Integer> result = new ArrayList<>();
    StringBuilder sb = new StringBuilder();

    // This is a max heap contain <State, largestSum>
    PriorityQueue<Pair> pq = new PriorityQueue<>();

    char[] stateChars = new char[totalCharacterInState];
    Arrays.fill(stateChars, '0');
    sb.append(stateChars);
    String firstState = sb.toString();
    states.add(firstState);

    int firstLargestSum = 0;
    for(int i = 0; i < numOfLists; i++) firstLargestSum += lists[i][0];

    // Imagine this is the initial state in a graph
    pq.add(new Pair(firstState, firstLargestSum));

    while(n > 0) {
        // In case n is larger than the number of combinations of all list entries 
        if(pq.isEmpty()) break;
        Pair top = pq.poll();
        String currentState = top.state;
        int currentSum = top.sum;

        /*
         * Loop for all lists and generate new states of which only 1 character is different from the former state  
         * For example: the initial state (Stage 0) 0 0 0 0 0
         * So the next states (Stage 1) should be:
         *  1 0 0 0 0
         *  0 1 0 0 0 (choose element at index 2 from 2nd array)
         *  0 0 1 0 0 (choose element at index 2 from 3rd array)
         *  0 0 0 0 1 
         * But don't forget to check whether index in any lists have exceeded list's length
         */
        for(int i = 0; i < numOfLists; i++) {
            int indexInList = Integer.parseInt(
                    currentState.substring(positionStartingCharacterOfListState[i], positionStartingCharacterOfListState[i+1]));
            if( indexInList < lists[i].length - 1) {
                int numberOfCharacters = positionStartingCharacterOfListState[i+1] - positionStartingCharacterOfListState[i];
                sb = new StringBuilder(currentState.substring(0, positionStartingCharacterOfListState[i]));
                sb.append(String.format("%0" + numberOfCharacters + "d", indexInList + 1));
                sb.append(currentState.substring(positionStartingCharacterOfListState[i+1]));
                String newState = sb.toString();
                if(!states.contains(newState)) {

                    // The newSum is always <= currentSum
                    int newSum = currentSum - lists[i][indexInList] + lists[i][indexInList+1];

                    states.add(newState);
                    // Using priority queue, we can immediately retrieve the largest Sum at Stage k and track all other unused states.
                    // From that Stage k largest Sum's state, then we can generate new states
                    // Those sums composed by recently generated states don't guarantee to be larger than those sums composed by old unused states.
                    pq.add(new Pair(newState, newSum));
                }

            }
        }
        result.add(currentSum);
        n--;
    }
    return result;
}

Let me explain how I come up with the solution:

  1. The while loop in my answer executes N times, consider the max heap ( priority queue).
  2. Poll operation 1 time with complexity O(log( sumOfListLength )) because the maximum element Pair in heap is sumOfListLength.
  3. Insertion operations might up to K times, the complexity for each insertion is log(sumOfListLength). Therefore, the complexity is O(N * log(sumOfListLength) ),
0

上一篇:

下一篇:

精彩评论

暂无评论...
验证码 换一张
取 消

最新问答

问答排行榜