torch_pgn.evaluate package
Submodules
torch_pgn.evaluate.plot_utils module
- torch_pgn.evaluate.plot_utils.plot_correlation(model, args, data_loader, mean=0, std=1, metrics=None, filename='train_correlation', fit=True)
Simple method to plot correlations for a model. :param model: The model to be evaluated :param args: TrainArgs type object containing the parameters used to train the model :param data_loader: The torch dataloader object containing the working_data to be plotted. :param filename: The name of the plot file. :param fit: Boolean toggle for whether to include a trendline on the plot. :return: None (saved plot to savedir results file).