0

0

[源码解析] PyTorch 分布式(16) --- 使用异步执行实现批处理 RPC

爱谁谁

爱谁谁

发布时间:2025-06-26 13:00:12

|

504人浏览过

|

来源于php中文网

原创

[源码解析] PyTorch 分布式(16) --- 使用异步执行实现批处理 RPC

目录

[源码解析] PyTorch 分布式(16) --- 使用异步执行实现批处理 RPC0x00 摘要0x01 前言1.1 先决条件1.2 基础知识1.3 代码0x02 启动2.1 总体启动2.2 启动参数服务器0x03 参数服务器0x04 Trainer0x05 对比0xFF 参考0x00 摘要

在前面的文章之中,我们已经学习了PyTorch 分布式的基本模块,接下来我们通过几篇文章来看看如何把这些模块应用到实践之中,顺便把PyTorch分布式逻辑整体梳理一下。本文介绍如何使用异步执行操作来实现批处理 RPC,大家可以学习到PyTorch对参数服务器一个新的实现方式。

本文以IMPLEMENTING BATCH RPC PROCESSING USING ASYNCHRONOUS EXECUTIONS的翻译为基础,加入了自己的理解。

0x01 前言1.1 先决条件

本文的先决条件如下:

PyTorch 分布式概述分布式 RPC 框架入门使用分布式 RPC 框架实现参数服务器RPC 异步执行装饰器

本教程演示了如何使用@rpc.functions.async_execution 装饰器构建批处理 RPC 应用程序,这有助于通过减少被阻塞的 RPC 线程的数量,并且在被调用方整合 CUDA 操作来加快训练速度。这与使用 TorchServer 进行批量推理的想法相同。Batch RPC 有助于将动作整合到较少的 CUDA 操作中,从而摊销开销。

注意:本教程需要 PyTorch v1.6.0 或更高版本。

1.2 基础知识

之前的教程已经展示了使用torch.distributed.rpc构建分布式训练应用程序的步骤,但他们没有详细说明在处理 RPC 请求时被调用方会发生什么。从 PyTorch v1.5 开始,针对每个 RPC 请求,被调用者都会启动一个线程来执行该请求中的函数,该线程会阻塞直到该函数返回。这适用于许多用例,但有一个问题:如果用户函数在 IO 上阻塞,例如使用嵌套的 RPC 调用或信号(例如等待不同的 RPC 请求来解除阻塞),则被调用者上的 RPC 线程将不得不空闲等待,直到 IO 完成或信号(signal)事件发生。因此,RPC 被调用者使用的线程可能会使用比实际需要更多。造成这个问题的原因是RPC把用户函数当成黑盒,对函数中发生的事情知之甚少。为了让用户函数能够让出和释放 RPC 线程,需要向 RPC 系统提供更多的提示。

从 v1.6.0 开始,PyTorch 通过引入两个新概念来解决这个问题:

torch.futures.Future 封装了一个异步执行,同时也支持安装回调函数。@rpc.functions.async_execution 装饰器,它允许应用程序告诉被调用者,本目标函数将返回一个future,并且可以在执行过程中多次暂停和yield。

使用这两个工具,应用程序代码可以将用户函数分解为多个较小的函数,将它们链接在一起作为Future 对象的回调方法,并返回包含最终结果的 Future给调用者。在被调用方,在获取Future对象时,它也会安装后续的 RPC 响应处理作为回调方法,这些回调会在最终结果准备好时被触发。这样,被调用者不再需要阻塞一个线程,只是等待最终返回值准备好就行。 简单的例子请参考@rpc.functions.async_execution的API文档 。

除了减少被调用者的空闲线程数量外,这些工具还使批处理 RPC 处理更容易、更快。本教程演示了如何使用@rpc.functions.async_execution 装饰器构建分布式批量更新参数服务器和批量处理强化学习应用程序 。

注:我们不考虑强化学习的领域,那样会影响我们的思路,牵扯精力

1.3 代码

