-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
80 lines (62 loc) · 2.15 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import sys
import urllib.request
from collections import defaultdict
from tqdm import tqdm
import numpy as np
import torch
import pandas as pd
import matplotlib.pyplot as plt
import networkx as nx
from tqdm.notebook import tqdm
from rdkit import Chem
from rdkit.Chem import AllChem, Draw
from rdkit.Chem.Draw import IPythonConsole
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, Subset
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.utils import to_networkx
from torch_geometric.nn.models import AttentiveFP
def atom_feature(atom):
return [atom.GetAtomicNum(),
atom.GetDegree(),
atom.GetNumImplicitHs(),
atom.GetIsAromatic()]
def bond_feature(bond):
return [bond.GetBondType(),
bond.GetStereo()]
def smi_to_pyg(smi, y):
mol = Chem.MolFromSmiles(smi)
if mol is None:
return None
id_pairs = ((b.GetBeginAtomIdx(), b.GetEndAtomIdx()) for b in mol.GetBonds())
atom_pairs = [z for (i, j) in id_pairs for z in ((i, j), (j, i))]
bonds = (mol.GetBondBetweenAtoms(i, j) for (i, j) in atom_pairs)
atom_features = [atom_feature(a) for a in mol.GetAtoms()]
bond_features = [bond_feature(b) for b in bonds]
return Data(edge_index=torch.LongTensor(list(zip(*atom_pairs))),
x=torch.FloatTensor(atom_features),
edge_attr=torch.FloatTensor(bond_features),
y=torch.LongTensor([y]),
mol=mol,
smiles=smi)
@torch.no_grad()
def predict(loader, model):
y_pred = []
y_true = []
for data in loader:
out = model(data.x, data.edge_index, data.edge_attr, data.batch)
_, predicted = torch.max(out.data, 1)
y_true.extend(data.y.cpu().numpy())
y_pred.extend(predicted.cpu().numpy())
return y_pred
class MyDataset(Dataset):
def __init__(self, smiles, response):
mols = [smi_to_pyg(smi, y) for smi, y in \
tqdm(zip(smiles, response), total=len(smiles))]
self.X = [m for m in mols if m]
def __getitem__(self, idx):
return self.X[idx]
def __len__(self):
return len(self.X)