0

0

PyMC模型中自定义对数似然的性能优化:兼论JAX兼容性与数学表达式重构

心靈之曲

心靈之曲

发布时间:2025-11-19 15:37:24

|

606人浏览过

|

来源于php中文网

原创

PyMC模型中自定义对数似然的性能优化:兼论JAX兼容性与数学表达式重构

pymc模型中,当使用自定义pytensor op定义对数似然并尝试结合blackjax采样器时,可能遭遇jax转换兼容性错误。本文将深入探讨如何实现自定义对数似然,分析blackjax集成时的挑战,并提供一种通过数学表达式重构来显著提升核心计算函数性能的通用优化策略,即使无法利用jax加速,也能有效缩短采样时间。

1. PyMC中自定义对数似然函数的实现

在贝叶斯建模中,有时标准分布无法满足特定需求,需要引入自定义的对数似然函数。PyMC(基于PyTensor)提供了一种机制,允许用户通过定义自定义的pytensor.Op来集成任意Python函数及其梯度。

1.1 定义自定义PyTensor Op

要将一个复杂的Python函数(例如,涉及外部库或数值求解器)集成到PyMC的计算图中,需要创建两个pytensor.Op类:一个用于计算函数值(对数似然),另一个用于计算其梯度。

LogLikeWithGrad (对数似然函数)

这个Op负责计算给定参数的对数似然值。它需要实现perform方法来执行实际的对数似然计算,并重载grad方法来指定如何计算梯度。

import pytensor.tensor as pt
import numpy as np
from scipy.optimize import approx_fprime # 用于数值梯度

class LogLikeWithGrad(pt.Op):
    itypes = [pt.dvector]  # 输入是一个参数向量
    otypes = [pt.dscalar]  # 输出是一个标量(对数似然值)

    def __init__(self, loglike_function):
        self.likelihood = loglike_function
        self.loglike_grad_op = LogLikeGrad(loglike_function) # 初始化梯度Op

    def perform(self, node, inputs, outputs):
        (theta,) = inputs
        logl = self.likelihood(theta)
        outputs[0][0] = np.array(logl)

    def grad(self, inputs, grad_outputs):
        (theta,) = inputs
        # 调用自定义的梯度Op来计算梯度
        grads = self.loglike_grad_op(theta)
        return [grad_outputs[0] * grads]

LogLikeGrad (对数似然梯度函数)

这个Op专门用于计算对数似然函数相对于其输入的梯度。在缺乏解析梯度的情况下,可以使用数值近似方法,例如scipy.optimize.approx_fprime。

class LogLikeGrad(pt.Op):
    itypes = [pt.dvector]  # 输入是一个参数向量
    otypes = [pt.dvector]  # 输出是一个梯度向量

    def __init__(self, loglike_function):
        self.likelihood = loglike_function

    def perform(self, node, inputs, outputs):
        (theta,) = inputs
        # 使用数值方法近似梯度
        grads = approx_fprime(theta, self.likelihood, epsilon=1e-8)
        outputs[0][0] = grads

1.2 在PyMC模型中集成自定义似然

一旦定义了自定义的LogLikeWithGrad Op,就可以将其作为pm.Potential添加到PyMC模型中。pm.Potential允许用户在模型中引入任意的对数概率贡献。

import pymc as pm

# 假设 applyMCMC 是你的核心对数似然计算函数
# 并且 param_names 和 lower/upper_boundaries 已经定义
# logl = LogLikeWithGrad(applyMCMC)

with pm.Model() as model:
    # 定义模型参数
    for i, name in enumerate(param_names):
        pm.Uniform(name, lower=lower_boundaries[0][i], upper=upper_boundaries[0][i])

    # 将所有参数组合成一个PyTensor向量
    theta = pt.as_tensor_variable([model[param] for param in param_names])

    # 将自定义对数似然作为潜力项添加到模型中
    pm.Potential("likelihood", logl(theta))

    # 执行采样
    # trace = pm.sample(draws=niter, step=pm.NUTS(), tune=500, cores=64, init="jitter+adapt_diag", progressbar=True)

2. Blackjax采样器与PyTensor Op的JAX兼容性挑战

PyMC 5.x 版本支持多种NUTS采样器后端,包括其默认的PyTensor后端以及基于JAX的Blackjax采样器,后者在GPU等加速设备上表现出色。然而,当模型中包含自定义的pytensor.Op时,尝试使用Blackjax采样器可能会遇到兼容性问题。

2.1 NotImplementedError 的根源

当尝试通过 pm.sample(nuts_sampler="blackjax") 使用Blackjax时,如果自定义的LogLikeWithGrad Op没有对应的JAX转换实现,PyTensor的JAX后端会抛出 NotImplementedError,错误信息通常为 No JAX conversion for the given Op: LogLikeWithGrad。

AIBox 一站式AI创作平台
AIBox 一站式AI创作平台

AIBox365一站式AI创作平台,支持ChatGPT、GPT4、Claue3、Gemini、Midjourney等国内外大模型

下载

这是因为Blackjax采样器依赖于JAX的即时编译(JIT)能力,而JAX只能编译其能够理解的操作。自定义的pytensor.Op本质上是一个Python对象,PyTensor需要一个明确的规则来告诉JAX如何将其转换为JAX可执行的操作。对于标准PyTensor操作,这些转换已经内置,但对于用户自定义的Op,则需要手动提供。

2.2 解决方案方向

解决此问题通常需要以下两种方法之一:

  1. 重写Op以完全使用JAX操作:如果自定义对数似然函数的核心逻辑可以用JAX原生的操作(如jax.numpy)表达,那么可以直接在JAX中构建这个似然函数,并将其集成到PyMC模型中。这通常涉及将所有PyTensor代码替换为JAX代码。
  2. 为自定义Op提供JAX转换规则:这是一种更高级的方法,涉及为自定义pytensor.Op编写一个JAX转换函数,告诉PyTensor的JAX后端如何将该Op转换为JAX的计算图。这通常通过注册JAX调度函数来实现,但实现起来较为复杂,且文档相对较少。

在许多情况下,特别是当自定义似然函数依赖于复杂外部库(如物理模拟器)时,直接将其完全转换为JAX操作可能非常困难或不可能。此时,即使无法利用Blackjax的JAX加速,我们仍然可以通过优化核心计算逻辑来提升采样性能。

3. 提升PyMC模型计算性能的通用策略:数学表达式优化

即使无法直接利用JAX的GPU加速,通过对核心数学计算函数进行细致的优化,也能显著提升PyMC模型的采样速度。这种优化策略侧重于减少冗余计算、避免重复的函数调用以及利用局部变量缓存中间结果。

3.1 识别并消除冗余计算

在复杂的数学表达式中,往往存在重复计算相同子表达式的情况。通过将这些子表达式的结果存储在局部变量中,可以避免多次计算,从而提高效率。

以原始代码中的dH和du函数为例,它们包含大量重复的幂运算和乘法:

  • (1 + z) 的不同幂次 ((1 + z)**2, (1 + z)**3, (1 + z)**4, (1 + z)**5)
  • math.pi * Rho_m
  • Omega_k * Phi
  • Phi * Phi (即 Phi**2)

3.2 优化示例:dH 和 du 函数

以下是针对 dH 和 du 函数的优化版本,通过引入局部变量来缓存重复计算的中间结果:

import math
import timeit # 用于性能测试

# 假设 Rho_m, Phi, u, omega_BD, Omega_k, z 为示例输入
# (为了测试方便,这里使用任意值,实际应是模型参数)
Rho_m = -1.0
Phi = 0.1
u = 3.0
omega_BD = 4.0
Omega_k = -5.0
z = 6.0

# 原始 du 函数 (为对比而保留,实际代码中应替换为优化版本)
def du_original(Rho_m, Phi, u, omega_BD, Omega_k, z):
    return (
        24 * math.pi * Rho_m * Phi**3
        + (1 + z)
        * u
        * Phi**2
        * (
            8 * math.pi * (-3 + omega_BD) * Rho_m
            - 3 * (1 + z) ** 2 * (3 + 2 * omega_BD) * Omega_k * Phi
        )
        - 3
        * (1 + z) ** 2
        * u**2
        * Phi
        * (
            -4 * math.pi * omega_BD * Rho_m
            + (1 + z) ** 4 * (3 + 2 * omega_BD) * Omega_k * Phi
        )
        - omega_BD
        * u**3
        * (
            4 * math.pi * (1 + z) ** 3 * (1 + omega_BD) * Rho_m
            + (1 + z) ** 5 * (3 + 2 * omega_BD) * Omega_k * Phi
        )
    ) / (
        (1 + z) ** 2
        * (3 + 2 * omega_BD)
        * Phi**2
        * (8 * math.pi * Rho_m + 3 * (1 + z) ** 2 * Omega_k * Phi)
    )

