1 Star 3 Fork 0

guox66 / wind_prediction

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
t_model.py 1.51 KB
一键复制 编辑 原始数据 按行查看 历史
Xu Guo 提交于 2024-02-28 20:59 . Add files via upload
import os
from config import *
from utils import process, t_plot, write_log
import torch
txt_list = []
all_data = data[:, output_n + 1]
Dte = process(test_data, 1, False, interval, pred_size, output_n)
test_num = int(int(len(data) * (1 - test_rate)))
train_data = data[:int(len(data) * (1 - val_rate - test_rate))]
val_data = data[int(len(data) * (1 - val_rate - test_rate)):test_num]
test_data = data[test_num:]
model = torch.load(f'{model_path}/model-{r_name}.pth')
write_log(f'model: {r_name}', txt_list)
y_re, y_real = t_plot(result_path, r_name, model, test_num, t_date, Dte, all_data, device, pred_size,
interval, txt_list)
n = np.where(y_real == 0)
print(n)
MSE = np.sum((y_re - y_real) ** 2) / len(y_real)
RMSE = np.sqrt(MSE)
MAE = np.sum(np.abs(y_re - y_real)) / len(y_real)
MAPE = np.sum(np.abs((y_re - y_real) / y_real)) / len(y_real) * 100
S = np.abs(y_re - y_real) / ((np.abs(y_re) + np.abs(y_real)) / 2)
SMAPE = np.sum(S) / len(y_real) * 100
av_y = sum(y_real) / len(y_real)
R2 = 1 - np.sum((y_re - y_real) ** 2) / np.sum((av_y - y_real) ** 2)
write_log(f'MSE: {MSE}', txt_list)
write_log(f'RMSE: {RMSE}', txt_list)
write_log(f'MAE: {MAE}', txt_list)
write_log(f'R2: {R2}', txt_list)
write_log(f'MAPE: {MAPE} %', txt_list)
write_log(f'SMAPE: {SMAPE} %', txt_list)
os.makedirs(f'{result_path}/log', exist_ok=True)
content = ''
for txt in txt_list:
content += txt
with open(f'{result_path}/log/log-{r_name}.txt', 'w+', encoding='utf8') as f:
f.write(content)
Python
1
https://gitee.com/guox66/wind_prediction.git
git@gitee.com:guox66/wind_prediction.git
guox66
wind_prediction
wind_prediction
main

搜索帮助