凌云的博客

行胜于言

LeetCode 算法题 4. 寻找两个有序数组的中位数

分类:algorithm| 发布时间:2016-06-25 23:04:00


题目

给定两个大小为 m 和 n 的有序数组 nums1 和 nums2。

请你找出这两个有序数组的中位数,并且要求算法的时间复杂度为 O(log(m + n))。

你可以假设 nums1 和 nums2 不会同时为空。

示例 1:

nums1 = [1, 3]
nums2 = [2]

则中位数是 2.0

示例 2:

nums1 = [1, 2]
nums2 = [3, 4]

则中位数是 (2 + 3)/2 = 2.5

解法1

将两个有序数组合成一个有序数组,然后直接查出结果。 这个算法简单易懂,不过时间复杂度是 O(m+n),不符合题目要求, 无论如何,先给出实现。

class Solution {
public:
    double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
        vector<int> merge;
        int m = nums1.size(), n = nums2.size();
        merge.reserve(m + n);
        int i = 0, j = 0;

        for (; i != m && j !=n;) {
            if (nums1[i] <= nums2[j]) {
                merge.push_back(nums1[i++]);
            } else {
                merge.push_back(nums2[j++]);
            }
        }

        while (i != m) {
            merge.push_back(nums1[i++]);
        }

        while (j != n) {
            merge.push_back(nums2[j++]);
        }

        int median = (m + n - 1) / 2;
        if ((m + n) % 2) {
            return merge[median];
        } else {
            return (double(merge[median]) + merge[median + 1]) / 2;
        }
    }
};

解法2

本题目关键在于在将个有序的数组中找出特定序号的值。 假设要找出两个数组中从小到大排 k 位的数字。 设数组 nums1 长度为 m,数组 nums2 长度为 n,并且 m < n。

    int a = min(k / 2, m);
    int b = k - a;

对比 nums1[a] 与 nums2[b] 的大小,若 nums1[a] 小于等于 nums2[b],则 第 k 个数字必定在 nums1 的第 a 个之后,或者在 nums2 中(可以通过反证法,证明这一点) 否则第 k 个数字必定在 nums1 或者在 nums2 的第 b 个数字之后(同样可以通过反证法证明)。

给出 C++ 的递归实现,此算法的复杂度为 O(log(m + n)),满足题目要求。

class Solution {
public:
    double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
        const int m = nums1.size(), n = nums2.size();
        int median = (m + n - 1) / 2 + 1;
        if ((m + n) % 2) {
            return find(nums1, m, nums2, n, median);
        } else {
            return (double(find(nums1, m, nums2, n, median)) + find(nums1, m, nums2, n, median + 1)) / 2;
        }
    }

private:
    int find(vector<int> &nums1, int m, vector<int> &nums2, int n, int off, int mStart = 0, int nStart = 0) {
        if (m > n) {
            return find(nums2, n, nums1, m, off, nStart, mStart);
        }

        if (!m) {
            return nums2[off + nStart - 1];
        }

        if (1 == off) {
            return min(nums1[mStart], nums2[nStart]);
        }

        int a = min(off / 2, m);
        int b = off - a;
        if (nums1[a + mStart - 1] <= nums2[b + nStart - 1]) {
            return find(nums1, m - a, nums2, n, off - a, mStart + a, nStart);
        } else {
            return find(nums1, m, nums2, n - b, off - b, mStart, nStart + b);
        }
    }
};

当然,稍微修改下就能变成非递归:

class Solution {
public:
    double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
        const int m = nums1.size(), n = nums2.size();
        int median = (m + n - 1) / 2 + 1;
        if ((m + n) % 2) {
            return find(&nums1, m, &nums2, n, median);
        } else {
            return (double(find(&nums1, m, &nums2, n, median)) + find(&nums1, m, &nums2, n, median + 1)) / 2;
        }
    }

private:
    int find(vector<int> *nums1, int m, vector<int> *nums2, int n, int off) {
        int mStart = 0, nStart = 0;
        while (true) {
            if (m > n) {
                swap(nums1, nums2);
                swap(m, n);
                swap(mStart, nStart);
            }

            if (!m) {
                return (*nums2)[off + nStart - 1];
            }

            if (1 == off) {
                return min((*nums1)[mStart], (*nums2)[nStart]);
            }

            int a = min(off / 2, m);
            int b = off - a;
            if ((*nums1)[a + mStart - 1] <= (*nums2)[b + nStart - 1]) {
                m -= a;
                off -= a;
                mStart += a;
            } else {
                n -= b;
                off -= b;
                nStart += b;
            }
        }
    }
};