代码拉取完成,页面将自动刷新
import time
from rdData import readData
from kFord import kFord
import numpy as np
def initial(K):
seed = 0
flag = 0
time_start = time.time()
while 1: # 保证划分后以及K折后的训练集包含所有的user
train_data, test_data = readData('ml-latest-small/ratings.csv', seed)
split_train_data = kFord(K, train_data, seed)
train_data = {}
valid_data = {}
user_index = {}
for i in range(K):
valid_data[i] = {}
for key in split_train_data.keys():
valid_data[i][key] = split_train_data[key][i]
train_data[i] = {}
for j in range(i):
for key in split_train_data.keys():
if key not in train_data[i].keys():
train_data[i][key] = split_train_data[key][j]
else:
train_data[i][key] = np.concatenate((train_data[i][key], split_train_data[key][j]), axis=0)
for j in range(i + 1, K):
for key in split_train_data.keys():
if key not in train_data[i].keys():
train_data[i][key] = split_train_data[key][j]
else:
train_data[i][key] = np.concatenate((train_data[i][key], split_train_data[key][j]), axis=0)
# 注释内容为生成user-item-rating矩阵,由于这样存储需要花费太大代价,不符合稀疏矩阵的需求,所以注释
# max_userid=np.max(train_data[i]['userId'])
# max_itemid=np.max(train_data[i]['movieId'])
# train_data_np=np.zeros([max_userid+1,max_itemid+1])
# for j in range(1,train_data[i]['userId'].shape[0]+1):
# train_data_np[train_data[i]['userId'][j-1],train_data[i]['movieId'][j-1]]=train_data[i]['rating'][j-1]
# train_data[i]=train_data_np
# 将打乱顺序的数据重新恢复(不过userID内部是乱序的,并未恢复)
sorted_index = np.argsort(train_data[i]['userId'])
train_data[i]['userId'] = np.sort(train_data[i]['userId'])
for key in train_data[i].keys():
if key != 'userId':
train_data[i][key] = train_data[i][key][sorted_index]
user_index[i] = np.zeros([np.max(train_data[i]['userId']) + 2, ],dtype='int64')
for j in range(np.max(train_data[i]['userId'])):
index = np.where(train_data[i]['userId'] == j + 1)
if index[0].shape[0]!=0:
user_index[i][j + 1] = index[0][0]
else:
flag = 1
break
user_index[i][user_index[i].shape[0]-1]=train_data[i]['userId'].shape[0]
if flag == 1:
break
else:
continue
if flag == 1:
flag = 0
seed += 1
time_now = time.time()
if time_now - time_start > 60: # 如果大于60秒,仍然没有找到可行的划分,报错
print('error,cant find fit split data')
return ([],[],[],[],-1)
else:
return train_data,valid_data,test_data,user_index,0
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。