0

0

Positional Encoding高频位置编码

P粉084495128

P粉084495128

发布时间:2025-07-29 11:05:31

|

379人浏览过

|

来源于php中文网

原创

本文围绕位置编码(Positional Encoding)在神经网络拟合图片中的作用展开实验。通过PaddlePaddle构建NeRF2D神经网络,分别在不使用和使用位置编码的情况下,以坐标预测像素RGB值。对比发现,加入含三角函数的位置编码后,拟合结果更清晰,验证了其提升神经网络表示能力的效果。

☞☞☞AI 智能聊天, 问答助手, AI 智能搜索, 免费无限量使用 DeepSeek R1 模型☜☜☜

positional encoding高频位置编码 - php中文网

Positional Encoding 高频位置编码

Positional Encoding 是神经网络设计的常用技巧。 例如 NeRF 提到, 用一个神经网络来表示一个场景。给定任何一个像素点的坐标(和观察方向)作为输入,神经网络输出这个点的像素值。但是, 如果输入只是单纯的坐标, 则神经网络表示的场景往往比较模糊. 但是如果额外输入坐标的多个三角函数值,

[sin(x),sin(2x),sin(4x),sin(8x),...,cos(x),cos(2x),cos(4x),cos(8x),...][sin(x),sin(2x),sin(4x),sin(8x),...,cos(x),cos(2x),cos(4x),cos(8x),...]

则所得到的结果可能会更加清晰。这被称为位置编码 (positional encoding)。


本项目利用 PaddlePaddle 做一个2D版本的简单实验 (这个实验在 CVPR 2020 Tutorial 的视频中提到):

已有一张图片, 用一个神经网络拟合它:对于任意一个点的二维坐标 (x,y)(x,y), 输出(预测)图片在 (x,y)(x,y) 处的像素值 (r,g,b)(r,g,b). 如果加入位置编码 (positional encoding), 则输入变成了

[x,y,sin(x),sin(y),sin(2x),sin(2y),...,cos(x),cos(y),cos(2x),cos(2y),...][x,y,sin(x),sin(y),sin(2x),sin(2y),...,cos(x),cos(y),cos(2x),cos(2y),...]

输出图片在该点的 RGB 值。

并比较不加入与加入位置编码的两种方法清晰度。


本 notebook 执行用时 (用AI Studio的V100 32GB) < 10分钟

Nanonets
Nanonets

基于AI的自学习OCR文档处理,自动捕获文档数据

下载

下图范例来自CVPR 2020 Tutorial, 左边是真实图片, 中间是不用位置编码的神经网络拟合结果, 右边采用了位置编码更加清晰。

本项目将复现类似的结果。Positional Encoding高频位置编码 - php中文网        

In [2]
import paddlefrom matplotlib import pyplot as plt 
import numpy as np 
from tqdm import tqdm 
from PIL import Image
   

准备工作

很少有更简单的数据准备工作了——随便找一张图片即可。

In [51]
# 随便找一张图片, 放进项目的文件夹里并用 Image 库打开img = Image.open('work/example1.jpg')
img = np.array(img)print(img.shape)
plt.figure(figsize=(5,7))
plt.imshow(img)
       
(960, 720, 3)
       
<matplotlib.image.AxesImage at 0x7fb48d122290>
               
<Figure size 360x504 with 1 Axes>
               
In [14]
# 类似 NeRF 的想法, 用一个神经网络表示图片:给一个像素坐标 (x,y), 预测该像素点的 RGB 值。# 神经网络简单地用若干 Linear 层 和 ReLU 激活函数堆叠而成。class NeRF2D(paddle.nn.Layer):
    def __init__(self, input_size: int = 2, layers: int = 5, hidden_size: int = 256):
        super().__init__()        assert layers >= 2, 'MLP should have at least 2 layers.'
        input_size :int = input_size # 2d 输入: 像素坐标 (x,y)
        output_size:int = 3          # 3d 输出: 坐标的对应RGB值 (R,G,B)
    
        # 保存一些参数
        self.input_size = input_size 
        self.output_size = output_size 
        self.layers = layers 

        # 用一个 Layerlist 存放所有层
        self.mlps = paddle.nn.LayerList()
        mlp_dim = [input_size] + [hidden_size] * (layers - 1) + [output_size]        for layer in range(layers):            # 全连接层
            self.mlps.append(
                paddle.nn.Linear(mlp_dim[layer], mlp_dim[layer+1])
            )            # ReLU层
            self.mlps.append(paddle.nn.ReLU())    def forward(self, x):
        # 让 x 通过所有层
        for layer in range(self.layers * 2): 
            x = self.mlps[layer](x)        return x
   

不使用 Positional Encoding

