'How can one use the RemoveIsolatedNodes transform in Pytorch Geometric?

I am trying to run a graph classification problem in pytorch-geometric and I see that some of my graphs contain isolated nodes (which can cause problems). For example, my dataset is a list of pytorch data objects:

dataset = [graph1, graph2, graph3...] 

where graph1 is a pytorch-geometric data object, containing the graph's structure, node features and label. I see that pytorch geometric ALREADY HAS A TRANSFORM for precisely this task, however it doesn't say anywhere how to apply it, as it's a class that takes no input.



Solution 1:[1]

To do that, you can just use the remove_isolated_nodes method from torch_geometric.utils library. The code might look as follows:

for graph in dataset:
    graph['edge_index'] = remove_isolated_nodes(graph['edge_index'])[0]

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 LucTuc