KMNIST CNN Training Script
This is a Python script for training a Convolutional Neural Network (CNN) on the KMNIST dataset. It includes importing necessary libraries and modules, setting up argument parsing for training data path, defining network hyperparameters like learning rate and epochs, preparing the dataset for training including splitting into training and validation sets, defining a DataLoader for batch processing, and initiating the training process. The script utilizes PyTorch and its various functionalities alongside sklearn for classification report, and matplotlib for visualization of the data. It's specially designed to run on GPU if available, for accelerated computing.
from dataset import HessDataset from model import HessClass from sklearn.metrics import classification_report from torch.utils.data import random_split from torch.utils.data import DataLoader from torchvision.transforms import ToTensor from torchvision.datasets import KMNIST from torch.optim import Adam from torch import nn import matplotlib.pyplot as plt import numpy as np import argparse import torch import time from tqdm import tqdm if __name__ == "__main__": parser = argparse.ArgumentParser(description="KMINST CNN") parser.add_argument("TrainingData", type=str, help="path to training data") args = parser.parse_args() INIT_LR = 1e-3 BATCH_SIZE=32 EPOCHS=20 TRAIN_SPLIT=0.75 VAL_SPLIT = 1-TRAIN_SPLIT device = torch.device("cuda" if torch.cuda.is_available() else "cpu") AllData = HessDataset(args.TrainingData) print("Splitting Training and validation set") numTrainSamples = int(len(AllData) * TRAIN_SPLIT) numValSamples = int(len(AllData) * VAL_SPLIT) adjust = len(AllData) - (numTrainSamples + numValSamples) numTrainSamples += adjust (trainData, valData) = random_split(AllData, [numTrainSamples, numValSamples], generator=torch.Generator().manual_seed(42)) trainDataLoader = DataLoader(trainData, shuffle=True, batch_size=BATCH_SIZE) for x, y in trainDataLoader: for sx in x: fig, ax = plt.subplots(1,1,figsize=(10,7)) ax.imshow(sx[0]) plt.show()