'How can one use the RemoveIsolatedNodes transform in Pytorch Geometric?
I am trying to run a graph classification problem in pytorch-geometric and I see that some of my graphs contain isolated nodes (which can cause problems). For example, my dataset is a list of pytorch data objects:
dataset = [graph1, graph2, graph3...]
where graph1 is a pytorch-geometric data object, containing the graph's structure, node features and label. I see that pytorch geometric ALREADY HAS A TRANSFORM for precisely this task, however it doesn't say anywhere how to apply it, as it's a class that takes no input.
Solution 1:[1]
To do that, you can just use the remove_isolated_nodes
method from torch_geometric.utils
library. The code might look as follows:
for graph in dataset:
graph['edge_index'] = remove_isolated_nodes(graph['edge_index'])[0]
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 | LucTuc |