# 优化后的 du 函数
def du_optimized(Rho_m, Phi, u, omega_BD, Omega_k, z):
    # 缓存幂次和重复乘法
    Phi_pow2 = Phi * Phi
    Phi_pow3 = Phi_pow2 * Phi
    one_plus_z = 1 + z
    one_plus_z_pow2 = one_plus_z * one_plus_z
    one_plus_z_pow3 = one_plus_z_pow2 * one_plus_z
    one_plus_z_pow4 = one_plus_z_pow3 * one_plus_z
    one_plus_z_pow5 = one_plus_z_pow4 * one_plus_z

    # 缓存其他重复子表达式
    one_plus_z_pow2_times_3 = 3 * one_plus_z_pow2
    pi_times_Rho_m = math.pi * Rho_m
    Omega_k_times_Phi = Omega_k * Phi
    u_pow2 = u * u
    u_pow3 = u_pow2 * u
    omg1 = (3 + 2 * omega_BD) # (3 + 2 * omega_BD)
    omg = omg1 * Omega_k_times_Phi # (3 + 2 * omega_BD) * Omega_k * Phi
    omg2 = omega_BD * pi_times_Rho_m # omega_BD * math.pi * Rho_m

    return (
        24 * pi_times_Rho_m * Phi_pow3
        + one_plus_z * u * Phi_pow2 * (8 * (-3 + omega_BD) * pi_times_Rho_m - one_plus_z_pow2_times_3 * omg)
        - one_plus_z_pow2_times_3 * u_pow2 * Phi * (-4 * omg2 + one_plus_z_pow4 * omg)
        - omega_BD * u_pow3 * (4 * one_plus_z_pow3 * (pi_times_Rho_m + omg2) + one_plus_z_pow5 * omg)
    ) / (
        one_plus_z_pow2 * omg1 * Phi_pow2 * (8 * pi_times_Rho_m + one_plus_z_pow2_times_3 * Omega_k_times_Phi)
    )

# 原始 dH 函数 (为对比而保留)
def dH_original(Rho_m, Phi, u, omega_BD, Omega_k, z):
    val = (-16 * math.pi * Rho_m - 6 * (1 + z) ** 2 * Omega_k * Phi) / (
        6 * (1 + z) * u + ((1 + z) ** 2 * omega_BD * u**2) / Phi - 6 * Phi
    )
    if val >= 0:
        return -(
            (
                (1 + z)
                * (16 * math.pi * Rho_m + 6 * (1 + z) ** 2 * Omega_k * Phi)
                * (
                    (1 + z) * omega_BD * u**3
                    - 2
                    * omega_BD
                    * u
                    * ((1 + z) * du_original(Rho_m, Phi, u, omega_BD, Omega_k, z) + u)
                    * Phi
                    - 6 * du_original(Rho_m, Phi, u, omega_BD, Omega_k, z) * Phi**2
                )
                + (
                    6
                    * Phi
                    * (
                        -8 * math.pi * Rho_m
                        + (1 + z) ** 2 * Omega_k * ((1 + z) * u + 2 * Phi)
                    )
                    * (6 * Phi**2 - (1 + z) * u * ((1 + z) * omega_BD * u + 6 * Phi))
                )
                / (1 + z)
            )
            / (
                2
                * math.sqrt(val)
                * (
                    (1 + z) ** 2 * omega_BD * u**2
                    + 6 * (1 + z) * u * Phi
                    - 6 * Phi**2
                )
                ** 2
            )
        )
    else:
        return None

