0

0

如何在 Numba 中高效实现稀疏矩阵乘法(COO 格式)

聖光之護

聖光之護

发布时间:2026-02-22 08:27:12

|

548人浏览过

|

来源于php中文网

原创

如何在 Numba 中高效实现稀疏矩阵乘法(COO 格式)

本文介绍一种兼顾性能与兼容性的稀疏矩阵乘法加速方案:在 Numba nopython 模式下,通过 objmode 安全调用 SciPy 高度优化的稀疏乘法内核,并直接返回 COO 或 CSR 格式结果,实测仅比原生 SciPy 慢 2%–15%。

本文介绍一种兼顾性能与兼容性的稀疏矩阵乘法加速方案:在 numba `nopython` 模式下,通过 `objmode` 安全调用 scipy 高度优化的稀疏乘法内核,并直接返回 coo 或 csr 格式结果,实测仅比原生 scipy 慢 2%–15%。

在高性能科学计算中,稀疏矩阵乘法(SpMM)是常见瓶颈。尽管 Numba 不原生支持 scipy.sparse 类型,且手动实现 COO 格式乘法(如双循环 + 哈希累加)看似可控,但实际性能往往远逊于 SciPy —— 这源于 SciPy 底层调用高度优化的 CSR/CSC 内核(基于 Intel MKL、OpenBLAS 或自研稀疏算法),并经过数十年工程打磨。

直接手写 Numba 版本(如问题中 _mul 函数)面临多重挑战:

  • 内存局部性差:COO 三元组无序存储,导致频繁随机访存;
  • 重复索引合并开销大:需动态哈希或排序去重,而 np.zeros((an, bm)) 的稠密中间矩阵更违背稀疏初衷;
  • 并行化失效:@njit(parallel=True) 在细粒度稀疏访存场景下易引发线程竞争与缓存抖动,反而降速;
  • 内存预分配困难:输出非零元数量未知,extend_arr 动态扩容带来显著额外开销。

因此,最优策略不是重造轮子,而是“桥接”——在 Numba 生态中安全复用 SciPy 的工业级实现。关键在于:如何在保持调用方 @njit 兼容性的同时,嵌入 Python 层稀疏运算?

答案是 numba.objmode:它允许在 nopython 函数中划定一段“Python 模式”代码块,执行任意 Python 对象操作(如构建 coo_matrix、调用 @ 运算符、转换格式),同时严格声明该块的输入/输出类型,使 Numba 能静态推断整个函数签名。

故事AI绘图神器
故事AI绘图神器

文本生成图文视频的AI工具,无需配音,无需剪辑,快速成片,角色固定。

下载

以下为两个生产就绪的实现:

✅ 推荐方案 1:COO 格式输出(兼顾通用性)

import numba as nb
from scipy.sparse import coo_matrix

@nb.njit()
def mul_coo(ar, ac, av, br, bc, bv, n):
    """
    稀疏矩阵乘法:A @ B,输入为 COO 三元组,输出为 (row, col, data) 三元组。
    A = coo_matrix((av, (ar, ac)), shape=(n, n))
    B = coo_matrix((bv, (br, bc)), shape=(n, n))
    """
    with nb.objmode(row='i4[:]', col='i4[:]', data='f8[:]'):
        a_sci = coo_matrix((av, (ar, ac)), shape=(n, n))
        b_sci = coo_matrix((bv, (br, bc)), shape=(n, n))
        res_coo = (a_sci @ b_sci).tocoo()  # 强制转为 COO
        row = res_coo.row.copy()   # .row/.col 是 view,需 copy 保证所有权
        col = res_coo.col.copy()
        data = res_coo.data.copy()
    return row, col, data

✅ 推荐方案 2:CSR 格式输出(极致性能)

@nb.njit()
def mul_csr(ar, ac, av, br, bc, bv, n):
    """
    输出 CSR 格式三元组:(data, indices, indptr),避免 tocoo() 开销。
    更适合后续 CSR 专用计算(如 SpMV)。
    """
    with nb.objmode(data='f8[:]', indices='i4[:]', indptr='i4[:]'):
        a_sci = coo_matrix((av, (ar, ac)), shape=(n, n))
        b_sci = coo_matrix((bv, (br, bc)), shape=(n, n))
        res_csr = a_sci @ b_sci  # 默认返回 CSR
        data = res_csr.data.copy()
        indices = res_csr.indices.copy()
        indptr = res_csr.indptr.copy()
    return data, indices, indptr

