1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
| #!/usr/bin/python # -*- coding: UTF-8 -*-
import math import torch from torch import nn from d2l import torch as d2l import pandas as pd from DotProductAttention import * from EncoderDecoder import * from MultiHeadAttention import * from PositionalEncoding import *
def sequence_mask(X, valid_len, value=0): """在序列中屏蔽不相关的项""" maxlen = X.size(1) mask = torch.arange((maxlen), dtype=torch.float32, device=X.device)[None, :] < valid_len[:, None] X[~mask] = value return X
#@save class MaskedSoftmaxCELoss(nn.CrossEntropyLoss): """带遮蔽的softmax交叉熵损失函数""" # pred的形状:(batch_size,num_steps,vocab_size) # label的形状:(batch_size,num_steps) # valid_len的形状:(batch_size,) def forward(self, pred, label, valid_len): weights = torch.ones_like(label) weights = sequence_mask(weights, valid_len) self.reduction='none' unweighted_loss = super(MaskedSoftmaxCELoss, self).forward( pred.permute(0, 2, 1), label) weighted_loss = (unweighted_loss * weights).mean(dim=1) return weighted_loss
def train_seq2seq(net, data_iter, lr, num_epochs, tgt_vocab, device): """训练序列到序列模型""" def xavier_init_weights(m): if type(m) == nn.Linear: nn.init.xavier_uniform_(m.weight) if type(m) == nn.GRU: for param in m._flat_weights_names: if "weight" in param: nn.init.xavier_uniform_(m._parameters[param])
net.apply(xavier_init_weights) net.to(device) optimizer = torch.optim.Adam(net.parameters(), lr=lr) loss = MaskedSoftmaxCELoss() net.train() animator = d2l.Animator(xlabel='epoch', ylabel='loss', xlim=[10, num_epochs]) for epoch in range(num_epochs): timer = d2l.Timer() metric = d2l.Accumulator(2) # 训练损失总和,词元数量 for batch in data_iter: optimizer.zero_grad() X, X_valid_len, Y, Y_valid_len = [x.to(device) for x in batch] bos = torch.tensor([tgt_vocab['<bos>']] * Y.shape[0], device=device).reshape(-1, 1) dec_input = torch.cat([bos, Y[:, :-1]], 1) # 强制教学 Y_hat, _ = net(X, dec_input, X_valid_len) l = loss(Y_hat, Y, Y_valid_len) l.sum().backward() # 损失函数的标量进行“反向传播” d2l.grad_clipping(net, 1) num_tokens = Y_valid_len.sum() optimizer.step() with torch.no_grad(): metric.add(l.sum(), num_tokens) if (epoch + 1) % 10 == 0: animator.add(epoch + 1, (metric[0] / metric[1],)) print(f'loss {metric[0] / metric[1]:.3f}, {metric[1] / timer.stop():.1f} ' f'tokens/sec on {str(device)}')
|