torch_pgn.models package
Submodules
torch_pgn.models.DimeNet module
- class torch_pgn.models.DimeNet.DimeNet(args: TrainArgs, atom_dim: int)
Bases:
ModuleThe directional message passing neural network (DimeNet) from the "Directional Message Passing for Molecular Graphs" paper. DimeNet transforms messages based on the angle between them in a rotation-equivariant fashion.
Note
For an example of using a pretrained DimeNet variant, see examples/qm9_pretrained_dimenet.py.
- Parameters:
hidden_channels (int) -- Hidden embedding size.
out_channels (int) -- Size of each output sample.
num_blocks (int) -- Number of building blocks.
num_bilinear (int) -- Size of the bilinear layer tensor.
num_spherical (int) -- Number of spherical harmonics.
num_radial (int) -- Number of radial basis functions.
cutoff (float, optional) -- Cutoff distance for interatomic interactions. (default:
5.0)max_num_neighbors (int, optional) -- The maximum number of neighbors to collect for each node within the
cutoffdistance. (default:32)envelope_exponent (int, optional) -- Shape of the smooth cutoff. (default:
5)num_before_skip (int, optional) -- Number of residual layers in the interaction blocks before the skip connection. (default:
1)num_after_skip (int, optional) -- Number of residual layers in the interaction blocks after the skip connection. (default:
2)num_output_layers (int, optional) -- Number of linear layers for the output blocks. (default:
3)act (str or Callable, optional) -- The activation function. (default:
"swish")
- forward(node_feats, edge_index, pos, batch=None)
- reset_parameters()
- training: bool
- triplets(edge_index, num_nodes)
- url = 'https://github.com/klicperajo/dimenet/raw/master/pretrained/dimenet'
- class torch_pgn.models.DimeNet.DimeNetPlusPlus(args: TrainArgs, atom_dim: int)
Bases:
DimeNetThe DimeNet++ from the "Fast and Uncertainty-Aware Directional Message Passing for Non-Equilibrium Molecules" paper.
DimeNetPlusPlusis an upgrade to theDimeNetmodel with 8x faster and 10% more accurate thanDimeNet.- Parameters:
hidden_channels (int) -- Hidden embedding size.
out_channels (int) -- Size of each output sample.
num_blocks (int) -- Number of building blocks.
int_emb_size (int) -- Size of embedding in the interaction block.
basis_emb_size (int) -- Size of basis embedding in the interaction block.
out_emb_channels (int) -- Size of embedding in the output block.
num_spherical (int) -- Number of spherical harmonics.
num_radial (int) -- Number of radial basis functions.
cutoff -- (float, optional): Cutoff distance for interatomic interactions. (default:
5.0)max_num_neighbors (int, optional) -- The maximum number of neighbors to collect for each node within the
cutoffdistance. (default:32)envelope_exponent (int, optional) -- Shape of the smooth cutoff. (default:
5)num_before_skip -- (int, optional): Number of residual layers in the interaction blocks before the skip connection. (default:
1)num_after_skip -- (int, optional): Number of residual layers in the interaction blocks after the skip connection. (default:
2)num_output_layers -- (int, optional): Number of linear layers for the output blocks. (default:
3)act -- (str or Callable, optional): The activation funtion. (default:
"swish")
- training: bool
- url = 'https://raw.githubusercontent.com/gasteigerjo/dimenet/master/pretrained/dimenet_pp'
torch_pgn.models.FPEncoder module
- class torch_pgn.models.FPEncoder.FPEncoder(args: TrainArgs)
Bases:
ModulePass-through encoder module for fp datasets.
- forward(data)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
torch_pgn.models.GGNet module
- class torch_pgn.models.GGNet.GGNet(args: TrainArgs, node_dim: int, bond_dim: int)
Bases:
ModuleStock network from QM9 prediction paper gilmer et. al
- forward(data)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
torch_pgn.models.dimenet_utils module
- class torch_pgn.models.dimenet_utils.BesselBasisLayer(num_radial, cutoff=5.0, envelope_exponent=5)
Bases:
Module- forward(dist)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- reset_parameters()
- training: bool
- class torch_pgn.models.dimenet_utils.EmbeddingBlock(atom_dim, num_radial, hidden_channels, act)
Bases:
Module- forward(x, rbf, i, j)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- reset_parameters()
- training: bool
- class torch_pgn.models.dimenet_utils.Envelope(exponent)
Bases:
Module- forward(x)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- class torch_pgn.models.dimenet_utils.InteractionBlock(hidden_channels, num_bilinear, num_spherical, num_radial, num_before_skip, num_after_skip, act)
Bases:
Module- forward(x, rbf, sbf, idx_kj, idx_ji)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- reset_parameters()
- training: bool
- class torch_pgn.models.dimenet_utils.InteractionPPBlock(hidden_channels, int_emb_size, basis_emb_size, num_spherical, num_radial, num_before_skip, num_after_skip, act)
Bases:
Module- forward(x, rbf, sbf, idx_kj, idx_ji)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- reset_parameters()
- training: bool
- class torch_pgn.models.dimenet_utils.OutputBlock(num_radial, hidden_channels, out_channels, num_layers, act)
Bases:
Module- forward(x, rbf, i, num_nodes=None)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- class torch_pgn.models.dimenet_utils.OutputPPBlock(num_radial, hidden_channels, out_emb_channels, out_channels, num_layers, act)
Bases:
Module- forward(x, rbf, i, num_nodes=None)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
- class torch_pgn.models.dimenet_utils.ResidualLayer(hidden_channels, act)
Bases:
Module- forward(x)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- reset_parameters()
- training: bool
- class torch_pgn.models.dimenet_utils.SphericalBasisLayer(num_spherical, num_radial, cutoff=5.0, envelope_exponent=5)
Bases:
Module- forward(dist, angle, idx_kj)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool
torch_pgn.models.dmpnn_encoder module
- class torch_pgn.models.dmpnn_encoder.MPNEncoder(args, atom_fdim: int, bond_fdim: int)
Bases:
ModuleAn
MPNEncoderis a message passing neural network for encoding a molecule.- forward(molgraph, features_batch=None)
Encodes a batch of molecular graphs. :param mol_graph: A
BatchMolGraphrepresentinga batch of molecular graphs.
- Parameters:
features_batch -- A list of numpy arrays containing additional features.
atom_descriptors_batch -- A list of numpy arrays containing additional atomic descriptors
- Returns:
A PyTorch tensor of shape
(num_molecules, hidden_size)containing the encoding of each molecule.
- training: bool
torch_pgn.models.model module
- class torch_pgn.models.model.PGNNetwork(args: TrainArgs, node_dim: int, bond_dim: int)
Bases:
ModuleA netork that includes the message passing PFPEncoder and a feed-forward network for learning tasks.
- construct_encoder()
Constructs the message passing network for encoding proximity graphs.
- construct_feed_forward()
Constructs the feed-forward network used for regression tasks
- forward(data)
Runs the PFPNetwork on the input :param input: batch of Proximity Graphs :return: Output of the PFPNetwork
- training: bool
torch_pgn.models.nn_utils module
Helper functions for the model construction functions.
- torch_pgn.models.nn_utils.get_pool_function(pool_type: str)
Returns a fuctional form of the proper pooling function given input
- torch_pgn.models.nn_utils.get_sparse_fnctn(sparse_type: str)
- torch_pgn.models.nn_utils.index_select_ND(source, index)
Selects the message features from source corresponding to the atom or bond indices in
index. :param source: A tensor of shape(num_bonds, hidden_size)containing message features. :param index: A tensor of shape(num_atoms/num_bonds, max_num_bonds)containing the atom or bondindices to select from
source.- Returns:
A tensor of shape
(num_atoms/num_bonds, max_num_bonds, hidden_size)containing the message features corresponding to the atoms/bonds specified in index.
torch_pgn.models.pfp_encoder module
- class torch_pgn.models.pfp_encoder.PFPEncoder(args: TrainArgs, node_dim: int, bond_dim: int)
Bases:
ModuleMessage passing network used to encode interaction graphs for use in either regression or classification tasks.
- construct_nn_conv()
- forward(data)
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- training: bool