KMNIST CNN Training Script

This is a Python script for training a Convolutional Neural Network (CNN) on the KMNIST dataset. It includes importing necessary libraries and modules, setting up argument parsing for training data path, defining network hyperparameters like learning rate and epochs, preparing the dataset for training including splitting into training and validation sets, defining a DataLoader for batch processing, and initiating the training process. The script utilizes PyTorch and its various functionalities alongside sklearn for classification report, and matplotlib for visualization of the data. It's specially designed to run on GPU if available, for accelerated computing.

from dataset import HessDataset
from model import HessClass
from sklearn.metrics import classification_report
from torch.utils.data import random_split
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision.datasets import KMNIST
from torch.optim import Adam
from torch import nn
 
import matplotlib.pyplot as plt
import numpy as np
import argparse
import torch
import time
 
from tqdm import tqdm
 
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="KMINST CNN")
    parser.add_argument("TrainingData", type=str, help="path to training data")
 
    args = parser.parse_args()
 
    INIT_LR = 1e-3
    BATCH_SIZE=32
    EPOCHS=20
 
    TRAIN_SPLIT=0.75
    VAL_SPLIT = 1-TRAIN_SPLIT
 
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
    AllData = HessDataset(args.TrainingData)
 
    print("Splitting Training and validation set")
    numTrainSamples = int(len(AllData) * TRAIN_SPLIT)
    numValSamples = int(len(AllData) * VAL_SPLIT)
    adjust = len(AllData) - (numTrainSamples + numValSamples)
    numTrainSamples += adjust
    (trainData, valData) = random_split(AllData, [numTrainSamples, numValSamples], generator=torch.Generator().manual_seed(42))
 
    trainDataLoader = DataLoader(trainData, shuffle=True, batch_size=BATCH_SIZE)
 
    for x, y in trainDataLoader:
        for sx in x:
            fig, ax = plt.subplots(1,1,figsize=(10,7))
            ax.imshow(sx[0])
            plt.show()