简体   繁体   中英

Predicting in pytorch: How to interpret a pytorch graph prediction?

I have a graph-based Pytorch model and then I wanted to predict the class for 10 graphs.

The data objects (ie train_dataset in the code below) looks like this:

[Data(x=[10, 5], edge_index=[2, 18], y=[1]), Data(x=[15, 5], edge_index=[2, 28], y=[1]), Data(x=[13, 5], edge_index=[2, 24], y=[1]), Data(x=[18, 5], edge_index=[2, 34], y=[1]), Data(x=[14, 5], edge_index=[2, 26], y=[1]), Data(x=[13, 5], edge_index=[2, 24], y=[1]), Data(x=[15, 5], edge_index=[2, 28], y=[1]), Data(x=[19, 5], edge_index=[2, 36], y=[1]), Data(x=[15, 5], edge_index=[2, 28], y=[1]), Data(x=[27, 5], edge_index=[2, 52], y=[1])]

So I ran this (where model is a model I have built):

predict_dataset = new_dataset[0:10] 
for i in predict_dataset:
    prediction = model(i)
    label = torch.argmax()
    print(prediction)

And my output is:

(tensor(5.5788e-05, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), tensor(1.))
(tensor(0.0190, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), tensor(1.))
(tensor(5.0663e-05, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), tensor(1.))
(tensor(0.0338, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), tensor(1.))
(tensor(4.7684e-07, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), tensor(1.))
(tensor(2.9166, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), tensor(0.))
(tensor(0.1944, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), tensor(1.))
(tensor(0.0591, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), tensor(1.))
(tensor(1.9073e-06, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), tensor(1.))
(tensor(0.0025, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>), tensor(1.))

I'm confused what the numbers mean, is the last item in each tuple the predicted class? And then what's the first number?

Thanks, just not sure if I've predicted properly so all suggestions/other code examples appreciated.

I believe that the first tensor is the loss since it is the output of the BinaryCrossEntropy , and the second tensor, as you said, is the index of the predicted class. But It would be helpful if you show the result of print(model) to further understand the model.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM