torch_pgn.models package

Submodules

torch_pgn.models.DimeNet module

class torch_pgn.models.DimeNet.DimeNet(args: TrainArgs, atom_dim: int)

Bases: Module

The directional message passing neural network (DimeNet) from the "Directional Message Passing for Molecular Graphs" paper. DimeNet transforms messages based on the angle between them in a rotation-equivariant fashion.

Note

For an example of using a pretrained DimeNet variant, see examples/qm9_pretrained_dimenet.py.

Parameters:
  • hidden_channels (int) -- Hidden embedding size.

  • out_channels (int) -- Size of each output sample.

  • num_blocks (int) -- Number of building blocks.

  • num_bilinear (int) -- Size of the bilinear layer tensor.

  • num_spherical (int) -- Number of spherical harmonics.

  • num_radial (int) -- Number of radial basis functions.

  • cutoff (float, optional) -- Cutoff distance for interatomic interactions. (default: 5.0)

  • max_num_neighbors (int, optional) -- The maximum number of neighbors to collect for each node within the cutoff distance. (default: 32)

  • envelope_exponent (int, optional) -- Shape of the smooth cutoff. (default: 5)

  • num_before_skip (int, optional) -- Number of residual layers in the interaction blocks before the skip connection. (default: 1)

  • num_after_skip (int, optional) -- Number of residual layers in the interaction blocks after the skip connection. (default: 2)

  • num_output_layers (int, optional) -- Number of linear layers for the output blocks. (default: 3)

  • act (str or Callable, optional) -- The activation function. (default: "swish")

forward(node_feats, edge_index, pos, batch=None)
reset_parameters()
training: bool
triplets(edge_index, num_nodes)
url = 'https://github.com/klicperajo/dimenet/raw/master/pretrained/dimenet'
class torch_pgn.models.DimeNet.DimeNetPlusPlus(args: TrainArgs, atom_dim: int)

Bases: DimeNet

The DimeNet++ from the "Fast and Uncertainty-Aware Directional Message Passing for Non-Equilibrium Molecules" paper.

DimeNetPlusPlus is an upgrade to the DimeNet model with 8x faster and 10% more accurate than DimeNet.

Parameters:
  • hidden_channels (int) -- Hidden embedding size.

  • out_channels (int) -- Size of each output sample.

  • num_blocks (int) -- Number of building blocks.

  • int_emb_size (int) -- Size of embedding in the interaction block.

  • basis_emb_size (int) -- Size of basis embedding in the interaction block.

  • out_emb_channels (int) -- Size of embedding in the output block.

  • num_spherical (int) -- Number of spherical harmonics.

  • num_radial (int) -- Number of radial basis functions.

  • cutoff -- (float, optional): Cutoff distance for interatomic interactions. (default: 5.0)

  • max_num_neighbors (int, optional) -- The maximum number of neighbors to collect for each node within the cutoff distance. (default: 32)

  • envelope_exponent (int, optional) -- Shape of the smooth cutoff. (default: 5)

  • num_before_skip -- (int, optional): Number of residual layers in the interaction blocks before the skip connection. (default: 1)

  • num_after_skip -- (int, optional): Number of residual layers in the interaction blocks after the skip connection. (default: 2)

  • num_output_layers -- (int, optional): Number of linear layers for the output blocks. (default: 3)

  • act -- (str or Callable, optional): The activation funtion. (default: "swish")

training: bool
url = 'https://raw.githubusercontent.com/gasteigerjo/dimenet/master/pretrained/dimenet_pp'

torch_pgn.models.FPEncoder module

class torch_pgn.models.FPEncoder.FPEncoder(args: TrainArgs)

Bases: Module

Pass-through encoder module for fp datasets.

forward(data)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool

torch_pgn.models.GGNet module

class torch_pgn.models.GGNet.GGNet(args: TrainArgs, node_dim: int, bond_dim: int)

Bases: Module

Stock network from QM9 prediction paper gilmer et. al

forward(data)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool

torch_pgn.models.dimenet_utils module

class torch_pgn.models.dimenet_utils.BesselBasisLayer(num_radial, cutoff=5.0, envelope_exponent=5)

