This Python script is designed to train a convolutional neural network (CNN) on the KMNIST dataset. It includes loading and splitting the dataset, defining the model architecture in a separate class, setting up the training loop with loss function and optimizer, and plotting the training/validation loss and accuracy over epochs. The script allows for command line arguments to specify paths for training data, model saving, and loss plot saving. It uses libraries such as PyTorch for model training and matplotlib for plotting.
from dataset import HessDataset import matplotlib matplotlib.use('Agg') from model import HessClass from sklearn.metrics import classification_report from torch.utils.data import random_split from torch.utils.data import DataLoader from torchvision.transforms import ToTensor from torchvision.datasets import KMNIST from torch.optim import Adam from torch import nn import matplotlib.pyplot as plt import numpy as np import argparse import torch import time from tqdm import tqdm if __name__ == "__main__": parser = argparse.ArgumentParser(description="KMINST CNN") parser.add_argument("TrainingData", type=str, help="path to training data") parser.add_argument("-m", "--model", type=str, required=True, help="path to save the trained model too") parser.add_argument("-p", "--plot", type=str, required=True, help="path to save save loss plot too") parser.add_argument("--idxStop", type=int, default=None, help="index to limit data too (for debugging)") args = parser.parse_args() INIT_LR = 1e-3 BATCH_SIZE=32 EPOCHS=20 TRAIN_SPLIT=0.75 VAL_SPLIT = 1-TRAIN_SPLIT device = torch.device("cuda" if torch.cuda.is_available() else "cpu") AllData = HessDataset(args.TrainingData, idxStop=args.idxStop) print("Splitting Training and validation set") numTrainSamples = int(len(AllData) * TRAIN_SPLIT) numValSamples = int(len(AllData) * VAL_SPLIT) adjust = len(AllData) - (numTrainSamples + numValSamples) numTrainSamples += adjust (trainData, valData) = random_split(AllData, [numTrainSamples, numValSamples], generator=torch.Generator().manual_seed(42)) trainDataLoader = DataLoader(trainData, shuffle=True, batch_size=BATCH_SIZE) valDataLoader = DataLoader(valData, batch_size=BATCH_SIZE) trainSteps = len(trainDataLoader.dataset) // BATCH_SIZE valSteps = len(valDataLoader.dataset) // BATCH_SIZE print("Initing Model") model = HessClass().to(device) opt = Adam(model.parameters(), lr=INIT_LR) # lossFn = nn.NLLLoss() lossFn = lambda output, target : torch.mean((output-target)**2) H = { "train_loss" : list(), "train_acc" : list(), "val_loss" : list(), "val_acc" : list() } print("Training") startTime = time.perf_counter() for e in range(0, EPOCHS): model.train() totalTrainLoss = 0 totalValLoss = 0 trainCorrect = 0 valCorrect = 0 for (x,y) in trainDataLoader: # print(x) (x,y) = (x.to(device, dtype=torch.float), y.to(device, dtype=torch.float)) pred = model(x) loss = lossFn(pred, y) opt.zero_grad() loss.backward() opt.step() totalTrainLoss += loss trainCorrect += (pred.argmax(1) == y).type(torch.float).sum().item() with torch.no_grad(): model.eval() for (x,y) in valDataLoader: (x,y) = (x.to(device, dtype=torch.float), y.to(device, dtype=torch.float)) pred = model(x) totalValLoss += lossFn(pred, y) valCorrect += (pred.argmax(1) == y).type(torch.float).sum().item() avgTrainLoss = totalTrainLoss/trainSteps avgValLoss = totalValLoss/valSteps trainCorrect = trainCorrect/len(trainDataLoader.dataset) valCorrect = valCorrect/len(valDataLoader.dataset) H['train_loss'].append(avgTrainLoss.cpu().detach().numpy()) H['train_acc'].append(trainCorrect) H['val_loss'].append(avgValLoss.cpu().detach().numpy()) H['val_acc'].append(valCorrect) print(f"EPOCH {e+1}/{EPOCHS}") print(f"Training Loss: {avgTrainLoss:0.6f}, Training Accuracy: {trainCorrect:0.4f}") print(f"Validation Loss: {avgValLoss:0.6f}, Validation Accuracy: {valCorrect:0.4f}") endtime = time.perf_counter() print(f"training took {endtime-startTime}") # print("Testing") # # with torch.no_grad(): # model.eval() # # preds = [] # # for (x,y) in testDataLoader: # x = x.to(device) # # pred = model(x) # preds.extend(pred.argmax(axis=1).cpu().numpy()) # # print(classification_report(testData.targets.cpu().numpy(), np.array(preds), target_names=testData.classes)) torch.save(model, args.model) # plt.style.use('ggplot') plt.figure() plt.plot(H['train_loss'], label='train_loss') plt.plot(H['val_loss'], label='val_loss') plt.plot(H['train_acc'], label="train_acc") plt.plot(H['val_acc'], label='val_acc') plt.title("loss and acc") plt.xlabel("EPOCH") plt.ylabel("Loss/Acc") plt.legend() plt.savefig(args.plot)