Source code for

Tool to apply machine learning models in bulk (as opposed to event by event).
import numpy as np
import tables
from astropy.table import Table, join, vstack
from import tqdm

from ctapipe.core.tool import Tool
from ctapipe.core.traits import Bool, Integer, List, Path, classes_with_traits, flag
from import HDF5Merger, TableLoader, write_table
from import read_table
from import TelListToMaskTransform
from import _join_subarray_events
from ctapipe.reco import Reconstructor

__all__ = [

[docs]class ApplyModels(Tool): """Apply machine learning models to data. This tool predicts all events at once. To apply models in the regular event loop, set the appropriate options to ``ctapipe-process``. Models need to be trained with `` and ``. """ name = "ctapipe-apply-models" description = __doc__ examples = """ ctapipe-apply-models \\ --input gamma.dl2.h5 \\ --reconstructor energy_regressor.pkl \\ --reconstructor particle-classifier.pkl \\ --output gamma_applied.dl2.h5 """ input_url = Path( default_value=None, allow_none=False, directory_ok=False, exists=True, help="Input dl1b/dl2 file", ).tag(config=True) output_path = Path( default_value=None, allow_none=False, directory_ok=False, help="Output file", ).tag(config=True) reconstructor_paths = List( Path(exists=True, directory_ok=False), default_value=[], help="Paths to trained reconstructors to be applied to the input data", ).tag(config=True) chunk_size = Integer( default_value=100000, allow_none=True, help="How many subarray events to load at once for making predictions.", ).tag(config=True) progress_bar = Bool( help="show progress bar during processing", default_value=True, ).tag(config=True) aliases = { ("i", "input"): "ApplyModels.input_url", ("r", "reconstructor"): "ApplyModels.reconstructor_paths", ("o", "output"): "ApplyModels.output_path", "chunk-size": "ApplyModels.chunk_size", } flags = { **flag( "progress", "ProcessorTool.progress_bar", "show a progress bar during event processing", "don't show a progress bar during event processing", ), **flag( "dl1-parameters", "HDF5Merger.dl1_parameters", "Include dl1 parameters", "Exclude dl1 parameters", ), **flag( "dl1-images", "HDF5Merger.dl1_images", "Include dl1 images", "Exclude dl1 images", ), **flag( "true-parameters", "HDF5Merger.true_parameters", "Include true parameters", "Exclude true parameters", ), **flag( "true-images", "HDF5Merger.true_images", "Include true images", "Exclude true images", ), } classes = [TableLoader] + classes_with_traits(Reconstructor)
[docs] def setup(self): """ Initialize components from config """ self.check_output(self.output_path)"Copying to output destination.") with HDF5Merger(self.output_path, parent=self) as merger: merger(self.input_url) self.h5file = self.enter_context(tables.open_file(self.output_path, mode="r+")) self.loader = self.enter_context( TableLoader( self.input_url, parent=self, load_dl1_parameters=True, load_dl2=True, load_instrument=True, load_dl1_images=False, load_simulated=False, load_observation_info=True, ) ) self._reconstructors = [, parent=self, subarray=self.loader.subarray) for path in self.reconstructor_paths ]
[docs] def start(self): """Apply models to input tables""" chunk_iterator = self.loader.read_telescope_events_by_id_chunked( self.chunk_size ) bar = tqdm( chunk_iterator, desc="Applying reconstructors", unit=" Array Events", total=chunk_iterator.n_total, disable=not self.progress_bar, ) with bar: for chunk, (start, stop, tel_tables) in enumerate(chunk_iterator): for reconstructor in self._reconstructors: self.log.debug("Applying %s to chunk %d", reconstructor, chunk) self._apply(reconstructor, tel_tables, start=start, stop=stop) bar.update(stop - start)
def _apply(self, reconstructor, tel_tables, start, stop): prefix = reconstructor.prefix for tel_id, table in tel_tables.items(): tel =[tel_id] if tel not in reconstructor._models: self.log.warning( "No model in %s for telescope type %s, skipping tel %d", reconstructor, tel, tel_id, ) continue if len(table) == 0:"No events for telescope %d", tel_id) continue predictions = reconstructor.predict_table(tel, table) for prop, prediction_table in predictions.items(): # copy/overwrite columns into full feature table new_columns = prediction_table.colnames for col in new_columns: table[col] = prediction_table[col] output_columns = ["obs_id", "event_id", "tel_id"] + new_columns write_table( table[output_columns], self.output_path, f"/dl2/event/telescope/{prop}/{prefix}/tel_{tel_id:03d}", append=True, ) self._combine(reconstructor, tel_tables, start=start, stop=stop) def _combine(self, reconstructor, tel_tables, start, stop): stereo_table = vstack(list(tel_tables.values())) combiner = reconstructor.stereo_combiner stereo_predictions = combiner.predict_table(stereo_table) del stereo_table trafo = TelListToMaskTransform(self.loader.subarray) for c in filter( lambda c:"telescopes"), stereo_predictions.columns.values(), ): stereo_predictions[] = np.array([trafo(r) for r in c]) stereo_predictions[].description = c.description # stacking the single telescope tables and joining # potentially changes the order of the subarray events. # to ensure events are stored in the correct order, # we resort to trigger table order trigger = read_table( self.h5file, "/dl1/event/subarray/trigger", start=start, stop=stop )[["obs_id", "event_id"]] trigger["__sort_index__"] = np.arange(len(trigger)) stereo_predictions = _join_subarray_events(trigger, stereo_predictions) stereo_predictions.sort("__sort_index__") del stereo_predictions["__sort_index__"] write_table( stereo_predictions, self.output_path, f"/dl2/event/subarray/{}/{combiner.prefix}", append=True, ) for tel_table in tel_tables.values(): _add_stereo_prediction(tel_table, stereo_predictions)
def _add_stereo_prediction(tel_events, array_events): """Add columns from array_events table to tel_events table""" join_table = Table( { "obs_id": tel_events["obs_id"], "event_id": tel_events["event_id"], "__sort_index__": np.arange(len(tel_events)), } ) joined = join(join_table, array_events, keys=["obs_id", "event_id"]) del join_table joined.sort("__sort_index__") joined.remove_columns(["obs_id", "event_id", "__sort_index__"]) for colname in joined.colnames: tel_events[colname] = joined[colname] def main(): ApplyModels().run() if __name__ == "__main__": main()