Source code for ctlearn.tools.train_model

"""
Tool to train a ``CTLearnModel`` on R1/DL1a data using the ``DLDataReader`` and ``DLDataLoader``.
"""

import atexit
import keras
import pandas as pd
import numpy as np
import tensorflow as tf

from ctapipe.core import Tool
from ctapipe.core.tool import ToolConfigurationError
from ctapipe.core.traits import (
    Bool,
    CaselessStrEnum,
    Path,
    Float,
    Int,
    List,
    Dict,
    classes_with_traits,
    ComponentName,
    Unicode,
)
from dl1_data_handler.reader import DLDataReader
from ctlearn import __version__ as ctlearn_version
from ctlearn.core.loader import DLDataLoader
from ctlearn.core.model import CTLearnModel
from ctlearn.utils import validate_trait_dict


[docs] class TrainCTLearnModel(Tool): """ Tool to train a ``~ctlearn.core.model.CTLearnModel`` on R1/DL1a data. The tool trains a CTLearn model on the input data (R1 calibrated waveforms or DL1a images) and saves the trained model in the output directory. The input data is loaded from the input directories for signal and background events using the ``~dl1_data_handler.reader.DLDataReader`` and ``~dl1_data_handler.loader.DLDataLoader``. The tool supports the following reconstruction tasks: - Classification of the primary particle type (gamma/proton) - Regression of the primary particle energy - Regression of the primary particle arrival direction based on the offsets in camera coordinates - Regression of the primary particle arrival direction based on the offsets in sky coordinates """ name = "ctlearn-train-model" description = __doc__ examples = """ To train a CTLearn model for the classification of the primary particle type: > ctlearn-train-model \\ --signal /path/to/your/gammas_dl1_dir/ \\ --pattern-signal "gamma_*_run1.dl1.h5" \\ --pattern-signal "gamma_*_run10.dl1.h5" \\ --background /path/to/your/protons_dl1_dir/ \\ --pattern-background "proton_*_run1.dl1.h5" \\ --pattern-background "proton_*_run10.dl1.h5" \\ --output /path/to/your/type/ \\ --reco type \\ To train a CTLearn model for the regression of the primary particle energy: > ctlearn-train-model \\ --signal /path/to/your/gammas_dl1_dir/ \\ --pattern-signal "gamma_*_run1.dl1.h5" \\ --pattern-signal "gamma_*_run10.dl1.h5" \\ --output /path/to/your/energy/ \\ --reco energy \\ To train a CTLearn model for the regression of the primary particle arrival direction based on the offsets in camera coordinates: > ctlearn-train-model \\ --signal /path/to/your/gammas_dl1_dir/ \\ --pattern-signal "gamma_*_run1.dl1.h5" \\ --pattern-signal "gamma_*_run10.dl1.h5" \\ --output /path/to/your/direction/ \\ --reco cameradirection \\ To train a CTLearn model for the regression of the primary particle arrival direction based on the offsets in sky coordinates: > ctlearn-train-model \\ --signal /path/to/your/gammas_dl1_dir/ \\ --pattern-signal "gamma_*_run1.dl1.h5" \\ --pattern-signal "gamma_*_run10.dl1.h5" \\ --output /path/to/your/direction/ \\ --reco skydirection \\ """ input_dir_signal = Path( help="Input directory for signal events", allow_none=False, exists=True, directory_ok=True, file_ok=False, ).tag(config=True) file_pattern_signal = List( trait=Unicode(), default_value=["*.h5"], help="List of specific file pattern for matching files in ``input_dir_signal``", ).tag(config=True) input_dir_background = Path( default_value=None, help="Input directory for background events", allow_none=True, exists=True, directory_ok=True, file_ok=False, ).tag(config=True) file_pattern_background = List( trait=Unicode(), default_value=["*.h5"], help="List of specific file pattern for matching files in ``input_dir_background``", ).tag(config=True) dl1dh_reader_type = ComponentName(DLDataReader, default_value="DLImageReader").tag( config=True ) stack_telescope_images = Bool( default_value=False, allow_none=False, help=( "Set whether to stack the telescope images in the data loader. " "Requires DLDataReader mode to be ``stereo``." ), ).tag(config=True) sort_by_intensity = Bool( default_value=False, allow_none=True, help=( "Set whether to sort the telescope images by intensity in the data loader. " "Requires DLDataReader mode to be ``stereo``." ), ).tag(config=True) model_type = ComponentName(CTLearnModel, default_value="ResNet").tag(config=True) output_dir = Path( exits=False, default_value=None, allow_none=False, directory_ok=True, file_ok=False, help="Output directory for the trained reconstructor.", ).tag(config=True) reco_tasks = List( trait=CaselessStrEnum(["type", "energy", "cameradirection", "skydirection"]), allow_none=False, help=( "List of reconstruction tasks to perform. " "'type': classification of the primary particle type; " "'energy': regression of the primary particle energy; " "'cameradirection': reconstruction of the primary particle arrival direction in camera coordinates; " "'skydirection': reconstruction of the primary particle arrival direction in sky coordinates." ), ).tag(config=True) n_epochs = Int( default_value=10, allow_none=False, help="Number of epochs to train the neural network.", ).tag(config=True) batch_size = Int( default_value=64, allow_none=False, help="Size of the batch to train the neural network.", ).tag(config=True) validation_split = Float( default_value=0.1, help="Fraction of the data to use for validation", min=0.01, max=0.99, ).tag(config=True) save_best_validation_only = Bool( default_value=True, allow_none=False, help="Set whether to save the best validation checkpoint only.", ).tag(config=True) optimizer = Dict( default_value={ "name": "Adam", "base_learning_rate": 0.0001, "adam_epsilon": 1.0e-8, }, help=( "Optimizer to use for training. " "E.g. {'name': 'Adam', 'base_learning_rate': 0.0001, 'adam_epsilon': 1.0e-8}. " ), ).tag(config=True) lr_reducing = Dict( default_value={ "factor": 0.5, "patience": 5, "min_delta": 0.01, "min_lr": 0.000001, }, allow_none=True, help=( "Learning rate reducing parameters for the Keras callback. " "E.g. {'factor': 0.5, 'patience': 5, 'min_delta': 0.01, 'min_lr': 0.000001}. " ), ).tag(config=True) random_seed = Int( default_value=0, help=( "Random seed for shuffling the data " "before the training/validation split " "and after the end of an epoch." ), ).tag(config=True) save_onnx = Bool( default_value=False, allow_none=False, help="Set whether to save model in an ONNX file.", ).tag(config=True) early_stopping = Dict( default_value=None, allow_none=True, help=( "Early stopping parameters for the Keras callback. " "E.g. {'monitor': 'val_loss', 'patience': 4, 'verbose': 1, 'restore_best_weights': True}. " ), ).tag(config=True) aliases = { "signal": "TrainCTLearnModel.input_dir_signal", "background": "TrainCTLearnModel.input_dir_background", "pattern-signal": "TrainCTLearnModel.file_pattern_signal", "pattern-background": "TrainCTLearnModel.file_pattern_background", "reco": "TrainCTLearnModel.reco_tasks", ("o", "output"): "TrainCTLearnModel.output_dir", } classes = classes_with_traits(CTLearnModel) + classes_with_traits(DLDataReader)
[docs] def setup(self): self.log.info("ctlearn version %s", ctlearn_version) # Check if the output directory exists if self.output_dir.exists(): raise ToolConfigurationError( f"Output directory {self.output_dir} already exists." ) # Create a MirroredStrategy. self.strategy = tf.distribute.MirroredStrategy() atexit.register(self.strategy._extended._collective_ops._lock.locked) # type: ignore self.log.info("Number of devices: %s", self.strategy.num_replicas_in_sync) # Get signal input files self.input_url_signal = [] for signal_pattern in self.file_pattern_signal: self.input_url_signal.extend(self.input_dir_signal.glob(signal_pattern)) # Get bkg input files self.input_url_background = [] if self.input_dir_background is not None: for background_pattern in self.file_pattern_background: self.input_url_background.extend( self.input_dir_background.glob(background_pattern) ) # Set up the data reader self.log.info("Loading data:") self.log.info("For a large dataset, this may take a while...") if self.dl1dh_reader_type == "DLFeatureVectorReader": raise NotImplementedError( "'DLFeatureVectorReader' is not supported in CTLearn yet. " "Missing stereo CTLearnModel implementation." ) self.dl1dh_reader = DLDataReader.from_name( self.dl1dh_reader_type, input_url_signal=sorted(self.input_url_signal), input_url_background=sorted(self.input_url_background), parent=self, ) self.log.info("Number of events loaded: %s", self.dl1dh_reader._get_n_events()) if "type" in self.reco_tasks: self.log.info( "Number of signal events: %d", self.dl1dh_reader.n_signal_events ) self.log.info( "Number of background events: %d", self.dl1dh_reader.n_bkg_events ) # Check if the number of events is enough to form a batch if self.dl1dh_reader._get_n_events() < self.batch_size: raise ValueError( f"{self.dl1dh_reader._get_n_events()} events are not enough " f"to form a batch of size {self.batch_size}. Reduce the batch size." ) # Check if there are at least two classes in the reader for the particle classification if self.dl1dh_reader.class_weight is None and "type" in self.reco_tasks: raise ValueError( "Classification task selected but less than two classes are present in the data." ) # Check if stereo mode is selected for stacking telescope images if self.stack_telescope_images and self.dl1dh_reader.mode == "mono": raise ToolConfigurationError( f"Cannot stack telescope images in mono mode. Use stereo mode for stacking." ) # Ckeck if only one telescope type is selected for stacking telescope images if ( self.stack_telescope_images and len(list(self.dl1dh_reader.selected_telescopes)) > 1 ): raise ToolConfigurationError( f"Cannot stack telescope images from multiple telescope types. Use only one telescope type." ) # Check if sorting by intensity is disabled for stacking telescope images if self.stack_telescope_images and self.sort_by_intensity: raise ToolConfigurationError( f"Cannot stack telescope images when sorting by intensity. Disable sorting by intensity." ) # Set up the data loaders for training and validation indices = list(range(self.dl1dh_reader._get_n_events())) # Shuffle the indices before the training/validation split np.random.seed(self.random_seed) np.random.shuffle(indices) n_validation_examples = int( self.validation_split * self.dl1dh_reader._get_n_events() ) training_indices = indices[n_validation_examples:] validation_indices = indices[:n_validation_examples] self.training_loader = DLDataLoader( self.dl1dh_reader, training_indices, tasks=self.reco_tasks, batch_size=self.batch_size * self.strategy.num_replicas_in_sync, random_seed=self.random_seed, sort_by_intensity=self.sort_by_intensity, stack_telescope_images=self.stack_telescope_images, ) self.validation_loader = DLDataLoader( self.dl1dh_reader, validation_indices, tasks=self.reco_tasks, batch_size=self.batch_size * self.strategy.num_replicas_in_sync, random_seed=self.random_seed, sort_by_intensity=self.sort_by_intensity, stack_telescope_images=self.stack_telescope_images, ) # Set up the callbacks monitor = "val_loss" monitor_mode = "min" # Model checkpoint callback model_path = f"{self.output_dir}/ctlearn_model.keras" model_checkpoint_callback = keras.callbacks.ModelCheckpoint( filepath=model_path, monitor=monitor, verbose=1, mode=monitor_mode, save_best_only=self.save_best_validation_only, ) # Tensorboard callback tensorboard_callback = keras.callbacks.TensorBoard( log_dir=self.output_dir, histogram_freq=1 ) # CSV logger callback csv_logger_callback = keras.callbacks.CSVLogger( filename=f"{self.output_dir}/training_log.csv", append=True ) self.callbacks = [ model_checkpoint_callback, tensorboard_callback, csv_logger_callback, ] if self.early_stopping is not None: # EarlyStopping callback validate_trait_dict( self.early_stopping, ["monitor", "patience", "verbose", "restore_best_weights"], ) early_stopping_callback = keras.callbacks.EarlyStopping( monitor=self.early_stopping["monitor"], patience=self.early_stopping["patience"], verbose=self.early_stopping["verbose"], restore_best_weights=self.early_stopping["restore_best_weights"], ) self.callbacks.append(early_stopping_callback) # Learning rate reducing callback if self.lr_reducing is not None: # Validate the learning rate reducing parameters validate_trait_dict( self.lr_reducing, ["factor", "patience", "min_delta", "min_lr"] ) lr_reducing_callback = keras.callbacks.ReduceLROnPlateau( monitor=monitor, factor=self.lr_reducing["factor"], patience=self.lr_reducing["patience"], mode=monitor_mode, verbose=1, min_delta=self.lr_reducing["min_delta"], min_lr=self.lr_reducing["min_lr"], ) self.callbacks.append(lr_reducing_callback)
[docs] def start(self): # Open a strategy scope. with self.strategy.scope(): # Construct the model self.log.info("Setting up the model.") self.model = CTLearnModel.from_name( self.model_type, input_shape=self.training_loader.input_shape, tasks=self.reco_tasks, parent=self, ).model # Validate the optimizer parameters validate_trait_dict(self.optimizer, ["name", "base_learning_rate"]) # Set the learning rate for the optimizer learning_rate = self.optimizer["base_learning_rate"] # Set the epsilon for the Adam optimizer adam_epsilon = None if self.optimizer["name"] == "Adam": # Validate the epsilon for the Adam optimizer validate_trait_dict(self.optimizer, ["adam_epsilon"]) # Set the epsilon for the Adam optimizer adam_epsilon = self.optimizer["adam_epsilon"] # Select optimizer with appropriate arguments # Dict of optimizer_name: (optimizer_fn, optimizer_args) optimizers = { "Adadelta": ( keras.optimizers.Adadelta, dict(learning_rate=learning_rate), ), "Adam": ( keras.optimizers.Adam, dict(learning_rate=learning_rate, epsilon=adam_epsilon), ), "RMSProp": ( keras.optimizers.RMSprop, dict(learning_rate=learning_rate), ), "SGD": (keras.optimizers.SGD, dict(learning_rate=learning_rate)), } # Get the optimizer function and arguments optimizer_fn, optimizer_args = optimizers[self.optimizer["name"]] # Get the losses and metrics for the model losses, metrics = self._get_losses_and_mertics(self.reco_tasks) # Compile the model self.log.info("Compiling CTLearn model.") self.model.compile( optimizer=optimizer_fn(**optimizer_args), loss=losses, metrics=metrics ) # Train and evaluate the model self.log.info("Training and evaluating...") self.model.fit( self.training_loader, validation_data=self.validation_loader, epochs=self.n_epochs, class_weight=self.dl1dh_reader.class_weight, callbacks=self.callbacks, verbose=2, ) self.log.info("Training and evaluating finished succesfully!")
[docs] def finish(self): # Saving model weights in onnx format if self.save_onnx: self.log.info("Converting Keras model into ONNX format...") self.log.info("Make sure tf2onnx is installed in your enviroment!") try: import tf2onnx except ImportError: raise ImportError("tf2onnx is not installed in your environment!") output_path = f"{self.output_dir}/ctlearn_model.onnx" tf2onnx.convert.from_keras( self.model, input_signature=self.model.input_layer.input._type_spec, output_path=output_path, ) self.log.info("ONNX model saved in %s", self.output_dir) self.log.info("Tool is shutting down")
def _get_losses_and_mertics(self, tasks): """ Build the fully connected head for the CTLearn model. Function to build the fully connected head of the CTLearn model using the specified parameters. Parameters ---------- inputs : keras.layers.Layer Keras layer of the model. layers : dict Dictionary containing the number of neurons (as value) in the fully connected head for each task (as key). tasks : list List of tasks to build the head for. Returns ------- logits : dict Dictionary containing the logits for each task. """ losses, metrics = {}, {} if "type" in self.reco_tasks: losses["type"] = keras.losses.CategoricalCrossentropy( reduction="sum_over_batch_size" ) metrics["type"] = [ keras.metrics.CategoricalAccuracy(name="accuracy"), keras.metrics.AUC(name="auc"), ] # Temp fix till keras support class weights for multiple outputs or I wrote custom loss # https://github.com/keras-team/keras/issues/11735 if len(tasks) == 1: losses = losses["type"] metrics = metrics["type"] if "energy" in self.reco_tasks: losses["energy"] = keras.losses.MeanAbsoluteError( reduction="sum_over_batch_size" ) metrics["energy"] = keras.metrics.MeanAbsoluteError(name="mae_energy") if "cameradirection" in self.reco_tasks: losses["cameradirection"] = keras.losses.MeanAbsoluteError( reduction="sum_over_batch_size" ) metrics["cameradirection"] = keras.metrics.MeanAbsoluteError( name="mae_cameradirection" ) if "skydirection" in self.reco_tasks: losses["skydirection"] = keras.losses.MeanAbsoluteError( reduction="sum_over_batch_size" ) metrics["skydirection"] = keras.metrics.MeanAbsoluteError( name="mae_skydirection" ) return losses, metrics
[docs] def main(): # Run the tool tool = TrainCTLearnModel() tool.run()
if __name__ == "main": main()