HessDataset Class for Age Estimation in Astronomy

This Python script defines a custom PyTorch Dataset named HessDataset, utilized for creating datasets from histograms of stellar populations in astronomy. It preprocesses the data for neural network training, aiming at estimating stellar ages. The script includes functionality for loading and normalizing the data, adding markers to the input histograms, and providing an interface for iterating over the dataset.

import sys
sys.path.insert(1, "/home/tboudreaux/Algebrist/Astronomy/GraduateSchool/Thesis/JaoAgeDateing/Theoretical/CNNGapQuantification/sampleDataGeneration")
from mkHistFromfCMD import iter_hist_file
from mkHistFromfCMD import get_hist_file_meta
import torch
 
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
 
import numpy as np
from tqdm import tqdm
 
import matplotlib.pyplot as plt
 
class HessDataset(Dataset):
    def __init__(self, path, idxStop=None):
        length, bins = get_hist_file_meta(path)
        if idxStop:
            length = min(length, idxStop)
        self.ages = np.empty(shape=(length))
        self.IPUT = np.empty(shape=(length, 1, bins, bins))
        for idx, (age, loadedData) in tqdm(enumerate(iter_hist_file(path)), total=length):
            if idxStop:
                if idx >= idxStop: break
            self.IPUT[idx] = np.reshape(loadedData['H'].copy(), (1, bins, bins))
            self._add_markers(idx, loadedData['X'], loadedData['Y'])
            self.ages[idx] = age
        self.IPUT=torch.FloatTensor(self.IPUT)
        self.mean, self.std, self.var = torch.mean(self.IPUT), torch.std(self.IPUT), torch.var(self.IPUT)
        self.IPUT = (self.IPUT-self.mean)/self.std
        self.mean, self.std, self.var = torch.mean(self.IPUT), torch.std(self.IPUT), torch.var(self.IPUT)
 
    def __len__(self):
        return self.IPUT.shape[0]
 
    def _add_markers(self, idx, color, mag):
        colorMarker0 = np.argmin(abs(color-2.3))
        colorMarker1 = np.argmin(abs(color - 2.325))
        colorMarker2 = np.argmin(abs(color - 2.31))
        magMarker0 = np.argmin(abs(mag - 10.40))
        magMarker1 = np.argmin(abs(mag - 10.20))
        magMarker2 = np.argmin(abs(mag - 10.30))
        self.IPUT[idx, 0, magMarker0-3:magMarker0+3, colorMarker0-3:colorMarker0+3] = 10
        self.IPUT[idx, 0, magMarker1-3:magMarker1+3, colorMarker1-3:colorMarker1+3] = 10
        self.IPUT[idx, 0, magMarker2-3:magMarker2+3, colorMarker2-3:colorMarker2+3] = 10
 
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
 
        return self.IPUT[idx], self.ages[idx]