代码拉取完成,页面将自动刷新
from pyspark import SparkConf
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.recommendation import ALS
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql import functions as F
import csv
from itertools import islice
def load_test_data(data_path):
file = csv.reader(open(data_path, "r", encoding="utf-8"))
user_id_list = []
for line in islice(file, 1, None):
user_id_list.append(int(line[0]))
return user_id_list
if __name__ == '__main__':
conf = SparkConf().setAppName('collaborativeFiltering').setMaster('local')
spark = SparkSession.builder.config(conf=conf).getOrCreate()
ratingResourcesPath = "./dataset/new_dataset.csv"
ratingSamples = spark.read.format('csv').option('header', 'true').load(ratingResourcesPath) \
.withColumn("userIdInt", F.col("user_id").cast(IntegerType())) \
.withColumn("itemIdInt", F.col("item_id").cast(IntegerType())) \
.withColumn("ratingFloat", F.col("rating").cast(FloatType()))
training, test = ratingSamples.randomSplit((0.8, 0.2))
# Build the recommendation model using ALS on the training data
# Note we set cold start strategy to 'drop' to ensure we don't get NaN evaluation metrics
als = ALS(regParam=0.01, maxIter=5, userCol='userIdInt', itemCol='itemIdInt', ratingCol='ratingFloat',
coldStartStrategy='drop')
model = als.fit(training)
# Evaluate the model by computing the RMSE on the test data
predictions = model.transform(test)
model.itemFactors.show(10, truncate=False)
model.userFactors.show(10, truncate=False)
evaluator = RegressionEvaluator(predictionCol="prediction", labelCol='ratingFloat', metricName='rmse')
rmse = evaluator.evaluate(predictions)
print("Root-mean-square error = {}".format(rmse))
recommendations = model.recommendForAllUsers(10)
test_user_id_list = load_test_data("./dataset/book_test_dataset.csv")
file = open("submission_lfm_2.csv", "w", encoding="utf-8", newline="")
csv_writer = csv.writer(file)
csv_writer.writerow(["user_id", "item_id"])
count = 0
for uid in test_user_id_list:
res = recommendations.where(recommendations.userIdInt == uid).collect()
res = res[0][1]
for itemid in res:
csv_writer.writerow([str(uid), str(itemid)])
count += 1
if count % 20 == 0:
print("当前推荐进度: {:.2f}%".format(count / len(test_user_id_list) * 100))
file.close()
spark.stop()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。