用 F(x,y)F(x,y) 表示神经网络函数 ((x,y)(x,y) 是输入), 则损失函数为神经网络的预测 F(x,y)F(x,y) 与真实图片的像素值 RGB(x,y)RGB(x,y) 的差距 (L2 损失)。

L=1Nall pixels (x,y)(F(x,y)RGB(x,y))2L=N1all pixels (x,y)∑(F(x,y)−RGB(x,y))2

In [29]
w , h = img.shape[0], img.shape[1]
pi = np.pi 
# 将 w*h 个像素坐标 (x,y) 作为输入# 这里不妨将图片看成 [-pi/2, pi/2] x [-pi/2, pi/2] 的坐标系# (因为神经网络的输入和输出最好不要太大, 每个值 <= 1 比较好)inputs = paddle.zeros((w,h,2))
inputs[:,:,0] += paddle.linspace(-pi/2, pi/2, num=w).reshape((w,1)) # 纵坐标inputs[:,:,1] += paddle.linspace(-pi/2, pi/2, num=h).reshape((1,h)) # 横坐标# 输入相当于 w*h 条二维数据 (输入不是很大, 可以一次性全部计算, 不需要 batch)inputs = inputs.reshape((w*h, 2))# 作为参照的, 期望的输出是 w*h 个像素的 RGB 值, 即图片本身outputs = paddle.to_tensor(img, dtype='float32').reshape((w*h, 3))
outputs = outputs / 255. # 由于网络最后是 ReLU, 所以期望的输出应该 >= 0, 最好介于 [0,1]
   
In [44]
# 创建一个神经网络, 用 Adam 优化器net = NeRF2D()
opt = paddle.optimizer.Adam(parameters = net.parameters())
losses = []
   
In [45]
# 训练 2000 步 (3分钟左右)epochs = 2000for epoch in tqdm(range(len(losses)+1, len(losses)+epochs+1)):
    render = net(inputs) # (w*h, 3)
    # 将神经网络根据坐标预测的渲染 (render) 结果与原图片对比, 用 L2 损失作为损失函数
    loss = paddle.mean(paddle.square(render - outputs))
    losses.append(loss.item())    # 反向传播+梯度下降
    opt.clear_grad()
    loss.backward()
    opt.step()
       
100%|██████████| 2000/2000 [03:06<00:00, 10.70it/s]
       
In [46]
# 观察 loss 曲线 和 最后一次渲染的结果, # 神经网络生成的图片有模有样, 缺点是比较模糊plt.figure(figsize=(12,7))
plt.subplot(1,2,1)
plt.semilogy(losses)
plt.subplot(1,2,2)
plt.imshow(render.reshape((w,h,3)))
       
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
       
<matplotlib.image.AxesImage at 0x7fb48ce69510>
               
<Figure size 864x504 with 2 Axes>
               

使用 Positional Encoding

使用了 Positional Encoding, 即将每个输入从 (x,y)(x,y) 变成

[x,y,sin(x),sin(y),sin(2x),sin(2y),...,cos(x),cos(y),cos(2x),cos(2y),...][x,y,sin(x),sin(y),sin(2x),sin(2y),...,cos(x),cos(y),cos(2x),cos(2y),...]

这里我们最高取到 20,21,22,23,2420,21,22,23,24 倍.

In [39]
# 接下来是使用 Positional Encoding 的改进# 这时候每个输入不再是 2 维的 (x,y)# 而是多维的 (x,y,sinx,siny,cosx,cosy,sin(2x),sin(2y),cos(2x),cos(2y),...)positional_encoding = [inputs]
freqs = [1.,2.,4.,8.,16.]for freq in freqs:
    positional_encoding.append(paddle.cos(inputs * freq))
    positional_encoding.append(paddle.sin(inputs * freq))

positional_encoding = paddle.to_tensor(positional_encoding, dtype='float32') # shape = [1+2*Freqs, w*h, 2]positional_encoding = positional_encoding.transpose((1,2,0)).reshape((w*h, -1))print(positional_encoding.shape)
       
[691200, 22]
       
In [41]
# 再创建一个神经网络, 这次每条输入的维度是 22net = NeRF2D(input_size = positional_encoding.shape[1])
opt = paddle.optimizer.Adam(parameters = net.parameters())
losses = []
   
In [42]
# 训练 2000 步 (3分钟左右)epochs = 2000for epoch in tqdm(range(len(losses)+1, len(losses)+epochs+1)):
    render = net(positional_encoding) # (w*h, 3)
    # 将神经网络根据坐标预测的渲染 (render) 结果与原图片对比, 用 L2 损失作为损失函数
    loss = paddle.mean(paddle.square(render - outputs))
    losses.append(loss.item())

    opt.clear_grad()
    loss.backward()
    opt.step()
       
100%|██████████| 2000/2000 [03:08<00:00, 10.62it/s]
       
