How do node features get passed around in MessagePassing base class? #2120
Answered
by
rusty1s
HongtaoYang
asked this question in
Q&A
-
|
I'm confused about how node features get passed around in MessagePassing base class? See the toy example bolow. import numpy as np
import torch
from torch_geometric.nn import MessagePassing
# construct a simple graph, node features are 2 dimensional vectors whose value is simply the node id.
# e.g. for node 0, the feature is [0, 0]
edge_index = np.array([[0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 4, 4, 5, 5], [1, 2, 3, 0, 2, 0, 1, 4, 5, 0, 2, 5, 2, 4]])
x = np.array([[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5]])
edge_index = torch.from_numpy(edge_index)
x = torch.from_numpy(x)
class MyMessagePassing(MessagePassing):
def __init__(self, in_channels, out_channels):
super(MyMessagePassing, self).__init__(aggr='add')
def message(self, x_j: torch.Tensor) -> torch.Tensor:
print(x)
print(x_j)
return x_j
def forward(self, x, edge_index):
x = x + 0.1 # toggle this does not change the print(x) output.
self.propagate(edge_index, x=x)
gcn = MyMessagePassing(2, 2)
output = gcn(x, edge_index)In the example, I simply inherit The problem is no matter what operations I apply on |
Beta Was this translation helpful? Give feedback.
Answered by
rusty1s
Feb 16, 2021
Replies: 1 comment 1 reply
-
|
The def message(self, x: torch.Tensor, x_j: torch.Tensor):
print(x)
print(x_j) |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
HongtaoYang
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The
xcomes from global scope, and is therefore unmodified. Fixable by changing themessageheader to