
本文详解 jax 中 `jax.scipy.linalg.expm` 批量计算失败的常见原因与解决方案,涵盖新版原生支持、旧版兼容写法及关键形状调试技巧。
在使用 JAX 计算矩阵指数(如量子线路中的参数化幺正演化 $ e^{iA} $)时,一个典型错误是:
ValueError: expected A to be a square matrix
尽管你确认最后两维是方阵(如 (4, 4)),但报错仍发生——这往往源于 输入张量的维度结构不符合 expm 的隐式批处理规则。
? 根本原因:expm 对输入形状有严格要求
jax.scipy.linalg.expm 自 JAX v0.4.7 起原生支持批量输入,但前提是:
✅ 输入数组的最后两个轴必须构成方阵(如 (..., n, n));
❌ 其余前导维度将被自动视为 batch 维度;
❌ 若中间存在非 batch 的冗余维度(如你的 A.shape = (2, 2, 2, 2, 2, 2, 2, 2, 4, 4)),它仍能工作;
⚠️ 但若 A 的最后两维不满足 n == n(例如 (4, 5)),或 A.ndim
在你的代码中,问题出在 pauli_matrix(num_qubits) 的构造逻辑:
def pauli_matrix(num_qubits):
_pauli_matrices = jnp.array(
[[[1, 0], [0, 1]], [[0, 1], [1, 0]], [[0, -1j], [1j, 0]], [[1, 0], [0, -1]]]
)
# ❌ 错误:对 _pauli_matrices 重复 kronecker 积,却未指定作用于哪一组 qubit
# 且 [1:] 切片导致维度混乱,最终使 tensordot 结果 A 的 shape 不符合预期
return reduce(jnp.kron, (_pauli_matrices for _ in range(num_qubits)))[1:]该函数实际生成的是 (15, 4**num_qubits, 4**num_qubits) 形状的 Pauli 基(对 2-qubit 应为 (15, 4, 4)),但 reduce(jnp.kron, ...) 在 num_qubits=2 时会生成 (4^2, 4^2) = (16, 16) 矩阵,再 [1:] 切片得 (15, 16, 16) —— 而你 theta 是 (15, 2,2,2,2,2,2,2,2),tensordot 后 A 实际为 (2,2,2,2,2,2,2,2, 16, 16),并非你误以为的 (2,...,2,4,4)。因此 expm 接收的不是 (N, 4, 4),而是高维张量,但只要末两维是方阵,新版 JAX 就能处理。
✅ 正确做法:确保 A 的 shape 为 (..., d, d),其中 d = 2**num_qubits。
✅ 解决方案一:升级 JAX 并规范输入(推荐)
确保使用 JAX ≥ 0.4.7:
pip install --upgrade jax jaxlib
然后修正 pauli_matrix 和 SpecialUnitary:
import jax.numpy as jnp
import jax.scipy.linalg as linalg
from functools import reduce
def pauli_basis_1q():
return jnp.array([
[[1., 0.], [0., 1.]], # I
[[0., 1.], [1., 0.]], # X
[[0., -1j], [1j, 0.]], # Y
[[1., 0.], [0., -1.]], # Z
])
def pauli_matrix(num_qubits):
"""返回 (4**num_qubits - 1) 个 traceless n-qubit Pauli 算符,shape (15, 4, 4) for n=2"""
basis = pauli_basis_1q()
# 构造所有非恒等的 n-qubit Pauli 张量积:共 4^n - 1 个
from itertools import product
ops = []
for indices in product(range(4), repeat=num_qubits):
if all(i == 0 for i in indices): # skip identity
continue
op = basis[indices[0]]
for i in indices[1:]:
op = jnp.kron(op, basis[i])
ops.append(op)
return jnp.stack(ops) # shape: (15, 4, 4) for num_qubits=2
num_qubits = 2
d = 2 ** num_qubits # 4
theta = jnp.pi * jnp.random.uniform(shape=(15,)) # 简化:单组参数,shape (15,)
A = jnp.tensordot(theta, pauli_matrix(num_qubits), axes=[[0], [0]]) # -> (4, 4)
U = linalg.expm(1j * A / 2) # ✅ works: (4, 4)
# 批量示例:theta shape (8, 15) → A shape (8, 4, 4) → U shape (8, 4, 4)
theta_batch = jnp.pi * jnp.random.uniform(shape=(8, 15))
A_batch = jnp.einsum('bi,ij->bjk', theta_batch, pauli_matrix(num_qubits)) # (8, 4, 4)
U_batch = linalg.expm(1j * A_batch / 2) # ✅ native batch support
print(U_batch.shape) # (8, 4, 4)⚙️ 解决方案二:旧版 JAX 兼容写法(jnp.vectorize)
若受限于旧版 JAX(
expm_vec = jnp.vectorize(linalg.expm, signature='(n,n)->(n,n)') # A_batch shape: (B, d, d) U_batch = expm_vec(1j * A_batch / 2) # returns (B, d, d)
⚠️ 注意:vectorize 在 JIT 下可能不如原生批量高效,仅作兼容之用。
? 关键检查清单
- ✅ 使用 A.shape[-2] == A.shape[-1] 验证末两维是否为方阵;
- ✅ 避免在 tensordot 或 einsum 中引入意外维度(如你的原始 theta 有 9 维,极易出错);
- ✅ 优先用 einsum 替代嵌套 tensordot 提升可读性;
- ✅ 调试时打印 A.shape 和 A.dtype,确认无 float64(JAX 默认 float32,expm 要求浮点)。
掌握这些要点,你就能稳健地在 JAX 中实现量子态演化、李群指数映射等核心计算。