In [43]
# 观察 loss 曲线 和 最后一次渲染的结果, 可以看出比不用 positional encoding 要清晰很多plt.figure(figsize=(12,7))
plt.subplot(1,2,1)
plt.semilogy(losses)
plt.subplot(1,2,2)
plt.imshow(render.reshape((w,h,3)))
       
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
       
<matplotlib.image.AxesImage at 0x7fb48cc63450>
               
<Figure size 864x504 with 2 Axes>
               

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

WorkBuddy
WorkBuddy

腾讯云推出的AI原生桌面智能体工作台

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
数据分析的方法
数据分析的方法

数据分析的方法有:对比分析法,分组分析法,预测分析法,漏斗分析法,AB测试分析法,象限分析法,公式拆解法,可行域分析法,二八分析法,假设性分析法。php中文网为大家带来了数据分析的相关知识、以及相关文章等内容。

504

2023.07.04

数据分析方法有哪几种
数据分析方法有哪几种

数据分析方法有:1、描述性统计分析;2、探索性数据分析;3、假设检验;4、回归分析;5、聚类分析。本专题为大家提供数据分析方法的相关的文章、下载、课程内容,供大家免费下载体验。

292

2023.08.07

网站建设功能有哪些
网站建设功能有哪些

网站建设功能包括信息发布、内容管理、用户管理、搜索引擎优化、网站安全、数据分析、网站推广、响应式设计、社交媒体整合和电子商务等功能。这些功能可以帮助网站管理员创建一个具有吸引力、可用性和商业价值的网站,实现网站的目标。

757

2023.10.16

数据分析网站推荐
数据分析网站推荐

数据分析网站推荐:1、商业数据分析论坛;2、人大经济论坛-计量经济学与统计区;3、中国统计论坛;4、数据挖掘学习交流论坛;5、数据分析论坛;6、网站数据分析;7、数据分析;8、数据挖掘研究院;9、S-PLUS、R统计论坛。想了解更多数据分析的相关内容,可以阅读本专题下面的文章。

534

2024.03.13

Python 数据分析处理
Python 数据分析处理

本专题聚焦 Python 在数据分析领域的应用,系统讲解 Pandas、NumPy 的数据清洗、处理、分析与统计方法,并结合数据可视化、销售分析、科研数据处理等实战案例,帮助学员掌握使用 Python 高效进行数据分析与决策支持的核心技能。

82

2025.09.08

Python 数据分析与可视化
Python 数据分析与可视化

本专题聚焦 Python 在数据分析与可视化领域的核心应用,系统讲解数据清洗、数据统计、Pandas 数据操作、NumPy 数组处理、Matplotlib 与 Seaborn 可视化技巧等内容。通过实战案例(如销售数据分析、用户行为可视化、趋势图与热力图绘制),帮助学习者掌握 从原始数据到可视化报告的完整分析能力。

60

2025.10.14

Python异步编程与Asyncio高并发应用实践
Python异步编程与Asyncio高并发应用实践

本专题围绕 Python 异步编程模型展开,深入讲解 Asyncio 框架的核心原理与应用实践。内容包括事件循环机制、协程任务调度、异步 IO 处理以及并发任务管理策略。通过构建高并发网络请求与异步数据处理案例,帮助开发者掌握 Python 在高并发场景中的高效开发方法,并提升系统资源利用率与整体运行性能。

37

2026.03.12

C# ASP.NET Core微服务架构与API网关实践
C# ASP.NET Core微服务架构与API网关实践

本专题围绕 C# 在现代后端架构中的微服务实践展开,系统讲解基于 ASP.NET Core 构建可扩展服务体系的核心方法。内容涵盖服务拆分策略、RESTful API 设计、服务间通信、API 网关统一入口管理以及服务治理机制。通过真实项目案例,帮助开发者掌握构建高可用微服务系统的关键技术,提高系统的可扩展性与维护效率。

136

2026.03.11

Go高并发任务调度与Goroutine池化实践
Go高并发任务调度与Goroutine池化实践

本专题围绕 Go 语言在高并发任务处理场景中的实践展开,系统讲解 Goroutine 调度模型、Channel 通信机制以及并发控制策略。内容包括任务队列设计、Goroutine 池化管理、资源限制控制以及并发任务的性能优化方法。通过实际案例演示,帮助开发者构建稳定高效的 Go 并发任务处理系统,提高系统在高负载环境下的处理能力与稳定性。

47

2026.03.10

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
10分钟--Midjourney创作自己的漫画
10分钟--Midjourney创作自己的漫画

共1课时 | 0.1万人学习

Midjourney 关键词系列整合
Midjourney 关键词系列整合

共13课时 | 0.9万人学习

AI绘画教程
AI绘画教程

共2课时 | 0.2万人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号