torch_pgn.train package

Submodules

torch_pgn.train.Trainer module

Class allowing for better multiple training runs.

class torch_pgn.train.Trainer.Trainer(args)

Bases: object

Class that loaders and holds the arguments and working_data objects. Allows for easy evaluation and retraining when the same working_data is going to be used multiple times.

get_score()

Returns the score :return:

load_checkpoint(path)

Loads a checkpoint file and sets it to the model to be used in training. :param path: The path of the checkpoint file to be loaded.

load_data()
run_training()

Runs training. :return:

set_hyperopt_args(hyperopt_args, reload_data=False)

Changes the arguments used for training. The working_data arguments must be the same as initialization args unless load_data is set to true :param hyperopt_args: The new argument object. :param reload_data: Bool toggle to determine whether the working_data will be reloaded. :return: None

torch_pgn.train.cross_validate_model module

torch_pgn.train.cross_validate_model.cross_validation(args, train_data)

Function to run cross validation to train a model. :param args: TrainArgs object containing the parameters for the cross-validation run. :param train_data: The training to be used in cross-validation :param test_data: The testing working_data if loaded to be used to evaluate model performance. :return:

torch_pgn.train.evaluate_model module

torch_pgn.train.evaluate_model.evaluate(model, data_loader, args, metrics, mean=0, std=1, remove_norm=True)

Function used to evaluate the model performance on a given dataset. :param model: The model to be evaluated. :param data_loader: The dataloader containing the working_data the model with be evaluated on. :param args: TrainArgs object containing the relevant arguments for evaluation. :param metric: The metrics used to evaluate the model on the given working_data. :param mean:The mean of the non-normalized working_data. :param std: The stddev. of the non-normalized working_data. :return: The value of the metrics.

torch_pgn.train.hyperopt module

Hyperparameter optimization. Adapted from: https://github.com/chemprop/chemprop/blob/master/chemprop/hyperparameter_optimization.py

torch_pgn.train.hyperopt.hyperopt(args)

Runs hyperparmeter optimization. :param args: The arguments class containing the arguments used for optimization and training. :return: None

torch_pgn.train.hyperopt.hyperparameter_optimization()

Processes hyperparameter optimization arguements and initiates an optimization run using the specified parameters.

This function serves as an entry point for the command torch_pgn_hyperparameter_optimization in the command line

torch_pgn.train.run_training module

torch_pgn.train.run_training.run_training(args)

Wrapper for running training given TrainArgs :param args: TrainArgs object with parameters to be used in training :return: torch_pgn.train.Trainer.Trainer object

torch_pgn.train.run_training.train()

Processes training arguements and runs training using the specified parameters/

This serves as an entry points for the command line 'train' command.

torch_pgn.train.train module

torch_pgn.train.train.train(model, data_loader, loss_function, optimizer, scheduler, train_args, epoch_num=0, logger=None, writer=None, device='cpu')

Trains the model for an epoch. :param model: A PFPNetwork to be trained :param data_loader: The Dataloader to be used for the epoch of training :param loss_function: The function used to calculate the loss of the model :param optimizer: The optimizer used for training :param scheduler: A learning rate scheduler :param train_args: The arguments used to determine the training method :param epoch_num: The epoch number of training :param logger: A tensorboard logger used to record details of training :param writer: A tensorboard writer used to output recorded training details :return: The average loss of the epoch.

torch_pgn.train.train_model module

torch_pgn.train.train_model.train_model(args, train_data, validation_data, test_data=None)

Function to run a complete run of training. The function also constructs the model and writes the output of the training to the specified model directory (see documentation). :param args: The TrainArgs container that contains the training parameters and settings :param train_data: The training working_data as a ProximityGraphDataset :param validation_data: The validation working_data as a ProximityGraphDatatset :param test_data: The testing working_data as a ProximityGraphDataset (defaults to None for training instances where the held out test-set is not used for evaluation) :return: The best best model from training as determined by validation score and a dictionary of the metrics to evaluate validation performance.

torch_pgn.train.train_utils module

torch_pgn.train.train_utils.format_batch(train_args, data)
torch_pgn.train.train_utils.get_labels(data_loader)

Helper function to get the ground truth values :param data_loader: dataloader to be get the ground truth values from :return: A labels array (np)

torch_pgn.train.train_utils.get_metric_functions(metrics)

Returns the relevant metric functions given a list of valid metrics :param metrics: A list of valid metrics in {'rmse', 'mse', 'r2', 'pcc', 'aucroc', 'aucprc'} :return: A dictionary that maps metrics to functions.

torch_pgn.train.train_utils.load_args(path, device)

Loads args from checkpoint file. :param path: The path which contains the checkpoint file. :param device: The device to to set the args.device parameter to. :return: The loaded model file.

torch_pgn.train.train_utils.load_checkpoint(path, device, return_args=False)

Loads a checkpoint. :param path: The path which contains the checkpoint file. :param device: The device to loader the model to. :return: The loaded model file.

torch_pgn.train.train_utils.make_save_directories(save_directory)

Formats the empty save directory in order to have the proper format. :param save_directory: An empty directory where the output of training will be saved. :return: None

torch_pgn.train.train_utils.mse_loss(predicted, actual, num_graphs)

Returns the MSE loss for given predicted and ground_truth values for a given number of graphs (batch size) :param predicted: The output of the model on the given batch of working_data :param actual: The ground truth value for the given graphs :param num_graphs: The number of graphs being evaluated (not used) :return: The average MSE loss over the given graphs.

torch_pgn.train.train_utils.parse_loss(args)

Parses the arg for loss function and returns the appropriate loss function (currently either RMSE or MSE). :param args: An instance of train args :return: The loss function

torch_pgn.train.train_utils.predict(model, data_loader, args, progress_bar=True, return_labels=False, remove_norm=False)

Return the result when the specified model is applied to the working_data in the data_loader :param model: The model being used to predict :param data_loader: The pytorch_geometric dataloader object containing the working_data to be evaluated. :param args: The TrainArgs object that contains the required accessory arguments :param return_labels: Boolean toggle of whether to collect and return the labels at the same time as generating the predictions. This is useful when the dataset being analyzed in shuffled :return: The raw output of the model as a numpy array.

torch_pgn.train.train_utils.rmse_loss(predicted, actual, num_graphs)

Returns the RMSE loss for given predicted and ground_truth values for a given number of graphs (batch size) :param predicted: The output of the model on the given batch of working_data :param actual: The ground truth value for the given graphs :param num_graphs: The number of graphs being evaluated :return: The average RMSE loss over the given graphs.

torch_pgn.train.train_utils.save_checkpoint(path, model, args)

Save the current state of training including the model and the arguments used to instantiate the model. :param path: The path to save the state to. :param model: The current model. :param args: The training arguments used to construct/parameterize the model. :return: None

Module contents