# 优化后的 dH 函数
def dH_optimized(Rho_m, Phi, u, omega_BD, Omega_k, z):
    # 缓存常用变量和幂次
    Phi_pow2 = Phi * Phi
    Phi_pow2_times_6 = Phi_pow2 * 6
    Phi_times_6 = Phi * 6
    one_plus_z = 1 + z
    one_plus_z_pow2 = one_plus_z * one_plus_z
    one_plus_z_times_u = one_plus_z * u
    pi_times_Rho_m = math.pi * Rho_m
    Omega_k_times_Phi = Omega_k * Phi
    u_pow2 = u * u
    u_pow3 = u_pow2 * u

    # 重新计算 duu (如果 duu 在 dH 内部被调用多次,直接内联其计算可进一步优化)
    # 此处为简洁起见,仍调用 du_optimized,但注意实际场景可内联
    # 或者,如原答案所示,直接将 du_optimized 的计算逻辑复制到此处

    # duu 的内联计算部分 (来自 du_optimized)
    Phi_pow3_du = Phi_pow2 * Phi
    one_plus_z_pow3_du = one_plus_z_pow2 * one_plus_z
    one_plus_z_pow4_du = one_plus_z_pow3_du * one_plus_z
    one_plus_z_pow5_du = one_plus_z_pow4_du * one_plus_z
    one_plus_z_pow2_times_3_du = 3 * one_plus_z_pow2
    omg1_du = (3 + 2 * omega_BD)
    omg_du = omg1_du * Omega_k_times_Phi
    omg2_du = omega_BD * pi_times_Rho_m
    duu = (
        24 * pi_times_Rho_m * Phi_pow3_du
        + one_plus_z * u * Phi_pow2 * (8 * (-3 + omega_BD) * pi_times_Rho_m - one_plus_z_pow2_times_3_du * omg_du)
        - one_plus_z_pow2_times_3_du * u_pow2 * Phi * (-4 * omg2_du + one_plus_z_pow4_du * omg_du)
        - omega_BD * u_pow3 * (4 * one_plus_z_pow3_du * (pi_times_Rho_m + omg2_du) + one_plus_z_pow5_du * omg_du)
    ) / (
        one_plus_z_pow2 * omg1_du * Phi_pow2 * (8 * pi_times_Rho_m + one_plus_z_pow2_times_3_du * Omega_k_times_Phi)
    )
    # duu 内联计算结束

    val1 = (-16 * pi_times_Rho_m - 6 * one_plus_z_pow2 * Omega_k_times_Phi)
    val = val1 / (6 * one_plus_z_times_u + (one_plus_z_pow2 * omega_BD * u_pow2) / Phi - Phi_times_6)

    if val >= 0:
        Phi_times_2 = Phi + Phi
        val2 = (one_plus_z_pow2 * omega_BD * u_pow2 + 6 * (one_plus_z_times_u * Phi - Phi_pow2))

        # 优化分子中的复杂项
        term1_numerator = one_plus_z_pow2 * val1 * (omega_BD * u * (one_plus_z * u_pow2 - Phi_times_2 * (one_plus_z * duu + u)) - duu * Phi_pow2_times_6)
        term2_numerator = Phi_times_6 * (-8 * pi_times_Rho_m + one_plus_z_pow2 * Omega_k * (one_plus_z_times_u + Phi_times_2)) * (Phi_pow2_times_6 - one_plus_z_times_u * (one_plus_z_times_u * omega_BD + Phi_times_6))

        return (term1_numerator - term2_numerator) / (2 * one_plus_z * math.sqrt(val) * val2 * val2)
    else:
        return None