⚠️ 关键注意事项

  • copy() 不可省略:res_coo.row 等是 NumPy view,若不显式 copy(),Numba 可能因内存所有权问题报错或产生未定义行为;
  • objmode 块内禁止 Numba 类型操作:所有 SciPy 构建、乘法、格式转换必须在 with nb.objmode(...): 内完成;
  • 类型声明须精确:'i4[:]' 表示 int32 一维数组,'f8[:]' 表示 float64 一维数组,与 NumPy dtype 严格对应;
  • 避免过度使用:objmode 会中断 JIT 流水线,仅用于不可替代的 Python 生态调用;若全程需纯 nopython,应转向 CSR/CSC 手写内核(但开发成本与调试难度剧增);
  • 形状一致性:示例假设方阵 (n, n),实际使用时请按需传入 shape=(m, k) 和 (k, p) 并校验维度兼容性。

? 性能对比(n=50000, m=1000)

方法 耗时(均值 ± std) 相对 SciPy
原生 SciPy (coo @ coo) 27.8 ms ± 471 µs 1.0×(基准)
手写 Numba _mul 184 ms ± 594 µs ≈6.6× 慢
mul_coo(COO 输出) 32.0 ms ± 228 µs ≈1.15× 慢
mul_csr(CSR 输出) 28.3 ms ± 685 µs ≈1.02× 慢

可见,objmode 方案成功将性能损失控制在极小范围内,同时保留了 Numba 函数链的 nopython 兼容性——上游数据预处理、下游稀疏向量运算等均可无缝使用 @njit 加速。

总结:当面对稀疏计算这类“已有成熟工业实现”的任务时,明智的加速策略是“站在巨人的肩膀上”。numba.objmode 提供了安全、类型明确、低开销的桥梁,让 Numba 用户得以在不牺牲生态优势的前提下,获得接近底层库的性能。这正是现代高性能 Python 工程化的典型范式。

本站声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

通义千问
通义千问

阿里巴巴推出的全能AI助手

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
java基础知识汇总
java基础知识汇总

java基础知识有Java的历史和特点、Java的开发环境、Java的基本数据类型、变量和常量、运算符和表达式、控制语句、数组和字符串等等知识点。想要知道更多关于java基础知识的朋友,请阅读本专题下面的的有关文章,欢迎大家来php中文网学习。

1556

2023.10.24

Go语言中的运算符有哪些
Go语言中的运算符有哪些

Go语言中的运算符有:1、加法运算符;2、减法运算符;3、乘法运算符;4、除法运算符;5、取余运算符;6、比较运算符;7、位运算符;8、按位与运算符;9、按位或运算符;10、按位异或运算符等等。本专题为大家提供相关的文章、下载、课程内容,供大家免费下载体验。

238

2024.02.23

php三元运算符用法
php三元运算符用法

本专题整合了php三元运算符相关教程,阅读专题下面的文章了解更多详细内容。

127

2025.10.17

线程和进程的区别
线程和进程的区别

线程和进程的区别:线程是进程的一部分,用于实现并发和并行操作,而线程共享进程的资源,通信更方便快捷,切换开销较小。本专题为大家提供线程和进程区别相关的各种文章、以及下载和课程。

695

2023.08.10

线程和进程的区别
线程和进程的区别

线程和进程的区别:线程是进程的一部分,用于实现并发和并行操作,而线程共享进程的资源,通信更方便快捷,切换开销较小。本专题为大家提供线程和进程区别相关的各种文章、以及下载和课程。

695

2023.08.10

页面置换算法
页面置换算法

页面置换算法是操作系统中用来决定在内存中哪些页面应该被换出以便为新的页面提供空间的算法。本专题为大家提供页面置换算法的相关文章,大家可以免费体验。

461

2023.08.14

pixiv网页版官网登录与阅读指南_pixiv官网直达入口与在线访问方法
pixiv网页版官网登录与阅读指南_pixiv官网直达入口与在线访问方法

本专题系统整理pixiv网页版官网入口及登录访问方式,涵盖官网登录页面直达路径、在线阅读入口及快速进入方法说明,帮助用户高效找到pixiv官方网站,实现便捷、安全的网页端浏览与账号登录体验。

928

2026.02.13

微博网页版主页入口与登录指南_官方网页端快速访问方法
微博网页版主页入口与登录指南_官方网页端快速访问方法

本专题系统整理微博网页版官方入口及网页端登录方式,涵盖首页直达地址、账号登录流程与常见访问问题说明,帮助用户快速找到微博官网主页,实现便捷、安全的网页端登录与内容浏览体验。

307

2026.02.13

Flutter跨平台开发与状态管理实战
Flutter跨平台开发与状态管理实战

本专题围绕Flutter框架展开,系统讲解跨平台UI构建原理与状态管理方案。内容涵盖Widget生命周期、路由管理、Provider与Bloc状态管理模式、网络请求封装及性能优化技巧。通过实战项目演示,帮助开发者构建流畅、可维护的跨平台移动应用。

183

2026.02.13

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号