2024年2月

前言

个人认为,这是一道非常经典的 LeetCode Hard 题目,第一次碰到时,感觉觉得如此之精妙,想破脑袋也不会想到使用什么样合适的数据结构去解决这道问题。时隔快5年,在2024年农历新年最后一个工作日,重温下这道经典的题目。

题目描述

中位数是有序序列最中间的那个数。如果序列的长度是偶数,则没有最中间的数;此时中位数是最中间的两个数的平均数。

例如:

[2,3,4],中位数是 3
[2,3],中位数是 (2 + 3) / 2 = 2.5
给你一个数组 nums,有一个长度为 k 的窗口从最左端滑动到最右端。窗口中有 k 个数,每次窗口向右移动 1 位。你的任务是找出每次窗口移动后得到的新窗口中元素的中位数,并输出由它们组成的数组。

示例:

给出 nums = [1,3,-1,-3,5,3,6,7],以及 k = 3。

窗口位置 中位数
--------------- -----
[1 3 -1] -3 5 3 6 7 1
1 [3 -1 -3] 5 3 6 7 -1
1 3 [-1 -3 5] 3 6 7 -1
1 3 -1 [-3 5 3] 6 7 3
1 3 -1 -3 [5 3 6] 7 5
1 3 -1 -3 5 [3 6 7] 6
因此,返回该滑动窗口的中位数数组 [1,-1,-1,3,5,6]。

提示:

你可以假设 k 始终有效,即:k 始终小于等于输入的非空数组的元素个数。
与真实值误差在 10 ^ -5 以内的答案将被视作正确答案。

原题快速门

说句题外话,以现在卷的程度,是不是很多校招生必会解此题呢,想当年,这样的题目,是不太可能出现在现场笔试题目中的。^.^

解题思路

首先,拿到一道题目后,要分析输入规模,那其实老的题目这点是不太规范的,往往没有给你输入规模的描述,也就少了一些提示。我理解这道题目的输入规模应该如下:
n = len(nums)
1 <= k <= n <= 10^5
那么如果是这样的数据规模,所有O(n^2)的算法可以直接劝退,不用费脑子和精力去想了。接下来,要仔细想下题目的考点是哪一个,或者哪几个组合。显然这道题目非常贴心的告诉你一个考点,滑动窗口,什么是滑动窗口,题目的示例解释的非常清楚了,k表示滑动窗口的长度,窗口往右滑动时,前面的数出窗口,后面的数进入窗口,就是这么个过程。其次,要求窗口内数据的中位数,如果每次暴力的重新计算,那么需要排序,取中间值,显然这个方案的复杂度,最坏情况下会变成O(n^2logn),应该是过不了testcase,那么这里就要引入数据结构,以O(1)时间计算中位数,如果滑动窗口内的数字保存在一个有序数组,那么计算中位数的复杂度为O(1)就很轻松了,但是删除和添加一个元素到有序数组,操作是O(N)的,那什么数据结构删除和添加一个元素是O(logn)以下的复杂度呢?啊,想到了,类似堆或者红黑树的数据结构。
那怎样使用有序的这些数据结构,实现滑动窗口中位数呢?最巧妙的设计在于,使用两个堆或者有序树的实例,一个维护前半段,一个维护后半段,那么中位数和前半段的最大值和后半段的最小值相关。如果是堆的话,前面半段使用最大堆,后面半段使用最小堆。如下图所示:
企业微信20240205-173311@2x.png

这里还有一个细节点,如果k是偶数,那么前半段和后半段的数字数量相等,否则前半段多维护一个数,所以前半段维护数字的数量是 k+1 除以 2 下取整,后半段维护数字的数量是 k 除以 2 下取整。所以当 k 为奇数时,前半段堆最大值就是中位数,否则,需要将两个堆顶的数字取平均。

代码实现

通过上面的分析,思路有了,我实际实现时,使用了Python3,以及SortedList这个有序List(底层实现采用了二分查找和分块算法。在插入元素时,它会将列表分块,并在每个块内部使用二分查找来维护有序性。这样可以在保证列表有序的同时,提高插入和查找元素的效率),效果和堆是类似的,就是优化插入和删除元素的时间复杂度。代码如下:

from sortedcontainers import SortedList

class Solution:
def medianSlidingWindow(self, nums: List[int], k: int) -> List[float]:
    n = len(nums)
    l = (k+1) // 2
    r = k // 2

    L = SortedList(nums[:k])
    R = SortedList()

    def L2R():
        x = L.pop()
        R.add(x)

    def R2L():
        x = R.pop(0)
        L.add(x)

    while len(L) > l:
        L2R()

    ans = []
    if l > r:
        ans.append(L[-1])
    else:
        ans.append((L[-1] + R[0]) / 2)

    for i in range(k, n):
        out_val = nums[i-k]
        if out_val in L:
            L.remove(out_val)
        else:
            R.remove(out_val)
        
        in_val = nums[i]
        if not L or in_val < L[-1]:
            L.add(in_val)
        else:
            R.add(in_val)
        
        if len(L) == l-1:
            R2L()
        elif len(L) == l+1:
            L2R()

        if l > r:
            ans.append(L[-1])
        else:
            ans.append((L[-1] + R[0]) / 2)
    
    return ans

结论

上述实现的代码,时间复杂度O(nlogk),空间复杂度O(k)。所谓温故知新,最近也有些堆顶堆的题目出现在周赛上面,只是比起以前的题目,现在的题目不会直接把题目考点赤裸裸写在标题里面,需要你自己抽象转换,这步其实非常难,需要你特别熟悉这些题目考点,才能在关键时刻想到如何转换。