3 Star 8 Fork 4

ljpassingby/电力负荷预测

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
predict_model.py 3.24 KB
一键复制 编辑 原始数据 按行查看 历史
ljpassingby 提交于 2020-06-16 11:08 . 添加了predict_model.py
#coding:utf-8
from pyspark import SparkConf
from pyspark import SparkContext
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.tree import GradientBoostedTrees
from pyspark.mllib.tree import GradientBoostedTreesModel
import numpy as np
import pandas as pd
import sys
def extract_label(record):
return float(record[1])
#含所有天气加大楼人数
# def extract_features_dt(record):
# return np.array(record[6:15] + record[2:4])
#含所有天气不含大楼人数
# def extract_features_dt(record):
# return np.array(record[6:14] + record[2:4])
#不含天气只含大楼人数
def extract_features_dt(record):
return np.array(record[11:] + record[2:4])
#评估回归模型的性能
#平均绝对误差
def abs_error(actual, pred):
return np.abs(pred-actual)
#MAPE计算
def cal_error(true_vs_predicted):
mape = true_vs_predicted.map(lambda t: abs_error(t[0], t[1])/t[0]).mean()
print ('MAPE: %2.4f' % mape)
#返回真实值与预测值结合成tuple的列表list
def plot_days(dt_model_gbt, records, day1, day2):
data_test_data = records.filter(lambda point: '' not in point).filter(lambda point: pd.to_datetime(point[0]) >= pd.to_datetime(day1))
data_test_data = data_test_data.filter(lambda point: pd.to_datetime(point[0]) < pd.to_datetime(day2))
data_test = data_test_data.map(lambda point: LabeledPoint(extract_label(point), extract_features_dt(point)))
preds_gbt = dt_model_gbt.predict(data_test.map(lambda p: p.features))
actual = data_test.map(lambda p:p.label)
true_vs_predicted_gbt = actual.zip(preds_gbt)
return true_vs_predicted_gbt
#返回只包含预测值的列表list
def plot_load(dt_model_gbt, records, day1, day2):
data_test_data = records.filter(lambda point: '' not in point).filter(lambda point: pd.to_datetime(point[0]) >= pd.to_datetime(day1))
data_test_data = data_test_data.filter(lambda point: pd.to_datetime(point[0]) < pd.to_datetime(day2))
preds_time = data_test_data.map(lambda p: p[0])
preds_gbt = dt_model_gbt.predict(data_test_data.map(lambda p: extract_features_dt(p)))
return preds_gbt,preds_time
time1 = sys.argv[1]
time2 = sys.argv[2]
modelname = sys.argv[3]
file_path = 'hdfs://192.168.1.5:9000/spark/building/' + modelname + '_before13_count.csv'
model_path = 'hdfs://192.168.1.5:9000/spark/building/' + modelname + '_load_model'
sc = SparkContext("spark://192.168.1.5:7077","a predict spark app")
#file_path = 'hdfs://202.114.96.180:9000/user/xzxu/spark/building/huaning_before13_count.csv'
#sc = SparkContext("spark://202.114.96.180:7077","a predict spark app")
raw_data = sc.textFile(file_path)
records = raw_data.map(lambda x: x.split(','))
records.persist()
model = GradientBoostedTreesModel.load(sc, model_path)
#model = GradientBoostedTreesModel.load(sc, 'hdfs://202.114.96.180:9000/user/xzxu/spark/building/huaning_load_model')
#true_vs_predicted = plot_days(model, records, '2017-12-25', '2017-12-29')
true_vs_predicted, preds_time = plot_load(model, records, time1, time2)
true_vs_predicted.persist()
#print (cal_error(true_vs_predicted))
predictions = true_vs_predicted.collect()
preds_time = preds_time.collect()
for i in range(len(predictions)):
predictions[i] = round(predictions[i], 2)
dic = {"predictions":predictions,"times":preds_time}
print (dic)
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Java
1
https://gitee.com/ljpassingby/pmw.git
git@gitee.com:ljpassingby/pmw.git
ljpassingby
pmw
电力负荷预测
master

搜索帮助