This action will force synchronization from OpenDocCN/pytorch-doc-zh, which will overwrite any changes that you have made since you forked the repository, and can not be recovered!!!
Synchronous operation will process in the background and will refresh the page when finishing processing. Please be patient.
torch.multiprocessing
是 Python 的 multiprocessing
多进程模块的替代品。它支持完全相同的操作,但对其进行了扩展,以便所有通过多进程队列 multiprocessing.Queue
发送的张量都能将其数据移入共享内存,而且仅将其句柄发送到另一个进程。
注意:
当张量
Tensor
被发送到另一个进程时,张量的数据和梯度torch.Tensor.grad
都将被共享。
这一特性允许实现各种训练方法,如 Hogwild,A3C 或任何其他需要异步操作的训练方法。
仅 Python 3 支持进程之间共享 CUDA 张量,我们可以使用 spawn
或forkserver
启动此类方法。 Python 2 中的 multiprocessing
多进程处理只能使用 fork
创建子进程,并且 CUDA 运行时不支持多进程处理。
警告:
CUDA API 规定输出到其他进程的共享张量,只要它们被这些进程使用时,都将持续保持有效。您应该小心并确保您共享的 CUDA 张量不会超出它应该的作用范围(不会出现作用范围延伸的问题)。这对于共享模型的参数应该不是问题,但应该小心地传递其他类型的数据。请注意,此限制不适用于共享的 CPU 内存。
也可以参阅:, 使用 nn.DataParallel 替代多进程处理
产生新进程时会出现很多错误,导致死锁最常见的原因是后台线程。如果有任何持有锁或导入模块的线程,并且 fork
被调用,则子进程很可能处于崩溃状态,并且会以不同方式死锁或失败。请注意,即使您没有这样做,Python 中内置的库也可能会,更不必说 多进程处理
了。multiprocessing.Queue
多进程队列实际上是一个非常复杂的类,它产生了多个用于序列化、发送和接收对象的线程,并且它们也可能导致上述问题。如果您发现自己处于这种情况,请尝试使用multiprocessing.queues.SimpleQueue
,它不使用任何其他额外的线程。
我们正在尽可能的为您提供便利,并确保这些死锁不会发生,但有些事情不受我们控制。如果您有任何问题暂时无法应对,请尝试到论坛求助,我们会查看是否可以解决问题。
请记住,每次将张量放入多进程队列 multiprocessing.Queue
时,它必须被移动到共享内存中。如果它已经被共享,将会是一个空操作,否则会产生一个额外的内存拷贝,这会减慢整个过程。即使您有一组进程将数据发送到单个进程,也可以让它将缓冲区发送回去,这几乎是不占资源的,并且可以在发送下一批时避免产生拷贝动作。
使用多进程处理 torch.multiprocessing
,可以异步地训练一个模型,参数既可以一直共享,也可以周期性同步。在第一种情况下,我们建议发送整个模型对象,而在后者中,我们建议只发送状态字典 state_dict()
。
我们建议使用多进程处理队列 multiprocessing.Queue
在进程之间传递各种 PyTorch 对象。使用 fork
启动一个方法时,它也可能会继承共享内存中的张量和存储空间,但这种方式也非常容易出错,应谨慎使用,最好只能让高阶用户使用。而队列,尽管它们有时候不太优雅,却能在任何情况下正常工作。
警告:
你应该留意没有用
if __name__ =='__main__'
来保护的全局语句。如果使用了不同于fork
启动方法,它们将在所有子进程中执行。
具体的 Hogwild 实现可以在 示例库 中找到,但为了展示代码的整体结构,下面还有一个最简单的示例:
import torch.multiprocessing as mp
from model import MyModel
def train(model):
# 构建 data_loader,优化器等
for data, labels in data_loader:
optimizer.zero_grad()
loss_fn(model(data), labels).backward()
optimizer.step() # 更新共享的参数
if __name__ == '__main__':
num_processes = 4
model = MyModel()
# 注意:这是 "fork" 方法工作所必需的
model.share_memory()
processes = []
for rank in range(num_processes):
p = mp.Process(target=train, args=(model,))
p.start()
processes.append(p)
for p in processes:
p.join()
用户名 | 头像 | 职能 | 签名 |
---|---|---|---|
风中劲草 | 翻译 | 人生总要追求点什么 |
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。