torch_pgn.train package
Submodules
torch_pgn.train.Trainer module
Class allowing for better multiple training runs.
- class torch_pgn.train.Trainer.Trainer(args)
Bases:
objectClass that loaders and holds the arguments and working_data objects. Allows for easy evaluation and retraining when the same working_data is going to be used multiple times.
- get_score()
Returns the score :return:
- load_checkpoint(path)
Loads a checkpoint file and sets it to the model to be used in training. :param path: The path of the checkpoint file to be loaded.
- load_data()
- run_training()
Runs training. :return:
- set_hyperopt_args(hyperopt_args, reload_data=False)
Changes the arguments used for training. The working_data arguments must be the same as initialization args unless load_data is set to true :param hyperopt_args: The new argument object. :param reload_data: Bool toggle to determine whether the working_data will be reloaded. :return: None
torch_pgn.train.cross_validate_model module
- torch_pgn.train.cross_validate_model.cross_validation(args, train_data)
Function to run cross validation to train a model. :param args: TrainArgs object containing the parameters for the cross-validation run. :param train_data: The training to be used in cross-validation :param test_data: The testing working_data if loaded to be used to evaluate model performance. :return:
torch_pgn.train.evaluate_model module
- torch_pgn.train.evaluate_model.evaluate(model, data_loader, args, metrics, mean=0, std=1, remove_norm=True)
Function used to evaluate the model performance on a given dataset. :param model: The model to be evaluated. :param data_loader: The dataloader containing the working_data the model with be evaluated on. :param args: TrainArgs object containing the relevant arguments for evaluation. :param metric: The metrics used to evaluate the model on the given working_data. :param mean:The mean of the non-normalized working_data. :param std: The stddev. of the non-normalized working_data. :return: The value of the metrics.
torch_pgn.train.hyperopt module
Hyperparameter optimization. Adapted from: https://github.com/chemprop/chemprop/blob/master/chemprop/hyperparameter_optimization.py
- torch_pgn.train.hyperopt.hyperopt(args)
Runs hyperparmeter optimization. :param args: The arguments class containing the arguments used for optimization and training. :return: None
- torch_pgn.train.hyperopt.hyperparameter_optimization()
Processes hyperparameter optimization arguements and initiates an optimization run using the specified parameters.
This function serves as an entry point for the command torch_pgn_hyperparameter_optimization in the command line
torch_pgn.train.run_training module
- torch_pgn.train.run_training.run_training(args)
Wrapper for running training given TrainArgs :param args: TrainArgs object with parameters to be used in training :return: torch_pgn.train.Trainer.Trainer object
- torch_pgn.train.run_training.train()
Processes training arguements and runs training using the specified parameters/
This serves as an entry points for the command line 'train' command.
torch_pgn.train.train module
- torch_pgn.train.train.train(model, data_loader, loss_function, optimizer, scheduler, train_args, epoch_num=0, logger=None, writer=None, device='cpu')
Trains the model for an epoch. :param model: A PFPNetwork to be trained :param data_loader: The Dataloader to be used for the epoch of training :param loss_function: The function used to calculate the loss of the model :param optimizer: The optimizer used for training :param scheduler: A learning rate scheduler :param train_args: The arguments used to determine the training method :param epoch_num: The epoch number of training :param logger: A tensorboard logger used to record details of training :param writer: A tensorboard writer used to output recorded training details :return: The average loss of the epoch.
torch_pgn.train.train_model module
- torch_pgn.train.train_model.train_model(args, train_data, validation_data, test_data=None)
Function to run a complete run of training. The function also constructs the model and writes the output of the training to the specified model directory (see documentation). :param args: The TrainArgs container that contains the training parameters and settings :param train_data: The training working_data as a ProximityGraphDataset :param validation_data: The validation working_data as a ProximityGraphDatatset :param test_data: The testing working_data as a ProximityGraphDataset (defaults to None for training instances where the held out test-set is not used for evaluation) :return: The best best model from training as determined by validation score and a dictionary of the metrics to evaluate validation performance.
torch_pgn.train.train_utils module
- torch_pgn.train.train_utils.format_batch(train_args, data)
- torch_pgn.train.train_utils.get_labels(data_loader)
Helper function to get the ground truth values :param data_loader: dataloader to be get the ground truth values from :return: A labels array (np)
- torch_pgn.train.train_utils.get_metric_functions(metrics)
Returns the relevant metric functions given a list of valid metrics :param metrics: A list of valid metrics in {'rmse', 'mse', 'r2', 'pcc', 'aucroc', 'aucprc'} :return: A dictionary that maps metrics to functions.
- torch_pgn.train.train_utils.load_args(path, device)
Loads args from checkpoint file. :param path: The path which contains the checkpoint file. :param device: The device to to set the args.device parameter to. :return: The loaded model file.
- torch_pgn.train.train_utils.load_checkpoint(path, device, return_args=False)
Loads a checkpoint. :param path: The path which contains the checkpoint file. :param device: The device to loader the model to. :return: The loaded model file.
- torch_pgn.train.train_utils.make_save_directories(save_directory)
Formats the empty save directory in order to have the proper format. :param save_directory: An empty directory where the output of training will be saved. :return: None
- torch_pgn.train.train_utils.mse_loss(predicted, actual, num_graphs)
Returns the MSE loss for given predicted and ground_truth values for a given number of graphs (batch size) :param predicted: The output of the model on the given batch of working_data :param actual: The ground truth value for the given graphs :param num_graphs: The number of graphs being evaluated (not used) :return: The average MSE loss over the given graphs.
- torch_pgn.train.train_utils.parse_loss(args)
Parses the arg for loss function and returns the appropriate loss function (currently either RMSE or MSE). :param args: An instance of train args :return: The loss function
- torch_pgn.train.train_utils.predict(model, data_loader, args, progress_bar=True, return_labels=False, remove_norm=False)
Return the result when the specified model is applied to the working_data in the data_loader :param model: The model being used to predict :param data_loader: The pytorch_geometric dataloader object containing the working_data to be evaluated. :param args: The TrainArgs object that contains the required accessory arguments :param return_labels: Boolean toggle of whether to collect and return the labels at the same time as generating the predictions. This is useful when the dataset being analyzed in shuffled :return: The raw output of the model as a numpy array.
- torch_pgn.train.train_utils.rmse_loss(predicted, actual, num_graphs)
Returns the RMSE loss for given predicted and ground_truth values for a given number of graphs (batch size) :param predicted: The output of the model on the given batch of working_data :param actual: The ground truth value for the given graphs :param num_graphs: The number of graphs being evaluated :return: The average RMSE loss over the given graphs.
- torch_pgn.train.train_utils.save_checkpoint(path, model, args)
Save the current state of training including the model and the arguments used to instantiate the model. :param path: The path to save the state to. :param model: The current model. :param args: The training arguments used to construct/parameterize the model. :return: None