# 使用 TB-Net 白盒推荐模型 [![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/r1.8/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/r1.8/docs/xai/docs/source_zh_cn/using_tbnet.md) ## 什么是 TB-Net TB-Net是一个基于知识图谱的可解释推荐系统,它将用户和商品的交互信息以及物品的属性信息在知识图谱中构建子图,并利用双向传导的计算方法对图谱中的路径进行计算,最后得到可解释的推荐结果。 论文:Shendi Wang, Haoyang Li, Xiao-Hui Li, Caleb Chen Cao, Lei Chen. Tower Bridge Net (TB-Net): Bidirectional Knowledge Graph Aware Embedding Propagation for Explainable Recommender Systems ## 准备 ### 下载数据集 首先,我们要下一个用例数据包并解压到一个本地 [XAI原码包](https://gitee.com/mindspore/xai) 中的`models/whitebox/tbnet`文件夹: ```bash wget https://mindspore-website.obs.myhuaweicloud.com/notebook/datasets/xai/tbnet_data.tar.gz tar -xf tbnet_data.tar.gz git clone https://gitee.com/mindspore/xai.git mv data xai/models/whitebox/tbnet ``` `xai/models/whitebox/tbnet/` 文件夹结构: ```bash . └─tbnet ├─README.md ├─README_CN.md ├─data │ └─steam # Steam 用户历史行为数据集 │ ├─LICENSE │ ├─config.json # 超参和训练配置 │ ├─src_infer.csv # 推理用原始数据 │ ├─src_test.csv # 测试用原始数据 │ └─src_train.csv # 训练用原始数据 ├─src │ ├─dataset.py # 数据集加载器 │ ├─embedding.py # 实体嵌入模组 │ ├─metrics.py # 模型度量 │ ├─path_gen.py # 数据预处理器 │ ├─recommend.py # 推理结果集成器 │ └─tbnet.py # TB-Net网络架构 ├─export.py # 导出MINDIR/AIR文件脚本 ├─preprocess.py # 数据预处理脚本 ├─eval.py # 评估网络脚本 ├─infer.py # 推理和解释脚本 ├─train.py # 训练网络脚本 └─tbnet_config.py # 配置阅读器 ``` ### 准备 Python 环境 TB-Net 是 XAI 的一部份,用户在安装好 [MindSpore](https://mindspore.cn/install) 及 [XAI](https://www.mindspore.cn/xai/docs/zh-CN/r1.8/installation.html) 后即可使用,支持 GPU。 ## 数据预处理 本步骤的完整用例代码:[preprocess.py](https://gitee.com/mindspore/xai/blob/r1.8/models/whitebox/tbnet/preprocess.py) 。 在训练 TB-Net 前我们必须把原始数据转换为关系路径数据。 ### 原始数据格式 Steam 数据集的所有原始数据文件都拥有完全相同的 CSV 格式,文件头是: `user,item,rating,developer,genre,category,release_year` 头三个列是必需的,而且它们的次序及意义是固定的: - `user`:字串,用户ID,同一用户的数据必须被结集在同一个文件中相邻的行,把数据分散在不相邻的行或横跨不同的文件会导致错误的结果。 - `item`:字串,商品ID。 - `rating`:单字符,商品评级,可选:`c`(用户跟该商品有过互动如点击,但没有购买过)、`p`(用户购买过该商品)、`x`(其他商品)。 (备注:Steam 数据集并没有 `c` 评级的商品) 由于以上三个列的次序及意义是固定的,所以用户可以自定义它们的名称,例如 `uid,iid,act` 等。 其余的列 `developer,genre,category,release_year` 是商品的属性(即关系)列,储存字串属性值ID。用户须自定义列的名称(即关系名称)并在所有相关的原始数据文件中保持一致。最少要有一个属性列,但并没有最大数量限制。如果在一个属性中商品具有超过一个的属性值ID,它们必须由`;`分隔。如果商品不具有谋些属性,请把该属性留空。 不同使用目的原始数据文件的具体内容都有一些区别: - `src_train.csv`:训练用,在总体上,`p` 评级的商品行数要和 `c`、`x` 评级的商品行数之和大致持平,可以使用二次采样达致,无须为每个用户列出所有的商品。 - `src_test.csv`:评估用,跟 `src_train.csv` 一样, 但数据量较少。 - `src_infer.csv`:推理用,只能含有一个用户的数据,而且要把所有 `c`、`p` 及 `x` 评级的商品都列出。在 [preprocess.py](https://gitee.com/mindspore/xai/blob/r1.8/models/whitebox/tbnet/preprocess.py) 中,只有 `c` 或 `x` 评级的商品才会成为关系路径数据中的候选推荐商品。 ### 转换为关系路径数据 ```python import io import json from src.path_gen import PathGen path_gen = PathGen(per_item_paths=39) path_gen.generate("./data/steam/src_train.csv", "./data/steam/train.csv") # 储存ID映射表以留待推理时 Recommender 使用 with io.open("./data/steam/id_maps.json", mode="w", encoding="utf-8") as f: json.dump(path_gen.id_maps(), f, indent=4) # 把在 src_test.csv 及 src_infer.csv 新遇到的商品及属性ID都视为默生实体,内部ID 0 会用来代表它们 path_gen.grow_id_maps = False path_gen.generate("./data/steam/src_test.csv", "./data/steam/test.csv") # src_infer.csv 只含有一个用户的数据,只有 c 或 x 评级的商品才会成为 infer.csv 中的候选推荐商品 path_gen.subject_ratings = "cx" path_gen.generate("./data/steam/src_infer.csv", "./data/steam/infer.csv") ``` `PathGen` 是负责把原始数据转换为关系路径数据的类。 ### 关系路径数据格式 关系路径数据是没有文件头的CSV(全部为整数值),对应的数据列如下: `,