Source code for GCR.composite

"""
Composite catalog reader
"""
import warnings
from collections import defaultdict
from collections.abc import Mapping
from itertools import repeat

import numpy as np

from .base import BaseGenericCatalog

__all__ = ['CompositeSpecs', 'CompositeCatalog', 'MATCHING_FORMAT', 'MATCHING_ORDER']


# For backward compatibility: define module constants to be used in matching_columns
MATCHING_FORMAT = None
MATCHING_ORDER = tuple()


def _match(index_this, index_main, sorter=None):
    if sorter is None:
        sorter = np.argsort(index_this)
    s = np.searchsorted(index_this, index_main, sorter=sorter)
    s[s >= len(sorter)] = -1
    matching_idx = sorter[s]
    not_matched_mask = (index_this[matching_idx] != index_main)
    return matching_idx, not_matched_mask, sorter


def _slice_matched(input_data_dict, quantity_needed, matching_idx, not_matched_mask,
                   always_return_masked_array=False):
    for q in quantity_needed:
        data_this = input_data_dict[q][matching_idx]
        if always_return_masked_array or not_matched_mask.any():
            data_this = np.ma.array(data_this, mask=not_matched_mask)
        yield q, data_this


class CompositeSpecs(object):
    """
    CompositeSpecs class

    Parameters
    ----------
    instance : instance of GCR.BaseGenericCatalog
    identifier : str, name of this catalog
    matching_partition : bool, optional (default: True)
        Whether this catalog has matching partitions as the main catalog
    matching_row_order : bool, optional (default: True)
        Whether this catalog has the same row ordering as the main catalog
    matching_by_column : str, optional (default: None)
        The column in this catalog to be used to match to the main catalog.
        If set, matching_row_order is assumed to be False.
    matching_column_in_main : str, optional (default: None)
        The column in the main catalog to be used to match to this catalog.
    overwrite_quantities : bool (default: True)
        Whether this catalog should overwrite existing columns in the main catalog
    overwrite_attributes : bool (default: True)
        Whether this catalog should overwrite existing catalog attributes in the main catalog
    include_native_quantities : bool (default: True)
        Whether native quantities should be included
    is_main : bool (default: False)
        Whether this catalog is the main catalog
    """
    def __init__(
        self,
        instance,
        identifier=None,
        matching_partition=True,
        matching_row_order=True,
        matching_by_column=None,
        matching_column_in_main=None,
        overwrite_quantities=True,
        overwrite_attributes=True,
        include_native_quantities=True,
        is_main=False,
        **kwargs
    ):
        self.instance = instance
        self.identifier = str(identifier or "")

        self.matching_partition = bool(matching_partition)
        self.matching_row_order = bool(matching_row_order)
        self.matching_by_column = self.matching_column_in_main = None
        if matching_by_column:
            self.set_matching_column(matching_by_column, matching_column_in_main, matching_partition)

        # For backward compatibility
        if "matching_method" in kwargs:
            matching_method = kwargs["matching_method"]
            if matching_method == MATCHING_FORMAT or matching_method == "MATCHING_FORMAT":
                self.set_matching_format()
            elif matching_method == MATCHING_ORDER or matching_method == "MATCHING_ORDER":
                self.set_matching_order()
            else:
                self.set_matching_column(matching_method)

        self.overwrite_quantities = bool(overwrite_quantities)
        self.overwrite_attributes = bool(overwrite_attributes)
        self.include_native_quantities = bool(include_native_quantities)

        self._is_main = None
        self.is_main = is_main

        self.other_kwargs = kwargs.copy()

        self.cache = None
        self.sorter = None
        self.counter = None

    def clear(self):
        """
        clear cache
        """
        self.cache = None
        self.counter = None

    @property
    def is_main(self):
        return self._is_main

    @is_main.setter
    def is_main(self, value):
        self._is_main = bool(value)
        if self._is_main:
            self.set_matching_format()
            self.overwrite_quantities = True
            self.overwrite_attributes = True

    def set_matching_format(self):
        self.matching_partition = True
        self.matching_row_order = True
        self.matching_by_column = None
        self.matching_column_in_main = None

    def set_matching_order(self):
        self.matching_partition = False
        self.matching_row_order = True
        self.matching_by_column = None
        self.matching_column_in_main = None

    def set_matching_column(self, column, column_in_main=None, same_partition=False):
        self.matching_partition = bool(same_partition)
        self.matching_row_order = False
        self.matching_by_column = column
        self.matching_column_in_main = column_in_main if column_in_main else column

    @property
    def is_valid_matching(self):
        return bool(self.matching_row_order or self.matching_by_column)

    def get_data_iterator(self, native_filters=None):
        self.clear()
        if self.matching_partition:
            return self.instance._iter_native_dataset(native_filters)
        return repeat(None)


