Deep Learning Model Training Script for KMNIST Dataset

This Python script is designed to train a convolutional neural network (CNN) on the KMNIST dataset. It includes loading and splitting the dataset, defining the model architecture in a separate class, setting up the training loop with loss function and optimizer, and plotting the training/validation loss and accuracy over epochs. The script allows for command line arguments to specify paths for training data, model saving, and loss plot saving. It uses libraries such as PyTorch for model training and matplotlib for plotting.

from dataset import HessDataset
 
import matplotlib
matplotlib.use('Agg')
 
 
from model import HessClass
from sklearn.metrics import classification_report
from torch.utils.data import random_split
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision.datasets import KMNIST
from torch.optim import Adam
from torch import nn
 
import matplotlib.pyplot as plt
import numpy as np
import argparse
import torch
import time
 
from tqdm import tqdm
 
if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="KMINST CNN")
    parser.add_argument("TrainingData", type=str, help="path to training data")
    parser.add_argument("-m", "--model", type=str, required=True, help="path to save the trained model too")
    parser.add_argument("-p", "--plot", type=str, required=True, help="path to save save loss plot too")
    parser.add_argument("--idxStop", type=int, default=None, help="index to limit data too (for debugging)")
 
    args = parser.parse_args()
 
    INIT_LR = 1e-3
    BATCH_SIZE=32
    EPOCHS=20
 
    TRAIN_SPLIT=0.75
    VAL_SPLIT = 1-TRAIN_SPLIT
 
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
    AllData = HessDataset(args.TrainingData, idxStop=args.idxStop)
 
    print("Splitting Training and validation set")
    numTrainSamples = int(len(AllData) * TRAIN_SPLIT)
    numValSamples = int(len(AllData) * VAL_SPLIT)
    adjust = len(AllData) - (numTrainSamples + numValSamples)
    numTrainSamples += adjust
    (trainData, valData) = random_split(AllData, [numTrainSamples, numValSamples], generator=torch.Generator().manual_seed(42))
 
    trainDataLoader = DataLoader(trainData, shuffle=True, batch_size=BATCH_SIZE)
    valDataLoader = DataLoader(valData, batch_size=BATCH_SIZE)
 
    trainSteps = len(trainDataLoader.dataset) // BATCH_SIZE
    valSteps = len(valDataLoader.dataset) // BATCH_SIZE
 
    print("Initing Model")
    model = HessClass().to(device)
 
    opt = Adam(model.parameters(), lr=INIT_LR)
    # lossFn = nn.NLLLoss()
    lossFn = lambda output, target : torch.mean((output-target)**2)
 
    H = {
            "train_loss" : list(),
            "train_acc" : list(),
            "val_loss" : list(),
            "val_acc" : list()
            }
 
    print("Training")
 
    startTime = time.perf_counter()
 
    for e in range(0, EPOCHS):
        model.train()
 
        totalTrainLoss = 0
        totalValLoss = 0
 
        trainCorrect = 0
        valCorrect = 0
 
        for (x,y) in trainDataLoader:
            # print(x)
            (x,y) = (x.to(device, dtype=torch.float), y.to(device, dtype=torch.float))
 
            pred = model(x)
            loss = lossFn(pred, y)
 
            opt.zero_grad()
            loss.backward()
            opt.step()
 
            totalTrainLoss += loss
            trainCorrect += (pred.argmax(1) == y).type(torch.float).sum().item()
 
        with torch.no_grad():
            model.eval()
 
            for (x,y) in valDataLoader:
                (x,y) = (x.to(device, dtype=torch.float), y.to(device, dtype=torch.float))
 
                pred = model(x)
                totalValLoss += lossFn(pred, y)
 
                valCorrect += (pred.argmax(1) == y).type(torch.float).sum().item()
 
        avgTrainLoss = totalTrainLoss/trainSteps
        avgValLoss = totalValLoss/valSteps
 
        trainCorrect = trainCorrect/len(trainDataLoader.dataset)
        valCorrect = valCorrect/len(valDataLoader.dataset)
 
        H['train_loss'].append(avgTrainLoss.cpu().detach().numpy())
        H['train_acc'].append(trainCorrect)
        H['val_loss'].append(avgValLoss.cpu().detach().numpy())
        H['val_acc'].append(valCorrect)
 
        print(f"EPOCH {e+1}/{EPOCHS}")
        print(f"Training Loss: {avgTrainLoss:0.6f}, Training Accuracy: {trainCorrect:0.4f}")
        print(f"Validation Loss: {avgValLoss:0.6f}, Validation Accuracy: {valCorrect:0.4f}")
 
    endtime = time.perf_counter()
    print(f"training took {endtime-startTime}")
    # print("Testing")
    #
    # with torch.no_grad():
    #     model.eval()
    #
    #     preds = []
    #
    #     for (x,y) in testDataLoader:
    #         x = x.to(device)
    #
    #         pred = model(x)
    #         preds.extend(pred.argmax(axis=1).cpu().numpy())
    #
    # print(classification_report(testData.targets.cpu().numpy(), np.array(preds), target_names=testData.classes))
    torch.save(model, args.model)
    #
    plt.style.use('ggplot')
    plt.figure()
    plt.plot(H['train_loss'], label='train_loss')
    plt.plot(H['val_loss'], label='val_loss')
    plt.plot(H['train_acc'], label="train_acc")
    plt.plot(H['val_acc'], label='val_acc')
 
    plt.title("loss and acc")
    plt.xlabel("EPOCH")
    plt.ylabel("Loss/Acc")
    plt.legend()
    plt.savefig(args.plot)