From 92d3a02dd50f391399adb2ba85b0b8b30118da07 Mon Sep 17 00:00:00 2001 From: Matija Marijan Date: Fri, 19 Dec 2025 12:52:10 +0100 Subject: [PATCH 1/2] Transpose protein embedding matrix, apply three Conv1d with max pooling --- models/gat.py | 17 ++++++++++------- models/gat_gcn.py | 13 +++++++++---- models/gcn.py | 13 +++++++++---- models/ginconv.py | 13 +++++++++---- 4 files changed, 37 insertions(+), 19 deletions(-) diff --git a/models/gat.py b/models/gat.py index f444714..be27cda 100644 --- a/models/gat.py +++ b/models/gat.py @@ -18,8 +18,10 @@ def __init__(self, num_features_xd=78, n_output=1, num_features_xt=25, # 1D convolution on protein sequence self.embedding_xt = nn.Embedding(num_features_xt + 1, embed_dim) - self.conv_xt1 = nn.Conv1d(in_channels=1000, out_channels=n_filters, kernel_size=8) - self.fc_xt1 = nn.Linear(32*121, output_dim) + self.conv_xt_1 = nn.Conv1d(in_channels=embed_dim, out_channels=n_filters, kernel_size=8) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=2*n_filters, kernel_size=8) + self.conv_xt_3 = nn.Conv1d(in_channels=2*n_filters, out_channels=3*n_filters, kernel_size=8) + self.fc1_xt = nn.Linear(3*n_filters, output_dim) # combined layers self.fc1 = nn.Linear(256, 1024) @@ -46,12 +48,13 @@ def forward(self, data): # protein input feed-forward: target = data.target embedded_xt = self.embedding_xt(target) - conv_xt = self.conv_xt1(embedded_xt) - conv_xt = self.relu(conv_xt) + embedded_xt = torch.permute(embedded_xt, (0, 2, 1)) - # flatten - xt = conv_xt.view(-1, 32 * 121) - xt = self.fc_xt1(xt) + conv_xt = self.conv_xt_1(embedded_xt) + conv_xt = self.conv_xt_2(conv_xt) + conv_xt = self.conv_xt_3(conv_xt) + xt = torch.max(conv_xt, dim = -1)[0] + xt = self.fc1_xt(xt) # concat xc = torch.cat((x, xt), 1) diff --git a/models/gat_gcn.py b/models/gat_gcn.py index d99d1b4..aee8201 100644 --- a/models/gat_gcn.py +++ b/models/gat_gcn.py @@ -23,8 +23,10 @@ def __init__(self, n_output=1, num_features_xd=78, num_features_xt=25, # 1D convolution on protein sequence self.embedding_xt = nn.Embedding(num_features_xt + 1, embed_dim) - self.conv_xt_1 = nn.Conv1d(in_channels=1000, out_channels=n_filters, kernel_size=8) - self.fc1_xt = nn.Linear(32*121, output_dim) + self.conv_xt_1 = nn.Conv1d(in_channels=embed_dim, out_channels=n_filters, kernel_size=8) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=2*n_filters, kernel_size=8) + self.conv_xt_3 = nn.Conv1d(in_channels=2*n_filters, out_channels=3*n_filters, kernel_size=8) + self.fc1_xt = nn.Linear(3*n_filters, output_dim) # combined layers self.fc1 = nn.Linear(256, 1024) @@ -46,9 +48,12 @@ def forward(self, data): x = self.fc_g2(x) embedded_xt = self.embedding_xt(target) + embedded_xt = torch.permute(embedded_xt, (0, 2, 1)) + conv_xt = self.conv_xt_1(embedded_xt) - # flatten - xt = conv_xt.view(-1, 32 * 121) + conv_xt = self.conv_xt_2(conv_xt) + conv_xt = self.conv_xt_3(conv_xt) + xt = torch.max(conv_xt, dim = -1)[0] xt = self.fc1_xt(xt) # concat diff --git a/models/gcn.py b/models/gcn.py index 6c583c3..f87fd5c 100644 --- a/models/gcn.py +++ b/models/gcn.py @@ -22,8 +22,10 @@ def __init__(self, n_output=1, n_filters=32, embed_dim=128,num_features_xd=78, n # protein sequence branch (1d conv) self.embedding_xt = nn.Embedding(num_features_xt + 1, embed_dim) - self.conv_xt_1 = nn.Conv1d(in_channels=1000, out_channels=n_filters, kernel_size=8) - self.fc1_xt = nn.Linear(32*121, output_dim) + self.conv_xt_1 = nn.Conv1d(in_channels=embed_dim, out_channels=n_filters, kernel_size=8) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=2*n_filters, kernel_size=8) + self.conv_xt_3 = nn.Conv1d(in_channels=2*n_filters, out_channels=3*n_filters, kernel_size=8) + self.fc1_xt = nn.Linear(3*n_filters, output_dim) # combined layers self.fc1 = nn.Linear(2*output_dim, 1024) @@ -54,9 +56,12 @@ def forward(self, data): # 1d conv layers embedded_xt = self.embedding_xt(target) + embedded_xt = torch.permute(embedded_xt, (0, 2, 1)) + conv_xt = self.conv_xt_1(embedded_xt) - # flatten - xt = conv_xt.view(-1, 32 * 121) + conv_xt = self.conv_xt_2(conv_xt) + conv_xt = self.conv_xt_3(conv_xt) + xt = torch.max(conv_xt, dim = -1)[0] xt = self.fc1_xt(xt) # concat diff --git a/models/ginconv.py b/models/ginconv.py index bd37f4c..a081030 100644 --- a/models/ginconv.py +++ b/models/ginconv.py @@ -41,8 +41,10 @@ def __init__(self, n_output=1,num_features_xd=78, num_features_xt=25, # 1D convolution on protein sequence self.embedding_xt = nn.Embedding(num_features_xt + 1, embed_dim) - self.conv_xt_1 = nn.Conv1d(in_channels=1000, out_channels=n_filters, kernel_size=8) - self.fc1_xt = nn.Linear(32*121, output_dim) + self.conv_xt_1 = nn.Conv1d(in_channels=embed_dim, out_channels=n_filters, kernel_size=8) + self.conv_xt_2 = nn.Conv1d(in_channels=n_filters, out_channels=2*n_filters, kernel_size=8) + self.conv_xt_3 = nn.Conv1d(in_channels=2*n_filters, out_channels=3*n_filters, kernel_size=8) + self.fc1_xt = nn.Linear(3*n_filters, output_dim) # combined layers self.fc1 = nn.Linear(256, 1024) @@ -68,9 +70,12 @@ def forward(self, data): x = F.dropout(x, p=0.2, training=self.training) embedded_xt = self.embedding_xt(target) + embedded_xt = torch.permute(embedded_xt, (0, 2, 1)) + conv_xt = self.conv_xt_1(embedded_xt) - # flatten - xt = conv_xt.view(-1, 32 * 121) + conv_xt = self.conv_xt_2(conv_xt) + conv_xt = self.conv_xt_3(conv_xt) + xt = torch.max(conv_xt, dim = -1)[0] xt = self.fc1_xt(xt) # concat From 4c567bc1af260e40a06d74e57968269a0e8ee417 Mon Sep 17 00:00:00 2001 From: Matija Marijan Date: Fri, 19 Dec 2025 13:10:13 +0100 Subject: [PATCH 2/2] Add ReLU activations between Conv1d layers --- models/gat.py | 3 +++ models/gat_gcn.py | 3 +++ models/gcn.py | 3 +++ models/ginconv.py | 3 +++ 4 files changed, 12 insertions(+) diff --git a/models/gat.py b/models/gat.py index be27cda..74ff93a 100644 --- a/models/gat.py +++ b/models/gat.py @@ -51,8 +51,11 @@ def forward(self, data): embedded_xt = torch.permute(embedded_xt, (0, 2, 1)) conv_xt = self.conv_xt_1(embedded_xt) + conv_xt = self.relu(conv_xt) conv_xt = self.conv_xt_2(conv_xt) + conv_xt = self.relu(conv_xt) conv_xt = self.conv_xt_3(conv_xt) + conv_xt = self.relu(conv_xt) xt = torch.max(conv_xt, dim = -1)[0] xt = self.fc1_xt(xt) diff --git a/models/gat_gcn.py b/models/gat_gcn.py index aee8201..2240710 100644 --- a/models/gat_gcn.py +++ b/models/gat_gcn.py @@ -51,8 +51,11 @@ def forward(self, data): embedded_xt = torch.permute(embedded_xt, (0, 2, 1)) conv_xt = self.conv_xt_1(embedded_xt) + conv_xt = self.relu(conv_xt) conv_xt = self.conv_xt_2(conv_xt) + conv_xt = self.relu(conv_xt) conv_xt = self.conv_xt_3(conv_xt) + conv_xt = self.relu(conv_xt) xt = torch.max(conv_xt, dim = -1)[0] xt = self.fc1_xt(xt) diff --git a/models/gcn.py b/models/gcn.py index f87fd5c..ec27562 100644 --- a/models/gcn.py +++ b/models/gcn.py @@ -59,8 +59,11 @@ def forward(self, data): embedded_xt = torch.permute(embedded_xt, (0, 2, 1)) conv_xt = self.conv_xt_1(embedded_xt) + conv_xt = self.relu(conv_xt) conv_xt = self.conv_xt_2(conv_xt) + conv_xt = self.relu(conv_xt) conv_xt = self.conv_xt_3(conv_xt) + conv_xt = self.relu(conv_xt) xt = torch.max(conv_xt, dim = -1)[0] xt = self.fc1_xt(xt) diff --git a/models/ginconv.py b/models/ginconv.py index a081030..8d06aed 100644 --- a/models/ginconv.py +++ b/models/ginconv.py @@ -73,8 +73,11 @@ def forward(self, data): embedded_xt = torch.permute(embedded_xt, (0, 2, 1)) conv_xt = self.conv_xt_1(embedded_xt) + conv_xt = self.relu(conv_xt) conv_xt = self.conv_xt_2(conv_xt) + conv_xt = self.relu(conv_xt) conv_xt = self.conv_xt_3(conv_xt) + conv_xt = self.relu(conv_xt) xt = torch.max(conv_xt, dim = -1)[0] xt = self.fc1_xt(xt)