Example Code for Creating Histogram Plots in Python using PyTorch and Matplotlib

This Python script is a combination of data loading, model training preparation, and visualization steps for a machine learning project. It involves loading a dataset (HessDataset), potentially for a classification task with PyTorch's DataLoader, and then visualizing histogram data using matplotlib. The script showcases loading custom dataset, iterating over data batches, and saving plots as PDF files. It uses libraries such as PyTorch for data handling and neural network models, sklearn for classification metrics, and matplotlib for generating visualizations. This could be part of a larger project focusing on machine learning applications in analyzing and visualizing dataset distributions.

from dataset import HessDataset
 
from model import HessClass
from sklearn.metrics import classification_report
from torch.utils.data import random_split
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision.datasets import KMNIST
from torch.optim import Adam
from torch import nn
 
import matplotlib.pyplot as plt
import numpy as np
import argparse
import torch
import time
 
from tqdm import tqdm
 
if __name__ == "__main__":
    AllData = HessDataset("../sampleDataGeneration/histograms_OPLIB_200.np7", idxStop=100)
 
    dl = DataLoader(AllData, batch_size=1)
    for (x,y) in dl:
        x = x[0]
        y = y[0]
        fig, ax = plt.subplots(1,1,figsize=(10,7))
        ax.imshow(x[0])
        ax.set_title(y)
        plt.savefig(f"figs/Plot_at_{y:0.6f}.pdf", bbox_inches='tight')