代码拉取完成,页面将自动刷新
from __future__ import absolute_import, division, print_function
import numpy as np
import pandas as pd
def read_process(filname, sep="\t"):
col_names = ["user", "item", "rate", "st"]
df = pd.read_csv(filname, sep=sep, header=None, names=col_names, engine='python')
df["user"] -= 1
df["item"] -= 1
for col in ("user", "item"):
df[col] = df[col].astype(np.int32)
df["rate"] = df["rate"].astype(np.float32)
return df
class ShuffleIterator(object):
"""
Randomly generate batches
"""
def __init__(self, inputs, batch_size=10):
self.inputs = inputs
self.batch_size = batch_size
self.num_cols = len(self.inputs)
self.len = len(self.inputs[0])
self.inputs = np.transpose(np.vstack([np.array(self.inputs[i]) for i in range(self.num_cols)]))
def __len__(self):
return self.len
def __iter__(self):
return self
def __next__(self):
return self.next()
def next(self):
ids = np.random.randint(0, self.len, (self.batch_size,))
out = self.inputs[ids, :]
return [out[:, i] for i in range(self.num_cols)]
class OneEpochIterator(ShuffleIterator):
"""
Sequentially generate one-epoch batches, typically for test data
"""
def __init__(self, inputs, batch_size=10):
super(OneEpochIterator, self).__init__(inputs, batch_size=batch_size)
if batch_size > 0:
self.idx_group = np.array_split(np.arange(self.len), np.ceil(self.len / batch_size))
else:
self.idx_group = [np.arange(self.len)]
self.group_id = 0
def next(self):
if self.group_id >= len(self.idx_group):
self.group_id = 0
raise StopIteration
out = self.inputs[self.idx_group[self.group_id], :]
self.group_id += 1
return [out[:, i] for i in range(self.num_cols)]
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。