HessDataset Class for Age Estimation in Astronomy
This Python script defines a custom PyTorch Dataset named HessDataset, utilized for creating datasets from histograms of stellar populations in astronomy. It preprocesses the data for neural network training, aiming at estimating stellar ages. The script includes functionality for loading and normalizing the data, adding markers to the input histograms, and providing an interface for iterating over the dataset.
import sys sys.path.insert(1, "/home/tboudreaux/Algebrist/Astronomy/GraduateSchool/Thesis/JaoAgeDateing/Theoretical/CNNGapQuantification/sampleDataGeneration") from mkHistFromfCMD import iter_hist_file from mkHistFromfCMD import get_hist_file_meta import torch from torch.utils.data import Dataset, DataLoader from torchvision import transforms, utils import numpy as np from tqdm import tqdm import matplotlib.pyplot as plt class HessDataset(Dataset): def __init__(self, path, idxStop=None): length, bins = get_hist_file_meta(path) if idxStop: length = min(length, idxStop) self.ages = np.empty(shape=(length)) self.IPUT = np.empty(shape=(length, 1, bins, bins)) for idx, (age, loadedData) in tqdm(enumerate(iter_hist_file(path)), total=length): if idxStop: if idx >= idxStop: break self.IPUT[idx] = np.reshape(loadedData['H'].copy(), (1, bins, bins)) self._add_markers(idx, loadedData['X'], loadedData['Y']) self.ages[idx] = age self.IPUT=torch.FloatTensor(self.IPUT) self.mean, self.std, self.var = torch.mean(self.IPUT), torch.std(self.IPUT), torch.var(self.IPUT) self.IPUT = (self.IPUT-self.mean)/self.std self.mean, self.std, self.var = torch.mean(self.IPUT), torch.std(self.IPUT), torch.var(self.IPUT) def __len__(self): return self.IPUT.shape[0] def _add_markers(self, idx, color, mag): colorMarker0 = np.argmin(abs(color-2.3)) colorMarker1 = np.argmin(abs(color - 2.325)) colorMarker2 = np.argmin(abs(color - 2.31)) magMarker0 = np.argmin(abs(mag - 10.40)) magMarker1 = np.argmin(abs(mag - 10.20)) magMarker2 = np.argmin(abs(mag - 10.30)) self.IPUT[idx, 0, magMarker0-3:magMarker0+3, colorMarker0-3:colorMarker0+3] = 10 self.IPUT[idx, 0, magMarker1-3:magMarker1+3, colorMarker1-3:colorMarker1+3] = 10 self.IPUT[idx, 0, magMarker2-3:magMarker2+3, colorMarker2-3:colorMarker2+3] = 10 def __getitem__(self, idx): if torch.is_tensor(idx): idx = idx.tolist() return self.IPUT[idx], self.ages[idx]