import numpy as np

from ctapipe.core import Tool
from ctapipe.core.traits import Int, IntTelescopeParameter, Path
from import TableLoader
from ctapipe.reco import CrossValidator, EnergyRegressor
from ctapipe.reco.preprocessing import check_valid_rows

[docs]class TrainEnergyRegressor(Tool): """ Tool to train a `~ctapipe.reco.EnergyRegressor` on dl1b/dl2 data. The tool first performs a cross validation to give an initial estimate on the quality of the estimation and then finally trains one model per telescope type on the full dataset. """ name = "ctapipe-train-energy-regressor" description = __doc__ examples = """ ctapipe-train-energy-regressor \\ --config train_energy_regressor.yaml \\ --input gamma.dl2.h5 \\ --output energy_regressor.pkl """ output_path = Path( default_value=None, allow_none=False, directory_ok=False, help=( "Output path for the trained reconstructor." " At the moment, pickle is the only supported format." ), ).tag(config=True) n_events = IntTelescopeParameter( default_value=None, allow_none=True, help=( "Number of events for training the model." " If not given, all available events will be used." ), ).tag(config=True) random_seed = Int( default_value=0, help="Random seed for sampling and cross validation" ).tag(config=True) aliases = { ("i", "input"): "TableLoader.input_url", ("o", "output"): "TrainEnergyRegressor.output_path", "n-events": "TrainEnergyRegressor.n_events", "cv-output": "CrossValidator.output_path", } classes = [ TableLoader, EnergyRegressor, CrossValidator, ]
[docs] def setup(self): """ Initialize components from config """ self.loader = TableLoader( parent=self, load_dl1_images=False, load_dl1_parameters=True, load_dl2=True, load_simulated=True, load_instrument=True, ) self.n_events.attach_subarray(self.loader.subarray) self.regressor = EnergyRegressor(self.loader.subarray, parent=self) self.cross_validate = CrossValidator( parent=self, model_component=self.regressor ) self.rng = np.random.default_rng(self.random_seed) self.check_output(self.output_path, self.cross_validate.output_path)
[docs] def start(self): """ Train models per telescope type. """ types = self.loader.subarray.telescope_types"Inputfile: %s", self.loader.input_url)"Training models for %d types", len(types)) for tel_type in types:"Loading events for %s", tel_type) table = self._read_table(tel_type)"Train on %s events", len(table)) self.cross_validate(tel_type, table)"Performing final fit for %s", tel_type), table)"done")
def _read_table(self, telescope_type): table = self.loader.read_telescope_events([telescope_type])"Events read from input: %d", len(table)) mask = self.regressor.quality_query.get_table_mask(table) table = table[mask]"Events after applying quality query: %d", len(table)) table = self.regressor.feature_generator(table) feature_names = self.regressor.features + [] table = table[feature_names] valid = check_valid_rows(table) if np.any(~valid): self.log.warning("Dropping non-predictable events.") table = table[valid] n_events =[telescope_type] if n_events is not None: if n_events > len(table): self.log.warning( "Number of events in table (%d) is less than requested number of events %d", len(table), n_events, ) else:"Sampling %d events", n_events) idx = self.rng.choice(len(table), n_events, replace=False) idx.sort() table = table[idx] return table
[docs] def finish(self): """ Write-out trained models and cross-validation results. """"Writing output") self.regressor.write(self.output_path, overwrite=self.overwrite) if self.cross_validate.output_path: self.cross_validate.write(overwrite=self.overwrite) self.loader.close()
def main(): TrainEnergyRegressor().run() if __name__ == "__main__": main()