Bases: Module

forward(dist)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

reset_parameters()
training: bool
class torch_pgn.models.dimenet_utils.EmbeddingBlock(atom_dim, num_radial, hidden_channels, act)

Bases: Module

forward(x, rbf, i, j)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

reset_parameters()
training: bool
class torch_pgn.models.dimenet_utils.Envelope(exponent)

Bases: Module

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class torch_pgn.models.dimenet_utils.InteractionBlock(hidden_channels, num_bilinear, num_spherical, num_radial, num_before_skip, num_after_skip, act)

Bases: Module

forward(x, rbf, sbf, idx_kj, idx_ji)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

reset_parameters()
training: bool
class torch_pgn.models.dimenet_utils.InteractionPPBlock(hidden_channels, int_emb_size, basis_emb_size, num_spherical, num_radial, num_before_skip, num_after_skip, act)

Bases: Module

forward(x, rbf, sbf, idx_kj, idx_ji)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

reset_parameters()
training: bool
class torch_pgn.models.dimenet_utils.OutputBlock(num_radial, hidden_channels, out_channels, num_layers, act)

Bases: Module

forward(x, rbf, i, num_nodes=None)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class torch_pgn.models.dimenet_utils.OutputPPBlock(num_radial, hidden_channels, out_emb_channels, out_channels, num_layers, act)

Bases: Module

forward(x, rbf, i, num_nodes=None)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool
class torch_pgn.models.dimenet_utils.ResidualLayer(hidden_channels, act)

Bases: Module

forward(x)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

reset_parameters()
training: bool
class torch_pgn.models.dimenet_utils.SphericalBasisLayer(num_spherical, num_radial, cutoff=5.0, envelope_exponent=5)

Bases: Module

forward(dist, angle, idx_kj)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool

torch_pgn.models.dmpnn_encoder module

class torch_pgn.models.dmpnn_encoder.MPNEncoder(args, atom_fdim: int, bond_fdim: int)

Bases: Module

An MPNEncoder is a message passing neural network for encoding a molecule.

forward(molgraph, features_batch=None)

Encodes a batch of molecular graphs. :param mol_graph: A BatchMolGraph representing

a batch of molecular graphs.

Parameters:
  • features_batch -- A list of numpy arrays containing additional features.

  • atom_descriptors_batch -- A list of numpy arrays containing additional atomic descriptors

Returns:

A PyTorch tensor of shape (num_molecules, hidden_size) containing the encoding of each molecule.

training: bool

torch_pgn.models.model module

class torch_pgn.models.model.PGNNetwork(args: TrainArgs, node_dim: int, bond_dim: int)

Bases: Module

A netork that includes the message passing PFPEncoder and a feed-forward network for learning tasks.

construct_encoder()

Constructs the message passing network for encoding proximity graphs.

construct_feed_forward()

Constructs the feed-forward network used for regression tasks

forward(data)

Runs the PFPNetwork on the input :param input: batch of Proximity Graphs :return: Output of the PFPNetwork

training: bool

torch_pgn.models.nn_utils module

Helper functions for the model construction functions.

torch_pgn.models.nn_utils.get_pool_function(pool_type: str)

Returns a fuctional form of the proper pooling function given input

torch_pgn.models.nn_utils.get_sparse_fnctn(sparse_type: str)
torch_pgn.models.nn_utils.index_select_ND(source, index)

Selects the message features from source corresponding to the atom or bond indices in index. :param source: A tensor of shape (num_bonds, hidden_size) containing message features. :param index: A tensor of shape (num_atoms/num_bonds, max_num_bonds) containing the atom or bond

indices to select from source.

Returns:

A tensor of shape (num_atoms/num_bonds, max_num_bonds, hidden_size) containing the message features corresponding to the atoms/bonds specified in index.

torch_pgn.models.pfp_encoder module

class torch_pgn.models.pfp_encoder.PFPEncoder(args: TrainArgs, node_dim: int, bond_dim: int)

Bases: Module

Message passing network used to encode interaction graphs for use in either regression or classification tasks.

construct_nn_conv()
forward(data)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

training: bool

Module contents