代码拉取完成,页面将自动刷新
import os
from config import *
from utils import process, t_plot, write_log
import torch
txt_list = []
all_data = data[:, output_n + 1]
Dte = process(test_data, 1, False, interval, pred_size, output_n)
test_num = int(int(len(data) * (1 - test_rate)))
train_data = data[:int(len(data) * (1 - val_rate - test_rate))]
val_data = data[int(len(data) * (1 - val_rate - test_rate)):test_num]
test_data = data[test_num:]
model = torch.load(f'{model_path}/model-{r_name}.pth')
write_log(f'model: {r_name}', txt_list)
y_re, y_real = t_plot(result_path, r_name, model, test_num, t_date, Dte, all_data, device, pred_size,
interval, txt_list)
n = np.where(y_real == 0)
print(n)
MSE = np.sum((y_re - y_real) ** 2) / len(y_real)
RMSE = np.sqrt(MSE)
MAE = np.sum(np.abs(y_re - y_real)) / len(y_real)
MAPE = np.sum(np.abs((y_re - y_real) / y_real)) / len(y_real) * 100
S = np.abs(y_re - y_real) / ((np.abs(y_re) + np.abs(y_real)) / 2)
SMAPE = np.sum(S) / len(y_real) * 100
av_y = sum(y_real) / len(y_real)
R2 = 1 - np.sum((y_re - y_real) ** 2) / np.sum((av_y - y_real) ** 2)
write_log(f'MSE: {MSE}', txt_list)
write_log(f'RMSE: {RMSE}', txt_list)
write_log(f'MAE: {MAE}', txt_list)
write_log(f'R2: {R2}', txt_list)
write_log(f'MAPE: {MAPE} %', txt_list)
write_log(f'SMAPE: {SMAPE} %', txt_list)
os.makedirs(f'{result_path}/log', exist_ok=True)
content = ''
for txt in txt_list:
content += txt
with open(f'{result_path}/log/log-{r_name}.txt', 'w+', encoding='utf8') as f:
f.write(content)
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。