This action will force synchronization from OpenDocCN/pytorch-doc-zh, which will overwrite any changes that you have made since you forked the repository, and can not be recovered!!!
Synchronous operation will process in the background and will refresh the page when finishing processing. Please be patient.
原文: https://pytorch.org/tutorials/beginner/transformer_tutorial.html
译者: zanshuxun
本教程展示了如何使用nn.Transformer 模块训练一个seq2seq模型。单击此处下载完整的示例代码
PyTorch 1.2 发布了一个基于论文《Attention is All You Need》的标准transformer模块。transformer模型在很多seq2seq问题上效果更好,且更容易实现并行训练。nn.Transformer
模块使用一种注意力机制(最近实现的另一种注意力为 nn.MultiheadAttention)来捕捉输出和输入之间的整体依赖关系。 nn.Transformer
做到了高度模块化,其中的单个组件也很容易进行修改和使用(例如本教程中的 nn.TransformerEncoder)。
在本教程中,我们训练了一个nn.TransformerEncoder
模型来进行语言建模任务。语言建模任务是指:已有一句话,预测其后续出现某个词或某句话的概率。这句话(一串符号)经过嵌入(embedding)层之后,再使用一个位置编码(positional encoding)层来学习其中的词顺序(详见下一段)。nn.TransformerEncoder
由多层 nn.TransformerEncoderLayer 组成。除了输入序列之外,还需要一个正方形的注意力掩码矩阵。因为是用已经出现的词预测后面的词,训练过程中模型不能看到后面已经出现的词,需要用mask矩阵掩盖掉。 为了获得每个单词的预测概率,nn.TransformerEncoder
后面会接上一个Linear层和softmax层。
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class TransformerModel(nn.Module):
def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
super(TransformerModel, self).__init__()
from torch.nn import TransformerEncoder, TransformerEncoderLayer
self.model_type = 'Transformer'
self.src_mask = None
self.pos_encoder = PositionalEncoding(ninp, dropout)
encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
self.encoder = nn.Embedding(ntoken, ninp)
self.ninp = ninp
self.decoder = nn.Linear(ninp, ntoken)
self.init_weights()
def _generate_square_subsequent_mask(self, sz):
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
return mask
def init_weights(self):
initrange = 0.1
self.encoder.weight.data.uniform_(-initrange, initrange)
self.decoder.bias.data.zero_()
self.decoder.weight.data.uniform_(-initrange, initrange)
def forward(self, src):
if self.src_mask is None or self.src_mask.size(0) != len(src):
device = src.device
mask = self._generate_square_subsequent_mask(len(src)).to(device)
self.src_mask = mask
src = self.encoder(src) * math.sqrt(self.ninp)
src = self.pos_encoder(src)
output = self.transformer_encoder(src, self.src_mask)
output = self.decoder(output)
return output
PositionalEncoding
模块能够学到一些序列中符号的相对或绝对位置信息。位置编码层的输出维度与嵌入层相同,两者可以相加。这里我们使用sine
和cosine
函数来学习单词之间的位置信息。
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0), :]
return self.dropout(x)
训练时使用torchtext
中的 Wikitext-2 数据集。下面代码中的 vocab 对象可以将数据集中的符号转为张量,batchify()
函数用于生成批次数据,将训练集按照batch_size
切分为多个序列,并剔除多余的字符。 例如,当输入序列是字母表时(总长度为 26),设置batch_size
为 4,batchify()
函数会将输入序列分为 4 个长度为 6 的序列:
不同列对于模型来说是独立的,这意味模型无法学习G
和F
的依赖性,但可以进行更有效的批次训练。
import torchtext
from torchtext.data.utils import get_tokenizer
TEXT = torchtext.data.Field(tokenize=get_tokenizer("basic_english"),
init_token='<sos>',
eos_token='<eos>',
lower=True)
train_txt, val_txt, test_txt = torchtext.datasets.WikiText2.splits(TEXT)
TEXT.build_vocab(train_txt)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def batchify(data, bsz):
data = TEXT.numericalize([data.examples[0].text])
# Divide the dataset into bsz parts.
nbatch = data.size(0) // bsz
# Trim off any extra elements that wouldn't cleanly fit (remainders).
data = data.narrow(0, 0, nbatch * bsz)
# Evenly divide the data across the bsz batches.
data = data.view(bsz, -1).t().contiguous()
return data.to(device)
batch_size = 20
eval_batch_size = 10
train_data = batchify(train_txt, batch_size)
val_data = batchify(val_txt, eval_batch_size)
test_data = batchify(test_txt, eval_batch_size)
输出:
downloading wikitext-2-v1.zip
extracting
get_batch()
函数为transformer 模型生成输入和目标序列,将源数据切分为长度为bptt
的块。 对于语言建模任务,模型需要后面出现的单词作为Target
。 例如,bptt
值为 2、i
= 0 时,get_batch()
函数会得到以下两个变量:
注意,每一块数据的第0维与 Transformer 模型中的S
维度一致,第1维是批次尺寸N
。
bptt = 35
def get_batch(source, i):
seq_len = min(bptt, len(source) - 1 - i)
data = source[i:i+seq_len]
target = source[i+1:i+1+seq_len].view(-1)
return data, target
使用下面的超参数创建模型。 词表大小等于 vocab 对象的长度。
ntokens = len(TEXT.vocab.stoi) # the size of vocabulary
emsize = 200 # embedding dimension
nhid = 200 # the dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 2 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 2 # the number of heads in the multiheadattention models
dropout = 0.2 # the dropout value
model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout).to(device)
损失函数采用CrossEntropyLoss , 优化器采用SGD 中实现的随机梯度下降法。初始学习率设置为 5.0。 StepLR 用于在不同的迭代伦次(epochs)中调整学习率。 训练时使用 nn.utils.clip_grad_norm_ 函数将所有梯度进行缩放,来防止发生梯度爆炸。
criterion = nn.CrossEntropyLoss()
lr = 5.0 # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)
import time
def train():
model.train() # Turn on the train mode
total_loss = 0.
start_time = time.time()
ntokens = len(TEXT.vocab.stoi)
for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
data, targets = get_batch(train_data, i)
optimizer.zero_grad()
output = model(data)
loss = criterion(output.view(-1, ntokens), targets)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
optimizer.step()
total_loss += loss.item()
log_interval = 200
if batch % log_interval == 0 and batch > 0:
cur_loss = total_loss / log_interval
elapsed = time.time() - start_time
print('| epoch {:3d} | {:5d}/{:5d} batches | '
'lr {:02.2f} | ms/batch {:5.2f} | '
'loss {:5.2f} | ppl {:8.2f}'.format(
epoch, batch, len(train_data) // bptt, scheduler.get_lr()[0],
elapsed * 1000 / log_interval,
cur_loss, math.exp(cur_loss)))
total_loss = 0
start_time = time.time()
def evaluate(eval_model, data_source):
eval_model.eval() # Turn on the evaluation mode
total_loss = 0.
ntokens = len(TEXT.vocab.stoi)
with torch.no_grad():
for i in range(0, data_source.size(0) - 1, bptt):
data, targets = get_batch(data_source, i)
output = eval_model(data)
output_flat = output.view(-1, ntokens)
total_loss += len(data) * criterion(output_flat, targets).item()
return total_loss / (len(data_source) - 1)
对训练集最多遍历epochs
次。 如果模型在验证集上的损失达到最优,则保存模型。 每一轮训练结束后都会调整学习率。
best_val_loss = float("inf")
epochs = 3 # The number of epochs
best_model = None
for epoch in range(1, epochs + 1):
epoch_start_time = time.time()
train()
val_loss = evaluate(model, val_data)
print('-' * 89)
print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
val_loss, math.exp(val_loss)))
print('-' * 89)
if val_loss < best_val_loss:
best_val_loss = val_loss
best_model = model
scheduler.step()
输出:
| epoch 1 | 200/ 2981 batches | lr 5.00 | ms/batch 29.47 | loss 8.04 | ppl 3112.50
| epoch 1 | 400/ 2981 batches | lr 5.00 | ms/batch 28.38 | loss 6.78 | ppl 882.16
| epoch 1 | 600/ 2981 batches | lr 5.00 | ms/batch 28.38 | loss 6.38 | ppl 589.27
| epoch 1 | 800/ 2981 batches | lr 5.00 | ms/batch 28.40 | loss 6.23 | ppl 508.15
| epoch 1 | 1000/ 2981 batches | lr 5.00 | ms/batch 28.41 | loss 6.12 | ppl 454.63
| epoch 1 | 1200/ 2981 batches | lr 5.00 | ms/batch 28.40 | loss 6.09 | ppl 441.65
| epoch 1 | 1400/ 2981 batches | lr 5.00 | ms/batch 28.42 | loss 6.04 | ppl 418.77
| epoch 1 | 1600/ 2981 batches | lr 5.00 | ms/batch 28.41 | loss 6.04 | ppl 421.53
| epoch 1 | 1800/ 2981 batches | lr 5.00 | ms/batch 28.40 | loss 5.96 | ppl 387.98
| epoch 1 | 2000/ 2981 batches | lr 5.00 | ms/batch 28.41 | loss 5.96 | ppl 386.42
| epoch 1 | 2200/ 2981 batches | lr 5.00 | ms/batch 28.42 | loss 5.85 | ppl 346.77
| epoch 1 | 2400/ 2981 batches | lr 5.00 | ms/batch 28.42 | loss 5.89 | ppl 362.54
| epoch 1 | 2600/ 2981 batches | lr 5.00 | ms/batch 28.42 | loss 5.90 | ppl 364.01
| epoch 1 | 2800/ 2981 batches | lr 5.00 | ms/batch 28.43 | loss 5.80 | ppl 329.20
-----------------------------------------------------------------------------------------
| end of epoch 1 | time: 88.26s | valid loss 5.73 | valid ppl 307.01
-----------------------------------------------------------------------------------------
| epoch 2 | 200/ 2981 batches | lr 4.51 | ms/batch 28.58 | loss 5.79 | ppl 328.13
| epoch 2 | 400/ 2981 batches | lr 4.51 | ms/batch 28.42 | loss 5.77 | ppl 319.25
| epoch 2 | 600/ 2981 batches | lr 4.51 | ms/batch 28.42 | loss 5.60 | ppl 270.79
| epoch 2 | 800/ 2981 batches | lr 4.51 | ms/batch 28.44 | loss 5.63 | ppl 279.91
| epoch 2 | 1000/ 2981 batches | lr 4.51 | ms/batch 28.44 | loss 5.58 | ppl 265.99
| epoch 2 | 1200/ 2981 batches | lr 4.51 | ms/batch 28.44 | loss 5.61 | ppl 273.55
| epoch 2 | 1400/ 2981 batches | lr 4.51 | ms/batch 28.42 | loss 5.63 | ppl 277.59
| epoch 2 | 1600/ 2981 batches | lr 4.51 | ms/batch 28.45 | loss 5.66 | ppl 287.09
| epoch 2 | 1800/ 2981 batches | lr 4.51 | ms/batch 28.44 | loss 5.58 | ppl 266.00
| epoch 2 | 2000/ 2981 batches | lr 4.51 | ms/batch 28.44 | loss 5.61 | ppl 272.58
| epoch 2 | 2200/ 2981 batches | lr 4.51 | ms/batch 28.44 | loss 5.50 | ppl 244.59
| epoch 2 | 2400/ 2981 batches | lr 4.51 | ms/batch 28.44 | loss 5.57 | ppl 262.87
| epoch 2 | 2600/ 2981 batches | lr 4.51 | ms/batch 28.44 | loss 5.58 | ppl 265.65
| epoch 2 | 2800/ 2981 batches | lr 4.51 | ms/batch 28.44 | loss 5.51 | ppl 246.48
-----------------------------------------------------------------------------------------
| end of epoch 2 | time: 88.16s | valid loss 5.53 | valid ppl 253.40
-----------------------------------------------------------------------------------------
| epoch 3 | 200/ 2981 batches | lr 4.29 | ms/batch 28.58 | loss 5.55 | ppl 256.02
| epoch 3 | 400/ 2981 batches | lr 4.29 | ms/batch 28.43 | loss 5.55 | ppl 256.76
| epoch 3 | 600/ 2981 batches | lr 4.29 | ms/batch 28.46 | loss 5.36 | ppl 212.31
| epoch 3 | 800/ 2981 batches | lr 4.29 | ms/batch 28.44 | loss 5.42 | ppl 225.88
| epoch 3 | 1000/ 2981 batches | lr 4.29 | ms/batch 28.46 | loss 5.38 | ppl 217.24
| epoch 3 | 1200/ 2981 batches | lr 4.29 | ms/batch 28.45 | loss 5.41 | ppl 223.82
| epoch 3 | 1400/ 2981 batches | lr 4.29 | ms/batch 28.43 | loss 5.42 | ppl 226.87
| epoch 3 | 1600/ 2981 batches | lr 4.29 | ms/batch 28.44 | loss 5.47 | ppl 238.34
| epoch 3 | 1800/ 2981 batches | lr 4.29 | ms/batch 28.45 | loss 5.41 | ppl 223.13
| epoch 3 | 2000/ 2981 batches | lr 4.29 | ms/batch 28.45 | loss 5.44 | ppl 230.23
| epoch 3 | 2200/ 2981 batches | lr 4.29 | ms/batch 28.44 | loss 5.32 | ppl 205.28
| epoch 3 | 2400/ 2981 batches | lr 4.29 | ms/batch 28.44 | loss 5.40 | ppl 221.60
| epoch 3 | 2600/ 2981 batches | lr 4.29 | ms/batch 28.45 | loss 5.42 | ppl 224.76
| epoch 3 | 2800/ 2981 batches | lr 4.29 | ms/batch 28.44 | loss 5.34 | ppl 209.38
-----------------------------------------------------------------------------------------
| end of epoch 3 | time: 88.18s | valid loss 5.48 | valid ppl 240.75
-----------------------------------------------------------------------------------------
在测试集上使用保存的最优模型来评估效果。
test_loss = evaluate(best_model, test_data)
print('=' * 89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
test_loss, math.exp(test_loss)))
print('=' * 89)
输出:
=========================================================================================
| End of training | test loss 5.39 | test ppl 219.13
=========================================================================================
脚本的总运行时间:(4 分钟 42.167 秒)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。