[docs]class CompositeCatalog(BaseGenericCatalog): """ Composite Catalog class Parameters ---------- catalog_instances : A list of CompositeSpecs instance. Alternatively, a list of dictionaries that can be used to init CompositeSpecs. See CompositeSpecs docstring for details. always_return_masked_array : bool, optional (default: False) If set to True, always return masked array unless the catalogs have matching formats. Example ------- >>> cat0.list_all_quantities() ['a', 'b', 'c', 'd'] >>> cat1.list_all_quantities() ['a', 'b', 'e'] >>> cat2.list_all_quantities() ['a', 'c', 'f'] >>> cc = CompositeCatalog([cat0, cat1, cat2]) >>> cc.get_quantity_modifier('a') ('_2', 'a') >>> cc.get_quantity_modifier('b') ('_1', 'b') >>> cc.get_quantity_modifier('c') ('_2', 'c') >>> cc.get_quantity_modifier('d') ('_0', 'd') >>> cc.get_quantity_modifier('e') ('_1', 'e') >>> cc.get_quantity_modifier('f') ('_2', 'f') """ def __init__( self, catalog_instances, catalog_identifiers=None, matching_methods=None, only_use_master_attr=None, always_return_masked_array=False, **kwargs ): warnings.warn('CompositeCatalog is still an experimental feature. Use with care!') # check number of catalogs if len(catalog_instances) < 2: raise ValueError('Must have more than one catalogs to make a composite catalog!') self._catalogs = list() for i, instance in enumerate(catalog_instances): if isinstance(instance, CompositeSpecs): cat = instance elif isinstance(instance, Mapping): cat = CompositeSpecs(**instance) else: warnings.warn("Please supply a list of CompositeSpecs instances", DeprecationWarning) try: identifier = catalog_identifiers[i] except (TypeError, KeyError, IndexError): identifier = None try: matching_method = matching_methods[i] except (TypeError, KeyError, IndexError): matching_method = MATCHING_FORMAT cat = CompositeSpecs(instance, identifier, matching_method=matching_method) self._catalogs.append(cat) # check uniqueness of main catalogs main_flags = [cat.is_main for cat in self._catalogs] if sum(main_flags) > 1: raise ValueError('There can be only one main catalog!') elif not any(main_flags): self._main = self._catalogs[0] self._main.is_main = True else: # make sure main catalog is the first catalog self._main = self._catalogs.pop(main_flags.index(True)) self._catalogs.insert(0, self._main) self._catalogs = tuple(self._catalogs) # check uniqueness of identifiers identifiers = [cat.identifier for cat in self._catalogs] for i, cat in enumerate(self._catalogs): if identifiers.index(cat.identifier) != i or not cat.identifier: cat.identifier = cat.identifier + "_{}".format(i) # check all catalogs have valid matching method if not all(cat.is_valid_matching for cat in self._catalogs): raise ValueError("Not all catalogs have valid matching method!") # backward compatibility: matching_column_in_main was used to set in main catalog's matching_method main_matching_col = self._main.other_kwargs.get("matching_method") if main_matching_col and self._main.instance.has_quantity(main_matching_col): for cat in self._catalogs: if ( cat.matching_column_in_main and cat.matching_column_in_main == cat.other_kwargs.get("matching_method") ): cat.matching_column_in_main = main_matching_col self._native_filter_quantities = set(self.main.native_filter_quantities) self.native_filter_string_only = self.main.native_filter_string_only self._native_quantities = set() self._quantity_modifiers = dict() for cat in self._catalogs: for q in cat.instance.list_all_quantities(cat.include_native_quantities): key = (cat.identifier, q) self._native_quantities.add(key) if cat.overwrite_quantities or q not in self._quantity_modifiers: self._quantity_modifiers[q] = key if only_use_master_attr: for cat in self._catalogs[1:]: cat.overwrite_attributes = False self.always_return_masked_array = bool(always_return_masked_array) super(CompositeCatalog, self).__init__(**kwargs) def _subclass_init(self, **kwargs): pass def _generate_native_quantity_list(self): return self._native_quantities def _obtain_native_data_dict(self, native_quantities_needed, native_quantity_getter): native_quantities_needed_dict = defaultdict(set) for identifier, quantity in native_quantities_needed: native_quantities_needed_dict[identifier].add(quantity) order_matching_dummy_col = None for cat_id in list(native_quantities_needed_dict): cat = self._get_catalog_by_id(cat_id) if cat.is_main: continue # add matching columns if cat.matching_by_column: native_quantities_needed_dict[cat_id].add(cat.matching_by_column) native_quantities_needed_dict[self._main.identifier].add(cat.matching_column_in_main) # set order_matching_dummy_col if needed elif order_matching_dummy_col is None and cat.matching_row_order and not cat.matching_partition: if native_quantities_needed_dict[self._main.identifier]: order_matching_dummy_col = list(native_quantities_needed_dict[self._main.identifier]).pop() else: order_matching_dummy_col = list(self.master.list_all_quantities(True)).pop() native_quantities_needed_dict[self._main.identifier].add(order_matching_dummy_col) data = dict() for cat in self._catalogs: if cat.identifier not in native_quantities_needed_dict: continue quantities_needed = native_quantities_needed_dict[cat.identifier] getter = native_quantity_getter[cat.identifier] # load data if cat.is_main or cat.matching_partition: cat.cache = cat.instance._load_quantities(quantities_needed, getter) elif cat.cache is None: cat.cache = cat.instance.get_quantities(quantities_needed) # for main catalog or catalog has exact matching format, import data and do nothing else: if cat.is_main or (cat.matching_partition and cat.matching_row_order): for q, v in cat.cache.items(): data[(cat.identifier, q)] = v cat.cache = None continue # match column if needed: if cat.matching_by_column: matching_idx, not_matched_mask, cat.sorter = _match( cat.cache[cat.matching_by_column], data[(self._main.identifier, cat.matching_column_in_main)], cat.sorter, ) for q, v in _slice_matched( cat.cache, quantities_needed, matching_idx, not_matched_mask, self.always_return_masked_array ): data[(cat.identifier, q)] = v # slice rows if needed if order_matching_dummy_col is not None: if not cat.counter: cat.counter = 0 count = len(data[(self._main.identifier, order_matching_dummy_col)]) slice_this = slice(cat.counter, cat.counter + count) cat.counter += count for q in quantities_needed: data[(cat.identifier, q)] = cat.cache[q][slice_this] if cat.matching_partition: cat.cache = cat.sorter = cat.counter = None return data def _iter_native_dataset(self, native_filters=None): identifiers = tuple((cat.identifier for cat in self._catalogs)) for getters in zip(*(cat.get_data_iterator(native_filters) for cat in self._catalogs)): yield dict(zip(identifiers, getters)) def __getattr__(self, name): attr_found = False attr = None for cat in reversed(self._catalogs): if hasattr(cat.instance, name): if cat.overwrite_attributes: return getattr(cat.instance, name) else: attr = getattr(cat.instance, name) attr_found = True else: if attr_found: return attr return getattr(self.main, name) def _get_quantity_info_dict(self, quantity, default=None): native_q = self._quantity_modifiers.get(quantity) if isinstance(native_q, tuple) and len(native_q) == 2: cat_id, q = native_q try: cat = self._get_catalog_by_id(cat_id) except KeyError: pass else: return cat.instance.get_quantity_info(q, default=default) return default def _get_catalog_by_id(self, identifier): for cat in self._catalogs: if cat.identifier == identifier: return cat raise KeyError(identifier, "not exist!") @property def catalogs(self): return tuple((cat.instance for cat in self._catalogs)) @property def main(self): return self.catalogs[0] @property def master(self): # backward compatibility return self.main def __len__(self): return len(self.main)