@njit函数不能直接用于rolling().apply(),因其不支持pandas.Series输入;正确做法是设raw=True使窗口传入ndarray,再用@njit处理一维数组并返回标量。

为什么 rolling().apply() 不能直接用 @njit
因为 pandas 的 rolling().apply() 要求传入函数能接收一个 Series 或 ndarray(取决于 raw 参数),而 numba 的 @njit 编译后函数默认不支持 pandas.Series —— 它只认纯 numpy.ndarray,且要求类型在编译时可推断。直接把 @njit 函数塞进 apply() 会报类似 TypingError: Failed in nopython mode pipeline... cannot determine Numba type of 的错误。
正确做法:设 raw=True + 用 @njit 处理 ndarray
关键点是让 rolling 把窗口数据以 numpy.ndarray 形式传给你的函数,而不是 Series。这靠设置 raw=True 实现。此时传入函数的参数就是一维 ndarray,numba 可以顺利编译和运行。
-
raw=True是必须的,否则传进来的是Series,@njit直接拒收 - 函数签名必须只接受一个
ndarray,返回标量(如float64);不能有可选参数、不能调用 pandas 方法 - 如果窗口含
NaN,raw=True下仍会传入nan值,需在@njit函数内显式处理(比如用np.isnan()判断) - 示例:计算滚动窗口中位数(
np.median在numba中不可用,得手写)
@njit def rolling_median(arr): n = len(arr) if n == 0: return np.nan # 简单排序取中位(仅作示意;生产环境建议用 partial sort) temp = arr.copy() for i in range(n): for j in range(i + 1, n): if temp[i] > temp[j]: temp[i], temp[j] = temp[j], temp[i] if n % 2 == 1: return temp[n // 2] else: return (temp[n // 2 - 1] + temp[n // 2]) / 2.0使用方式
s = pd.Series([1.0, 3.0, 2.0, 5.0, 4.0]) result = s.rolling(3).apply(rolling_median, raw=True)
性能对比与注意事项
加速效果取决于函数复杂度。对简单操作(如
np.sum),用@njit反而可能更慢(启动开销+小数组无优势);但对自定义逻辑(如条件聚合、分位数、迭代计算),提速常达 2–10 倍。不过要注意:
-
raw=True会禁用rolling内置的min_periods对齐逻辑:若窗口内有效值不足,apply仍会传入含NaN的数组,你得自己检查有效长度 -
@njit不支持axis参数或高维数组;rolling在 DataFrame 上使用时,需逐列处理或改用apply(axis=1)+raw=True,但后者传入的是行向量(1D),不是二维切片 - 调试困难:
@njit报错信息晦涩,建议先用纯 Python 版本验证逻辑,再加装饰器
替代方案:用 numba 手写滚动循环(更可控)
当 rolling().apply(raw=True) 无法满足需求(比如要跳过 NaN、动态窗口、或需要索引信息),更稳的方式是绕过 pandas 的 rolling,直接用 numba 写一个带循环的函数,输入整个数组,输出结果数组。这样完全掌控内存布局、缺失值策略和边界行为。
例如实现等价于 s.rolling(3).mean() 的 jit 版本:
@njit
def rolling_mean_jit(arr, window):
n = len(arr)
out = np.full(n, np.nan)
for i in range(window - 1, n):
window_arr = arr[i - window + 1:i + 1]
# 过滤 NaN(numba 支持 np.nansum / np.nanmean 从 0.58+,但旧版需手动)
total, count = 0.0, 0
for j in range(window):
if not np.isnan(window_arr[j]):
total += window_arr[j]
count += 1
if count > 0:
out[i] = total / count
return out
调用
result = rolling_mean_jit(s.to_numpy(), 3)
这种写法失去 pandas 的自动对齐和日期索引保留能力,但换来最大灵活性和可预测性能。真正高频、长序列、定制逻辑强的场景,这是更可靠的选择。










