0

0

如何在 JAX 中实现支持动态形状的 next 函数(字符串重写系统)

碧海醫心

碧海醫心

发布时间:2026-01-03 13:10:49

|

124人浏览过

|

来源于php中文网

原创

如何在 JAX 中实现支持动态形状的 next 函数(字符串重写系统)

本文详解如何在 jax 中绕过动态形状限制,用静态形状语义实现可向量化、可递归应用的规则替换函数,适用于字符串重写系统等场景。

JAX 的核心约束之一是:所有变换(如 vmap、jit、grad)要求中间和输出数组具有静态形状(static shape)——即形状不能依赖于运行时张量值。而原始 replace_first_one 函数中,jnp.where(arr == 1)[0] 的长度、jnp.concatenate 的结果长度均随输入数据动态变化,导致 vmap 报错:size argument of jnp.nonzero must be statically specified。

要使函数兼容 vmap,关键在于消除所有动态形状分支,转而采用“填充-掩码-切片”范式,确保每一步输出尺寸完全可推导且恒定。以下是完整、可向量化、可嵌套调用的实现方案:

✅ 正确实现(静态形状兼容版)

import jax
import jax.numpy as jnp
from jax import vmap

# 所有规则必须统一长度(静态形状前提)
rules_int = jnp.array([
    [0, 0],   # rule 0 → length=2
    [1, 1],   # rule 1 → length=2
], dtype=jnp.int32)

def replace_first_one(arr, action):
    """
    静态形状安全的首 '1' 替换函数:
    - 输入 arr 形状固定为 (L,),action 为标量
    - 输出形状固定为 (L + 1),因每次替换:删1个元素 + 插入2个 → 净增1
    - 使用 jnp.where(..., size=1) 强制返回固定长度索引
    - 使用 dynamic_update_slice 避免 concat 动态拼接
    """
    L = arr.shape[0]
    # 查找第一个 1 的位置;若不存在,index = L(越界,后续逻辑自动跳过替换)
    indices = jnp.where(arr == 1, size=1, fill_value=L)[0]
    index = indices[0]  # scalar index, static at compile time

    # 待插入规则向量
    rule = rules_int[action]

    # 预分配输出数组(长度 = L - 1 + len(rule) = L + 1)
    output = jnp.zeros(L + 1, dtype=arr.dtype)

    # 构造插入前段:arr[:index]
    pre = jnp.where(jnp.arange(L) < index, arr, 0)
    # 构造插入后段:arr[index+1:]
    post = jnp.where(jnp.arange(L) > index, arr, 0)

    # 拼接三段(逻辑上)→ 但实际用 dynamic_update_slice 实现高效写入
    # 更简洁做法:先填入 arr[0:index] 和 arr[index+1:],再覆盖插入 rule
    # 我们分步构造:
    output = output.at[:index].set(arr[:index])
    output = output.at[index:index+2].set(rule)  # rule 长度固定为 2
    output = output.at[index+2:].set(arr[index+1:L-1])  # 注意对齐长度

    # ⚠️ 上述 set 可能越界,更鲁棒写法:使用 dynamic_update_slice + mask
    # 推荐最终版本(零拷贝、边界安全):
    output = jnp.zeros(L + 1, dtype=arr.dtype)
    # 写入前段 [0:index]
    output = jax.lax.dynamic_update_slice(output, arr[:index], (0,))
    # 写入规则 [index:index+2]
    output = jax.lax.dynamic_update_slice(output, rule, (index,))
    # 写入后段 [index+2:]
    tail_start = index + 1  # 原 arr 中跳过 index 后的起始位置
    tail_len = L - tail_start
    pad_len = (L + 1) - (index + 2) - tail_len
    padded_tail = jnp.pad(arr[tail_start:], (0, pad_len), constant_values=0)
    output = jax.lax.dynamic_update_slice(output, padded_tail, (index + 2,))

    return output

# ✅ 现在可安全 vmap
batch_arr = jnp.array([
    [1, 4, 5, 1],  # → 替换第0个1 → [0,0,4,5,1]
    [6, 1, 8, 1],  # → 替换第1个1 → [6,1,1,8,1]
])
batch_actions = jnp.array([0, 1])

vectorized_replace = vmap(replace_first_one, in_axes=(0, 0))
result = vectorized_replace(batch_arr, batch_actions)
print(result)
# [[0 0 4 5 1]
#  [6 1 1 8 1]]

? 支持递归应用(字符串重写系统)

若需反复应用规则直至无 1 可替换(即模拟图灵机或 L-system),可用 jax.lax.while_loop 实现静态迭代上限下的循环:

PathFinder
PathFinder

AI驱动的销售漏斗分析工具

下载
def rewrite_until_stable(init_arr, max_steps=10):
    def cond_fn(state):
        arr, step = state
        has_one = jnp.any(arr == 1)
        return jnp.logical_and(has_one, step < max_steps)

    def body_fn(state):
        arr, step = state
        # 找到首个 1 对应的 action(示例:固定用 rule 0;实际可查表)
        index = jnp.where(arr == 1, size=1, fill_value=arr.shape[0])[0][0]
        action = jnp.where(index < arr.shape[0], 0, 0)  # placeholder
        new_arr = replace_first_one(arr, action)
        return new_arr, step + 1

    final_arr, _ = jax.lax.while_loop(cond_fn, body_fn, (init_arr, 0))
    return final_arr

