Example Code for Creating Histogram Plots in Python using PyTorch and Matplotlib
This Python script is a combination of data loading, model training preparation, and visualization steps for a machine learning project. It involves loading a dataset (HessDataset), potentially for a classification task with PyTorch's DataLoader, and then visualizing histogram data using matplotlib. The script showcases loading custom dataset, iterating over data batches, and saving plots as PDF files. It uses libraries such as PyTorch for data handling and neural network models, sklearn for classification metrics, and matplotlib for generating visualizations. This could be part of a larger project focusing on machine learning applications in analyzing and visualizing dataset distributions.
from dataset import HessDataset from model import HessClass from sklearn.metrics import classification_report from torch.utils.data import random_split from torch.utils.data import DataLoader from torchvision.transforms import ToTensor from torchvision.datasets import KMNIST from torch.optim import Adam from torch import nn import matplotlib.pyplot as plt import numpy as np import argparse import torch import time from tqdm import tqdm if __name__ == "__main__": AllData = HessDataset("../sampleDataGeneration/histograms_OPLIB_200.np7", idxStop=100) dl = DataLoader(AllData, batch_size=1) for (x,y) in dl: x = x[0] y = y[0] fig, ax = plt.subplots(1,1,figsize=(10,7)) ax.imshow(x[0]) ax.set_title(y) plt.savefig(f"figs/Plot_at_{y:0.6f}.pdf", bbox_inches='tight')