简体   繁体   English

StellarGraph PaddedGraphGenerator - 如何提供特定的训练、验证和测试集

[英]StellarGraph PaddedGraphGenerator - how to provide specific training, validation and test sets

I'm trying to train a basic Graph Neural Network using the StellarGraph library, in particular starting from the example provided in [0].我正在尝试使用 StellarGraph 库训练一个基本的图神经网络,特别是从 [0] 中提供的示例开始。

The example works fine, but now I would like to repeat the same exercize removing the N-Fold Crossvalidation and providing specific training, validation and test sets.该示例运行良好,但现在我想重复相同的练习,删除 N 折交叉验证并提供特定的训练、验证和测试集。 I'm trying to do so with the following code:我正在尝试使用以下代码执行此操作:

# One hot encoding
graph_training_set_labels_encoded = pd.get_dummies(graphs_training_set_labels, drop_first=True)
graph_validation_set_labels_encoded = pd.get_dummies(graphs_validation_set_labels, drop_first=True)

graphs = graphs_training_set + graphs_validation_set

# Graph generator preparation
generator = PaddedGraphGenerator(graphs=graphs)

train_gen = generator.flow([x for x in range(0, len(graphs_training_set))],
                           targets=graph_training_set_labels_encoded,
                           batch_size=batch_size)

valid_gen = generator.flow([x for x in range(len(graphs_training_set),
                                            len(graphs_training_set) + len(graphs_validation_set))],
                          targets=graph_validation_set_labels_encoded,
                          batch_size=batch_size)

# Stopping criterium
es = EarlyStopping(monitor="val_loss",
                   min_delta=0,
                   patience=20,
                   restore_best_weights=True)

# Model definition
gc_model = GCNSupervisedGraphClassification(layer_sizes=[64, 64],
                                            activations=["relu", "relu"],
                                            generator=generator,
                                            dropout=dropout_value)

x_inp, x_out = gc_model.in_out_tensors()
predictions = Dense(units=32, activation="relu")(x_out)
predictions = Dense(units=16, activation="relu")(predictions)
predictions = Dense(units=1, activation="sigmoid")(predictions)

# Creating Keras model and preparing it for training
model = Model(inputs=x_inp, outputs=predictions)
model.compile(optimizer=Adam(adam_value), loss=binary_crossentropy, metrics=["acc"])

# GNN Training
history = model.fit(train_gen, epochs=num_epochs, validation_data=valid_gen, verbose=0, callbacks=[es])

# Calculate performance on the validation data
test_metrics = model.evaluate(valid_gen, verbose=0)
valid_acc = test_metrics[model.metrics_names.index("acc")]

print(f"Test Accuracy model = {valid_acc}")

Where graphs_training_set and graphs_validation_set are lists of StellarDiGraphs.其中graphs_training_setgraphs_validation_set是 StellarDiGraphs 的列表。

I am able to run this piece of code, but it provides NaN as result.我能够运行这段代码,但它提供了 NaN 作为结果。 What could be the problem?可能是什么问题呢?

Since it is the first time I am using StellarGraph and in particular PaddedGraphGenerator.因为这是我第一次使用 StellarGraph,尤其是 PaddedGraphGenerator。 I think my mistake rely on the usage of that generator, but providing training set and validation set in different manner didn't produce better results.我认为我的错误依赖于该生成器的使用,但是以不同的方式提供训练集和验证集并没有产生更好的结果。

Thank you in advance.先感谢您。

UPDATE Fixed I typo in the code, as pointed out here (thanks to george123).更新修正了我在代码中的拼写错误,正如这里所指出的(感谢 george123)。

[0] https://stellargraph.readthedocs.io/en/stable/demos/graph-classification/gcn-supervised-graph-classification.html [0] https://stellargraph.readthedocs.io/en/stable/demos/graph-classification/gcn-supervised-graph-classification.html