# 示例:jnp.array([1]) → [0,0] → 无1 → 停止
stable = rewrite_until_stable(jnp.array([1]))
print(stable)  # [0 0]

⚠️ 注意事项与权衡

  • 规则长度必须统一:这是静态形状的硬性要求。若原始规则长度不一(如 [0,0] vs [1,1,1]),需补零或截断至最大长度,并用 mask 控制有效区域。
  • 性能提示:dynamic_update_slice 比 concatenate 更适合 JIT;避免 jnp.where 无 size 参数的用法。
  • 替代方案:若业务逻辑必须支持真正变长输出(如生成不同长度 token 序列),则应考虑:
    • 在 Python 层循环(放弃 vmap 加速);
    • 使用 jax.vmap + jax.pmap 分 batch 处理,每 batch 内部统一 padding;
    • 迁移至支持 ragged tensor 的框架(如 TensorFlow with tf.RaggedTensor),但将失去 JAX 生态优势。

总之,JAX 中的“动态形状”并非不可逾越,而是需要以静态契约重构逻辑——通过预分配、填充、掩码与 slice 操作,在编译期锁定维度,从而释放 vmap/jit 的全部潜力。

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

WorkBuddy
WorkBuddy

腾讯云推出的AI原生桌面智能体工作台

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
登录token无效
登录token无效

登录token无效解决方法:1、检查token的有效期限,如果token已经过期,需要重新获取一个新的token;2、检查token的签名,如果签名不正确,需要重新获取一个新的token;3、检查密钥的正确性,如果密钥不正确,需要重新获取一个新的token;4、使用HTTPS协议传输token,建议使用HTTPS协议进行传输 ;5、使用双因素认证,双因素认证可以提高账户的安全性。

6629

2023.09.14

登录token无效怎么办
登录token无效怎么办

登录token无效的解决办法有检查Token是否过期、检查Token是否正确、检查Token是否被篡改、检查Token是否与用户匹配、清除缓存或Cookie、检查网络连接和服务器状态、重新登录或请求新的Token、联系技术支持或开发人员等。本专题为大家提供token相关的文章、下载、课程内容,供大家免费下载体验。

843

2023.09.14

token怎么获取
token怎么获取

获取token值的方法:1、小程序调用“wx.login()”获取 临时登录凭证code,并回传到开发者服务器;2、开发者服务器以code换取,用户唯一标识openid和会话密钥“session_key”。想了解更详细的内容,可以阅读本专题下面的文章。

1092

2023.12.21

token什么意思
token什么意思

token是一种用于表示用户权限、记录交易信息、支付虚拟货币的数字货币。可以用来在特定的网络上进行交易,用来购买或出售特定的虚拟货币,也可以用来支付特定的服务费用。想了解更多token什么意思的相关内容可以访问本专题下面的文章。

2189

2024.03.01

js 字符串转数组
js 字符串转数组

js字符串转数组的方法:1、使用“split()”方法;2、使用“Array.from()”方法;3、使用for循环遍历;4、使用“Array.split()”方法。本专题为大家提供js字符串转数组的相关的文章、下载、课程内容,供大家免费下载体验。

760

2023.08.03

js截取字符串的方法
js截取字符串的方法

js截取字符串的方法有substring()方法、substr()方法、slice()方法、split()方法和slice()方法。本专题为大家提供字符串相关的文章、下载、课程内容,供大家免费下载体验。

221

2023.09.04

java基础知识汇总
java基础知识汇总

java基础知识有Java的历史和特点、Java的开发环境、Java的基本数据类型、变量和常量、运算符和表达式、控制语句、数组和字符串等等知识点。想要知道更多关于java基础知识的朋友,请阅读本专题下面的的有关文章,欢迎大家来php中文网学习。

1567

2023.10.24

字符串介绍
字符串介绍

字符串是一种数据类型,它可以是任何文本,包括字母、数字、符号等。字符串可以由不同的字符组成,例如空格、标点符号、数字等。在编程中,字符串通常用引号括起来,如单引号、双引号或反引号。想了解更多字符串的相关内容,可以阅读本专题下面的文章。

649

2023.11.24

Python异步编程与Asyncio高并发应用实践
Python异步编程与Asyncio高并发应用实践

本专题围绕 Python 异步编程模型展开,深入讲解 Asyncio 框架的核心原理与应用实践。内容包括事件循环机制、协程任务调度、异步 IO 处理以及并发任务管理策略。通过构建高并发网络请求与异步数据处理案例,帮助开发者掌握 Python 在高并发场景中的高效开发方法,并提升系统资源利用率与整体运行性能。

37

2026.03.12

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
最新Python教程 从入门到精通
最新Python教程 从入门到精通

共4课时 | 22.5万人学习

Django 教程
Django 教程

共28课时 | 5万人学习

SciPy 教程
SciPy 教程

共10课时 | 1.9万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号