Source code for ctlearn.utils

import importlib
import logging
import os
import pkg_resources
import sys
import time

import numpy as np
import pandas as pd
import tables
import yaml

[docs]def setup_logging(config, log_dir, debug, log_to_file): # Log configuration to a text file in the log dir time_str = time.strftime('%Y%m%d_%H%M%S') config_filename = os.path.join(log_dir, time_str + '_config.yml') with open(config_filename, 'w') as outfile: ctlearn_version = pkg_resources.get_distribution("ctlearn").version tensorflow_version = pkg_resources.get_distribution("tensorflow-gpu").version outfile.write('# Training performed with ' 'CTLearn version {} and TensorFlow version {}.\n'.format(ctlearn_version, tensorflow_version)) yaml.dump(config, outfile, default_flow_style=False) # Set up logger logger = logging.getLogger() if debug: logger.setLevel(logging.DEBUG) else: logger.setLevel(logging.INFO) logger.handlers = [] # remove existing handlers from any previous runs if not log_to_file: handler = logging.StreamHandler() else: logging_filename = os.path.join(log_dir, time_str + '_logfile.log') handler = logging.FileHandler(logging_filename) handler.setFormatter(logging.Formatter("%(levelname)s:%(message)s")) logger.addHandler(handler) return logger
[docs]def setup_DL1DataReader(config, mode): # Parse file list or prediction file list if mode in ['train', 'load_only']: if isinstance(config['Data']['file_list'], str): data_files = [] with open(config['Data']['file_list']) as f: for line in f: line = line.strip() if line and line[0] != "#": data_files.append(line) config['Data']['file_list'] = data_files if not isinstance(config['Data']['file_list'], list): raise ValueError("Invalid file list '{}'. " "Must be list or path to file".format(config['Data']['file_list'])) else: file_list = config['Prediction']['prediction_file_lists'][config['Prediction']['prediction_label']] if file_list.endswith(".txt"): data_files = [] with open(file_list) as f: for line in f: line = line.strip() if line and line[0] != "#": data_files.append(line) config['Data']['file_list'] = data_files elif file_list.endswith(".h5"): config['Data']['file_list'] = [file_list] if not isinstance(config['Data']['file_list'], list): raise ValueError("Invalid prediction file list '{}'. " "Must be list or path to file".format(file_list)) with tables.open_file(config['Data']['file_list'][0], mode="r") as f: if 'CTA PRODUCT DATA MODEL NAME' in f.root._v_attrs: data_format = 'stage1' elif 'dl1_data_handler_version' in f.root._v_attrs: data_format = 'dl1dh' else: raise ValueError("Data format is not implemented in the DL1DH reader. Available data formats are 'stage1' and 'dl1dh'.") allow_overwrite = config['Data'].get('allow_overwrite', True) if 'allow_overwrite' in config['Data']: del config['Data']['allow_overwrite'] selected_telescope_types = config['Data']['selected_telescope_types'] camera_types = [tel_type.split("_")[-1] for tel_type in selected_telescope_types] tasks = config['Reco'] transformations = [] event_info = [] if data_format == 'dl1dh': # Parse list of event selection filters event_selection = {} for s in config['Data'].get('event_selection', {}): s = {'module': 'dl1_data_handler.filters', **s} filter_fn, filter_params = load_from_module(**s) event_selection[filter_fn] = filter_params config['Data']['event_selection'] = event_selection # Parse list of image selection filters image_selection = {} for s in config['Data'].get('image_selection', {}): s = {'module': 'dl1_data_handler.filters', **s} filter_fn, filter_params = load_from_module(**s) image_selection[filter_fn] = filter_params config['Data']['image_selection'] = image_selection if 'direction' in tasks: event_info.append('src_pos_cam_x') event_info.append('src_pos_cam_y') transformations.append({'name': 'AltAz', 'args': {'alt_col_name': 'src_pos_cam_x', 'az_col_name': 'src_pos_cam_y', 'deg2rad': False}}) else: if 'direction' in tasks: event_info.append('true_alt') event_info.append('true_az') transformations.append({'name': 'DeltaAltAz_fix_subarray'}) if 'particletype' in tasks: event_info.append('true_shower_primary_id') transformations.append({'name': 'ShowerPrimaryID'}) if 'energy' in tasks: event_info.append('true_energy') transformations.append({'name': 'MCEnergy'}) concat_telescopes = config['Input'].get('concat_telescopes', False) if config['Data']['mode'] == 'stereo' and not concat_telescopes: for tel_desc in selected_telescope_types: transformations.append({'name': 'SortTelescopes', 'args': {'sorting': 'size', 'tel_desc': f'{tel_desc}'}}) # Convert interpolation image shapes from lists to tuples, if present if 'interpolation_image_shape' in config['Data'].get('mapping_settings',{}): config['Data']['mapping_settings']['interpolation_image_shape'] = {k: tuple(l) for k, l in config['Data']['mapping_settings']['interpolation_image_shape'].items()} if allow_overwrite: config['Data']['event_info'] = event_info config['Data']['mapping_settings']['camera_types'] = camera_types else: transformations = config['Data'].get('transforms', {}) transforms = [] # Parse list of Transforms for t in transformations: t = {'module': 'dl1_data_handler.transforms', **t} transform, args = load_from_module(**t) transforms.append(transform(**args)) config['Data']['transforms'] = transforms # Possibly add additional info to load if predicting to write later if mode == 'predict': if 'Prediction' not in config: config['Prediction'] = {} if config['Prediction'].get('save_identifiers', False): if 'event_info' not in config['Data']: config['Data']['event_info'] = [] config['Data']['event_info'].extend(['event_id', 'obs_id']) if config['Data']['mode'] == 'mono': if 'array_info' not in config['Data']: config['Data']['array_info'] = [] config['Data']['array_info'].append('id') return config['Data'], data_format
[docs]def load_from_module(name, module, path=None, args=None): if path is not None and path not in sys.path: sys.path.append(path) mod = importlib.import_module(module) fn = getattr(mod, name) params = args if args is not None else {} return fn, params