代码拉取完成,页面将自动刷新
from matplotlib import pyplot as plt
from matplotlib.pyplot import figure
from tqdm import tqdm
from torch import nn
from torch.optim.lr_scheduler import StepLR
from LSTM import LSTM, CNN_LSTM, Seq2Seq
from config import *
from utils import process
if __name__ == '__main__':
print(f'{r_name} training...')
os.makedirs(model_path, exist_ok=True)
os.makedirs(f'{result_path}/train', exist_ok=True)
Dtr = process(train_data, batch_size, True, interval, pred_size, output_n)
DVa = process(val_data, batch_size, True, interval, pred_size, output_n)
if model_name == 'LSTM':
model = LSTM(input_size, hidden_size, num_layers, pred_size, batch_size, device)
elif model_name == 'CNN_LSTM':
model = CNN_LSTM(input_size, hidden_size, num_layers, pred_size)
elif model_name == 'Seq2Seq':
model = Seq2Seq(input_size, hidden_size, num_layers, pred_size, batch_size, device)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate)
loss_fn = nn.MSELoss()
scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma)
min_val_loss = np.Inf
loss_list = []
for epoch in tqdm(range(max_epochs)):
train_loss = []
model.train()
for (seq, label) in Dtr:
seq = seq.to(device)
label = label.to(device)
y_pred = model(seq)
# print(label.shape, y_pred.shape)
loss = loss_fn(y_pred, label)
train_loss.append(loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
model.eval()
total_val_loss = 0
with torch.no_grad(): # 验证数据集时禁止反向传播优化权重
for seq, label in DVa:
seq = seq.to(device)
label = label.to(device)
outputs = model(seq)
loss = loss_fn(outputs, label)
total_val_loss = total_val_loss + loss.item()
loss_list.append(total_val_loss)
if total_val_loss < min_val_loss:
min_val_loss = total_val_loss
m_epoch = epoch
torch.save(model, f"{model_path}/model-{r_name}.pth") # 保存最好的模型
print()
print(f'本次训练损失最小的epoch为{m_epoch},最小损失为{min_val_loss}')
figure(figsize=(12.8, 9.6))
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.plot(loss_list, color='red', label='损失曲线')
plt.scatter(m_epoch, min_val_loss, color='blue', s=50)
plt.text(m_epoch, min_val_loss - min_val_loss * 0.5, '%.6f' % min_val_loss, ha='center', va='bottom', size=20)
plt.title(f'LOSS-{r_name}', fontsize=20)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.legend(fontsize=20)
plt.ylim((0, max(loss_list)))
plt.savefig(f'{result_path}/train/LOSS-{r_name}.png')
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。