I found a solution digging in the StellarGraph documentation for PaddedGraphGenerator [0] and GCN Neural Network Class GCNSupervisedGraphClassification [1].我在 PaddedGraphGenerator [0] 和 GCN 神经网络类 GCNSupervisedGraphClassification [1] 的 StellarGraph 文档中找到了一个解决方案。 Furthermore, I have found a similar question on StellarGraph Issue Tracker [2] which also points out to the solution.此外,我在 StellarGraph 问题跟踪器 [2] 上发现了一个类似的问题,它也指出了解决方案。

    # Graph generator preparation
    generator = PaddedGraphGenerator(graphs=graphs)
    
    train_gen = generator.flow([x for x in range(0, num_graphs_for_training)],
                               targets=training_graphs_labels,
                               batch_size=35)
    valid_gen = generator.flow([x for x in range(num_graphs_for_training, num_graphs_for_training + num_graphs_for_validation)],
                               targets=validation_graphs_labels,
                               batch_size=35)
    
    # Stopping criterium
    es = EarlyStopping(monitor="val_loss",
                       min_delta=0.001,
                       patience=10,
                       restore_best_weights=True)
    
    # Model definition
    gc_model = GCNSupervisedGraphClassification(layer_sizes=[64, 64],
                                                activations=["relu", "relu"],
                                                generator=generator,
                                                dropout=dropout_value)
    
    x_inp, x_out = gc_model.in_out_tensors()
    predictions = Dense(units=32, activation="relu")(x_out)
    predictions = Dense(units=16, activation="relu")(predictions)
    predictions = Dense(units=1, activation="sigmoid")(predictions)
    
    # Let's create the Keras model and prepare it for training
    model = Model(inputs=x_inp, outputs=predictions)
    model.compile(optimizer=Adam(adam_value), loss=binary_crossentropy, metrics=["acc"])
    
    # GNN Training
    history = model.fit(train_gen, epochs=num_epochs, validation_data=valid_gen, verbose=1, callbacks=[es])
    
    # Evaluate performance on the validation data
    valid_metrics = model.evaluate(valid_gen, verbose=0)
    valid_acc = valid_metrics[model.metrics_names.index("acc")]
    
    # Define test set indices temporary vars
    index_begin_test_set = num_graphs_for_training + num_graphs_for_validation
    index_end_test_set = index_begin_test_set + num_graphs_for_testing
    
    test_set_indices = [x for x in range(index_begin_test_set, index_end_test_set)]
    
    # Evaluate performance on test set
    generator_for_test_set = PaddedGraphGenerator(graphs=graphs)
    test_gen = generator_for_test_set.flow(test_set_indices)
    result = model.predict(test_gen)

I hope it can be useful to others too.我希望它也对其他人有用。

UPDATE I have updated the links below to point to the master branch.更新我已经更新了下面的链接以指向主分支。

[0] https://github.com/stellargraph/stellargraph/blob/master/stellargraph/mapper/padded_graph_generator.py [0] https://github.com/stellargraph/stellargraph/blob/master/stellargraph/mapper/padded_graph_generator.py

[1] https://github.com/stellargraph/stellargraph/blob/master/stellargraph/layer/graph_classification.py [1] https://github.com/stellargraph/stellargraph/blob/master/stellargraph/layer/graph_classification.py

[2] https://github.com/stellargraph/stellargraph/issues/1980 [2] https://github.com/stellargraph/stellargraph/issues/1980

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 如何将图像数据集拆分为 python 中的测试/训练/验证集? - How to split images dataset into test/training/validation sets in python? 如何使用分层抽样将图像文件夹拆分为测试/训练/验证集? - How to split folder of images into test/training/validation sets with stratified sampling? 在 Stellargraph 中拆分用于 Node2vec 链接预测的训练测试集 - Splitting train test sets for Node2vec link prediction in Stellargraph 如何将大型(〜50Gb)数据集划分为训练集,测试集和验证集? - How should I divide a large (~50Gb) dataset into training, test, and validation sets? 如何按比例将数据分为训练集和测试集? - How to split data into training and test sets proportionally? 训练集和验证集的区别? - Difference between training sets and validation sets? 如何 plot 正确训练和验证集的损失曲线? - how to plot correctly loss curves for training and validation sets? 将数据分为训练和测试集 - Split the data into training and test sets 尚不清楚函数“ GridSearchCV”如何分解训练和测试集 - It is not clear how the function“GridSearchCV” breaks up the training and test sets 将 tensorflow 数据集从 keras 拆分为训练集、测试集和验证集。预处理 API - Splitting a tensorflow dataset into training, test, and validation sets from keras.preprocessing API
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM