Source code for ctapipe.core.traits

"""
Traitlet implementations for ctapipe
"""
import os
import pathlib
from urllib.parse import urlparse

import traitlets
import traitlets.config
from astropy.time import Time
from traitlets import Undefined

from ctapipe.core.plugins import detect_and_import_plugins

from .component import Component, non_abstract_children
from .telescope_component import TelescopeParameter

__all__ = [
    # Implemented here
    "AstroTime",
    "BoolTelescopeParameter",
    "IntTelescopeParameter",
    "FloatTelescopeParameter",
    "classes_with_traits",
    "create_class_enum_trait",
    "has_traits",
    # imported from traitlets
    "Path",
    "Bool",
    "CRegExp",
    "CaselessStrEnum",
    "CInt",
    "Dict",
    "Enum",
    "Float",
    "Int",
    "Integer",
    "List",
    "Long",
    "Set",
    "TraitError",
    "Tuple",
    "Unicode",
    "flag",
    "observe",
]

import logging

logger = logging.getLogger(__name__)


# Aliases
Bool = traitlets.Bool
Int = traitlets.Int
CInt = traitlets.CInt
Integer = traitlets.Integer
Float = traitlets.Float
Long = traitlets.Long
Unicode = traitlets.Unicode
Dict = traitlets.Dict
Enum = traitlets.Enum
List = traitlets.List
Set = traitlets.Set
CRegExp = traitlets.CRegExp
CaselessStrEnum = traitlets.CaselessStrEnum
UseEnum = traitlets.UseEnum
TraitError = traitlets.TraitError
TraitType = traitlets.TraitType
Tuple = traitlets.Tuple
observe = traitlets.observe
flag = traitlets.config.boolean_flag


class AstroTime(TraitType):
    """A trait representing a point in Time, as understood by `astropy.time`"""

    def validate(self, obj, value):
        """try to parse and return an ISO time string"""
        try:
            the_time = Time(value)
            the_time.format = "iso"
            return the_time
        except ValueError:
            return self.error(obj, value)

    def info(self):
        info = "an ISO8601 datestring or Time instance"
        if self.allow_none:
            info += "or None"
        return info


class Path(TraitType):
    """
    A path Trait for input/output files.

    Attributes
    ----------
    exists: boolean or None
        If True, path must exist, if False path must not exist
    directory_ok: boolean
        If False, path must not be a directory
    file_ok: boolean
        If False, path must not be a file
    """

    def __init__(
        self,
        default_value=Undefined,
        exists=None,
        directory_ok=True,
        file_ok=True,
        **kwargs,
    ):
        super().__init__(default_value=default_value, **kwargs)
        self.exists = exists
        self.directory_ok = directory_ok
        self.file_ok = file_ok

    def info(self):
        info = "a pathlib.Path or non-empty str for "
        if self.exists is True:
            info += "an existing"
        elif self.exists is False:
            info += "a not existing"
        else:
            info += "a"

        if self.directory_ok and self.file_ok:
            info += " directory or file"
        else:
            if self.file_ok:
                info += " file"
            if self.directory_ok:
                info += "directory"
        if self.allow_none:
            info += " or None"

        return info

    def validate(self, obj, value):
        if isinstance(value, bytes):
            value = os.fsdecode(value)

        if value is None or value is Undefined:
            if self.allow_none:
                return value
            else:
                self.error(obj, value)

        if not isinstance(value, (str, pathlib.Path)):
            return self.error(obj, value)

        # expand any environment variables in the path:
        value = os.path.expandvars(value)

        if isinstance(value, str):
            if value == "":
                return self.error(obj, value)

            try:
                url = urlparse(value)
            except ValueError:
                return self.error(obj, value)

            if url.scheme in ("http", "https"):
                # here to avoid circular import, since every module imports
                # from ctapipe.core
                from ctapipe.utils.download import download_cached

                value = download_cached(value, progress=True)
            elif url.scheme == "dataset":
                # here to avoid circular import, since every module imports
                # from ctapipe.core
                from ctapipe.utils import get_dataset_path

                value = get_dataset_path(value.partition("dataset://")[2])
            elif url.scheme in ("", "file"):
                value = pathlib.Path(url.netloc, url.path)
            else:
                return self.error(obj, value)

        value = value.absolute()
        exists = value.exists()
        if self.exists is not None:
            if exists != self.exists:
                raise TraitError(
                    'Path "{}" {} exist'.format(
                        value, "does not" if self.exists else "must not"
                    )
                )
        if exists:
            if not self.directory_ok and value.is_dir():
                raise TraitError(f'Path "{value}" must not be a directory')
            if not self.file_ok and value.is_file():
                raise TraitError(f'Path "{value}" must not be a file')

        return value


def create_class_enum_trait(base_class, default_value, help=None, allow_none=False):
    """create a configurable CaselessStrEnum traitlet from baseclass

    the enumeration should contain all names of non_abstract_children()
    of said baseclass and the default choice should be given by
    ``base_class._default`` name.

    default must be specified and must be the name of one child-class
    """
    if help is None:
        help = "{} to use.".format(base_class.__name__)

    choices = [cls.__name__ for cls in non_abstract_children(base_class)]

    if default_value not in choices:
        raise ValueError(f"{default_value} is not in choices: {choices}")

    return CaselessStrEnum(
        choices,
        default_value=default_value,
        help=help,
        allow_none=allow_none,
    ).tag(config=True)


class ComponentName(Unicode):
    """A trait that is the name of a Component class"""

    def __init__(self, cls, **kwargs):
        # we need to prevent triggering importing plugins at
        # import time to avoid circular imports, this flag is used
        # to prevent calling the full plugin mechanism at defintion
        # time of a `ComponentName`
        self._init_done = False

        if not issubclass(cls, Component):
            raise TypeError(f"cls must be a Component, got {cls}")

        self.cls = cls
        super().__init__(**kwargs)
        if "help" not in kwargs:
            self.help = f"The name of a {cls.__name__} subclass"
        self._init_done = True

    @property
    def help(self):
        if self._init_done:
            children = list(self.cls.non_abstract_subclasses())
        else:
            children = []
        return f"{self._help}. Possible values: {children}"

    @help.setter
    def help(self, value):
        self._help = value

    @property
    def info_text(self):
        if self._init_done:
            return f"Any of {list(self.cls.non_abstract_subclasses())}"
        else:
            return f"Any subclass of {self.cls}"

    def validate(self, obj, value):
        if self.allow_none and value is None:
            return None

        if value in self.cls.non_abstract_subclasses():
            return value

        self.error(obj, value)


class ComponentNameList(List):
    """A trait that is a list of Component classes"""

    def __init__(self, cls, **kwargs):
        # we need to prevent triggering importing plugins at
        # import time to avoid circular imports, this flag is used
        # to prevent calling the full plugin mechanism at defintion
        # time of a `ComponentNameList`
        self._init_done = False
        if not issubclass(cls, Component):
            raise TypeError(f"cls must be a Component, got {cls}")

        self.cls = cls
        trait = ComponentName(cls)
        super().__init__(trait=trait, **kwargs)

        if "help" not in kwargs:
            self.help = f"A list of {cls.__name__} subclass names"
        self._init_done = True

    @property
    def help(self):
        if self._init_done:
            children = list(self.cls.non_abstract_subclasses())
        else:
            children = []
        return f"{self._help}. Possible values: {children}"

    @help.setter
    def help(self, value):
        self._help = value

    @property
    def info_text(self):
        if self._init_done:
            return f"A list of {list(self.cls.non_abstract_subclasses())}"
        else:
            return f"A list of {self.cls} subclasses"


def classes_with_traits(base_class):
    """Returns a list of the base class plus its non-abstract children
    if they have traits"""

    if hasattr(base_class, "plugin_entry_point"):
        detect_and_import_plugins(base_class.plugin_entry_point)

    all_classes = [base_class] + non_abstract_children(base_class)
    with_traits = []

    for cls in all_classes:
        if has_traits(cls):
            with_traits.append(cls)

        # add subcomponents
        if hasattr(cls, "classes"):
            # we will ignore failing classes to not break anyone
            if isinstance(cls.classes, List):
                classes = cls.classes.default()
            else:
                classes = cls.classes

            try:
                for component in classes:
                    with_traits.extend(classes_with_traits(component))
            except Exception:
                pass

    return with_traits


def has_traits(cls, ignore=("config", "parent")):
    """True if cls has any traits apart from the usual ones

    all our components have at least 'config' and 'parent' as traitlets
    this is inherited from `traitlets.config.Configurable` so we ignore them
    here.
    """
    return bool(set(cls.class_trait_names()) - set(ignore))


class FloatTelescopeParameter(TelescopeParameter):
    """a `~ctapipe.core.telescope_component.TelescopeParameter` with Float trait type"""

    def __init__(self, **kwargs):
        """Create a new FloatTelescopeParameter"""
        super().__init__(trait=Float(), **kwargs)


class IntTelescopeParameter(TelescopeParameter):
    """a `~ctapipe.core.telescope_component.TelescopeParameter` with Int trait type"""

    def __init__(self, **kwargs):
        """Create a new IntTelescopeParameter"""
        super().__init__(trait=Integer(), **kwargs)


class BoolTelescopeParameter(TelescopeParameter):
    """a `~ctapipe.core.telescope_component.TelescopeParameter` with Bool trait type"""

    def __init__(self, **kwargs):
        """Create a new BoolTelescopeParameter"""
        super().__init__(trait=Bool(), **kwargs)