# 性能测试
t_original = timeit.timeit('dH_original(Rho_m, Phi, u, omega_

相关文章

数码产品性能查询
数码产品性能查询

该软件包括了市面上所有手机CPU,手机跑分情况,电脑CPU,电脑产品信息等等,方便需要大家查阅数码产品最新情况,了解产品特性,能够进行对比选择最具性价比的商品。

下载

本站声明:本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

WorkBuddy
WorkBuddy

腾讯云推出的AI原生桌面智能体工作台

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
PHP 高并发与性能优化
PHP 高并发与性能优化

本专题聚焦 PHP 在高并发场景下的性能优化与系统调优,内容涵盖 Nginx 与 PHP-FPM 优化、Opcode 缓存、Redis/Memcached 应用、异步任务队列、数据库优化、代码性能分析与瓶颈排查。通过实战案例(如高并发接口优化、缓存系统设计、秒杀活动实现),帮助学习者掌握 构建高性能PHP后端系统的核心能力。

114

2025.10.16

PHP 数据库操作与性能优化
PHP 数据库操作与性能优化

本专题聚焦于PHP在数据库开发中的核心应用,详细讲解PDO与MySQLi的使用方法、预处理语句、事务控制与安全防注入策略。同时深入分析SQL查询优化、索引设计、慢查询排查等性能提升手段。通过实战案例帮助开发者构建高效、安全、可扩展的PHP数据库应用系统。

99

2025.11.13

JavaScript 性能优化与前端调优
JavaScript 性能优化与前端调优

本专题系统讲解 JavaScript 性能优化的核心技术,涵盖页面加载优化、异步编程、内存管理、事件代理、代码分割、懒加载、浏览器缓存机制等。通过多个实际项目示例,帮助开发者掌握 如何通过前端调优提升网站性能,减少加载时间,提高用户体验与页面响应速度。

36

2025.12.30

JavaScript浏览器渲染机制与前端性能优化实践
JavaScript浏览器渲染机制与前端性能优化实践

本专题围绕 JavaScript 在浏览器中的执行与渲染机制展开,系统讲解 DOM 构建、CSSOM 解析、重排与重绘原理,以及关键渲染路径优化方法。内容涵盖事件循环机制、异步任务调度、资源加载优化、代码拆分与懒加载等性能优化策略。通过真实前端项目案例,帮助开发者理解浏览器底层工作原理,并掌握提升网页加载速度与交互体验的实用技巧。

103

2026.03.06

TypeScript类型系统进阶与大型前端项目实践
TypeScript类型系统进阶与大型前端项目实践

本专题围绕 TypeScript 在大型前端项目中的应用展开,深入讲解类型系统设计与工程化开发方法。内容包括泛型与高级类型、类型推断机制、声明文件编写、模块化结构设计以及代码规范管理。通过真实项目案例分析,帮助开发者构建类型安全、结构清晰、易维护的前端工程体系,提高团队协作效率与代码质量。

25

2026.03.13

Python异步编程与Asyncio高并发应用实践
Python异步编程与Asyncio高并发应用实践

本专题围绕 Python 异步编程模型展开,深入讲解 Asyncio 框架的核心原理与应用实践。内容包括事件循环机制、协程任务调度、异步 IO 处理以及并发任务管理策略。通过构建高并发网络请求与异步数据处理案例,帮助开发者掌握 Python 在高并发场景中的高效开发方法,并提升系统资源利用率与整体运行性能。

44

2026.03.12

C# ASP.NET Core微服务架构与API网关实践
C# ASP.NET Core微服务架构与API网关实践

本专题围绕 C# 在现代后端架构中的微服务实践展开,系统讲解基于 ASP.NET Core 构建可扩展服务体系的核心方法。内容涵盖服务拆分策略、RESTful API 设计、服务间通信、API 网关统一入口管理以及服务治理机制。通过真实项目案例,帮助开发者掌握构建高可用微服务系统的关键技术,提高系统的可扩展性与维护效率。

177

2026.03.11

Go高并发任务调度与Goroutine池化实践
Go高并发任务调度与Goroutine池化实践

本专题围绕 Go 语言在高并发任务处理场景中的实践展开,系统讲解 Goroutine 调度模型、Channel 通信机制以及并发控制策略。内容包括任务队列设计、Goroutine 池化管理、资源限制控制以及并发任务的性能优化方法。通过实际案例演示,帮助开发者构建稳定高效的 Go 并发任务处理系统,提高系统在高负载环境下的处理能力与稳定性。

50

2026.03.10

Kotlin Android模块化架构与组件化开发实践
Kotlin Android模块化架构与组件化开发实践

本专题围绕 Kotlin 在 Android 应用开发中的架构实践展开,重点讲解模块化设计与组件化开发的实现思路。内容包括项目模块拆分策略、公共组件封装、依赖管理优化、路由通信机制以及大型项目的工程化管理方法。通过真实项目案例分析,帮助开发者构建结构清晰、易扩展且维护成本低的 Android 应用架构体系,提升团队协作效率与项目迭代速度。

92

2026.03.09

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
最新Python教程 从入门到精通
最新Python教程 从入门到精通

共4课时 | 22.5万人学习

Django 教程
Django 教程

共28课时 | 5万人学习

SciPy 教程
SciPy 教程

共10课时 | 1.9万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号