1 Star 0 Fork 0

loic_wang/TF-recomm

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
ops.py 1.88 KB
一键复制 编辑 原始数据 按行查看 历史
Guocong Song 提交于 2017-02-07 10:22 . Fix for Tensorflow v1.0 APIs
import tensorflow as tf
def inference_svd(user_batch, item_batch, user_num, item_num, dim=5, device="/cpu:0"):
with tf.device("/cpu:0"):
bias_global = tf.get_variable("bias_global", shape=[])
w_bias_user = tf.get_variable("embd_bias_user", shape=[user_num])
w_bias_item = tf.get_variable("embd_bias_item", shape=[item_num])
bias_user = tf.nn.embedding_lookup(w_bias_user, user_batch, name="bias_user")
bias_item = tf.nn.embedding_lookup(w_bias_item, item_batch, name="bias_item")
w_user = tf.get_variable("embd_user", shape=[user_num, dim],
initializer=tf.truncated_normal_initializer(stddev=0.02))
w_item = tf.get_variable("embd_item", shape=[item_num, dim],
initializer=tf.truncated_normal_initializer(stddev=0.02))
embd_user = tf.nn.embedding_lookup(w_user, user_batch, name="embedding_user")
embd_item = tf.nn.embedding_lookup(w_item, item_batch, name="embedding_item")
with tf.device(device):
infer = tf.reduce_sum(tf.multiply(embd_user, embd_item), 1)
infer = tf.add(infer, bias_global)
infer = tf.add(infer, bias_user)
infer = tf.add(infer, bias_item, name="svd_inference")
regularizer = tf.add(tf.nn.l2_loss(embd_user), tf.nn.l2_loss(embd_item), name="svd_regularizer")
return infer, regularizer
def optimization(infer, regularizer, rate_batch, learning_rate=0.001, reg=0.1, device="/cpu:0"):
global_step = tf.train.get_global_step()
assert global_step is not None
with tf.device(device):
cost_l2 = tf.nn.l2_loss(tf.subtract(infer, rate_batch))
penalty = tf.constant(reg, dtype=tf.float32, shape=[], name="l2")
cost = tf.add(cost_l2, tf.multiply(regularizer, penalty))
train_op = tf.train.AdamOptimizer(learning_rate).minimize(cost, global_step=global_step)
return cost, train_op
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/loic_wang/TF-recomm.git
git@gitee.com:loic_wang/TF-recomm.git
loic_wang
TF-recomm
TF-recomm
master

搜索帮助