1 Star 3 Fork 0

ydsungan/图书推荐系统

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
spark_cf.py 2.54 KB
一键复制 编辑 原始数据 按行查看 历史
from pyspark import SparkConf
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.recommendation import ALS
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql import functions as F
import csv
from itertools import islice
def load_test_data(data_path):
file = csv.reader(open(data_path, "r", encoding="utf-8"))
user_id_list = []
for line in islice(file, 1, None):
user_id_list.append(int(line[0]))
return user_id_list
if __name__ == '__main__':
conf = SparkConf().setAppName('collaborativeFiltering').setMaster('local')
spark = SparkSession.builder.config(conf=conf).getOrCreate()
ratingResourcesPath = "./dataset/new_dataset.csv"
ratingSamples = spark.read.format('csv').option('header', 'true').load(ratingResourcesPath) \
.withColumn("userIdInt", F.col("user_id").cast(IntegerType())) \
.withColumn("itemIdInt", F.col("item_id").cast(IntegerType())) \
.withColumn("ratingFloat", F.col("rating").cast(FloatType()))
training, test = ratingSamples.randomSplit((0.8, 0.2))
# Build the recommendation model using ALS on the training data
# Note we set cold start strategy to 'drop' to ensure we don't get NaN evaluation metrics
als = ALS(regParam=0.01, maxIter=5, userCol='userIdInt', itemCol='itemIdInt', ratingCol='ratingFloat',
coldStartStrategy='drop')
model = als.fit(training)
# Evaluate the model by computing the RMSE on the test data
predictions = model.transform(test)
model.itemFactors.show(10, truncate=False)
model.userFactors.show(10, truncate=False)
evaluator = RegressionEvaluator(predictionCol="prediction", labelCol='ratingFloat', metricName='rmse')
rmse = evaluator.evaluate(predictions)
print("Root-mean-square error = {}".format(rmse))
recommendations = model.recommendForAllUsers(10)
test_user_id_list = load_test_data("./dataset/book_test_dataset.csv")
file = open("submission_lfm_2.csv", "w", encoding="utf-8", newline="")
csv_writer = csv.writer(file)
csv_writer.writerow(["user_id", "item_id"])
count = 0
for uid in test_user_id_list:
res = recommendations.where(recommendations.userIdInt == uid).collect()
res = res[0][1]
for itemid in res:
csv_writer.writerow([str(uid), str(itemid)])
count += 1
if count % 20 == 0:
print("当前推荐进度: {:.2f}%".format(count / len(test_user_id_list) * 100))
file.close()
spark.stop()
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/ydsungan/book-recommendation-system.git
git@gitee.com:ydsungan/book-recommendation-system.git
ydsungan
book-recommendation-system
图书推荐系统
master

搜索帮助

0d507c66 1850385 C8b1a773 1850385