1 Star 0 Fork 0

loic_wang/TF-recomm

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
dataio.py 1.85 KB
一键复制 编辑 原始数据 按行查看 历史
from __future__ import absolute_import, division, print_function
import numpy as np
import pandas as pd
def read_process(filname, sep="\t"):
col_names = ["user", "item", "rate", "st"]
df = pd.read_csv(filname, sep=sep, header=None, names=col_names, engine='python')
df["user"] -= 1
df["item"] -= 1
for col in ("user", "item"):
df[col] = df[col].astype(np.int32)
df["rate"] = df["rate"].astype(np.float32)
return df
class ShuffleIterator(object):
"""
Randomly generate batches
"""
def __init__(self, inputs, batch_size=10):
self.inputs = inputs
self.batch_size = batch_size
self.num_cols = len(self.inputs)
self.len = len(self.inputs[0])
self.inputs = np.transpose(np.vstack([np.array(self.inputs[i]) for i in range(self.num_cols)]))
def __len__(self):
return self.len
def __iter__(self):
return self
def __next__(self):
return self.next()
def next(self):
ids = np.random.randint(0, self.len, (self.batch_size,))
out = self.inputs[ids, :]
return [out[:, i] for i in range(self.num_cols)]
class OneEpochIterator(ShuffleIterator):
"""
Sequentially generate one-epoch batches, typically for test data
"""
def __init__(self, inputs, batch_size=10):
super(OneEpochIterator, self).__init__(inputs, batch_size=batch_size)
if batch_size > 0:
self.idx_group = np.array_split(np.arange(self.len), np.ceil(self.len / batch_size))
else:
self.idx_group = [np.arange(self.len)]
self.group_id = 0
def next(self):
if self.group_id >= len(self.idx_group):
self.group_id = 0
raise StopIteration
out = self.inputs[self.idx_group[self.group_id], :]
self.group_id += 1
return [out[:, i] for i in range(self.num_cols)]
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/loic_wang/TF-recomm.git
git@gitee.com:loic_wang/TF-recomm.git
loic_wang
TF-recomm
TF-recomm
master

搜索帮助