Source code for ctlearn.run_model

import atexit
import argparse
import importlib
import logging
import math
from random import randint
import os
import glob
from pprint import pformat
import sys

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import yaml

import tensorflow as tf
from tensorflow.python import debug as tf_debug

from dl1_data_handler.reader import DL1DataReaderSTAGE1, DL1DataReaderDL1DH
from ctlearn.data_loader import KerasBatchGenerator
from ctlearn.output_handler import *
from ctlearn.utils import *


[docs]def run_model(config, mode="train", debug=False, log_to_file=False): # Load options relating to logging and checkpointing root_model_dir = model_dir = config['Logging']['model_directory'] random_seed = None if config['Logging'].get('add_seed', False): random_seed = config['Data']['seed'] model_dir += f"/seed_{random_seed}" if not os.path.exists(model_dir): if mode == 'predict': raise ValueError(f"Invalid output directory '{model_dir}'. " "Must be a path to an existing directory in the predict mode.") os.makedirs(model_dir) # Set up logging, saving the config and optionally logging to a file logger = setup_logging(config, model_dir, debug, log_to_file) # Log the loaded configuration logger.debug(pformat(config)) logger.info("Logging has been correctly set up") # Create a MirroredStrategy. strategy = tf.distribute.MirroredStrategy() atexit.register(strategy._extended._collective_ops._pool.close) # type: ignore logger.info('Number of devices: {}'.format(strategy.num_replicas_in_sync)) # Set up the DL1DataReader config['Data'], data_format = setup_DL1DataReader(config, mode) # Create data reader logger.info("Loading data:") logger.info(" For a large dataset, this may take a while...") if data_format == 'stage1': reader = DL1DataReaderSTAGE1(**config['Data']) elif data_format == 'dl1dh': reader = DL1DataReaderDL1DH(**config['Data']) logger.info(" Number of events loaded: {}".format(len(reader))) # Set up the KerasBatchGenerator indices = list(range(len(reader))) if 'Input' not in config: config['Input'] = {} batch_size_per_worker = config['Input'].get('batch_size_per_worker', 64) batch_size = batch_size_per_worker * strategy.num_replicas_in_sync concat_telescopes = config['Input'].get('concat_telescopes', False) if mode == 'train': if 'Training' not in config: config['Training'] = {} validation_split = np.float(config['Training'].get('validation_split', 0.1)) if not 0.0 < validation_split < 1.0: raise ValueError("Invalid validation split: {}. " "Must be between 0.0 and 1.0".format(validation_split)) num_training_examples = math.floor((1 - validation_split) * len(reader)) training_indices = indices[:num_training_examples] validation_indices = indices[num_training_examples:] data = KerasBatchGenerator(reader, training_indices, batch_size=batch_size, mode=mode, concat_telescopes=concat_telescopes) validation_data = KerasBatchGenerator(reader, validation_indices, batch_size=batch_size, mode=mode, concat_telescopes=concat_telescopes) elif mode == 'predict': logger.info(" Simulation info for pyirf.simulations.SimulatedEventsInfo: {}".format(reader.simulation_info)) data = KerasBatchGenerator(reader, indices, batch_size=batch_size, mode=mode, shuffle=False, concat_telescopes=concat_telescopes) # Keras is only considering the last complete batch. # In prediction mode we don't want to loose the last # uncomplete batch, so we are creating an adiitional # batch generator for the remaining events. rest = len(indices) % batch_size rest_indices = indices[-rest:] rest_data = KerasBatchGenerator(reader, rest_indices, batch_size=rest, mode=mode, shuffle=False, concat_telescopes=concat_telescopes) # Construct the model model_file = config['Model'].get('model_file', None) logger.info("Setting up model:") model_directory = config['Model'].get('model_directory', os.path.abspath(os.path.join( os.path.dirname(__file__), "default_models/"))) sys.path.append(model_directory) logger.info(" Constructing model from config.") # Write the model parameters in the params dictionary model_params = {**config['Model'], **config.get('Model Parameters', {})} model_params['model_directory'] = model_directory # Open a strategy scope. with strategy.scope(): # Backbone model backbone_module = importlib.import_module(config['Model']['backbone']['module']) backbone_model = getattr(backbone_module, config['Model']['backbone']['function']) backbone, backbone_inputs = backbone_model(data, model_params) backbone_output = backbone(backbone_inputs) # Head model head_module = importlib.import_module(config['Model']['head']['module']) head_model = getattr(head_module, config['Model']['head']['function']) logits, losses, loss_weights, metrics = head_model(inputs=backbone_output, tasks=config['Reco'], params=model_params) if 'saved_model.pb' in np.array([os.listdir(model_dir)]): logger.info(" Loading model from '{}'.".format(model_dir)) model = tf.keras.models.load_model(model_dir) else: model = tf.keras.Model(backbone_inputs, logits, name='CTLearn_model') if config['Model'].get('plot_model', False): logger.info(" Saving the backbone architecture in '{}/backbone.png'.".format(model_dir)) tf.keras.utils.plot_model(backbone, to_file=model_dir+'/backbone.png', show_shapes=True, show_layer_names=True) logger.info(" Saving the model architecture in '{}/model.png'.".format(model_dir)) tf.keras.utils.plot_model(model, to_file=model_dir+'/model.png', show_shapes=True, show_layer_names=True) logger.info(" Model has been correctly set up from config.") optimizer = config['Training'].get('optimizer', 'Adam') logger.info(" Optimizer: {}".format(optimizer)) adam_epsilon = float(config['Training'].get('adam_epsilon', 1.0e-8)) learning_rate = float(config['Training'].get('base_learning_rate', 0.0001)) logger.info(" Learning rate: {}".format(learning_rate)) # Select optimizer with appropriate arguments # Dict of optimizer_name: (optimizer_fn, optimizer_args) optimizers = { 'Adadelta': (tf.keras.optimizers.Adadelta, dict(learning_rate=learning_rate)), 'Adam': (tf.keras.optimizers.Adam, dict(learning_rate=learning_rate, epsilon=adam_epsilon)), 'RMSProp': (tf.keras.optimizers.RMSprop, dict(learning_rate=learning_rate)), 'SGD': (tf.keras.optimizers.SGD, dict(learning_rate=learning_rate)) } optimizer_fn, optimizer_args = optimizers[optimizer] optimizer = optimizer_fn(**optimizer_args) logger.info(" Compiling model.") model.compile( optimizer=optimizer, loss=losses, metrics=metrics) if mode == 'train': logger.info("Setting up training:") logger.info(" Validation split: {}".format(validation_split)) if not 0.0 < validation_split < 1.0: raise ValueError("Invalid validation split: {}. " "Must be between 0.0 and 1.0".format( validation_split)) num_epochs = int(config['Training'].get('num_epochs', 10)) logger.info(" Number of epochs: {}".format(num_epochs)) logger.info(" Size of the batches per worker: {}".format(batch_size_per_worker)) logger.info(" Size of the batches: {}".format(batch_size)) num_training_examples = math.floor((1 - validation_split) * len(reader)) logger.info(" Number of training steps per epoch: {}".format( int(num_training_examples / batch_size))) verbose = int(config['Training'].get('verbose', 2)) logger.info(" Verbosity mode: {}".format(verbose)) workers = int(config['Training'].get('workers', 1)) # ToDo: Fix multiprocessing issue workers = 1 logger.info(" Number of workers: {}".format(workers)) use_multiprocessing = True if workers > 1 else False logger.info(" Use of multiprocessing: {}".format(use_multiprocessing)) # ToDo: Come up with a better solution for the callbacks # Set up the callbacks monitor='loss' monitor_mode='min' val_freq = int((num_training_examples / batch_size) / 5) logger.info(" Validation frequency: {}".format(val_freq)) if 'particletype' in config['Reco'] and len(config['Reco'])==1: monitor='auc' monitor_mode='max' if 'energy' in config['Reco'] and len(config['Reco'])==1: monitor='mae_energy' if 'direction' in config['Reco'] and len(config['Reco'])==1: monitor='mae_direction' # Model checkpoint callback model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint( filepath=model_dir, monitor=monitor, mode=monitor_mode, save_freq=val_freq, save_best_only=True) # Tensorboard callback tensorboard_callback = tf.keras.callbacks.TensorBoard( log_dir=model_dir, histogram_freq=1, update_freq=val_freq) # CSV logger callback csv_logger_callback = tf.keras.callbacks.CSVLogger( filename= model_dir+'/training_log.csv', append=True) # Early stopping #early_stopping_callback = tf.keras.callbacks.EarlyStopping( # monitor=monitor, # patience=5, # mode=monitor_mode, # restore_best_weights=True) #callbacks = [model_checkpoint_callback, tensorboard_callback, csvlogger_callback, early_stopping_callback] callbacks = [model_checkpoint_callback, tensorboard_callback, csv_logger_callback] # Class weights calculation # Scaling by total/2 helps keep the loss to a similar magnitude. # The sum of the weights of all examples stays the same. class_weight = None if len(reader.simulated_particles) > 2: logger.info(" Apply class weights:") shower_primary_id_to_class = { 0: 1, # gamma 101: 0, # proton 1: 2, # electron 255: 0 # MAGIC real data } total = reader.simulated_particles['total'] logger.info(" Total number: {}".format(total)) class_weight = {} for particle_id, num_particles in reader.simulated_particles.items(): if particle_id != 'total': logger.info(" Breakdown by {}: {}".format(shower_primary_id_to_class[particle_id], num_particles)) class_weight[shower_primary_id_to_class[particle_id]] = (1 / num_particles) * (total / 2.0) logger.info(" Class weights: {}".format(class_weight)) initial_epoch = 0 if 'training_log.csv' in os.listdir(model_dir): initial_epoch = pd.read_csv(model_dir+'/training_log.csv')['epoch'].iloc[-1] + 1 # Train and evaluate the model logger.info("Training and evaluating...") history = model.fit(x=data, validation_data=validation_data, batch_size=batch_size, epochs=num_epochs, initial_epoch = initial_epoch, class_weight=class_weight, callbacks=callbacks, verbose=verbose, workers=workers, use_multiprocessing=use_multiprocessing) model.save(model_dir) logger.info("Training and evaluating finished succesfully!") # Plotting training history training_log = pd.read_csv(model_dir+'/training_log.csv') for metric in training_log.columns: epochs = training_log['epoch'] + 1 if metric != 'epoch' and not metric.startswith("val_"): logger.info("Plotting training history: {}".format(metric)) fig, ax = plt.subplots() plt.plot(epochs, training_log[metric]) plt.plot(epochs, training_log[f'val_{metric}']) plt.title(f'CTLearn training history - {metric}') plt.xlabel('epoch') plt.ylabel(metric) plt.legend(['train', 'val'], loc='upper left') plt.savefig(f'{model_dir}/{metric}.png') elif mode == 'predict': # Generate predictions and add to output logger.info("Predicting...") predictions = model.predict(data) predictions = np.concatenate((predictions, model.predict(rest_data)), axis=0) output_file = file = config['Prediction'].get('file', "experiment") if random_seed: file += "_{}".format(random_seed) output_file = os.path.abspath(os.path.join(os.path.dirname(__file__), model_dir+"/{}.h5".format(file))) write_output(output_file, reader, indices, predictions, config['Reco'], config['Prediction']['prediction_label']) # clear the handlers, shutdown the logging and delete the logger logger.handlers.clear() logging.shutdown() del logger return
[docs]def main(): parser = argparse.ArgumentParser( description=("Train/Predict with a CTLearn model.")) parser.add_argument( '--config_file', '-c', help="Path to YAML configuration file with training options") parser.add_argument( '--input', '-i', help='Input directories (not required when file_list is set in the config file)', nargs='+') parser.add_argument( '--pattern', '-p', help='Pattern to mask unwanted files from the data input directory', default=["*.h5"], nargs='+') parser.add_argument( '--mode', '-m', default="train", help="Mode to run CTLearn; valid options: train, predict, or train_and_predict") parser.add_argument( '--output', '-o', help="Output directory, where the logging, model weights and processed output files are stored") parser.add_argument( '--reco', '-r', help='Reconstruction task to perform; valid options: particletype, energy, and/or direction', nargs='+') parser.add_argument( '--default_model', '-d', help="Default CTLearn Model; valid options: TRN, TRN_cleaned, mergedTRN, mergedTRN_cleaned, CNNRNN, and CNNRNN_cleaned") parser.add_argument( '--pretrained_weights', '-w', help='Path to the pretrained weights') parser.add_argument( '--tel_types', '-t', help='Selection of telescope types; valid option: LST_LST_LSTCam, LST_MAGIC_MAGICCam, MST_MST_FlashCam, MST_MST_NectarCam, SST_SCT_SCTCam, and/or SST_ASTRI_ASTRICam', nargs='+') parser.add_argument( '--allowed_tels', '-a', type=int, help='List of allowed tel_ids, others will be ignored. Selected tel_ids will be ignored, when their telescope type is not selected', nargs='+') parser.add_argument( '--size_cut', '-z', type=float, help="Hillas intensity cut to perform") parser.add_argument( '--leakage_cut', '-l', type=float, help="Leakage intensity cut to perform") parser.add_argument( '--multiplicity_cut', '-u', type=int, help="Multiplicity cut to perform") parser.add_argument( '--num_epochs', '-e', type=int, help="Number of epochs to train") parser.add_argument( '--batch_size', '-b', type=int, help="Batch size per worker") parser.add_argument( '--random_seed', '-s', type=int, help="Selection of random seed (4 digits)") parser.add_argument( '--log_to_file', action='store_true', help="Log to a file in model directory instead of terminal") parser.add_argument( '--debug', action='store_true', help="Print debug/logger messages") args = parser.parse_args() # Use the default CTLearn config file if no config file is provided # and a default CTLearn model is selected. if args.default_model and not args.config_file: default_config_files = os.path.abspath(os.path.join( os.path.dirname(__file__), "default_config_files/")) args.config_file = f'{default_config_files}/{args.default_model}.yml' with open(args.config_file, 'r') as config_file: config = yaml.safe_load(config_file) if args.reco: config['Reco'] = args.reco if args.tel_types: config['Data']['selected_telescope_types'] = args.tel_types if args.allowed_tels: config['Data']['selected_telescope_ids'] = args.allowed_tels parameter_selection = [] if args.size_cut: parameter_selection.append({'col_name': 'hillas_intensity', 'min_value': args.size_cut}) if args.leakage_cut: parameter_selection.append({'col_name': 'leakage_intensity_width_2', 'max_value': args.leakage_cut}) for parameter in config['Data'].get('parameter_selection', []): if parameter['col_name'] == 'hillas_intensity' and args.size_cut: continue if parameter['col_name'] == 'leakage_intensity_width_2' and args.leakage_cut: continue parameter_selection.append(parameter) if parameter_selection: config['Data']['parameter_selection'] = parameter_selection if args.multiplicity_cut: config['Data']['multiplicity_selection'] = {'Subarray': args.multiplicity_cut} if args.output: config['Logging'] = {} config['Logging']['model_directory'] = args.output # Create output directory if it doesn't exist already if not os.path.exists(config['Logging']['model_directory']): if 'predict' in args.mode: raise ValueError(f"Invalid output directory '{config['Logging']['model_directory']}'. " "Must be a path to an existing directory in the predict mode.") os.makedirs(config['Logging']['model_directory']) # Set the path to pretrained weights from the command line if args.pretrained_weights: config['Model']['pretrained_weights'] = args.pretrained_weights config['Model']['trainable_backbone'] = False # Overwrite the number of epochs, batch size and random seed in the config file if args.num_epochs: if 'Training' not in config: config['Training'] = {} config['Training']['num_epochs'] = args.num_epochs if args.batch_size: if 'Input' not in config: config['Input'] = {} config['Input']['batch_size_per_worker'] = args.batch_size if args.random_seed: if 1000 <= args.random_seed <= 9999: config['Data']['seed'] = args.random_seed config['Logging']['add_seed'] = True else: raise ValueError("Random seed: '{}'. " "Must be 4 digit integer!".format( args.random_seed)) random_seed = config['Data'].get('seed', 1234) if 'train' in args.mode: # Shuffle the data in train mode as default if 'shuffle' not in config['Data']: config['Data']['shuffle'] = True # Training file handling training_file_list = f"{config['Logging']['model_directory']}/training_file_list.txt" if args.input: for input in args.input: abs_file_dir = os.path.abspath(input) with open(training_file_list, 'a') as file_list: for pattern in args.pattern: files = glob.glob(os.path.join(abs_file_dir, pattern)) if not files: continue for file in np.sort(files): file_list.write(f"{file}\n") config['Data']['file_list'] = training_file_list if 'training_file_list.txt' in os.listdir(config['Logging']['model_directory']): config['Data']['file_list'] = training_file_list run_model(config, mode='train', debug=args.debug, log_to_file=args.log_to_file) if 'predict' in args.mode: if args.input: for input in args.input: abs_file_dir = os.path.abspath(input) for pattern in args.pattern: files = glob.glob(os.path.join(abs_file_dir, pattern)) if not files: continue for file in files: with open(args.config_file, 'r') as config_file: config = yaml.safe_load(config_file) if args.reco: config['Reco'] = args.reco if args.tel_types: config['Data']['selected_telescope_types'] = args.tel_types if args.allowed_tels: config['Data']['selected_telescope_ids'] = args.allowed_tels if parameter_selection: config['Data']['parameter_selection'] = parameter_selection if args.multiplicity_cut: config['Data']['multiplicity_selection'] = {'Subarray': args.multiplicity_cut} if args.output: config['Logging'] = {} config['Logging']['model_directory'] = args.output if args.pretrained_weights: config['Model']['pretrained_weights'] = args.pretrained_weights config['Model']['trainable_backbone'] = False config['Data']['shuffle'] = False config['Data']['seed'] = random_seed if args.random_seed: config['Logging']['add_seed'] = True if 'Prediction' not in config: config['Prediction'] = {} config['Prediction']['file'] = file.split("/")[-1].replace("_S_", "_E_").replace("dl1", "dl2").replace(".h5","") config['Prediction']['prediction_label'] = 'data' config['Prediction']['prediction_file_lists'] = {'data': file} run_model(config, mode='predict', debug=args.debug, log_to_file=args.log_to_file) else: for key in config['Prediction']['prediction_file_lists']: with open(args.config_file, 'r') as config_file: config = yaml.safe_load(config_file) if args.reco: config['Reco'] = args.reco if args.tel_types: config['Data']['selected_telescope_types'] = args.tel_types if args.allowed_tels: config['Data']['selected_telescope_ids'] = args.allowed_tels if parameter_selection: config['Data']['parameter_selection'] = parameter_selection if args.multiplicity_cut: config['Data']['multiplicity_selection'] = {'Subarray': args.multiplicity_cut} if args.output: config['Logging'] = {} config['Logging']['model_directory'] = args.output if args.pretrained_weights: config['Model']['pretrained_weights'] = args.pretrained_weights config['Model']['trainable_backbone'] = False config['Data']['shuffle'] = False config['Data']['seed'] = random_seed if args.random_seed: config['Logging']['add_seed'] = True if 'Prediction' not in config: config['Prediction'] = {} config['Prediction']['prediction_label'] = key run_model(config, mode='predict', debug=args.debug, log_to_file=args.log_to_file)
if __name__ == "__main__": main()