5 Star 16 Fork 11

京东零售/Galileo

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
克隆/下载
custom.md 3.21 KB
一键复制 编辑 原始数据 按行查看 历史
duhua 提交于 2021-09-09 10:46 . galileo first commit

Galileo定制化图模型

Galileo定制化图模型分两部分,自定义模型输入和自定义模型。

自定义模型输入

使用Galileo提供的Transforms构建模型输入见入门教程中的构建模型输入。更多的Transforms见API中Transforms接口。

这里介绍如何自定义模型输入。有两种方法。

  1. 继承galileo.BaseTransform实现其中的transform方法。

  2. 直接在Inputs类中实现transform方法,更加方便一些。

参考GCN的例子,使用了方法2实现了获取目标顶点的标签的transform

transform中的需要使用图引擎服务的接口来获取或采样数据,详细接口列表见API中图引擎服务的采样接口。

class Inputs(g.BaseInputs):
    def __init__(self, **kwargs):
        super().__init__(config=kwargs)

    def transform(self, vertices):
        label_name = self.config['label_name']
        label_dim = self.config['label_dim']
        vertices = tf.cast(vertices, tf.int64)
        vertices = tf.reshape(vertices, [-1])
        u_vertices, _ = tf.unique(vertices)
        # 使用 get_pod_feature 获取顶点的标签特征
        labels = gt.ops.get_pod_feature([u_vertices], [label_name],
                                        [label_dim], [tf.float32])[0]
        return dict(targets=u_vertices, labels=labels)

    def train_data(self):
        vertex_type = self.config['vertex_type']

        def base_dataset(**kwargs):
            return gt.VertexDataset(vertex_type, 100)

        return gt.dataset_pipeline(base_dataset, self.transform, **self.config)

自定义模型

自定义模型可以按照下图选择基类:

alt flow

Supervised类继承自BaseSupervised类,提供了图有监督模型的基本框架。

galileo.tf.Supervised

galileo.pytorch.Supervised

Unsupervised类继承自BaseUnsupervised类,提供了图无监督模型的基本框架。

galileo.tf.Unsupervised

galileo.pytorch.Unsupervised

自定义模型的子类只需要实现其中的encoder方法即可。基类会计算loss和metrics。

如果是继承自tf.keras.Model或torch.nn.Module那么需要实现计算loss和metrics,模型返回一个dict,key包括loss和metrics的名字。

说明:Tensorflow后端的相关类同时支持keras和estimator。

例如:

  1. 在tf的有监督graphSAGE示例中继承了galileo.tf.Supervised实现了encoder方法。
  2. 在tf的无监督graphSAGE示例中继承了galileo.tf.Unsupervised实现了target_encoder方法和context_encoder方法。
  3. 在tf的GCN示例中继承了tf.keras.Model,在call方法是返回一个dict。

更多的定制化图模型例子见图模型

马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
1
https://gitee.com/jd-platform-opensource/galileo.git
git@gitee.com:jd-platform-opensource/galileo.git
jd-platform-opensource
galileo
Galileo
main

搜索帮助

Cb406eda 1850385 E526c682 1850385