代码拉取完成,页面将自动刷新
同步操作将从 OpenDocCN/pytorch-doc-zh 强制同步,此操作会覆盖自 Fork 仓库以来所做的任何修改,且无法恢复!!!
确定后同步将在后台操作,完成时将刷新页面,请耐心等待。
译者:@飞龙
作者: Soumith Chintala
首先, 你需要编写你的 C 函数.
下面你可以找到模块的正向和反向函数的示例实现, 它将两个输入相加.
在你的 .c
文件中, 你可以使用 #include <TH/TH.h>
直接包含 TH, 以及使用 #include <THC/THC.h>
包含 THC.
ffi (外来函数接口) 工具会确保编译器可以在构建过程中找到它们.
/* src/my_lib.c */
#include <TH/TH.h>
int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2,
THFloatTensor *output)
{
if (!THFloatTensor_isSameSizeAs(input1, input2))
return 0;
THFloatTensor_resizeAs(output, input1);
THFloatTensor_cadd(output, input1, 1.0, input2);
return 1;
}
int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input)
{
THFloatTensor_resizeAs(grad_input, grad_output);
THFloatTensor_fill(grad_input, 1);
return 1;
}
代码没有任何限制, 除了你必须准备单个头文件, 它会列出所有你想要从 Python 调用的函数.
它会由 ffi 用于生成合适的包装.
/* src/my_lib.h */
int my_lib_add_forward(THFloatTensor *input1, THFloatTensor *input2, THFloatTensor *output);
int my_lib_add_backward(THFloatTensor *grad_output, THFloatTensor *grad_input);
现在, 你需要一个超短的文件, 它会构建你的自定义扩展:
# build.py
from torch.utils.ffi import create_extension
ffi = create_extension(
name='_ext.my_lib',
headers='src/my_lib.h',
sources=['src/my_lib.c'],
with_cuda=False
)
ffi.build()
你运行它之后, pytorch 会创建一个 _ext
目录, 并把 my_lib
放到里面.
包名称可以在最终模块名称之前, 包含任意数量的包 (包括没有). 如果构建成功, 你可以导入你的扩展, 就像普通的 Python 文件.
# functions/add.py
import torch
from torch.autograd import Function
from _ext import my_lib
class MyAddFunction(Function):
def forward(self, input1, input2):
output = torch.FloatTensor()
my_lib.my_lib_add_forward(input1, input2, output)
return output
def backward(self, grad_output):
grad_input = torch.FloatTensor()
my_lib.my_lib_add_backward(grad_output, grad_input)
return grad_input
# modules/add.py
from torch.nn import Module
from functions.add import MyAddFunction
class MyAddModule(Module):
def forward(self, input1, input2):
return MyAddFunction()(input1, input2)
# main.py
import torch
import torch.nn as nn
from torch.autograd import Variable
from modules.add import MyAddModule
class MyNetwork(nn.Module):
def __init__(self):
super(MyNetwork, self).__init__()
self.add = MyAddModule()
def forward(self, input1, input2):
return self.add(input1, input2)
model = MyNetwork()
input1, input2 = Variable(torch.randn(5, 5)), Variable(torch.randn(5, 5))
print(model(input1, input2))
print(input1 + input2)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。