因为原文主要是强化学习代码讲解,而我们只关注普通分布式批量更新参数服务器,所以需要看原始代码。

代码位于 https://github.com/pytorch/examples/blob/master/distributed/rpc/batch/parameter_server.py。先全部摘录如下:

代码语言:javascript代码运行次数:0运行复制
<code class="javascript">import osimport threadingfrom datetime import datetimeimport torchimport torch.distributed.rpc as rpcimport torch.multiprocessing as mpimport torch.nn as nnfrom torch import optimimport torchvisionbatch_size = 20image_w = 64image_h = 64num_classes = 30batch_update_size = 5num_batches = 6def timed_log(text):    print(f"{datetime.now().strftime('%H:%M:%S')} {text}")class BatchUpdateParameterServer(object):    def __init__(self, batch_update_size=batch_update_size):        self.model = torchvision.models.resnet50(num_classes=num_classes)        self.lock = threading.Lock()        self.future_model = torch.futures.Future()        self.batch_update_size = batch_update_size        self.curr_update_size = 0        self.optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)        for p in self.model.parameters():            p.grad = torch.zeros_like(p)    def get_model(self):        return self.model    @staticmethod    @rpc.functions.async_execution    def update_and_fetch_model(ps_rref, grads):        self = ps_rref.local_value()        timed_log(f"PS got {self.curr_update_size}/{batch_update_size} updates")        for p, g in zip(self.model.parameters(), grads):            p.grad += g        with self.lock:            self.curr_update_size += 1            fut = self.future_model            if self.curr_update_size >= self.batch_update_size:                for p in self.model.parameters():                    p.grad /= self.batch_update_size                self.curr_update_size = 0                self.optimizer.step()                self.optimizer.zero_grad()                fut.set_result(self.model)                timed_log("PS updated model")                self.future_model = torch.futures.Future()        return futclass Trainer(object):    def __init__(self, ps_rref):        self.ps_rref = ps_rref        self.loss_fn = nn.MSELoss()        self.one_hot_indices = torch.LongTensor(batch_size) \                                    .random_(0, num_classes) \                                    .view(batch_size, 1)    def get_next_batch(self):        for _ in range(num_batches):            inputs = torch.randn(batch_size, 3, image_w, image_h)            labels = torch.zeros(batch_size, num_classes) \                        .scatter_(1, self.one_hot_indices, 1)            yield inputs.cuda(), labels.cuda()    def train(self):        name = rpc.get_worker_info().name        m = self.ps_rref.rpc_sync().get_model().cuda()        for inputs, labels in self.get_next_batch():            timed_log(f"{name} processing one batch")            self.loss_fn(m(inputs), labels).backward()            timed_log(f"{name} reporting grads")            m = rpc.rpc_sync(                self.ps_rref.owner(),                BatchUpdateParameterServer.update_and_fetch_model,                args=(self.ps_rref, [p.grad for p in m.cpu().parameters()]),            ).cuda()            timed_log(f"{name} got updated model")def run_trainer(ps_rref):    trainer = Trainer(ps_rref)    trainer.train()def run_ps(trainers):    timed_log("Start training")    ps_rref = rpc.RRef(BatchUpdateParameterServer())    futs = []    for trainer in trainers:        futs.append(            rpc.rpc_async(trainer, run_trainer, args=(ps_rref,))        )    torch.futures.wait_all(futs)    timed_log("Finish training")def run(rank, world_size):    os.environ['MASTER_ADDR'] = 'localhost'    os.environ['MASTER_PORT'] = '29500'    options=rpc.TensorPipeRpcBackendOptions(        num_worker_threads=16,        rpc_timeout=0  # infinite timeout     )    if rank != 0:        rpc.init_rpc(            f"trainer{rank}",            rank=rank,            world_size=world_size,            rpc_backend_options=options        )        # trainer passively waiting for ps to kick off training iterations    else:        rpc.init_rpc(            "ps",            rank=rank,            world_size=world_size,            rpc_backend_options=options        )        run_ps([f"trainer{r}" for r in range(1, world_size)])    # block until all rpcs finish    rpc.shutdown()if __name__=="__main__":    world_size = batch_update_size + 1    mp.spawn(run, args=(world_size, ), nprocs=world_size, join=True)</code>
0x02 启动

我们首先看看如何启动。

2.1 总体启动

我们假设有一个master(rank 0),一个worker。Master 之上运行的是参数服务器,worker 之上是训练代码。

代码语言:javascript代码运行次数:0运行复制
<code class="javascript">def run(rank, world_size):    os.environ['MASTER_ADDR'] = 'localhost'    os.environ['MASTER_PORT'] = '29500'    options=rpc.TensorPipeRpcBackendOptions(        num_worker_threads=16,        rpc_timeout=0  # infinite timeout     )    if rank != 0:        rpc.init_rpc( # 训练代码            f"trainer{rank}",            rank=rank,            world_size=world_size,            rpc_backend_options=options        )        # trainer passively waiting for ps to kick off training iterations    else:        rpc.init_rpc( # 参数服务器            "ps",             rank=rank,            world_size=world_size,            rpc_backend_options=options        )        run_ps([f"trainer{r}" for r in range(1, world_size)])    # block until all rpcs finish    rpc.shutdown()if __name__=="__main__":    world_size = batch_update_size + 1    mp.spawn(run, args=(world_size, ), nprocs=world_size, join=True)</code>

逻辑如下图:

代码语言:javascript代码运行次数:0运行复制
<code class="javascript">             torch.multiprocessing.spawn                        +                        |                        |           +------------+-------------------------------------------------           |                                                             |           |                                                             |           v                                                             v+----------+----------------------------------------------+ +------------+----------------+| "ps"                                           rank = 0 | | f"trainer{rank}"   rank = 1 ||                                                         | |                             ||                                                         | |                             ||                     rpc.init_rpc                        | |         rpc.init_rpc        ||                                                         | |                             ||                                                         | |                             ||  run_ps([f"trainer{r}" for r in range(1, world_size)])  | |                             ||                                                         | |                             ||                                                         | |                             |+---------------------------------------------------------+ +-----------------------------+</code>
2.2 启动参数服务器

run_ps 启动了参数服务器和trainer。注意,这里在参数服务器之中启动 trainer,即,master 不仅仅有一个参数服务器,还负责通过 rpc 来驱动trainer上的训练循环。

Loki.Build
Loki.Build

AI原生网站构建工具

下载
代码语言:javascript代码运行次数:0运行复制
<code class="javascript">def run_ps(trainers):    timed_log("Start training")    ps_rref = rpc.RRef(BatchUpdateParameterServer())    futs = []    for trainer in trainers: # trainer 是字符串,比如"trainer1"        futs.append(            rpc.rpc_async(trainer, run_trainer, args=(ps_rref,)) # 运行run_trainer        )    torch.futures.wait_all(futs)    timed_log("Finish training")    def run_trainer(ps_rref):    trainer = Trainer(ps_rref)    trainer.train() # 调用 Trainer 的方法   </code>

具体拓展如下:

这里没有给出参数服务器和trainer的逻辑,我们会在后续分析之后陆续给出。trainer 也只给出了一个。

[源码解析] PyTorch 分布式(16) --- 使用异步执行实现批处理 RPC
0x03 参数服务器

上面图中没有给出具体参数服务器代码,我们接下来就分析一下。

这里考虑具有一个参数服务器 (PS) 和多个trainer的同步训练应用程序。在这个应用中,PS 持有参数并等待所有训练器报告梯度。在每次迭代中,它等待直到从所有训练器接收梯度,然后一次性更新所有参数。

下面的代码显示了 PS 类的实现。

PS初始化时候生成了常规SGB优化器,不是分布式优化器,而且优化器是在PS之上update_and_fetch_model方法被 @rpc.functions.async_execution所装饰,将由trainer调用。每次调用都会返回一个Future对象,该对象将被用来处理更新后的模型。大多数训练器发起的调用只是累积梯度到 .grad成员变量 ,然后立即返回,并在 PS 上产生 RPC 线程。最后到达的训练器将触发优化器步骤并消耗所有先前上报的梯度。然后它使用更新后的模型来设置future_model,这是依靠通过Future对象来依次通知来自其他训练者的先前请求,并将更新后的模型发送给所有训练者。

具体代码如下:

代码语言:javascript代码运行次数:0运行复制
<code class="javascript">batch_size = 20image_w = 64image_h = 64num_classes = 30batch_update_size = 5num_batches = 6def timed_log(text):    print(f"{datetime.now().strftime('%H:%M:%S')} {text}")class BatchUpdateParameterServer(object):    def __init__(self, batch_update_size=batch_update_size):        self.model = torchvision.models.resnet50(num_classes=num_classes)        self.lock = threading.Lock()        self.future_model = torch.futures.Future()        self.batch_update_size = batch_update_size        self.curr_update_size = 0        # 重点:这里是常规SGB优化器,不是分布式优化器        self.optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)        for p in self.model.parameters():            p.grad = torch.zeros_like(p)    def get_model(self):        return self.model    @staticmethod    @rpc.functions.async_execution # trainer会直接调用    def update_and_fetch_model(ps_rref, grads):        self = ps_rref.local_value()        timed_log(f"PS got {self.curr_update_size}/{batch_update_size} updates")        for p, g in zip(self.model.parameters(), grads): # 得到            p.grad += g # 累积梯度        with self.lock:            self.curr_update_size += 1            fut = self.future_model            if self.curr_update_size >= self.batch_update_size:                # 最后到达的训练器将触发优化器步骤并消耗所有先前上报的梯度。                for p in self.model.parameters():                    p.grad /= self.batch_update_size                self.curr_update_size = 0                self.optimizer.step() # 更新模型                self.optimizer.zero_grad()                fut.set_result(self.model) # 将更新后的模型发送给所有训练者                timed_log("PS updated model")                self.future_model = torch.futures.Future() # 使用更新后的模型来设置future_model        return fut # 该对象将被用来处理更新后的模型</code>

逻辑拓展如下,这里省略了参数服务器生成trainer的步骤:

手机如下:

[源码解析] PyTorch 分布式(16) --- 使用异步执行实现批处理 RPC
0x04 Trainer

对于训练器,它们都使用来自 PS 的相同参数集进行初始化。在每次迭代中执行如下操作:

每个训练器首先运行前向和后向传播以在本地生成梯度。然后,每个训练器使用 RPC 向 PS 报告其梯度,并通过同一 RPC 请求的返回值取回更新后的参数。

在训练器的实现中,目标函数是否被标记 @rpc.functions.async_execution是没有区别的。训练器只需使用 rpc_sync 调用update_and_fetch_model,其将阻塞训练器,直到返回更新的模型。

可以看到,参数服务器存储模型,模型可以返回到trainer。

代码语言:javascript代码运行次数:0运行复制
<code class="javascript">class Trainer(object):    def __init__(self, ps_rref):        self.ps_rref = ps_rref        self.loss_fn = nn.MSELoss()        self.one_hot_indices = torch.LongTensor(batch_size) \                                    .random_(0, num_classes) \                                    .view(batch_size, 1)    def get_next_batch(self):        for _ in range(num_batches):            inputs = torch.randn(batch_size, 3, image_w, image_h)            labels = torch.zeros(batch_size, num_classes) \                        .scatter_(1, self.one_hot_indices, 1)            yield inputs.cuda(), labels.cuda()    def train(self):        name = rpc.get_worker_info().name        # 从参数服务器获取模型        m = self.ps_rref.rpc_sync().get_model().cuda()        for inputs, labels in self.get_next_batch():            timed_log(f"{name} processing one batch")            # 利用模型来前向传播/反向传播            self.loss_fn(m(inputs), labels).backward()            timed_log(f"{name} reporting grads")            # 调用参数服务器的函数来提交梯度            m = rpc.rpc_sync( # rpc_sync 操作完成之后,m就是最新模型了                self.ps_rref.owner(),                BatchUpdateParameterServer.update_and_fetch_model,                args=(self.ps_rref, [p.grad for p in m.cpu().parameters()]),            ).cuda()            timed_log(f"{name} got updated model")</code>

拓展逻辑如下:

参数服务器的run_trainer 方法会直接调用 trainer.train() 方法来执行一步step。train 方法之中,会调用 self.ps_rref.rpc_sync().get_model().cuda() 从参数服务器获得模型,放到本地设备之上(图上是双向箭头,表示这是一个get/return动作,需要把模型存储在worker本地)。调用 self.loss_fn(m(inputs), labels).backward() 来进行前向传播/反向传播。调用参数服务器的 update_and_fetch_model 函数来提交梯度,这里使用了异步RPC。参数服务器的 update_and_fetch_model 之中,进行梯度累积,模型更新是通过PS之上常规SGD优化器完成,最后调用 fut.set_result(self.model) 来发布新模型给trainer。在trainer 之中,就是 m = rpc.rpc_sync(...) 这个赋值之后,m 是最新模型了。
[源码解析] PyTorch 分布式(16) --- 使用异步执行实现批处理 RPC
0x05 对比

前文结尾,我们对比参数服务器的经典实现 ps-lite 和 前两篇实现的参数服务器。

ps-lite 是类似传统服务器实现,有自己主动的业务循环,可以响应用户的显式请求,也有自己明确的逻辑,本地也有自己的KV存储。PyTorch 前两篇官方文档(本系列前两篇文章)之中,参数服务器则是另外一种思路: 参数服务器上没有主动的循环,没有KV存储,没有服务器逻辑,而是可以直接存储业务模型,ps 会把业务模型需要优化的参数返回给trainer 之上的 DistributedOptimizer。业务驱动由trainer完成:train loop代码在trainer 之中,DistributedOptimizer 在trainer 之中,DistributedOptimizer 负责进行分布式优化。本文又与上面不同,看起来更像是ps-lite,但是又糅合了RPC实现: ps进程会启动trainer的训练循环。每个迭代之中,trainer 会从参数服务器获取最新模型,前向操作/后向传播都在trainer 完成。trainer 会通过异步RPC把梯度提交给参数服务器。模型更新是通过PS之上常规SGD优化器完成。模型更新之后通过异步RPC把模型再次分发给trainer。

不得不说,官方这几篇文章快把各种实现方式玩出花来了,大家可以依据自己业务特点来参考实现。

0xFF 参考

IMPLEMENTING BATCH RPC PROCESSING USING ASYNCHRONOUS EXECUTIONS

热门AI工具

更多
DeepSeek
DeepSeek

幻方量化公司旗下的开源大模型平台

豆包大模型
豆包大模型

字节跳动自主研发的一系列大型语言模型

WorkBuddy
WorkBuddy

腾讯云推出的AI原生桌面智能体工作台

腾讯元宝
腾讯元宝

腾讯混元平台推出的AI助手

文心一言
文心一言

文心一言是百度开发的AI聊天机器人,通过对话可以生成各种形式的内容。

讯飞写作
讯飞写作

基于讯飞星火大模型的AI写作工具,可以快速生成新闻稿件、品宣文案、工作总结、心得体会等各种文文稿

即梦AI
即梦AI

一站式AI创作平台,免费AI图片和视频生成。

ChatGPT
ChatGPT

最最强大的AI聊天机器人程序,ChatGPT不单是聊天机器人,还能进行撰写邮件、视频脚本、文案、翻译、代码等任务。

相关专题

更多
什么是分布式
什么是分布式

分布式是一种计算和数据处理的方式,将计算任务或数据分散到多个计算机或节点中进行处理。本专题为大家提供分布式相关的文章、下载、课程内容,供大家免费下载体验。

407

2023.08.11

分布式和微服务的区别
分布式和微服务的区别

分布式和微服务的区别在定义和概念、设计思想、粒度和复杂性、服务边界和自治性、技术栈和部署方式等。本专题为大家提供分布式和微服务相关的文章、下载、课程内容,供大家免费下载体验。

251

2023.10.07

线程和进程的区别
线程和进程的区别

线程和进程的区别:线程是进程的一部分,用于实现并发和并行操作,而线程共享进程的资源,通信更方便快捷,切换开销较小。本专题为大家提供线程和进程区别相关的各种文章、以及下载和课程。

765

2023.08.10

github中文官网入口 github中文版官网网页进入
github中文官网入口 github中文版官网网页进入

github中文官网入口https://docs.github.com/zh/get-started,GitHub 是一种基于云的平台,可在其中存储、共享并与他人一起编写代码。 通过将代码存储在GitHub 上的“存储库”中,你可以: “展示或共享”你的工作。 持续“跟踪和管理”对代码的更改。

4225

2026.01.21

pytorch是干嘛的
pytorch是干嘛的

pytorch是一个基于python的深度学习框架,提供以下主要功能:动态图计算,提供灵活性。强大的张量操作,实现高效处理。自动微分,简化梯度计算。预构建的神经网络模块,简化模型构建。各种优化器,用于性能优化。想了解更多pytorch的相关内容,可以阅读本专题下面的文章。

469

2024.05.29

Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习
Python AI机器学习PyTorch教程_Python怎么用PyTorch和TensorFlow做机器学习

PyTorch 是一种用于构建深度学习模型的功能完备框架,是一种通常用于图像识别和语言处理等应用程序的机器学习。 使用Python 编写,因此对于大多数机器学习开发者而言,学习和使用起来相对简单。 PyTorch 的独特之处在于,它完全支持GPU,并且使用反向模式自动微分技术,因此可以动态修改计算图形。

27

2025.12.22

http与https有哪些区别
http与https有哪些区别

http与https的区别:1、协议安全性;2、连接方式;3、证书管理;4、连接状态;5、端口号;6、资源消耗;7、兼容性。本专题为大家提供相关的文章、下载、课程内容,供大家免费下载体验。

2912

2024.08.16

C# ASP.NET Core微服务架构与API网关实践
C# ASP.NET Core微服务架构与API网关实践

本专题围绕 C# 在现代后端架构中的微服务实践展开,系统讲解基于 ASP.NET Core 构建可扩展服务体系的核心方法。内容涵盖服务拆分策略、RESTful API 设计、服务间通信、API 网关统一入口管理以及服务治理机制。通过真实项目案例,帮助开发者掌握构建高可用微服务系统的关键技术,提高系统的可扩展性与维护效率。

76

2026.03.11

Go高并发任务调度与Goroutine池化实践
Go高并发任务调度与Goroutine池化实践

本专题围绕 Go 语言在高并发任务处理场景中的实践展开,系统讲解 Goroutine 调度模型、Channel 通信机制以及并发控制策略。内容包括任务队列设计、Goroutine 池化管理、资源限制控制以及并发任务的性能优化方法。通过实际案例演示,帮助开发者构建稳定高效的 Go 并发任务处理系统,提高系统在高负载环境下的处理能力与稳定性。

38

2026.03.10

热门下载

更多
网站特效
/
网站源码
/
网站素材
/
前端模板

精品课程

更多
相关推荐
/
热门推荐
/
最新课程
Git 教程
Git 教程

共21课时 | 4.2万人学习

Git版本控制工具
Git版本控制工具

共8课时 | 1.6万人学习

Git中文开发手册
Git中文开发手册

共0课时 | 94人学习

关于我们 免责申明 举报中心 意见反馈 讲师合作 广告合作 最新更新
php中文网:公益在线php培训,帮助PHP学习者快速成长!
关注服务号 技术交流群
PHP中文网订阅号
每天精选资源文章推送

Copyright 2014-2026 https://www.php.cn/ All Rights Reserved | php.cn | 湘ICP备2023035733号