diff --git a/quickstats/_version.py b/quickstats/_version.py index 4338f6cbaa8a678be12f967dad33fa8a23001a10..463ac227e347654323f29463f1728eddd3061dca 100644 --- a/quickstats/_version.py +++ b/quickstats/_version.py @@ -1 +1 @@ -__version__ = "0.7.0.8" +__version__ = "0.8.0.0" diff --git a/quickstats/algorithms/bump_hunt/bump_hunt_1d.py b/quickstats/algorithms/bump_hunt/bump_hunt_1d.py index 25d042e547007f132a37478088a432ef1e15a300..192a9c5873c0cf7baa9947eeb1d28ffe5c567117 100644 --- a/quickstats/algorithms/bump_hunt/bump_hunt_1d.py +++ b/quickstats/algorithms/bump_hunt/bump_hunt_1d.py @@ -19,7 +19,7 @@ from quickstats.concepts import Binning, Histogram1D from quickstats.interface.pydantic import DefaultModel from quickstats.maths.numerics import sliced_sum from quickstats.utils.common_utils import execute_multi_tasks, combine_dict -from quickstats.utils.string_utils import format_dict_to_string +from quickstats.utils.string_utils import format_aligned_dict from .settings import BumpHuntMode, SignalStrengthScale, AutoScanStep __all__ = ['BumpHunt1D'] @@ -301,7 +301,7 @@ class BumpHuntOutput1D: 'Number of signals': output['nsignal'], 'Local p-value': output['local_pval'] } - result += format_dict_to_string(table, left_margin=4) + '\n' + result += format_aligned_dict(table, left_margin=4) + '\n' # Combined result result += f'Combined\n\n' @@ -320,7 +320,7 @@ class BumpHuntOutput1D: table['Global significance'] = f'> {summary["global_significance"]} (lower limit)' else: table['Global significance'] = summary['global_significance'] - result += format_dict_to_string(table, left_margin=4) + '\n' + result += format_aligned_dict(table, left_margin=4) + '\n' return result diff --git a/quickstats/components/modelling/data_modelling.py b/quickstats/components/modelling/data_modelling.py index 7affbc376ebf276e6f49ff9875661883964e62f7..f0416195835526118febab57afe013c4f9508fba 100644 --- a/quickstats/components/modelling/data_modelling.py +++ b/quickstats/components/modelling/data_modelling.py @@ -229,7 +229,10 @@ class DataModelling(ROOTObject): } if isinstance(data, np.ndarray): from quickstats.components.modelling import ArrayDataSource - data_source = ArrayDataSource(data, weights=weights, **kwargs) + data_source = ArrayDataSource(data, weights=weights, + observable_name=self.observable_name, + weight_name=self.weight_name, + **kwargs) elif isinstance(data, ROOT.RooDataSet): from quickstats.components.modelling import RooDataSetDataSource data_source = RooDataSetDataSource(data, **kwargs) @@ -342,8 +345,8 @@ class DataModelling(ROOTObject): return summary_text def create_plot(self, data: Union[np.ndarray, "ROOT.RooDataSet", "ROOT.TTree", DataSource], - weights: Optional[np.ndarray]=None, - saveas:Optional[str]=None): + weights: Optional[np.ndarray] = None, + saveas: Optional[str] = None): if not self.result: raise RuntimeError("No results to plot") data_source = self.create_data_source(data, weights=weights) diff --git a/quickstats/components/workspaces/ws_comparer.py b/quickstats/components/workspaces/ws_comparer.py index dd7ba37797683dbb280c86bf9549764be115373b..adcbb2ead4f4e602f09972e06dac8f0e56528a03 100644 --- a/quickstats/components/workspaces/ws_comparer.py +++ b/quickstats/components/workspaces/ws_comparer.py @@ -13,7 +13,7 @@ from ROOT.RooPrintable import (kName, kClassName, kValue, kArgs, kExtras, kAddre import quickstats from quickstats import semistaticmethod, timer, AbstractObject, GeneralEnum from quickstats.components import ExtendedModel -from quickstats.core.io import get_colored_text, format_comparison_text +from quickstats.core.io import TextColors from quickstats.utils.root_utils import load_macro, get_macro_dir from quickstats.maths.numerics import is_float, pretty_value from quickstats.components.basics import WSArgument @@ -367,12 +367,12 @@ class ComparisonData: size = len(contents) if size > 0: s = f"{indent}[{title} ({size})]\n" - summary_str += get_colored_text(s, title_color) + summary_str += TextColors.colorize(s, title_color) if not show_content: return summary_str for content in contents: s = f"{indent*2}{content}\n" - summary_str += get_colored_text(s, content_color) + summary_str += TextColors.colorize(s, content_color) return summary_str def _get_mapped_content_summary_str(self, contents_1:List[str], contents_2:List[str], @@ -389,12 +389,12 @@ class ComparisonData: raise ValueError("content_1 and content_2 must have the same size") if size_1 > 0: s = f"{indent}[{title} ({size_1})]\n" - summary_str += get_colored_text(s, title_color) + summary_str += TextColors.colorize(s, title_color) if not show_content: return summary_str for (content_1, content_2) in zip(contents_1, contents_2): - s_left, s_right = format_comparison_text(content_1, content_2, - equal_color, delete_color, insert_color) + s_left, s_right = TextColors.format_comparison(content_1, content_2, + equal_color, delete_color, insert_color) summary_str += f"{indent*2}{s_left} -> {s_right}\n" return summary_str @@ -468,7 +468,7 @@ class ComparisonData: if ext_ref is None: ext_ref = {} s = f"{self.title}:\n" - summary_str = get_colored_text(s, "bright magenta") + summary_str = TextColors.colorize(s, "bright magenta") if self.support_definition: df = self.get_df("identical") content_kwargs = { diff --git a/quickstats/concepts/__init__.py b/quickstats/concepts/__init__.py index 0fad8b635846cb7736354dbff7cf04c0e02b64b8..dd7c02954ee5bd3f977103c0c558afa172936832 100644 --- a/quickstats/concepts/__init__.py +++ b/quickstats/concepts/__init__.py @@ -1,3 +1,4 @@ from .version import Version from .binning import Binning -from .histogram1d import Histogram1D \ No newline at end of file +from .histogram1d import Histogram1D +from .stacked_histogram import StackedHistogram \ No newline at end of file diff --git a/quickstats/concepts/binning.py b/quickstats/concepts/binning.py index eee0e4da089bbc77977e19fd2f2f6fe7e37ab9d7..648194a8ab2016b81ffda488de07cc174ba30182 100644 --- a/quickstats/concepts/binning.py +++ b/quickstats/concepts/binning.py @@ -1,168 +1,328 @@ -from typing import Optional, Union, Tuple +""" +Enhanced binning utilities for numerical data analysis. + +This module provides a flexible class for defining and manipulating binning information, +supporting both uniform and non-uniform bins with comprehensive validation and error handling. +""" + +from __future__ import annotations + +from typing import ( + Optional, Union, Tuple, ClassVar, + overload, cast +) +from dataclasses import dataclass +from copy import deepcopy import numpy as np +from numpy.typing import ArrayLike, NDArray -from quickstats.core.typing import ArrayLike from quickstats.maths.statistics import bin_edge_to_bin_center +class BinningError(Exception): + """Base exception for binning-related errors.""" + pass + +@dataclass +class BinningConfig: + """Configuration for binning operations.""" + min_bins: int = 1 + min_edges: int = 2 + rtol: float = 1e-05 # Relative tolerance for float comparisons + atol: float = 1e-08 # Absolute tolerance for float comparisons + class Binning: """ - A class for defining binning information. + A class for defining and manipulating binning information. + + This class provides functionality for creating and managing bin definitions, + supporting both uniform and non-uniform binning with comprehensive validation. Parameters ---------- - bins : ArrayLike or int - If ArrayLike, specifies the bin edges directly. - If int, specifies the number of bins, in which case `bin_range` must be provided. - bin_range : Optional[ArrayLike], optional - The range of the bins as a tuple (low, high), required if `bins` is an int. + bins : Union[ArrayLike, int] + Either bin edges array or number of bins + bin_range : Optional[ArrayLike], default None + Range for bin creation when bins is an integer (low, high) - Attributes - ---------- - bin_edges : np.ndarray - The edges of the bins. - bin_centers : np.ndarray - The centers of the bins. - bin_widths : np.ndarray - The widths of the bins. - nbins : int - The number of bins. - - Methods - ------- - bin_edges - Returns the edges of the bins. - bin_centers - Returns the centers of the bins. - bin_widths - Returns the widths of the bins. - nbins - Returns the number of bins. + Raises + ------ + BinningError + If initialization parameters are invalid + ValueError + If input values are out of valid ranges + + Examples + -------- + >>> # Create uniform binning with 10 bins + >>> binning = Binning(10, (0, 1)) + >>> print(binning.nbins) + 10 + + >>> # Create custom binning with specific edges + >>> edges = [0, 1, 2, 4, 8] + >>> binning = Binning(edges) + >>> print(binning.bin_widths) + array([1, 1, 2, 4]) """ - - def __init__(self, bins: Union[ArrayLike, int], bin_range: Optional[ArrayLike] = None): + + # Class-level configuration + config: ClassVar[BinningConfig] = BinningConfig() + + def __init__( + self, + bins: Union[ArrayLike, int], + bin_range: Optional[ArrayLike] = None + ) -> None: + """Initialize binning with edges or number of bins.""" + try: + self._bin_edges = self._init_bin_edges(bins, bin_range) + except Exception as e: + raise BinningError(f"Failed to initialize binning: {str(e)}") from e + + def _init_bin_edges( + self, + bins: Union[ArrayLike, int], + bin_range: Optional[ArrayLike] + ) -> NDArray: + """ + Initialize bin edges with validation. + + Parameters + ---------- + bins : Union[ArrayLike, int] + Bin specification + bin_range : Optional[ArrayLike] + Optional range for bin creation + + Returns + ------- + NDArray + Validated bin edges array + + Raises + ------ + ValueError + If bin specification is invalid + """ if np.ndim(bins) == 1: - if len(bins) < 2: - raise ValueError('Number of bin edges must be greater than 1 to define a binning.') - self._bin_edges = np.array(bins) - elif np.ndim(bins) == 0: - if (not isinstance(bins, int)) or (bins < 1): - raise ValueError('Number of bins must be greater than 0 to define a binning.') - if bin_range is None: - raise ValueError('`bin_range` must be given when `bins` is a number.') - bin_low, bin_high = bin_range - if bin_low > bin_high: - raise ValueError('`bin_range[0]` can not be larger than `bin_range[1]`.') - self._bin_edges = np.linspace(bin_low, bin_high, bins + 1) - else: - raise ValueError('Invalid value for `bins`. It must be either an array representing the bin edges or a number representing the number of bins.') - - def __copy__(self): - """ - Create a copy of the current Binning instance with the same bin edges. + return self._init_from_edges(np.asarray(bins)) + + if np.ndim(bins) == 0: + return self._init_from_count( + cast(int, bins), + bin_range + ) + + raise ValueError( + "Invalid bins parameter. Must be either bin edges array or bin count." + ) + + def _init_from_edges(self, edges: NDArray) -> NDArray: + """Initialize from explicit bin edges.""" + if len(edges) < self.config.min_edges: + raise ValueError( + f"Number of bin edges must be at least {self.config.min_edges}" + ) + + if not np.all(np.diff(edges) > 0): + raise ValueError("Bin edges must be strictly increasing") + + return edges + + def _init_from_count( + self, + nbins: int, + bin_range: Optional[ArrayLike] + ) -> NDArray: + """Initialize from bin count and range.""" + if not isinstance(nbins, int) or nbins < self.config.min_bins: + raise ValueError( + f"Number of bins must be integer >= {self.config.min_bins}" + ) + + if bin_range is None: + raise ValueError("bin_range required when specifying bin count") + + bin_low, bin_high = self._validate_range(bin_range) + return np.linspace(bin_low, bin_high, nbins + 1) + + def _validate_range( + self, + bin_range: ArrayLike + ) -> Tuple[Any, Any]: + """ + Validate and convert bin range. + + Parameters + ---------- + bin_range : ArrayLike + Range specification for bins + + Returns + ------- + Tuple[Any, Any] + Tuple of (low, high) range values with original types preserved + + Raises + ------ + ValueError + If bin range is invalid + """ + try: + bin_range_arr = np.asarray(bin_range) + if bin_range_arr.shape != (2,): + raise ValueError("bin_range must be sequence of length 2") + + bin_low, bin_high = bin_range_arr + if bin_low >= bin_high: + raise ValueError("bin_range[0] must be less than bin_range[1]") + + return bin_low, bin_high + + except Exception as e: + raise ValueError(f"Invalid bin_range: {str(e)}") from e + + def __copy__(self) -> Binning: + """ + Create a shallow copy. Returns ------- Binning - A new Binning instance with the same bin edges as the current instance. + New instance with copy of bin edges """ - # Create a new instance of Binning using the current instance's bin edges new_instance = self.__class__(self._bin_edges.copy()) - return new_instance + return new_instance + + def __deepcopy__(self, memo: dict) -> Binning: + """ + Create a deep copy. + + Parameters + ---------- + memo : dict + Memo dictionary for deepcopy + + Returns + ------- + Binning + New instance with deep copy of bin edges + """ + new_instance = self.__class__(deepcopy(self._bin_edges, memo)) + return new_instance def __eq__(self, other: object) -> bool: """ - Check if two Binning instances are equal by comparing bin edges. + Check equality with another Binning instance. Parameters ---------- other : object - Another Binning instance to compare against. + Object to compare with Returns ------- bool - True if bin edges have the same shape and are element-wise equal within a tolerance. + True if binnings are equal """ if not isinstance(other, Binning): return NotImplemented - # Check if the shapes of the bin edges are the same - if self.bin_edges.shape != other.bin_edges.shape: - return False - - # Check if the values in the bin edges are close - return np.allclose(self.bin_edges, other.bin_edges) - + return ( + self.bin_edges.shape == other.bin_edges.shape and + np.allclose( + self.bin_edges, + other.bin_edges, + rtol=self.config.rtol, + atol=self.config.atol + ) + ) + + def __repr__(self) -> str: + """Create string representation.""" + return ( + f"{self.__class__.__name__}(" + f"edges=[{self.bin_edges[0]}, ..., {self.bin_edges[-1]}], " + f"nbins={self.nbins})" + ) + @property - def bin_edges(self) -> np.ndarray: + def bin_edges(self) -> NDArray: """ - Returns the edges of the bins. + Get bin edges array. Returns ------- - np.ndarray - The edges of the bins. + NDArray + Array of bin edges """ return self._bin_edges.copy() @property - def bin_centers(self) -> np.ndarray: + def bin_centers(self) -> NDArray: """ - Returns the centers of the bins. + Get bin centers array. Returns ------- - np.ndarray - The centers of the bins. + NDArray + Array of bin centers """ return bin_edge_to_bin_center(self.bin_edges) @property - def bin_widths(self) -> np.ndarray: + def bin_widths(self) -> NDArray: """ - Returns the widths of the bins. + Get bin widths array. Returns ------- - np.ndarray - The widths of the bins. + NDArray + Array of bin widths """ return np.diff(self.bin_edges) @property def nbins(self) -> int: """ - Returns the number of bins. + Get number of bins. Returns ------- int - The number of bins. + Number of bins """ return len(self.bin_edges) - 1 @property - def bin_range(self) -> Tuple[float, float]: + def bin_range(self) -> Tuple[Any, Any]: """ - Returns the bin range. + Get bin range. Returns ------- - (float, float) - The bin range. + Tuple[Any, Any] + (minimum edge, maximum edge) with original types preserved """ return (self.bin_edges[0], self.bin_edges[-1]) - @property def is_uniform(self) -> bool: """ Check if binning is uniform. + A binning is uniform if all bins have the same width within + numerical tolerance. + Returns ------- bool - True if binning is uniform. + True if binning is uniform """ - delta_widths = np.diff(self.bin_widths) - return np.allclose(np.zeros(delta_widths.shape), delta_widths) \ No newline at end of file + widths = self.bin_widths + return np.allclose( + widths, + widths[0], + rtol=self.config.rtol, + atol=self.config.atol + ) \ No newline at end of file diff --git a/quickstats/concepts/histogram1d.py b/quickstats/concepts/histogram1d.py index 6a7573c8a7ad36dfc7f1b2742883b43de480d356..37d9cdcc686d42818b8a7534e57fdc1045bbf515 100644 --- a/quickstats/concepts/histogram1d.py +++ b/quickstats/concepts/histogram1d.py @@ -1,507 +1,1154 @@ -from typing import Optional, Union, Tuple, Any +from __future__ import annotations + +from typing import ( + Optional, Union, Tuple, Any, Callable, Sequence, + cast, TypeVar +) +from numbers import Real import numpy as np from quickstats import stdout -from quickstats.core.typing import ArrayLike, Real -from quickstats.maths.statistics import BinErrorMode, poisson_interval, histogram -from quickstats.maths.numerics import all_integers, safe_div, is_integer +from quickstats.core.typing import ArrayLike +from quickstats.maths.numerics import all_integers, safe_div +from quickstats.maths.histograms import ( + BinErrorMode, + HistComparisonMode, + poisson_interval, + histogram, + get_histogram_mask, +) from .binning import Binning -BinErrorType = Optional[Tuple[np.ndarray, np.ndarray]] +# Type aliases for better type safety +H = TypeVar('H', bound='Histogram1D') +BinErrors = Optional[Tuple[np.ndarray, np.ndarray]] +ComparisonMode = Union[HistComparisonMode, str, Callable[[H, H], H]] class Histogram1D: """ - A class for defining a 1D histogram. - - Parameters - ---------- - bin_content : np.ndarray - The bin content of the histogram. - bin_edges : np.ndarray - The bin edges of the histogram. - bin_errors : ArrayLike, optional - The bin errors of the histogram. If None and a Poisson error mode - is used, the bin errors will be automatically calculated. - error_mode : BinErrorMode or str, default = "auto" - The method with which the bin errors are evaluated. It can - be "sumw2" (symmetric error from Wald approximation), "poisson" - (Poisson interval at one sigma), "auto" (deduce automatically from - bin_content, use "poisson" if bin content is integer type or float - type with zero remainders, and "sumw2" otherwise). Note that this - method only indicates the method with which the bin_errors are - calculated at initialization. Subsequent evaluation could differ - depending on the operations (e.g. summation or division of histograms, - and scaling of histograms). + A class representing a one-dimensional histogram with bin contents, edges, and errors. Attributes ---------- bin_content : np.ndarray - The bin content of the histogram. - bin_errors : (np.ndarray, np.ndarray) - The bin errors of the histogram. + The bin content of the histogram + bin_errors : Optional[Tuple[np.ndarray, np.ndarray]] + The bin errors (lower, upper) if available bin_edges : np.ndarray - The bin edges of the histogram. + The bin edges of the histogram bin_centers : np.ndarray - The bin centers of the histogram. + The bin centers of the histogram bin_widths : np.ndarray - The widths of the bins. - nbins : int - The number of bins. + The widths of the bins + nbins : int + The number of bins error_mode : BinErrorMode - The current error mode of the histogram. + The current error mode """ - - def __init__(self, bin_content: np.ndarray, - bin_edges: np.ndarray, - bin_errors: Optional[ArrayLike] = None, - error_mode: Union[BinErrorMode, str] = "auto") -> None: + + def __init__( + self, + bin_content: np.ndarray, + bin_edges: np.ndarray, + bin_errors: Optional[ArrayLike] = None, + error_mode: Union[BinErrorMode, str] = "auto" + ) -> None: """ - Initialize the Histogram1D object by setting the bin content, edges, errors, and error mode. - """ - self.set_data(bin_content=bin_content, - bin_edges=bin_edges, - bin_errors=bin_errors, - error_mode=error_mode) + Initialize a Histogram1D instance. - def __add__(self, other: "Histogram1D") -> "Histogram1D": - """Perform addition between two histograms.""" - return self._operate('add', other) + Parameters + ---------- + bin_content : np.ndarray + The bin content of the histogram + bin_edges : np.ndarray + The bin edges of the histogram + bin_errors : Optional[ArrayLike], default None + The bin errors of the histogram. Supported formats: + - None: No errors + - scalar: Same error for all bins + - 1D array: Symmetric errors, length must match bins + - 2D array: Asymmetric errors, shape must be (2, nbins) + - Tuple[array, array]: Asymmetric errors, each array length matches bins + If None and a Poisson error mode is used, the bin errors will + be automatically calculated. + error_mode : Union[BinErrorMode, str], default "auto" + The method for error calculation. It can + be "sumw2" (symmetric error from Wald approximation), "poisson" + (Poisson interval at one sigma), or "auto" (deduced from bin content) + + Raises + ------ + ValueError + If bin_content or bin_edges are not 1D arrays + If arrays have incompatible sizes + """ + self.set_data( + bin_content=bin_content, + bin_edges=bin_edges, + bin_errors=bin_errors, + error_mode=error_mode, + ) + + def __add__(self, other: Union[Histogram1D, Real]) -> Histogram1D: + """Add another histogram or scalar value.""" + return self._operate("add", other) - def __iadd__(self, other: "Histogram1D") -> "Histogram1D": - """Perform in-place addition between two histograms.""" - return self._ioperate('add', other) + def __sub__(self, other: Union[Histogram1D, Real]) -> Histogram1D: + """Subtract another histogram or scalar value.""" + return self._operate("sub", other) - def __sub__(self, other: "Histogram1D") -> "Histogram1D": - """Perform subtraction between two histograms.""" - return self._operate('sub', other) + def __mul__(self, other: Union[Real, ArrayLike]) -> Histogram1D: + """Multiply by a scalar or array.""" + return self._operate("scale", other) - def __isub__(self, other: "Histogram1D") -> "Histogram1D": - """Perform in-place subtraction between two histograms.""" - return self._ioperate('sub', other) + def __rmul__(self, other: Union[Real, ArrayLike]) -> Histogram1D: + """Right multiplication by a scalar or array.""" + return self._operate("scale", other) - def __truediv__(self, other: Union[Real, ArrayLike, "Histogram1D"]) -> "Histogram1D": - """Perform division between a histogram and either a scalar or another histogram.""" - instance = self._operate('div', other) + def __truediv__(self, other: Union[Histogram1D, Union[Real, ArrayLike]]) -> Histogram1D: + """Divide by another histogram, scalar, or array.""" + instance = self._operate("div", other) # Ensure that bin content is treated as weighted after division instance._bin_content = instance._bin_content.astype(float) return instance + + def __iadd__(self, other: Union[Histogram1D, Real]) -> Histogram1D: + """In-place addition with histogram or scalar.""" + return self._ioperate("add", other) + + def __isub__(self, other: Union[Histogram1D, Real]) -> Histogram1D: + """In-place subtraction with histogram or scalar.""" + return self._ioperate("sub", other) + + def __itruediv__(self, other: Union[Histogram1D, Real, ArrayLike]) -> Histogram1D: + """In-place division by histogram or scalar.""" + return self._ioperate("div", other) - def __itruediv__(self, other: Union[Real, ArrayLike, "Histogram1D"]) -> "Histogram1D": - """Perform in-place division between a histogram and either a scalar or another histogram.""" - return self._ioperate('div', other) + def __imul__(self, other: Union[Real, ArrayLike]) -> Histogram1D: + """ + In-place multiplication by scalar or array. - def __mul__(self, other: Union[Real, ArrayLike]) -> "Histogram1D": - """Perform multiplication of a histogram by a scalar.""" - return self._operate('scale', other) + Parameters + ---------- + other : Union[Real, ArrayLike] + Scalar or array to multiply by - def __imul__(self, other: Union[Real, ArrayLike]) -> "Histogram1D": - """Perform in-place multiplication of a histogram by a scalar.""" - return self._ioperate('scale', other) + Returns + ------- + Histogram1D + Self multiplied by other + """ + return self._ioperate("scale", other) + + def _operate( + self, + method: str, + other: Any + ) -> Histogram1D: + """ + Perform operations on histogram data. - def __rmul__(self, other: Union[Real, ArrayLike]) -> "Histogram1D": - """Perform right multiplication of a histogram by a scalar.""" - return self._operate('scale', other) + Parameters + ---------- + method : str + The operation to perform ('add', 'sub', 'div', 'scale') + other : Any + The other operand (histogram, scalar, or array) - def _operate(self, method: str, other: Any) -> "Histogram1D": - """Perform a binary operation (add, sub, div, scale) on the histogram.""" - if not hasattr(self, f'_{method}'): - raise ValueError(f'no operation named "_{method}"') - fn = getattr(self, f'_{method}') - bin_content, bin_errors = fn(other) + Returns + ------- + Histogram1D + A new histogram with the operation result + + Raises + ------ + ValueError + If operation is invalid or operands are incompatible + """ + operation = getattr(self, f"_{method}", None) + if operation is None: + raise ValueError(f'Invalid operation: "{method}"') + + bin_content, bin_errors = operation(other) + bin_content_raw, bin_errors_raw = self._operate_masked(method, other) + if isinstance(other, Histogram1D): error_mode = self._resolve_error_mode(bin_content, other._error_mode) else: error_mode = self._error_mode - # Avoid copying bin_edges - bin_edges = self._binning._bin_edges - return type(self)(bin_content=bin_content, bin_edges=bin_edges, - bin_errors=bin_errors, error_mode=error_mode) - - def _ioperate(self, method: str, other: Any) -> "Histogram1D": - """Perform an in-place binary operation (add, sub, div, scale) on the histogram.""" - if not hasattr(self, f'_{method}'): - raise ValueError(f'no operation named "_{method}"') - fn = getattr(self, f'_{method}') - bin_content, bin_errors = fn(other) + + mask = self._combine_mask(other) + self._apply_mask(mask, bin_content, bin_errors) + + instance = type(self)( + bin_content=bin_content, + bin_edges=self._binning._bin_edges, + bin_errors=bin_errors, + error_mode=error_mode, + ) + + instance._bin_content_raw = bin_content_raw + instance._bin_errors_raw = bin_errors_raw + instance._mask = mask + return instance + + def _operate_masked( + self, + method: str, + other: Any + ) -> Tuple[Optional[np.ndarray], BinErrors]: + """ + Handle operations with masked histograms. + + Parameters + ---------- + method : str + Operation to perform + other : Any + Other operand + + Returns + ------- + Tuple[Optional[np.ndarray], BinErrors] + Raw bin content and errors if masked, else (None, None) + """ + self_masked = self.is_masked() + other_masked = isinstance(other, Histogram1D) and other.is_masked() + + if not (self_masked or other_masked): + return None, None + + self_copy = self.copy() + other_copy = other.copy() if isinstance(other, Histogram1D) else other + + if self_masked: + self_copy.unmask() + if other_masked: + other_copy.unmask() + + operation = getattr(self_copy, f"_{method}") + return operation(other_copy) + + def _ioperate(self, method: str, other: Any) -> Histogram1D: + """ + Perform in-place operation. + + Parameters + ---------- + method : str + Operation to perform ('add', 'sub', 'div', 'scale') + other : Any + Other operand + + Returns + ------- + Histogram1D + Self with operation applied + + Raises + ------ + ValueError + If operation is invalid + """ + if not hasattr(self, f"_{method}"): + raise ValueError(f'Invalid operation: "{method}"') + + operation = getattr(self, f"_{method}") + bin_content, bin_errors = operation(other) + bin_content_raw, bin_errors_raw = self._operate_masked(method, other) + if isinstance(other, Histogram1D): - self._error_mode = self._resolve_error_mode(bin_content, other._error_mode) + self._error_mode = self._resolve_error_mode( + bin_content, other._error_mode + ) + + mask = self._combine_mask(other) + self._apply_mask(mask, bin_content, bin_errors) + self._bin_content = bin_content self._bin_errors = bin_errors - return self + self._bin_content_raw = bin_content_raw + self._bin_errors_raw = bin_errors_raw + self._mask = mask + + return self + + def _combine_mask(self, other: Any) -> np.ndarray: + """ + Combine masks between operands. + + Parameters + ---------- + other : Any + Other operand - def _validate_other(self, other: "Histogram1D") -> None: - """Ensure the other histogram is of the same binning and valid for operations.""" + Returns + ------- + np.ndarray + Combined mask or None if no masks exist + """ + mask = self._mask + if isinstance(other, Histogram1D) and other.is_masked(): + if mask is None: + mask = other._mask.copy() + else: + mask = mask | other._mask + return mask.copy() if mask is not None else None + + def _apply_mask( + self, + mask: np.ndarray, + bin_content: np.ndarray, + bin_errors: BinErrors + ) -> None: + """ + Apply mask to bin content and errors. + + Parameters + ---------- + mask : np.ndarray + Mask to apply + bin_content : np.ndarray + Bin content to mask + bin_errors : BinErrors + Bin errors to mask + """ + if mask is None: + return + + bin_content[mask] = 0 + if bin_errors is not None: + bin_errors[0][mask] = 0.0 + bin_errors[1][mask] = 0.0 + + def _validate_other(self, other: Histogram1D) -> None: + """ + Validate compatibility of another histogram. + + Parameters + ---------- + other : Histogram1D + Histogram to validate + + Raises + ------ + ValueError + If histograms are incompatible + """ if not isinstance(other, Histogram1D): - raise ValueError(f'operation only allowed with another Histogram1D object') + raise ValueError( + "Operation only allowed between Histogram1D objects" + ) if self.binning != other.binning: - raise ValueError(f'operations not allowed between histograms with different binnings') + raise ValueError( + "Operations not allowed between histograms with different binning" + ) - def _resolve_error_mode(self, bin_content: np.ndarray, other_mode: BinErrorMode) -> BinErrorMode: + def _resolve_error_mode( + self, + bin_content: np.ndarray, + other_mode: BinErrorMode + ) -> BinErrorMode: """ - Resolve the error mode based on bin content and other histogram's error mode. - Prefer Poisson errors if possible. + Resolve error mode for operations. + + Parameters + ---------- + bin_content : np.ndarray + Current bin content + other_mode : BinErrorMode + Other histogram's error mode + + Returns + ------- + BinErrorMode + Resolved error mode """ - if ((bin_content.dtype == 'int64') or - (self._error_mode == BinErrorMode.POISSON) or - (other_mode == BinErrorMode.POISSON)): - return BinErrorMode.POISSON - return BinErrorMode.SUMW2 + # Prefer Poisson errors if possible + use_poisson = ( + bin_content.dtype == np.int64 or + self._error_mode == BinErrorMode.POISSON or + other_mode == BinErrorMode.POISSON + ) + return BinErrorMode.POISSON if use_poisson else BinErrorMode.SUMW2 - def _scale(self, val: Union[Real, ArrayLike]) -> Tuple[np.ndarray, BinErrorType]: + def _scale( + self, + val: Union[Real, ArrayLike] + ) -> Tuple[np.ndarray, BinErrors]: """ - Scale the histogram bin content and errors by a scalar value or by an array of - values with size matching the number of bins. + Scale histogram contents. + + Parameters + ---------- + val : Union[Real, ArrayLike] + Scaling factor + + Returns + ------- + Tuple[np.ndarray, BinErrors] + Scaled bin content and errors + + Raises + ------ + ValueError + If scaling array has invalid shape """ val = np.asarray(val) is_weighted = self.is_weighted() - if (not is_weighted) and all_integers(val): - val = val.astype(int) - if not np.all(val >= 0.): - stdout.warning('scaling unweighted histogram by negative value(s) will make it weighted' - ' and force usage of sumw2 errors') + + # Handle integer scaling + if not is_weighted and all_integers(val): + val = val.astype(np.int64) + if not np.all(val >= 0): + stdout.warning( + "Scaling unweighted histogram by negative values " + "will make it weighted and force sumw2 errors" + ) val = val.astype(float) else: val = val.astype(float) - if val.ndim == 0: - val = val[()] - elif val.ndim == 1: - if val.size != self.nbins: - raise ValueError(f'size of array ({val.size}) does not match number ' - f'of bins ({self.nbins}) of histogram') - else: - raise ValueError(f'cannot scale a histogram with a value of dimension {val.ndim}') - + + # Validate scaling array shape + if val.ndim > 1: + raise ValueError(f"Cannot scale with {val.ndim}-dimensional value") + if val.ndim == 1 and val.size != self.nbins: + raise ValueError( + f"Scaling array size ({val.size}) doesn't match bins ({self.nbins})" + ) + bin_content = self._bin_content * val if self._bin_errors is None: return bin_content, None - if bin_content.dtype == 'int64': + # Handle errors based on content type + if bin_content.dtype == np.int64: bin_errors = poisson_interval(bin_content) + if self.is_masked(): + bin_errors[0][self._mask] = 0.0 + bin_errors[1][self._mask] = 0.0 else: errlo, errhi = self._bin_errors bin_errors = (val * errlo, val * errhi) return bin_content, bin_errors - def _add(self, other: "Histogram1D", neg: bool = False) -> Tuple[np.ndarray, BinErrorType]: - """Add (or subtract if `neg` is True) two histograms.""" + def _add( + self, + other: Union[Histogram1D, Real], + neg: bool = False + ) -> Tuple[np.ndarray, BinErrors]: + """ + Add/subtract histograms or scalar. + + Parameters + ---------- + other : Union[Histogram1D, Real] + Value to add/subtract + neg : bool, default False + True for subtraction + + Returns + ------- + Tuple[np.ndarray, BinErrors] + The resulting bin content and bin errors. + """ + if isinstance(other, Real): + # Convert scalar to histogram + bin_content = np.full( + self._bin_content.shape, + other, + dtype=self._bin_content.dtype + ) + bin_errors = ( + np.zeros_like(self._bin_content), + np.zeros_like(self._bin_content) + ) + other = type(self)( + bin_content=bin_content, + bin_edges=self._binning._bin_edges, + bin_errors=bin_errors, + ) + self._validate_other(other) - if neg: - bin_content = self._bin_content - other._bin_content - else: - bin_content = self._bin_content + other._bin_content + + # Perform addition/subtraction + bin_content = ( + self._bin_content - other._bin_content if neg + else self._bin_content + other._bin_content + ) - if (self._bin_errors is None) and (other._bin_errors is None): + if self._bin_errors is None and other._bin_errors is None: return bin_content, None - if (self._bin_errors is not None) and (other._bin_errors is not None): + # Handle errors + if self._bin_errors is not None and other._bin_errors is not None: use_poisson = False - if bin_content.dtype == 'int64': + if bin_content.dtype == np.int64: if np.all(bin_content >= 0): use_poisson = True else: - stdout.warning('Histogram has negative bin content - force usage of sumw2 errors') + stdout.warning( + "Negative bin content - forcing sumw2 errors" + ) + if use_poisson: bin_errors = poisson_interval(bin_content) + if self.is_masked(): + bin_errors[0][self._mask] = 0.0 + bin_errors[1][self._mask] = 0.0 else: - errlo = np.sqrt(self._bin_errors[0] ** 2 + other._bin_errors[0] ** 2) - errhi = np.sqrt(self._bin_errors[1] ** 2 + other._bin_errors[1] ** 2) + errlo = np.sqrt( + self._bin_errors[0] ** 2 + other._bin_errors[0] ** 2 + ) + errhi = np.sqrt( + self._bin_errors[1] ** 2 + other._bin_errors[1] ** 2 + ) bin_errors = (errlo, errhi) else: - if self._bin_errors is None: - bin_errors = other.bin_errors - else: - bin_errors = self.bin_errors + bin_errors = ( + self._bin_errors if self._bin_errors is not None + else other._bin_errors + ) return bin_content, bin_errors - def _sub(self, other: "Histogram1D") -> Tuple[np.ndarray, BinErrorType]: - """Subtract another histogram from the current histogram.""" + def _sub( + self, + other: Union[Histogram1D, Real] + ) -> Tuple[np.ndarray, BinErrors]: + """ + Subtract another histogram from the current histogram. + + Parameters + ---------- + other : Histogram1D + The other histogram to subtract. + + Returns + ------- + Tuple[np.ndarray, BinErrors] + The resulting bin content and bin errors. + """ return self._add(other, neg=True) + + def _div( + self, + other: Union[Histogram1D, Real, ArrayLike] + ) -> Tuple[np.ndarray, BinErrors]: + """ + Divide histogram by another histogram or scalar. - def _div(self, other: Any) -> Tuple[np.ndarray, ArrayLike, BinErrorType]: - """Divide the current histogram by either a scalar or another histogram.""" + Parameters + ---------- + other : Union[Histogram1D, Union[Real, ArrayLike]] + Divisor (histogram or scalar/array) + + Returns + ------- + Tuple[np.ndarray, BinErrors] + Resulting bin content and errors + + Raises + ------ + ValueError + If division by zero occurs + If histograms have incompatible binning + """ if not isinstance(other, Histogram1D): - scale = 1. / other - return self._scale(scale) + # Handle scalar division + if np.any(other == 0): + raise ValueError("Division by zero") + return self._scale(1.0 / other) + self._validate_other(other) - bin_content = safe_div(self._bin_content, other._bin_content, True) - # Enforce sumw2 error by making bin content weighted - if bin_content.dtype == 'int64': - bin_content = bin_content.astype(float) + bin_content = safe_div(self._bin_content, other._bin_content, False) - if (self._bin_errors is None) and (other._bin_errors is None): - return bin_content, None + # Force float type for division results + bin_content = bin_content.astype(float) - # NB: can be jitted / vectorized - def get_err(b1, b2, e1, e2): - b2_sq = b2 * b2 - err2 = safe_div(e1 * e1 * b2_sq + e2 * e2 * b1 * b1, b2_sq * b2_sq, True) - return np.sqrt(err2) + if self._bin_errors is None and other._bin_errors is None: + return bin_content, None + # Handle errors err1 = self._bin_errors or (np.zeros(self.nbins), np.zeros(self.nbins)) err2 = other._bin_errors or (np.zeros(other.nbins), np.zeros(other.nbins)) + + errlo, errhi = self._calculate_division_errors( + self._bin_content, + other._bin_content, + err1, + err2 + ) + + if self.is_masked(): + errlo[self._mask] = 0.0 + errhi[self._mask] = 0.0 + + return bin_content, (errlo, errhi) - errlo = get_err(self._bin_content, other._bin_content, err1[0], err2[0]) - errhi = get_err(self._bin_content, other._bin_content, err1[1], err2[1]) - bin_errors = (errlo, errhi) + @staticmethod + def _calculate_division_errors( + num: np.ndarray, + den: np.ndarray, + num_errs: Tuple[np.ndarray, np.ndarray], + den_errs: Tuple[np.ndarray, np.ndarray] + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Calculate errors for histogram division. - return bin_content, bin_errors + Uses error propagation formula for ratio: σ(a/b)² = (a/b)² * (σa²/a² + σb²/b²) + + Parameters + ---------- + num : np.ndarray + Numerator values + den : np.ndarray + Denominator values + num_errs : Tuple[np.ndarray, np.ndarray] + Numerator errors (low, high) + den_errs : Tuple[np.ndarray, np.ndarray] + Denominator errors (low, high) + + Returns + ------- + Tuple[np.ndarray, np.ndarray] + Resulting low and high errors + """ + den_sq = den * den + errlo = safe_div( + np.sqrt(num_errs[0]**2 * den_sq + den_errs[0]**2 * num * num), + den_sq * den_sq, + False + ) + errhi = safe_div( + np.sqrt(num_errs[1]**2 * den_sq + den_errs[1]**2 * num * num), + den_sq * den_sq, + False + ) + return errlo, errhi @staticmethod - def _regularize_errors(bin_content: np.ndarray, - bin_errors: Optional[ArrayLike] = None) -> BinErrorType: + def _regularize_errors( + bin_content: np.ndarray, + bin_errors: Optional[ArrayLike] = None + ) -> BinErrors: + """ + Convert bin errors to standard format. + + Converts various error input formats to the standard + (lower_errors, upper_errors) tuple format. + + Parameters + ---------- + bin_content : np.ndarray + The histogram bin content + bin_errors : Optional[ArrayLike], default None + Bin errors in one of several formats: + - None: No errors + - scalar: Same error for all bins + - 1D array: Symmetric errors, length must match bins + - 2D array: Asymmetric errors, shape must be (2, nbins) + - Tuple[array, array]: Asymmetric errors, each array length matches bins + + Returns + ------- + BinErrors + Tuple of (lower_errors, upper_errors) arrays, or None + + Raises + ------ + ValueError + If error array has invalid shape or size + """ if bin_errors is None: return None + size = bin_content.size - ndim = np.ndim(bin_errors) - # make sure bin_errors is a 2-tuple - if ndim == 0: - bin_errors = (bin_errors, bin_errors) - elif ndim == 1: - if isinstance(bin_errors, tuple): - if len(bin_errors) != 2: - raise ValueError(f'`bin_errors` should have size 2 when given as a tuple.') - else: - bin_errors = (bin_errors, bin_errors) - elif ndim == 2: - bin_errors = np.asarray(bin_errors) - if bin_errors.shape[0] != 2: - raise ValueError(f'`bin_errors` should have shape (2, N) when given as an array.') - bin_errors = (bin_errors[0], bin_errors[1]) - else: - raise RuntimeError(f'`bin_errors` should have dimension less than 3.') - assert isinstance(bin_errors, tuple) and (len(bin_errors) == 2) - result = [] - for errors in bin_errors: - if isinstance(errors, Real): - errors = np.full(size, errors, dtype=float) - else: - errors = np.array(errors) - result.append(errors) - if bin_errors[0].shape != bin_errors[1].shape: - raise ValueError(f'upper and lower errors must have the same shape.') - if bin_errors[0].shape != bin_content.shape: - raise ValueError(f'bin content and bin errors (upper or lower) must have the same shape.') - return bin_errors[0], bin_errors[1] + + # Handle scalar errors + if np.isscalar(bin_errors): + err = np.full(size, bin_errors, dtype=float) + return (err, err) + + bin_errors = np.asarray(bin_errors) + + # Handle 1D error array + if bin_errors.ndim == 1: + if bin_errors.size != size: + raise ValueError( + "Error array size must match bin content size" + ) + return (bin_errors, bin_errors) + + # Handle 2D error array + if bin_errors.ndim == 2: + if bin_errors.shape != (2, size): + raise ValueError( + "2D error array must have shape (2, nbins)" + ) + return (bin_errors[0], bin_errors[1]) + + raise ValueError( + f"Error array has invalid dimension: {bin_errors.ndim}" + ) - def set_data(self, bin_content: np.ndarray, - bin_edges: np.ndarray, - bin_errors: Optional[ArrayLike] = None, - error_mode: Union[BinErrorMode, str] = "auto") -> None: + def set_data( + self, + bin_content: np.ndarray, + bin_edges: np.ndarray, + bin_errors: Optional[ArrayLike] = None, + error_mode: Union[BinErrorMode, str] = "auto", + ) -> None: """ - Set the histogram's data including bin content, bin edges, bin errors, and error mode. + Set the histogram data. Parameters ---------- bin_content : np.ndarray - The bin content for the histogram. + The bin contents bin_edges : np.ndarray - The bin edges for the histogram. - bin_errors : ArrayLike, optional - The bin errors. If None, errors will be calculated automatically. - error_mode : BinErrorMode or str, default = "auto" - The error mode to use. + The bin edges + bin_errors : Optional[ArrayLike], default None + The bin errors in any valid format + error_mode : Union[BinErrorMode, str], default "auto" + Error calculation mode + + Raises + ------ + ValueError + If data arrays have invalid shapes or sizes + If bin_content and bin_edges sizes don't match """ - bin_content = np.array(bin_content) + # Validate input arrays + bin_content = np.asarray(bin_content) if bin_content.ndim != 1: - raise ValueError(f'`bin_content` must be a 1D array.') - bin_edges = np.array(bin_edges) + raise ValueError("Bin content must be 1-dimensional") + + bin_edges = np.asarray(bin_edges) if bin_edges.ndim != 1: - raise ValueError(f'`bin_edges` must be a 1D array.') + raise ValueError("Bin edges must be 1-dimensional") + if bin_content.size != (bin_edges.size - 1): - raise RuntimeError(f'number of bins from bin content (= {bin_content.size}) ' - f'is different from that from bin edges (= {bin_edges.size - 1})') + raise ValueError( + f"Expected {bin_edges.size - 1} bins from edges, " + f"got {bin_content.size} from content" + ) + # Create binning object binning = Binning(bins=bin_edges) error_mode = BinErrorMode.parse(error_mode) + # Determine content type and error mode + is_poisson_data = all_integers(bin_content) and not np.all(bin_content == 0) if error_mode == BinErrorMode.AUTO: - error_mode = BinErrorMode.POISSON if all_integers(bin_content) else BinErrorMode.SUMW2 - - # Coerce bin content type based on error mode - if error_mode == BinErrorMode.POISSON: - if (bin_content.dtype != 'int64'): - # cast to int because dtype is used to check whether histogram is weighted - if all_integers(bin_content): - bin_content = bin_content.astype(int) - else: - bin_content = bin_content.astype(float) + error_mode = ( + BinErrorMode.POISSON if is_poisson_data + else BinErrorMode.SUMW2 + ) + + # Set content type based on error mode + if error_mode == BinErrorMode.POISSON and is_poisson_data: + bin_content = bin_content.astype(np.int64) else: bin_content = bin_content.astype(float) + # Handle errors bin_errors = self._regularize_errors(bin_content, bin_errors) - if (error_mode == BinErrorMode.POISSON) and (bin_errors is None): + if error_mode == BinErrorMode.POISSON and bin_errors is None: bin_errors = poisson_interval(bin_content) - elif (error_mode == BinErrorMode.SUMW2) and bin_errors is not None: - if not np.allclose(bin_errors[0], bin_errors[1]): - raise ValueError('the given bin errors are not symmetric although the error mode is sumw2') + # Set attributes self._bin_content = bin_content self._binning = binning self._bin_errors = bin_errors self._error_mode = error_mode + self._bin_content_raw = None + self._bin_errors_raw = None + self._mask = None + + @classmethod + def create( + cls, + x: np.ndarray, + weights: Optional[np.ndarray] = None, + bins: Union[int, ArrayLike] = 10, + bin_range: Optional[ArrayLike] = None, + underflow: bool = False, + overflow: bool = False, + divide_bin_width: bool = False, + normalize: bool = False, + clip_weight: bool = False, + evaluate_error: bool = True, + error_mode: Union[BinErrorMode, str] = "auto", + ) -> Histogram1D: + """ + Create a histogram from array data. + Parameters + ---------- + x : np.ndarray + Input data to histogram + weights : Optional[np.ndarray], default None + Optional weights for each data point + bins : Union[int, ArrayLike], default 10 + Number of bins or bin edges + bin_range : Optional[ArrayLike], default None + Optional (min, max) range for binning + underflow : bool, default False + Include underflow in first bin + overflow : bool, default False + Include overflow in last bin + divide_bin_width : bool, default False + Normalize by bin width + normalize : bool, default False + Normalize histogram to unit area + clip_weight : bool, default False + Ignore out-of-range weights + evaluate_error : bool, default True + Calculate bin errors + error_mode : Union[BinErrorMode, str], default "auto" + Error calculation mode + + Returns + ------- + Histogram1D + New histogram instance + """ + bin_content, bin_edges, bin_errors = histogram( + x=x, + weights=weights, + bins=bins, + bin_range=bin_range, + underflow=underflow, + overflow=overflow, + divide_bin_width=divide_bin_width, + normalize=normalize, + clip_weight=clip_weight, + evaluate_error=evaluate_error, + error_mode=error_mode, + ) + + return cls( + bin_content=bin_content, + bin_edges=bin_edges, + bin_errors=bin_errors, + error_mode=error_mode, + ) + @property def bin_content(self) -> np.ndarray: - """Returns a copy of the bin content.""" + """Get copy of bin content array.""" return self._bin_content.copy() @property def binning(self) -> Binning: - """Returns the binning information.""" + """Get binning object.""" return self._binning @property def bin_edges(self) -> np.ndarray: - """Returns the edges of the bins.""" + """Get bin edges array.""" return self._binning.bin_edges @property def bin_centers(self) -> np.ndarray: - """Returns the centers of the bins.""" + """Get bin centers array.""" return self._binning.bin_centers @property def bin_widths(self) -> np.ndarray: - """Returns the widths of the bins.""" + """Get bin widths array.""" return self._binning.bin_widths @property def nbins(self) -> int: - """Returns the number of bins.""" + """Get number of bins.""" return self._binning.nbins @property def bin_range(self) -> Tuple[float, float]: - """Returns the bin range as a tuple (low, high).""" + """Get (min, max) bin range.""" return self._binning.bin_range @property def uniform_binning(self) -> bool: """Check if binning is uniform.""" - return self._binning.is_uniform + return self._binning.is_uniform() @property - def bin_errors(self) -> BinErrorType: - """Returns the bin errors.""" + def bin_errors(self) -> BinErrors: + """ + Get bin errors. + + Returns + ------- + BinErrors + Tuple of (lower_errors, upper_errors) arrays or None + """ if self._bin_errors is None: return None - errlo = self._bin_errors[0].copy() - errhi = self._bin_errors[1].copy() - return errlo, errhi + return ( + self._bin_errors[0].copy(), + self._bin_errors[1].copy() + ) @property def bin_errlo(self) -> Optional[np.ndarray]: - """Returns the lower bin errors.""" + """Get lower bin errors array.""" if self._bin_errors is None: return None return self._bin_errors[0].copy() @property def bin_errhi(self) -> Optional[np.ndarray]: - """Returns the upper bin errors.""" + """Get upper bin errors array.""" if self._bin_errors is None: return None return self._bin_errors[1].copy() + @property + def rel_bin_errors(self) -> BinErrors: + """Get relative bin errors with content.""" + if self._bin_errors is None: + return None + errlo = self._bin_content - self._bin_errors[0] + errhi = self._bin_content + self._bin_errors[1] + return (errlo, errhi) + + @property + def rel_bin_errlo(self) -> Optional[np.ndarray]: + """Get relative lower bin errors with content.""" + if self._bin_errors is None: + return None + return self._bin_content - self._bin_errors[0] + + @property + def rel_bin_errhi(self) -> Optional[np.ndarray]: + """Get relative upper bin errors with content.""" + if self._bin_errors is None: + return None + return self._bin_content + self._bin_errors[1] + @property def error_mode(self) -> BinErrorMode: - """Returns the current error mode.""" + """Get current error mode.""" return self._error_mode - @classmethod - def create(cls, x:np.ndarray, weights:Optional[np.ndarray]=None, - bins:Union[int, ArrayLike]=10, - bin_range:Optional[ArrayLike]=None, - underflow:bool=False, - overflow:bool=False, - divide_bin_width:bool=False, - normalize:bool=True, - clip_weight:bool=False, - evaluate_error:bool=False, - error_mode:Union[BinErrorMode, str]="auto", - **kwargs): - """ - Create histogram from unbinned data. - - Arguments: - ------------------------------------------------------------------------------- - x: ndarray - Input data array from which the histogram is computed. - weights: (optional) ndarray - Array of weights with same shape as input data. If not given, the - input data is assumed to have unit weights. - bins: (optional) int or array of scalars, default = 10 - If integer, it defines the number of equal-width bins in the - given range. - If an array, it defines a monotonically increasing array of bin edges, - including the rightmost edge. - bin_range: (optional) array of the form (float, float) - The lower and upper range of the bins. If not provided, range is simply - ``(x.min(), x.max())``. Values outside the range are ignored. - underflow: bool, default = False - Include undeflow data in the first bin. - overflow: bool, default = False - Include overflow data in the last bin. - divide_bin_width: bool, default = False - Divide each bin by the bin width. - normalize: bool, default = True - Normalize the sum of weights to one. Weights outside the bin range will - not be counted if ``clip_weight`` is set to false, so the sum of bin - content could be less than one. - clip_weight: bool, default = False - Ignore data outside given range when evaluating total weight - used in normalization. - evaluate_error: bool, default = True - Evaluate the error of the bin contents using the given error option. - error_mode: BinErrorMode or str, default = "auto" - How to evaluate bin errors. If "sumw2", symmetric errors from the Wald - approximation is used (square root of sum of squares of weights). If - "poisson", asymmetric errors from Poisson interval at one sigma is - used. If "auto", it will use sumw2 error if data has unit weights, - else Poisson error will be used. - """ - bin_content, bin_edges, bin_errors = histogram(x=x, - weights=weights, - bins=bins, - bin_range=bin_range, - underflow=underflow, - overflow=overflow, - divide_bin_width=divide_bin_width, - normalize=normalize, - clip_weight=clip_weight, - evaluate_error=evaluate_error, - error_mode=error_mode) - instance = cls(bin_content=bin_content, - bin_edges=bin_edges, - bin_errors=bin_errors, - error_mode=error_mode) - return instance + @property + def bin_mask(self) -> np.ndarray: + """Get bin mask array if any.""" + if self._mask is None: + return None + return self._mask.copy() + + def has_errors(self) -> bool: + """Check if histogram has errors.""" + return self._bin_errors is not None def is_weighted(self) -> bool: - """Check if the histogram is weighted (i.e., non-integer bin content).""" - return self._bin_content.dtype != 'int64' + """ + Check if histogram is weighted. - def scale(self, val: Union[Real, ArrayLike], inplace: bool = False) -> "Histogram1D": - """Scale the histogram by a scalar value.""" - if inplace: - return self._ioperate('scale', val) - return self._operate('scale', val) + Returns True if bin content is non-integer type. + """ + return self._bin_content.dtype != np.int64 + + def is_empty(self) -> bool: + """Check if histogram is empty (zero sum).""" + return np.sum(self._bin_content) == 0 + + def sum(self) -> Union[float, int]: + """ + Calculate sum of bin contents. - def sum(self) -> float: + Returns + ------- + Union[float, int] + Sum, preserving integer type if unweighted + """ return self._bin_content.sum() def integral(self) -> float: - return (self.bin_widths * self._bin_content).sum() + """ + Calculate histogram integral. + + Returns + ------- + float + Integral (sum of bin contents * bin widths) + """ + return np.sum(self.bin_widths * self._bin_content) + + def copy(self) -> Histogram1D: + """ + Create a deep copy of histogram. + + Returns + ------- + Histogram1D + New histogram instance with copied data + """ + instance = type(self)( + bin_content=self.bin_content, + bin_edges=self.bin_edges, + bin_errors=self.bin_errors, + error_mode=self.error_mode, + ) + if self.is_masked(): + instance._bin_content_raw = self._bin_content_raw.copy() + if self.has_errors(): + instance._bin_errors_raw = ( + self._bin_errors_raw[0].copy(), + self._bin_errors_raw[1].copy(), + ) + instance._mask = self._mask.copy() + return instance + + def mask( + self, + condition: Union[Sequence[float], Callable] + ) -> None: + """ + Apply mask to histogram data. + + Parameters + ---------- + condition : Union[Sequence[float], Callable] + Either [min, max] range or function returning bool + for each bin + + Examples + -------- + >>> hist.mask([1.0, 2.0]) # Mask bins outside [1, 2] + >>> hist.mask(lambda x: x > 0) # Mask bins where x <= 0 + """ + x = self.bin_centers + has_errors = self.has_errors() + + # Store raw data if needed + if self._bin_content_raw is None: + self._bin_content_raw = self._bin_content.copy() + if has_errors: + self._bin_errors_raw = ( + self._bin_errors[0].copy(), + self._bin_errors[1].copy(), + ) + y = self._bin_content + yerr = self._bin_errors + else: + y = self._bin_content_raw + yerr = self._bin_errors_raw - def normalize(self, density: bool = False, inplace: bool = False) -> "Histogram1D": - norm_factor = self.sum() + mask = get_histogram_mask(x=x, y=y, condition=condition) + y[mask] = 0 + if has_errors: + yerr[0][mask] = 0.0 + yerr[1][mask] = 0.0 + self._mask = mask + + def unmask(self) -> None: + """Remove mask and restore original data.""" + if self.is_masked(): + self._bin_content = self._bin_content_raw + self._bin_errors = self._bin_errors_raw + self._bin_content_raw = None + self._bin_errors_raw = None + self._mask = None + + def is_masked(self) -> bool: + """Check if histogram has mask applied.""" + return self._bin_content_raw is not None + + def scale( + self, + val: Union[Real, ArrayLike], + inplace: bool = False + ) -> Histogram1D: + """ + Scale histogram by value. + + Parameters + ---------- + val : Union[Real, ArrayLike] + Scale factor(s) + inplace : bool, default False + If True, modify in place + + Returns + ------- + Histogram1D + Scaled histogram + """ + if inplace: + return self._ioperate("scale", val) + return self._operate("scale", val) + + def normalize( + self, + density: bool = False, + inplace: bool = False + ) -> Histogram1D: + """ + Normalize histogram. + + Parameters + ---------- + density : bool, default False + If True, normalize by bin widths + inplace : bool, default False + If True, modify in place + + Returns + ------- + Histogram1D + Normalized histogram + """ + norm_factor = float(self.sum()) + if norm_factor == 0: + return self.copy() if not inplace else self + if density: norm_factor *= self.bin_widths + if inplace: - return self._ioperate('div', norm_factor) - return self._operate('div', norm_factor) - \ No newline at end of file + return self._ioperate("div", norm_factor) + return self._operate("div", norm_factor) + + def compare( + self, + reference: Histogram1D, + mode: ComparisonMode = "ratio" + ) -> Histogram1D: + """ + Compare with reference histogram. + + Parameters + ---------- + reference : Histogram1D + Reference histogram + mode : ComparisonMode, default "ratio" + Comparison mode ("ratio", "difference" or callable) + + Returns + ------- + Histogram1D + Comparison histogram + + Raises + ------ + ValueError + For invalid comparison mode + """ + if callable(mode): + return mode(self, reference) + + mode = HistComparisonMode.parse(mode) + if mode == HistComparisonMode.RATIO: + return self / reference + elif mode == HistComparisonMode.DIFFERENCE: + return self - reference + + raise ValueError(f"Unknown comparison mode: {mode}") \ No newline at end of file diff --git a/quickstats/concepts/stacked_histogram.py b/quickstats/concepts/stacked_histogram.py new file mode 100644 index 0000000000000000000000000000000000000000..14f5819fe34ae872677e93a4f8c4448aa3887987 --- /dev/null +++ b/quickstats/concepts/stacked_histogram.py @@ -0,0 +1,498 @@ +from __future__ import annotations + +from typing import ( + Optional, Union, Tuple, Any, Callable, Dict, List, + TypeVar +) +try: + from collections.abc import Iterable +except ImportError: + from typing import Iterable + +import numpy as np + +from quickstats import stdout +from quickstats.core.typing import ArrayLike, ArrayContainer +from quickstats.maths.histograms import BinErrorMode +from .binning import Binning +from .histogram1d import Histogram1D, BinErrors + +T = TypeVar('T', bound='StackedHistogram') +HistKey = Union[str, int] +HistDict = Dict[HistKey, Histogram1D] +HistList = Union[Dict[str, Histogram1D], List[Histogram1D]] +ConditionType = Union[Tuple[float, ...], Callable] + +def deduce_bin_range(*arrays: ArrayLike) -> Optional[Tuple[float, float]]: + """ + Deduce the global range across multiple arrays. + + Parameters + ---------- + *arrays : ArrayLike + Arrays to analyze + + Returns + ------- + Optional[Tuple[float, float]] + (min, max) range or None if no arrays provided + """ + if not arrays: + return None + + min_range = min(np.min(array) for array in arrays) + max_range = max(np.max(array) for array in arrays) + return (min_range, max_range) + +class StackedHistogram: + """ + A class for managing stacked histograms. + + This class provides functionality for working with multiple histograms + that can be stacked together, supporting both named and indexed access. + + Attributes + ---------- + histograms : Dict[HistKey, Histogram1D] + The component histograms + bin_content : np.ndarray + Total stacked bin content + bin_edges : np.ndarray + Bin edge locations + bin_errors : BinErrors + Combined bin errors if available + """ + + def __init__( + self, + histograms: Optional[HistList] = None + ) -> None: + """ + Initialize StackedHistogram. + + Parameters + ---------- + histograms : Optional[HistList], default None + Initial histograms to stack. Can be: + - Dictionary mapping names to histograms + - List of histograms (accessed by index) + """ + self.reset() + if histograms is not None: + self.set_histograms(histograms) + + def __getitem__(self, key: HistKey) -> Histogram1D: + """Get histogram by name or index.""" + try: + return self._histograms[key] + except KeyError as e: + raise KeyError(f"Histogram not found: {key}") from e + + @classmethod + def create( + cls: type[T], + sample: Union[Dict[str, ArrayLike], ArrayContainer], + weights: Optional[Union[Dict[str, ArrayLike], ArrayContainer]] = None, + bins: Union[int, ArrayLike] = 10, + bin_range: Optional[ArrayLike] = None, + underflow: bool = False, + overflow: bool = False, + divide_bin_width: bool = False, + normalize: bool = False, + clip_weight: bool = False, + evaluate_error: bool = True, + error_mode: Union[BinErrorMode, str] = "auto", + **kwargs: Any + ) -> T: + """ + Create stacked histogram from unbinned data. + + Parameters + ---------- + sample : Union[Dict[str, ArrayLike], ArrayContainer] + Input data arrays. Can be: + - Dictionary mapping names to arrays + - List/tuple/array of arrays (accessed by index) + weights : Optional[Union[Dict[str, ArrayLike], ArrayContainer]], default None + Optional weights for each sample + bins : Union[int, ArrayLike], default 10 + Number of bins or bin edges + bin_range : Optional[ArrayLike], default None + Optional (min, max) range for binning + underflow : bool, default False + Include underflow in first bin + overflow : bool, default False + Include overflow in last bin + divide_bin_width : bool, default False + Normalize by bin width + normalize : bool, default False + Normalize stacked result + clip_weight : bool, default False + Ignore out-of-range weights + evaluate_error : bool, default True + Calculate bin errors + error_mode : Union[BinErrorMode, str], default "auto" + Error calculation mode + **kwargs : Any + Additional histogram creation options + + Returns + ------- + T + New StackedHistogram instance + + Raises + ------ + ValueError + If sample is empty + If incompatible arrays provided + TypeError + If invalid sample type + """ + if not sample: + raise ValueError('Empty sample provided') + + try: + if isinstance(sample, dict): + keys = list(sample.keys()) + indexed = False + elif isinstance(sample, (tuple, list, np.ndarray)): + keys = list(range(len(sample))) + indexed = True + else: + raise TypeError( + 'Sample must be dictionary or sequence of arrays' + ) + + if bin_range is None: + bin_range = deduce_bin_range( + *(sample[key] for key in keys) + ) + + if weights is None: + weights = {k: None for k in keys} + + histograms: Union[List[Optional[Histogram1D]], Dict[str, Histogram1D]] + histograms = [None] * len(keys) if indexed else {} + + for key in keys: + histograms[key] = Histogram1D.create( + x=sample[key], + weights=weights[key], + bins=bins, + bin_range=bin_range, + underflow=underflow, + overflow=overflow, + divide_bin_width=False, # Handle later + normalize=False, # Handle later + clip_weight=clip_weight, + evaluate_error=evaluate_error, + error_mode=error_mode, + **kwargs + ) + + instance = cls(histograms=histograms) + if normalize: + instance.normalize( + density=divide_bin_width, + inplace=True + ) + return instance + + except Exception as e: + if isinstance(e, (ValueError, TypeError)): + raise + raise ValueError(f"Failed to create stacked histogram: {str(e)}") from e + + @property + def bin_content(self) -> np.ndarray: + """Get total stacked bin content.""" + return self._stacked_histogram.bin_content + + @property + def binning(self) -> Binning: + """Get binning object.""" + return self._stacked_histogram.binning + + @property + def bin_edges(self) -> np.ndarray: + """Get bin edges array.""" + return self._stacked_histogram.bin_edges + + @property + def bin_centers(self) -> np.ndarray: + """Get bin centers array.""" + return self._stacked_histogram.bin_centers + + @property + def bin_widths(self) -> np.ndarray: + """Get bin widths array.""" + return self._stacked_histogram.bin_widths + + @property + def nbins(self) -> int: + """Get number of bins.""" + return self._stacked_histogram.nbins + + @property + def bin_range(self) -> Tuple[float, float]: + """Get (min, max) bin range.""" + return self._stacked_histogram.bin_range + + @property + def uniform_binning(self) -> bool: + """Check if binning is uniform.""" + return self._stacked_histogram.uniform_binning + + @property + def bin_errors(self) -> BinErrors: + """Get total stacked bin errors.""" + return self._stacked_histogram.bin_errors + + @property + def bin_errlo(self) -> Optional[np.ndarray]: + """Get lower bin errors array.""" + return self._stacked_histogram.bin_errlo + + @property + def bin_errhi(self) -> Optional[np.ndarray]: + """Get upper bin errors array.""" + return self._stacked_histogram.bin_errhi + + @property + def rel_bin_errors(self) -> BinErrors: + """Get relative bin errors with content.""" + return self._stacked_histogram.rel_bin_errors + + @property + def rel_bin_errlo(self) -> Optional[np.ndarray]: + """Get relative lower bin errors with content.""" + return self._stacked_histogram.rel_bin_errlo + + @property + def rel_bin_errhi(self) -> Optional[np.ndarray]: + """Get relative upper bin errors with content.""" + return self._stacked_histogram.rel_bin_errhi + + @property + def error_mode(self) -> BinErrorMode: + """Get current error mode.""" + return self._stacked_histogram.error_mode + + @property + def bin_mask(self) -> Optional[np.ndarray]: + """Get bin mask array if any.""" + return self._stacked_histogram.bin_mask + + @property + def histograms(self) -> HistDict: + """Get dictionary of component histograms.""" + return self._histograms + + def offset_histograms(self) -> Iterator[Tuple[HistKey, Histogram1D]]: + """ + Generate histograms with cumulative offsets for stacking. + + Each histogram is offset by the sum of all previous histograms, + creating the stacked effect. + + Yields + ------ + Tuple[HistKey, Histogram1D] + (name/index, offset histogram) pairs + + Examples + -------- + >>> for name, hist in stacked.offset_histograms(): + ... plt.fill_between(hist.bin_centers, hist.bin_content) + """ + base_histogram = Histogram1D( + bin_content=np.zeros_like(self.bin_content), + bin_edges=self.bin_edges + ) + for name, histogram in self._histograms.items(): + offset_hist = histogram + base_histogram + yield name, offset_hist + base_histogram += histogram + + def reset(self) -> None: + """Reset to initial empty state.""" + self._indexed = True + self._histograms = {} + self._stacked_histogram = Histogram1D( + bin_content=np.array([0]), + bin_edges=np.array([0, 1]) + ) + + def add_histogram( + self, + histogram: Histogram1D, + name: Optional[HistKey] = None + ) -> None: + """ + Add histogram to stack. + + Parameters + ---------- + histogram : Histogram1D + Histogram to add + name : Optional[HistKey], default None + Name for histogram if using named access + + Raises + ------ + TypeError + If histogram is not Histogram1D + ValueError + If name handling is inconsistent with current mode + """ + if not isinstance(histogram, Histogram1D): + raise TypeError('Histogram must be Histogram1D instance') + + if self.indexed: + if name is not None and name != self.count: + raise ValueError( + 'Cannot specify histogram name in indexed mode' + ) + name = self.count + else: + if name is None: + raise ValueError( + 'Must specify histogram name in named mode' + ) + + if self.is_empty(): + self._stacked_histogram = histogram.copy() + else: + self._stacked_histogram += histogram + self._histograms[name] = histogram + + def set_histograms(self, histograms: HistList) -> None: + """ + Set multiple histograms to stack. + + Parameters + ---------- + histograms : HistList + Dictionary or list of histograms + + Raises + ------ + TypeError + If histograms has invalid type + """ + self.reset() + if isinstance(histograms, dict): + self._indexed = False + for name, histogram in histograms.items(): + self.add_histogram(histogram, name=name) + elif isinstance(histograms, (list, tuple)): + self._indexed = True + for histogram in histograms: + self.add_histogram(histogram) + else: + raise TypeError( + 'Histograms must be dictionary or sequence' + ) + + @property + def indexed(self) -> bool: + """Check if using indexed access mode.""" + return self._indexed + + @property + def count(self) -> int: + """Get number of histograms in stack.""" + return len(self._histograms) + + def is_empty(self) -> bool: + """Check if stack is empty.""" + return self._stacked_histogram.is_empty() + + def is_weighted(self) -> bool: + """Check if stack contains weighted histograms.""" + return self._stacked_histogram.is_weighted() + + def sum(self) -> Union[float, int]: + """Get sum of all bin contents.""" + return self._stacked_histogram.sum() + + def integral(self) -> float: + """Get integral (sum * bin widths).""" + return self._stacked_histogram.integral() + + def normalize( + self, + density: bool = False, + inplace: bool = False + ) -> StackedHistogram: + """ + Normalize stacked histograms. + + Parameters + ---------- + density : bool, default False + Normalize by bin widths + inplace : bool, default False + Modify in place + + Returns + ------- + StackedHistogram + Normalized stack + """ + result = self if inplace else self.copy() + count = self.count + if count == 0: + return result + + for histogram in result._histograms.values(): + histogram.normalize(density=density, inplace=True) + histogram /= count + + result._stacked_histogram.normalize(density=density, inplace=True) + return result + + def copy(self) -> StackedHistogram: + """ + Create deep copy of stack. + + Returns + ------- + StackedHistogram + New instance with copied data + """ + histograms = { + key: histogram.copy() + for key, histogram in self._histograms.items() + } + instance = type(self)() + instance._indexed = self._indexed + instance._histograms = histograms + instance._stacked_histogram = self._stacked_histogram.copy() + return instance + + def mask(self, condition: ConditionType) -> None: + """ + Apply mask to stacked histogram. + + Parameters + ---------- + condition : ConditionType + Masking condition: + - Tuple of bin range limits + - Function returning bool for each bin + """ + self._stacked_histogram.mask(condition) + + def unmask(self) -> None: + """Remove mask and restore original data.""" + self._stacked_histogram.unmask() + + def is_masked(self) -> bool: + """Check if stack has mask applied.""" + return self._stacked_histogram.is_masked() + + def has_errors(self) -> bool: + """Check if stack has error information.""" + return self._stacked_histogram.has_errors() \ No newline at end of file diff --git a/quickstats/concepts/version.py b/quickstats/concepts/version.py index b52af28c59bd40182edc85f20c7f8b48e4b64df2..6f4db3aa31af5c98ad67603764c59237b3aa02c3 100644 --- a/quickstats/concepts/version.py +++ b/quickstats/concepts/version.py @@ -1,113 +1,170 @@ -from typing import Union, Tuple +from __future__ import annotations + +from typing import Union, Tuple, Optional import re +from functools import total_ordering + +@total_ordering class Version: - """ - A class to represent a package version. + """A class to represent and compare package versions. + + This class handles version numbers in both string ("x.y.z") and tuple ((x,y,z)) + formats, providing comparison operations and string representations. Parameters ---------- - version : str or tuple - The version information. This can be a string in the format "major.minor.micro" or "major.minor", - or a tuple of the form (major, minor, micro) or (major, minor). + version : Union[str, Tuple[int, ...], 'Version'] + The version information in one of the following formats: + - String: "major.minor.micro" or "major.minor" + - Tuple: (major, minor, micro) or (major, minor) + - Version: another Version instance Attributes ---------- major : int - Major version number. + The major version number minor : int - Minor version number. + The minor version number micro : int - Micro version number. + The micro (patch) version number Raises ------ ValueError - If the input version string or tuple is in an invalid format. + If the version format is invalid or contains non-integer values + TypeError + If the version input is of an unsupported type + + Examples + -------- + >>> v1 = Version("1.2.3") + >>> v2 = Version((1, 2)) + >>> v1 > v2 + True + >>> str(v1) + '1.2.3' """ - def __init__(self, version: Union[str, Tuple[int, int, int], Tuple[int, int]]): + # Regular expression for validating version strings + _VERSION_PATTERN = re.compile(r'^\d+(\.\d+){1,2}$') + + def __init__( + self, + version: Union[str, Tuple[int, ...], 'Version'] + ) -> None: + """Initialize a Version instance.""" + self.major: int = 0 + self.minor: int = 0 + self.micro: int = 0 + if isinstance(version, str): self._parse_version_string(version) elif isinstance(version, tuple): self._parse_version_tuple(version) elif isinstance(version, Version): - self.major, self.minor, self.micro = version.major, version.minor, version.micro + self._copy_version(version) else: - raise ValueError("Version must be a string or a tuple") - - def _parse_version_string(self, version: str): - """ - Parse the version string and set the major, minor, and micro attributes. + raise TypeError( + f"Version must be a string, tuple, or Version instance, " + f"not {type(version).__name__}" + ) + + def _parse_version_string(self, version: str) -> None: + """Parse a version string into its components. Parameters ---------- version : str - The version string to parse. + Version string in the format "x.y.z" or "x.y" Raises ------ ValueError - If the version string is not in the correct format. + If the version string format is invalid """ - pattern = r'^\d+(\.\d+){1,2}$' - if not re.match(pattern, version): - raise ValueError("Invalid version string format") - parts = list(map(int, version.split('.'))) + if not self._VERSION_PATTERN.match(version): + raise ValueError( + "Invalid version string format. Expected 'x.y.z' or 'x.y'" + ) + + try: + parts = [int(part) for part in version.split('.')] + except ValueError as e: + raise ValueError("Version components must be valid integers") from e + if len(parts) == 2: - parts.append(0) - self.major, self.minor, self.micro = parts + self.major, self.minor = parts + self.micro = 0 + else: + self.major, self.minor, self.micro = parts - def _parse_version_tuple(self, version: Tuple[int, int, int]): - """ - Parse the version tuple and set the major, minor, and micro attributes. + def _parse_version_tuple(self, version: Tuple[int, ...]) -> None: + """Parse a version tuple into its components. Parameters ---------- - version : tuple - The version tuple to parse. + version : Tuple[int, ...] + Version tuple in the format (x, y, z) or (x, y) Raises ------ ValueError - If the version tuple is not in the correct format or contains non-integer elements. + If the tuple length is invalid or contains non-integer values """ + if not (2 <= len(version) <= 3): + raise ValueError("Version tuple must have 2 or 3 elements") + if not all(isinstance(part, int) for part in version): - raise ValueError("All elements of the version tuple must be integers") + raise ValueError("All version components must be integers") + if len(version) == 2: self.major, self.minor = version self.micro = 0 - elif len(version) == 3: - self.major, self.minor, self.micro = version else: - raise ValueError("Version tuple must have 2 or 3 elements") - - def __eq__(self, other): - other = Version(other) - return (self.major, self.minor, self.micro) == (other.major, other.minor, other.micro) - - def __ne__(self, other): - other = Version(other) - return not self.__eq__(other) - - def __gt__(self, other): - other = Version(other) - return (self.major, self.minor, self.micro) > (other.major, other.minor, other.micro) - - def __ge__(self, other): - other = Version(other) - return (self.major, self.minor, self.micro) >= (other.major, other.minor, other.micro) - - def __lt__(self, other): - other = Version(other) - return (self.major, self.minor, self.micro) < (other.major, other.minor, other.micro) - - def __le__(self, other): - other = Version(other) - return (self.major, self.minor, self.micro) <= (other.major, other.minor, other.micro) - - def __repr__(self): - return f"Version(major={self.major}, minor={self.minor}, micro={self.micro})" - - def __str__(self): + self.major, self.minor, self.micro = version + + def _copy_version(self, version: 'Version') -> None: + """Copy version components from another Version instance. + + Parameters + ---------- + version : Version + Source Version instance to copy from + """ + self.major = version.major + self.minor = version.minor + self.micro = version.micro + + def to_tuple(self) -> Tuple[int, int, int]: + """Convert the version to a tuple representation. + + Returns + ------- + Tuple[int, int, int] + Version components as a tuple (major, minor, micro) + """ + return (self.major, self.minor, self.micro) + + def __eq__(self, other: object) -> bool: + """Compare two versions for equality.""" + if not isinstance(other, (Version, str, tuple)): + return NotImplemented + other_version = Version(other) if not isinstance(other, Version) else other + return self.to_tuple() == other_version.to_tuple() + + def __lt__(self, other: Union[Version, str, Tuple[int, ...]]) -> bool: + """Compare if this version is less than another version.""" + if not isinstance(other, (Version, str, tuple)): + return NotImplemented + other_version = Version(other) if not isinstance(other, Version) else other + return self.to_tuple() < other_version.to_tuple() + + def __repr__(self) -> str: + """Return a detailed string representation of the Version instance.""" + return (f"{self.__class__.__name__}(" + f"major={self.major}, minor={self.minor}, micro={self.micro})") + + def __str__(self) -> str: + """Return a string representation of the version.""" return f"{self.major}.{self.minor}.{self.micro}" \ No newline at end of file diff --git a/quickstats/core/__init__.py b/quickstats/core/__init__.py index 65daa5164968544ca4cedabdc33908175fa8eb76..72aeb76c715e8f055498da9e8c79dde684432240 100644 --- a/quickstats/core/__init__.py +++ b/quickstats/core/__init__.py @@ -6,8 +6,12 @@ from .setup import * from .typing import * from .type_validation import * from .configuration import * -from .virtual_trees import * +from .parameters import * +from .virtual_trees import TVirtualNode, TVirtualTree from .path_manager import PathManager from .flexible_dumper import FlexibleDumper +from .trees import NamedTreeNode +from . import mappings +from .mappings import NestedDict from .abstract_object import AbstractObject \ No newline at end of file diff --git a/quickstats/core/abstract_object.py b/quickstats/core/abstract_object.py index 42bdbed9bae806b15421a126e742a4b8f8b6ec72..a73f1b861dedf714f7f1b17102905c63d31ff5b7 100644 --- a/quickstats/core/abstract_object.py +++ b/quickstats/core/abstract_object.py @@ -1,25 +1,205 @@ -from typing import Optional, Union, List, Dict +""" +Base class providing output and verbosity control functionality. -from .io import VerbosePrint +This module defines the AbstractObject class which serves as a base for objects +requiring configurable verbosity and output handling. +""" -class AbstractObject(object): - - stdout = VerbosePrint("INFO") +from __future__ import annotations + +from typing import Any, ClassVar, Optional, Union, Type + +from .decorators import hybridproperty +from .io import Verbosity, VerbosePrint + + +class AbstractObject: + """ + Base class with verbosity control and standard output management. + This class provides a foundation for objects that need configurable + output verbosity and standardized output handling. It maintains both + class-level and instance-level output controls. + + Parameters + ---------- + verbosity : Optional[Union[int, str, Verbosity]] + Verbosity level for output control. If None, uses class default. + **kwargs : Any + Additional keyword arguments for subclasses. + + Attributes + ---------- + stdout : VerbosePrint + Output handler with verbosity control. + + Examples + -------- + >>> class MyObject(AbstractObject): + ... def process(self): + ... self.stdout.info("Processing...") + ... + >>> obj = MyObject(verbosity="DEBUG") + >>> obj.stdout.debug("Debug message") + [DEBUG] Debug message + """ + + # Class-level default output handler + _class_stdout: ClassVar[VerbosePrint] = VerbosePrint(Verbosity.INFO) + + def __init__( + self, + verbosity: Optional[Union[int, str, Verbosity]] = None, + **kwargs: Any + ) -> None: + """ + Initialize AbstractObject with specified verbosity. + + Parameters + ---------- + verbosity : Optional[Union[int, str, Verbosity]] + Verbosity level for output control + **kwargs : Any + Additional keyword arguments for subclasses + """ + self._stdout: Optional[VerbosePrint] = None + self.set_verbosity(verbosity) + super().__init__() + + @hybridproperty + def stdout(cls) -> VerbosePrint: + """Get class-level output handler.""" + return cls._class_stdout + + @stdout.instance + def stdout(self) -> VerbosePrint: + """Get instance-level output handler.""" + return getattr(self, '_stdout', self.__class__._class_stdout) + @property - def debug_mode(self): - return self.stdout._verbosity._name_ == "DEBUG" - - def __init__(self, verbosity:Optional[Union[int, str]]="INFO", **kwargs): - - if verbosity is None: - self.stdout = AbstractObject.stdout + def debug_mode(self) -> bool: + """ + Check if debug mode is enabled. + + Returns + ------- + bool + True if current verbosity is set to DEBUG + """ + return self.stdout.verbosity == Verbosity.DEBUG + + def set_verbosity( + self, + verbosity: Optional[Union[int, str, Verbosity]] + ) -> None: + """ + Change output verbosity level. + + This method detaches the instance from the class-level output handler + and creates a new instance-specific handler with the specified verbosity. + + Parameters + ---------- + verbosity : Optional[Union[int, str, Verbosity]] + New verbosity level. If None, uses class default. + + Examples + -------- + >>> obj = AbstractObject() + >>> obj.set_verbosity("DEBUG") + >>> obj.debug_mode + True + + Raises + ------ + ValueError + If verbosity level is invalid + """ + if verbosity is not None: + # Create new VerbosePrint instance + # VerbosePrint constructor validates verbosity + self._stdout = VerbosePrint(verbosity) else: - self.stdout = VerbosePrint(verbosity) - - def set_verbosity(self, verbosity:Optional[Union[int, str]]): + # Use class-level stdout + self._stdout = None + + @classmethod + def set_default_verbosity( + cls, + verbosity: Union[int, str, Verbosity] + ) -> None: + """ + Set default verbosity for all new instances. + + This method changes the class-level default verbosity which affects + all instances using the class-level handler. + + Parameters + ---------- + verbosity : Union[int, str, Verbosity] + New default verbosity level + + Examples + -------- + >>> AbstractObject.set_default_verbosity("DEBUG") + >>> obj = AbstractObject() # Will use DEBUG level + """ + cls._class_stdout = VerbosePrint(verbosity) + + def copy_verbosity_from(self, other: AbstractObject) -> None: + """ + Copy verbosity settings from another instance. + + Parameters + ---------- + other : AbstractObject + Instance to copy verbosity from + + Examples + -------- + >>> obj1 = AbstractObject(verbosity="DEBUG") + >>> obj2 = AbstractObject() + >>> obj2.copy_verbosity_from(obj1) + >>> obj2.debug_mode + True + """ + if not hasattr(other, '_stdout') or other._stdout is None: + # Other instance uses class-level stdout + self._stdout = None + else: + # Copy instance-level stdout + self._stdout = other.stdout.copy() + + def __getstate__(self) -> dict: + """ + Support for pickling. + + Returns + ------- + dict + State dictionary for pickling + """ + state = self.__dict__.copy() + if hasattr(self, '_stdout') and self._stdout is not None: + # Only store verbosity level if instance has custom stdout + state['_verbosity_level'] = self._stdout.verbosity + state.pop('_stdout', None) + return state + + def __setstate__(self, state: dict) -> None: """ - Change the verbosity of the current class. This will detach the class's standard output - from the centrally-managed standard output. + Support for unpickling. + + Parameters + ---------- + state : dict + State dictionary from pickling """ - self.stdout = VerbosePrint(verbosity) \ No newline at end of file + verbosity_level = state.pop('_verbosity_level', None) + self.__dict__.update(state) + + # Recreate VerbosePrint instance if needed + if verbosity_level is not None: + self._stdout = VerbosePrint(verbosity_level) + else: + self._stdout = None \ No newline at end of file diff --git a/quickstats/core/configuration.py b/quickstats/core/configuration.py index f3d165117ff31c395ffc87aa6fc9af78add6e541..171c93b0c1486d81f9c84d6d171d0311fd24a89d 100644 --- a/quickstats/core/configuration.py +++ b/quickstats/core/configuration.py @@ -16,7 +16,7 @@ from .decorators import semistaticmethod from .metaclasses import MergeAnnotationsMeta from .type_validation import get_type_validator, get_type_hint_str from .typing import NOTSET -from quickstats.utils.string_utils import format_dict_to_string +from quickstats.utils.string_utils import format_aligned_dict __all__ = ['ConfigComponent', 'ConfigScheme', 'ConfigFile', 'ConfigurableObject', 'ConfigUnit'] @@ -326,7 +326,7 @@ class ConfigComponent: def has_default(self) -> bool: return self._default is not NOTSET or self._default_factory is not NOTSET - def get_explain_text(self, indent_level: int = 0, indent_size: int = 4, line_break: int = 100, + def get_explain_text(self, indent_level: int = 0, indent_size: int = 4, linebreak: int = 100, with_default:bool=True, with_description:bool=True) -> str: """Generates and returns help text for the configuration component.""" indent = " " * indent_level * indent_size @@ -337,7 +337,7 @@ class ConfigComponent: if _is_configscheme_class(self.dtype): scheme_text = self.dtype.get_explain_text(indent_level=indent_level, indent_size=indent_size, - line_break=line_break) + linebreak=linebreak) return f"{indent}{name}\n{scheme_text}" components = { "Type": get_type_hint_str(self.dtype), @@ -348,14 +348,14 @@ class ConfigComponent: if with_description: components['Description'] = self.description or "No description provided" left_margin = indent_level * indent_size + 2 - attributes_text = format_dict_to_string(components, left_margin=left_margin, line_break=line_break) + attributes_text = format_aligned_dict(components, left_margin=left_margin, linebreak=linebreak) # TODO: display level? return f"{indent}{name}\n{attributes_text}" - def explain(self, line_break: int = 100, + def explain(self, linebreak: int = 100, with_default:bool=True, with_description:bool=True) -> None: - print(self.get_explain_text(line_break=line_break, + print(self.get_explain_text(linebreak=linebreak, with_default=with_default, with_description=with_description)) @@ -462,7 +462,7 @@ class ConfigScheme(AbstractObject, metaclass=MergeConfigAnnotationsMeta): @semistaticmethod def get_explain_text(self, components:Optional[List[str]]=None, - indent_level: int = 0, indent_size: int = 4, line_break: int = 100, + indent_level: int = 0, indent_size: int = 4, linebreak: int = 100, with_default:bool=True, with_description:bool=True) -> str: if inspect.isclass(self): title = self.__name__ @@ -480,18 +480,18 @@ class ConfigScheme(AbstractObject, metaclass=MergeConfigAnnotationsMeta): continue explain_text += component.get_explain_text(indent_level=indent_level+1, indent_size=indent_size, - line_break=line_break, + linebreak=linebreak, with_default=with_default, with_description=with_description) return explain_text @semistaticmethod def explain(self, components:Optional[List[str]]=None, - line_break: int = 100, + linebreak: int = 100, with_default:bool=True, with_description:bool=True) -> None: print(self.get_explain_text(components=components, - line_break=line_break, + linebreak=linebreak, with_default=with_default, with_description=with_description)) @@ -711,7 +711,7 @@ class ConfigurableObject(AbstractObject, metaclass=ConfigurableObjectMeta): @semistaticmethod def get_explain_text(self, names:Optional[Union[str , List[str]]]=None, - line_break: int = 100, + linebreak: int = 100, with_default:bool=True, with_description:bool=True) -> str: if inspect.isclass(self): @@ -736,19 +736,19 @@ class ConfigurableObject(AbstractObject, metaclass=ConfigurableObjectMeta): config = configs[name] components = required_components.get(name, None) explain_text += config.get_explain_text(components=components, - line_break=line_break, + linebreak=linebreak, with_default=with_default, with_description=with_description) return explain_text @semistaticmethod def explain_config(self, names:Optional[Union[str, List[str]]]=None, - line_break: int = 100, + linebreak: int = 100, with_default:bool=True, with_description:bool=True): print(self.get_explain_text(names, - line_break=line_break, + linebreak=linebreak, with_default=with_default, with_description=with_description)) diff --git a/quickstats/core/constraints.py b/quickstats/core/constraints.py index 4849e3fbadfc1078a80bed70529d43e45d30f335..6ba187fb411b25c104d1f81b9bb304e335adbd44 100644 --- a/quickstats/core/constraints.py +++ b/quickstats/core/constraints.py @@ -1,37 +1,103 @@ +""" +Value constraints for validating numerical and choice-based conditions. + +This module provides a set of constraint classes for enforcing value bounds +and valid choices on objects that support comparison operations. +""" + +from __future__ import annotations + +from typing import ( + Any, Set, TypeVar, Hashable, Union, Optional, + Collection, AbstractSet +) + +T = TypeVar('T', bound=Hashable) +Number = Union[int, float] + + class BaseConstraint: + """ + Base class for all constraints. + + Provides common functionality for constraint checking and representation. + """ - def __call__(self, obj): + def __call__(self, obj: Any) -> bool: + """ + Check if object satisfies constraint. + + Parameters + ---------- + obj : Any + Object to validate + + Returns + ------- + bool + True if constraint is satisfied + """ return True - def __repr__(self): - className = self.__class__.__name__ - attributes = ", ".join(f"{key}={value!r}" for key, value in self.__dict__.items() if not key.startswith('_')) - return f"{className}({attributes})" + def __repr__(self) -> str: + """Generate string representation of constraint.""" + attrs = ','.join( + f"{k}={v!r}" + for k, v in self.__dict__.items() + if not k.startswith('_') + ) + return f"{self.__class__.__name__}({attrs})" + + def __eq__(self, other: object) -> bool: + """Check equality with another constraint.""" + return isinstance(other, type(self)) - def __eq__(self): - return type(self) == type(other) + def __hash__(self) -> int: + """Generate hash based on public attributes.""" + return hash(tuple( + v for k, v in self.__dict__.items() + if not k.startswith('_') + )) @classmethod - def get_label(cls): + def get_label(cls) -> str: + """Get constraint class name.""" return cls.__name__ + class RangeConstraint(BaseConstraint): """ - Restricts the range of an object + Constraint enforcing value bounds. + + Parameters + ---------- + vmin : Number + Minimum value of the range + vmax : Number + Maximum value of the range + lbound : bool, optional + If True, vmin is inclusive, by default True + rbound : bool, optional + If True, vmax is inclusive, by default True + + Raises + ------ + ValueError + If vmin > vmax or bounds are not boolean """ - def __init__(self, vmin, vmax, lbound=True, rbound=True): - """ - Args: - vmin: Minimum value of the range. - vmax: Maximum value of the range. - lbound: If True, vmin is inclusive. Defaults to True. - rbound: If True, vmax is inclusive. Defaults to True. - """ + + def __init__( + self, + vmin: Number, + vmax: Number, + lbound: bool = True, + rbound: bool = True + ) -> None: if vmin > vmax: raise ValueError("vmin must be less than or equal to vmax") if not isinstance(lbound, bool) or not isinstance(rbound, bool): raise ValueError("lbound and rbound must be boolean values") - + self.vmin = vmin self.vmax = vmax self.lbound = lbound @@ -39,95 +105,226 @@ class RangeConstraint(BaseConstraint): self._lopt = '__le__' if lbound else '__lt__' self._ropt = '__ge__' if rbound else '__gt__' - def __call__(self, obj): + def __call__(self, obj: Any) -> bool: + """ + Check if object falls within range. + + Parameters + ---------- + obj : Any + Object to validate + + Returns + ------- + bool + True if object is within range + + Raises + ------ + ValueError + If object doesn't support required comparison operations + """ if not hasattr(obj, self._lopt) or not hasattr(obj, self._ropt): - raise ValueError(f"{obj} does not support '{self._lopt}' and '{self._ropt}', required for this range constraint.") - - lopt_result = getattr(obj, self._lopt)(self.vmin) - ropt_result = getattr(obj, self._ropt)(self.vmax) - return lopt_result and ropt_result + raise ValueError( + f"Object does not support required comparison operators: " + f"{self._lopt} and {self._ropt}" + ) - def __eq__(self, other): + return ( + getattr(obj, self._lopt)(self.vmin) and + getattr(obj, self._ropt)(self.vmax) + ) + + def __eq__(self, other: object) -> bool: + """Check equality with another range constraint.""" if not isinstance(other, RangeConstraint): return False - return (self.vmin, self.vmax, self.lbound, self.rbound) == (other.vmin, other.vmax, other.lbound, other.rbound) + return ( + self.vmin == other.vmin and + self.vmax == other.vmax and + self.lbound == other.lbound and + self.rbound == other.rbound + ) + + def __hash__(self) -> int: + """Generate hash based on constraint parameters.""" + return hash((self.vmin, self.vmax, self.lbound, self.rbound)) + class MinConstraint(BaseConstraint): """ - Restricts the minimum value of an object + Constraint enforcing minimum value. + + Parameters + ---------- + vmin : Number + Minimum allowable value + inclusive : bool, optional + If True, vmin is included in valid range, by default True + + Raises + ------ + ValueError + If inclusive is not boolean """ - def __init__(self, vmin, inclusive=True): - """ - Args: - vmin: The minimum allowable value. - inclusive: If True, vmin is included in the valid range. Defaults to True. - """ + + def __init__(self, vmin: Number, inclusive: bool = True) -> None: if not isinstance(inclusive, bool): raise ValueError("inclusive must be a boolean value") - + self.vmin = vmin self.inclusive = inclusive self._opt = '__le__' if inclusive else '__lt__' - def __call__(self, obj): + def __call__(self, obj: Any) -> bool: + """ + Check if object meets minimum constraint. + + Parameters + ---------- + obj : Any + Object to validate + + Returns + ------- + bool + True if object meets minimum requirement + + Raises + ------ + ValueError + If object doesn't support required comparison operation + """ if not hasattr(obj, self._opt): - raise ValueError(f"{obj} does not support the '{self._opt}' comparison operator required for this minimum constraint.") - + raise ValueError( + f"Object does not support required comparison operator: {self._opt}" + ) return getattr(obj, self._opt)(self.vmin) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: + """Check equality with another minimum constraint.""" if not isinstance(other, MinConstraint): return False - return (self.vmin, self.inclusive) == (other.vmin, other.inclusive) + return self.vmin == other.vmin and self.inclusive == other.inclusive + + def __hash__(self) -> int: + """Generate hash based on constraint parameters.""" + return hash((self.vmin, self.inclusive)) class MaxConstraint(BaseConstraint): """ - Restricts the maximum value of an object + Constraint enforcing maximum value. + + Parameters + ---------- + vmax : Number + Maximum allowable value + inclusive : bool, optional + If True, vmax is included in valid range, by default True + + Raises + ------ + ValueError + If inclusive is not boolean """ - def __init__(self, vmax, inclusive=True): - """ - Args: - vmax: The maximum allowable value. - inclusive: If True, vmax is included in the valid range. Defaults to True. - """ + + def __init__(self, vmax: Number, inclusive: bool = True) -> None: if not isinstance(inclusive, bool): raise ValueError("inclusive must be a boolean value") - + self.vmax = vmax self.inclusive = inclusive self._opt = '__ge__' if inclusive else '__gt__' - def __call__(self, obj): + def __call__(self, obj: Any) -> bool: + """ + Check if object meets maximum constraint. + + Parameters + ---------- + obj : Any + Object to validate + + Returns + ------- + bool + True if object meets maximum requirement + + Raises + ------ + ValueError + If object doesn't support required comparison operation + """ if not hasattr(obj, self._opt): - raise ValueError(f"{obj} does not support the '{self._opt}' comparison operator required for this maximum constraint.") - + raise ValueError( + f"Object does not support required comparison operator: {self._opt}" + ) return getattr(obj, self._opt)(self.vmax) - def __eq__(self, other): + def __eq__(self, other: object) -> bool: + """Check equality with another maximum constraint.""" if not isinstance(other, MaxConstraint): return False - return (self.vmax, self.inclusive) == (other.vmax, other.inclusive) + return self.vmax == other.vmax and self.inclusive == other.inclusive + + def __hash__(self) -> int: + """Generate hash based on constraint parameters.""" + return hash((self.vmax, self.inclusive)) class ChoiceConstraint(BaseConstraint): """ - Restricts the value of an object to be among a given set of choices + Constraint restricting values to a set of choices. + + Parameters + ---------- + *choices : T + Allowable choices + + Examples + -------- + >>> constraint = ChoiceConstraint('red', 'green', 'blue') + >>> constraint('red') + True + >>> constraint('yellow') + False """ - def __init__(self, *choices): - """ - Args: - *choices: A variable-length list of allowable choices. + + def __init__(self, *choices: T) -> None: + self.choices: AbstractSet[T] = frozenset(choices) + + def __call__(self, obj: T) -> bool: """ - self.choices = set(choices) + Check if object is an allowed choice. + + Parameters + ---------- + obj : T + Object to validate - def __call__(self, obj): + Returns + ------- + bool + True if object is an allowed choice + + Raises + ------ + ValueError + If object is not among allowed choices + """ if obj not in self.choices: - raise ValueError(f"{obj} is not one of the allowed choices: {self.choices}") + raise ValueError( + f"Value not in allowed choices: {sorted(self.choices)}" + ) return True - def __eq__(self, other): + def __eq__(self, other: object) -> bool: + """Check equality with another choice constraint.""" if not isinstance(other, ChoiceConstraint): return False - # Note: Direct set comparison - return self.choices == other.choices \ No newline at end of file + return self.choices == other.choices + + def __hash__(self) -> int: + """Generate hash based on choices.""" + return hash(self.choices) \ No newline at end of file diff --git a/quickstats/core/decorators.py b/quickstats/core/decorators.py index 5faa941a64fa409f6301028b2f293ff9eae869eb..0a0105153f098a9c0ed4376be6454b9bdaeccdcd 100644 --- a/quickstats/core/decorators.py +++ b/quickstats/core/decorators.py @@ -1,253 +1,290 @@ -from typing import Optional, get_type_hints -from functools import partial, wraps -from dataclasses import dataclass, field, fields, Field -import sys -import time -import inspect -import importlib - -from .type_validation import check_type, get_type_hint_str -from .typing import NOTSET +""" +Function and class decorators for enhanced functionality. -__all__ = ["semistaticmethod", "cls_method_timer", "timer", "type_check", "strongdataclass"] +This module provides decorators for type checking, timing, and dataclass enhancements. +""" -class semistaticmethod(object): - """ - Descriptor to allow a staticmethod inside a class to use 'self' when called from an instance. +from __future__ import annotations - This custom descriptor class `semistaticmethod` enables a static method defined inside a class - to access the instance (`self`) when called from an instance, similar to how regular instance - methods can access the instance attributes. By default, static methods do not have access to - the instance and can only access the class-level attributes. +__all__ = [ + "semistaticmethod", + "cls_method_timer", + "timer", + "type_check", + "strongdataclass" +] - Note: - When defining a static method using this descriptor, it should be used like a regular method - within the class definition. It will work as a normal static method when called from the class, - and when called from an instance, it will receive the instance as the first argument. +import inspect +import time +from dataclasses import dataclass, fields +from functools import partial, wraps +from typing import ( + Any, Callable, Optional, Type, TypeVar, Union, + get_type_hints, cast, Generic +) - Args: - callable (function): The original static method defined within the class. +from .type_validation import check_type, get_type_hint_str +from .typing import NOTSET - Returns: - callable: A callable object that behaves like a static method but can also access the instance. - """ - def __init__(self, callable): - self.f = callable +T = TypeVar('T') +F = TypeVar('F', bound=Callable[..., Any]) - def __get__(self, obj, type=None): - if (obj is None) and (type is not None): - return partial(self.f, type) - if (obj is not None): - return partial(self.f, obj) - return self.f - @property - def __func__(self): - return self.f +class semistaticmethod(Generic[F]): + """ + Descriptor for static methods that can access instance when called from instance. + + Parameters + ---------- + func : Callable + The function to be converted into a semi-static method -def cls_method_timer(func): + Examples + -------- + >>> class MyClass: + ... @semistaticmethod + ... def my_method(self_or_cls): + ... return self_or_cls + ... + >>> MyClass.my_method() # Returns class + >>> obj = MyClass() + >>> obj.my_method() # Returns instance """ - Decorator function to measure the execution time of a class method. - - The `cls_method_timer` decorator function can be applied to any class method to automatically measure - the execution time of the method. When the decorated method is called, it records the start and end - times, calculates the time interval, and prints a message with the method name and the execution time. - Args: - func (callable): The class method to be decorated. + def __init__(self, func: F) -> None: + self.func = func - Returns: - callable: The wrapped function with timing functionality. + def __get__( + self, + obj: Optional[Any], + cls: Optional[Type[Any]] = None + ) -> Callable[..., Any]: + if obj is None and cls is not None: + return partial(self.func, cls) + if obj is not None: + return partial(self.func, obj) + return self.func - Example: - class MyClass: - @cls_method_timer - def my_method(self, n): - # Some time-consuming computation here - result = sum(range(n)) - return result + @property + def __func__(self) -> F: + """Get the original function.""" + return self.func - obj = MyClass() - obj.my_method(1000000) # The decorated method will print the execution time - # Output: "Task MyClass::my_method executed in 0.006 s" - Note: - The `cls_method_timer` function should be used as a decorator when defining a class method. - When the decorated method is called, it will print the execution time to the console. +class hybridproperty(Generic[T]): """ - def wrapper(self, *args, **kwargs): - """ - Wrapper function to measure the execution time of the class method. - - This wrapper function records the start time before calling the original method, then calls - the original method, and finally calculates and prints the execution time. - - Args: - self: The instance of the class. - *args: Variable-length argument list. - **kwargs: Keyword arguments. + Decorator for properties that work at both class and instance level. + + This decorator allows defining different behaviors for when the property + is accessed at the class level versus the instance level. - Returns: - The result returned by the original method. - """ + Parameters + ---------- + fcls : Callable + Function to call for class-level access + finst : Optional[Callable] + Function to call for instance-level access + If None, uses class-level implementation + """ + + def __init__( + self, + fcls: Callable[..., T], + finst: Optional[Callable[..., T]] = None + ) -> None: + self.fcls = fcls + self.finst = finst or fcls + + def __get__( + self, + instance: Optional[Any], + cls: Optional[Type] = None + ) -> T: + if instance is None and cls is not None: + return self.fcls(cls) + if instance is not None: + return self.finst(instance) + raise TypeError("Cannot access property from neither instance nor class") + + def instance(self, finst: Callable[..., T]) -> hybridproperty[T]: + """Decorator to set the instance-level implementation.""" + return type(self)(self.fcls, finst) + + +def cls_method_timer(func: F) -> F: + """ + Time execution of class methods. + + Parameters + ---------- + func : Callable + The class method to time + + Returns + ------- + Callable + Wrapped method that prints execution time + + Examples + -------- + >>> class MyClass: + ... @cls_method_timer + ... def my_method(self): + ... time.sleep(1) + """ + + @wraps(func) + def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: t1 = time.time() result = func(self, *args, **kwargs) t2 = time.time() + method_name = f"{type(self).__name__}::{func.__name__}" - self.stdout.info(f'Task {method_name!r} executed in {(t2 - t1):.3f} s') + self.stdout.info( + f'Task {method_name!r} executed in {(t2 - t1):.3f} s' + ) return result - return wrapper + return cast(F, wrapper) -class timer: - """ - Context manager class for measuring the execution time of a code block. - - Example: - with timer() as t: - # Perform some time-consuming task here - time.sleep(2) - - print("Elapsed time:", t.interval) # outputs: "Elapsed time: 2.0 seconds" +class timer: + """ + Context manager for timing code blocks. + + Measures both real time and CPU time elapsed. + + Examples + -------- + >>> with timer() as t: + ... time.sleep(1) + >>> print(f"{t.real_time_elapsed:.1f}s") + 1.0s """ - def __enter__(self): - """ - Records the start time when entering the context. - Returns: - timer: The timer instance itself. - """ + def __init__(self) -> None: + self.start_real: float = 0.0 + self.start_cpu: float = 0.0 + self.end_real: float = 0.0 + self.end_cpu: float = 0.0 + self.interval: float = 0.0 + self.real_time_elapsed: float = 0.0 + self.cpu_time_elapsed: float = 0.0 + + def __enter__(self) -> timer: + """Start timing.""" self.start_real = time.time() self.start_cpu = time.process_time() return self - def __exit__(self, *args): - """ - Calculates the time interval when exiting the context. - - Args: - *args: Variable-length argument list. - - Returns: - None - """ + def __exit__(self, *args: Any) -> None: + """Stop timing and compute intervals.""" self.end_cpu = time.process_time() self.end_real = time.time() self.interval = self.end_real - self.start_real self.real_time_elapsed = self.interval self.cpu_time_elapsed = self.end_cpu - self.start_cpu -def type_check(func): - """ - A decorator to enforce type checking on function arguments based on their type hints. +def type_check(func: F) -> F: + """ + Decorator for runtime type checking of function arguments. + Parameters ---------- func : Callable - The function to be decorated. - + Function to add type checking to + Returns ------- Callable - The decorated function with type checking. - + Wrapped function that validates argument types + Raises ------ TypeError - If an argument does not match its type hint. - + If an argument's type doesn't match its annotation + Examples -------- >>> @type_check - ... def my_function(a: int, b: str, c, *args, **kwargs): - ... print(a, b, c, args, kwargs) - ... - >>> my_function(10, "hello", 20, 30, 40, key="value") - 10 hello 20 (30, 40) {'key': 'value'} - >>> my_function(10, 20, "hello") - Traceback (most recent call last): - ... - TypeError: Type check failed for the function "my_function". Argument "b" must be of type str, but got int. + ... def greet(name: str, count: int) -> str: + ... return name * count + >>> greet("hi", 3) + 'hihihi' + >>> greet("hi", "3") # Raises TypeError """ + @wraps(func) - def wrapper(*args, **kwargs): - # Retrieve the function signature + def wrapper(*args: Any, **kwargs: Any) -> Any: sig = inspect.signature(func) - bound_args = sig.bind(*args, **kwargs) - bound_args.apply_defaults() + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() - # Check each argument against its type hint - for name, value in bound_args.arguments.items(): - if ((name not in sig.parameters) or - (sig.parameters[name].annotation == sig.parameters[name].empty)): + # Check each argument against its type hint + for name, value in bound.arguments.items(): + param = sig.parameters.get(name) + if param is None or param.annotation is param.empty: continue - type_hint = sig.parameters[name].annotation - if not check_type(value, type_hint): - type_hint_str = get_type_hint_str(type_hint) - raise TypeError(f'Type check failed for the function "{func.__qualname__}". ' - f'Argument "{name}" must be of type {type_hint_str}, ' - f'but got {type(value).__name__}.') - + + if not check_type(value, param.annotation): + type_hint_str = get_type_hint_str(param.annotation) + raise TypeError( + f'Type check failed for function "{func.__qualname__}". ' + f'Argument "{name}" must be of type {type_hint_str}, ' + f'got {type(value).__name__}' + ) + return func(*args, **kwargs) - - return wrapper - -def strongdataclass(cls): - """ - A decorator to create a dataclass with strong type checking. - - Parameters - ---------- - cls : type - The class to be decorated. - - Returns - ------- - type - The decorated class with strong type checking. + + return cast(F, wrapper) + + +def strongdataclass( + cls: Optional[Type[T]] = None, + *, + init: bool = True, + repr: bool = True, + eq: bool = True, + order: bool = False, + unsafe_hash: bool = False, + frozen: bool = False, +) -> Union[Callable[[Type[T]], Type[T]], Type[T]]: """ - cls = dataclass(cls) - type_hints = get_type_hints(cls) - - for field in fields(cls): - private_name = f"_{field.name}" - public_name = field.name - type_hint = type_hints.get(field.name, NOTSET) - - # Define the getter - def getter(self, private_name=private_name): - return getattr(self, private_name) - - def setter(self, value, private_name=private_name, type_hint=type_hint): - if (type_hint is not NOTSET) and (not check_type(value, type_hint)): - public_name_ = private_name.strip("_") - type_hint_str = get_type_hint_str(type_hint) - raise TypeError(f'`{public_name_}` expects type {type_hint_str}, ' - f'got {type(value).__name__}') - setattr(self, private_name, value) - - setattr(cls, public_name, property(getter, setter)) + Create a dataclass with runtime type checking. - return cls - -def strongdataclass(cls=None, *args, **kwargs): - """ - A decorator to create a dataclass with strong type checking. - Parameters ---------- - cls : type - The class to be decorated. - + cls : Optional[Type] + Class to decorate + init, repr, eq, order, unsafe_hash, frozen : bool + Standard dataclass parameters + Returns ------- - type - The decorated class with strong type checking. + Union[Callable[[Type], Type], Type] + Decorated class with type checking + + Examples + -------- + >>> @strongdataclass + ... class Person: + ... name: str + ... age: int + >>> p = Person("Alice", 30) # OK + >>> p = Person("Alice", "30") # Raises TypeError """ - def wrap(cls): - - cls = dataclass(cls, *args, **kwargs) + + def wrap(cls: Type[T]) -> Type[T]: + cls = dataclass( + cls, + init=init, + repr=repr, + eq=eq, + order=order, + unsafe_hash=unsafe_hash, + frozen=frozen + ) type_hints = get_type_hints(cls) @@ -256,20 +293,32 @@ def strongdataclass(cls=None, *args, **kwargs): public_name = field.name type_hint = type_hints.get(field.name, NOTSET) - # Define the getter - def getter(self, private_name=private_name): + def getter( + self: Any, + private_name: str = private_name + ) -> Any: return getattr(self, private_name) - def setter(self, value, private_name=private_name, type_hint=type_hint): - if (type_hint is not NOTSET) and (not check_type(value, type_hint)): - public_name_ = private_name.strip("_") + def setter( + self: Any, + value: Any, + private_name: str = private_name, + type_hint: Any = type_hint + ) -> None: + if ( + type_hint is not NOTSET and + not check_type(value, type_hint) + ): + public_name = private_name.lstrip('_') type_hint_str = get_type_hint_str(type_hint) - raise TypeError(f'`{public_name_}` expects type {type_hint_str}, ' - f'got {type(value).__name__}') + raise TypeError( + f'`{public_name}` expects type {type_hint_str}, ' + f'got {type(value).__name__}' + ) setattr(self, private_name, value) - + setattr(cls, public_name, property(getter, setter)) - + return cls if cls is None: diff --git a/quickstats/core/enums.py b/quickstats/core/enums.py index 26338291d1acd789e847d603cef1000deac52ae6..53466a7bbad7d1292423f65021f121f0901cbe17 100644 --- a/quickstats/core/enums.py +++ b/quickstats/core/enums.py @@ -1,224 +1,375 @@ -from typing import Any, Optional, Union, List, Dict +from __future__ import annotations + +import sys from enum import Enum +from typing import ( + Any, Optional, Union, List, Dict, TypeVar, Type, ClassVar, + cast, NoReturn +) + +# Handle Python version differences for TypeAlias +if sys.version_info >= (3, 10): + from typing import TypeAlias +else: + from typing_extensions import TypeAlias + +__all__ = ["CaseInsensitiveStrEnum", "GeneralEnum", "DescriptiveEnum"] -__all__ = ["GeneralEnum", "DescriptiveEnum", "CaseInsensitiveStrEnum"] +# Type definitions +T = TypeVar('T', bound='GeneralEnum') +EnumValue: TypeAlias = Union[int, str, 'GeneralEnum'] class CaseInsensitiveStrEnum(str, Enum): + """ + String enumeration that supports case-insensitive comparison. + + This class extends the standard string Enum to allow case-insensitive + matching when looking up enum members. + + Examples + -------- + >>> class Format(CaseInsensitiveStrEnum): + ... JSON = "json" + ... XML = "xml" + >>> Format("JSON") == Format.JSON + True + >>> Format("json") == Format.JSON + True + """ + @classmethod - def _missing_(cls, value): - value = value.lower() + def _missing_(cls, value: Any) -> Optional[CaseInsensitiveStrEnum]: + """Handle missing enum values with case-insensitive matching.""" + if not isinstance(value, str): + return None + + value_lower = value.lower() for member in cls: - if member.lower() == value: + if member.lower() == value_lower: return member return None + class GeneralEnum(Enum): """ - Extended Enum class with additional parsing and lookup functionalities. - - Args: - expr (Optional[Union[int, str, GeneralEnum]], optional): The expression to parse into an enum member. - This can be an integer, a string representing the enum member name (case-insensitive), - or an existing `GeneralEnum` instance. Defaults to None. - - Returns: - GeneralEnum or None: The corresponding enum member if the expression is valid and matches any enum - member or alias. Returns None if expr is None. - - Example: - class MyEnum(GeneralEnum): - OPTION_A = 1 - OPTION_B = 2 - - # Parsing from string - option = MyEnum.parse("option_a") # Returns MyEnum.OPTION_A - - # Alias handling - MyEnum.__aliases__ = {"alias_a": "option_a"} - alias_option = MyEnum.parse("alias_a") # Returns MyEnum.OPTION_A - - # Retrieving enum member by attribute value - option_with_value = MyEnum.get_member_by_attribute("value", 2) # Returns MyEnum.OPTION_B + Enhanced enumeration with parsing and lookup capabilities. + + This class extends the standard Enum to provide additional functionality + like flexible parsing, aliasing, and attribute-based lookups. + + Attributes + ---------- + __aliases__ : ClassVar[Dict[str, str]] + Class-level mapping of alias names to enum member names + + Methods + ------- + parse(value: Optional[Union[int, str, GeneralEnum]]) -> Optional[GeneralEnum] + Convert a string, int, or enum value to an enum member + get_members() -> List[str] + Get list of all member names + get_member_by_attribute(attribute: str, value: Any) -> Optional[GeneralEnum] + Find enum member by attribute value + + Examples + -------- + >>> class Status(GeneralEnum): + ... ACTIVE = 1 + ... INACTIVE = 2 + ... __aliases__ = {"enabled": "active"} + >>> Status.parse("active") + <Status.ACTIVE: 1> + >>> Status.parse("enabled") + <Status.ACTIVE: 1> """ - __aliases__ = { - } + __aliases__: ClassVar[Dict[str, str]] = {} - @classmethod - def _missing_(cls, value: Any): - return cls.parse(value) - - @classmethod - def on_parse_exception(cls, expr: str): + def __eq__(self, other: Any) -> bool: + """ + Compare enum members with support for parsing string/int values. + + Parameters + ---------- + other : Any + Value to compare against + + Returns + ------- + bool + True if values are equal """ - Raises a runtime error for invalid options in the parse() method. + if isinstance(other, Enum): + if not isinstance(other, type(self)): + return False + return self.value == other.value + + try: + other_member = self.parse(other) + return self.value == other_member.value if other_member is not None else False + except ValueError: + return False - Args: - expr (str): The expression representing the invalid option. + def __hash__(self) -> int: + """Generate hash based on enum value.""" + return hash(self.value) - Raises: - RuntimeError: If the expression is not a valid enum member or alias, providing the list of - allowed options. + @classmethod + def _missing_(cls: Type[T], value: Any) -> Optional[T]: + """Handle missing enum values by attempting to parse them.""" + try: + return cls.parse(value) + except ValueError: + return None + + @classmethod + def on_parse_exception(cls, expr: str) -> NoReturn: """ - classname = cls.__name__ - option_text = ", ".join(cls.get_members()) - raise RuntimeError(f'Invalid option "{expr}" for the enum class "{classname}" ' - f'(allowed options: {option_text}).') + Handle invalid parse attempts with detailed error message. + Parameters + ---------- + expr : str + The invalid expression that failed to parse + + Raises + ------ + ValueError + Always raised with detailed error message + """ + valid_options = ", ".join(cls.get_members()) + raise ValueError( + f'Invalid option "{expr}" for enum class "{cls.__name__}". ' + f'Allowed options: {valid_options}' + ) + @classmethod - def parse(cls, value: Optional[Union[int, str, "GeneralEnum"]] = None) -> Optional["GeneralEnum"]: + def parse( + cls: Type[T], + value: Optional[EnumValue] = None + ) -> Optional[T]: """ - Parses a given expression into the corresponding enum member or its alias. - - Args: - value (Optional[Union[int, str, GeneralEnum]], optional): The expression to parse into an enum member. - This can be an integer representing the enum value, a string representing the enum member - name (case-insensitive), or an existing `GeneralEnum` instance. Defaults to None. - - Returns: - GeneralEnum or None: The corresponding enum member if the expression is valid and matches any enum - member or alias. Returns None if expr is None. - - Raises: - RuntimeError: If the expression is not valid or does not match any enum member or alias. + Parse a value into the corresponding enum member. + + Parameters + ---------- + value : Optional[Union[int, str, GeneralEnum]] + Value to convert to enum member. Can be: + - None (returns None) + - Integer (matched against enum values) + - String (matched against member names or aliases) + - Enum instance (returned if matching type) + + Returns + ------- + Optional[T] + Corresponding enum member or None if input is None + + Raises + ------ + ValueError + If value cannot be parsed to a valid enum member + + Examples + -------- + >>> class Color(GeneralEnum): + ... RED = 1 + ... BLUE = 2 + >>> Color.parse("red") + <Color.RED: 1> + >>> Color.parse(2) + <Color.BLUE: 2> """ - if isinstance(value, str): - expr = value.strip().lower() - members_map = cls.get_members_map() - if expr in members_map: - return members_map[expr] - aliases_map = cls.get_aliases_map() - if expr in aliases_map: - return cls.parse(aliases_map[expr]) - cls.on_parse_exception(value) if value is None: return None + if isinstance(value, cls): return value + + if isinstance(value, str): + value_lower = value.strip().lower() + members_map = cls.get_members_map() + if value_lower in members_map: + return members_map[value_lower] + + aliases_map = cls.get_aliases_map() + if value_lower in aliases_map: + return cls.parse(aliases_map[value_lower]) + + cls.on_parse_exception(value) + values_map = cls.get_values_map() if value in values_map: return values_map[value] - cls.on_parse_exception(value) + cls.on_parse_exception(str(value)) + @classmethod def get_members(cls) -> List[str]: """ - Returns a list of member names in lowercase. - - Returns: - list[str]: A list of member names in lowercase. + Get list of all member names in lowercase. + + Returns + ------- + List[str] + Member names in lowercase """ - return [i.lower() for i in cls.__members__] + return [name.lower() for name in cls.__members__] @classmethod - def get_members_map(cls) -> Dict[str, "GeneralEnum"]: + def get_members_map(cls: Type[T]) -> Dict[str, T]: """ - Returns a dictionary mapping lowercase member names to enum members. - - Returns: - dict[str, GeneralEnum]: A dictionary mapping lowercase member names to enum members. + Get mapping of lowercase names to enum members. + + Returns + ------- + Dict[str, T] + Mapping of {lowercase_name: enum_member} """ - return {k.lower(): v for k, v in cls.__members__.items()} + return { + name.lower(): member + for name, member in cls.__members__.items() + } @classmethod - def get_values_map(cls) -> Dict[str, "GeneralEnum"]: + def get_values_map(cls: Type[T]) -> Dict[Any, T]: """ - Returns a dictionary mapping enum values to enum member. - - Returns: - dict[str, GeneralEnum]: A dictionary mapping enum values to enum member. + Get mapping of enum values to members. + + Returns + ------- + Dict[Any, T] + Mapping of {enum_value: enum_member} """ - return {v.value: v for k, v in cls.__members__.items()} + return { + member.value: member + for member in cls.__members__.values() + } @classmethod - def get_aliases_map(cls) -> Dict[str, "GeneralEnum"]: + def get_aliases_map(cls) -> Dict[str, str]: """ - Returns a dictionary mapping lowercase aliases to enum members. - - Returns: - dict[str, GeneralEnum]: A dictionary mapping lowercase aliases to enum members. + Get mapping of lowercase aliases to member names. + + Returns + ------- + Dict[str, str] + Mapping of {lowercase_alias: member_name} """ - return {k.lower(): v for k, v in cls.__aliases__.items()} + return { + alias.lower(): target.lower() + for alias, target in cls.__aliases__.items() + } @classmethod def has_member(cls, name: str) -> bool: """ - Checks if an enum member exists with the given name (case-insensitive). - - Args: - name (str): The name of the enum member to check. - - Returns: - bool: True if the enum member exists, False otherwise. + Check if member exists with given name. + + Parameters + ---------- + name : str + Name to check (case-insensitive) + + Returns + ------- + bool + True if member exists """ return name.lower() in cls.get_members() @classmethod - def get_member_by_attribute(cls, attribute: str, value: Any) -> Optional["GeneralEnum"]: + def get_member_by_attribute( + cls: Type[T], + attribute: str, + value: Any + ) -> Optional[T]: """ - Returns the enum member that has the specified attribute with the given value. - - Args: - attribute (str): The name of the attribute to search for. - value (Any): The value of the attribute to match. - - Returns: - GeneralEnum or None: The enum member that matches the attribute value. Returns None if not found. + Find enum member by attribute value. + + Parameters + ---------- + attribute : str + Name of attribute to check + value : Any + Value to match + + Returns + ------- + Optional[T] + Matching enum member or None if not found + + Examples + -------- + >>> class Format(GeneralEnum): + ... JSON = (1, "JavaScript Object Notation") + ... def __init__(self, id, desc): + ... self.id = id + ... self.desc = desc + >>> Format.get_member_by_attribute('id', 1) + <Format.JSON: 1> """ - members = cls.__members__ - return next((x for x in members.values() if getattr(x, attribute) == value), None) + for member in cls.__members__.values(): + if hasattr(member, attribute): + if getattr(member, attribute) == value: + return member + return None + class DescriptiveEnum(GeneralEnum): """ - Enum class with support for additional descriptions for each enum member. - - Attributes: - description (str): The additional description associated with each enum member. - - Example: - class MyEnum(DescriptiveEnum): - OPTION_A = 1, "This is option A" - OPTION_B = 2, "This is option B" - - # Accessing enum member and its description - print(MyEnum.OPTION_A) # Output: MyEnum.OPTION_A - print(MyEnum.OPTION_A.description) # Output: "This is option A" - - # Parsing from string with description in on_parse_exception - option = MyEnum.parse("option_b") # Returns MyEnum.OPTION_B - print(option.description) # Output: "This is option B" + Enumeration with additional descriptive text for each member. + + This class extends GeneralEnum to add a description field to each + enum member, useful for human-readable labels and documentation. + + Parameters + ---------- + value : int + The enum value + description : str, optional + Human-readable description of the enum member + + Examples + -------- + >>> class Status(DescriptiveEnum): + ... ACTIVE = (1, "Item is currently active") + ... INACTIVE = (2, "Item has been deactivated") + >>> Status.ACTIVE.description + 'Item is currently active' """ - def __new__(cls, value: int, description: str = ""): - """ - Creates a new `DescriptiveEnum` instance with the given value and an optional description. - - Args: - value (int): The value associated with the enum member. - description (str, optional): An additional description for the enum member. Defaults to "". + description: str - Returns: - DescriptiveEnum: The newly created `DescriptiveEnum` instance with the given value and description. - """ + def __new__(cls, value: int, description: str = "") -> DescriptiveEnum: + """Create new enum member with description.""" obj = object.__new__(cls) obj._value_ = value obj.description = description return obj @classmethod - def on_parse_exception(cls, expr: str): + def on_parse_exception(cls, expr: str) -> NoReturn: """ - Raises a runtime error when an invalid option is passed to the parse() method. - - Args: - expr (str): The expression representing the invalid option. - - Raises: - RuntimeError: If the expression is not a valid enum member or alias, providing the list of - allowed options along with their descriptions. + Handle invalid parse attempts with descriptive error. + + Parameters + ---------- + expr : str + The invalid expression + + Raises + ------ + ValueError + With detailed error including available options and descriptions """ - classname = cls.__name__ - enum_descriptions = "".join([f' {key.lower()} - {val.description}\n' \ - for key, val in cls.__members__.items()]) - raise RuntimeError(f'Invalid option "{expr}" for the enum class "{classname}"\n' - f' Allowed options:\n{enum_descriptions}') \ No newline at end of file + class_name = cls.__name__ + descriptions = "\n".join( + f" {name.lower()} - {member.description}" + for name, member in cls.__members__.items() + ) + + raise ValueError( + f'Invalid option "{expr}" for enum class "{class_name}"\n' + f'Available options:\n{descriptions}' + ) \ No newline at end of file diff --git a/quickstats/core/flexible_dumper.py b/quickstats/core/flexible_dumper.py index 78a5b82cc5c3a6a64a11db8927fdbaaf60277559..09254ffd02a06381bf16d5acdbc4ddd9d8873d1e 100644 --- a/quickstats/core/flexible_dumper.py +++ b/quickstats/core/flexible_dumper.py @@ -1,177 +1,342 @@ -from typing import Any +""" +Flexible text representation generator for Python objects. -class FlexibleDumper: +This module provides a customizable dumper for creating human-readable +text representations of nested Python data structures with support for +various formatting options and limits. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, List, Optional, Union, Sequence, Dict +from collections.abc import Mapping, Iterable + + +@dataclass +class DumperConfig: """ - A flexible dumper that creates a text representation of a python object. + Configuration settings for FlexibleDumper. + + Parameters + ---------- + item_indent : str + Indentation string for regular items + list_indent : str + Indentation string for list items + separator : str + Separator between keys and values + skip_str : str + String to indicate truncation + indent_sequence_on_key : bool + Whether to indent sequences under their keys + max_depth : int + Maximum nesting depth (-1 for unlimited) + max_iteration : int + Maximum number of sequence items (-1 for unlimited) + max_item : int + Maximum number of mapping items (-1 for unlimited) + max_line : int + Maximum number of output lines (-1 for unlimited) + max_len : int + Maximum line length (-1 for unlimited) + + Raises + ------ + ValueError + If item_indent and list_indent have different lengths """ - @property - def item_indent(self) -> str: - return self._item_indent - - @property - def list_indent(self) -> str: - return self._list_indent - - @property - def separator(self) -> str: - return self._separator - - @property - def skip_str(self) -> str: - return self._skip_str - - @property - def indent_sequence_on_key(self) -> bool: - return self._indent_sequence_on_key - - @property - def max_depth(self) -> int: - return self._max_depth - - @property - def max_iteration(self) -> int: - return self._max_iteration - - @property - def max_item(self) -> int: - return self._max_item - - @property - def max_line(self) -> int: - return self._max_line - - @property - def max_len(self) -> int: - return self._max_len - - def __init__(self, item_indent: str = ' ', list_indent: str = '- ', - separator: str = ': ', skip_str: str = '...', - indent_sequence_on_key: bool = True, max_depth: int = -1, - max_iteration: int = -1, max_item: int = -1, max_line: int = -1, max_len: int = -1): - if len(item_indent) != len(list_indent): - raise ValueError('Length of `item_indent` must equal that of `list_indent`.') + item_indent: str = " " + list_indent: str = "- " + separator: str = ": " + skip_str: str = "..." + indent_sequence_on_key: bool = True + max_depth: int = -1 + max_iteration: int = -1 + max_item: int = -1 + max_line: int = -1 + max_len: int = -1 + + def __post_init__(self) -> None: + """Validate configuration after initialization.""" + if len(self.item_indent) != len(self.list_indent): + raise ValueError("Length of item_indent must equal that of list_indent") - self._item_indent = item_indent - self._list_indent = list_indent - self._separator = separator - self._skip_str = skip_str - self._indent_sequence_on_key = indent_sequence_on_key - self._max_depth = max_depth - self._max_iteration = max_iteration - self._max_item = max_item - self._max_line = max_line - self._max_len = max_len + for limit_name in ('max_depth', 'max_iteration', 'max_item', 'max_line', 'max_len'): + value = getattr(self, limit_name) + if not isinstance(value, int): + raise TypeError(f"{limit_name} must be an integer") + if value < -1: + raise ValueError(f"{limit_name} must be >= -1") + + +class FlexibleDumper: + """ + A flexible dumper that creates a text representation of Python objects. + + This class provides customizable formatting for nested data structures + with support for depth limits, iteration limits, and line length limits. + + Parameters + ---------- + **config_kwargs + Keyword arguments passed to DumperConfig + + Examples + -------- + >>> dumper = FlexibleDumper(max_depth=2, max_line=50) + >>> data = {"a": [1, 2, {"b": 3}], "c": 4} + >>> print(dumper.dump(data)) + a: + - 1 + - 2 + - ... + c: 4 + """ + + def __init__(self, **config_kwargs) -> None: + self.config = DumperConfig(**config_kwargs) self.reset() - def configure(self, **kwargs): - for key, value in kwargs.items(): - if not hasattr(self, f'_{key}'): - raise KeyError(f'Invalid property: {key}') - setattr(self, f'_{key}', value) - - def reset(self): - """Reset the state of the dumper.""" - self.terminate = False - self.terminate_depth = False - self.lines = [] - - def get_indent_str(self, depth: int, iteration: int = 0, list_degree: int = 0) -> str: - """Get the indentation string for the current depth and iteration.""" + def reset(self) -> None: + """Reset the internal state of the dumper.""" + self._terminate: bool = False + self._terminate_depth: bool = False + self._lines: List[str] = [] + + def _get_indent(self, depth: int, iteration: int = 0, list_degree: int = 0) -> str: + """ + Generate indentation string for current context. + + Parameters + ---------- + depth : int + Current nesting depth + iteration : int, optional + Current iteration number within a sequence + list_degree : int, optional + Number of nested lists at current position + + Returns + ------- + str + Formatted indentation string + """ if list_degree > 0: if iteration == 0: - return (depth - list_degree) * self.item_indent + list_degree * self.list_indent - return (depth - 1) * self.item_indent + self.list_indent - return depth * self.item_indent + return ( + self.config.item_indent * (depth - list_degree) + + self.config.list_indent * list_degree + ) + return ( + self.config.item_indent * (depth - 1) + + self.config.list_indent + ) + return self.config.item_indent * depth + + def _add_line( + self, + text: str, + depth: int, + iteration: int = 0, + item: int = 0, + list_degree: int = 0 + ) -> None: + """ + Add a formatted line to the output buffer. - def add_line(self, text:str, depth: int, iteration: int = 0, item: int = 0, list_degree: int = 0): - """Add a line to the output with proper indentation and length checks.""" - if self.terminate: + Parameters + ---------- + text : str + Text content to add + depth : int + Current nesting depth + iteration : int, optional + Current iteration number + item : int, optional + Current item number + list_degree : int, optional + Number of nested lists + """ + if self._terminate: return - if self.max_line > 0 and len(self.lines) >= self.max_line: - self.lines.append(self.skip_str) - self.terminate = True + + if self.config.max_line > 0 and len(self._lines) >= self.config.max_line: + self._lines.append(self.config.skip_str) + self._terminate = True return - if self.terminate_depth and (depth <= self.max_depth): - self.terminate_depth = False + if self._terminate_depth and depth <= self.config.max_depth: + self._terminate_depth = False + + # Handle multiline text if '\n' in text: - subtexts = text.split('\n') - # only need to keep the list indicator for the first line - self.add_line(subtexts[0], depth, iteration=iteration, item=item, list_degree=list_degree) - for subtext in subtexts[1:]: - self.add_line(subtext, depth, iteration=iteration, item=item, list_degree=0) + first, *rest = text.split('\n') + self._add_line(first, depth, iteration, item, list_degree) + for subtext in rest: + self._add_line(subtext, depth, iteration, item, 0) return - - indent = self.get_indent_str(depth, iteration, list_degree) + + indent = self._get_indent(depth, iteration, list_degree) line = indent + text - - if self.max_len > 0 and len(line) > self.max_len: - line = line[:max(len(indent), self.max_len)] + self.skip_str - - if self.max_depth > 0 and depth > self.max_depth: - if not self.terminate_depth: - line = indent + self.skip_str - self.terminate_depth = True - else: - return - - if self.max_iteration > 0: - if iteration == self.max_iteration: - if list_degree > 0: - line = self.get_indent_str(depth, iteration, 0) + self.skip_str + + # Apply line length limit + if self.config.max_len > 0 and len(line) > self.config.max_len: + line = line[:max(len(indent), self.config.max_len)] + self.config.skip_str + + # Handle depth limit + if self.config.max_depth > 0: + if depth > self.config.max_depth: + if not self._terminate_depth: + line = indent + self.config.skip_str + self._terminate_depth = True else: - line = indent + self.skip_str - elif iteration > self.max_iteration: + return + + # Handle iteration limit + if self.config.max_iteration > 0: + if iteration >= self.config.max_iteration: + if iteration == self.config.max_iteration: + base_indent = ( + self._get_indent(depth, iteration, 0) + if list_degree > 0 + else indent + ) + line = base_indent + self.config.skip_str return - if self.max_item > 0: - if item == self.max_item: - line = indent + self.skip_str - elif item > self.max_item: + # Handle item limit + if self.config.max_item > 0: + if item >= self.config.max_item: + if item == self.config.max_item: + line = indent + self.config.skip_str return - - self.lines.append(line) - def dump(self, data: Any, depth: int = 0, iteration: int = 0, list_degree: int = 0, root: bool = True) -> str: - """Dump the provided data structure to a formatted string.""" + self._lines.append(line) + + def dump( + self, + data: Any, + depth: int = 0, + iteration: int = 0, + list_degree: int = 0, + root: bool = True + ) -> Optional[str]: + """ + Generate a formatted string representation of the data. + + Parameters + ---------- + data : Any + Data structure to dump + depth : int, optional + Current nesting depth + iteration : int, optional + Current iteration number + list_degree : int, optional + Number of nested lists + root : bool, optional + Whether this is the root call + + Returns + ------- + Optional[str] + Formatted string representation if root call, + None for recursive calls + + Notes + ----- + The method handles nested data structures recursively, + applying configured limits and formatting rules. + """ if root: self.reset() - if self.terminate_depth and (depth <= self.max_depth): - self.terminate_depth = False - - if self.terminate or self.terminate_depth: - return - + if self._terminate or ( + self._terminate_depth and not (depth <= self.config.max_depth) + ): + return None + + # Normalize list degree for nested iterations if iteration > 0 and list_degree > 1: list_degree = 1 + + # Handle mappings (dict-like objects) + if isinstance(data, Mapping) and data: + self._dump_mapping(data, depth, iteration, list_degree) - if isinstance(data, dict) and data: - for item, (key, value) in enumerate(data.items()): - if self.max_item > 0 and item == self.max_item: - self.add_line('', depth, iteration=iteration, item=item, list_degree=list_degree) - break - if isinstance(value, (dict, list, tuple)) and value: - text = f'{key}{self.separator}' - self.add_line(text, depth, iteration=iteration, list_degree=list_degree) - if isinstance(value, (list, tuple)) and not self.indent_sequence_on_key: - self.dump(value, depth, iteration=iteration, list_degree=0, root=False) - else: - self.dump(value, depth + 1, iteration=iteration, list_degree=0, root=False) - else: - text = f'{key}{self.separator}{value}' - self.add_line(text, depth, iteration=iteration, list_degree=list_degree) - list_degree = 0 + # Handle sequences (list-like objects) elif isinstance(data, (list, tuple)) and data: - for subiteration, data_i in enumerate(data): - if self.max_iteration > 0 and subiteration == self.max_iteration: - self.add_line('', depth, iteration=subiteration, list_degree=list_degree) - break - self.dump(data_i, depth + 1, iteration=subiteration, - list_degree=list_degree + 1, root=False) - else: - self.add_line(f'{data}', depth, iteration=iteration, list_degree=list_degree) + self._dump_sequence(data, depth, iteration, list_degree) + # Handle primitive values + else: + self._add_line(f"{data}", depth, iteration, list_degree=list_degree) + if root: - return '\n'.join(self.lines) \ No newline at end of file + return '\n'.join(self._lines) + return None + + def _dump_mapping( + self, + data: Mapping[Any, Any], + depth: int, + iteration: int, + list_degree: int + ) -> None: + """Handle dumping of mapping types.""" + for item_num, (key, value) in enumerate(data.items()): + if self.config.max_item > 0 and item_num >= self.config.max_item: + if item_num == self.config.max_item: + self._add_line( + '', + depth, + iteration=iteration, + item=item_num, + list_degree=list_degree + ) + break + + if isinstance(value, (Mapping, Sequence)) and value: + text = f"{key}{self.config.separator}" + self._add_line(text, depth, iteration=iteration, list_degree=list_degree) + + next_depth = depth + ( + 0 if isinstance(value, (list, tuple)) + and not self.config.indent_sequence_on_key + else 1 + ) + self.dump(value, next_depth, iteration, 0, root=False) + else: + text = f"{key}{self.config.separator}{value}" + self._add_line(text, depth, iteration=iteration, list_degree=list_degree) + + list_degree = 0 + + def _dump_sequence( + self, + data: Sequence[Any], + depth: int, + iteration: int, + list_degree: int + ) -> None: + """Handle dumping of sequence types.""" + for idx, item in enumerate(data): + if self.config.max_iteration > 0 and idx >= self.config.max_iteration: + if idx == self.config.max_iteration: + self._add_line( + '', + depth, + iteration=idx, + list_degree=list_degree + ) + break + + self.dump( + item, + depth + 1, + iteration=idx, + list_degree=list_degree + 1, + root=False + ) \ No newline at end of file diff --git a/quickstats/core/io.py b/quickstats/core/io.py index 11a28074f318694f03831375bb30884f2b5eef80..3ac2a9530a00eebb5c1e8dc79bf6f4aa3c0ac5a0 100644 --- a/quickstats/core/io.py +++ b/quickstats/core/io.py @@ -1,3 +1,13 @@ +""" +Terminal output formatting and verbosity control. + +This module provides utilities for formatted console output with color support, +verbosity levels, and comparison formatting. Designed for minimal overhead +in performance-critical I/O operations. +""" + +from __future__ import annotations + import os import sys import time @@ -5,124 +15,113 @@ import difflib import logging import traceback import threading -from enum import Enum -from typing import Union, Optional +from typing import Dict, Union, Optional, ClassVar, Generator, TypeVar from functools import total_ordering from contextlib import contextmanager -__all__ = ['Verbosity', 'VerbosePrint', 'set_default_log_format'] - -text_color_map = { - None: '', - 'black': '\033[30m', - 'red': '\033[31m', - 'green': '\033[32m', - 'yellow': '\033[33m', - 'blue': '\033[34m', - 'magenta': '\033[35m', - 'cyan': '\033[36m', - 'white': '\033[37m', - 'bright black': '\033[30;1m', - 'bright red': '\033[31;1m', - 'bright green': '\033[32;1m', - 'bright yellow': '\033[33;1m', - 'bright blue': '\033[34;1m', - 'bright magenta': '\033[35;1m', - 'bright cyan': '\033[36;1m', - 'bright white': '\033[37;1m', - 'darkred': '\033[91m', - 'reset': '\033[0m', - 'okgreen': '\033[92m' -} - -def get_colored_text(text: str, color: str) -> str: - """ - Returns the text formatted with the specified color. +from .enums import DescriptiveEnum - Parameters - ---------- - text : str - The text to be colored. - color : str - The color to apply to the text. - - Returns - ------- - str - The input text with the specified color formatting. - """ - return f"{text_color_map[color]}{text}{text_color_map['reset']}" +__all__ = ['TextColors', 'Verbosity', 'VerbosePrint', 'set_default_log_format'] -def format_comparison_text(text_left: str, text_right: str, - equal_color: Optional[str] = None, - delete_color: str = "red", - insert_color: str = "green") -> tuple: - """ - Formats two texts for comparison with color coding for differences. +# Type aliases +VerbosityLevel = Union[int, 'Verbosity', str] +T = TypeVar('T') - Parameters - ---------- - text_left : str - The left text to compare. - text_right : str - The right text to compare. - equal_color : str, optional - The color for equal text. Default is None. - delete_color : str - The color for deleted text. Default is 'red'. - insert_color : str - The color for inserted text. Default is 'green'. - - Returns - ------- - tuple - A tuple containing the formatted left and right texts. - """ - codes = difflib.SequenceMatcher(a=text_left, b=text_right).get_opcodes() - s_left = "" - s_right = "" - for code in codes: - if code[0] == "equal": - s = get_colored_text(text_left[code[1]:code[2]], equal_color) - s_left += s - s_right += s - elif code[0] == "delete": - s_left += get_colored_text(text_left[code[1]:code[2]], delete_color) - elif code[0] == "insert": - s_right += get_colored_text(text_right[code[3]:code[4]], insert_color) - elif code[0] == "replace": - s_left += get_colored_text(text_left[code[1]:code[2]], delete_color) - s_right += get_colored_text(text_right[code[3]:code[4]], insert_color) - return s_left, s_right - -getThreads = True -getMultiprocessing = True -getProcesses = True +class TextColors: + """ANSI color codes for terminal output.""" + + CODES: ClassVar[Dict[str, str]] = { + # Standard colors + 'black': '\033[30m', + 'red': '\033[31m', + 'green': '\033[32m', + 'yellow': '\033[33m', + 'blue': '\033[34m', + 'magenta': '\033[35m', + 'cyan': '\033[36m', + 'white': '\033[37m', + # Bright colors + 'bright black': '\033[30;1m', + 'bright red': '\033[31;1m', + 'bright green': '\033[32;1m', + 'bright yellow': '\033[33;1m', + 'bright blue': '\033[34;1m', + 'bright magenta': '\033[35;1m', + 'bright cyan': '\033[36;1m', + 'bright white': '\033[37;1m', + # Special colors + 'darkred': '\033[91m', + 'okgreen': '\033[92m', + # Control + 'reset': '\033[0m', + } + + @classmethod + def colorize(cls, text: str, color: Optional[str]) -> str: + """Apply color formatting to text.""" + if not color: + return text + + color_code = cls.CODES.get(color) + if not color_code: + return text + + return f"{color_code}{text}{cls.CODES['reset']}" + + @classmethod + def format_comparison( + cls, + text_left: str, + text_right: str, + equal_color: Optional[str] = None, + delete_color: str = "red", + insert_color: str = "green" + ) -> tuple[str, str]: + """Format text comparison with color coding.""" + matcher = difflib.SequenceMatcher(a=text_left, b=text_right) + left_result = [] + right_result = [] + + for tag, i1, i2, j1, j2 in matcher.get_opcodes(): + if tag == "equal": + text = cls.colorize(text_left[i1:i2], equal_color) + left_result.append(text) + right_result.append(text) + elif tag == "delete": + left_result.append(cls.colorize(text_left[i1:i2], delete_color)) + elif tag == "insert": + right_result.append(cls.colorize(text_right[j1:j2], insert_color)) + elif tag == "replace": + left_result.append(cls.colorize(text_left[i1:i2], delete_color)) + right_result.append(cls.colorize(text_right[j1:j2], insert_color)) + + return "".join(left_result), "".join(right_result) @total_ordering -class Verbosity(Enum): +class Verbosity(DescriptiveEnum): """ - Enum for verbosity levels. - + Verbosity levels for output control. + Attributes ---------- SILENT : Verbosity - No output. + No output (level 100) CRITICAL : Verbosity - Critical errors. + Critical errors only (level 50) ERROR : Verbosity - Errors. + Errors and above (level 40) TIPS : Verbosity - Tips. + Tips and above (level 35) WARNING : Verbosity - Warnings. + Warnings and above (level 30) INFO : Verbosity - Information. + Information and above (level 20) DEBUG : Verbosity - Debugging information. + All output including debug (level 10) IGNORE : Verbosity - Ignore messages. + Process all messages (level 0) """ + SILENT = (100, 'SILENT') CRITICAL = (50, 'CRITICAL') ERROR = (40, 'ERROR') @@ -132,269 +131,260 @@ class Verbosity(Enum): DEBUG = (10, 'DEBUG') IGNORE = (0, 'IGNORE') - def __new__(cls, value: int, levelname: str = ""): - obj = object.__new__(cls) - obj._value_ = value - obj.levelname = levelname - return obj - - def __lt__(self, other): - if self.__class__ is other.__class__: + def __lt__(self, other: Union[Verbosity, int, str]) -> bool: + if isinstance(other, type(self)): return self.value < other.value - elif isinstance(other, int): + if isinstance(other, int): return self.value < other - elif isinstance(other, str): - return self.value < getattr(self, other.upper()).value - return NotImplemented - - def __eq__(self, other): - if self.__class__ is other.__class__: - return self.value == other.value - elif isinstance(other, int): - return self.value == other - elif isinstance(other, str): - return self.value == getattr(self, other.upper()).value + if isinstance(other, str): + try: + other_level = getattr(self.__class__, other.upper()) + return self.value < other_level.value + except AttributeError: + return NotImplemented return NotImplemented - + + class VerbosePrint: """ - A class for managing verbose printing. - - Parameters + Configurable verbose printing with formatting support. + + Attributes ---------- - verbosity : Union[int, Verbosity, str], optional - The verbosity level. Default is Verbosity.INFO. - fmt : str, optional - The format string for messages. Default is None. - name : str, optional - The name for the logger. Default is ''. - msecfmt : str, optional - The format string for milliseconds. Default is None. - datefmt : str, optional - The date format string. Default is None. - - Methods - ------- - silent(text='', color=None, bare=False) - Silent print (no output). - tips(text='', color=None, bare=False) - Print tips. - info(text='', color=None, bare=False) - Print information. - warning(text='', color=None, bare=False) - Print warnings. - error(text='', color=None, bare=False) - Print errors. - critical(text='', color=None, bare=False) - Print critical errors. - debug(text='', color=None, bare=False) - Print debug information. - write(text='', color=None) - Write text with no formatting. - set_format(fmt=None) - Set the message format. - set_timefmt(datefmt=None, msecfmt=None) - Set the time format. - format_time() - Format the current time. + FORMATS : ClassVar[Dict[str, str]] + Predefined format templates + DEFAULT_FORMAT : ClassVar[str] + Default message format + DEFAULT_DATEFORMAT : ClassVar[str] + Default date format + DEFAULT_MSECFORMAT : ClassVar[str] + Default millisecond format """ - FORMATS = { + FORMATS: ClassVar[Dict[str, str]] = { 'basic': '[%(levelname)s] %(message)s', 'detailed': '%(asctime)s | PID:%(process)d, TID:%(threadName)s | %(levelname)s | %(message)s' } - DEFAULT_FORMAT = FORMATS['basic'] - DEFAULT_DATEFORMAT = '%Y-%m-%d %H:%M:%S' - DEFAULT_MSECFORMAT = '%s.%03d' - ASCTIME_SEARCH = '%(asctime)' - - @property - def verbosity(self): - return self._verbosity - @verbosity.setter - def verbosity(self, val): - if isinstance(val, str): - try: - v = getattr(Verbosity, val.upper()) - except Exception: - raise ValueError(f"invalid verbosity level: {val}") - self._verbosity = v - else: - self._verbosity = val - - def __init__(self, verbosity: Union[int, Verbosity, str] = Verbosity.INFO, - fmt: Optional[str] = None, name: Optional[str] = '', - msecfmt: Optional[str] = None, - datefmt: Optional[str] = None): - self.verbosity = verbosity - self.set_format(fmt) - self.set_timefmt(datefmt, msecfmt) - self._name = name + DEFAULT_FORMAT: ClassVar[str] = FORMATS['basic'] + DEFAULT_DATEFORMAT: ClassVar[str] = '%Y-%m-%d %H:%M:%S' + DEFAULT_MSECFORMAT: ClassVar[str] = '%s.%03d' - def __call__(self, text: str, verbosity: Union[int, Verbosity] = Verbosity.INFO, - color: Optional[str] = None, bare: bool = False): + def __init__( + self, + verbosity: VerbosityLevel = Verbosity.INFO, + fmt: Optional[str] = None, + name: str = '', + msecfmt: Optional[str] = None, + datefmt: Optional[str] = None + ): """ - Print the text with the specified verbosity level and color. + Initialize VerbosePrint with custom configuration. Parameters ---------- - text : str - The text to print. - verbosity : Union[int, Verbosity], optional - The verbosity level. Default is Verbosity.INFO. - color : str, optional - The color to apply to the text. Default is None. - bare : bool, optional - If True, prints text without formatting. Default is False. + verbosity : VerbosityLevel, optional + Initial verbosity level, by default Verbosity.INFO + fmt : str, optional + Message format template or format name ('basic' or 'detailed') + name : str, optional + Logger name for message formatting + msecfmt : str, optional + Millisecond format string, default is '%s.%03d' + datefmt : str, optional + Date format string (strftime format), default is '%Y-%m-%d %H:%M:%S' """ - if verbosity < self.verbosity: - return None - if color: - text = f"{text_color_map[color]}{text}{text_color_map['reset']}" + self._verbosity = Verbosity.parse(verbosity) + self._name = name + self.set_timefmt(datefmt, msecfmt) + self.set_format(fmt) + + @property + def verbosity(self) -> Verbosity: + """Current verbosity level.""" + return self._verbosity + + @verbosity.setter + def verbosity(self, value: VerbosityLevel) -> None: + """Set verbosity level.""" + self._verbosity = Verbosity.parse(value) + + def __call__( + self, + text: str, + verbosity: VerbosityLevel = Verbosity.INFO, + color: Optional[str] = None, + bare: bool = False + ) -> None: + """Print text with specified verbosity and formatting.""" + level = Verbosity.parse(verbosity) + if level < self.verbosity: + return + if not bare: - if hasattr(verbosity, 'levelname'): - levelname = verbosity.levelname - else: - levelname = f"Level {verbosity}" - if self._formatter.usesTime(): - asctime = self.format_time() - else: - asctime = None - if getThreads: - thread = threading.get_ident() - threadName = threading.current_thread().name - else: - thread = None - threadName = None - if getProcesses and hasattr(os, 'getpid'): - process = os.getpid() - else: - process = None - args = { - 'name': self._name, - 'message': text, - 'levelname': levelname, - 'asctime': asctime, - 'thread': thread, - 'threadName': threadName, - 'process': process - } - text = self._formatter._fmt % args + text = self._format_message(text, level) + + if color: + text = TextColors.colorize(text, color) + sys.stdout.write(f"{text}\n") - def __copy__(self): - """ - Create a shallow copy of the VerbosePrint instance. - """ - cls = self.__class__ - new_instance = cls.__new__(cls) - new_instance.__dict__.update(self.__dict__) - return new_instance + def _format_message(self, text: str, level: Verbosity) -> str: + """Format message with current configuration.""" + args = { + 'name': self._name, + 'message': text, + 'levelname': level.description, + } - def silent(self, text: str = '', color: Optional[str] = None, bare: bool = False): - pass + if self._needs_time: + args['asctime'] = self.format_time() - def tips(self, text: str = '', color: Optional[str] = None, bare: bool = False): - self.__call__(text, Verbosity.TIPS, color=color, bare=bare) - - def info(self, text: str = '', color: Optional[str] = None, bare: bool = False): - self.__call__(text, Verbosity.INFO, color=color, bare=bare) - - def warning(self, text: str = '', color: Optional[str] = None, bare: bool = False): - self.__call__(text, Verbosity.WARNING, color=color, bare=bare) + if self._needs_process: + args['process'] = os.getpid() if hasattr(os, 'getpid') else None + + if self._needs_thread and threading: + args['thread'] = threading.get_ident() + args['threadName'] = threading.current_thread().name + + return self._formatter._fmt % args + + def format_time(self) -> str: + """Format current time according to configuration.""" + current_time = time.time() + time_struct = self._formatter.converter(current_time) + base_time = time.strftime(self._datefmt, time_struct) - def error(self, text: str = '', color: Optional[str] = None, bare: bool = False): - self.__call__(text, Verbosity.ERROR, color=color, bare=bare) + if self._msecfmt: + msecs = int((current_time - int(current_time)) * 1000) + return self._msecfmt % (base_time, msecs) - def critical(self, text: str = '', color: Optional[str] = None, bare: bool = False): - self.__call__(text, Verbosity.CRITICAL, color=color, bare=bare) - - def debug(self, text: str = '', color: Optional[str] = None, bare: bool = False): - self.__call__(text, Verbosity.DEBUG, color=color, bare=bare) + return base_time - def write(self, text: str = '', color: Optional[str] = None): - self.__call__(text, Verbosity.SILENT, color=color, bare=True) - - def set_format(self, fmt: Optional[str] = None): + def set_format(self, fmt: Optional[str] = None) -> None: """ - Set the message format. + Set message format template. Parameters ---------- fmt : str, optional - The format string for messages. Default is None. + Format template or name ('basic' or 'detailed'). + If None, uses DEFAULT_FORMAT. """ - if fmt is None: - fmt = self.DEFAULT_FORMAT - elif fmt in self.FORMATS: + fmt = fmt or self.DEFAULT_FORMAT + if fmt in self.FORMATS: fmt = self.FORMATS[fmt] self._formatter = logging.Formatter(fmt) - def set_timefmt(self, datefmt: Optional[str] = None, msecfmt: Optional[str] = None): + self._needs_time = '%(asctime)' in fmt + self._needs_process = '%(process)' in fmt + self._needs_thread = any(key in fmt for key in ('%(thread)', '%(threadName)')) + + def set_timefmt( + self, + datefmt: Optional[str] = None, + msecfmt: Optional[str] = None + ) -> None: """ - Set the time format. + Set time format for timestamps. Parameters ---------- datefmt : str, optional - The date format string. Default is None. + Date format string (strftime format). If None, uses DEFAULT_DATEFORMAT msecfmt : str, optional - The format string for milliseconds. Default is None. - """ - if datefmt is None: - datefmt = self.DEFAULT_DATEFORMAT - if msecfmt is None: - msecfmt = self.DEFAULT_MSECFORMAT - self._datefmt = datefmt - self._msecfmt = msecfmt + Millisecond format string. If None, uses DEFAULT_MSECFORMAT. + Set to empty string to disable milliseconds. + + Examples + -------- + >>> printer = VerbosePrint(fmt='detailed') + >>> printer.set_timefmt(datefmt='%H:%M:%S') # Time only + >>> printer.info('Test') + 12:34:56.123 | PID:1234, TID:MainThread | INFO | Test - def format_time(self) -> str: + >>> printer.set_timefmt(msecfmt='') # Disable milliseconds + >>> printer.info('Test') + 12:34:56 | PID:1234, TID:MainThread | INFO | Test """ - Format the current time. + self._datefmt = datefmt or self.DEFAULT_DATEFORMAT + self._msecfmt = msecfmt or self.DEFAULT_MSECFORMAT + + # Convenience methods for different verbosity levels + def silent(self, text: str = '', color: Optional[str] = None, bare: bool = False) -> None: + """No output.""" + pass + + def debug(self, text: str = '', color: Optional[str] = None, bare: bool = False) -> None: + """Print debug message.""" + self(text, Verbosity.DEBUG, color, bare) + + def info(self, text: str = '', color: Optional[str] = None, bare: bool = False) -> None: + """Print info message.""" + self(text, Verbosity.INFO, color, bare) + + def warning(self, text: str = '', color: Optional[str] = None, bare: bool = False) -> None: + """Print warning message.""" + self(text, Verbosity.WARNING, color, bare) + + def error(self, text: str = '', color: Optional[str] = None, bare: bool = False) -> None: + """Print error message.""" + self(text, Verbosity.ERROR, color, bare) + + def critical(self, text: str = '', color: Optional[str] = None, bare: bool = False) -> None: + """Print critical message.""" + self(text, Verbosity.CRITICAL, color, bare) + + def tips(self, text: str = '', color: Optional[str] = None, bare: bool = False) -> None: + """Print tip message.""" + self(text, Verbosity.TIPS, color, bare) + + def write(self, text: str = '', color: Optional[str] = None) -> None: + """Write raw text.""" + self(text, Verbosity.SILENT, color, True) + + def copy(self) -> VerbosePrint: + """Create a copy of this printer.""" + return self.__class__( + verbosity=self.verbosity, + fmt=self._formatter._fmt, + name=self._name, + msecfmt=self._msecfmt, + datefmt=self._datefmt + ) - Returns - ------- - str - The formatted current time string. - """ - _ct = time.time() - ct = self._formatter.converter(_ct) - s = time.strftime(self._datefmt, ct) - if self._msecfmt: - msecs = int((_ct - int(_ct)) * 1000) + 0.0 - s = self._msecfmt % (s, msecs) - return s - def copy(self): - return self.__copy__() - @contextmanager -def switch_verbosity(target: VerbosePrint, verbosity: Union[int, str]): +def switch_verbosity(target: VerbosePrint, verbosity: VerbosityLevel) -> Generator[None, None, None]: """ - Context manager to switch verbosity temporarily. - + Temporarily change verbosity level. + Parameters ---------- target : VerbosePrint - The target VerbosePrint instance. - verbosity : Union[int, str] - The new verbosity level to set temporarily. - - Yields - ------ - None + Printer to modify + verbosity : Union[int, str, Verbosity] + New verbosity level """ + original = target.verbosity try: - orig_verbosity = target.verbosity target.verbosity = verbosity yield except Exception: traceback.print_exc(file=sys.stdout) finally: - target.verbosity = orig_verbosity + target.verbosity = original + -def set_default_log_format(fmt:str='basic'): - if fmt in VerbosePrint.FORMATS: - fmt = VerbosePrint.FORMATS[fmt] - VerbosePrint.DEFAULT_FORMAT = fmt \ No newline at end of file +def set_default_log_format(fmt: str = 'basic') -> None: + """ + Set default format for new VerbosePrint instances. + + Parameters + ---------- + fmt : str + Format name or template + """ + VerbosePrint.DEFAULT_FORMAT = ( + VerbosePrint.FORMATS.get(fmt, fmt) + ) \ No newline at end of file diff --git a/quickstats/core/mappings.py b/quickstats/core/mappings.py new file mode 100644 index 0000000000000000000000000000000000000000..71a13ff05ccaeabd257f367ba1ecbf5d280b178e --- /dev/null +++ b/quickstats/core/mappings.py @@ -0,0 +1,288 @@ +""" +Utilities for working with nested dictionaries, providing recursive update and merge capabilities. + +This module implements a custom dictionary class and helper functions for handling nested +dictionary structures with recursive update operations. +""" + +from __future__ import annotations + +from collections.abc import Mapping, MutableMapping +from copy import deepcopy +from typing import ( + Any, + Callable, + Iterable, + Optional, + Type, + TypeVar, + Union, +) + +T = TypeVar('T', bound=Mapping[str, Any]) + +def recursive_update( + target: MutableMapping[str, Any], + source: Mapping[str, Any], + / +) -> MutableMapping[str, Any]: + """ + Update a dictionary recursively with values from another dictionary. + + Parameters + ---------- + target : MutableMapping[str, Any] + The dictionary to be updated. + source : Mapping[str, Any] + The dictionary containing updates. + + Returns + ------- + MutableMapping[str, Any] + The updated dictionary (same object as target). + + Notes + ----- + This function modifies the target dictionary in-place. For nested dictionaries, + it performs a deep update rather than simply replacing the nested dictionary. + + Examples + -------- + >>> d1 = {'a': 1, 'b': {'c': 2}} + >>> d2 = {'b': {'d': 3}} + >>> recursive_update(d1, d2) + {'a': 1, 'b': {'c': 2, 'd': 3}} + """ + if not source: + return target + + for key, value in source.items(): + if ( + isinstance(value, Mapping) + and key in target + and isinstance(target[key], MutableMapping) + ): + recursive_update(target[key], value) + else: + target[key] = value + return target + + +def concatenate( + mappings: Iterable[Optional[Mapping[str, Any]]], + *, + copy: bool = False +) -> NestedDict: + """ + Concatenate multiple dictionaries recursively. + + Parameters + ---------- + mappings : Iterable[Optional[Mapping[str, Any]]] + An iterable of dictionaries to concatenate. None values are skipped. + copy : bool, optional + If True, create a deep copy of each dictionary before updating, by default False. + + Returns + ------- + NestedDict + A new NestedDict containing the concatenated result. + + Examples + -------- + >>> d1 = {'a': 1, 'b': {'c': 2}} + >>> d2 = {'b': {'d': 3}, 'e': 4} + >>> concatenate([d1, d2]) + {'a': 1, 'b': {'c': 2, 'd': 3}, 'e': 4} + """ + result = NestedDict() + for mapping in mappings: + if not mapping: + continue + source = deepcopy(mapping) if copy else mapping + recursive_update(result, source) + return result + + +# Alias for convenience, keeping the same name for backward compatibility +concat = concatenate + + +def merge_classattr( + cls: Type, + attribute: str, + *, + copy: bool = False, + parse: Optional[Callable[[Any], Mapping[str, Any]]] = None +) -> NestedDict: + """ + Merge a class attribute from the class hierarchy using recursive update. + + Parameters + ---------- + cls : Type + The class whose MRO will be used to recursively update the attribute. + attribute : str + The name of the dictionary attribute to be updated. + copy : bool, optional + If True, create a deep copy of each attribute before updating, by default False. + parse : Callable[[Any], Mapping[str, Any]], optional + Function to transform the class attribute before merging. + Must return a Mapping. If None, no transformation is applied. + + Returns + ------- + NestedDict + The merged attribute dictionary after applying recursive updates. + + Raises + ------ + TypeError + If the parsed attribute is not a Mapping. + AttributeError + If the specified attribute doesn't exist and parse function is None. + + Examples + -------- + >>> class Base: + ... data = {'a': 1} + >>> class Child(Base): + ... data = {'b': 2} + >>> merge_classattr(Child, 'data') + {'a': 1, 'b': 2} + """ + result = NestedDict() + + for base_cls in reversed(cls.__mro__): + try: + base_data = getattr(base_cls, attribute) + except AttributeError: + continue + + if parse is not None: + try: + base_data = parse(base_data) + except Exception as e: + raise ValueError( + f"Failed to parse attribute '{attribute}' from {base_cls.__name__}" + ) from e + + if not isinstance(base_data, Mapping): + raise TypeError( + f"Attribute '{attribute}' in {base_cls.__name__} " + f"must be a Mapping, not {type(base_data).__name__}" + ) + + if copy: + base_data = deepcopy(base_data) + + recursive_update(result, base_data) + + return result + + +class NestedDict(dict): + """ + A dictionary subclass supporting recursive updates via operators. + + This class extends the built-in dict to provide recursive update operations + using the & and &= operators, similar to set operations but for nested dictionaries. + + Methods + ------- + merge(other) + Update the dictionary recursively with values from another mapping. + copy(deep=False) + Create a shallow or deep copy of the dictionary. + + Examples + -------- + >>> d1 = NestedDict({'a': 1, 'b': {'c': 2}}) + >>> d2 = {'b': {'d': 3}} + >>> d3 = d1 & d2 + >>> print(d3) + {'a': 1, 'b': {'c': 2, 'd': 3}} + """ + + def merge(self, other: Optional[Mapping[str, Any]]) -> None: + """ + Update the dictionary recursively with values from another mapping. + + Parameters + ---------- + other : Optional[Mapping[str, Any]] + The mapping containing updates. If None, no update is performed. + + Raises + ------ + TypeError + If other is not None and not a Mapping instance. + """ + if other is None: + return + + if not isinstance(other, Mapping): + raise TypeError( + f"Expected Mapping, got {type(other).__name__}" + ) + + recursive_update(self, other) + + def __and__(self, other: Optional[Mapping[str, Any]]) -> NestedDict: + """ + Create a new dictionary by recursively updating with another mapping. + + Parameters + ---------- + other : Optional[Mapping[str, Any]] + The mapping containing updates. + + Returns + ------- + NestedDict + A new dictionary containing the merged result. + + Raises + ------ + TypeError + If other is not None and not a Mapping instance. + """ + return concatenate([self, other], copy=True) + + def __iand__(self, other: Optional[Mapping[str, Any]]) -> NestedDict: + """ + Update the dictionary in-place recursively with another mapping. + + Parameters + ---------- + other : Optional[Mapping[str, Any]] + The mapping containing updates. + + Returns + ------- + NestedDict + Self, after applying the updates. + + Raises + ------ + TypeError + If other is not None and not a Mapping instance. + """ + self.merge(other) + return self + + def copy(self, deep: bool = False) -> NestedDict: + """ + Create a copy of the dictionary. + + Parameters + ---------- + deep : bool, optional + If True, create a deep copy, by default False. + + Returns + ------- + NestedDict + A new dictionary containing the copied data. + """ + return NestedDict(deepcopy(self) if deep else super().copy()) \ No newline at end of file diff --git a/quickstats/core/parameters.py b/quickstats/core/parameters.py new file mode 100644 index 0000000000000000000000000000000000000000..c1226bf3a3ebbed41eaca4c2292fa935dea6cfb6 --- /dev/null +++ b/quickstats/core/parameters.py @@ -0,0 +1,200 @@ +from typing import Union, Optional, List, Any, Callable, Tuple, Type +from functools import lru_cache +import copy + +from quickstats.utils.string_utils import ( + format_aligned_dict, + format_delimited_dict +) +from .type_validation import get_type_validator, get_annotation_str +from .typing import NOTSET, NOTSETTYPE +from .constraints import BaseConstraint + +__all__ = ['Parameter'] + +def is_mutable(obj): + return obj.__class__.__hash__ is None + +def _deprecation_message(deprecated: Optional[Union[bool, str]] = None) -> Optional[str]: + if deprecated is None: + return None + if isinstance(deprecated, bool): + return 'deprecated' if deprecated else None + if isinstance(deprecated, str): + return deprecated + return f'{deprecated!r}' + +def _format_str(name: Union[str, NOTSETTYPE] = NOTSET, + default: Any = NOTSET, + default_factory: Union[Callable, NOTSETTYPE] = NOTSET, + dtype: type = Any, + tags: Optional[Tuple[str]] = None, + description: Optional[str] = None, + constraints: Optional[Tuple[BaseConstraint]] = None, + deprecated: Union[bool, str] = False, + aligned: bool = False, + linebreak: int = 100): + name = 'Parameter' if name is NOTSET else name + attribs = {} + if default is not NOTSET: + attribs['default'] = f'{default!r}' + if default_factory is not NOTSET: + attribs['default_factory'] = f'{default_factory!r}' + if tags is not None: + attribs['tags'] = f'{tags!r}' + if constraints is not None: + attribs['constraints'] = f'{constraints!r}' + if description is not None: + attribs['description'] = description + message = _deprecation_message(deprecated) + if (message is not None): + attribs['deprecation_message'] = message + if aligned: + attrib_str = format_aligned_dict(attribs, left_margin=4, linebreak=linebreak) + return f'{label}\n{attrib_str}' + f'{name}({",".join([f"{key}={value!r}" for key, value in attribs.items()])},)' + +def Parameter(default: Any = NOTSET, + *, + default_factory: Union[Callable, NOTSETTYPE] = NOTSET, + tags: Optional[List[str]] = None, + description: Optional[str] = None, + constraints: Optional[List[BaseConstraint]] = None, + validate_default: bool = False, + deprecated: bool = False): + return ParameterInfo(default=default, + default_factory=default_factory, + tags=tags, + description=description, + constraints=constraints, + validate_default=validate_default, + deprecated=deprecated) + +class ParameterInfo: + + __slots__ = ( + 'default', + 'default_factory', + 'dtype', + 'tags', + 'description', + 'constraints', + 'deprecated', + 'validate_default', + ) + + def __init__(self, default: Any = NOTSET, + *, + default_factory: Union[Callable, NOTSETTYPE] = NOTSET, + name: str = NOTSET, + annotation: Type = Any, + tags: Optional[Tuple[str]] = None, + description: Optional[str] = None, + constraints: Optional[Tuple[BaseConstraint]] = None, + deprecated: Union[bool, str] = False, + validate_default: bool = False): + + if default is not NOTSET and default_factory is not NOTSET: + raise ValueError('cannot specify both default and default_factory') + + if is_mutable(default): + raise ValueError('mutable default value is not allowed') + + self.default = default + self.default_factory = default_factory + self.name = name + self.annotation = annotation + if (tags is not None) and not isinstance(tags, tuple): + tags = tuple(tags) + self.tags = tags + self.description = description + if (constraints is not None) and not isinstance(constraints, tuple): + constraints = tuple(constraints) + self.constraints = constraints + self.deprecated = deprecated + self.validate_default = validate_default + + def __repr__(self) -> str: + return _format_str(name=self.name, + default=self.default, + default_factory=self.default_factory, + dtype=self.dtype, + tags=self.tags, + description=self.description, + constraints=self.constraints, + deprecated=self.deprecated, + aligned=False) + + def __str__(self) -> str: + return _format_str(name=self.name, + default=self.default, + default_factory=self.default_factory, + dtype=self.dtype, + tags=self.tags, + description=self.description, + constraints=self.constraints, + deprecated=self.deprecated, + aligned=True) + + def __set_name__(self, name:str) -> None: + self.name = name + + def __set_annotation__(self, annotation: Type) -> None: + self.annotation = annotation + + @classmethod + def _repr(cls, label: str = 'Parameter', + default: Any = NOTSET, + default_factory: Union[Callable, NOTSETTYPE] = NOTSET, + dtype: Union[type, NOTSETTYPE] = NOTSET, + tags: Optional[Tuple[str]] = None, + description: Optional[str] = None, + constraints: Optional[Tuple[BaseConstraint]] = None, + deprecated: Union[bool, str] = False): + + attribs = {} + if dtype is not NOTSET: + attribs['dtype'] = get_annotation_str(dtype) + if default is not NOTSET: + attribs['default'] = str(default) + if default_factory is not NOTSET: + attribs['default_factory'] = str(default_factory) + if tags is not None: + attribs['tags'] = str(tags) + if constraints is not None: + attribs['constraints'] = str(constraints) + if description is not None: + attribs['description'] = description + if (deprecated is not None): + attribs['deprecated'] = str(deprecated) + attrib_str = format_delimited_dict(attribs) + return f'{label}({attrib_str})' + + @property + def label(self) -> str: + return 'Parameter' + + @property + def required(self) -> bool: + return (default is NOTSET) and (default_factory is NOTSET) + + @property + def validator(self) -> Callable: + return get_type_validator(self.dtype) + + @property + def has_default(self) -> bool: + return self.default is not NOTSET or self.default_factory is not NOTSET + + def get_default(self, evaluate_factory: bool = True) -> Any: + if self.default is not NOTSET: + return self.default + if (self.default_factory is not NOTSET) and evaluate_factory: + return self.default_factory() + return NOTSET + + def type_check(self, value: Any) -> bool: + return self.validator(value) + + def constraint_check(self, value: Any) -> bool: + return all(constraint(value) for constraint in self.constraints) \ No newline at end of file diff --git a/quickstats/core/tree_data.py b/quickstats/core/tree_data.py new file mode 100644 index 0000000000000000000000000000000000000000..2a7135b8cd8a237ea7e079829b058ea0975b54da --- /dev/null +++ b/quickstats/core/tree_data.py @@ -0,0 +1,148 @@ +from typing import Any, Dict, List, Optional +from quickstats.core.typing import NOTSET + +class TreeData: + """ + A tree-like data structure with hierarchical key-value storage, + supporting nested namespaces and flexible key lookups. + """ + + def __init__(self, data: Optional[Dict[str, Any]] = None, delimiter: str = '.'): + self._delimiter = delimiter + self.load(data) + + @classmethod + def create(cls, data: Optional[Dict[str, Any]] = None, delimiter: str = '.'): + """ + Factory method for creating a TreeData instance. + """ + return cls(data, delimiter=delimiter) + + def load(self, data: Optional[Dict[str, Any]] = None) -> None: + """ + Load a dictionary into the tree structure, resetting existing data. + """ + self._branches: Dict[str, Any] = {} + if data: + for key, value in data.items(): + self.set(key, value) + + def __setitem__(self, key: str, value: Any) -> None: + """ + Set a value in the tree structure. + """ + self.set(key, value) + + def __getitem__(self, key: str) -> Any: + """ + Get a value from the tree structure. + """ + value = self.get(key) + if value is NOTSET: + raise KeyError(f'{key} not found in tree') + return value + + def __getattr__(self, key: str) -> Any: + """ + Get a branch or value as an attribute. + """ + return self._branches[key] + + def __contains__(self, key: str) -> bool: + """ + Check if a key exists in the tree. + """ + return self.get(key) is not NOTSET + + def _branch_split(self, key: str, maxsplit: int = -1): + """ + Split the key based on the delimiter. + """ + return key.split(self._delimiter, maxsplit) + + def get(self, key: str, value: Any = NOTSET) -> Any: + """ + Retrieve a value from the tree, navigating nested branches. + + Parameters + ---------- + key : str + The key to look up, potentially containing delimiters for nested access. + value : any + The value to return if the key is not found (default is NOTSET). + + Returns + ------- + any + The value found in the tree, or `value` if not found. + """ + current = self + for branch in self._branch_split(key): + if not isinstance(current, TreeData): + return value + current = current._branches.get(branch, value) + return current + + def set(self, key: str, value: Any) -> None: + """ + Set a value in the tree, creating nested branches as needed. + + Parameters + ---------- + key : str + The key to set, potentially containing delimiters for nested access. + value : any + The value to assign at the given key. + """ + if self._delimiter not in key: + self._branches[key] = value + else: + branch, subbranch = self._branch_split(key, 1) + if branch not in self._branches: + self._branches[branch] = self.create(delimiter=self._delimiter) + self._branches[branch][subbranch] = value + + def branches(self) -> List[str]: + """ + Return the list of top-level branches in the tree. + + Returns + ------- + list of str + A list of keys representing the top-level branches. + """ + return [key for key in self._branches if isinstance(self._branches[key], TreeData)] + + def todict(self, nested: bool = False) -> Dict[str, Any]: + """ + Convert the TreeData object to a dictionary. + + Parameters + ---------- + nested : bool + If True, return a nested dictionary. If False, return a flat dictionary with delimiter-separated keys. + + Returns + ------- + dict + The dictionary representation of the tree. + """ + result = {} + for key, value in self._branches.items(): + if isinstance(value, TreeData): + if nested: + result[key] = value.todict(nested=True) + else: + sub_dict = value.todict(nested=False) + for subkey, subvalue in sub_dict.items(): + result[f"{key}{self._delimiter}{subkey}"] = subvalue + else: + result[key] = value + return result + + def __repr__(self) -> str: + """ + String representation of the TreeData object, showing its structure. + """ + class_name = self.__class__.__name__ + return f"<{class_name}: {self._branches}>" diff --git a/quickstats/core/trees.py b/quickstats/core/trees.py new file mode 100644 index 0000000000000000000000000000000000000000..a408efc37871b703284e0a4d64c68469abbc07f3 --- /dev/null +++ b/quickstats/core/trees.py @@ -0,0 +1,988 @@ +""" +A module providing a flexible tree data structure with named nodes. + +This module implements a tree structure where each node has a name, optional data, +and can have multiple children. Nodes can be accessed using domain-style notation +(e.g., 'parent.child.grandchild') and support dict-like operations. + +Examples +-------- +>>> # Create a basic tree +>>> root = NamedTreeNode("root", data="root_data") +>>> root.add_child(NamedTreeNode("child1", "child1_data")) +>>> print(root) +root: 'root_data' + child1: 'child1_data' + +>>> # Use domain notation +>>> root.set("new_data", domain="child1") +>>> print(root.get(domain="child1")) +'new_data' + +>>> # Use dictionary updates +>>> root |= {"name": "child2", "data": "child2_data", "children": {}} +>>> print(root) +root: 'root_data' + child1: 'new_data' + child2: 'child2_data' +""" +from __future__ import annotations + +from typing import ( + Any, Optional, List, Dict, Union, Iterator, TypeVar, Generic, + Sequence, Mapping, ClassVar, get_args, get_origin, Type +) +from dataclasses import dataclass +import copy +import re + +from .type_validation import check_type + +# Type variables and aliases +T = TypeVar('T') +DomainType = Optional[str] +NodeData = TypeVar('NodeData') + +class TreeError(Exception): + """Base exception for tree-related errors.""" + pass + +class InvalidNodeError(TreeError): + """Exception raised for invalid node operations.""" + pass + +class DomainError(TreeError): + """Exception raised for invalid domain operations.""" + pass + +class ValidationError(TreeError): + """Exception raised for data validation errors.""" + pass + +@dataclass +class NodeConfig: + """Configuration for tree nodes. + + Parameters + ---------- + separator : str, default='.' + Separator used for domain paths + allow_none_data : bool, default=True + Whether to allow None as valid data + validate_names : bool, default=False + Whether to validate node names against pattern + validate_data_type : bool, default=False + Whether to validate data against the node's type parameter + name_pattern : str, default=r'^[a-zA-Z][a-zA-Z0-9_]*$' + Pattern for valid node names if validate_names is True + """ + separator: str = '.' + allow_none_data: bool = True + validate_names: bool = False # Default to False for performance + validate_data_type: bool = False # Default to False for performance + name_pattern: str = r'^[a-zA-Z][a-zA-Z0-9_]*$' + +class NamedTreeNode(Generic[NodeData]): + """ + A tree node with a name, optional data, and child nodes. + + This class implements a flexible tree structure where each node has: + - A unique name within its parent's scope + - Optional data of any type + - Zero or more child nodes + - Support for domain-style access (e.g., 'parent.child.grandchild') + + Attributes + ---------- + name : str + The name of the node + data : Optional[NodeData] + The data stored in the node + children : Dict[str, NamedTreeNode] + Dictionary of child nodes keyed by their names + + Examples + -------- + >>> # Create a basic tree + >>> root = NamedTreeNode[str]("root", "root_data") + >>> root.add_child(NamedTreeNode("child1", "child1_data")) + >>> root["child2"] = "child2_data" + + >>> # Access data + >>> print(root.get("child1")) + 'child1_data' + >>> print(root["child2"]) + 'child2_data' + + >>> # Update with dictionary + >>> root |= { + ... "name": "child3", + ... "data": "child3_data", + ... "children": {} + ... } + """ + + # Class-level configuration + config: ClassVar[NodeConfig] = NodeConfig() + + def __init__( + self, + name: str = 'root', + data: Optional[NodeData] = None, + separator: Optional[str] = None + ) -> None: + """ + Initialize a named tree node. + + Parameters + ---------- + name : str, default 'root' + The name of the node. Must be a valid identifier. + data : Optional[NodeData], default None + The data to store in the node. + separator : Optional[str], default None + The separator to use for domain strings. If None, uses class default. + + Raises + ------ + ValidationError + If the name is invalid or data validation fails. + """ + self._validate_name(name) + self._name = name + self._data = data + self._children = {} + self._data_type = None + self._data_type_inferred = False + self._separator = separator or self.config.separator + + def _infer_data_type(self) -> None: + """Infer the data type when needed.""" + if not self._data_type_inferred: + if hasattr(self, "__orig_class__"): + # Extract the type from the instantiated generic class + type_args = get_args(self.__orig_class__) + if type_args: + self._data_type = type_args[0] + self._data_type_inferred = True + + def _validate_name(self, name: str) -> None: + """Validate node name.""" + + if not isinstance(name, str): + raise ValidationError(f"Name must be a string, got {type(name)}") + + if not name: + raise ValidationError("Name cannot be empty") + + if self.config.validate_names: + if not re.match(self.config.name_pattern, name): + raise ValidationError( + f"Invalid name '{name}'. Must match pattern: {self.config.name_pattern}" + ) + + def _validate_data(self, data: Optional[NodeData]) -> Optional[NodeData]: + """ + Validate node data. + + Parameters + ---------- + data : Optional[NodeData] + Data to validate + + Returns + ------- + Optional[NodeData] + Validated data + + Raises + ------ + ValidationError + If data validation fails + """ + # Handle None data + if data is None: + if not self.config.allow_none_data: + raise ValidationError("None data not allowed") + return None + + # Perform type validation if enabled and type is known + if self.config.validate_data_type and self.data_type is not None: + if not check_type(data, self.data_type): + raise ValidationError( + f"Data type mismatch: expected {self._data_type.__name__}, " + f"got {type(data).__name__}" + ) + + return data + + def _split_domain(self, domain: Optional[str] = None) -> List[str]: + """Split domain string into components.""" + if not domain: + return [] + + if not isinstance(domain, str): + raise DomainError(f"Domain must be a string, got {type(domain)}") + + return domain.split(self._separator) + + def _format_domain(self, *components: Optional[str]) -> str: + """Format domain components into a domain string.""" + return self._separator.join(comp for comp in components if comp) + + def __setitem__(self, domain: str, data: NodeData) -> None: + """ + Set data for a node at the specified domain. + + Parameters + ---------- + domain : str + The domain path to the node + data : NodeData + The data to set + + Examples + -------- + >>> root = NamedTreeNode[str]("root") + >>> root["child.grandchild"] = "data" + >>> print(root["child.grandchild"]) + 'data' + """ + try: + self.set(data=data, domain=domain) + except Exception as e: + raise DomainError(f"Failed to set item at '{domain}': {str(e)}") from e + + def __getitem__(self, domain: str) -> NodeData: + """ + Get data from a node at the specified domain. + + Parameters + ---------- + domain : str + The domain path to the node + + Returns + ------- + NodeData + The data at the specified domain + + Raises + ------ + KeyError + If the domain doesn't exist + + Examples + -------- + >>> root = NamedTreeNode[str]("root") + >>> root["child"] = "data" + >>> print(root["child"]) + 'data' + """ + node = self.traverse_domain(domain, create=False) + if node is None: + raise KeyError(f"Domain not found: '{domain}'") + return node.data + + def __or__(self, other: Union[NamedTreeNode[NodeData], Dict[str, Any]]) -> NamedTreeNode[NodeData]: + """ + Combine this node with another node or dictionary. + + Parameters + ---------- + other : Union[NamedTreeNode[NodeData], Dict[str, Any]] + The other node or dictionary to combine with + + Returns + ------- + NamedTreeNode[NodeData] + A new node combining both trees + + Examples + -------- + >>> node1 = NamedTreeNode[str]("node1", "data1") + >>> node2 = NamedTreeNode[str]("node2", "data2") + >>> combined = node1 | node2 + >>> print(combined.name) + 'node2' + """ + if isinstance(other, dict): + other = self.from_dict(other) + elif not isinstance(other, NamedTreeNode): + raise TypeError("Can only combine with another NamedTreeNode or dict") + + new_node = self.create(self._name, self._data) + new_node.update(self) + new_node.update(other) + return new_node + + def __ior__(self, other: Union[NamedTreeNode[NodeData], Dict[str, Any]]) -> NamedTreeNode[NodeData]: + """ + Update this node with another node or dictionary in-place. + + Parameters + ---------- + other : Union[NamedTreeNode[NodeData], Dict[str, Any]] + The other node or dictionary to update with + + Returns + ------- + NamedTreeNode[NodeData] + This node, updated + + Examples + -------- + >>> node = NamedTreeNode[str]("node", "old_data") + >>> node |= {"name": "node", "data": "new_data"} + >>> print(node.data) + 'new_data' + """ + self.update(other) + return self + + def __ror__(self, other: Dict[str, Any]) -> NamedTreeNode[NodeData]: + """ + Combine a dictionary with this node. + + Parameters + ---------- + other : Dict[str, Any] + The dictionary to combine with + + Returns + ------- + NamedTreeNode[NodeData] + A new node combining both + """ + new_node = self.from_dict(other) + return new_node | self + + def __contains__(self, domain: str) -> bool: + """ + Check if a domain exists in the tree. + + Parameters + ---------- + domain : str + The domain to check for + + Returns + ------- + bool + True if the domain exists + + Examples + -------- + >>> root = NamedTreeNode[str]("root") + >>> root["child"] = "data" + >>> print("child" in root) + True + """ + try: + return self.traverse_domain(domain, create=False) is not None + except DomainError: + return False + + def __copy__(self) -> NamedTreeNode[NodeData]: + """Create a shallow copy.""" + new_node = self.create(self._name, self._data) + new_node._children = self._children.copy() + return new_node + + def __deepcopy__(self, memo: Dict[int, Any]) -> NamedTreeNode[NodeData]: + """Create a deep copy.""" + new_node = self.create(self._name, copy.deepcopy(self._data, memo)) + new_node._children = { + name: copy.deepcopy(child, memo) + for name, child in self._children.items() + } + return new_node + + def __repr__(self, level: int = 0) -> str: + """ + Create a string representation of the tree. + + Parameters + ---------- + level : int, default 0 + The current indentation level + + Returns + ------- + str + A formatted string representation + """ + indent = " " * level + result = [f"{indent}{self._name}: {repr(self._data)}"] + + for child in self._children.values(): + result.append(child.__repr__(level + 1)) + + return "\n".join(result) + + def __iter__(self) -> Iterator[NamedTreeNode[NodeData]]: + """ + Iterate over child nodes. + + Yields + ------ + NamedTreeNode[NodeData] + Each child node + + Examples + -------- + >>> root = NamedTreeNode[str]("root") + >>> root["child1"] = "data1" + >>> root["child2"] = "data2" + >>> for child in root: + ... print(child.name) + child1 + child2 + """ + return iter(self._children.values()) + + @classmethod + def from_dict( + cls, + data: Dict[str, Any] + ) -> NamedTreeNode[NodeData]: + """ + Create a new tree from a dictionary. + + Parameters + ---------- + data : Dict[str, Any] + Dictionary containing node data and children + + Returns + ------- + NamedTreeNode[NodeData] + A new tree node + + Examples + -------- + >>> data = { + ... "name": "root", + ... "data": "root_data", + ... "children": { + ... "child1": { + ... "name": "child1", + ... "data": "child1_data" + ... } + ... } + ... } + >>> root = NamedTreeNode[str].from_dict(data) + """ + if not isinstance(data, dict): + raise TypeError("Expected dictionary input") + + name = data.get('name', 'root') + node_data = data.get('data') + node = cls(name, node_data) + + children = data.get('children', {}) + if not isinstance(children, dict): + raise TypeError("Children must be a dictionary") + + for child_name, child_data in children.items(): + node.add_child(cls.from_dict(child_data)) + + return node + + @classmethod + def from_mapping( + cls, + data: Mapping[str, NodeData] + ) -> NamedTreeNode[NodeData]: + """ + Create a new tree from a mapping. + + Parameters + ---------- + data : Mapping[str, NodeData] + Mapping containing node data + + Returns + ------- + NamedTreeNode[NodeData] + A new tree node + + Examples + -------- + >>> data = { + ... None: "root_data", + ... "child1": "child1_data", + ... "child2": "child2_data" + ... } + >>> root = NamedTreeNode[str].from_mapping(data) + """ + node = cls() + node.set(data.get(None)) + + for name, value in data.items(): + if name is not None: + node[name] = value + + return node + + @classmethod + def create( + cls, + name: str, + data: Optional[NodeData] = None + ) -> NamedTreeNode[NodeData]: + """ + Create a new tree node. + + Parameters + ---------- + name : str + Name for the node + data : Optional[NodeData] + Data for the node + + Returns + ------- + NamedTreeNode[NodeData] + A new tree node + """ + return cls(name, data) + + @property + def name(self) -> str: + """Get the node's name.""" + return self._name + + @property + def data(self) -> Optional[NodeData]: + """Get the node's data.""" + return self._data + + @property + def data_type(self) -> Optional[Type[NodeData]]: + """Access the inferred data type.""" + self._infer_data_type() + return self._data_type + + @property + def namespaces(self) -> List[str]: + """ + Get list of immediate child names. + + Returns + ------- + List[str] + List of child names + """ + return list(self._children.keys()) + + @property + def domains(self) -> List[str]: + """ + Get list of all domain paths in the tree. + + Returns + ------- + List[str] + List of domain paths + + Examples + -------- + >>> root = NamedTreeNode[str]("root") + >>> root["a.b"] = "data1" + >>> root["a.c"] = "data2" + >>> print(root.domains) + ['a.b', 'a.c'] + """ + result = [] + for namespace, node in self._children.items(): + subdomains = node.domains + if not subdomains: + result.append(namespace) + else: + result.extend([ + self._format_domain(namespace, subdomain) + for subdomain in subdomains + ]) + return result + + def format(self, *components: Optional[str]) -> str: + """Format domain components into a domain string.""" + return self._format_domain(*components) + + def copy(self, deep: bool = False) -> NamedTreeNode[NodeData]: + """ + Create a copy of the tree. + + Parameters + ---------- + deep : bool, default False + If True, creates a deep copy + + Returns + ------- + NamedTreeNode[NodeData] + A copy of the tree + + Examples + -------- + >>> root = NamedTreeNode[str]("root", "data") + >>> root["child"] = "child_data" + >>> copy1 = root.copy() # Shallow copy + >>> copy2 = root.copy(deep=True) # Deep copy + """ + return self.__deepcopy__({}) if deep else self.__copy__() + + def add_child(self, child_node: NamedTreeNode[NodeData]) -> None: + """ + Add a child node to the tree. + + Parameters + ---------- + child_node : NamedTreeNode[NodeData] + The child node to add + + Raises + ------ + TypeError + If child_node is not a NamedTreeNode + ValidationError + If child node validation fails + + Examples + -------- + >>> root = NamedTreeNode[str]("root") + >>> child = NamedTreeNode("child", "data") + >>> root.add_child(child) + """ + if not isinstance(child_node, NamedTreeNode): + raise TypeError("Child must be a NamedTreeNode instance") + + self._validate_child(child_node) + self._children[child_node.name] = child_node + + def _validate_child(self, child: NamedTreeNode[NodeData]) -> None: + """Validate a child node before adding.""" + if child.name in self._children: + raise ValidationError(f"Child name '{child.name}' already exists") + + def get_child( + self, + name: str, + default: Optional[NamedTreeNode[NodeData]] = None + ) -> Optional[NamedTreeNode[NodeData]]: + """ + Get a child node by name. + + Parameters + ---------- + name : str + Name of the child node + default : Optional[NamedTreeNode[NodeData]], default None + Value to return if child doesn't exist + + Returns + ------- + Optional[NamedTreeNode[NodeData]] + The child node or default value + + Examples + -------- + >>> root = NamedTreeNode[str]("root") + >>> root["child"] = "data" + >>> child = root.get_child("child") + >>> print(child.data) + 'data' + """ + return self._children.get(name, default) + + def remove_child(self, name: str) -> Optional[NamedTreeNode[NodeData]]: + """ + Remove and return a child node. + + Parameters + ---------- + name : str + Name of the child to remove + + Returns + ------- + Optional[NamedTreeNode[NodeData]] + The removed child node, or None if not found + + Examples + -------- + >>> root = NamedTreeNode[str]("root") + >>> root["child"] = "data" + >>> removed = root.remove_child("child") + >>> print(removed.data) + 'data' + """ + return self._children.pop(name, None) + + def traverse( + self, + *namespaces: str, + create: bool = False + ) -> Optional[NamedTreeNode[NodeData]]: + """ + Traverse the tree through multiple namespaces. + + Parameters + ---------- + *namespaces : str + Sequence of namespace names to traverse + create : bool, default False + Whether to create missing nodes during traversal + + Returns + ------- + Optional[NamedTreeNode[NodeData]] + The final node or None if not found + + Examples + -------- + >>> root = NamedTreeNode[str]("root") + >>> node = root.traverse("a", "b", "c", create=True) + >>> node.data = "data" + >>> print(root["a.b.c"]) + 'data' + """ + node = self + for namespace in namespaces: + if not namespace: + continue + + subnode = node._children.get(namespace) + if subnode is None: + if create: + subnode = self.create(namespace) + node.add_child(subnode) + else: + return None + node = subnode + return node + + def traverse_domain( + self, + domain: Optional[str] = None, + create: bool = False + ) -> Optional[NamedTreeNode[NodeData]]: + """ + Traverse the tree using a domain string. + + Parameters + ---------- + domain : Optional[str] + Domain path (e.g., "parent.child.grandchild") + create : bool, default False + Whether to create missing nodes during traversal + + Returns + ------- + Optional[NamedTreeNode[NodeData]] + The final node or None if not found + + Examples + -------- + >>> root = NamedTreeNode[str]("root") + >>> node = root.traverse_domain("a.b.c", create=True) + >>> node.data = "data" + >>> print(root.get("a.b.c")) + 'data' + """ + components = self._split_domain(domain) + return self.traverse(*components, create=create) + + def get( + self, + domain: Optional[str] = None, + default: Any = None, + strict: bool = False + ) -> Optional[NodeData]: + """ + Get data from a node at the specified domain. + + Parameters + ---------- + domain : Optional[str] + Domain path to the node + default : Any, default None + Value to return if node not found + strict : bool, default False + If True, raises KeyError for missing nodes + + Returns + ------- + Optional[NodeData] + The node's data or default value + + Raises + ------ + KeyError + If strict=True and node not found + + Examples + -------- + >>> root = NamedTreeNode[str]("root") + >>> root["a.b"] = "data" + >>> print(root.get("a.b")) + 'data' + >>> print(root.get("x.y", default="not found")) + 'not found' + """ + node = self.traverse_domain(domain) + if strict and node is None: + raise KeyError(f"Domain not found: '{domain}'") + return node.data if node is not None else default + + def set( + self, + data: NodeData, + domain: Optional[str] = None + ) -> None: + """ + Set data for a node at the specified domain. + + Parameters + ---------- + data : NodeData + The data to set + domain : Optional[str] + Domain path to the node + + Examples + -------- + >>> root = NamedTreeNode[str]("root") + >>> root.set("data", "a.b.c") + >>> print(root.get("a.b.c")) + 'data' + """ + node = self.traverse_domain(domain, create=True) + if node: + node._data = self._validate_data(data) + + def update( + self, + other: Union[NamedTreeNode[NodeData], Dict[str, Any]] + ) -> None: + """ + Update the tree with another tree or dictionary. + + Parameters + ---------- + other : Union[NamedTreeNode[NodeData], Dict[str, Any]] + The source to update from + + Examples + -------- + >>> root = NamedTreeNode[str]("root") + >>> root.update({ + ... "name": "root", + ... "data": "new_data", + ... "children": { + ... "child": {"name": "child", "data": "child_data"} + ... } + ... }) + """ + if isinstance(other, dict): + other = self.from_dict(other) + elif not isinstance(other, NamedTreeNode): + raise TypeError( + "Expected NamedTreeNode or dict, " + f"got {type(other).__name__}" + ) + + # Update name and data + self._name = other.name + self._data = self._validate_data(other.data) + + # Update children + for name, child in other._children.items(): + if name in self._children: + self._children[name].update(child) + else: + self._children[name] = child + + def to_dict(self) -> Dict[str, Any]: + """ + Convert the tree to a dictionary representation. + + Returns + ------- + Dict[str, Any] + Dictionary representation of the tree + + Examples + -------- + >>> root = NamedTreeNode[str]("root", "data") + >>> root["child"] = "child_data" + >>> dict_repr = root.to_dict() + >>> print(dict_repr['children']['child']['data']) + 'child_data' + """ + return { + "name": self._name, + "data": self._data, + "children": { + name: child.to_dict() + for name, child in self._children.items() + } + } + + def clear(self) -> None: + """ + Remove all children from the tree. + + Examples + -------- + >>> root = NamedTreeNode[str]("root") + >>> root["a"] = "data1" + >>> root["b"] = "data2" + >>> root.clear() + >>> print(len(root.children)) + 0 + """ + self._children.clear() + + def merge( + self, + other: NamedTreeNode[NodeData], + strategy: str = 'replace' + ) -> None: + """ + Merge another tree into this one. + + Parameters + ---------- + other : NamedTreeNode[NodeData] + The tree to merge from + strategy : str, default 'replace' + Merge strategy ('replace' or 'keep') + + Examples + -------- + >>> tree1 = NamedTreeNode[str]("root") + >>> tree1["a"] = "data1" + >>> tree2 = NamedTreeNode[str]("root") + >>> tree2["b"] = "data2" + >>> tree1.merge(tree2) + """ + if not isinstance(other, NamedTreeNode): + raise TypeError("Can only merge with another NamedTreeNode") + + if strategy not in {'replace', 'keep'}: + raise ValueError( + "Invalid merge strategy. Must be 'replace' or 'keep'" + ) + + # Merge data if needed + if strategy == 'replace' or self._data is None: + self._data = self._validate_data(other.data) + + # Merge children + for name, other_child in other._children.items(): + if name in self._children: + self._children[name].merge(other_child, strategy) + else: + self._children[name] = other_child.copy(deep=True) \ No newline at end of file diff --git a/quickstats/core/type_validation.py b/quickstats/core/type_validation.py index d10cd783ee69d2695fb5e92ec1a68c53774add96..faa60e44ed7533fff8041ba26c8c9028561626eb 100644 --- a/quickstats/core/type_validation.py +++ b/quickstats/core/type_validation.py @@ -1,213 +1,380 @@ +""" +Runtime type validation utilities for Python type hints. + +This module provides efficient runtime validation of objects against Python type hints, +including support for generics, unions, literals, and other complex type annotations. +Designed for Python 3.8+ with focus on performance and extensibility. +""" + +from __future__ import annotations + +__all__ = [ + 'ValidatorFactory', + 'check_type', + 'get_type_hint_str', + 'get_annotation_str', + 'get_type_validator' +] + from collections import abc from functools import lru_cache -import types -from typing import Any, Callable, Dict, List, Tuple, Union, get_args, get_origin - -# types.UnionType requires python 3.10+ -try: - from types import UnionType - union_like_types = {Union, UnionType} -except ImportError: - union_like_types = {Union} +from typing import ( + Any, Callable, Dict, List, Tuple, Union, Optional, + TypeVar, get_args, get_origin, Type, final +) +# Handle Literal type availability try: from typing import Literal - has_literal = True + HAS_LITERAL = True except ImportError: - Literal = None - has_literal = False + try: + from typing_extensions import Literal + HAS_LITERAL = True + except ImportError: + Literal = None + HAS_LITERAL = False +# Type definitions +ValidatorFunc = Callable[[Any], bool] +TypeHint = Any +T = TypeVar('T') + +class ValidationError(TypeError): + """ + Exception raised for type validation errors. + + Parameters + ---------- + message : str + Error description + expected_type : Optional[TypeHint] + Expected type annotation + received_type : Optional[Type] + Actual type received + value : Optional[Any] + Value that failed validation + + Examples + -------- + >>> raise ValidationError("Invalid type", List[int], str, "not a list") + ValidationError: Invalid type - expected List[int], got str ('not a list') + """ + + def __init__( + self, + message: str, + expected_type: Optional[TypeHint] = None, + received_type: Optional[Type] = None, + value: Any = None + ) -> None: + self.expected_type = expected_type + self.received_type = received_type + self.value = value + + if expected_type is not None and received_type is not None: + message = ( + f"{message} - expected {get_type_hint_str(expected_type)}, " + f"got {received_type.__name__}" + ) + if value is not None: + message = f"{message} ({value!r})" + + super().__init__(message) + + +@final class ValidatorFactory: """ - A factory class for creating validators that check if objects match specified type hints. - - Example Usage: - - Validate an integer list: - ``` - validator = ValidatorFactory.get_validator(List[int]) - print(validator([1, 2, 3])) # Expected: True - ``` - - Validate a dictionary with string keys and integer values: - ``` - validator = ValidatorFactory.get_validator(Dict[str, int]) - print(validator({'key': 1})) # Expected: True - ``` - - Validate a tuple of fixed types: - ``` - validator = ValidatorFactory.get_validator(Tuple[int, str]) - print(validator((1, 'a'))) # Expected: True - ``` - - Validate a variable-length tuple of integers: - ``` - validator = ValidatorFactory.get_validator(Tuple[int, ...]) - print(validator((1, 2, 3))) # Expected: True - ``` + Factory for creating and caching type validators. + + This class provides efficient creation and caching of validator functions + for various type hints. It handles complex type hierarchies and supports + custom validation rules. + + Examples + -------- + >>> validator = ValidatorFactory.get_validator(List[int]) + >>> validator([1, 2, 3]) + True + >>> validator(['a', 'b']) + False + + >>> # Validate nested types + >>> nested_validator = ValidatorFactory.get_validator(Dict[str, List[int]]) + >>> nested_validator({'nums': [1, 2, 3]}) + True """ + CACHE_SIZE: int = 1024 + + @staticmethod + def _is_optional(type_hint: TypeHint) -> bool: + """Check if type hint is Optional[T].""" + if get_origin(type_hint) is Union: + args = get_args(type_hint) + return type(None) in args + return False + @staticmethod - @lru_cache(maxsize=None) - def create_union_validator(type_args): - non_generic_types = tuple(arg for arg in type_args if not get_args(arg)) - generic_type_validators = [ValidatorFactory.get_validator(arg) for arg in type_args if get_args(arg)] + @lru_cache(maxsize=CACHE_SIZE) + def create_union_validator(type_args: Tuple[TypeHint, ...]) -> ValidatorFunc: + """Create an optimized validator for Union types.""" + simple_types = tuple( + arg for arg in type_args + if not get_args(arg) and not isinstance(arg, TypeVar) + ) - def validate(source): - return isinstance(source, non_generic_types) or any(validator(source) for validator in generic_type_validators) + complex_validators = tuple( + ValidatorFactory.get_validator(arg) + for arg in type_args + if get_args(arg) or isinstance(arg, TypeVar) + ) + if not complex_validators: + return lambda obj: isinstance(obj, simple_types) + + if not simple_types: + return lambda obj: any(v(obj) for v in complex_validators) + + def validate(obj: Any) -> bool: + return isinstance(obj, simple_types) or any(v(obj) for v in complex_validators) + return validate @staticmethod - @lru_cache(maxsize=None) - def create_sequence_validator(type_arg, container_type=(list, tuple)): - item_validator = ValidatorFactory.get_validator(type_arg[0]) if type_arg else lambda x: True - - def validate(source): - return isinstance(source, container_type) and all(item_validator(item) for item in source) + @lru_cache(maxsize=CACHE_SIZE) + def create_sequence_validator( + item_type: TypeHint, + accepted_types: Tuple[Type, ...] = (list, tuple) + ) -> ValidatorFunc: + """Create an optimized validator for sequence types.""" + item_validator = ( + ValidatorFactory.get_validator(item_type) + if item_type is not Any + else lambda _: True + ) + def validate(obj: Any) -> bool: + if not isinstance(obj, accepted_types): + return False + return all(item_validator(item) for item in obj) + return validate @staticmethod - @lru_cache(maxsize=None) - def create_dict_validator(type_args): - if not type_args: - return lambda source: isinstance(source, dict) - key_validator, value_validator = map(ValidatorFactory.get_validator, type_args) - - def validate(source): - return isinstance(source, dict) and all(key_validator(k) and value_validator(v) for k, v in source.items()) + @lru_cache(maxsize=CACHE_SIZE) + def create_mapping_validator( + key_type: TypeHint, + value_type: TypeHint + ) -> ValidatorFunc: + """Create an optimized validator for mapping types.""" + key_validator = ValidatorFactory.get_validator(key_type) + value_validator = ValidatorFactory.get_validator(value_type) + def validate(obj: Any) -> bool: + if not isinstance(obj, abc.Mapping): + return False + return all( + key_validator(k) and value_validator(v) + for k, v in obj.items() + ) + return validate @staticmethod - @lru_cache(maxsize=None) - def create_tuple_validator(type_args): + @lru_cache(maxsize=CACHE_SIZE) + def create_tuple_validator(type_args: Tuple[TypeHint, ...]) -> ValidatorFunc: + """Create an optimized validator for tuple types.""" if not type_args: - return lambda source: isinstance(source, tuple) - elif len(type_args) == 2 and type_args[1] is Ellipsis: + return lambda obj: isinstance(obj, tuple) + + # Handle Tuple[T, ...] + if len(type_args) == 2 and type_args[1] is Ellipsis: item_validator = ValidatorFactory.get_validator(type_args[0]) - return lambda source: isinstance(source, tuple) and all(item_validator(item) for item in source) - else: - validators = [ValidatorFactory.get_validator(arg) for arg in type_args] - def validate(source): - return isinstance(source, tuple) and len(source) == len(validators) and all(validator(item) for validator, item in zip(validators, source)) - return validate + return lambda obj: ( + isinstance(obj, tuple) and + all(item_validator(item) for item in obj) + ) + + # Handle fixed-length tuples + validators = tuple( + ValidatorFactory.get_validator(arg) + for arg in type_args + ) + + def validate(obj: Any) -> bool: + if not isinstance(obj, tuple) or len(obj) != len(validators): + return False + return all( + validator(item) + for validator, item in zip(validators, obj) + ) + + return validate @staticmethod - @lru_cache(maxsize=None) - def create_literal_validator(type_args): - literal_values = set(type_args) - - def validate(source): - return source in literal_values - - return validate + @lru_cache(maxsize=CACHE_SIZE) + def create_literal_validator(allowed_values: Tuple[Any, ...]) -> ValidatorFunc: + """Create an optimized validator for Literal types.""" + value_set = frozenset(allowed_values) + return lambda obj: obj in value_set @staticmethod - @lru_cache(maxsize=None) - def get_validator(type_hint) -> Callable: + @lru_cache(maxsize=CACHE_SIZE) + def get_validator(type_hint: TypeHint) -> ValidatorFunc: """ - Retrieves a validator function for a given type hint. - - Args: - type_hint: The type hint for which to retrieve the validator. - - Returns: - A validator function that can be used to check if an object matches the type hint. + Get or create a validator function for a type hint. + + Parameters + ---------- + type_hint : TypeHint + Type hint to validate against + + Returns + ------- + ValidatorFunc + Function that validates objects against the type hint """ + if type_hint is Any: + return lambda _: True + origin = get_origin(type_hint) + if origin is None: + if isinstance(type_hint, TypeVar): + if type_hint.__constraints__: + return ValidatorFactory.create_union_validator( + type_hint.__constraints__ + ) + if type_hint.__bound__: + return ValidatorFactory.get_validator(type_hint.__bound__) + return lambda _: True + return lambda obj: isinstance(obj, type_hint) + args = get_args(type_hint) - if origin in {list, abc.Sequence} and not issubclass(origin, tuple): # exclude tuples from sequence validator - return ValidatorFactory.create_sequence_validator(args) - elif origin == dict: - return ValidatorFactory.create_dict_validator(args) - elif origin in union_like_types: + # Handle Optional types + if ValidatorFactory._is_optional(type_hint): + non_none_types = tuple(arg for arg in args if arg is not type(None)) + if len(non_none_types) == 1: + validator = ValidatorFactory.get_validator(non_none_types[0]) + return lambda obj: obj is None or validator(obj) return ValidatorFactory.create_union_validator(args) - elif origin == tuple: + + # Handle common container types + if origin in (list, abc.Sequence) and not issubclass(origin, tuple): + return ValidatorFactory.create_sequence_validator( + args[0] if args else Any + ) + + if origin in (dict, abc.Mapping): + return ValidatorFactory.create_mapping_validator( + args[0] if args else Any, + args[1] if len(args) > 1 else Any + ) + + if origin == tuple: return ValidatorFactory.create_tuple_validator(args) - elif type_hint == Any: - return lambda _: True - elif has_literal and (origin == Literal): + + if origin == Union: + return ValidatorFactory.create_union_validator(args) + + if HAS_LITERAL and origin is Literal: return ValidatorFactory.create_literal_validator(args) - else: - return lambda source: isinstance(source, type_hint) + + # Handle other generic types + if args: + validator = ValidatorFactory.get_validator(origin) + return lambda obj: validator(obj) and all( + ValidatorFactory.get_validator(arg)(obj) + for arg in args + ) + + return lambda obj: isinstance(obj, origin) -get_type_validator = ValidatorFactory.get_validator -def check_type(obj, type_hint): +def check_type( + obj: Any, + type_hint: TypeHint, + *, + raise_error: bool = False +) -> bool: """ - Checks if an object matches a given type hint. + Check if an object matches a type hint. + + Parameters + ---------- + obj : Any + Object to validate + type_hint : TypeHint + Type hint to validate against + raise_error : bool, optional + If True, raises ValidationError on type mismatch - Args: - obj: The object to be checked. - type_hint: The type hint against which the object is to be validated. - - Returns: - bool: True if the object matches the type hint, False otherwise. - - Examples: - Check if a list only contains integers: - >>> check_type([1, 2, 3], List[int]) - True - - Check if a variable is either a string or a list of strings: - >>> check_type("hello", Union[str, List[str]]) - True - >>> check_type(["hello", "world"], Union[str, List[str]]) - True - - Validate a dictionary with string keys and integer values: - >>> check_type({'key': 42}, Dict[str, int]) - True - - Check against a tuple with specified types: - >>> check_type((1, 'a'), Tuple[int, str]) - True - - Validate a variable-length tuple of integers: - >>> check_type((1, 2, 3, 4), Tuple[int, ...]) - True - - Note: The method will return False if the object does not match the type hint: - >>> check_type([1, 'a', 3], List[int]) - False - >>> check_type({'key': 'value'}, Dict[str, int]) - False + Returns + ------- + bool + True if object matches type hint + + Raises + ------ + ValidationError + If raise_error is True and validation fails """ - return get_type_validator(type_hint)(obj) - -@lru_cache(maxsize=None) -def get_type_hint_str(type_hint) -> str: - """Converts a type hint to its string representation. + validator = ValidatorFactory.get_validator(type_hint) + result = validator(obj) + + if not result and raise_error: + raise ValidationError( + "Type mismatch", + expected_type=type_hint, + received_type=type(obj), + value=obj + ) + return result - Args: - type_hint: The type hint to convert. - Returns: - A string representation of the type hint. - """ +@lru_cache(maxsize=ValidatorFactory.CACHE_SIZE) +def get_type_hint_str(type_hint: TypeHint) -> str: + """Get human-readable string representation of a type hint.""" + if type_hint is type(None): + return 'None' + origin = get_origin(type_hint) - if origin: - args = get_args(type_hint) - if not args: - return origin.__name__ - elif origin in union_like_types: - # Process Union types, specifically checking for Optional by detecting NoneType - non_none_args = [arg for arg in args if arg is not type(None)] - args_str = " | ".join(get_type_hint_str(arg) for arg in non_none_args) - return f"Optional[{args_str}]" if type(None) in args else args_str - else: - # Process other generic types like List, Dict - args_str = ", ".join(get_type_hint_str(arg) for arg in args) - return f"{origin.__name__}[{args_str}]" - elif hasattr(type_hint, '__name__'): - return type_hint.__name__ - else: - return str(type_hint) - -@lru_cache(maxsize=None) -def get_annotation_str(annotation) -> str: - return get_type_hint_str(annotation) \ No newline at end of file + if origin is None: + if isinstance(type_hint, TypeVar): + if type_hint.__constraints__: + constraints = ' | '.join( + get_type_hint_str(t) for t in type_hint.__constraints__ + ) + return f"{type_hint.__name__}[{constraints}]" + if type_hint.__bound__: + return f"{type_hint.__name__}[{get_type_hint_str(type_hint.__bound__)}]" + return type_hint.__name__ + return getattr(type_hint, '__name__', str(type_hint)) + + args = get_args(type_hint) + if not args: + return origin.__name__ + + if origin is Union: + non_none_args = tuple(arg for arg in args if arg is not type(None)) + args_str = ' | '.join(map(get_type_hint_str, non_none_args)) + + return ( + f"Optional[{args_str}]" + if type(None) in args + else args_str + ) + + if HAS_LITERAL and origin is Literal: + values_str = ' | '.join(repr(arg) for arg in args) + return f"Literal[{values_str}]" + + args_str = ', '.join(map(get_type_hint_str, args)) + return f"{origin.__name__}[{args_str}]" + + +# Alias for backward compatibility +get_type_validator = ValidatorFactory.get_validator +get_annotation_str = get_type_hint_str \ No newline at end of file diff --git a/quickstats/core/typing.py b/quickstats/core/typing.py index 7a0e8789e01753d0d68ef76095010077f3f1858c..5558717bc3f86be3cdc9da1ea504d49024676d4c 100644 --- a/quickstats/core/typing.py +++ b/quickstats/core/typing.py @@ -1,8 +1,10 @@ import numbers -from typing import Union, final, Any +from typing import Union, final, Any, TypeVar, Tuple, List + +import numpy as np from numpy.typing import ArrayLike -__all__ = ["Numeric", "Scalar", "Real", "ArrayLike", "NOTSET", "NOTSETTYPE"] +__all__ = ["Numeric", "Scalar", "Real", "ArrayLike", "NOTSET", "NOTSETTYPE", "T"] Numeric = Union[int, float] @@ -10,6 +12,8 @@ Scalar = Numeric Real = numbers.Real +ArrayContainer = Union[Tuple[ArrayLike, ...], List[ArrayLike], np.ndarray] + @final class NOTSETTYPE: """A type used as a sentinel for unspecified values.""" @@ -20,4 +24,6 @@ class NOTSETTYPE: def __deepcopy__(self, memo: Any): return self -NOTSET = NOTSETTYPE() \ No newline at end of file +NOTSET = NOTSETTYPE() + +T = TypeVar('T') \ No newline at end of file diff --git a/quickstats/core/virtual_trees.py b/quickstats/core/virtual_trees.py index aa55d2b3ef3256c34e5a2ed1a1b63871de824f06..e50c89755aa4001eb005553531022c6517bcecf8 100644 --- a/quickstats/core/virtual_trees.py +++ b/quickstats/core/virtual_trees.py @@ -2,8 +2,6 @@ from typing import Optional, List from .decorators import semistaticmethod -__all__ = ['TVirtualNode', 'TVirtualTree'] - class TVirtualNode: def __init__(self, name:Optional[str]=None, level:Optional[int]=0, diff --git a/quickstats/interface/pydantic/default_model.py b/quickstats/interface/pydantic/default_model.py index 372c8164ec573961f8ab4c779839eb1996479ba8..64367f47f510ba620fe4a5d5926e63e513d11d7b 100644 --- a/quickstats/interface/pydantic/default_model.py +++ b/quickstats/interface/pydantic/default_model.py @@ -3,7 +3,7 @@ from typing import Any, List, Optional, Union from pydantic import Field, BaseModel, ConfigDict, model_validator from quickstats import VerbosePrint, check_type, FlexibleDumper, get_type_hint_str -from quickstats.utils.string_utils import format_dict_to_string +from quickstats.utils.string_utils import format_aligned_dict from .alias_generators import to_pascal __all__ = ['DefaultModel'] @@ -59,7 +59,7 @@ class DefaultModel(BaseModel): _dumper.configure(**kwargs) @classmethod - def generate_help_text(cls, line_break: int = 100) -> str: + def generate_help_text(cls, linebreak: int = 100) -> str: """ Generate help text for the class, displaying the class name and field info. """ @@ -77,15 +77,14 @@ class DefaultModel(BaseModel): 'Default': field_default, 'Description': field_description } - attributes_text = format_dict_to_string(attributes, left_margin=4, line_break=line_break) + attributes_text = format_aligned_dict(attributes, left_margin=4, linebreak=linebreak) help_text += f' {field_name}:\n' help_text += f'{attributes_text}\n' return help_text @classmethod - def help(cls, line_break: int = 100) -> None: - """ +linebreak """ Prints out the help message for the class. """ - print(cls.generate_help_text(line_break=line_break)) \ No newline at end of file + print(cls.generate_help_text(linebreak=linebreak)) \ No newline at end of file diff --git a/quickstats/maths/histograms.py b/quickstats/maths/histograms.py new file mode 100644 index 0000000000000000000000000000000000000000..bc12231be67a79e95d04efcc6e461cdcdfd7d413 --- /dev/null +++ b/quickstats/maths/histograms.py @@ -0,0 +1,1169 @@ +""" +Enhanced histogram utilities for numerical data analysis. +""" + +from __future__ import annotations + +from typing import ( + Union, Optional, List, Tuple, Sequence, Callable, + TypeVar, Any, cast +) +from numbers import Real +from dataclasses import dataclass + +import numpy as np + +from quickstats import DescriptiveEnum +from quickstats.core.typing import ArrayLike +from .numerics import array_issubset, safe_div +from .statistics import poisson_interval + +BinType = Union[int, ArrayLike] +RangeType = Optional[Union[Tuple[Real, Real], List[Real]]] +HistoMaskType = Union[ArrayLike, Callable] + +class HistogramError(Exception): + """Base exception for histogram-related errors.""" + pass + +class BinError(HistogramError): + """Exception for bin-related errors.""" + pass + +class BinErrorMode(DescriptiveEnum): + AUTO = (0, "Determine bin error method from data weights") + SUMW2 = (1, "Errors with Wald approximation: sqrt(sum of weight^2)") + POISSON = (2, "Errors from Poisson interval at 68.3% (1 sigma)") + +class HistComparisonMode(DescriptiveEnum): + RATIO = (0, "Ratio of data (target / reference)") + DIFFERENCE = (1, "Difference of data (target - reference)") + +@dataclass +class HistogramConfig: + """Configuration for histogram operations.""" + bin_precision: int = 8 + ghost_threshold: float = 1e-8 + rtol: float = 1e-5 + atol: float = 1e-8 + +# Global configuration +CONFIG = HistogramConfig() + +def bin_edge_to_bin_center(bin_edges: ArrayLike) -> np.ndarray: + """ + Calculate bin centers from bin edges. + + Parameters + ---------- + bin_edges : ArrayLike + The edges of the bins + + Returns + ------- + np.ndarray + The centers of the bins + + Examples + -------- + >>> bin_edges = [0, 1, 2, 3] + >>> bin_edge_to_bin_center(bin_edges) + array([0.5, 1.5, 2.5]) + """ + bin_edges = np.asarray(bin_edges) + return (bin_edges[:-1] + bin_edges[1:]) / 2 + +def bin_center_to_bin_edge(bin_centers: ArrayLike) -> np.ndarray: + """ + Calculate bin edges from bin centers. + + Parameters + ---------- + bin_centers : ArrayLike + The centers of the bins + + Returns + ------- + np.ndarray + The edges of the bins + + Raises + ------ + BinError + If bin centers have irregular widths + + Examples + -------- + >>> bin_centers = [0.5, 1.5, 2.5] + >>> bin_center_to_bin_edge(bin_centers) + array([0., 1., 2., 3.]) + """ + try: + bin_centers = np.asarray(bin_centers) + bin_widths = np.round(np.diff(bin_centers), CONFIG.bin_precision) + + if not np.allclose( + bin_widths, + bin_widths[0], + rtol=CONFIG.rtol, + atol=CONFIG.atol + ): + raise BinError("Cannot deduce edges from centers with irregular widths") + + bin_width = bin_widths[0] + return np.concatenate([ + bin_centers - bin_width / 2, + [bin_centers[-1] + bin_width / 2] + ]) + + except Exception as e: + raise BinError(f"Failed to convert centers to edges: {str(e)}") from e + +def bin_edge_to_bin_width(bin_edges: ArrayLike) -> np.ndarray: + """ + Calculate bin widths from bin edges. + + Parameters + ---------- + bin_edges : ArrayLike + The edges of the bins + + Returns + ------- + np.ndarray + The widths of the bins + + Examples + -------- + >>> bin_edges = [0, 1, 2, 3] + >>> bin_edge_to_bin_width(bin_edges) + array([1, 1, 1]) + """ + bin_edges = np.asarray(bin_edges) + return np.diff(bin_edges) + +def get_clipped_data( + x: np.ndarray, + bin_range: Optional[Sequence] = None, + clip_lower: bool = True, + clip_upper: bool = True +) -> np.ndarray: + """ + Clip data within specified range. + + Parameters + ---------- + x : np.ndarray + Data to be clipped + bin_range : Optional[Sequence], default None + Range (min, max) for clipping + clip_lower : bool, default True + Whether to clip at lower bound + clip_upper : bool, default True + Whether to clip at upper bound + + Returns + ------- + np.ndarray + Clipped array + + Examples + -------- + >>> x = np.array([1, 5, 10, 15, 20]) + >>> get_clipped_data(x, (5, 15)) + array([ 5, 5, 10, 15, 15]) + """ + if bin_range is None or (not clip_lower and not clip_upper): + return np.array(x) + + xmin = bin_range[0] if clip_lower else None + xmax = bin_range[1] if clip_upper else None + return np.clip(x, xmin, xmax) + +def normalize_range( + bin_range: Optional[Tuple[Optional[float], ...]] = None, + dimension: int = 1 +) -> Tuple[Optional[float], ...]: + """ + Normalize range for each dimension. + + Parameters + ---------- + bin_range : Optional[Tuple[Optional[float], ...]], default None + Range for each dimension + dimension : int, default 1 + Number of dimensions + + Returns + ------- + Tuple[Optional[float], ...] + Normalized range + + Raises + ------ + BinError + If range doesn't match dimensions + + Examples + -------- + >>> normalize_range((0, 1), dimension=2) + (0, 1) + """ + try: + if bin_range is None: + return tuple(None for _ in range(dimension)) + + if len(bin_range) != dimension: + raise BinError( + f"Range must have {dimension} entries, got {len(bin_range)}" + ) + + return tuple(bin_range) + + except Exception as e: + raise BinError(f"Failed to normalize range: {str(e)}") from e + +def get_histogram_bins( + sample: ArrayLike, + bins: BinType = 10, + bin_range: RangeType = None, + dimensions: int = 1 +) -> Tuple[np.ndarray, ...]: + """ + Calculate histogram bins for given sample. + + Parameters + ---------- + sample : ArrayLike + Input sample data (array or sequence of arrays) + bins : BinType, default 10 + Number of bins or bin edges for each dimension + bin_range : RangeType, default None + Range for each dimension + dimensions : int, default 1 + Number of dimensions + + Returns + ------- + Tuple[np.ndarray, ...] + Bin edges for each dimension + + Raises + ------ + BinError + For invalid bin specifications or dimensions + + Examples + -------- + >>> sample = np.array([[1, 2], [3, 4], [5, 6]]) + >>> edges = get_histogram_bins(sample, bins=[3, 2], dimensions=2) + >>> [e.shape for e in edges] + [(4,), (3,)] + """ + try: + # Convert sample to ND array + try: + num_samples, sample_dimensions = np.atleast_2d(sample).shape + except ValueError as e: + raise BinError("Invalid sample shape") from e + + if sample_dimensions != dimensions: + raise BinError( + f"Sample has {sample_dimensions} dimensions, expected {dimensions}" + ) + + # Initialize arrays + num_bins = np.empty(sample_dimensions, dtype=np.intp) + bin_edges_list = [None] * sample_dimensions + bin_values = _normalize_bins(bins, sample_dimensions) + bin_range = normalize_range(bin_range, dimension=sample_dimensions) + + # Calculate edges for each dimension + for i in range(sample_dimensions): + bin_edges_list[i] = _get_dimension_edges( + sample[:, i], + bin_values[i], + bin_range[i], + dimension_idx=i + ) + num_bins[i] = len(bin_edges_list[i]) + 1 + + return tuple(bin_edges_list) + + except Exception as e: + if not isinstance(e, BinError): + e = BinError(f"Failed to calculate histogram bins: {str(e)}") + raise e + +def _normalize_bins( + bins: BinType, + dimensions: int +) -> List[Union[int, Sequence[float]]]: + """Normalize bin specification to list form.""" + try: + if isinstance(bins, (Sequence, np.ndarray)) and len(bins) == dimensions: + return list(bins) + if isinstance(bins, int): + return [bins] * dimensions + raise BinError( + f"Bins must be integer or sequence of length {dimensions}" + ) + except TypeError as e: + raise BinError("Invalid bin specification") from e + +def _get_dimension_edges( + data: np.ndarray, + bins: Union[int, Sequence[float]], + bin_range: Optional[Sequence[float]], + dimension_idx: int +) -> np.ndarray: + """Calculate bin edges for one dimension.""" + from numpy.lib.histograms import _get_outer_edges + + if np.ndim(bins) == 0: + if bins < 1: + raise BinError( + f"Number of bins must be positive, got {bins} for dimension {dimension_idx}" + ) + + min_val, max_val = _get_outer_edges(data, bin_range) + return np.linspace(min_val, max_val, bins + 1) + + if np.ndim(bins) == 1: + edges = np.asarray(bins) + if not np.all(np.diff(edges) > 0): + raise BinError( + f"Bin edges must be monotonically increasing in dimension {dimension_idx}" + ) + return edges + + raise BinError( + f"Invalid bin specification for dimension {dimension_idx}" + ) + +def _calculate_bin_errors( + bin_content: np.ndarray, + x: np.ndarray, + y: Optional[np.ndarray], + weights: Optional[np.ndarray], + bins: BinType, + bin_range: RangeType, + error_mode: Union[BinErrorMode, str], + unweighted: bool, + norm_factor: float, + is_2d: bool = False +) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: + """Calculate bin errors for histogram.""" + error_mode = BinErrorMode.parse(error_mode) + + if error_mode == BinErrorMode.AUTO: + unweighted = unweighted or np.allclose(weights, np.ones_like(x)) + error_mode = ( + BinErrorMode.POISSON if unweighted + else BinErrorMode.SUMW2 + ) + + if error_mode == BinErrorMode.POISSON: + if is_2d: + errlo, errhi = poisson_interval(bin_content.flatten()) + errors = ( + errlo.reshape(bin_content.shape), + errhi.reshape(bin_content.shape) + ) + else: + errors = poisson_interval(bin_content) + else: # SUMW2 + assert weights is not None + if is_2d: + bin_content_weight2, _, _ = np.histogram2d( + x, y, # type: ignore + bins=bins, + range=bin_range, + weights=weights**2 + ) + else: + bin_content_weight2, _ = np.histogram( + x, + bins=bins, + range=bin_range, + weights=weights**2 + ) + errors = np.sqrt(bin_content_weight2) + + if norm_factor != 1: + if isinstance(errors, tuple): + errors = tuple(err / norm_factor for err in errors) + else: + errors /= norm_factor + + return errors + +def histogram( + x: np.ndarray, + weights: Optional[np.ndarray] = None, + bins: BinType = 10, + bin_range: RangeType = None, + underflow: bool = False, + overflow: bool = False, + divide_bin_width: bool = False, + normalize: bool = True, + clip_weight: bool = False, + evaluate_error: bool = False, + error_mode: Union[BinErrorMode, str] = "auto" +) -> Tuple[np.ndarray, np.ndarray, Optional[Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]]]: + """ + Compute histogram of data. + + Parameters + ---------- + x : np.ndarray + Input data array + weights : Optional[np.ndarray], default None + Weights with the same shape as the input data. If not given, + the input data is assumed to have unit weights. + bins : BinType, default 10 + Bin specification + If integer, defines the number of equal-width bins in the given range. + If sequence, defines a monotonically increasing array of bin edges, + including the rightmost edge. Default is 10. + bin_range : RangeType, default None + Data range for binning + underflow : bool, default False + Include underflow in first bin + overflow : bool, default False + Include overflow in last bin + divide_bin_width : bool, default False + Normalize by bin width; only used when normalize is True + normalize : bool, default True + Normalize total to unity + clip_weight : bool, default False + Ignore out-of-range data for normalization + evaluate_error : bool, default False + Calculate bin errors + error_mode : Union[BinErrorMode, str], default "auto" + Error calculation method; + If "sumw2", symmetric errors from the Wald approximation are + used (square root of sum of squares of weights). If "poisson", + asymmetric errors from Poisson interval at one sigma are used. + If "auto", it will use "sumw2" error if data has unit weights, + else "poisson" error will be used. + + Returns + ------- + Tuple[np.ndarray, np.ndarray, Optional[Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]]] + Tuple of (bin_content, bin_edges, bin_errors) + """ + try: + x = get_clipped_data( + x, + bin_range=bin_range, + clip_lower=underflow, + clip_upper=overflow + ) + + unweighted = weights is None + if unweighted: + weights = np.ones_like(x) + else: + # fix overflow bugs + weights = np.asarray(weights, dtype=float) + + if normalize: + if clip_weight and bin_range is not None: + first_edge, last_edge = bin_range + mask = (x >= first_edge) & (x <= last_edge) + norm_factor = weights[mask].sum() + else: + norm_factor = weights.sum() + else: + norm_factor = 1 + + # make sure bin_content has int type when no weights are given + if unweighted: + bin_content, bin_edges = np.histogram(x, bins=bins, range=bin_range) + else: + bin_content, bin_edges = np.histogram( + x, bins=bins, range=bin_range, weights=weights + ) + + if divide_bin_width: + bin_widths = bin_edge_to_bin_width(bin_edges) + norm_factor *= bin_widths + + bin_errors = None + if evaluate_error: + bin_errors = _calculate_bin_errors( + bin_content, x, None, weights, bins, bin_range, + error_mode, unweighted, norm_factor + ) + + if np.any(norm_factor != 1): + bin_content = bin_content.astype(float, copy=False) + bin_content /= norm_factor + + return bin_content, bin_edges, bin_errors + + except Exception as e: + raise HistogramError(f"Failed to compute histogram: {str(e)}") from e + +def histogram2d( + x: np.ndarray, + y: np.ndarray, + weights: Optional[np.ndarray] = None, + bins: Union[BinType, Sequence[BinType]] = 10, + bin_range: Union[RangeType, Sequence[RangeType]] = None, + underflow: bool = False, + overflow: bool = False, + divide_bin_width: bool = False, + normalize: bool = True, + clip_weight: bool = False, + evaluate_error: bool = False, + error_mode: Union[BinErrorMode, str] = "auto" +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]]]: + """ + Compute 2D histogram of data. + + Parameters + ---------- + x : np.ndarray + X-coordinates of data points + y : np.ndarray + Y-coordinates of data points + weights : Optional[np.ndarray], default None + Weights with the same shape as input data. If not given, the + input data is assumed to have unit weights. + bins : BinType, default 10 + Bin specification for both axes + - If int, the number of bins for the two dimensions (nx=ny=bins). + - If array_like, the bin edges for the two dimensions (x_edges=y_edges=bins). + - If [int, int], the number of bins in each dimension (nx, ny = bins). + - If [array, array], the bin edges in each dimension (x_edges, y_edges = bins). + - A combination [int, array] or [array, int], where int is the number of bins and array is the bin edges. + bin_range : RangeType, default None + The leftmost and rightmost edges of the bins along each dimension: [[xmin, xmax], [ymin, ymax]]. + Values outside of this range will be considered outliers and not tallied in the histogram. + underflow : bool, default False + Include underflow in first bins + overflow : bool, default False + Include overflow in last bins + divide_bin_width : bool, default False + Normalize by bin area; only used when normalize is True + normalize : bool, default True + Normalize the sum of weights to one + clip_weight : bool, default False + Ignore out-of-range data for normalization + evaluate_error : bool, default False + Calculate bin errors + error_mode : Union[BinErrorMode, str], default "auto" + Error calculation method; + If "sumw2", symmetric errors from the Wald approximation are + used (square root of sum of squares of weights). If "poisson", + asymmetric errors from Poisson interval at one sigma are used. + If "auto", it will use "sumw2" error if data has unit weights, + else "poisson" error will be used. + + Returns + ------- + Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]]] + Tuple of (bin_content, x_edges, y_edges, bin_errors) + """ + try: + if len(x) != len(y): + raise ValueError("x and y must have same length") + + bin_range = normalize_range(bin_range, dimension=2) + x = get_clipped_data(x, bin_range[0], underflow, overflow) + y = get_clipped_data(y, bin_range[1], underflow, overflow) + + unweighted = weights is None + if weights is None: + weights = np.ones_like(x) + else: + # fix overflow bugs + weights = np.asarray(weights, dtype=float) + + if normalize: + if clip_weight: + mask = np.ones_like(x, dtype=bool) + if bin_range[0] is not None: + first_edge, last_edge = bin_range[0] + mask &= (x >= first_edge) & (x <= last_edge) + if bin_range[1] is not None: + first_edge, last_edge = bin_range[1] + mask &= (y >= first_edge) & (y <= last_edge) + norm_factor = weights[mask].sum() + else: + norm_factor = weights.sum() + else: + norm_factor = 1 + + if unweighted: + bin_content, x_edges, y_edges = np.histogram2d( + x, y, bins=bins, range=bin_range + ) + else: + bin_content, x_edges, y_edges = np.histogram2d( + x, y, bins=bins, range=bin_range, weights=weights + ) + + if divide_bin_width: + x_widths = np.diff(x_edges)[:, np.newaxis] + y_widths = np.diff(y_edges)[np.newaxis, :] + norm_factor *= (x_widths * y_widths) + + bin_errors = None + if evaluate_error: + bin_errors = _calculate_bin_errors( + bin_content, x, y, weights, bins, bin_range, + error_mode, unweighted, norm_factor, is_2d=True + ) + + if norm_factor != 1: + bin_content = bin_content.astype(float, copy=False) + bin_content /= norm_factor + + return bin_content, x_edges, y_edges, bin_errors + + except Exception as e: + raise HistogramError(f"Failed to compute 2D histogram: {str(e)}") from e + +def get_sumw2(weights: np.ndarray) -> float: + """ + Calculate the sum of squared weights. + + Parameters + ---------- + weights : np.ndarray + The weights to be squared and summed + + Returns + ------- + float + Square root of the sum of squared weights + + Examples + -------- + >>> weights = np.array([1, 2, 3]) + >>> get_sumw2(weights) + 3.7416573867739413 + """ + return np.sqrt(np.sum(weights ** 2)) + +def get_hist_mean(x: np.ndarray, y: np.ndarray) -> float: + """ + Calculate mean value of a histogram. + + Parameters + ---------- + x : np.ndarray + Bin centers + y : np.ndarray + Bin contents + + Returns + ------- + float + Mean value of histogram + + Examples + -------- + >>> x = np.array([0.5, 1.5, 2.5]) + >>> y = np.array([1, 2, 1]) + >>> get_hist_mean(x, y) + 1.5 + """ + return np.sum(x * y) / np.sum(y) + +def get_hist_std(x: np.ndarray, y: np.ndarray) -> float: + """ + Calculate standard deviation of a histogram. + + Parameters + ---------- + x : np.ndarray + Bin centers + y : np.ndarray + Bin contents + + Returns + ------- + float + Standard deviation of histogram + + Examples + -------- + >>> x = np.array([0.5, 1.5, 2.5]) + >>> y = np.array([1, 2, 1]) + >>> get_hist_std(x, y) + 0.7071067811865476 + """ + mean = get_hist_mean(x, y) + count = np.sum(y) + if count == 0.0: + return 0.0 + # for negative stddev (e.g. when having negative weights) - return std=0 + std2 = np.max([np.sum(y * (x - mean) ** 2) / count, 0.0]) + return np.sqrt(std2) + +def get_hist_effective_entries(y: np.ndarray, yerr: np.ndarray) -> float: + """ + Calculate effective number of entries in histogram. + + Parameters + ---------- + y : np.ndarray + Bin contents + yerr : np.ndarray + Bin uncertainties + + Returns + ------- + float + Number of effective entries + + Examples + -------- + >>> y = np.array([1, 2, 1]) + >>> yerr = np.array([0.5, 0.5, 0.5]) + >>> get_hist_effective_entries(y, yerr) + 21.333333333333332 + """ + sumw2 = np.sum(yerr ** 2) + if sumw2 != 0.0: + return (np.sum(y) ** 2) / sumw2 + return 0.0 + +def get_hist_mean_error( + x: np.ndarray, + y: np.ndarray, + yerr: np.ndarray +) -> float: + """ + Calculate error on histogram mean. + + Parameters + ---------- + x : np.ndarray + Bin centers + y : np.ndarray + Bin contents + yerr : np.ndarray + Bin uncertainties + + Returns + ------- + float + Error on mean + + Examples + -------- + >>> x = np.array([0.5, 1.5, 2.5]) + >>> y = np.array([1, 2, 1]) + >>> yerr = np.array([0.5, 0.5, 0.5]) + >>> get_hist_mean_error(x, y, yerr) + 0.15309310892394865 + """ + neff = get_hist_effective_entries(y, yerr) + if neff > 0.0: + std = get_hist_std(x, y) + return std / np.sqrt(neff) + return 0.0 + +def get_cumul_hist( + y: np.ndarray, + yerr: np.ndarray +) -> Tuple[np.ndarray, np.ndarray]: + """ + Calculate cumulative histogram and uncertainties. + + Parameters + ---------- + y : np.ndarray + Bin contents + yerr : np.ndarray + Bin uncertainties + + Returns + ------- + Tuple[np.ndarray, np.ndarray] + Cumulative contents and uncertainties + + Examples + -------- + >>> y = np.array([1, 2, 1]) + >>> yerr = np.array([0.5, 0.5, 0.5]) + >>> get_cumul_hist(y, yerr) + (array([1, 3, 4]), array([0.5, 0.70710678, 0.8660254])) + """ + y_cum = np.cumsum(y) + yerr_cum = np.sqrt(np.cumsum(yerr ** 2)) + return y_cum, yerr_cum + +def get_bin_centers_from_range( + xlow: Real, + xhigh: Real, + nbins: int, + bin_precision: Optional[int] = None +) -> np.ndarray: + """ + Calculate bin centers for given range and number of bins. + + Parameters + ---------- + xlow : Real + Lower bound of range + xhigh : Real + Upper bound of range + nbins : int + Number of bins + bin_precision : Optional[int], default None + Precision for rounding bin centers + + Returns + ------- + np.ndarray + Array of bin centers + + Examples + -------- + >>> get_bin_centers_from_range(0, 10, 5) + array([1., 3., 5., 7., 9.]) + """ + if nbins <= 0: + raise ValueError("Number of bins must be positive") + if xlow >= xhigh: + raise ValueError("Upper bound must be greater than lower bound") + + bin_width = (xhigh - xlow) / nbins + low_center = xlow + bin_width / 2 + high_center = xhigh - bin_width / 2 + centers = np.linspace(low_center, high_center, nbins) + + bin_precision = bin_precision or CONFIG.bin_precision + centers = np.around(centers, bin_precision) + return centers + +def get_histogram_mask( + x: np.ndarray, + condition: HistoMaskType, + y: Optional[np.ndarray] = None +) -> np.ndarray: + """ + Create mask for histogram data based on condition. + + Parameters + ---------- + x : np.ndarray + Primary data array + condition : HistoMaskType + Condition for masking (array of bounds or callable) + y : Optional[np.ndarray], default None + Secondary data array for 2D conditions + + Returns + ------- + np.ndarray + Boolean mask array + + Raises + ------ + ValueError + If data arrays have incompatible shapes or condition is invalid + TypeError + If condition type is unsupported + + Examples + -------- + >>> x = np.array([1, 2, 3, 4, 5]) + >>> mask = get_histogram_mask(x, (2, 4)) + >>> x[mask] + array([3]) + """ + try: + if y is not None and len(x) != len(y): + raise ValueError("x and y arrays must have same length") + + mask = np.full(x.shape[:1], False) + + if callable(condition): + return _apply_callable_mask(x, y, condition) + + return _apply_range_mask(x, y, condition) + + except Exception as e: + if isinstance(e, (ValueError, TypeError)): + raise + raise TypeError(f"Invalid mask condition: {condition}") from e + +def _apply_callable_mask( + x: np.ndarray, + y: Optional[np.ndarray], + condition: Callable +) -> np.ndarray: + """Apply callable condition to create mask.""" + if y is None: + return np.asarray([condition(xi)]) + return np.asarray([condition(xi, yi)]) + +def _apply_range_mask( + x: np.ndarray, + y: Optional[np.ndarray], + condition: ArrayLike +) -> np.ndarray: + """Apply range condition to create mask.""" + condition = np.asarray(condition) + + if len(condition) == 2: + xmin, xmax = condition + return (x > xmin) & (x < xmax) + + if len(condition) == 4 and y is not None: + xmin, xmax, ymin, ymax = condition + return (x > xmin) & (x < xmax) & (y > ymin) & (y < ymax) + + raise ValueError( + "Range condition must be (xmin, xmax) or (xmin, xmax, ymin, ymax)" + ) + +def select_binned_data( + mask: np.ndarray, + x: np.ndarray, + y: np.ndarray, + xerr: Optional[ArrayLike] = None, + yerr: Optional[ArrayLike] = None +) -> Tuple[np.ndarray, np.ndarray, Optional[ArrayLike], Optional[ArrayLike]]: + """ + Select data points from binned data using mask. + + Parameters + ---------- + mask : np.ndarray + Boolean mask array + x : np.ndarray + Bin centers or x-coordinates + y : np.ndarray + Bin contents or y-values + xerr : Optional[ArrayLike], default None + X-axis uncertainties + yerr : Optional[ArrayLike], default None + Y-axis uncertainties + + Returns + ------- + Tuple[np.ndarray, np.ndarray, Optional[ArrayLike], Optional[ArrayLike]] + Selected data points and uncertainties + + Examples + -------- + >>> x = np.array([1, 2, 3]) + >>> y = np.array([10, 20, 30]) + >>> mask = np.array([True, False, True]) + >>> select_binned_data(mask, x, y) + (array([1, 3]), array([10, 30]), None, None) + """ + try: + x_sel, y_sel = np.asarray(x)[mask], np.asarray(y)[mask] + xerr_sel = _select_errors(xerr, mask) + yerr_sel = _select_errors(yerr, mask) + return x_sel, y_sel, xerr_sel, yerr_sel + except Exception as e: + raise HistogramError(f"Failed to select binned data: {str(e)}") from e + +def _select_errors( + err: Optional[ArrayLike], + mask: np.ndarray +) -> Optional[Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]]: + """Select error values based on mask.""" + if err is None: + return None + + if isinstance(err, (list, tuple, np.ndarray)): + if np.ndim(err) == 2 and np.shape(err)[0] == 2: + return (_select_errors(err[0], mask), + _select_errors(err[1], mask)) + return np.asarray(err)[mask] + + return err + +def dataset_is_binned( + x: np.ndarray, + y: np.ndarray, + xlow: float, + xhigh: float, + nbins: int, + ghost_threshold: Optional[float] = None, + bin_precision: Optional[int] = None +) -> bool: + """ + Check if dataset matches expected binning. + + Parameters + ---------- + x : np.ndarray + Data x-coordinates or bin centers + y : np.ndarray + Data y-values or bin contents + xlow : float + Lower bound of range + xhigh : float + Upper bound of range + nbins : int + Number of bins + ghost_threshold : Optional[float], default None + Threshold for identifying ghost bins + bin_precision : Optional[int], default None + Precision for bin center comparison + + Returns + ------- + bool + True if dataset matches expected binning + + Raises + ------ + HistogramError + If binning validation fails + + Examples + -------- + >>> x = np.array([1., 3., 5.]) + >>> y = np.array([1, 1, 1]) + >>> dataset_is_binned(x, y, 0, 6, 3) + True + """ + try: + ghost_threshold = ghost_threshold or CONFIG.ghost_threshold + bin_precision = bin_precision or CONFIG.bin_precision + + bin_centers = get_bin_centers_from_range( + xlow, xhigh, nbins, bin_precision + ) + x = np.around(x, bin_precision) + + # Check for matching bins + if len(x) == len(bin_centers) and np.allclose( + bin_centers, x, rtol=CONFIG.rtol, atol=CONFIG.atol + ): + return True + + # Check for unit weights + if np.allclose(y, 1.0): + return False + + # Check for ghost bins + y_no_ghost = y[y > ghost_threshold] + # First check if all events have unit weight (i.e. unbinned data) + # Second check if all events have the same scaled weight (in case of lumi scaling) + if (np.allclose(y_no_ghost, 1.0) or + np.allclose(y_no_ghost, y_no_ghost[0])): + return False + + # Check for subset relationship + if len(x) == len(bin_centers) or array_issubset(bin_centers, x): + return True + + raise HistogramError("Invalid binning detected") + + except Exception as e: + if isinstance(e, HistogramError): + raise + raise HistogramError(f"Failed to validate binning: {str(e)}") from e + +def fill_missing_bins( + x: np.ndarray, + y: np.ndarray, + xlow: float, + xhigh: float, + nbins: int, + value: float = 0, + bin_precision: Optional[int] = None +) -> Tuple[np.ndarray, np.ndarray]: + """ + Fill missing bins in histogram data. + + Parameters + ---------- + x : np.ndarray + Bin centers or x-coordinates + y : np.ndarray + Bin contents or y-values + xlow : float + Lower bound of range + xhigh : float + Upper bound of range + nbins : int + Number of bins + value : float, default 0 + Value to fill in missing bins + bin_precision : Optional[int], default None + Precision for bin center comparison + + Returns + ------- + Tuple[np.ndarray, np.ndarray] + Complete arrays with filled missing bins + + Examples + -------- + >>> x = np.array([1., 3.]) + >>> y = np.array([10, 30]) + >>> fill_missing_bins(x, y, 0, 4, 3) + (array([1., 2., 3.]), array([10., 0., 30.])) + """ + try: + bin_precision = bin_precision or CONFIG.bin_precision + bin_centers = get_bin_centers_from_range( + xlow, xhigh, nbins, bin_precision + ) + x_rounded = np.around(x, bin_precision) + + missing_bins = np.setdiff1d(bin_centers, x_rounded) + missing_values = np.full_like(missing_bins, value) + + x_filled = np.concatenate([x, missing_bins]) + y_filled = np.concatenate([y, missing_values]) + + sort_idx = np.argsort(x_filled) + return x_filled[sort_idx], y_filled[sort_idx] + + except Exception as e: + raise HistogramError( + f"Failed to fill missing bins: {str(e)}" + ) from e + +def rebin_dataset( + x: np.ndarray, + y: np.ndarray, + nbins: int +) -> Tuple[np.ndarray, np.ndarray]: + """ + Rebin histogram data to new bin count. + + Parameters + ---------- + x : np.ndarray + Current bin centers + y : np.ndarray + Current bin contents + nbins : int + New number of bins + + Returns + ------- + Tuple[np.ndarray, np.ndarray] + Rebinned data arrays + + Raises + ------ + HistogramError + If rebinning fails + """ + try: + bin_edges = bin_center_to_bin_edge(x) + from quickstats.interface.root import TH1 + hist = TH1.from_numpy_histogram(y, bin_edges=bin_edges) + hist.rebin(nbins) + return hist.bin_center, hist.bin_content + + except Exception as e: + raise HistogramError(f"Failed to rebin dataset: {str(e)}") from e \ No newline at end of file diff --git a/quickstats/maths/numerics.py b/quickstats/maths/numerics.py index d27fc2250447a56c8b79a6e4d213d63cd4d4e22d..9140a62265de2dade9aa69d4828cb62f53db4fe0 100644 --- a/quickstats/maths/numerics.py +++ b/quickstats/maths/numerics.py @@ -397,4 +397,111 @@ def is_integer(x: Real) -> bool: >>> is_integer(np.uint32(10)) True """ - return x == int(x) \ No newline at end of file + return x == int(x) + +def min_max_to_range(min_val:Optional[float]=None, max_val:Optional[float]=None): + if (min_val is None) and (max_val is None): + return None + if (min_val is not None) and (max_val is not None): + return (min_val, max_val) + raise ValueError("min and max values must be all None or all float") + +def pivot_table(x: np.ndarray, y: np.ndarray, z: np.ndarray, missing: Any = np.nan) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Creates a pivot table from x, y, and z arrays by aggregating z values + based on the unique pairs of x and y coordinates. + + Parameters + ---------- + x : np.ndarray + 1D array of x coordinates (row labels). Duplicates are allowed. + y : np.ndarray + 1D array of y coordinates (column labels). Duplicates are allowed. + z : np.ndarray + 1D array of values associated with each (x, y) pair. + missing : Any, optional + The value to use for missing entries in the pivot table, by default np.nan. + + Returns + ------- + Tuple[np.ndarray, np.ndarray, np.ndarray] + A tuple containing: + - `X`: A 2D array where rows correspond to unique sorted values of `x`. + - `Y`: A 2D array where columns correspond to unique sorted values of `y`. + - `Z`: A 2D array (pivot table) where values are taken from `z`. + If there are no values for certain (x, y) pairs, the cell will be filled with `missing`. + + Raises + ------ + ValueError + If the lengths of `x`, `y`, and `z` do not match. + + Examples + -------- + >>> x = np.array([1, 2, 1, 2]) + >>> y = np.array(['A', 'B', 'A', 'B']) + >>> z = np.array([10, 20, 30, 40]) + >>> X, Y, Z = pivot_table(x, y, z) + >>> Z + array([[10., 20.], + [30., 40.]]) + """ + if len(x) != len(y) or len(y) != len(z): + raise ValueError("x, y, and z must have the same size") + + X_unique, X_inv = np.unique(x, return_inverse=True) + Y_unique, Y_inv = np.unique(y, return_inverse=True) + X, Y = np.meshgrid(X_unique, Y_unique) + + Z = np.full((X_unique.size, Y_unique.size), missing) + Z[X_inv, Y_inv] = z + + return X, Y, Z + +def get_nan_shapes(x: np.ndarray, y: np.ndarray, z: np.ndarray, alpha: float = 0.) -> List["Polygon"]: + """ + Generates alpha shapes (concave hulls) for points where `z` is NaN, using the provided `x` and `y` coordinates. + + Parameters + ---------- + x : np.ndarray + 1D array of x coordinates. + y : np.ndarray + 1D array of y coordinates. + z : np.ndarray + 1D array of values associated with the (x, y) points. NaN values indicate where the shape should be generated. + alpha : float, optional + Alpha parameter for the alpha shape algorithm. Smaller values create more detailed shapes, by default 0. + + Returns + ------- + List[Polygon] + A list of Shapely Polygon objects representing the alpha shapes of the regions where `z` is NaN. + If only one shape is generated, it will still be returned as a list. + + Examples + -------- + >>> x = np.array([1, 2, 3, 4]) + >>> y = np.array([1, 2, 3, 4]) + >>> z = np.array([1, np.nan, np.nan, 4]) + >>> shapes = get_nan_shapes(x, y, z) + >>> len(shapes) + 1 + """ + from alphashape import alphashape + from shapely.geometry import Polygon, MultiPolygon + + # Mask NaN values in z and stack corresponding x and y coordinates + mask = np.isnan(z) + xy = np.column_stack((x[mask], y[mask])) + + # Generate the alpha shape for the masked points + shape = alphashape(xy, alpha=alpha) + + # Return the shape(s) as a list of Polygons + if isinstance(shape, MultiPolygon): + return list(shape.geoms) + elif isinstance(shape, Polygon): + return [shape] + else: + return [] \ No newline at end of file diff --git a/quickstats/maths/statistics.py b/quickstats/maths/statistics.py index 50c362eaf467c2b3ffa01bfd9234ab0163b180c3..67edc10e4fa5b7e328c5856ba58186e40468be5e 100644 --- a/quickstats/maths/statistics.py +++ b/quickstats/maths/statistics.py @@ -102,7 +102,69 @@ def calculate_chi2(data_obs, data_exp, error_obs=None, threshold:float=3, epsilo } return result -def sigma_to_confidence_level(nsigma: int) -> float: +def pvalue_to_significance(pvalues: ArrayLike) -> np.ndarray: + """ + Converts an array of p-values into significance values (Z-scores). + + The function attempts to use `scipy` if available. If `scipy` is not + available, it falls back on `ROOT` to compute the significance. + + Parameters + ---------- + pvalues : ArrayLike + Array-like object of p-values to convert into significance. + + Returns + ------- + np.ndarray + Array of significance values corresponding to the input p-values. + + Examples + -------- + >>> pvalue_to_significance([0.05, 0.01]) + array([1.64485363, 2.32634787]) + """ + pvalues = np.asarray(pvalues) + + if module_exist('scipy'): + scipy = cached_import('scipy') + return scipy.stats.norm.isf(pvalues) + + ROOT = cached_import('ROOT') + return np.array([ROOT.RooStats.PValueToSignificance(pvalue) for pvalue in pvalues]) + +def significance_to_pvalue(significances: ArrayLike) -> np.ndarray: + """ + Converts an array of significance values (Z-scores) into p-values. + + The function attempts to use `scipy` if available. If `scipy` is not + available, it falls back on `ROOT` to compute the p-values. + + Parameters + ---------- + significances : ArrayLike + Array-like object of significance values to convert into p-values. + + Returns + ------- + np.ndarray + Array of p-values corresponding to the input significance values. + + Examples + -------- + >>> significance_to_pvalue([1.96, 2.33]) + array([0.0249979 , 0.00990308]) + """ + significances = np.asarray(significances) + + if module_exist('scipy'): + scipy = cached_import('scipy') + return scipy.stats.norm.sf(significances) + + ROOT = cached_import('ROOT') + return np.array([1 - ROOT.Math.normal_cdf(significance, 1, 0) for significance in significances]) + +def sigma_to_confidence_level(nsigma: float) -> float: """ Convert a number of standard deviations (sigma) to a confidence level. @@ -112,7 +174,7 @@ def sigma_to_confidence_level(nsigma: int) -> float: Parameters ---------- - nsigma : int + nsigma : float The number of standard deviations (sigma) to compute the confidence level for. Returns @@ -141,6 +203,49 @@ def sigma_to_confidence_level(nsigma: int) -> float: ROOT = cached_import('ROOT') return ROOT.Math.erf(nsigma / np.sqrt(2.0)) +def confidence_level_to_chi2(q: float, k: int = 1) -> float: + """ + Calculate a confidence level to a chi-squared value. + + Parameters + ---------- + q : float + Confidence level (quantile). + k : int + Degrees of freedom of the chi-squared distribution + + Returns + ------- + float + The chi-squared value at the given confidence level. + """ + if module_exist('scipy'): + import scipy + chi2 = scipy.stats.chi2.ppf(q, df=k) + else: + ROOT = cached_import('ROOT') + chi2 = ROOT.Math.chisquared_quantile(q, k) + return np.round(chi2, 8) + +def sigma_to_chi2(nsigma: float, k: int = 1) -> float: + """ + Calculate a number of standard deviation to a chi-squared value. + + Parameters + ---------- + nsigma : float + The number of standard deviations (sigma) to compute the chi-squared value for. + k : int + Degrees of freedom of the chi-squared distribution + + Returns + ------- + float + The chi-squared value at the given standard deviation. + """ + q = sigma_to_confidence_level(nsigma) + return confidence_level_to_chi2(q, k) + def poisson_interval(data: ArrayLike, nsigma: int = 1, offset: bool = True) -> ArrayLike: """ Calculate the Poisson error interval for binned data. @@ -613,6 +718,7 @@ def histogram(x:np.ndarray, weights:Optional[np.ndarray]=None, bin_errors = None if norm_factor != 1: + bin_content = bin_content.astype(float, copy=False) bin_content /= norm_factor return bin_content, bin_edges, bin_errors @@ -975,11 +1081,6 @@ def select_binned_data(mask:np.ndarray, x:np.ndarray, y:np.ndarray, xerr, yerr = select_err(xerr), select_err(yerr) return x, y, xerr, yerr -def pvalue_to_significance(pvalue:float): - import ROOT - significance = ROOT.RooStats.PValueToSignificance(pvalue) - return significance - def dataset_is_binned(x:np.ndarray, y:np.ndarray, xlow:float, xhigh:float, nbins:int, ghost_threshold:float=1e-8, bin_precision:int=8): bin_centers = get_bin_centers_from_range(xlow, xhigh, nbins, bin_precision=bin_precision) @@ -1058,19 +1159,16 @@ def get_hist_comparison_data(reference_data, target_data, def get_global_pvalue_significance(x:np.ndarray, pvalue_local:Optional[np.ndarray]=None, Z_local:Optional[np.ndarray]=None, Z_ref:float=0): - import ROOT - def pval_to_Z(pvals): - return np.array([ROOT.RooStats.PValueToSignificance(pval) for pval in pvals]) if (pvalue_local is None) and (Z_local is None): raise ValueError('either pvalue_local or Z_local must be provided') elif (pvalue_local is not None) and (Z_local is not None): raise ValueError('can not specify both pvalue_local and Z_local') elif (pvalue_local is not None) and (Z_local is None): - Z_local = pval_to_Z(pvalue_local) + Z_local = pvalue_to_significance(pvalue_local) pvalue_local = np.array(pvalue_local) elif (pvalue_local is None) and (Z_local is not None): Z_local = np.array(Z_local) - pvalue_local = np.array([1 - ROOT.Math.normal_cdf(s, 1, 0) for s in Z_local]) + pvalue_local = significance_to_pvalue(Z_local) sort_idx = np.argsort(x) x = x[sort_idx] Z_local = Z_local[sort_idx] @@ -1083,11 +1181,11 @@ def get_global_pvalue_significance(x:np.ndarray, pvalue_local:Optional[np.ndarra exp_term = np.exp(-0.5*(Z_local**2 - Z_ref**2)) p_global = N_up * exp_term + pvalue_local delta_p_global = np.sqrt(N_up) * exp_term - Z_global = pval_to_Z(p_global) - Z_global_delta_up = pval_to_Z(p_global + delta_p_global) + Z_global = pvalue_to_significance(p_global) + Z_global_delta_up = pvalue_to_significance(p_global + delta_p_global) mask = ~np.isinf(Z_global_delta_up) Z_global_errhi = np.where(mask, np.subtract(Z_global, Z_global_delta_up, where=mask), np.nan) - Z_global_delta_down = pval_to_Z(p_global - delta_p_global) + Z_global_delta_down = pvalue_to_significance(p_global - delta_p_global) mask = ~np.isinf(Z_global_delta_down) Z_global_errlo = np.where(mask, np.subtract(Z_global_delta_down, Z_global, where=mask), np.nan) result = { @@ -1101,9 +1199,9 @@ def get_global_pvalue_significance(x:np.ndarray, pvalue_local:Optional[np.ndarra return result HistoMaskType = Union[ArrayLike, Callable] -def get_histogram_mask(x:np.ndarray, - condition:HistoMaskType, - y:Optional[np.ndarray]=None) -> np.ndarray: +def get_histogram_mask(x: np.ndarray, + condition: HistoMaskType, + y: Optional[np.ndarray]=None) -> np.ndarray: if (y is not None) and (len(x) != len(y)): raise ValueError('x and y values must have the same size') mask = np.full(x.shape[:1], False) @@ -1112,18 +1210,20 @@ def get_histogram_mask(x:np.ndarray, mask |= np.array(list(map(condition, x))) else: mask |= np.array(list(map(condition, x, y))) - elif isinstance(condition, ArrayLike): - if len(codnition) == 2: - xmin, xmax = condition - mask |= ((x > xmin) & (x < xmax)) - elif len(condition) == 4: - xmin, xmax, ymin, ymax = condition - mask |= (x > xmin) & (x < xmax) & (y > ymin) & (y < ymax) - else: - raise ValueError("Range based mask condition must be in the form " - "(xmin, xmax) or (xmin, xmax, ymin, ymax)") else: - raise TypeError(f'Invalid mask condition: {condition}') + try: + condition = np.asarray(condition) + if len(condition) == 2: + xmin, xmax = condition + mask |= ((x > xmin) & (x < xmax)) + elif len(condition) == 4: + xmin, xmax, ymin, ymax = condition + mask |= (x > xmin) & (x < xmax) & (y > ymin) & (y < ymax) + else: + raise ValueError("range based mask condition must be in the form " + "(xmin, xmax) or (xmin, xmax, ymin, ymax)") + except (TypeError, ValueError): + raise TypeError(f'invalid mask condition: {condition}') return mask def upcast_error(size:int, values:Union[float, ArrayLike, None]=None) -> Union[np.ndarray, None]: diff --git a/quickstats/plots/__init__.py b/quickstats/plots/__init__.py index ca2e329301fe3c671439b200a2cacda077063a73..71a7d937ac406ea79e064a5c5165259c5527fb05 100644 --- a/quickstats/plots/__init__.py +++ b/quickstats/plots/__init__.py @@ -1,7 +1,11 @@ import quickstats +from . import template_styles +from . import template_analysis_label_options + from .core import * from .color_schemes import * +from .registry import Registry from .abstract_plot import AbstractPlot from .collective_data_plot import CollectiveDataPlot from .stat_plot_config import * @@ -9,6 +13,7 @@ from .hypotest_inverter_plot import HypoTestInverterPlot from .variable_distribution_plot import VariableDistributionPlot from .score_distribution_plot import ScoreDistributionPlot from .test_statistic_distribution_plot import TestStatisticDistributionPlot +from .histogram_plot import HistogramPlot from .general_1D_plot import General1DPlot from .two_axis_1D_plot import TwoAxis1DPlot from .general_2D_plot import General2DPlot diff --git a/quickstats/plots/abstract_plot.py b/quickstats/plots/abstract_plot.py index cf08414496b1fbf09c8336106603f08761c1985d..d8fe88d1fbeaa891040e04109fb726c139c7aab7 100644 --- a/quickstats/plots/abstract_plot.py +++ b/quickstats/plots/abstract_plot.py @@ -1,336 +1,986 @@ -from typing import Optional, Union, Dict, List, Tuple, Callable, Sequence -from cycler import cycler +""" +Enhanced plotting utilities with customizable styles, colors, labels, and annotations. + +This module provides a flexible base class for creating plots with rich customization +options for styles, colors, labels, and annotations. It supports both single plots +and ratio plots with comprehensive configuration capabilities. +""" + +from __future__ import annotations + +from typing import ( + Optional, Union, Dict, List, Tuple, Callable, Sequence, Any, + TypeVar, cast +) +from collections import defaultdict from itertools import cycle +from copy import deepcopy +from dataclasses import dataclass import numpy as np import matplotlib.pyplot as plt +from cycler import cycler +from matplotlib.artist import Artist +from matplotlib.axes import Axes +from matplotlib.legend import Legend + +from quickstats import AbstractObject, NamedTreeNode +from quickstats.core import mappings as mp +from quickstats.utils.common_utils import insert_periodic_substr +from quickstats.maths.histograms import HistComparisonMode -from quickstats import AbstractObject, semistaticmethod -from quickstats.plots import get_color_cycle, get_cmap -from quickstats.plots.color_schemes import QUICKSTATS_PALETTES -from quickstats.plots.template import ( - single_frame, ratio_frame, - parse_styles, format_axis_ticks, - parse_analysis_label_options, centralize_axis, - create_transform, draw_multiline_text, - resolve_handle_label, get_axis_limits, - CUSTOM_HANDLER_MAP, TEMPLATE_STYLES) -from quickstats.utils.common_utils import combine_dict, insert_periodic_substr -from quickstats.maths.statistics import bin_center_to_bin_edge, get_hist_comparison_data, select_binned_data -from quickstats.maths.statistics import HistComparisonMode, get_histogram_mask -from quickstats.maths.numerics import get_subsequences +from . import template_styles, template_analysis_label_options from .core import PlotFormat, ErrorDisplayFormat +from .colors import ( + ColorType, + ColormapType, + get_color_cycle, + get_cmap, +) +from .template import ( + single_frame, + ratio_frame, + format_axis_ticks, + centralize_axis, + draw_multiline_text, + resolve_handle_label, + get_axis_limits, + CUSTOM_HANDLER_MAP, +) -class AbstractPlot(AbstractObject): - - COLOR_CYCLE = "default" - - COLOR_PALLETE = {} - COLOR_PALLETE_SEC = {} +# Type variables for better type hints +T = TypeVar('T') +StylesType = Union[Dict[str, Any], str] +DomainType = Union[List[str], str] - LABEL_MAP = {} - - STYLES = {} - - CONFIG = { +class PlottingError(Exception): + """Base exception for plotting-related errors.""" + pass + +@dataclass +class Point: + """Data structure for plot points.""" + x: float + y: float + label: Optional[str] = None + name: Optional[str] = None + styles: Optional[Dict[str, Any]] = None + +@dataclass +class Annotation: + """Data structure for plot annotations.""" + text: str + options: Dict[str, Any] + +@dataclass +class LegendEntry: + """Data structure for legend entries.""" + handle: Any # Could be Artist, tuple/list of Artists, Collection, etc. + label: str + + def to_dict(self) -> Dict[str, Any]: + """Convert entry to dictionary format.""" + return { + "handle": self.handle, + "label": self.label + } + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> "LegendEntry": + """Create entry from dictionary format.""" + return cls( + handle=data["handle"], + label=data["label"] + ) + + def has_valid_label(self) -> bool: + """Check if entry has a valid label for legend.""" + return bool(self.label and not self.label.startswith('_')) + +class AbstractPlot(AbstractObject): + """ + A base class for creating plots with customizable styles, colors, labels, and annotations. + + This class provides a foundation for creating plots with rich customization options, + supporting both single plots and ratio plots. It handles styles, colors, labels, + and annotations with comprehensive configuration capabilities. + + Parameters + ---------- + color_map : Optional[Dict[str, ColorType]], default None + Mapping of labels to colors + color_cycle : Optional[ColormapType], default None + Color cycle for sequential coloring + label_map : Optional[Dict[str, str]], default None + Mapping of internal labels to display labels + styles : Optional[StylesType], default None + Global styles for plot elements + config : Optional[Dict[str, Any]], default None + Plot configuration parameters + styles_map : Optional[Dict[str, StylesType]], default None + Target-specific style updates + config_map : Optional[Dict[str, Dict[str, Any]]], default None + Target-specific configuration updates + analysis_label_options : Optional[Union[str, Dict[str, Any]]], default None + Options for analysis labels + figure_index : Optional[int], default None + Index for the figure + verbosity : Union[int, str], default "INFO" + Logging verbosity level + + Attributes + ---------- + COLOR_MAP : Dict[str, ColorType] + Default color mapping + COLOR_CYCLE : str + Default color cycle name + LABEL_MAP : Dict[str, str] + Default label mapping + STYLES : Dict[str, Any] + Default styles + CONFIG : Dict[str, Any] + Default configuration + + Examples + -------- + >>> # Create a basic plot + >>> plot = AbstractPlot() + >>> plot.set_color_cycle('viridis') + >>> plot.add_point(1, 1, label='Point 1') + >>> ax = plot.draw_frame() + >>> plot.finalize(ax) + + >>> # Create a ratio plot with custom styles + >>> plot = AbstractPlot(styles={'figure': {'figsize': (10, 8)}}) + >>> ax_main, ax_ratio = plot.draw_frame(ratio=True) + >>> plot.draw_axis_labels(ax_main, xlabel='X', ylabel='Y') + """ + + COLOR_MAP: Dict[str, ColorType] = {} + COLOR_CYCLE: str = "default" + LABEL_MAP: Dict[str, str] = {} + STYLES: Dict[str, Any] = {} + CONFIG: Dict[str, Any] = { "xlabellinebreak": 50, "ylabellinebreak": 50, - 'ratio_line_styles':{ - 'color': 'gray', - 'linestyle': '--', - 'zorder': 0 - }, + "ratio_line_styles": { + "color": "gray", + "linestyle": "--", + "zorder": 0, + }, + 'draw_legend': True } - - def __init__(self, - color_pallete:Optional[Dict]=None, - color_pallete_sec:Optional[Dict]=None, - color_cycle:Optional[Union[List, str, "ListedColorMap"]]=None, - label_map:Optional[Dict]=None, - styles:Optional[Union[Dict, str]]=None, - analysis_label_options:Optional[Dict]=None, - figure_index:Optional[int]=None, - config:Optional[Dict]=None, - verbosity:Optional[Union[int, str]]='INFO'): - super().__init__(verbosity=verbosity) - - self.color_pallete = combine_dict(self.COLOR_PALLETE, color_pallete) - self.color_pallete_sec = combine_dict(self.COLOR_PALLETE_SEC, color_pallete_sec) - self.config = config - self.styles = styles - self.label_map = label_map - self.analysis_label_options = analysis_label_options + def __init__( + self, + color_map: Optional[Dict[str, ColorType]] = None, + color_cycle: Optional[ColormapType] = None, + label_map: Optional[Dict[str, str]] = None, + styles: Optional[StylesType] = None, + config: Optional[Dict[str, Any]] = None, + styles_map: Optional[Dict[str, StylesType]] = None, + config_map: Optional[Dict[str, Dict[str, Any]]] = None, + analysis_label_options: Optional[Union[str, Dict[str, Any]]] = None, + figure_index: Optional[int] = None, + verbosity: Union[int, str] = "INFO", + ) -> None: + """Initialize the AbstractPlot with customization options.""" + super().__init__(verbosity=verbosity) - self.reset_legend_data() + self._points: List[Point] = [] + self._annotations: List[Annotation] = [] + self._figure: Optional[plt.Figure] = None + # Initialize properties with validation + self.color_map = self._init_color_map(color_map) + self.set_color_cycle(color_cycle) + self.label_map = self._init_label_map(label_map) + self._styles_map = self._init_styles_map(styles) + self.update_styles_map(styles_map) + self._config_map = self._init_config_map(config) + self.update_config_map(config_map) + self.analysis_label_options = self._init_analysis_options(analysis_label_options) self.figure_index = figure_index - self.set_color_cycle(color_cycle) - - self.clear_annotations() - self.clear_points() - - self._hep_data = {} + self.reset() - def _get_fullname(self, name: str, domain: Optional[str] = None) -> str: - return name if not domain else f'{domain}.{name}' + def _init_color_map( + self, + color_map: Optional[Dict[str, ColorType]] + ) -> NamedTreeNode: + """Initialize color map with validation.""" + try: + data = mp.merge_classattr(type(self), 'COLOR_MAP', copy=True) + data &= color_map + return NamedTreeNode.from_mapping(data) + except Exception as e: + raise PlottingError(f"Failed to initialize color map: {str(e)}") from e - @property - def label_map(self) -> Dict: - return self._label_map + def _init_label_map( + self, + label_map: Optional[Dict[str, str]] + ) -> NamedTreeNode: + """Initialize label map with validation.""" + try: + data = mp.merge_classattr(type(self), 'LABEL_MAP', copy=True) + if label_map is not None: + data &= label_map + return NamedTreeNode.from_mapping(data) + except Exception as e: + raise PlottingError(f"Failed to initialize label map: {str(e)}") from e - @label_map.setter - def label_map(self, value: Optional[Dict] = None): - label_map = combine_dict(AbstractPlot.LABEL_MAP, self.LABEL_MAP) - if value is not None: - label_map = combine_dict(label_map, value) - self._label_map = label_map + def _init_styles_map( + self, + styles: Optional[StylesType] + ) -> NamedTreeNode: + """Initialize styles map with validation.""" + try: + data = template_styles.get() + data &= mp.merge_classattr( + type(self), + 'STYLES', + copy=True, + parse=template_styles.parse + ) + data &= template_styles.parse(styles) + return NamedTreeNode(data=data) + except Exception as e: + raise PlottingError(f"Failed to initialize styles map: {str(e)}") from e - @property - def config(self) -> Dict: - return self._config + def _init_config_map( + self, + config: Optional[Dict[str, Any]] + ) -> NamedTreeNode: + """Initialize configuration map with validation.""" + try: + data = mp.merge_classattr(type(self), 'CONFIG', copy=True) + data &= config + return NamedTreeNode(data=data) + except Exception as e: + raise PlottingError(f"Failed to initialize config map: {str(e)}") from e - @config.setter - def config(self, value: Optional[Dict] = None): - config = combine_dict(AbstractPlot.CONFIG, self.CONFIG) - if value is not None: - config = combine_dict(config, value) - self._config = config + def _init_analysis_options( + self, + options: Optional[Union[str, Dict[str, Any]]] + ) -> Optional[Dict[str, Any]]: + """Initialize analysis label options with validation.""" + if options is None: + return None + try: + return template_analysis_label_options.parse(options) + except Exception as e: + raise PlottingError( + f"Failed to initialize analysis label options: {str(e)}" + ) from e @property - def styles(self) -> Dict: - return self._styles - - @styles.setter - def styles(self, value: Optional[Union[Dict, str]] = None): - styles = parse_styles(value, self.STYLES) - self._styles = styles + def config(self) -> Dict[str, Any]: + """Get the configuration dictionary.""" + return self._config_map.data @property - def analysis_label_options(self) -> Dict: - return self._analysis_label_options + def styles(self) -> Dict[str, Any]: + """Get the styles dictionary.""" + return self._styles_map.data - @analysis_label_options.setter - def analysis_label_options(self, value: Optional[Dict] = None): - analysis_label_options = parse_analysis_label_options(value, use_default=False) - self._analysis_label_options = analysis_label_options + @property + def config_map(self) -> NamedTreeNode: + return self._config_map @property - def hep_data(self): - return self._hep_data + def styles_map(self) -> NamedTreeNode: + return self._styles_map + + def update_styles_map( + self, + data: Optional[Dict[str, StylesType]] = None + ) -> None: + """ + Update the styles map with additional styles. - def add_point(self, x: float, y: float, - label: Optional[str] = None, - name: Optional[str] = None, - styles: Optional[Dict] = None): - if label is not None and name is None: - name = label - if name is not None: - if name in self.legend_order: - raise RuntimeError(f'Point with name "{name}" already exists.') - self.legend_order.append(name) + Parameters + ---------- + data : Optional[Dict[str, StylesType]], default None + Additional style mappings to apply + + Raises + ------ + PlottingError + If style update fails + """ + if not data: + return - point = { - 'x' : x, - 'y' : y, - 'label' : label, - 'name' : name, - 'styles': styles - } - self.points.append(point) + try: + for key, value in data.items(): + self._styles_map[key] = template_styles.parse(value) + except Exception as e: + raise PlottingError( + f"Failed to update styles map: {str(e)}" + ) from e - def clear_points(self): - self.points = [] - - def add_annotation(self, text:str, **kwargs): - annotation = { - "text": text, - **kwargs - } - self.annotations.append(annotation) + def update_config_map( + self, + data: Optional[Dict[str, Dict[str, Any]]] = None + ) -> None: + """ + Update the configuration map with additional settings. - def clear_annotations(self): - self.annotations = [] - - def set_color_cycle(self, color_cycle:Optional[Union[List, str, "ListedColorMap"]]=None): - if color_cycle is None: - color_cycle = self.COLOR_CYCLE - self.cmap = get_cmap(color_cycle) - self.color_cycle = cycle(self.cmap.colors) + Parameters + ---------- + data : Optional[Dict[str, Dict[str, Any]]], default None + Additional configuration settings to apply - def reset_color_cycle(self): - self.color_cycle = cycle(self.cmap.colors) - - def get_hep_data(self): - return combine_dict(self.hep_data) + Raises + ------ + PlottingError + If configuration update fails + """ + if not data: + return + + try: + for key, value in data.items(): + self._config_map[key] = value + except Exception as e: + raise PlottingError( + f"Failed to update config map: {str(e)}" + ) from e + + def get_domain_styles( + self, + domain: Optional[str] = None, + copy: bool = True + ) -> Dict[str, Any]: + """ + Get styles for a specific domain. + + Parameters + ---------- + domain : Optional[str], default None + The domain to get styles for + copy : bool, default True + Whether to return a copy of the styles + + Returns + ------- + Dict[str, Any] + The domain styles + """ + styles = self._styles_map.get(domain, {}) + if copy: + styles = deepcopy(styles) + return defaultdict(dict, styles) + + def get_domain_label( + self, + name: str, + domain: Optional[str] = None, + fallback: bool = False + ) -> Optional[str]: + """ + Get label for a domain. + + Parameters + ---------- + name : str + The name to get label for + domain : Optional[str], default None + The domain context + fallback : bool, default False + Whether to fall back to name-only lookup + + Returns + ------- + Optional[str] + The domain label if found + """ + full_domain = self.label_map.format(domain, name) + if full_domain not in self.label_map and fallback: + return self.label_map.get(name) + return self.label_map.get(full_domain) + + def add_point( + self, + x: float, + y: float, + label: Optional[str] = None, + name: Optional[str] = None, + styles: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Add a point to the plot. + + Parameters + ---------- + x : float + X-coordinate + y : float + Y-coordinate + label : Optional[str], default None + Point label + name : Optional[str], default None + Point name for legend + styles : Optional[Dict[str, Any]], default None + Point styles + + Raises + ------ + PlottingError + If point with same name already exists + """ + if label is not None and name is None: + name = label + if name is not None and name in self.legend_order: + raise PlottingError(f"Point redeclared with name: {name}") - def get_colors(self): - return get_color_cycle(self.cmap).by_key()['color'] + self._points.append(Point(x=x, y=y, label=label, name=name, styles=styles)) + + def add_annotation( + self, + text: str, + **kwargs: Any + ) -> None: + """ + Add an annotation to the plot. + + Parameters + ---------- + text : str + Annotation text + **kwargs : Any + Additional annotation options + """ + if not isinstance(text, str): + raise ValueError("Annotation text must be a string") + if not text: + raise ValueError("Annotation text cannot be empty") + + try: + self._annotations.append(Annotation(text=text, options=kwargs)) + except Exception as e: + raise PlottingError(f"Failed to add annotation: {str(e)}") from e + + def set_color_cycle(self, color_cycle: Optional[ColormapType] = None) -> None: + """ + Set the color cycle for the plot. - def get_styles(self, name:str): - return self.styles.get(name, {}) + Parameters + ---------- + color_cycle : Optional[ColormapType], default None + The color cycle to use. Can be: + - A string name of a colormap + - A Colormap instance + - A list of colors + + Raises + ------ + PlottingError + If colormap creation fails or colormap has no colors + """ + try: + color_cycle = color_cycle or self.COLOR_CYCLE + self.cmap = get_cmap(color_cycle) + + # Check if colormap has colors attribute + if hasattr(self.cmap, 'colors'): + self.color_cycle = cycle(self.cmap.colors) + else: + # For continuous colormaps that don't have colors attribute + # Sample N colors from the colormap + N = 256 # or some other appropriate number + self.color_cycle = cycle(self.cmap(np.linspace(0, 1, N))) + + except Exception as e: + raise PlottingError(f"Failed to set color cycle: {str(e)}") from e - def get_combined_styles(self, name:str, custom_styles:Optional[Dict]=None) -> Dict: - return combine_dict(self.get_styles(name), custom_styles) + def get_colors(self) -> List[str]: + """ + Get the list of colors from the current color cycle. - def get_label(self, name:str, domain: Optional[str] = None) -> Optional[str]: - if domain is not None: - name = self._get_fullname(name, domain) - return self.label_map.get(name, None) + Returns + ------- + List[str] + List of colors + """ + return get_color_cycle(self.cmap).by_key()["color"] def get_default_legend_order(self) -> List[str]: + """ + Get the default legend order. + + Returns + ------- + List[str] + Default legend order + """ return [] def reset_legend_data(self) -> None: - self.legend_data = {} + """Reset legend data and order.""" + self.legend_data = NamedTreeNode() self.legend_order = self.get_default_legend_order() + + def get_labelled_legend_domains(self) -> List[str]: + """ + Get list of domains that have valid legend labels. + + Returns + ------- + List[str] + List of domain names with valid legend labels + """ + try: + return [ + domain for domain in self.legend_data.domains + if cast(LegendEntry, self.legend_data.get(domain)).has_valid_label() + ] + except Exception as e: + raise PlottingError( + f"Failed to get labelled legend domains: {str(e)}" + ) from e - def _get_legend_domain(self, key: str) -> Optional[str]: - substrings = key.split('.') - if len(substrings) == 2: - return substrings[0] - return None - - def get_legend_component(self, name: str, domain: Optional[str] = None) -> Dict: - key = self._get_fullname(name, domain) - return self.legend_data.get(key, None) - - def get_handle(self, name: str, domain: Optional[str] = None): - key = self._get_fullname(name, domain) - return self.legend_data.get(key, {}).get('handle', None) - - def update_legend_handles(self, handles: Dict, domain: Optional[str] = None, - raw: bool = False) -> None: - for name, handle in handles.items(): - key = self._get_fullname(name, domain) - handle, label = resolve_handle_label(handle, raw=raw) - self.legend_data[key] = { - 'handle': handle, - 'label': label - } + def get_handle(self, domain: str) -> Optional[Artist]: + """ + Get legend handle for a domain. + + Parameters + ---------- + domain : str + The domain to get handle for + + Returns + ------- + Optional[Artist] + The legend handle if found + """ + entry = self.legend_data.get(domain) + return entry.handle if entry is not None else None + + def update_legend_handles( + self, + handles: Dict[str, Artist], + domain: Optional[str] = None, + raw: bool = False + ) -> None: + """ + Update legend handles. + + Parameters + ---------- + handles : Dict[str, Artist] + Mapping of names to handles + domain : Optional[str], default None + The domain context + raw : bool, default False + Whether to use raw handles + + Raises + ------ + PlottingError + If handle validation fails + """ + try: + for name, handle in handles.items(): + key = ( + domain if name is None + else self.legend_data.format(domain, name) if domain + else name + ) + handle, label = resolve_handle_label(handle, raw=raw) - def add_legend_decoration(self, decorator, targets:List[str], domain: Optional[str] = None): - if domain is not None: - targets = [self._get_fullname(target, domain) for target in targets] - for key, component in self.legend_data.items(): - if key not in targets: - continue - handle = component['handle'] - if isinstance(handle, (list, tuple)): - new_handle = (*handle, decorator) - else: - new_handle = (handle, decorator) - component['handle'] = new_handle - - def get_legend_domains(self): - domains = [] - for key in self.legend_data: - domain = self._get_legend_domain(key) - if domain not in domains: - domains.append(domain) - return domains - - def get_legend_handles_labels(self, domains: Optional[Union[List[str], str]] = None): - + entry = LegendEntry(handle=handle, label=label) + self.legend_data[key] = entry + + except Exception as e: + raise PlottingError(f"Failed to update legend handles: {str(e)}") from e + + def get_legend_handles_labels( + self, + domains: Optional[DomainType] = None + ) -> Tuple[List[Artist], List[str]]: + """ + Get handles and labels for legend creation. + + Parameters + ---------- + domains : Optional[DomainType], default None + Domains to include in the legend + + Returns + ------- + Tuple[List[Artist], List[str]] + Tuple of (handles, labels) for creating the legend + + Notes + ----- + If domains is None, all domains are included. + Handles and labels are returned in the order specified by legend_order. + """ if domains is None: domains = [None] elif isinstance(domains, str): domains = [domains] - assert isinstance(domains, list) + + handles: List[Artist] = [] + labels: List[str] = [] + + try: + for name in self.legend_order: + for domain in domains: + key = ( + self.legend_data.format(domain, name) + if domain else name + ) + entry = self.legend_data.get(key) + + if entry is None: + continue + + if entry.label.startswith("_"): + continue + + handles.append(entry.handle) + labels.append(entry.label) + + return handles, labels + + except Exception as e: + raise PlottingError( + f"Failed to get legend handles and labels: {str(e)}" + ) from e - handles, labels = [], [] - - for name in self.legend_order: - for domain in domains: - key = self._get_fullname(name, domain) - if key not in self.legend_data: - continue - handle = self.legend_data[key]['handle'] - label = self.legend_data[key]['label'] - if label.startswith('_'): - continue - handles.append(handle) - labels.append(label) - return handles, labels + def add_legend_decoration( + self, + decorator: Artist, + targets: List[str], + domain: Optional[str] = None + ) -> None: + """ + Add a decorator to specified legend entries. - def draw_frame(self, ratio:bool=False, **kwargs): - if ratio: - frame_method = ratio_frame - else: - frame_method = single_frame - ax = frame_method(styles=self.styles, - prop_cycle=get_color_cycle(self.cmap), - analysis_label_options=self.analysis_label_options, - figure_index=self.figure_index, - **kwargs) - if isinstance(ax, tuple): - self.draw_annotations(ax[0]) - else: - self.draw_annotations(ax) - - self.figure = plt.gcf() + Parameters + ---------- + decorator : Artist + The matplotlib artist to use as a decorator + targets : List[str] + List of legend entry names to decorate + domain : Optional[str], default None + Domain context for the targets + + Raises + ------ + PlottingError + If decoration fails for any target + ValueError + If decorator is not a valid Artist + """ + if not isinstance(decorator, Artist): + raise ValueError(f"Decorator must be an Artist, got {type(decorator)}") + + try: + if domain is not None: + targets = [ + self.legend_data.format(domain, target) + for target in targets + ] + + for target in targets: + entry = self.legend_data.get(target) + if entry is None: + continue + + handle = entry.handle + if isinstance(handle, (list, tuple)): + new_handle = (*handle, decorator) + else: + new_handle = (handle, decorator) + + # Update the entry with new handle + self.legend_data[target] = LegendEntry( + handle=new_handle, + label=entry.label + ) + + except Exception as e: + raise PlottingError( + f"Failed to add legend decoration: {str(e)}" + ) from e + + def draw_frame( + self, + ratio: bool = False, + **kwargs: Dict[str, Any] + ) -> Union[Axes, Tuple[Axes, Axes]]: + """ + Draw the plot frame. + + Parameters + ---------- + ratio : bool, default False + Whether to create a ratio plot + **kwargs : Any + Additional frame options + + Returns + ------- + Union[Axes, Tuple[Axes, Axes]] + The created axes + """ + frame_method = ratio_frame if ratio else single_frame + ax = frame_method( + styles=self.styles, + prop_cycle=get_color_cycle(self.cmap), + analysis_label_options=self.analysis_label_options, + figure_index=self.figure_index, + **kwargs, + ) + + self._figure = plt.gcf() return ax - def draw_annotations(self, ax): - for options in self.annotations: - options = combine_dict(self.styles['annotation'], options) - ax.annotate(**options) - - def draw_points(self, ax): - for point in self.points: - handle = ax.plot(point['x'], point['y'], - label=point['label'], - **point['styles']) - if point['name'] is not None: - self.update_legend_handles({point['name']: handle[0]}) - - def draw_axis_labels(self, ax, xlabel:Optional[str]=None, ylabel:Optional[str]=None, - xlabellinebreak:Optional[int]=None, ylabellinebreak:Optional[int]=None, - combined_styles:Optional[Dict]=None, - title:Optional[str]=None): - if combined_styles is None: - combined_styles = self.styles + def draw_legend( + self, + ax: Axes, + handles: Optional[List[Artist]] = None, + labels: Optional[List[str]] = None, + handler_map: Optional[Dict[Any, Any]] = None, + domains: Optional[DomainType] = None, + **kwargs: Any + ) -> Optional[Legend]: + """ + Draw the plot legend. + + Parameters + ---------- + ax : Axes + The axes to draw on + handles : Optional[List[Artist]], default None + Legend handles + labels : Optional[List[str]], default None + Legend labels + handler_map : Optional[Dict[Any, Any]], default None + Custom handler mappings + domains : Optional[DomainType], default None + Domains to include + **kwargs : Any + Additional legend options + + Returns + ------- + Optional[Legend] + The created legend if handles exist + """ + if handles is None and labels is None: + handles, labels = self.get_legend_handles_labels(domains) + + if not handles: + return None + + handler_map = mp.concat((CUSTOM_HANDLER_MAP, handler_map)) + styles = mp.concat((self.styles["legend"], kwargs), copy=True) + styles["handler_map"] = handler_map + + return ax.legend(handles, labels, **styles) + + def draw_annotations(self, ax: Axes) -> None: + """ + Draw annotations on the plot. + + Parameters + ---------- + ax : Axes + The axes to draw on + """ + for annotation in self._annotations: + options = mp.concatenate( + (self.styles["annotation"], annotation.options), + copy=True + ) + ax.annotate(annotation.text, **options) + + def draw_points(self, ax: Axes) -> None: + """ + Draw points on the plot. + + Parameters + ---------- + ax : Axes + The axes to draw on + """ + for point in self._points: + styles = mp.concat( + (self.styles.get("point"), point.styles) + ) + handle = ax.plot( + point.x, + point.y, + label=point.label, + **styles, + ) + if point.name is not None: + self.update_legend_handles({point.name: handle[0]}) + self.legend_order.append(point.name) + + def draw_axis_labels( + self, + ax: Axes, + xlabel: Optional[str] = None, + ylabel: Optional[str] = None, + xlabellinebreak: Optional[int] = None, + ylabellinebreak: Optional[int] = None, + combined_styles: Optional[Dict[str, Any]] = None, + title: Optional[str] = None, + ) -> None: + """ + Draw axis labels and title. + + Parameters + ---------- + ax : Axes + The axes to draw on + xlabel : Optional[str], default None + X-axis label + ylabel : Optional[str], default None + Y-axis label + xlabellinebreak : Optional[int], default None + Character limit for x-label line breaks + ylabellinebreak : Optional[int], default None + Character limit for y-label line breaks + combined_styles : Optional[Dict[str, Any]], default None + Combined styles for labels + title : Optional[str], default None + Plot title + """ + combined_styles = combined_styles or self.styles + if xlabel is not None: - if (xlabellinebreak is not None) and (xlabel.count("$") < 2): + if (xlabellinebreak is not None and + xlabel.count("$") < 2): # Don't break LaTeX xlabel = insert_periodic_substr(xlabel, xlabellinebreak) - ax.set_xlabel(xlabel, **combined_styles['xlabel']) + ax.set_xlabel(xlabel, **combined_styles["xlabel"]) + if ylabel is not None: - if (ylabellinebreak is not None) and (ylabel.count("$") < 2): - ylabel = insert_periodic_substr(ylabel, ylabellinebreak) - ax.set_ylabel(ylabel, **combined_styles['ylabel']) - if title is not None: - ax.set_title(title, **self.styles['title']) - - def draw_text(self, ax, text:str, x, y, - dy:float=0.05, - transform_x:str="axis", - transform_y:str="axis", - **kwargs): - styles = combine_dict(self.styles['text'], kwargs) - draw_multiline_text(ax, x, y, text, dy=dy, - transform_x=transform_x, - transform_y=transform_y, - **styles) + if (ylabellinebreak is not None and + ylabel.count("$") < 2): # Don't break LaTeX + ylabel = insert_periodic_substr(ylabel, ylabellinebreak) + ax.set_ylabel(ylabel, **combined_styles["ylabel"]) - def draw_axis_components(self, ax, xlabel:Optional[str]=None, ylabel:Optional[str]=None, - ylim:Optional[Tuple[float]]=None, xlim:Optional[Tuple[float]]=None, - xticks:Optional[List]=None, yticks:Optional[List]=None, - xticklabels:Optional[List]=None, yticklabels:Optional[List]=None, - combined_styles:Optional[Dict]=None, - title:Optional[str]=None): - if combined_styles is None: - combined_styles = self.styles - self.draw_axis_labels(ax, xlabel, ylabel, - xlabellinebreak=self.config["xlabellinebreak"], - ylabellinebreak=self.config["ylabellinebreak"], - combined_styles=combined_styles, - title=title) - - format_axis_ticks(ax, **combined_styles['axis'], - xtick_styles=combined_styles['xtick'], - ytick_styles=combined_styles['ytick']) + if title is not None: + ax.set_title(title, **self.styles["title"]) + + def draw_text( + self, + ax: Axes, + text: str, + x: float, + y: float, + dy: float = 0.05, + transform_x: str = "axis", + transform_y: str = "axis", + **kwargs: Any, + ) -> None: + """ + Draw multiline text on the plot. + + Parameters + ---------- + ax : Axes + The axes to draw on + text : str + The text to draw + x : float + X-coordinate + y : float + Y-coordinate + dy : float, default 0.05 + Vertical spacing between lines + transform_x : str, default "axis" + X-coordinate transform + transform_y : str, default "axis" + Y-coordinate transform + **kwargs : Any + Additional text style options + """ + styles = mp.concat((self.styles["text"], kwargs)) + draw_multiline_text( + ax, + x, + y, + text, + dy=dy, + transform_x=transform_x, + transform_y=transform_y, + **styles, + ) + + def draw_axis_components( + self, + ax: Axes, + xlabel: Optional[str] = None, + ylabel: Optional[str] = None, + ylim: Optional[Tuple[float, float]] = None, + xlim: Optional[Tuple[float, float]] = None, + xticks: Optional[List[float]] = None, + yticks: Optional[List[float]] = None, + xticklabels: Optional[List[str]] = None, + yticklabels: Optional[List[str]] = None, + combined_styles: Optional[Dict[str, Any]] = None, + title: Optional[str] = None, + ) -> None: + """ + Draw axis components including labels, ticks, and limits. + + Parameters + ---------- + ax : Axes + The axes to draw on + xlabel : Optional[str], default None + X-axis label + ylabel : Optional[str], default None + Y-axis label + ylim : Optional[Tuple[float, float]], default None + Y-axis limits + xlim : Optional[Tuple[float, float]], default None + X-axis limits + xticks : Optional[List[float]], default None + X-axis tick positions + yticks : Optional[List[float]], default None + Y-axis tick positions + xticklabels : Optional[List[str]], default None + X-axis tick labels + yticklabels : Optional[List[str]], default None + Y-axis tick labels + combined_styles : Optional[Dict[str, Any]], default None + Combined styles for components + title : Optional[str], default None + Plot title + """ + combined_styles = combined_styles or self.styles + # Draw labels + self.draw_axis_labels( + ax, + xlabel, + ylabel, + xlabellinebreak=self.config["xlabellinebreak"], + ylabellinebreak=self.config["ylabellinebreak"], + combined_styles=combined_styles, + title=title, + ) + + # Format ticks + try: + format_axis_ticks( + ax, + **combined_styles["axis"], + xtick_styles=combined_styles["xtick"], + ytick_styles=combined_styles["ytick"], + ) + except Exception as e: + raise PlottingError(f"Failed to format axis ticks: {str(e)}") from e + + # Set limits and ticks if ylim is not None: ax.set_ylim(*ylim) if xlim is not None: @@ -342,320 +992,192 @@ class AbstractPlot(AbstractObject): if xticklabels is not None: ax.set_xticklabels(xticklabels) if yticklabels is not None: - ax.set_yticklabels(yticklabels) - - def set_axis_range(self, ax, - xmin:Optional[float]=None, xmax:Optional[float]=None, - ymin:Optional[float]=None, ymax:Optional[float]=None, - ypadlo:Optional[float]=None, ypadhi:Optional[float]=None, - ypad:Optional[float]=None): - xlim, ylim = get_axis_limits(ax, xmin=xmin, xmax=xmax, - ymin=ymin, ymax=ymax, - ypadlo=ypadlo, ypadhi=ypadhi, - ypad=ypad) - ax.set_xlim(*xlim) - ax.set_ylim(*ylim) - - @staticmethod - def close_all_figures(): - plt.close() - - def decorate_comparison_axis(self, ax, xlabel:str="", ylabel:str="", - mode:Union[HistComparisonMode, str]="ratio", - ylim:Optional[Sequence]=None, - ypad:Optional[float]=0.1, - draw_ratio_line:bool=True): - mode = HistComparisonMode.parse(mode) + ax.set_yticklabels(yticklabels) + + def set_axis_range( + self, + ax: Axes, + xmin: Optional[float] = None, + xmax: Optional[float] = None, + ymin: Optional[float] = None, + ymax: Optional[float] = None, + ypadlo: Optional[float] = None, + ypadhi: Optional[float] = None, + ypad: Optional[float] = None, + ) -> None: + """ + Set axis ranges with optional padding. + + Parameters + ---------- + ax : Axes + The axes to modify + xmin : Optional[float], default None + Minimum x-value + xmax : Optional[float], default None + Maximum x-value + ymin : Optional[float], default None + Minimum y-value + ymax : Optional[float], default None + Maximum y-value + ypadlo : Optional[float], default None + Lower y-padding fraction + ypadhi : Optional[float], default None + Upper y-padding fraction + ypad : Optional[float], default None + Symmetric y-padding fraction + """ + try: + xlim, ylim = get_axis_limits( + ax, + xmin=xmin, + xmax=xmax, + ymin=ymin, + ymax=ymax, + ypadlo=ypadlo, + ypadhi=ypadhi, + ypad=ypad, + ) + ax.set_xlim(*xlim) + ax.set_ylim(*ylim) + except Exception as e: + raise PlottingError(f"Failed to set axis range: {str(e)}") from e + + def decorate_comparison_axis( + self, + ax: Axes, + xlabel: str = "", + ylabel: str = "", + mode: Union[HistComparisonMode, str, Callable] = "ratio", + ylim: Optional[Sequence[float]] = None, + ypad: Optional[float] = 0.1, + draw_ratio_line: bool = True, + ) -> None: + """ + Decorate a comparison axis (ratio or difference plot). + + Parameters + ---------- + ax : Axes + The axes to decorate + xlabel : str, default "" + X-axis label + ylabel : str, default "" + Y-axis label + mode : Union[HistComparisonMode, str, Callable], default "ratio" + Comparison mode + ylim : Optional[Sequence[float]], default None + Y-axis limits + ypad : Optional[float], default 0.1 + Centralization padding + draw_ratio_line : bool, default True + Whether to draw the reference line + """ if ylim is not None: ax.set_ylim(ylim) - do_centralize_axis = ylim is None - if mode == HistComparisonMode.RATIO: - if do_centralize_axis: - centralize_axis(ax, which="y", ref_value=1, padding=ypad) - if draw_ratio_line: - ax.axhline(1, **self.config['ratio_line_styles']) - elif mode == HistComparisonMode.DIFFERENCE: - if do_centralize_axis: - centralize_axis(ax, which="y", ref_value=0, padding=ypad) - if draw_ratio_line: - ax.axhline(0, **self.config['ratio_line_styles']) - # set default ylabel if not given - if not ylabel: + + do_centralize = ylim is None + + if not callable(mode): + mode = HistComparisonMode.parse(mode) if mode == HistComparisonMode.RATIO: - ylabel = "Ratio" + if do_centralize: + centralize_axis(ax, which="y", ref_value=1, padding=ypad) + if draw_ratio_line: + ax.axhline(1, **self.config["ratio_line_styles"]) + ylabel = ylabel or "Ratio" elif mode == HistComparisonMode.DIFFERENCE: - ylabel = "Difference" + if do_centralize: + centralize_axis(ax, which="y", ref_value=0, padding=ypad) + if draw_ratio_line: + ax.axhline(0, **self.config["ratio_line_styles"]) + ylabel = ylabel or "Difference" + self.draw_axis_components(ax, xlabel=xlabel, ylabel=ylabel) - - def _draw_hist_from_binned_data(self, ax, x, y, - xerr=None, yerr=None, - bin_edges:Optional[np.ndarray]=None, - hide:Optional[Union[Tuple[float, float], Callable]]=None, - styles:Optional[Dict]=None): - styles = combine_dict(self.styles['hist'], styles) - # assume uniform binning - if bin_edges is None: - bin_center = np.array(x) - bin_edges = bin_center_to_bin_edge(bin_center) - bins, range = bin_edges, (bin_edges[0], bin_edges[1]) - if hide is not None: - mask = get_histogram_mask(x, y, [hide]) - y[mask] = 0. - hist_y, _, handle = ax.hist(x, bins, range=range, - weights=y, **styles) - assert np.allclose(y, hist_y) - return handle - - def _draw_stacked_hist_from_binned_data(self, ax, x_list, y_list, - bin_edges:Optional[np.ndarray]=None, - hide_list:Optional[List[Union[Tuple[float, float], Callable]]]=None, - styles:Optional[Dict]=None): - styles = combine_dict(self.styles['hist'], styles) - if bin_edges is None: - bin_edges_list = [] - for x in x_list: - bin_center = np.array(x) - bin_edges = bin_center_to_bin_edge(bin_center) - bin_edges_list.append(bin_edges) - - if not all(np.array_equal(bin_edges, bin_edges_list[0]) \ - for bin_edges in bin_edges_list): - raise RuntimeError('subhistograms in a stacked histogram have different binnings') - bin_edges = bin_edges_list[0] - bins, range = bin_edges, (bin_edges[0], bin_edges[1]) - if hide_list is not None: - assert len(hide_list) == len(x_list) - for i, hide in enumerate(hide_list): - if hide is None: - continue - mask = get_histogram_mask(x_list[i], y_list[i], [hide]) - y_list[i][mask] = 0. - hist_y, _, handles = ax.hist(x_list, bins, range=range, - weights=y_list, stacked=True, - **styles) - #assert ... - return hist_y, handles - - def _draw_errorbar(self, ax, x, y, - xerr=None, yerr=None, - hide:Optional[Union[Tuple[float, float], Callable]]=None, - styles:Optional[Dict]=None): - styles = combine_dict(self.styles['errorbar'], styles) - if hide is not None: - mask = ~get_histogram_mask(x, [hide]) - x, y, xerr, yerr = select_binned_data(mask, x, y, xerr, yerr) - handle = ax.errorbar(x, y, xerr=xerr, yerr=yerr, **styles) - return handle - - def _draw_fill_from_binned_data(self, ax, x, y, - xerr=None, yerr=None, - bin_edges:Optional[np.ndarray]=None, - hide:Optional[Union[Tuple[float, float], Callable]]=None, - styles:Optional[Dict]=None): - styles = combine_dict(self.styles['fill_between'], styles) - # assume uniform binning - if bin_edges is None: - bin_center = np.array(x) - bin_edges = bin_center_to_bin_edge(bin_center) - if yerr is None: - yerr = 0. - indices = np.arange(len(x)) - if hide is not None: - mask = ~get_histogram_mask(x, y, [hide]) - sections_indices = get_subsequences(indices, mask, min_length=2) - else: - sections_indices = [indices] - handles = [] - for section_indices in sections_indices: - mask = np.full(x.shape, False) - mask[section_indices] = True - x_i, y_i, xerr_i, yerr_i = select_binned_data(mask, x, y, xerr, yerr) - # extend to edge - x_i[0] = bin_edges[section_indices[0]] - x_i[-1] = bin_edges[section_indices[-1] + 1] - if isinstance(yerr_i, tuple): - yerrlo = y_i - yerr_i[0] - yerrhi = y_i + yerr_i[1] - else: - yerrlo = y_i - yerr_i - yerrhi = y_i + yerr_i - handle = ax.fill_between(x_i, yerrlo, yerrhi, **styles) - handles.append(handle) - return handles[0] - - def _draw_shade_from_binned_data(self, ax, x, y, - xerr=None, yerr=None, - bin_edges:Optional[np.ndarray]=None, - hide:Optional[Union[Tuple[float, float], Callable]]=None, - styles:Optional[Dict]=None): - styles = combine_dict(styles) - # assume uniform binning - if bin_edges is None: - bin_center = np.array(x) - bin_edges = bin_center_to_bin_edge(bin_center) - bin_widths = np.diff(bin_edges) - if yerr is None: - yerr = 0. - if hide is not None: - mask = ~get_histogram_mask(x, y, [hide]) - x, y, xerr, yerr = select_binned_data(mask, x, y, xerr, yerr) - bin_widths = bin_widths[mask] - if isinstance(yerr, tuple): - height = yerr[0] + yerr[1] - bottom = y - yerr[0] - else: - height = 2 * yerr - bottom = y - yerr - handle = ax.bar(x=x, height=height, bottom=bottom, - width=bin_widths, **styles) - return handle - - def draw_binned_data(self, ax, data, - draw_data:bool=True, - draw_error:bool=True, - bin_edges:Optional[np.ndarray]=None, - plot_format:Union[PlotFormat, str]='errorbar', - error_format:Union[ErrorDisplayFormat, str]='errorbar', - styles:Optional[Dict]=None, - hide:Optional[Union[Tuple[float, float], Callable]]=None, - error_styles:Optional[Dict]=None): - if (not draw_data) and (not draw_error): - raise ValueError('can not draw nothing') - if styles is None: - styles = {} - if error_styles is None: - error_styles = {} - plot_format = PlotFormat.parse(plot_format) - error_format = ErrorDisplayFormat.parse(error_format) - handle, error_handle = None, None - x, y = data['x'], data['y'] - xerr, yerr = data.get('xerr', 0), data.get('yerr', 0) + def reset_color_cycle(self) -> None: + """ + Reset the color cycle to its initial state. + + This method restarts the color cycle from the beginning of the colormap, + useful when you want to reuse the same color sequence. + """ + self.color_cycle = cycle(self.cmap.colors) + + def reset_metadata(self) -> None: + """ + Reset all metadata including legend data and order. + + This method clears all legend-related information and should be called + when starting a new plot or clearing the current one. + """ + self.reset_legend_data() + + def reset(self) -> None: + """Reset all plot data.""" + self.reset_metadata() + self._points.clear() + self._annotations.clear() + + def finalize(self, ax: Axes) -> None: + """ + Finalize the plot by drawing points and annotations. + + Parameters + ---------- + ax : Axes + The axes to finalize + """ + self.draw_points(ax) + self.draw_annotations(ax) - if draw_data: - if plot_format == PlotFormat.HIST: - handle = self._draw_hist_from_binned_data(ax, x, y, - bin_edges=bin_edges, - hide=hide, - styles=styles) - elif plot_format == PlotFormat.ERRORBAR: - if (not draw_error) or (error_format != ErrorDisplayFormat.ERRORBAR): - handle = self._draw_errorbar(ax, x, y, - hide=hide, - styles=styles) - else: - handle = self._draw_errorbar(ax, x, y, - xerr=xerr, yerr=yerr, - hide=hide, - styles=styles) - else: - raise RuntimeError(f'unsupported plot format: {plot_format.name}') + def stretch_axis( + self, + ax: Axes, + xlim: Optional[Tuple[float, float]] = None, + ylim: Optional[Tuple[float, float]] = None, + ) -> None: + """ + Stretch axis limits to encompass new ranges. + + This method extends the current axis limits to include new ranges + without shrinking the existing view. + + Parameters + ---------- + ax : Axes + The axes to modify + xlim : Optional[Tuple[float, float]], default None + New x-axis range to include + ylim : Optional[Tuple[float, float]], default None + New y-axis range to include + + Raises + ------ + PlottingError + If axis stretching fails + """ + try: + if xlim is not None: + curr_xlim = ax.get_xlim() + ax.set_xlim( + min(xlim[0], curr_xlim[0]), + max(xlim[1], curr_xlim[1]) + ) - if draw_error: - if error_format == ErrorDisplayFormat.FILL: - error_handle = self._draw_fill_from_binned_data(ax, x, y, yerr=yerr, - hide=hide, - styles={**error_styles, - "zorder": -1}) - elif error_format == ErrorDisplayFormat.SHADE: - error_handle = self._draw_shade_from_binned_data(ax, x, y, yerr=yerr, - bin_edges=bin_edges, - hide=hide, - styles={**error_styles, - "zorder": -1}) - elif ((error_format == ErrorDisplayFormat.ERRORBAR) and - ((not draw_data) or (plot_format !=PlotFormat.ERRORBAR))): - error_handle = self._draw_errorbar(ax, x, y, - xerr=xerr, yerr=yerr, - hide=hide, - styles={**error_styles, - "marker": 'none'}) - if isinstance(handle, list): - handle = handle[0] - handles = tuple([h for h in [handle, error_handle] if h is not None]) - return handles - - def draw_stacked_binned_data(self, ax, data, - draw_data:bool=True, - draw_error:bool=True, - bin_edges:Optional[np.ndarray]=None, - plot_format:Union[PlotFormat, str]='errorbar', - error_format_list:Union[ErrorDisplayFormat, str]='errorbar', - styles:Optional[Dict]=None, - hide_list:Optional[Union[Tuple[float, float], Callable]]=None, - error_styles_list:Optional[Dict]=None): - if (not draw_data) and (not draw_error): - raise ValueError('can not draw nothing') - n_component = len(data['x']) - if styles is None: - styles = {} - if error_styles_list is None: - error_styles_list = [{}] * n_component - plot_format = PlotFormat.parse(plot_format) - error_format_list = [ErrorDisplayFormat.parse(fmt) for fmt in error_format_list] - handles, error_handles = None, None - - x_list, y_list = data['x'], data['y'] - xerr_list, yerr_list = data.get('xerr', None), data.get('yerr', None) - if draw_data: - if plot_format == PlotFormat.HIST: - plot_func = self._draw_stacked_hist_from_binned_data - hist_y, handles = plot_func(ax, x_list, y_list, - bin_edges=bin_edges, - hide_list=hide_list, - styles=styles) - else: - raise RuntimeError(f'unsupported format for stacked plot: {plot_format.name}') - if draw_error: - error_handles = [] - def get_component(obj, index): - if obj is not None: - return obj[index] - return None - for i in range(n_component): - error_format = error_format_list[i] - x, y = x_list[i], hist_y[i] - xerr = get_component(xerr_list, i) - yerr = get_component(yerr_list, i) - hide = get_component(hide_list, i) - error_styles = get_component(error_styles_list, i) - if error_format == ErrorDisplayFormat.FILL: - error_handle = self._draw_fill_from_binned_data(ax, x, y, yerr=yerr, - hide=hide, - styles={**error_styles, - "zorder": -1}) - elif error_format == ErrorDisplayFormat.SHADE: - error_handle = self._draw_shade_from_binned_data(ax, x, y, yerr=yerr, - bin_edges=bin_edges, - hide=hide, - styles={**error_styles, - "zorder": -1}) - elif error_format == ErrorDisplayFormat.ERRORBAR: - error_handle = self._draw_errorbar(ax, x, y, - xerr=xerr, yerr=yerr, - hide=hide, - styles={**error_styles, - "marker": 'none'}) - error_handles.append(error_handle) - if error_handles is None: - return handles - if handles is None: - return error_handles - handles = [(handle, error_handle) for handle, error_handle in zip(handles, error_handles)] - return handles - - def draw_legend(self, ax, handles=None, labels=None, - handler_map=None, domains=None, **kwargs): - if (handles is None) and (labels is None): - handles, labels = self.get_legend_handles_labels(domains) - if not handles: - return - if handler_map is None: - handler_map = {} - handler_map = {**CUSTOM_HANDLER_MAP, **handler_map} - styles = {**self.styles['legend'], **kwargs} - styles['handler_map'] = handler_map - ax.legend(handles, labels, **styles) - - def finalize(self, ax): - self.draw_points(ax) \ No newline at end of file + if ylim is not None: + curr_ylim = ax.get_ylim() + ax.set_ylim( + min(ylim[0], curr_ylim[0]), + max(ylim[1], curr_ylim[1]) + ) + + except Exception as e: + raise PlottingError( + f"Failed to stretch axis limits: {str(e)}" + ) from e + + @staticmethod + def close_all_figures() -> None: + """Close all open matplotlib figures.""" + plt.close("all") \ No newline at end of file diff --git a/quickstats/plots/bar_chart.py b/quickstats/plots/bar_chart.py new file mode 100644 index 0000000000000000000000000000000000000000..2d8e39c1a44fb4d602067d37ad759e6d3846cd8a --- /dev/null +++ b/quickstats/plots/bar_chart.py @@ -0,0 +1,120 @@ +from typing import Dict, Optional, Union, List, Tuple +import pandas as pd +import numpy as np + +from quickstats.plots import AbstractPlot + +class General1DPlot(AbstractPlot): + + STYLES = { + } + + CONFIG = { + } + + def __init__(self, data_map:Union[pd.DataFrame, Dict[str, pd.DataFrame]], + label_map:Optional[Dict]=None, + styles_map:Optional[Dict]=None, + color_cycle=None, + styles:Optional[Union[Dict, str]]=None, + analysis_label_options:Optional[Dict]=None, + config:Optional[Dict]=None): + + self.data_map = data_map + self.styles_map = styles_map + + super().__init__(color_cycle=color_cycle, + label_map=label_map, + styles=styles, + analysis_label_options=analysis_label_options, + config=config) + + def get_default_legend_order(self): + if not isinstance(self.data_map, dict): + return [] + else: + return list(self.data_map) + + def draw_single_data(self, ax, data:pd.DataFrame, + xattrib:str, yattrib:str, + yerrloattrib:Optional[str]=None, + yerrhiattrib:Optional[str]=None, + stat_configs:Optional[List[StatPlotConfig]]=None, + styles:Optional[Dict]=None, + label:Optional[str]=None): + pass + + # issue of coloring + def draw(self, yattrib:str, *xattribs, + targets:Optional[List]=None, + target_alignment:str="vertical", + width:float=1, spacing:float=0.1, + xlabel:Optional[str]=None, ylabel:Optional[str]=None, + ypad:Optional[float]=0.3, logy:bool=False): + + ax = self.draw_frame(logx=logx, logy=logy) + + legend_order = [] + if isinstance(self.data_map, pd.DataFrame): + if draw_stats and (None in self.stat_configs): + stat_configs = self.stat_configs[None] + else: + stat_configs = None + handle, stat_handles = self.draw_single_data(ax, self.data_map, + xattrib=xattrib, + yattrib=yattrib, + yerrloattrib=yerrloattrib, + yerrhiattrib=yerrhiattrib, + stat_configs=stat_configs, + styles=self.styles_map) + elif isinstance(self.data_map, dict): + if targets is None: + targets = list(self.data_map.keys()) + if self.styles_map is None: + styles_map = {k:None for k in self.data_map} + else: + styles_map = self.styles_map + if self.label_map is None: + label_map = {k:k for k in self.data_map} + else: + label_map = self.label_map + handles = {} + for target in targets: + data = self.data_map[target] + styles = styles_map.get(target, None) + label = label_map.get(target, "") + if draw_stats: + if target in self.stat_configs: + stat_configs = self.stat_configs[target] + elif None in self.stat_configs: + stat_configs = self.stat_configs[None] + else: + stat_configs = None + else: + stat_configs = None + handle, stat_handles = self.draw_single_data(ax, data, + xattrib=xattrib, + yattrib=yattrib, + yerrloattrib=yerrloattrib, + yerrhiattrib=yerrhiattrib, + stat_configs=stat_configs, + styles=styles, + label=label) + handles[target] = handle + if stat_handles is not None: + for i, stat_handle in enumerate(stat_handles): + if handle_has_label(stat_handle): + handle_name = f"{target}_stat_handle_{i}" + handles[handle_name] = stat_handle + legend_order.extend(handles.keys()) + self.update_legend_handles(handles) + else: + raise ValueError("invalid data format") + + self.legend_order = legend_order + self.draw_legend(ax) + + self.draw_axis_components(ax, xlabel=xlabel, ylabel=ylabel) + self.set_axis_range(ax, xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax, ypad=ypad) + + return ax diff --git a/quickstats/plots/bidirectional_bar_chart.py b/quickstats/plots/bidirectional_bar_chart.py index ee53a887d3caba521726e4bedea37b0f3aebc548..386d8c7896c688d8b2f17c752fddfab555287150 100644 --- a/quickstats/plots/bidirectional_bar_chart.py +++ b/quickstats/plots/bidirectional_bar_chart.py @@ -103,13 +103,13 @@ class BidirectionalBarChart(AbstractPlot): } def __init__(self, collective_data:Union[Dict, "pandas.DataFrame"], - label_map:Optional[Dict]=None, color_cycle:Optional[Dict]=None, styles:Optional[Union[Dict, str]]=None, analysis_label_options:Optional[Union[Dict, str]]=None, config:Optional[Dict]=None): super().__init__(color_cycle=color_cycle, styles=styles, + label_map=label_map, analysis_label_options=analysis_label_options, config=config) if isinstance(collective_data, dict): @@ -117,7 +117,6 @@ class BidirectionalBarChart(AbstractPlot): assert len(self.collective_data) > 0 else: self.collective_data = collective_data.copy() - self.label_map = combine_dict(label_map) def draw_single_data(self, ax, width:float, x:np.ndarray, diff --git a/quickstats/plots/collective_data_plot.py b/quickstats/plots/collective_data_plot.py index eb93d62f9e95c3db71efa20fa0124978ad4c13ba..d55b38cd8d8534373cd8bacb93b0c4de74813135 100644 --- a/quickstats/plots/collective_data_plot.py +++ b/quickstats/plots/collective_data_plot.py @@ -17,6 +17,7 @@ class CollectiveDataPlot(AbstractPlot): super().__init__(color_pallete=color_pallete, color_cycle=color_cycle, + label_map=label_map, styles=styles, analysis_label_options=analysis_label_options, figure_index=figure_index, @@ -24,7 +25,6 @@ class CollectiveDataPlot(AbstractPlot): self.set_data(collective_data) self.plot_options = combine_dict({}, plot_options) - self.label_map = combine_dict({}, label_map) def set_data(self, collective_data:Dict[str, Any]): if not isinstance(collective_data, dict): diff --git a/quickstats/plots/color_schemes.py b/quickstats/plots/color_schemes.py index c90e956ddcb19fff05222d3560e61cc6772757cf..43b7517cde7825da49a3c716d4b37825db1f5d5b 100644 --- a/quickstats/plots/color_schemes.py +++ b/quickstats/plots/color_schemes.py @@ -97,5 +97,6 @@ QUICKSTATS_PALETTES = dict( "hdbs:outrageousorange", "hdbs:maroonX11", "hdbs:mintcream"], atlas_hh=["hh:darkpink", "hh:medturquoise", "hh:darkyellow", "hh:darkgreen", "hh:darkblue", "hh:lightturquoise", "hh:offwhite"], - qualitative=["#FF1F5B", "#00CD6C", "#009ADE", "#AF58BA", "#FF6C1E", "#F28522"] + qualitative=["#FF1F5B", "#00CD6C", "#009ADE", "#AF58BA", "#FF6C1E", "#F28522"], + shaded_regions=['#7ed3cd', '#fbdb80', '#ffb2b3', '#b1ffb2', '#b2b3fe'] ) \ No newline at end of file diff --git a/quickstats/plots/colors.py b/quickstats/plots/colors.py index fb017146ff2ef6af49ed7b5e50de96ac21d77e45..f155301e31d79fc9f866873d9ac7dc759a08cce8 100644 --- a/quickstats/plots/colors.py +++ b/quickstats/plots/colors.py @@ -1,276 +1,426 @@ -from typing import List, Dict, Optional, Union +""" +Enhanced color utilities for matplotlib. -from cycler import cycler -import numpy as np +This module provides a comprehensive set of utilities for handling colors and colormaps +in matplotlib, including color validation, registration, and visualization tools. +""" +from typing import List, Dict, Optional, Union, Tuple, Any + +import numpy as np import matplotlib as mpl import matplotlib.pyplot as plt from matplotlib.colors import ( - to_rgba, get_named_colors_mapping, - ListedColormap, LinearSegmentedColormap + to_rgba, + get_named_colors_mapping, + Colormap, + ListedColormap, + LinearSegmentedColormap, ) -#from matplotlib.colormaps import register -#from matplotlib.cm import get_cmap as gcm +from cycler import cycler + +# Type aliases for better type checking +ColorType = Union[ + str, # Named color, hex code, or grayscale + Tuple[float, float, float], # RGB tuple + Tuple[float, float, float, float], # RGBA tuple +] + +ColormapType = Union[ + str, + Colormap, + List[ColorType], +] + +# Custom exceptions for better error handling +class ColorError(Exception): + """Base exception for color-related errors.""" + pass -def get_cmap(source: Union[List[str], str, ListedColormap, LinearSegmentedColormap], - size: Optional[int] = None) -> ListedColormap: +class ColorValidationError(ColorError): + """Exception raised for invalid color specifications.""" + pass + +class ColormapError(ColorError): + """Exception raised for colormap-related errors.""" + pass + + +def get_cmap( + source: ColormapType, + size: Optional[int] = None, +) -> Colormap: """ Get a Matplotlib colormap from a name, list of colors, or an existing colormap. - + Parameters ---------- - source : Union[List[str], str, ListedColormap, LinearSegmentedColormap] + source : ColormapType The source for the colormap. It can be: - A string name of the colormap. - - A list of color strings. - - An existing colormap instance (ListedColormap or LinearSegmentedColormap). - size : Optional[int], default: None - The number of entries in the colormap lookup table. If None, the original size is used. - + - A list of color specifications. + - An existing Colormap instance. + size : Optional[int], default None + The number of entries in the colormap lookup table. + If None, the original size is used. + Returns ------- - ListedColormap + Colormap A Matplotlib colormap. - + Raises ------ + ColormapError + If source type is invalid or colormap creation fails. ValueError - If `source` is not a recognized type. - - Example - ------- - >>> get_cmap('viridis', size=10) - >>> get_cmap(['#FF0000', '#00FF00', '#0000FF'], size=5) - >>> get_cmap(mpl.colormaps['viridis'], size=256) + If size is negative. + + Examples + -------- + >>> # Get a built-in colormap + >>> cmap1 = get_cmap('viridis', size=10) + + >>> # Create from list of colors + >>> colors = ['#FF0000', '#00FF00', '#0000FF'] + >>> cmap2 = get_cmap(colors, size=5) + + >>> # Use existing colormap + >>> cmap3 = get_cmap(plt.cm.viridis, size=256) """ - if isinstance(source, str): - cmap = mpl.colormaps.get_cmap(source) - elif isinstance(source, (ListedColormap, LinearSegmentedColormap)): - cmap = source.copy() - elif isinstance(source, list): - cmap = ListedColormap(source) - else: - raise ValueError(f"Invalid source type for colormap: {type(source)}") - if size is not None: - return cmap.resampled(size) - return cmap - -def get_cmap_rgba(source: Union[List[str], str, ListedColormap, LinearSegmentedColormap], - size: Optional[int] = None) -> np.ndarray: + try: + # Validate size if provided + if size is not None and size <= 0: + raise ValueError("size must be positive") + + # Get colormap based on source type + if isinstance(source, str): + cmap = mpl.colormaps.get_cmap(source) + elif isinstance(source, Colormap): + cmap = source + elif isinstance(source, list): + # Validate all colors in the list + for color in source: + validate_color(color) + cmap = ListedColormap(source) + else: + raise ColormapError( + f"Invalid source type for colormap: {type(source)}. " + "Expected string, Colormap, or list of colors." + ) + + # Resample if size is specified + if size is not None: + cmap = cmap.resampled(size) + + return cmap + + except (ValueError, TypeError) as e: + raise ColormapError(f"Failed to create colormap: {str(e)}") from e + + +def get_cmap_rgba( + source: ColormapType, + size: Optional[int] = None, +) -> np.ndarray: """ Retrieve the RGBA values from a colormap. - + Parameters ---------- - source : Union[List[str], str, ListedColormap, LinearSegmentedColormap] - The source for the colormap. It can be: - - A string name of the colormap. - - A list of color strings. - - An existing colormap instance (ListedColormap or LinearSegmentedColormap). - size : Optional[int], default: None - The number of entries in the colormap lookup table. If None, the original size is used. - + source : ColormapType + The source for the colormap. + size : Optional[int], default None + The number of entries in the colormap lookup table. + If None, the original size is used. + Returns ------- np.ndarray - An array of RGBA values. - - Example - ------- - >>> get_cmap_rgba('viridis', size=10) - array([[0.267004, 0.004874, 0.329415, 1. ], - [0.282623, 0.140926, 0.457517, 1. ], - ..., - [0.993248, 0.906157, 0.143936, 1. ]]) - >>> get_cmap_rgba(['#FF0000', '#00FF00', '#0000FF'], size=5) - array([[1. , 0. , 0. , 1. ], - [0.75 , 0.5 , 0.25 , 1. ], - [0.5 , 1. , 0.5 , 1. ], - [0.25 , 0.75 , 0.75 , 1. ], - [0. , 0. , 1. , 1. ]]) - >>> get_cmap_rgba(mpl.colormaps['plasma'], size=256) + An array of RGBA values with shape (N, 4). + + Examples + -------- + >>> # Get RGBA values from built-in colormap + >>> rgba1 = get_cmap_rgba('viridis', size=10) + >>> print(rgba1.shape) # (10, 4) + + >>> # Get RGBA values from custom colors + >>> colors = ['#FF0000', '#00FF00', '#0000FF'] + >>> rgba2 = get_cmap_rgba(colors, size=5) + >>> print(rgba2.shape) # (5, 4) """ cmap = get_cmap(source, size=size) - - # Ensure cmap.N is used to retrieve correct number of colors rgba_values = cmap(np.linspace(0, 1, cmap.N)) - return rgba_values -def get_cmap_rgba(source:Optional[Union[List, str]], size:Optional[int]=None) -> List[List[float]]: - cmap = get_cmap(source, size=size) - return cmap(range(cmap.N)) - -def get_rgba(color: str, alpha: float = 1.0) -> List[float]: + +def get_rgba( + color: ColorType, + alpha: float = 1.0 +) -> Tuple[float, float, float, float]: """ - Convert a color string to an RGBA list with a specified alpha value. - + Convert a color specification to an RGBA tuple with a specified alpha value. + Parameters ---------- - color : str - A color string (e.g., 'blue', '#00FF00', 'rgb(255,0,0)', etc.). - alpha : float, default: 1.0 - The alpha (transparency) value to set, in the range [0.0, 1.0]. - + color : ColorType + A color specification (e.g., 'blue', '#00FF00', (1.0, 0.0, 0.0)). + alpha : float, default 1.0 + The alpha (transparency) value, in range [0.0, 1.0]. + Returns ------- - List[float] - A list of RGBA components [R, G, B, A] with the specified alpha value. - - Example - ------- - >>> get_rgba('blue', alpha=0.5) - [0.0, 0.0, 1.0, 0.5] - >>> get_rgba('#FF5733', alpha=0.8) - [1.0, 0.3411764705882353, 0.2, 0.8] + Tuple[float, float, float, float] + An RGBA tuple (R, G, B, A) with the specified alpha value. + + Raises + ------ + ColorValidationError + If color is invalid or alpha is out of range. + + Examples + -------- + >>> get_rgba('blue', alpha=0.5) # (0.0, 0.0, 1.0, 0.5) + >>> get_rgba('#FF5733', alpha=0.8) # (1.0, 0.341, 0.2, 0.8) + >>> get_rgba((1.0, 0.0, 0.0)) # (1.0, 0.0, 0.0, 1.0) """ - rgba = list(to_rgba(color)) - rgba[-1] = alpha - return rgba + try: + if not 0 <= alpha <= 1: + raise ValueError("alpha must be between 0 and 1") + + rgba = to_rgba(color) + return rgba[:3] + (alpha,) + + except ValueError as e: + raise ColorValidationError(f"Invalid color or alpha value: {str(e)}") from e -def validate_color(color: str) -> None: + +def validate_color(color: ColorType) -> None: """ - Validate a color string by converting it to RGBA. - + Validate a color specification by attempting to convert it to RGBA. + Parameters ---------- - color : str - The color string to validate. - + color : ColorType + The color specification to validate. + Raises ------ - ValueError - If the color cannot be converted to RGBA. + ColorValidationError + If the color specification is invalid. + + Examples + -------- + >>> validate_color('blue') # OK + >>> validate_color('#FF5733') # OK + >>> validate_color('not_a_color') # Raises ColorValidationError """ try: to_rgba(color) - except ValueError: - raise ValueError(f"Invalid color value: {color}") - -def register_colors(colors: Dict[str, Union[str, Dict[str, str]]]) -> None: + except ValueError as e: + raise ColorValidationError(f"Invalid color value: {color}") from e + + +def register_colors(colors: Dict[str, Union[ColorType, Dict[str, ColorType]]]) -> None: """ - Register colors to matplotlib's color registry. - + Register colors to Matplotlib's color registry. + Parameters ---------- - colors : Dict[str, Union[str, Dict[str, str]]] - A dictionary where keys are color labels and values are either color strings or - dictionaries mapping sub-labels to color strings. - + colors : Dict[str, Union[ColorType, Dict[str, ColorType]]] + A dictionary where keys are color labels and values are either: + - Color specifications + - Dictionaries mapping sub-labels to color specifications + Raises ------ - ValueError - If any of the colors cannot be converted to RGBA. + ColorValidationError + If any color specification is invalid. TypeError - If the color values are neither strings nor dictionaries. - - Example - ------- + If the color values have invalid types. + + Examples + -------- + >>> # Register simple colors >>> register_colors({ ... 'primary': '#FF0000', - ... 'secondary': {'shade1': '#00FF00', 'shade2': '#0000FF'} + ... 'secondary': '#00FF00' ... }) - """ - grouped_colors = {} - - for label, color in colors.items(): - if isinstance(color, dict): - for sublabel, subcolor in color.items(): - validate_color(subcolor) - full_label = f'{label}:{sublabel}' - grouped_colors[full_label] = subcolor - elif isinstance(color, str): - validate_color(color) - grouped_colors[label] = color - else: - raise TypeError(f"Color for '{label}' must be a string or a dictionary.") - # Extend the named colors dictionary - named_colors = get_named_colors_mapping() - named_colors.update(grouped_colors) + >>> # Register color groups + >>> register_colors({ + ... 'brand': { + ... 'light': '#FFE4E1', + ... 'main': '#FF4136', + ... 'dark': '#85144B' + ... } + ... }) + """ + try: + grouped_colors = {} + + for label, color in colors.items(): + if isinstance(color, dict): + for sublabel, subcolor in color.items(): + validate_color(subcolor) + full_label = f"{label}:{sublabel}" + grouped_colors[full_label] = subcolor + else: + validate_color(color) + grouped_colors[label] = color + + # Update the named colors mapping + named_colors = get_named_colors_mapping() + named_colors.update(grouped_colors) + + except (TypeError, AttributeError) as e: + raise TypeError( + "Colors must be color specifications or dictionaries of color specifications" + ) from e -def register_cmaps(listed_colors:Dict[str, List[str]], force:bool=True) -> None: + +def register_cmaps( + listed_colors: Dict[str, List[ColorType]], + force: bool = True, +) -> None: """ - Register listed color maps to the matplotlib registry. - + Register listed colormaps to the Matplotlib registry. + Parameters ---------- - listed_colors : Dict[str, List[str]] - A dictionary mapping from color map name to the underlying list of colors. - - Example - ------- - >>> register_cmaps({'my_cmap': ['#FF0000', '#00FF00', '#0000FF']}) + listed_colors : Dict[str, List[ColorType]] + A dictionary mapping colormap names to lists of color specifications. + force : bool, default True + Whether to overwrite existing colormaps with the same name. + + Raises + ------ + ColormapError + If colormap creation fails. + + Examples + -------- + >>> # Register simple sequential colormap + >>> register_cmaps({ + ... 'red_to_blue': ['#FF0000', '#0000FF'] + ... }) + + >>> # Register multiple colormaps + >>> register_cmaps({ + ... 'sunset': ['#FF7E5F', '#FEB47B', '#FFE66D'], + ... 'ocean': ['#1A2980', '#26D0CE'] + ... }) """ - for name, colors in listed_colors.items(): - cmap = ListedColormap(colors, name=name) - mpl.colormaps.register(name=name, cmap=cmap, force=force) + try: + for name, colors in listed_colors.items(): + # Validate all colors before creating colormap + for color in colors: + validate_color(color) + + cmap = ListedColormap(colors, name=name) + mpl.colormaps.register(cmap, name=name, force=force) + + except Exception as e: + raise ColormapError(f"Failed to register colormaps: {str(e)}") from e -def get_color_cycle(source: Union[List[str], str, ListedColormap]) -> cycler: + +def get_color_cycle(source: ColormapType) -> cycler: """ Convert a color source to a Matplotlib cycler object. - + Parameters ---------- - source : Union[List[str], str, ListedColormap] - The source of colors. It can be: - - A list of color strings. - - A string name of the colormap. - - A `ListedColormap` instance. - + source : ColormapType + The source of colors. Can be: + - A list of color specifications + - A string name of a colormap + - A Colormap instance + Returns ------- cycler A cycler object containing colors from the source. + + Examples + -------- + >>> # Create from list of colors + >>> cycle1 = get_color_cycle(['#FF0000', '#00FF00', '#0000FF']) + >>> plt.rc('axes', prop_cycle=cycle1) - Example - ------- - >>> get_color_cycle(['#FF0000', '#00FF00', '#0000FF']) - >>> get_color_cycle('viridis') - >>> get_color_cycle(mpl.colormaps['viridis']) + >>> # Create from built-in colormap + >>> cycle2 = get_color_cycle('viridis') + >>> plt.rc('axes', prop_cycle=cycle2) """ - if isinstance(source, str): - cmap = get_cmap(source) - colors = cmap.colors - elif isinstance(source, ListedColormap): - colors = source.colors - elif isinstance(source, list): - colors = source - else: - raise ValueError(f"Invalid source type for colors: {type(source)}") + cmap = get_cmap(source) + colors = cmap.colors if hasattr(cmap, "colors") else cmap(np.linspace(0, 1, cmap.N)) return cycler(color=colors) -# taken from https://matplotlib.org/stable/tutorials/colors/colormaps.html -def plot_color_gradients(cmap_list: List[str], size: Optional[int] = None) -> None: + +def plot_color_gradients( + cmap_list: List[str], + size: Optional[int] = None, + figsize: Optional[Tuple[float, float]] = None, +) -> None: """ Plot a series of color gradients for the given list of colormap names. - + Parameters ---------- cmap_list : List[str] - List of colormap names. - size : Optional[int], default: None - The colormap will be resampled to have `size` entries in the lookup table. - - Example - ------- - >>> plot_color_gradients(['viridis', 'plasma', 'inferno', 'magma', 'cividis']) - >>> plot_color_gradients(['Blues', 'Greens', 'Reds'], size=128) + List of colormap names to visualize. + size : Optional[int], default None + The colormap will be resampled to have `size` entries. + figsize : Optional[Tuple[float, float]], default None + Custom figure size (width, height). If None, size is computed automatically. + + Examples + -------- + >>> # Plot standard colormaps + >>> plot_color_gradients(['viridis', 'plasma', 'inferno']) + + >>> # Plot custom sized gradients + >>> plot_color_gradients( + ... ['Blues', 'Greens', 'Reds'], + ... size=128, + ... figsize=(8, 6) + ... ) """ + # Create gradient array gradient = np.linspace(0, 1, 256) gradient = np.vstack((gradient, gradient)) - - # Calculate the figure height based on the number of colormaps + + # Calculate figure dimensions nrows = len(cmap_list) - figh = 0.35 + 0.15 + (nrows + (nrows - 1) * 0.1) * 0.22 - fig, axs = plt.subplots(nrows=nrows, figsize=(6.4, figh)) - fig.subplots_adjust(top=1 - 0.35 / figh, bottom=0.15 / figh, - left=0.2, right=0.99, hspace=0.4) - - # Plot each colormap gradient + if figsize is None: + fig_height = 0.35 + 0.15 + (nrows + (nrows - 1) * 0.1) * 0.22 + figsize = (6.4, fig_height) + + # Create figure and subplots + fig, axs = plt.subplots(nrows=nrows, figsize=figsize) + if nrows == 1: + axs = [axs] + + # Adjust layout + fig.subplots_adjust( + top=1 - 0.35 / figsize[1], + bottom=0.15 / figsize[1], + left=0.2, + right=0.99, + hspace=0.4, + ) + + # Plot each colormap for ax, name in zip(axs, cmap_list): cmap = get_cmap(name, size=size) - ax.imshow(gradient, aspect='auto', cmap=cmap) - ax.text(-0.01, 0.5, name, va='center', ha='right', fontsize=10, - transform=ax.transAxes) + ax.imshow(gradient, aspect="auto", cmap=cmap) + ax.text( + -0.01, + 0.5, + name, + va="center", + ha="right", + fontsize=10, + transform=ax.transAxes, + ) ax.set_axis_off() \ No newline at end of file diff --git a/quickstats/plots/core.py b/quickstats/plots/core.py index c1d0a772892dc94e8ea92b1ad92a714e17070640..e56e006b3bf5dc52f86faa6fd3c4670b444cb787 100644 --- a/quickstats/plots/core.py +++ b/quickstats/plots/core.py @@ -6,7 +6,6 @@ from matplotlib.colors import ListedColormap, LinearSegmentedColormap import quickstats from quickstats import DescriptiveEnum -from quickstats.utils.common_utils import combine_dict from .colors import ( register_colors, register_cmaps, @@ -16,7 +15,7 @@ from .colors import ( ) __all__ = ["ErrorDisplayFormat", "PlotFormat", - "set_attrib", "reload_styles", "use_style", + "reload_styles", "use_style", "register_colors", "register_cmaps", "plot_color_gradients", "get_cmap", "get_rgba", "get_cmap_rgba", "get_color_cycle"] @@ -44,34 +43,6 @@ class PlotFormat(DescriptiveEnum): obj.mpl_method = mpl_method return obj -def set_attrib(obj, **kwargs): - for key, value in kwargs.items(): - target = obj - if '.' in key: - tokens = key.split('.') - if len(tokens) != 2: - raise ValueError('maximum of 1 subfield is allowed but {} is given'.format(len(tokens)-1)) - field, key = tokens[0], tokens[1] - method_name = 'Get' + field - if hasattr(obj, 'Get' + field): - target = getattr(obj, method_name)() - else: - raise ValueError('{} object does not contain the method {}'.format(type(target), method_name)) - method_name = 'Set' + key - if hasattr(target, 'Set' + key): - method_name = 'Set' + key - elif hasattr(target, key): - method_name = key - else: - raise ValueError('{} object does not contain the method {}'.format(type(target), method_name)) - if value is None: - getattr(target, method_name)() - elif isinstance(value, (list, tuple)): - getattr(target, method_name)(*value) - else: - getattr(target, method_name)(value) - return obj - def reload_styles(): from matplotlib import style style.core.USER_LIBRARY_PATHS.append(quickstats.stylesheet_path) diff --git a/quickstats/plots/correlation_plot.py b/quickstats/plots/correlation_plot.py index a465329a9a585927c81395c6926413e525bbb4df..f905cfdc01c6fa84c100ac42085ed0fe4bfb9d0d 100644 --- a/quickstats/plots/correlation_plot.py +++ b/quickstats/plots/correlation_plot.py @@ -57,8 +57,8 @@ class CorrelationPlot(AbstractPlot): config: Optional[Dict] = None): self.data = data - self.label_map = label_map - super().__init__(styles=styles, + super().__init__(label_map=label_map, + tyles=styles, config=config) @staticmethod diff --git a/quickstats/plots/general_1D_plot.py b/quickstats/plots/general_1D_plot.py index 150c7e7b83f926503ad1091d2787f634a618598e..a8e30ce35cd48aed1759e400cebb44704251e0ef 100644 --- a/quickstats/plots/general_1D_plot.py +++ b/quickstats/plots/general_1D_plot.py @@ -1,182 +1,282 @@ -from typing import Dict, Optional, Union, List, Tuple -import pandas as pd +from __future__ import annotations + +from typing import Dict, Optional, Union, List, Any +from collections import defaultdict + import numpy as np +import pandas as pd +from matplotlib.axes import Axes + +from quickstats.core import mappings as mp +from .colors import ColormapType +from .multi_data_plot import MultiDataPlot +from .stat_plot_config import StatPlotConfig + +PlotStyles = Dict[str, Any] +StatConfigs = List[StatPlotConfig] +TargetType = Optional[Union[str, List[Optional[str]]]] -from quickstats.plots import AbstractPlot, StatPlotConfig -from quickstats.plots.core import get_rgba -from quickstats.plots.template import create_transform, handle_has_label -from quickstats.utils.common_utils import combine_dict +class General1DPlot(MultiDataPlot): + """ + Class for plotting general 1D data. + """ -class General1DPlot(AbstractPlot): + COLOR_CYCLE: str = 'default' - STYLES = { + STYLES: PlotStyles = { + 'plot': { + 'marker': 'o', + 'markersize': 8 + }, 'fill_between': { - 'alpha': 0.3, - 'hatch': None, - 'linewidth': 1.0 - } + 'alpha': 0.3, + 'hatch': None, + 'linewidth': 1.0, + }, } - - CONFIG = { - 'errorband_legend': True + + CONFIG: Dict[str, bool] = { + 'isolate_error_legend': False, + 'inherit_color': True, + 'error_on_top': True, + 'draw_legend': True } - - def __init__(self, data_map:Union[pd.DataFrame, Dict[str, pd.DataFrame]], - label_map:Optional[Dict]=None, - styles_map:Optional[Dict]=None, - color_cycle=None, - styles:Optional[Union[Dict, str]]=None, - analysis_label_options:Optional[Dict]=None, - config:Optional[Dict]=None): - - self.data_map = data_map - self.label_map = label_map - self.styles_map = styles_map - - super().__init__(color_cycle=color_cycle, - styles=styles, - analysis_label_options=analysis_label_options, - config=config) - - self.stat_configs = {} - - def get_default_legend_order(self): - if not isinstance(self.data_map, dict): - return [] - else: - return list(self.data_map) - - def configure_stats(self, stat_configs:List[StatPlotConfig], - targets:Optional[Union[str, List[str]]]=None, - extend:bool=True): + + def __init__( + self, + data_map: Union[pd.DataFrame, Dict[str, pd.DataFrame]], + color_cycle: Optional[ColormapType] = None, + label_map: Optional[Dict[str, str]] = None, + styles: Optional[Union[PlotStyles, str]] = None, + styles_map: Optional[Dict[str, Union[PlotStyles, str]]] = None, + analysis_label_options: Optional[Union[str, Dict[str, Any]]] = None, + config: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Initialize General1DPlot. + + Parameters + ---------- + data_map : Union[pd.DataFrame, Dict[str, pd.DataFrame]] + Data to plot, single DataFrame or dictionary of DataFrames + color_cycle : Optional[ColormapType], default None + Color cycle for plots + label_map : Optional[Dict[str, str]], default None + Mapping of targets to display labels + styles : Optional[Union[PlotStyles,str]], default None + Global plot styles + styles_map : Optional[Dict[str, Union[PlotStyles, str]]], default None + Target-specific style overrides + analysis_label_options : Optional[Union[str, Dict[str, Any]]], default None + Options for analysis labels + config : Optional[Dict[str, Any]], default None + Plot configuration parameters + """ + self.stat_configs: Dict[Optional[str], StatConfigs] = {} + super().__init__( + data_map=data_map, + color_cycle=color_cycle, + label_map=label_map, + styles=styles, + styles_map=styles_map, + analysis_label_options=analysis_label_options, + config=config, + ) + + def configure_stats( + self, + stat_configs: StatConfigs, + targets: TargetType = None, + extend: bool = True, + ) -> None: + """ + Configure statistical annotations for targets. + + Parameters + ---------- + stat_configs : List[StatPlotConfig] + Statistical configurations to apply + targets : Optional[Union[str, List[Optional[str]]]], default None + Targets to configure, if None applies to all + extend : bool, default True + Whether to extend existing configurations + """ if not isinstance(targets, list): targets = [targets] + for target in targets: - if extend and (target in self.stat_configs): + if extend and target in self.stat_configs: self.stat_configs[target].extend(stat_configs) else: self.stat_configs[target] = stat_configs - - def draw_single_data(self, ax, data:pd.DataFrame, - xattrib:str, yattrib:str, - yerrloattrib:Optional[str]=None, - yerrhiattrib:Optional[str]=None, - stat_configs:Optional[List[StatPlotConfig]]=None, - styles:Optional[Dict]=None, - label:Optional[str]=None): + + def draw_single_data( + self, + ax: Axes, + data: pd.DataFrame, + xattrib: str, + yattrib: str, + yerrloattrib: Optional[str] = None, + yerrhiattrib: Optional[str] = None, + styles: Optional[PlotStyles] = None, + stat_configs: Optional[StatConfigs] = None, + domain: Optional[str] = None, + ) -> None: + """ + Draw single dataset with optional error bands and statistics. + + Parameters + ---------- + ax : matplotlib.axes.Axes + Axes to draw on + data : pd.DataFrame + Data to plot + xattrib : str + Column name for x values + yattrib : str + Column name for y values + yerrloattrib : Optional[str], default None + Column name for lower y errors + yerrhiattrib : Optional[str], default None + Column name for upper y errors + styles : Optional[PlotStyles], default None + Styles for this dataset + stat_configs : Optional[StatConfigs], default None + Statistical configurations + domain : Optional[str], default None + Domain name for this dataset + """ data = data.reset_index() - x, y = data[xattrib].values, data[yattrib].values + x = data[xattrib].values + y = data[yattrib].values + indices = np.argsort(x) - x, y = x[indices], y[indices] - draw_styles = combine_dict(self.styles['plot'], styles) - fill_styles = combine_dict(self.styles['fill_between']) + x = x[indices] + y = y[indices] + + handles: Dict[str, Any] = {} + styles = mp.concat((self.styles, styles), copy=True) + styles = defaultdict(dict, styles) + + label = self.label_map.get(domain, domain) + plot_handle = ax.plot(x, y, label=label, **styles['plot']) + handles[domain] = plot_handle[0] if len(plot_handle) == 1 else plot_handle + + if yerrloattrib and yerrhiattrib: + yerrlo = data[yerrloattrib].values[indices] + yerrhi = data[yerrhiattrib].values[indices] - if (yerrloattrib is not None) and (yerrhiattrib is not None): - yerrlo = data[yerrloattrib][indices] - yerrhi = data[yerrhiattrib][indices] - handle_fill = ax.fill_between(x, yerrlo, yerrhi, - **fill_styles) - else: - handle_fill = None - - handle_plot = ax.plot(x, y, **draw_styles, label=label) - if isinstance(handle_plot, list) and (len(handle_plot) == 1): - handle_plot = handle_plot[0] - - if handle_fill and ('color' not in fill_styles): - plot_color = handle_plot.get_color() - fill_color = get_rgba(plot_color) - handle_fill.set_color(fill_color) - - if stat_configs is not None: - stat_handles = [] - for stat_config in stat_configs: + if self.config['inherit_color']: + styles['fill_between'].setdefault('color', handles[domain].get_color()) + + zorder = handles[domain].get_zorder() + styles['fill_between'].setdefault('zorder', + zorder + 0.1 if self.config['error_on_top'] else zorder - 0.1) + + error_domain = self.label_map.format(domain, 'error') + error_label = self.label_map.get(error_domain, error_domain) + error_handle = ax.fill_between( + x, yerrlo, yerrhi, + label=error_label, + **styles['fill_between'] + ) + + if not self.config['isolate_error_legend']: + handles[domain] = (handles[domain], error_handle) + else: + handles[error_domain] = error_handle + + if stat_configs: + for i, stat_config in enumerate(stat_configs): stat_config.set_data(y) - stat_handle = stat_config.apply(ax, handle[0]) - stat_handles.append(stat_handle) - else: - stat_handles = None - - if self.config['errorband_legend'] and (handle_fill is not None): - handles = (handle_plot, handle_fill) - else: - handles = handle_plot - return handles, stat_handles - - def draw(self, xattrib:str, yattrib:str, - yerrloattrib:Optional[str]=None, - yerrhiattrib:Optional[str]=None, - targets:Optional[List[str]]=None, - xlabel:Optional[str]=None, ylabel:Optional[str]=None, - ymin:Optional[float]=None, ymax:Optional[float]=None, - xmin:Optional[float]=None, xmax:Optional[float]=None, - ypad:Optional[float]=0.3, - logx:bool=False, logy:bool=False, - draw_stats:bool=True): - + stat_handle = stat_config.apply(ax, handles[domain]) + stat_domain = self.label_map.format(domain, f"stat_handle_{i}") + handles[stat_domain] = stat_handle + + self.update_legend_handles(handles) + self.legend_order.extend(handles.keys()) + + def draw( + self, + xattrib: str, + yattrib: str, + yerrloattrib: Optional[str] = None, + yerrhiattrib: Optional[str] = None, + targets: Optional[List[str]] = None, + xlabel: Optional[str] = None, + ylabel: Optional[str] = None, + ymin: Optional[float] = None, + ymax: Optional[float] = None, + xmin: Optional[float] = None, + xmax: Optional[float] = None, + ypad: float = 0.3, + logx: bool = False, + logy: bool = False, + draw_stats: bool = True, + legend_order: Optional[List[str]] = None, + ) -> Axes: + """ + Draw complete plot with all datasets. + + Parameters + ---------- + xattrib : str + Column name for x values + yattrib : str + Column name for y values + yerrloattrib : Optional[str], default None + Column name for lower y errors + yerrhiattrib : Optional[str], default None + Column name for upper y errors + targets : Optional[List[str]], default None + Targets to plot + xlabel : Optional[str], default None + X-axis label + ylabel : Optional[str], default None + Y-axis label + ymin, ymax : Optional[float], default None + Y-axis limits + xmin, xmax : Optional[float], default None + X-axis limits + ypad : float, default 0.3 + Y-axis padding fraction + logx, logy : bool, default False + Use logarithmic scale + draw_stats : bool, default True + Draw statistical annotations + legend_order : Optional[List[str]], default None + Custom legend order + + Returns + ------- + matplotlib.axes.Axes + The plotted axes + """ + self.reset_metadata() ax = self.draw_frame(logx=logx, logy=logy) - - legend_order = [] - if isinstance(self.data_map, pd.DataFrame): - if draw_stats and (None in self.stat_configs): - stat_configs = self.stat_configs[None] - else: - stat_configs = None - handle, stat_handles = self.draw_single_data(ax, self.data_map, - xattrib=xattrib, - yattrib=yattrib, - yerrloattrib=yerrloattrib, - yerrhiattrib=yerrhiattrib, - stat_configs=stat_configs, - styles=self.styles_map) - elif isinstance(self.data_map, dict): - if targets is None: - targets = list(self.data_map.keys()) - if self.styles_map is None: - styles_map = {k:None for k in self.data_map} - else: - styles_map = self.styles_map - if self.label_map is None: - label_map = {k:k for k in self.data_map} - else: - label_map = self.label_map - handles = {} - for target in targets: - data = self.data_map[target] - styles = styles_map.get(target, None) - label = label_map.get(target, target) - if draw_stats: - if target in self.stat_configs: - stat_configs = self.stat_configs[target] - elif None in self.stat_configs: - stat_configs = self.stat_configs[None] - else: - stat_configs = None - else: - stat_configs = None - handle, stat_handles = self.draw_single_data(ax, data, - xattrib=xattrib, - yattrib=yattrib, - yerrloattrib=yerrloattrib, - yerrhiattrib=yerrhiattrib, - stat_configs=stat_configs, - styles=styles, - label=label) - handles[target] = handle - if stat_handles is not None: - for i, stat_handle in enumerate(stat_handles): - if handle_has_label(stat_handle): - handle_name = f"{target}_stat_handle_{i}" - handles[handle_name] = stat_handle - legend_order.extend(handles.keys()) - self.update_legend_handles(handles) - else: - raise ValueError("invalid data format") - + + targets = self.resolve_targets(targets) + for target in targets: + self.draw_single_data( + ax, + self.data_map[target], + xattrib=xattrib, + yattrib=yattrib, + yerrloattrib=yerrloattrib, + yerrhiattrib=yerrhiattrib, + styles=self.styles_map.get(target), + stat_configs=self.stat_configs.get(target) if draw_stats else None, + domain=target + ) + self.draw_axis_components(ax, xlabel=xlabel, ylabel=ylabel) self.set_axis_range(ax, xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax, ypad=ypad) - - self.legend_order = legend_order - self.draw_legend(ax) + self.finalize(ax) - return ax + if legend_order is not None: + self.legend_order = legend_order + + if self.config['draw_legend']: + self.draw_legend(ax) + + return ax \ No newline at end of file diff --git a/quickstats/plots/general_2D_plot.py b/quickstats/plots/general_2D_plot.py index 469b2c1955b765d06c2d28dc9ebb321650a7ae28..cab7dbdaad2f3cb281598625f472705342ba8be3 100644 --- a/quickstats/plots/general_2D_plot.py +++ b/quickstats/plots/general_2D_plot.py @@ -1,19 +1,42 @@ -from typing import Dict, Optional, Union, List, Tuple, Callable -import pandas as pd +from __future__ import annotations + +from typing import Dict, Optional, Union, List, Tuple, Callable, Any, cast +from collections import defaultdict + import numpy as np +import pandas as pd +from matplotlib.axes import Axes import matplotlib.pyplot as plt -import matplotlib.colors as mcolors -import matplotlib.colorbar as mcolorbar +from matplotlib.colors import Normalize, LogNorm +from mpl_toolkits.axes_grid1 import make_axes_locatable -from quickstats.plots import AbstractPlot -from quickstats.plots.template import format_axis_ticks +from quickstats.core import mappings as mp from quickstats.maths.interpolation import interpolate_2d -from quickstats.utils.common_utils import combine_dict +from quickstats.maths.numerics import pivot_table +from .multi_data_plot import MultiDataPlot +from .template import ( + format_axis_ticks, + convert_size, +) +from .colors import ColorType, ColormapType -class General2DPlot(AbstractPlot): +class General2DPlot(MultiDataPlot): + """ + Class for plotting general 2D data. - STYLES = { + Provides functionality for plotting 2D data with support for: + - Colormesh plots + - Contour plots (filled and unfilled) + - Scatter plots + - Multiple colorbars + - Custom interpolation + """ + + COLOR_CYCLE: str = 'default' + + STYLES: Dict[str, Any] = { 'pcolormesh': { + 'cmap': 'plasma', 'shading': 'auto', 'rasterized': True }, @@ -25,6 +48,8 @@ class General2DPlot(AbstractPlot): 'linewidths': 2 }, 'contourf': { + 'alpha': 0.5, + 'zorder': 0 }, 'scatter': { 'c': 'hh:darkpink', @@ -35,133 +60,451 @@ class General2DPlot(AbstractPlot): 'linewidth': 1, }, 'legend': { - 'handletextpad': 0., + 'handletextpad': 0.5 }, 'clabel': { 'inline': True, 'fontsize': 10 + }, + 'axis_divider': { + 'position': 'right', + 'size': 0.3, + 'pad': 0.1 } } - CONFIG = { + CONFIG: Dict[str, Any] = { 'interpolation': 'cubic', 'num_grid_points': 500 } - - def __init__(self, data:pd.DataFrame, - styles:Optional[Union[Dict, str]]=None, - analysis_label_options:Optional[Dict]=None, - config:Optional[Dict]=None): + + def __init__( + self, + data_map: Union[pd.DataFrame, Dict[str, pd.DataFrame]], + color_map: Optional[Dict[str, ColorType]] = None, + color_cycle: Optional[ColormapType] = None, + label_map: Optional[Dict[str, str]] = None, + styles: Optional[Dict[str, Any]] = None, + styles_map: Optional[Dict[str, Union[Dict[str, Any], str]]] = None, + analysis_label_options: Optional[Union[str, Dict[str, Any]]] = None, + config: Optional[Dict[str, Any]] = None, + ) -> None: + """Initialize General2DPlot with same parameters as MultiDataPlot.""" + self.Z_interp: Dict[Optional[str], np.ndarray] = {} + self.axis_divider: Optional[Any] = None + self.caxs: List[Axes] = [] + super().__init__( + data_map=data_map, + color_map=color_map, + color_cycle=color_cycle, + label_map=label_map, + styles=styles, + styles_map=styles_map, + analysis_label_options=analysis_label_options, + config=config, + ) + + def reset_metadata(self) -> None: + """Reset plot metadata.""" + super().reset_metadata() + self.Z_interp = {} + self.axis_divider = None + self.caxs = [] - self.data = data + def resolve_target_styles( + self, + targets: List[Optional[str]], + norm: Optional[Normalize] = None, + cmap: str = 'GnBu', + clabel_fmt: Union[str, Dict[str, Any]] = '%1.3f', + clabel_manual: Optional[Dict[str, Any]] = None, + contour_levels: Optional[List[float]] = None + ) -> Dict[Optional[str], Dict[str, Any]]: + """ + Resolve plotting styles for each target. + + Parameters + ---------- + targets : List[Optional[str]] + List of targets to resolve styles for + norm : Optional[Normalize], default None + Normalization to apply to all plots + cmap : str, default 'GnBu' + Colormap to use for plots + clabel_fmt : Union[str, Dict[str, Any]], default '%1.3f' + Format for contour labels + clabel_manual : Optional[Dict[str, Any]], default None + Manual contour label positions + contour_levels : Optional[List[float]], default None + Contour levels to use + + Returns + ------- + Dict[Optional[str], Dict[str, Any]] + Styles dictionary for each target + """ + target_styles = {} + for target in targets: + styles = self.get_domain_styles(target) + + # Set pcolormesh styles + styles['pcolormesh'].setdefault('norm', norm) + styles['pcolormesh'].setdefault('cmap', cmap) + + # Set contour styles + styles['contour'].setdefault('norm', norm) + styles['contour'].setdefault('levels', contour_levels) + if 'colors' not in styles['contour']: + styles['contour'].setdefault('cmap', cmap) + + # Set contour label styles + styles['clabel'].setdefault('fmt', clabel_fmt) + styles['clabel'].setdefault('manual', clabel_manual) + + # Set filled contour styles + styles['contourf'].setdefault('norm', norm) + styles['contourf'].setdefault('levels', contour_levels) + if 'colors' not in styles['contourf']: + styles['contourf'].setdefault('cmap', cmap) + + target_styles[target] = styles + return target_styles + + def get_global_norm( + self, + zattrib: str, + targets: List[str], + transform: Optional[Callable[[np.ndarray], np.ndarray]] = None + ) -> Normalize: + """Calculate global normalization across all targets.""" + if not targets: + raise ValueError('No targets specified') + + z = self.data_map[targets[0]][zattrib].values + if transform is not None: + z = transform(z) + vmin = np.min(z) + vmax = np.max(z) - super().__init__(styles=styles, - analysis_label_options=analysis_label_options, - config=config) - - def draw(self, xattrib:str, yattrib:str, zattrib:str, - xlabel:Optional[str]=None, ylabel:Optional[str]=None, - zlabel:Optional[str]=None, title:Optional[str]=None, - ymin:Optional[float]=None, ymax:Optional[float]=None, - xmin:Optional[float]=None, xmax:Optional[float]=None, - zmin:Optional[float]=None, zmax:Optional[float]=None, - logx:bool=False, logy:bool=False, logz:bool=False, - norm:Optional=None, cmap:str='GnBu', - draw_colormesh:bool=True, draw_contour:bool=False, - draw_contourf:bool=False, draw_scatter:bool=False, - draw_clabel:bool=True, draw_colorbar:bool=True, - clabel_fmt:Union[str, Dict]='%1.3f', - clabel_manual:Optional[Dict] = None, - contour_levels:Optional[Union[float, List[float]]]=None, - transform:Optional[Callable]=None, ax=None): - - if (zmin is not None) or (zmax is not None): - if norm is not None: - raise ValueError(f'Cannot specify both (zmin, zmax) and norm.') - if logz: - norm = mcolors.LogNorm(vmin=zmin, vmax=zmax) - else: - norm = mcolors.Normalize(vmin=zmin, vmax=zmax) + for target in targets[1:]: + z = self.data_map[target][zattrib].values + if transform is not None: + z = transform(z) + vmin = min(vmin, np.min(z)) + vmax = max(vmax, np.max(z)) - if ax is None: - ax = self.draw_frame(logx=logx, logy=logy) + return Normalize(vmin=vmin, vmax=vmax) + + def get_interp_data( + self, + x: np.ndarray, + y: np.ndarray, + z: np.ndarray, + domain: Optional[str] = None, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Get interpolated data based on configuration. + + Parameters + ---------- + x, y, z : numpy.ndarray + Input data arrays + domain : Optional[str], default None + Domain name for caching + + Returns + ------- + Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray] + X grid, Y grid, interpolated Z values + """ + interp_method = self.config.get('interpolation') + if interp_method: + n = self.config['num_grid_points'] + return interpolate_2d(x, y, z, method=interp_method, n=n) + return pivot_table(x, y, z, missing=np.nan) + + def select_colorbar_target(self, handles: Dict[str, Any]) -> Optional[Any]: + """Select appropriate target for colorbar from available handles.""" + for key in ['pcm', 'contourf', 'contour']: + if key in handles: + return handles[key] + return None + + def draw_colorbar( + self, + ax: Axes, + mappable: Any, + zlabel: Optional[str], + styles: Dict[str, Any] + ) -> None: + if self.axis_divider is None: + self.axis_divider = make_axes_locatable(ax) + + pad = styles['axis_divider'].get('pad', 0.05) + if isinstance(pad, str): + pad = convert_size(pad) + + orig_pad = pad + for cax_i in self.caxs[-1:]: + # Adjust padding based on existing colorbars + axis_width = cax_i.get_tightbbox().width + cbar_width = cax_i.get_window_extent().width + ticks_width = axis_width - cbar_width + pad_width = self.caxs[0].get_tightbbox().xmin - ax.get_tightbbox().xmax + width_ratio = ticks_width / pad_width + pad += orig_pad * width_ratio + + styles['axis_divider']['pad'] = pad + cax = self.axis_divider.append_axes(**styles['axis_divider']) - data = self.data - x, y, z = data[xattrib], data[yattrib], data[zattrib] + cbar = plt.colorbar(mappable, cax=cax, **styles['colorbar']) + styles['colorbar_label'].setdefault('label', zlabel) + if styles['colorbar_label']['label'] is not None: + cbar.set_label(**styles['colorbar_label']) + + format_axis_ticks(cax, **styles['colorbar_axis']) + self.caxs.append(cax) - if transform: - z = transform(z) - if norm is None: - norm = mcolors.Normalize(vmin=np.min(z), vmax=np.max(z)) + return cbar + + def draw_single_data( + self, + ax: Axes, + x: np.ndarray, + y: np.ndarray, + z: np.ndarray, + zlabel: Optional[str] = None, + draw_colormesh: bool = True, + draw_contour: bool = False, + draw_contourf: bool = False, + draw_clabel: bool = True, + draw_scatter: bool = False, + draw_colorbar: bool = True, + styles: Optional[Dict[str, Any]] = None, + domain: Optional[str] = None, + ) -> None: + """Draw single dataset with selected plot types.""" + styles = mp.concat((self.styles, styles), copy=True) + styles = defaultdict(dict, styles) + # Get interpolated data if needed if draw_colormesh or draw_contour or draw_contourf: - interp_method = self.config['interpolation'] - n = self.config['num_grid_points'] - X, Y, Z = interpolate_2d(x, y, z, method=interp_method, n=n) - self.Z = Z - - handles = {} + X, Y, Z = self.get_interp_data(x, y, z) + self.Z_interp[domain] = Z + + handles: Dict[str, Any] = {} + + # Draw selected plot types if draw_colormesh: - pcm_styles = combine_dict(self.styles['pcolormesh']) - if cmap is not None: - pcm_styles['cmap'] = cmap - if norm is not None: - pcm_styles.pop('vmin', None) - pcm_styles.pop('vmax', None) - pcm = ax.pcolormesh(X, Y, Z, norm=norm, **pcm_styles) - handles['pcm'] = pcm + handles['pcm'] = ax.pcolormesh(X, Y, Z, **styles['pcolormesh']) if draw_contour: - contour_styles = combine_dict(self.styles['contour']) - if 'colors' not in contour_styles: - contour_styles['cmap'] = cmap - contour = ax.contour(X, Y, Z, levels=contour_levels, - norm=norm, **contour_styles) + contour = ax.contour(X, Y, Z, **styles['contour']) handles['contour'] = contour - if draw_clabel: - clabel_styles = combine_dict(self.styles['clabel']) - clabel_styles['fmt'] = clabel_fmt - clabel_styles['manual'] = clabel_manual - clabel = ax.clabel(contour, **clabel_styles) - handles['clabel'] = clabel + handles['clabel'] = ax.clabel(contour, **styles['clabel']) if draw_contourf: - contourf_styles = combine_dict(self.styles['contourf']) - if 'colors' not in contourf_styles: - contourf_styles['cmap'] = cmap - contourf = ax.contourf(X, Y, Z, levels=contour_levels, - norm=norm, **contourf_styles) - handles['contourf'] = contourf + handles['contourf'] = ax.contourf(X, Y, Z, **styles['contourf']) if draw_scatter: - scatter = ax.scatter(x, y, **self.styles['scatter']) - handles['scatter'] = scatter - + handles['scatter'] = ax.scatter(x, y, **styles['scatter']) + if draw_colorbar: - if 'pcm' in handles: - mappable = pcm - elif 'contourf' in handles: - mappable = contourf - elif 'contour' in handles: - mappable = contour - else: - mappable = None - if mappable is not None: - cbar = plt.colorbar(mappable, ax=ax, **self.styles['colorbar']) - if zlabel: - cbar.set_label(zlabel, **self.styles['colorbar_label']) - format_axis_ticks(cbar.ax, **self.styles['colorbar_axis']) - handles['cbar'] = cbar - - self.update_legend_handles(handles, raw=True) - - self.draw_axis_components(ax, xlabel=xlabel, ylabel=ylabel, - title=title) - self.set_axis_range(ax, xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax) + mappable = self.select_colorbar_target(handles) + if mappable: + handles['cbar'] = self.draw_colorbar(ax, mappable, zlabel, styles) - self.finalize(ax) + self.update_legend_handles(handles, raw=True, domain=domain) + + def draw( + self, + xattrib: str, + yattrib: str, + zattrib: str, + targets: Optional[List[str]] = None, + colorbar_targets: Optional[List[str]] = None, + legend_order: Optional[List[str]] = None, + xlabel: Optional[str] = None, + ylabel: Optional[str] = None, + zlabel: Optional[str] = None, + title: Optional[str] = None, + ymin: Optional[float] = None, + ymax: Optional[float] = None, + xmin: Optional[float] = None, + xmax: Optional[float] = None, + zmin: Optional[float] = None, + zmax: Optional[float] = None, + logx: bool = False, + logy: bool = False, + logz: bool = False, + norm: Optional[Normalize] = None, + cmap: str = 'GnBu', + draw_colormesh: bool = True, + draw_contour: bool = False, + draw_contourf: bool = False, + draw_scatter: bool = False, + draw_clabel: bool = True, + draw_colorbar: bool = True, + draw_legend: bool = True, + clabel_fmt: Union[str, Dict[str, Any]] = '%1.3f', + clabel_manual: Optional[Dict[str, Any]] = None, + contour_levels: Optional[List[float]] = None, + transform: Optional[Callable[[np.ndarray], np.ndarray]] = None, + ax: Optional[Axes] = None, + ) -> Axes: + """ + Draw complete 2D plot with all components. - self.draw_legend(ax) + Parameters + ---------- + xattrib : str + Column name for x values + yattrib : str + Column name for y values + zattrib : str + Column name for z values + targets : Optional[List[str]], default None + List of targets to plot. If None, plots all available targets + colorbar_targets : Optional[List[str]], default None + Targets to draw colorbars for. If None, uses all targets + legend_order : Optional[List[str]], default None + Custom order for legend entries + xlabel, ylabel, zlabel : Optional[str], default None + Axis labels + title : Optional[str], default None + Plot title + ymin, ymax : Optional[float], default None + Y-axis limits + xmin, xmax : Optional[float], default None + X-axis limits + zmin, zmax : Optional[float], default None + Z-axis (colorbar) limits. Cannot be used with norm + logx, logy : bool, default False + Use logarithmic scale for axes + logz : bool, default False + Use logarithmic scale for colorbar + norm : Optional[Normalize], default None + Custom normalization for colorbar. Cannot be used with zmin/zmax + cmap : str, default 'GnBu' + Colormap name for plot + draw_colormesh : bool, default True + Draw pcolormesh plot + draw_contour : bool, default False + Draw contour lines + draw_contourf : bool, default False + Draw filled contours + draw_scatter : bool, default False + Draw scatter plot + draw_clabel : bool, default True + Add contour labels (if draw_contour is True) + draw_colorbar : bool, default True + Draw colorbar + draw_legend : bool, default True + Draw legend + clabel_fmt : Union[str, Dict[str, Any]], default '%1.3f' + Format for contour labels + clabel_manual : Optional[Dict[str, Any]], default None + Manual positions for contour labels + contour_levels : Optional[List[float]], default None + Specific levels for contours + transform : Optional[Callable[[np.ndarray], np.ndarray]], default None + Transform to apply to z values before plotting + ax : Optional[Axes], default None + Axes to plot on. If None, creates new axes + + Returns + ------- + matplotlib.axes.Axes + The plotted axes + + Raises + ------ + RuntimeError + If no targets to plot + ValueError + If incompatible options specified (e.g., both norm and zmin/zmax) + + Notes + ----- + - If using draw_contour with draw_clabel, clabel_fmt and clabel_manual + can be used to customize the labels + - The transform function is applied to z values before any normalization + """ + self.reset_metadata() + if ax is None: + ax = self.draw_frame(logx=logx, logy=logy) + + targets = self.resolve_targets(targets) + if not targets: + raise RuntimeError('No targets to draw') + + colorbar_targets = colorbar_targets or list(targets) + + if zmin is not None and zmax is not None: + if norm is not None: + raise ValueError('Cannot specify both (zmin, zmax) and norm') + norm = (LogNorm if logz else Normalize)( + vmin=zmin, vmax=zmax) + elif norm is None: + norm = self.get_global_norm(zattrib, targets, transform) + + target_styles = self.resolve_target_styles( + targets=targets, + norm=norm, + cmap=cmap, + clabel_fmt=clabel_fmt, + clabel_manual=clabel_manual, + contour_levels=contour_levels + ) + + for target in targets: + data = self.data_map[target] + x = data[xattrib].values + y = data[yattrib].values + z = data[zattrib].values + + if transform is not None: + z = transform(z) + + self.draw_single_data( + ax, x, y, z, + zlabel=zlabel, + draw_colormesh=draw_colormesh, + draw_contour=draw_contour, + draw_contourf=draw_contourf, + draw_scatter=draw_scatter, + draw_clabel=draw_clabel, + draw_colorbar=draw_colorbar and target in colorbar_targets, + styles=self.styles_map.get(target), + domain=target, + ) + + # Finalize plot + self.draw_axis_components( + ax, + xlabel=xlabel, + ylabel=ylabel, + title=title + ) + + self.set_axis_range( + ax, + xmin=xmin, + xmax=xmax, + ymin=ymin, + ymax=ymax + ) + + self.finalize(ax) + + if legend_order is not None: + self.legend_order = legend_order + else: + self.legend_order = self.get_labelled_legend_domains() + + if self.config['draw_legend']: + self.draw_legend(ax) + return ax \ No newline at end of file diff --git a/quickstats/plots/histogram_plot.py b/quickstats/plots/histogram_plot.py new file mode 100644 index 0000000000000000000000000000000000000000..04b5817cf44c7ef58655ab5ca876ab46c3e4d82f --- /dev/null +++ b/quickstats/plots/histogram_plot.py @@ -0,0 +1,592 @@ +from typing import Optional, Union, Dict, Tuple, List, Any, TypeVar + +import numpy as np +from numpy.typing import NDArray +from matplotlib.artist import Artist +from matplotlib.axes import Axes + +from quickstats.core import mappings as mp +from quickstats.maths.numerics import get_subsequences +from quickstats.concepts import Histogram1D, StackedHistogram +from .abstract_plot import AbstractPlot +from .template import get_artist_colors, is_transparent_color, remake_handles +from .core import PlotFormat, ErrorDisplayFormat + +# Type aliases for better readability +NumericArray = NDArray[np.float64] +HistogramType = TypeVar('HistogramType', Histogram1D, StackedHistogram) +ErrorType = Optional[Union[Tuple[NumericArray, NumericArray], NumericArray]] + +def validate_arrays(y1: NumericArray, y2: NumericArray) -> None: + """ + Validate if two arrays are consistent by checking their values match. + + Parameters + ---------- + y1 : NumericArray + First array to compare + y2 : NumericArray + Second array to compare + + Raises + ------ + ValueError + If arrays have different shapes or values don't match + """ + if y1.shape != y2.shape: + raise ValueError(f"Arrays have different shapes: {y1.shape} vs {y2.shape}") + + if not np.allclose(y1, y2): + raise ValueError( + "Histogram bin values do not match the supplied weights. Please check your inputs." + ) + +def apply_mask_to_error( + error: ErrorType, + mask: np.ndarray, +) -> ErrorType: + """ + Apply boolean mask to error data. + + Parameters + ---------- + error : ErrorType + Error data as either tuple of arrays or single array + mask : NumericArray + Boolean mask to apply + + Returns + ------- + ErrorType + Masked error data or None if no error data provided + """ + if error is None: + return None + + try: + if isinstance(error, tuple): + return error[0][mask], error[1][mask] + return error[mask] + except IndexError as e: + raise ValueError(f"Error array shape doesn't match mask shape: {e}") from e + +def has_color_specification(styles: Dict[str, Any]) -> bool: + """ + Check if style dictionary contains color-related options. + + Parameters + ---------- + styles : Dict[str, Any] + Dictionary of style options + + Returns + ------- + bool + True if any color option present + """ + return bool(styles.keys() & {'color', 'facecolor', 'edgecolor', 'colors'}) + +class HistogramPlot(AbstractPlot): + """ + Enhanced histogram plotting class with support for various styles and error displays. + """ + + COLOR_CYCLE = "atlas_hdbs" + + STYLES = { + "hist": { + "histtype": "step", + "linestyle": "-", + "linewidth": 2, + }, + "errorbar": { + "marker": "o", + "markersize": 10, + "linestyle": "none", + "linewidth": 0, + "elinewidth": 2, + "capsize": 0, + "capthick": 0, + }, + "fill_between": { + "alpha": 0.5, + "color": "gray", + }, + "bar": { + "linewidth": 0, + "alpha": 0.5, + "color": "gray", + }, + } + + CONFIG = { + "show_xerr": False, + "error_on_top": True, + "inherit_color": True, + "combine_stacked_error": False, + "box_legend_handle": False, + "isolate_error_legend": False, + } + + def draw_hist( + self, + ax: Axes, + histogram: HistogramType, + styles: Optional[Dict[str, Any]] = None, + ) -> List[Artist]: + """ + Draw histogram on given axis. + + Parameters + ---------- + ax : Axes + Matplotlib axes to draw on + histogram : HistogramType + Histogram data to plot + styles : Optional[Dict[str, Any]], optional + Additional style options, by default None + + Returns + ------- + List[Artist] + Handles for drawn histogram elements + + Raises + ------ + TypeError + If histogram type is unsupported + ValueError + If histogram data is invalid + """ + styles = mp.concat((self.styles["hist"], styles), copy=True) + + try: + if isinstance(histogram, Histogram1D): + n, _, patches = ax.hist( + histogram.bin_centers, + weights=histogram.bin_content, + bins=histogram.bin_edges, + **styles, + ) + handles = [patches] + validate_arrays(n, histogram.bin_content) + + elif isinstance(histogram, StackedHistogram): + x = [h.bin_centers for h in histogram.histograms.values()] + y = [h.bin_content for h in histogram.histograms.values()] + n, _, patches = ax.hist( + x, + weights=y, + bins=histogram.bin_edges, + stacked=True, + **styles, + ) + handles = list(patches) + + y_base = 0. + for n_i, y_i in zip(n, y): + validate_arrays(n_i - y_base, y_i) + y_base += y_i + + else: + raise TypeError(f"Unsupported histogram type: {type(histogram)}") + + return handles + + except Exception as e: + raise ValueError(f"Failed to draw histogram: {str(e)}") from e + + def draw_errorbar( + self, + ax: Axes, + histogram: HistogramType, + styles: Optional[Dict[str, Any]] = None, + with_error: bool = True, + ) -> Artist: + """ + Draw error bars for histogram data. + + Parameters + ---------- + ax : Axes + Matplotlib axes to draw on + histogram : HistogramType + Histogram data to plot + styles : Optional[Dict[str, Any]], optional + Additional style options, by default None + with_error : bool, optional + Whether to display error bars, by default True + + Returns + ------- + Artist + Handle for error bar plot + + Raises + ------ + ValueError + If error bar plotting fails + """ + styles = mp.concat((self.styles["errorbar"], styles)) + x = histogram.bin_centers + y = histogram.bin_content + + xerr = None + yerr = None + + if with_error: + xerr = histogram.bin_widths / 2 if self.config['show_xerr'] else None + yerr = histogram.bin_errors + + if histogram.is_masked(): + mask = ~histogram.bin_mask + x = x[mask] + y = y[mask] + xerr = apply_mask_to_error(xerr, mask) + yerr = apply_mask_to_error(yerr, mask) + + try: + return ax.errorbar(x, y, xerr=xerr, yerr=yerr, **styles) + except Exception as e: + raise ValueError(f"Failed to draw error bars: {str(e)}") from e + + def draw_filled_error( + self, + ax: Axes, + histogram: HistogramType, + styles: Optional[Dict[str, Any]] = None, + ) -> Artist: + """ + Draw filled error regions on plot. + + Parameters + ---------- + ax : Axes + Matplotlib axes to draw on + histogram : HistogramType + Histogram data to plot + styles : Optional[Dict[str, Any]], optional + Additional style options, by default None + + Returns + ------- + Artist + Handle for filled error region + + Raises + ------ + RuntimeError + If histogram is fully masked + ValueError + If filled error drawing fails + """ + styles = mp.concat((self.styles['fill_between'], styles), copy=True) + x = histogram.bin_centers + rel_yerr = histogram.rel_bin_errors + + if rel_yerr is None: + rel_yerr = (histogram.bin_content, histogram.bin_content) + styles['color'] = 'none' + styles.pop('facecolor', None) + styles.pop('edgecolor', None) + + try: + handle = None + if histogram.is_masked(): + # handle cases where data is not continuous + indices = np.arange(len(x)) + mask = ~histogram.bin_mask + section_indices = get_subsequences(indices, mask, min_length=2) + + if not len(section_indices): + raise RuntimeError('Histogram is fully masked, nothing to draw') + + for indices in section_indices: + mask = np.full(x.shape, False) + mask[indices] = True + x_i = x[mask] + rel_yerr_i = apply_mask_to_error(rel_yerr, mask) + + # Extend to edge + x_i[0] = histogram.bin_edges[indices[0]] + x_i[-1] = histogram.bin_edges[indices[-1] + 1] + + if (handle is not None) and (not has_color_specification(styles)): + styles['color'] = handle.get_facecolors()[0] + + handle_i = ax.fill_between(x_i, rel_yerr_i[0], rel_yerr_i[1], **styles) + if handle is None: + handle = handle_i + else: + handle = ax.fill_between(x, rel_yerr[0], rel_yerr[1], **styles) + + return handle + + except Exception as e: + raise ValueError(f"Failed to draw filled error: {str(e)}") from e + + def draw_shaded_error( + self, + ax: Axes, + histogram: HistogramType, + styles: Optional[Dict[str, Any]] = None, + ) -> Artist: + """ + Draw shaded error bars as bar plot. + + Parameters + ---------- + ax : Axes + Matplotlib axes to draw on + histogram : HistogramType + Histogram data to plot + styles : Optional[Dict[str, Any]], optional + Additional style options, by default None + + Returns + ------- + Artist + Handle for drawn bars + + Raises + ------ + ValueError + If error bar shading fails + """ + styles = mp.concat((self.styles["bar"], styles), copy=True) + x = histogram.bin_centers + y = histogram.bin_content + yerr = histogram.bin_errors + + if yerr is None: + yerr = (np.zeros_like(y), np.zeros_like(y)) + styles["color"] = "none" + styles.pop("facecolor", None) + styles.pop("edgecolor", None) + + height = yerr[0] + yerr[1] + bottom = y - yerr[0] + widths = histogram.bin_widths + + try: + return ax.bar(x, height=height, bottom=bottom, width=widths, **styles) + except Exception as e: + raise ValueError(f"Failed to draw shaded error: {str(e)}") from e + + def draw_histogram_data( + self, + ax: Axes, + histogram: HistogramType, + plot_format: Union[PlotFormat, str] = 'errorbar', + error_format: Union[ErrorDisplayFormat, str] = 'errorbar', + styles: Optional[Dict[str, Any]] = None, + error_styles: Optional[Dict[str, Any]] = None, + domain: str = 'main' + ) -> Dict[str, Union[Artist, Tuple[Artist, Artist]]]: + """ + Draw histogram data with specified plot and error formats. + + Parameters + ---------- + ax : Axes + Matplotlib axes to draw on + histogram : HistogramType + Histogram data to plot + plot_format : Union[PlotFormat, str], optional + Format for plotting histogram, by default 'errorbar' + error_format : Union[ErrorDisplayFormat, str], optional + Format for plotting error, by default 'errorbar' + styles : Optional[Dict[str, Any]], optional + Style options for plot, by default None + error_styles : Optional[Dict[str, Any]], optional + Style options for error representation, by default None + domain : str, optional + Domain name for legend labels, by default 'main' + + Returns + ------- + Dict[str, Union[Artist, Tuple[Artist, Artist]]] + Handles for plot and error elements + + Raises + ------ + ValueError + If plotting fails + """ + styles = styles or {} + error_styles = error_styles or {} + plot_format = PlotFormat.parse(plot_format) + error_format = ErrorDisplayFormat.parse(error_format) + + plot_handles: List[Artist] = [] + error_handles: List[Artist] = [] + + try: + if plot_format == PlotFormat.HIST: + handles = self.draw_hist(ax, histogram, styles=styles) + plot_handles.extend(handles) + + def custom_draw( + histogram_current: HistogramType, + styles_current: Dict[str, Any], + error_styles_current: Dict[str, Any] + ) -> None: + if plot_format == PlotFormat.ERRORBAR: + with_error = error_format == ErrorDisplayFormat.ERRORBAR + handle = self.draw_errorbar( + ax, + histogram_current, + styles=styles_current, + with_error=with_error + ) + plot_handles.append(handle) + + # inherit colors from plot handle + # priority: edgecolor > facecolor + if not has_color_specification(error_styles_current): + plot_handle = plot_handles[len(error_handles)] + + if plot_format == PlotFormat.HIST: + # take care of case histtype = 'step' or 'stepfilled' + plot_handle = plot_handle[0] if isinstance(plot_handle, list) else plot_handle + colors = get_artist_colors(plot_handle) + color = colors['edgecolor'] + if is_transparent_color(color): + color = colors['facecolor'] + elif plot_format == PlotFormat.ERRORBAR: + plot_handle = plot_handle[0] + colors = get_artist_colors(plot_handle) + color = colors['markeredgecolor'] + if is_transparent_color(color): + color = colors['markerfacecolor'] + else: + raise ValueError(f'Unsupported plot format: {plot_format}') + + zorder = plot_handle.get_zorder() + if self.config['error_on_top']: + error_styles_current['zorder'] = zorder + 0.1 + if self.config['inherit_color']: + error_styles_current['color'] = color + + if error_styles_current.get('color', None) is None: + error_styles_current.pop('color', None) + + if error_format == ErrorDisplayFormat.FILL: + handle = self.draw_filled_error(ax, histogram_current, styles=error_styles_current) + error_handles.append(handle) + elif error_format == ErrorDisplayFormat.SHADE: + handle = self.draw_shaded_error(ax, histogram_current, styles=error_styles_current) + error_handles.append(handle) + elif (error_format == ErrorDisplayFormat.ERRORBAR) and (plot_format != PlotFormat.ERRORBAR): + error_styles_current = mp.concat((error_styles_current, {'marker': 'none'}), copy=True) + handle = self.draw_errorbar( + ax, + histogram_current, + styles=error_styles_current, + with_error=True + ) + error_handles.append(handle) + + combine_stacked_error = self.config['combine_stacked_error'] + # must draw error for individual histogram when plot with errorbar + if plot_format == PlotFormat.ERRORBAR: + combine_stacked_error = False + + if isinstance(histogram, Histogram1D) or \ + (isinstance(histogram, StackedHistogram) and combine_stacked_error): + error_color = error_styles.get('color') + # use artist default color when drawn + if isinstance(error_color, list): + error_color = None + error_label = error_styles.get('label') + if isinstance(error_label, list): + error_label = self.label_map.get(f'{domain}.error', domain) + error_styles = mp.concat( + (error_styles, {'color': error_color, 'label': error_label}), + copy=True + ) + custom_draw(histogram, styles, error_styles) + elif isinstance(histogram, StackedHistogram): + def make_list(option: Any) -> List[Any]: + if option is None: + return [None] * histogram.count + if not isinstance(option, list): + return [option] * histogram.count + if len(option) != histogram.count: + raise ValueError( + f"Option list length ({len(option)}) does not match histogram count ({histogram.count})" + ) + return option + + colors = make_list(styles.get('color', None)) + labels = make_list(styles.get('label', None)) + error_colors = make_list(error_styles.get('color', None)) + error_labels = make_list(error_styles.get('label', None)) + + for i, (_, histogram_i) in enumerate(histogram.offset_histograms): + styles_i = mp.concat( + (styles, {'color': colors[i], 'label': labels[i]}), + copy=True + ) + error_styles_i = mp.concat( + (error_styles, {'color': error_colors[i], 'label': error_labels[i]}), + copy=True + ) + custom_draw(histogram_i, styles_i, error_styles_i) + + handles = {} + # there should be one-to-one correspondence between plot handle and error handle + # except when plotting stacked histograms but showing merged errors + if isinstance(histogram, StackedHistogram) and combine_stacked_error: + if len(plot_handles) != histogram.count or (len(error_handles) != 1 and histogram.has_errors()): + raise ValueError( + f"Mismatch in handle counts. Expected {histogram.count} plot handles " + f"and 1 error handle, got {len(plot_handles)} and {len(error_handles)}" + ) + for name, handle in zip(histogram.histograms.keys(), plot_handles): + handles[f'{domain}.{name}'] = handle + if histogram.has_errors(): + handles[f'{domain}.error'] = error_handles[0] + else: + if plot_format == PlotFormat.ERRORBAR and error_format == ErrorDisplayFormat.ERRORBAR: + error_handles = [None] * len(plot_handles) + elif len(plot_handles) != len(error_handles): + raise ValueError( + f"Mismatch in handle counts: {len(plot_handles)} plot handles " + f"vs {len(error_handles)} error handles" + ) + + if isinstance(histogram, StackedHistogram): + keys = [f'{domain}.{name}' for name in histogram.histograms.keys()] + else: + keys = [domain] + + isolate_error_legend = self.config['isolate_error_legend'] + for key, plot_handle, error_handle in zip(keys, plot_handles, error_handles): + # case histogram plot with histtype = 'step' or 'stepfilled' + if isinstance(plot_handle, list): + plot_handle = plot_handle[0] + + if error_handle is None: + handles[key] = plot_handle + else: + if isolate_error_legend: + handles[key] = plot_handle + handles[f'{key}.error'] = error_handle + else: + handles[key] = ( + (plot_handle, error_handle) + if self.config['error_on_top'] + else (error_handle, plot_handle) + ) + + if not self.config['box_legend_handle']: + handles = { + key: remake_handles([handle], polygon_to_line=True, fill_border=False)[0] + for key, handle in handles.items() + } + + return handles + + except Exception as e: + raise ValueError(f"Failed to draw histogram data: {str(e)}") from e \ No newline at end of file diff --git a/quickstats/plots/likelihood_1D_plot.py b/quickstats/plots/likelihood_1D_plot.py index 73763004f46a789789c93058f87821e005b9c66b..99c7ccd190eb60355580f67de4396708f305a909 100644 --- a/quickstats/plots/likelihood_1D_plot.py +++ b/quickstats/plots/likelihood_1D_plot.py @@ -1,196 +1,246 @@ -from typing import Dict, Optional, Union, List +from __future__ import annotations + +from typing import Dict, Optional, Union, List, Tuple, Any + import pandas as pd import numpy as np -from quickstats.plots import General1DPlot -from quickstats.plots.template import create_transform -from quickstats.utils.common_utils import combine_dict +from quickstats.core import mappings as mp +from quickstats.maths.interpolation import get_intervals +from .general_1D_plot import General1DPlot +from .likelihood_mixin import LikelihoodMixin +from .template import create_transform +from .colors import ColormapType + +class Likelihood1DPlot(LikelihoodMixin, General1DPlot): + """ + Class for plotting 1D likelihood scans with confidence levels. + """ -class Likelihood1DPlot(General1DPlot): + DOF: int = 1 + COLOR_CYCLE: str = "atlas_hdbs" - STYLES = { - 'annotation':{ + STYLES: Dict[str, Any] = { + 'plot': { + 'marker': 'none' + }, + 'annotation': { 'fontsize': 20 }, - 'text':{ + 'text': { 'fontsize': 20 - } - } - - CONFIG = { - # intervals to include in the plot - "interval_formats": { - "68_95" : ('0.68', '0.95'), - "one_two_sigma" : ('1sigma', '2sigma') }, - 'sigma_line_styles':{ + 'level_line': { 'color': 'gray', 'linestyle': '--' }, - 'sigma_text_styles':{ + 'level_text': { 'x': 0.98, 'ha': 'right', - 'color': 'gray' - }, - 'sigma_interval_styles':{ + 'color': 'gray' + }, + 'level_interval': { 'loc': (0.2, 0.4), 'main_text': '', - 'sigma_text': r'{sigma_label}: {xlabel}$\in {intervals}$', + 'interval_text': r'{level_label}: {xlabel}$\in {intervals}$', 'dy': 0.05, - 'decimal_place': 2 - }, - 'errorband_legend': True + 'decimal_place': 2 + } } - coverage_proba_data = { - '0.68': { - 'qmu': 0.99, - 'label': '68% CL' - }, - '1sigma': { - 'qmu': 1, # 68.2% - 'label': '1 $\sigma$' - }, - '0.95': { - 'qmu': 3.84, - 'label': '95% CL' - }, - '2sigma': { - 'qmu': 4, - 'label': '2 $\sigma$' + LABEL_MAP = { + 'confidence_level': '{level:.0%} CL', + 'sigma_level': r'{level:.0g} $\sigma$', + } + + CONFIG = { + 'level_key': { + 'confidence_level': '{level_str}_CL', + 'sigma_level': '{level_str}_sigma', } } - - def __init__(self, data_map:Union[pd.DataFrame, Dict[str, pd.DataFrame]], - label_map:Optional[Dict]=None, - styles_map:Optional[Dict]=None, - color_cycle=None, - styles:Optional[Union[Dict, str]]=None, - analysis_label_options:Optional[Dict]=None, - config:Optional[Dict]=None): - super().__init__(data_map=data_map, - label_map=label_map, - styles_map=styles_map, - color_cycle=color_cycle, - styles=styles, - analysis_label_options=analysis_label_options, - config=config) - self.intervals = {} - - def get_sigma_levels_values_labels(self, interval_format:Union[str, List[str]]="one_two_sigma"): - if isinstance(interval_format, (list, tuple)): - sigma_levels = list(interval_format) - else: - if interval_format not in self.config['interval_formats']: - choices = ','.join([f'"{choice}"' for choice in self.config['interval_formats']]) - raise ValueError(f'undefined sigma interval format: {interval_format} (choose from {choices})') - sigma_levels = self.config['interval_formats'][interval_format] - sigma_values = [] - sigma_labels = [] - for sigma_level in sigma_levels: - if sigma_level not in self.coverage_proba_data: - raise RuntimeError(f'undefined sigma level: {sigma_level}') - sigma_values.append(self.coverage_proba_data[sigma_level]['qmu']) - sigma_labels.append(self.coverage_proba_data[sigma_level]['label']) - return sigma_levels, sigma_values, sigma_labels - - def get_sigma_intervals(self, x:np.ndarray, y:np.ndarray, - interval_format:Union[str, List[str]]="one_two_sigma"): - from quickstats.maths.interpolation import get_intervals - sigma_levels, sigma_values, sigma_labels = self.get_sigma_levels_values_labels(interval_format) - sigma_intervals = {} - for i, (sigma_level, sigma_value, sigma_label) in enumerate(zip(sigma_levels, sigma_values, sigma_labels)): - intervals = get_intervals(x, y, sigma_value) - sigma_intervals[sigma_level] = intervals - return sigma_intervals - - def get_bestfit(self, x:np.ndarray, y:np.ndarray): + + def __init__( + self, + data_map: Union[pd.DataFrame, Dict[str, pd.DataFrame]], + color_cycle: Optional[ColormapType] = None, + label_map: Optional[Dict[str, str]] = None, + styles: Optional[Dict[str, Any]] = None, + styles_map: Optional[Dict[str, Dict[str, Any]]] = None, + analysis_label_options: Optional[Dict[str, Any]] = None, + config: Optional[Dict[str, Any]] = None + ) -> None: + self.intervals: Dict[str, List[Tuple[float, float]]] = {} + super().__init__( + data_map=data_map, + color_cycle=color_cycle, + label_map=label_map, + styles=styles, + styles_map=styles_map, + analysis_label_options=analysis_label_options, + config=config + ) + + def reset_metadata(self) -> None: + """Reset plot metadata.""" + super().reset_metadata() + self.intervals.clear() + + def get_bestfit(self, x: np.ndarray, y: np.ndarray) -> Tuple[float, float]: + """Find the best fit point (minimum of likelihood).""" bestfit_idx = np.argmin(y) - bestfit_x = x[bestfit_idx] - bestfit_y = y[bestfit_idx] - return bestfit_x, bestfit_y + return x[bestfit_idx], y[bestfit_idx] + + def get_intervals(self, x:np.ndarray, y:np.ndarray, + level_specs: Optional[List[Dict[str, Any]]] = None): + if not level_specs: + return + intervals = {} + for key, spec in level_specs.items(): + intervals[key] = get_intervals(x, y, spec['chi2']) + return intervals + + def draw_level_lines( + self, + ax: Any, + level_specs: Optional[Dict[str, Dict[str, Any]]] = None + ) -> None: + """Draw horizontal lines indicating confidence/sigma levels.""" + if not level_specs: + return + + line_styles = mp.concat((self.styles['line'], self.styles['level_line']), copy=True) + text_styles = mp.concat((self.styles['text'], self.styles['level_text']), copy=True) - def draw_sigma_lines(self, ax, interval_format:Union[str, List[str]]="one_two_sigma"): - sigma_line_styles = self.config['sigma_line_styles'] - sigma_levels, sigma_values, sigma_labels = self.get_sigma_levels_values_labels(interval_format) + x_pos = text_styles.get('x', 0.98) + text_styles.setdefault('va', 'bottom' if (0 < x_pos < 1) else 'center') transform = create_transform(transform_x="axis", transform_y="data") - ax.hlines(sigma_values, xmin=0, xmax=1, zorder=0, transform=transform, - **self.config['sigma_line_styles']) - styles = combine_dict(self.styles['text'], self.config['sigma_text_styles']) - if 'va' not in styles: - styles['va'] = 'bottom' if ((styles['x'] > 0) and (styles['x'] < 1)) else 'center' - for sigma_value, sigma_label in zip(sigma_values, sigma_labels): - ax.text(y=sigma_value, s=sigma_label, **styles, transform=transform) + + for spec in level_specs.values(): + ax.hlines(spec['chi2'], xmin=0, xmax=1, zorder=0, + transform=transform, **line_styles) + ax.text(y=spec['chi2'], s=spec['label'], + transform=transform, **text_styles) - def draw_sigma_intervals(self, ax, x:np.ndarray, y:np.ndarray, xlabel:str="", - interval_format:Union[str, List[str]]="one_two_sigma"): - from quickstats.maths.interpolation import get_intervals - sigma_levels, sigma_values, sigma_labels = self.get_sigma_levels_values_labels(interval_format) - sigma_intervals = self.get_sigma_intervals(x, y, interval_format=interval_format) - self.intervals = sigma_intervals - styles = self.config['sigma_interval_styles'] - loc = styles['loc'] - dp = styles['decimal_place'] - dy = styles['dy'] - sigma_text = styles['sigma_text'] + def draw_level_intervals( + self, + ax: Any, + x: np.ndarray, + y: np.ndarray, + xlabel: str = "", + level_specs: Optional[Dict[str, Dict[str, Any]]] = None, + domain: Optional[str] = None + ) -> None: + """Draw confidence/sigma interval annotations.""" + if not level_specs: + return + + self.intervals = self.get_intervals(x, y, level_specs) # do not draw when no intervals available - if all(len(intervals) == 0 for intervals in sigma_intervals.values()): - return None - ax.annotate(styles['main_text'], loc, xycoords='axes fraction', **self.styles['annotation']) - for i, (sigma_level, sigma_label) in enumerate(zip(sigma_levels, sigma_labels)): - sigma_interval = sigma_intervals[sigma_level] - if len(sigma_interval) == 0: + if all(not len(intervals) for intervals in self.intervals.values()): + return + + styles = self.styles['level_interval'] + if domain is not None: + domain_styles = self.styles_map.get(domain, {}).get('level_interval') + styles = mp.concat((styles, domain_styles)) + + ax.annotate(styles['main_text'], styles['loc'], + xycoords='axes fraction', **self.styles['annotation']) + + for i, (key, spec) in enumerate(level_specs.items()): + intervals = self.intervals.get(key) + if not len(intervals): continue - interval_str = r" \cup ".join([f"[{lo:.{dp}f}, {hi:.{dp}f}]" for (lo, hi) in sigma_interval]) - interval_str = interval_str.replace('-inf', r'N.A.').replace('inf', 'N.A.') - text = sigma_text.format(sigma_label=sigma_label, - xlabel=xlabel, - intervals=interval_str) - ax.annotate(text, (loc[0], loc[1] - (i + 1) * dy), - xycoords='axes fraction', **self.styles['annotation']) - - def draw(self, xattrib:str='mu', yattrib:str='qmu', xlabel:Optional[str]=None, - ylabel:Optional[str]="$-2\Delta ln(L)$", targets:Optional[List[str]]=None, - ymin:float=0, ymax:float=7, xmin:Optional[float]=None, xmax:Optional[float]=None, - draw_sigma_line:bool=True, - #draw_sm_line:bool=False, - draw_sigma_intervals:Union[str, bool]=False, - interval_format:Union[str, List[str]]="one_two_sigma"): - # ylabel = "$-2\Delta ln(L)$" - ax = super().draw(xattrib=xattrib, yattrib=yattrib, - xlabel=xlabel, ylabel=ylabel, targets=targets, - ymin=ymin, ymax=ymax, xmin=xmin, xmax=xmax) - - if draw_sigma_line: - self.draw_sigma_lines(ax, interval_format=interval_format) + + interval_str = self._format_intervals(intervals, styles['decimal_place']) + text = styles['interval_text'].format( + level_label=spec['label'], + xlabel=xlabel, + intervals=interval_str + ) + y_pos = styles['loc'][1] - (i + 1) * styles['dy'] + ax.annotate(text, (styles['loc'][0], y_pos), + xycoords='axes fraction', **self.styles['annotation']) + + @staticmethod + def _format_intervals(intervals: List[Tuple[float, float]], dp: int) -> str: + """Format intervals for display.""" + parts = [f"[{lo:.{dp}f}, {hi:.{dp}f}]" for lo, hi in intervals] + return r" \cup ".join(parts).replace('-inf', 'N.A.').replace('inf', 'N.A.') + + def draw( + self, + xattrib: str, + yattrib: str = 'qmu', + xlabel: Optional[str] = None, + ylabel: Optional[str] = r"$-2\Delta ln(L)$", + targets: Optional[List[str]] = None, + ymin: float = 0, + ymax: float = 7, + xmin: Optional[float] = None, + xmax: Optional[float] = None, + draw_level_lines: bool = True, + draw_level_intervals: bool = False, + sigma_levels: Optional[Tuple[float, ...]] = (1, 2), + confidence_levels: Optional[Tuple[float, ...]] = None + ) -> Any: """ - if draw_sm_line: - transform = create_transform(transform_y="axis", transform_x="data") - sm_line_styles = self.config['sm_line_styles'] - sm_values = self.config['sm_values'] - sm_names = self.config['sm_names'] - ax.vlines(sm_values, ymin=0, ymax=1, zorder=0, transform=transform, - **sm_line_styles) - if sm_names: - sm_pos = self.config['sm_pos'] - for sm_value, sm_name in zip(sm_values, sm_names): - ax.text(sm_value, sm_pos, sm_name, color='gray', ha='right', rotation=90, - va='bottom' if (sm_pos > 0 and sm_pos < 1) else 'center', fontsize=20, transform=transform) + Draw likelihood profile plot. + + Parameters + ---------- + xattrib : str + Column name for x values + yattrib : str, default 'qmu' + Column name for likelihood values + xlabel, ylabel : Optional[str] + Axis labels + targets : Optional[List[str]] + Targets to plot + ymin, ymax : float + Y-axis limits + xmin, xmax : Optional[float] + X-axis limits + draw_level_lines : bool, default True + Draw horizontal lines for confidence levels + draw_level_intervals : bool, default False + Draw confidence interval annotations + sigma_levels : Optional[Tuple[float, ...]], default (1, 2) + Sigma levels to indicate + confidence_levels : Optional[Tuple[float, ...]] + Confidence levels to indicate """ - if draw_sigma_intervals: - if isinstance(self.data_map, pd.DataFrame): - x = self.data_map[xattrib].values - y = self.data_map[yattrib].values - elif isinstance(self.data_map, dict): - if not isinstance(draw_sigma_intervals, str): - raise RuntimeError("name of the target likelihood curve must be specified " - "when drawing sigma intervals") - target = draw_sigma_intervals - x = self.data_map[target][xattrib].values - y = self.data_map[target][yattrib].values - else: - raise ValueError("invalid data format") - self.draw_sigma_intervals(ax, x, y, xlabel=xlabel, interval_format=interval_format) - return ax + targets = self.resolve_targets(targets) + ax = super().draw( + xattrib=xattrib, + yattrib=yattrib, + xlabel=xlabel, + ylabel=ylabel, + targets=targets, + ymin=ymin, + ymax=ymax, + xmin=xmin, + xmax=xmax + ) + + level_specs = self.get_level_specs(sigma_levels, confidence_levels) + + if draw_level_lines: + self.draw_level_lines(ax, level_specs) + + if draw_level_intervals: + for target in targets: + data = self.data_map[target] + self.draw_level_intervals( + ax, + data[xattrib].values, + data[yattrib].values, + xlabel=xlabel, + level_specs=level_specs, + domain=target + ) + + return ax \ No newline at end of file diff --git a/quickstats/plots/likelihood_2D_plot.py b/quickstats/plots/likelihood_2D_plot.py index a14bc31362e2c11344112f4bc0960ee01c93a81e..122165ec8a46e1c78e6cd8c2dfec2c345a5c816d 100644 --- a/quickstats/plots/likelihood_2D_plot.py +++ b/quickstats/plots/likelihood_2D_plot.py @@ -1,51 +1,38 @@ -from typing import Dict, Optional, Union, List, Sequence +from __future__ import annotations -from functools import partial -from itertools import repeat +from typing import Dict, Optional, Union, List, Tuple, Any import numpy as np import pandas as pd -import matplotlib.pyplot as plt - -from quickstats.plots import AbstractPlot -from quickstats.plots.template import create_transform, format_axis_ticks, isolate_contour_styles -from quickstats.utils.common_utils import combine_dict -from quickstats.utils.string_utils import remove_neg_zero -from quickstats.maths.interpolation import get_regular_meshgrid +from matplotlib.axes import Axes from matplotlib.lines import Line2D from matplotlib.patches import Polygon, Rectangle +from quickstats.core import mappings as mp +from quickstats.maths.numerics import get_nan_shapes +from quickstats.utils.string_utils import remove_neg_zero +from .general_2D_plot import General2DPlot +from .likelihood_mixin import LikelihoodMixin +from .colors import ColorType, ColormapType -class Likelihood2DPlot(AbstractPlot): - - STYLES = { - 'pcolormesh': { - 'cmap': 'GnBu', - 'shading': 'auto', - 'rasterized': True - }, - 'colorbar': { - 'pad': 0.02, - }, - 'contour': { - 'linestyles': 'solid', - 'linewidths': 3 - }, - 'contourf': { - 'alpha': 0.5, - 'zorder': 0 - }, +class Likelihood2DPlot(LikelihoodMixin, General2DPlot): + """ + Class for plotting 2D likelihood scans with confidence levels. + """ + + DOF: int = 2 + + COLOR_CYCLE: str = 'default' + + STYLES: Dict[str, Any] = { 'polygon': { 'fill': True, 'hatch': '/', 'alpha': 0.5, 'color': 'gray' }, - 'alphashape': { - 'alpha': 2 - }, 'bestfit': { - 'marker': 'P', + 'marker': '*', 'linewidth': 0, 'markersize': 15 }, @@ -55,394 +42,529 @@ class Likelihood2DPlot(AbstractPlot): 'markersize': 20, 'color': '#E9F1DF', 'markeredgecolor': 'black' + }, + 'contourf': { + 'extend': 'min' } } - LABEL_MAP = { - 'contour': '{sigma_label}', + COLOR_MAP: Dict[str, str] = { + '1_sigma': 'hh:darkblue', + '2_sigma': '#F2385A', + '3_sigma': '#FDC536', + '0p68_CL': 'hh:darkblue', + '0p95_CL': '#F2385A', + '0p99_CL': '#FDC536', + 'contour.1_sigma': '#000000', + 'contour.2_sigma': '#F2385A', + 'contourf.1_sigma': '#4AD9D9', + 'contourf.2_sigma': '#FDC536', + } + + LABEL_MAP: Dict[str, str] = { + 'confidence_level': '{level:.0%} CL', + 'sigma_level': r'{level:.0g} $\sigma$', 'bestfit': 'Best fit ({x:.2f}, {y:.2f})', - 'polygon': 'Nan NLL region' + 'polygon': 'Nan NLL region', } - CONFIG = { - # intervals to include in the plot - 'interval_formats': { - "68_95" : ('0.68', '0.95'), - "one_two_sigma" : ('1sigma', '2sigma'), - "68_95_99" : ('0.68', '0.95', '0.99'), - "one_two_three_sigma" : ('1sigma', '2sigma', '3sigma') - }, + CONFIG: Dict[str, Any] = { 'interpolation': 'cubic', 'num_grid_points': 500, - 'sm_values': None, - 'sm_line_styles': {} - } - - # qmu from https://pdg.lbl.gov/2018/reviews/rpp2018-rev-statistics.pdf#page=31 - COVERAGE_PROBA_DATA = { - '0.68': { - 'qmu': 2.30, - 'label': '68% CL', - 'color': "hh:darkblue" - }, - '1sigma': { - 'qmu': 2.30, # 68.2% - 'label': r'1 $\sigma$', - 'color': "hh:darkblue" - }, - '0.90': { - 'qmu': 4.61, - 'label': '90% CL', - 'color': "#36b1bf" - }, - '0.95': { - 'qmu': 5.99, - 'label': '95% CL', - 'color': "#F2385A" - }, - '2sigma': { - 'qmu': 6.18, # 95.45% - 'label': r'2 $\sigma$', - 'color': "#F2385A" + 'level_key': { + 'confidence_level': '{level_str}_CL', + 'sigma_level': '{level_str}_sigma', }, - '0.99': { - 'qmu': 9.21, - 'label': '99% CL', - 'color': "#FDC536" - }, - '3sigma': { - 'qmu': 11.83, # 99.73% - 'label': r'3 $\sigma$', - 'color': "#FDC536" - } + 'remove_nan_points_within_distance': None, + 'shade_nan_points': False, + 'alphashape_alpha': 2 } - def __init__(self, data_map: Union[pd.DataFrame, Dict[str, pd.DataFrame]], - label_map: Optional[Dict] = None, - color_cycle: Optional[Union[List, str, "ListedColorMap"]]=None, - styles: Optional[Union[Dict, str]] = None, - styles_map: Optional[Dict] = None, - analysis_label_options: Optional[Dict] = None, - config: Optional[Dict] = None): + def __init__( + self, + data_map: Union[pd.DataFrame, Dict[str, pd.DataFrame]], + color_map: Optional[Dict[str, ColorType]] = None, + color_cycle: Optional[ColormapType] = None, + label_map: Optional[Dict[str, str]] = None, + styles: Optional[Dict[str, Any]] = None, + styles_map: Optional[Dict[str, Union[Dict[str, Any], str]]] = None, + analysis_label_options: Optional[Union[Dict[str, Any], str]] = None, + config: Optional[Dict[str, Any]] = None, + ) -> None: + self.nan_shapes = {} + super().__init__( + data_map=data_map, + color_map=color_map, + color_cycle=color_cycle, + label_map=label_map, + styles=styles, + styles_map=styles_map, + analysis_label_options=analysis_label_options, + config=config, + ) - self.data_map = data_map - self.styles_map = combine_dict(styles_map) - self.coverage_proba_data = combine_dict(self.COVERAGE_PROBA_DATA) - self.highlight_data = [] - super().__init__(label_map=label_map, - styles=styles, - color_cycle=color_cycle, - analysis_label_options=analysis_label_options, - config=config) - - def get_sigma_levels(self, interval_format:str="one_two_three_sigma"): - if interval_format not in self.config['interval_formats']: - choices = ','.join([f'"{choice}"' for choice in self.config['interval_formats']]) - raise ValueError(f'undefined sigma interval format: {interval_format} (choose from {choices})') - sigma_levels = self.config['interval_formats'][interval_format] - return sigma_levels + def reset_metadata(self) -> None: + """Reset plot metadata including NaN shapes.""" + super().reset_metadata() + self.nan_shapes = {} + + def _get_color_for_level( + self, + target: str, + artist: str, + key: str, + used_colors: List[str], + default_colors: List[str], + color_index: int + ) -> Tuple[str, int]: + """Get color for contour level.""" + for domain in [ + self.color_map.format(target, artist, key), + self.color_map.format(target, key), + self.color_map.format(artist, key), + self.color_map.format(key) + ]: + if domain in self.color_map: + color = self.color_map[domain] + if color not in used_colors: + return color, color_index + + if color_index >= len(default_colors): + self.stdout.warning( + 'Number of colors required exceeds available colors. Recycling colors.' + ) + color = default_colors[color_index % len(default_colors)] + return color, color_index + 1 - def get_nan_shapes(self, data: pd.DataFrame, - xattrib: str, yattrib: str, - zattrib: str = 'qmu'): - df_nan = data[np.isnan(data[zattrib])] - xy = df_nan[[xattrib, yattrib]].values - import alphashape - shape = alphashape.alphashape(xy, alpha=self.config['alphashape_alpha']) - if hasattr(shape, 'geoms'): - shapes = [s for s in shape.geoms] + def resolve_target_styles(self, **kwargs) -> Dict[str, Dict[str, Any]]: + """Resolve styles for targets with level-specific colors.""" + level_specs = kwargs.pop('level_specs', {}) + if level_specs: + contour_levels = [spec['chi2'] for spec in level_specs.values()] else: - shapes = [shape] - return shapes - - def draw_shades(self, ax, shapes): - if len(shapes) == 0: - return None + contour_levels = None + + target_styles = super().resolve_target_styles(contour_levels=contour_levels, **kwargs) + + default_colors = self.get_colors() + color_index = 0 + + for target, styles in target_styles.items(): + for artist in ['contour', 'contourf']: + if artist not in styles or 'colors' in styles[artist]: + continue + + colors = [] + for key in level_specs: + color, color_index = self._get_color_for_level( + target, artist, key, colors, + default_colors, color_index + ) + colors.append(color) + + styles[artist]['colors'] = colors + styles[artist].pop('cmap') + + return target_styles + + def get_bestfit( + self, + x: np.ndarray, + y: np.ndarray, + z: np.ndarray + ) -> Tuple[float, float, float]: + """Find the point of minimum likelihood.""" + mask = (~np.isnan(z)) & (z >= 0.) + x, y, z = x[mask], y[mask], z[mask] + bestfit_idx = np.argmin(z) + return x[bestfit_idx], y[bestfit_idx], z[bestfit_idx] + + def _remove_points_near_shapes( + self, + X: np.ndarray, + Y: np.ndarray, + Z: np.ndarray, + shapes: List[Any], + distance: float + ) -> np.ndarray: + """Remove points within distance of NaN shapes.""" + from shapely import Point + + XY = np.column_stack((X.ravel(), Y.ravel())) + mask = np.full(XY.shape[0], False) + for shape in shapes: - x, y = shape.exterior.coords.xy - xy = np.column_stack((np.array(x).ravel(), np.array(y).ravel())) - polygon = Polygon(xy, **self.config['polygon_styles'], - label=self.config['polygon_label']) - ax.add_patch(polygon) - if 'shade' not in self.legend_data: - self.update_legend_handles({'shade': polygon}) - self.legend_order.append('shade') - - def draw_single_data(self, ax, data: pd.DataFrame, - xattrib: str, yattrib: str, - zattrib: str = 'qmu', - config: Optional[Dict] = None, - styles: Optional[Dict] = None, - draw_contour: bool = True, - draw_contourf: bool =False, - draw_colormesh: bool = False, - draw_clabel: bool = False, - draw_colorbar: bool =True, - clabel_size=None, - interval_format:str="one_two_three_sigma", - remove_nan_points_within_distance:Optional[float]=None, - shade_nan_points:bool=False, - domain: Optional[str] = None): - - handles = {} - sigma_handles = {} - if config is None: - config = self.config - if styles is None: - styles = self.styles + x_ext, y_ext = shape.exterior.coords.xy + min_x, max_x = np.min(x_ext) - distance, np.max(x_ext) + distance + min_y, max_y = np.min(y_ext) - distance, np.max(y_ext) + distance + # only focus on points within the largest box formed by the convex hull + distance + box_mask = (((XY[:, 0] > min_x) & (XY[:, 0] < max_x)) & + ((XY[:, 1] > min_y) & (XY[:, 1] < max_y))) + points_in_box = XY[box_mask] + # remove points inside the polygon + inside_mask = np.array([shape.contains(Point(xy)) for xy in points_in_box]) + points_outside = points_in_box[~inside_mask] + # remove points within distance d of the polygon + near_mask = np.array([ + shape.exterior.distance(Point(xy)) < distance + for xy in points_outside + ]) - sigma_levels = self.get_sigma_levels(interval_format=interval_format) - sigma_values = [self.coverage_proba_data[level]['qmu'] for level in sigma_levels] - sigma_labels = [self.coverage_proba_data[level]['label'] for level in sigma_levels] - sigma_colors = [self.coverage_proba_data[level]['color'] for level in sigma_levels] - - contour_labels = [] - for sigma_label in sigma_labels: - contour_label_fmt = self.get_label('contour', domain=domain) - if not contour_label_fmt: - contour_label_fmt = self.get_label('contour') - contour_label = contour_label_fmt.format(sigma_label=sigma_label) - contour_labels.append(contour_label) - - interpolate_method = self.config.get('interpolation', None) - if interpolate_method: - from scipy import interpolate - x, y, z = data[xattrib], data[yattrib], data[zattrib] - # remove nan data - mask = ~np.isnan(z) - x, y, z = x[mask], y[mask], z[mask] + box_indices = np.arange(mask.shape[0])[box_mask] + mask[box_indices[inside_mask]] = True + outside_indices = box_indices[~inside_mask] + mask[outside_indices[near_mask]] = True - n = self.config.get('num_grid_points', 500) - X, Y = get_regular_meshgrid(x, y, n=n) - Z = interpolate.griddata(np.stack((x, y), axis=1), z, (X, Y), interpolate_method) - else: - X_unique = np.sort(data[xattrib].unique()) - Y_unique = np.sort(data[yattrib].unique()) - X, Y = np.meshgrid(X_unique, Y_unique) - Z = (data.pivot_table(index=xattrib, columns=yattrib, values=zattrib).T.values - - data[zattrib].min()) - - # deal with regions with undefined likelihood - if (remove_nan_points_within_distance is not None) or (shade_nan_points): - nan_shapes = self.get_nan_shapes(data, xattrib, yattrib, zattrib) + Z_new = Z.copy() + Z_new[mask.reshape(Z.shape)] = np.nan + return Z_new + + def get_interp_data( + self, + x: np.ndarray, + y: np.ndarray, + z: np.ndarray, + domain: Optional[str] = None, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Get interpolated data handling NaN regions.""" + X, Y, Z = super().get_interp_data(x, y, z) + + distance = self.config.get('remove_nan_points_within_distance') + shade_nan_points = self.config.get('shade_nan_points', False) + if distance or shade_nan_points: + alpha = self.config.get('alphashape_alpha', 2) + nan_shapes = get_nan_shapes(x, y, z, alpha=alpha) + self.nan_shapes[domain] = nan_shapes else: nan_shapes = None - if (remove_nan_points_within_distance is not None) and (len(nan_shapes) > 0): - if len(nan_shapes) > 0: - from shapely import Point - XY = np.column_stack((X.ravel(), Y.ravel())) - d = remove_nan_points_within_distance - for shape in nan_shapes: - x_ext, y_ext = shape.exterior.coords.xy - min_x_cutoff, max_x_cutoff = np.min(x_ext) - d, np.max(x_ext) + d - min_y_cutoff, max_y_cutoff = np.min(y_ext) - d, np.max(y_ext) + d - # only focus on points within the largest box formed by the convex hull + distance - box_mask = (((XY[:, 0] > min_x_cutoff) & (XY[:, 0] < max_x_cutoff)) & - ((XY[:, 1] > min_y_cutoff) & (XY[:, 1] < max_y_cutoff))) - mask = np.full(box_mask.shape, False) - XY_box = XY[box_mask] - # remove points inside the polygon - mask_int = np.array([shape.contains(Point(xy)) for xy in XY_box]) - XY_box_ext = XY_box[~mask_int] - # remove points within distance d of the polygon - mask_int_d = np.array([shape.exterior.distance(Point(xy)) < d for xy in XY_box_ext]) - slice_int = np.arange(mask.shape[0])[box_mask][mask_int] - slice_int_d = np.arange(mask.shape[0])[box_mask][~mask_int][mask_int_d] - mask[slice_int] = True - mask[slice_int_d] = True - Z[mask.reshape(Z.shape)] = np.nan - - if draw_colormesh: - pcm = ax.pcolormesh(X, Y, Z, **styles['pcolormesh']) - handles['pcm'] = pcm - - if sigma_values: - if draw_contour: - contour_styles = combine_dict(styles['contour']) - if 'colors' not in contour_styles: - contour_styles['colors'] = sigma_colors - contour = ax.contour(X, Y, Z, levels=sigma_values, **contour_styles) - handles['contour'] = contour + if distance and nan_shapes: + Z = self._remove_points_near_shapes(X, Y, Z, nan_shapes, distance) + + return X, Y, Z + + def select_colorbar_target(self, handles: Dict[str, Any]) -> Optional[Any]: + """Select target for colorbar from available handles.""" + return handles.get('pcm') + + def _get_contour_styles( + self, + domain: Optional[str], + artist: str + ) -> Dict[str, Any]: + """Get styles for specified contour artist.""" + handle = self.get_handle(self.legend_data.format(domain, artist)) + if handle: + return { + 'facecolors': handle.get_facecolor(), + 'edgecolors': handle.get_edgecolor(), + 'linestyles': getattr(handle, 'linestyles', '-'), + 'linewidths': handle.get_linewidth(), + 'hatches': getattr(handle, 'hatches', None) + } + return {} - if draw_clabel: - clabel = ax.clabel(contour, **styles['clabel']) - handles['clabel'] = clabel - - # handle for individual contour level - sigma_contour_styles = isolate_contour_styles(contour_styles) - for i, (styles_, label_, color_) in enumerate(zip(sigma_contour_styles, contour_labels, sigma_colors)): - kwargs = combine_dict(styles_) - kwargs['label'] = label_ - if 'color' not in kwargs: - kwargs['color'] = color_ - handle = Line2D([0], [0], **kwargs) - key = f'contour_level_{i}' - sigma_handles[key] = handle - if key not in self.legend_order: - self.legend_order.append(key) - - if draw_contourf: - contourf_styles = combine_dict(styles['contourf']) - if 'colors' not in contourf_styles: - contourf_styles['colors'] = sigma_colors - sigma_values_ = [-np.inf] + sigma_values - contourf = ax.contourf(X, Y, Z, levels=sigma_values_, **contourf_styles) - handles['contourf'] = contourf - - # handle for individual contourf level - sigma_contourf_styles = isolate_contour_styles(contourf_styles) - for styles_, label_, color_ in zip(sigma_contourf_styles, contour_labels, sigma_colors): - kwargs = combine_dict(styles_) - kwargs['label'] = label_ - if 'color' not in kwargs: - kwargs['color'] = color_ - kwargs['facecolor'] = kwargs.pop('color') - handle = Rectangle((0, 0), 1, 1, **kwargs) - key = f'contourf_{label_}' - sigma_handles[key] = handle - if key not in self.legend_order: - self.legend_order.append(key) - - if draw_colorbar: - if 'pcm' in handles: - mappable = pcm - elif 'contourf' in handles: - mappable = contourf - elif 'contour' in handles: - mappable = contour - else: - mappable = None - if mappable is not None: - cbar = plt.colorbar(mappable, ax=ax, **styles['colorbar']) - cbar.set_label(zlabel, **styles['colorbar_label']) - format_axis_ticks(cbar.ax, **styles['colorbar_axis']) - handles['cbar'] = cbar + def create_custom_handles( + self, + level_specs: Optional[Dict[str, Dict[str, Any]]] = None, + domain: Optional[str] = None + ) -> None: + """Create custom legend handles for contours.""" + if not level_specs: + return + + handles = {} + contour_styles = self._get_contour_styles(domain, 'contour') + contourf_styles = self._get_contour_styles(domain, 'contourf') - if shade_nan_points and (len(nan_shapes) > 0): - self.draw_shades(ax, nan_shapes) + def get_style(style_dict: Dict[str, Any], key: str, idx: int) -> Any: + value = style_dict.get(key) + if isinstance(value, (list, tuple, np.ndarray)) and len(value) > idx: + return value[idx] + return value + + for i, (key, spec) in enumerate(level_specs.items()): + if contour_styles and contourf_styles: + handle = Rectangle((0, 0), 1, 1, + label=spec['label'], + edgecolor=get_style(contour_styles, 'edgecolors', i), + facecolor=get_style(contourf_styles, 'facecolors', i), + linestyle=get_style(contour_styles, 'linestyles', i), + linewidth=get_style(contour_styles, 'linewidths', i), + hatch=get_style(contourf_styles, 'hatches', i) + ) + elif contour_styles: + handle = Line2D([0], [0], + label=spec['label'], + color=get_style(contour_styles, 'edgecolors', i), + linestyle=get_style(contour_styles, 'linestyles', i) + ) + elif contourf_styles: + handle = Rectangle((0, 0), 1, 1, + label=spec['label'], + facecolor=get_style(contourf_styles, 'facecolors', i), + hatch=get_style(contourf_styles, 'hatches', i) + ) + else: + continue + handles[key] = handle - self.update_legend_handles(handles, raw=True, domain=domain) - self.update_legend_handles(sigma_handles, domain=domain) - - def is_single_data(self): - return not isinstance(self.data_map, dict) - - def add_highlight(self, x: float, y: float, label: str = "SM prediction", - styles: Optional[Dict] = None): - highlight_data = { - 'x': x, - 'y': y, - 'label': label, - 'styles': styles - } - self.highlight_data.append(highlight_data) - - def resolve_targets(self, targets: Optional[List[str]] = None) -> List[Optional[str]]: - if targets is None: - targets = [None] if self.is_single_data() else list(self.data_map) - return targets - - def resolve_target_styles(self, targets: List[Optional[str]]): - target_styles = {} - for target in targets: - styles = self.styles_map.get(target, {}) - target_styles[target] = combine_dict(self.styles, styles) - return target_styles + self.update_legend_handles(handles, domain=domain) - def draw_highlight(self, ax, x, y, label:str, - styles:Optional[Dict]=None, - domain:Optional[str]=None): - if styles is None: - styles = self.styles['highlight'] - handle = ax.plot(x, y, label=label, **styles) - key = f'highlight_{label}' - self.update_legend_handles({key: handle[0]}, domain=domain) - if key not in self.legend_order: - self.legend_order.append(key) - - def draw(self, xattrib:str, yattrib:str, zattrib:str='qmu', - targets:Optional[List[str]]=None, - xlabel: Optional[str] = "", ylabel: Optional[str] = "", - zlabel: Optional[str] = r"$-2\Delta ln(L)$", - title: Optional[str] = None, - ymax:Optional[float]=None, ymin:Optional[float]=None, - xmin:Optional[float]=None, xmax:Optional[float]=None, - draw_contour: bool = True, - draw_contourf: bool = False, - draw_colormesh: bool = False, - draw_clabel: bool = False, - draw_colorbar: bool = False, - draw_bestfit:Union[List[str], bool]=True, - draw_sm_line: bool = False, - draw_legend: bool = True, - legend_order: Optional[List[str]] = None, - interval_format:str="one_two_sigma", - remove_nan_points_within_distance:Optional[float]=None, - shade_nan_points:bool=False): + def draw_bestfit( + self, + ax: Axes, + x: np.ndarray, + y: np.ndarray, + z: np.ndarray, + styles: Optional[Dict[str, Any]] = None, + domain: Optional[str] = None + ) -> None: + """Draw best fit point (minimum likelihood).""" + styles = mp.concat((self.styles.get('bestfit'), styles), copy=True) + bestfit_x, bestfit_y, bestfit_z = self.get_bestfit(x, y, z) + bestfit_label_fmt = self.get_domain_label('bestfit', domain=domain, fallback=True) + if bestfit_label_fmt: + bestfit_label = bestfit_label_fmt.format(x=bestfit_x, y=bestfit_y) + styles['label'] = remove_neg_zero(bestfit_label) + + handle = ax.plot(bestfit_x, bestfit_y, **styles) + self.update_legend_handles({'bestfit': handle}, domain=domain) + + def draw_nan_shapes( + self, + ax: Axes, + shapes: List[Any], + styles: Optional[Dict[str, Any]] = None, + domain: Optional[str] = None + ) -> None: + """Draw shapes around NaN regions.""" + if not shapes: + return + + styles = mp.concat((self.styles.get('polygon'), styles), copy=True) + self.get_domain_label('nan_shape', domain=domain, fallback=True) + styles['label'] = label + + handle = None + for shape in shapes: + x, y = shape.exterior.coords.xy + xy = np.column_stack((np.array(x).ravel(), np.array(y).ravel())) + polygon = Polygon(xy, **styles) + ax.add_patch(polygon) + if handle is None: + handle = polygon + + if handle is not None: + self.update_legend_handles({'nan_shape': handle}, domain=domain) + + def draw_single_data( + self, + ax: Axes, + x: np.ndarray, + y: np.ndarray, + z: np.ndarray, + zlabel: Optional[str] = None, + draw_colormesh: bool = True, + draw_contour: bool = False, + draw_contourf: bool = False, + draw_clabel: bool = False, + draw_scatter: bool = False, + draw_colorbar: bool = True, + draw_bestfit: bool = True, + styles: Optional[Dict[str, Any]] = None, + level_specs: Optional[Dict[str, Dict[str, Any]]] = None, + domain: Optional[str] = None, + ) -> None: + """Draw single dataset with all components.""" + styles = styles or {} + super().draw_single_data( + ax=ax, + x=x, + y=y, + z=z, + zlabel=zlabel, + draw_colormesh=draw_colormesh, + draw_contour=draw_contour, + draw_contourf=draw_contourf, + draw_clabel=draw_clabel, + draw_scatter=draw_scatter, + draw_colorbar=draw_colorbar, + styles=styles, + domain=domain + ) + + self.create_custom_handles(level_specs=level_specs, domain=domain) + + if draw_bestfit: + self.draw_bestfit(ax, x, y, z, styles=styles.get('bestfit'), domain=domain) + + if self.config['shade_nan_points']: + self.draw_nan_shapes( + ax, + self.nan_shapes.get(domain), + styles=styles.get('polygon'), + domain=domain + ) + + def draw( + self, + xattrib: str, + yattrib: str, + zattrib: str = 'qmu', + targets: Optional[List[str]] = None, + colorbar_targets: Optional[List[str]] = None, + legend_order: Optional[List[str]] = None, + xlabel: Optional[str] = None, + ylabel: Optional[str] = None, + zlabel: Optional[str] = r"$-2\Delta ln(L)$", + title: Optional[str] = None, + ymin: Optional[float] = None, + ymax: Optional[float] = None, + xmin: Optional[float] = None, + xmax: Optional[float] = None, + zmax: Optional[float] = None, + logx: bool = False, + logy: bool = False, + norm: Optional[Any] = None, + cmap: str = 'GnBu', + draw_colormesh: bool = False, + draw_contour: bool = True, + draw_contourf: bool = False, + draw_scatter: bool = False, + draw_clabel: bool = False, + draw_colorbar: bool = True, + draw_bestfit: bool = True, + sigma_levels: Optional[Tuple[float, ...]] = (1, 2), + confidence_levels: Optional[Tuple[float, ...]] = None + ) -> Axes: + """ + Draw 2D likelihood plot. + + Parameters + ---------- + xattrib : str + Column name for x values + yattrib : str + Column name for y values + zattrib : str, default 'qmu' + Column name for likelihood values + targets : Optional[List[str]], default None + Target datasets to plot + colorbar_targets : Optional[List[str]], default None + Targets to draw colorbars for + legend_order : Optional[List[str]], default None + Custom order for legend entries + xlabel, ylabel, zlabel : Optional[str], default None + Axis labels + title : Optional[str], default None + Plot title + ymin, ymax : Optional[float], default None + Y-axis limits + xmin, xmax : Optional[float], default None + X-axis limits + zmax : Optional[float], default None + Maximum z value for normalization + logx, logy : bool, default False + Use logarithmic scale for axes + norm : Optional[Any], default None + Custom normalization for colormap + cmap : str, default 'GnBu' + Colormap name + draw_colormesh : bool, default False + Draw pcolormesh plot + draw_contour : bool, default True + Draw contour lines + draw_contourf : bool, default False + Draw filled contours + draw_scatter : bool, default False + Draw scatter plot + draw_clabel : bool, default False + Add contour labels + draw_colorbar : bool, default True + Draw colorbar + draw_bestfit : bool, default True + Draw best fit point + sigma_levels : Optional[Tuple[float, ...]], default (1, 2) + Sigma levels to show + confidence_levels : Optional[Tuple[float, ...]], default None + Confidence levels to show + + Returns + ------- + matplotlib.axes.Axes + The plotted axes + + Raises + ------ + RuntimeError + If no targets to draw + ValueError + If incompatible normalization options specified + """ + self.reset_metadata() + ax = self.draw_frame(logx=logx, logy=logy) + targets = self.resolve_targets(targets) - target_styles = self.resolve_target_styles(targets=targets) - - self.reset_legend_data() - if legend_order is not None: - self.legend_order = legend_order - ax = self.draw_frame() + if not targets: + raise RuntimeError('No targets to draw') + + colorbar_targets = colorbar_targets or list(targets) - for target, styles in target_styles.items(): - if (target is None): - data = self.data_map - elif target in self.data_map: - data = self.data_map[target] - else: - raise RuntimeError(f'No input data found for the target "{target}".') - - self.draw_single_data(ax, data, xattrib=xattrib, yattrib=yattrib, - zattrib=zattrib, styles=styles, - draw_colormesh=draw_colormesh, - draw_contour=draw_contour, - draw_contourf=draw_contourf, - draw_clabel=draw_clabel, - draw_colorbar=draw_colorbar, - interval_format=interval_format, - remove_nan_points_within_distance=remove_nan_points_within_distance, - shade_nan_points=shade_nan_points, - domain=target) + if zmax is not None: + if norm is not None: + raise ValueError('Cannot specify both zmax and norm') + norm = Normalize(vmin=0, vmax=zmax) - if ((draw_bestfit is True) or - (isinstance(draw_bestfit, (list, tuple)) and target in draw_bestfit)): - valid_data = data.query(f'{zattrib} >= 0') - bestfit_idx = np.argmin(valid_data[zattrib].values) - bestfit_x = valid_data.iloc[bestfit_idx][xattrib] - bestfit_y = valid_data.iloc[bestfit_idx][yattrib] - bestfit_label_fmt = self.get_label('bestfit', domain=target) - if not bestfit_label_fmt: - bestfit_label_fmt = self.get_label('bestfit') - bestfit_label = bestfit_label_fmt.format(x=bestfit_x, y=bestfit_y) - bestfit_label = remove_neg_zero(bestfit_label) - self.draw_highlight(ax, bestfit_x, bestfit_y, - label=bestfit_label, - styles=styles['bestfit'], - domain=target) + if norm is None: + norm = self.get_global_norm(zattrib, targets) + + level_specs = self.get_level_specs( + sigma_levels=sigma_levels, + confidence_levels=confidence_levels + ) - if self.highlight_data: - for options in self.highlight_data: - self.draw_highlight(ax, **options) - - if draw_sm_line and self.config['sm_values'] is not None: - sm_x, sm_y = self.config['sm_values'] - transform = create_transform(transform_x="data", transform_y="axis") - ax.vlines(sm_x, ymin=0, ymax=1, zorder=0, transform=transform, - **self.config['sm_line_styles']) - transform = create_transform(transform_x="axis", transform_y="data") - ax.hlines(sm_y, xmin=0, xmax=1, zorder=0, transform=transform, - **self.config['sm_line_styles']) + target_styles = self.resolve_target_styles( + targets=targets, + norm=norm, + cmap=cmap, + level_specs=level_specs + ) + for target in targets: + data = self.data_map[target] + x = data[xattrib].values + y = data[yattrib].values + z = data[zattrib].values + + target_draw_colorbar = draw_colorbar and target in colorbar_targets + + self.draw_single_data( + ax, x, y, z, + zlabel=zlabel, + draw_colormesh=draw_colormesh, + draw_contour=draw_contour, + draw_contourf=draw_contourf, + draw_scatter=draw_scatter, + draw_clabel=draw_clabel, + draw_colorbar=target_draw_colorbar, + draw_bestfit=draw_bestfit, + styles=target_styles[target], + level_specs=level_specs, + domain=target + ) + self.draw_axis_components(ax, xlabel=xlabel, ylabel=ylabel, title=title) self.set_axis_range(ax, xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax) - - if draw_legend: - legend_domains = self.get_legend_domains() - self.draw_legend(ax, domains=legend_domains) - - return ax + self.finalize(ax) + + if legend_order is not None: + self.legend_order = legend_order + else: + self.legend_order = self.get_labelled_legend_domains() + + if self.config['draw_legend']: + self.draw_legend(ax) + + return ax \ No newline at end of file diff --git a/quickstats/plots/likelihood_mixin.py b/quickstats/plots/likelihood_mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..6aa8cf27b6c7fa091cc3b782fcf764003762e4a7 --- /dev/null +++ b/quickstats/plots/likelihood_mixin.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +from typing import Optional, Dict, Union, Tuple + +from quickstats.maths.numerics import str_encode_value +from quickstats.maths.statistics import sigma_to_chi2, confidence_level_to_chi2 +from .abstract_plot import AbstractPlot + +class LikelihoodMixin(AbstractPlot): + """ + Mixin class for likelihood plots providing statistical confidence level handling. + + This mixin adds functionality for working with confidence levels and sigma + levels in likelihood plots, including chi-square conversion and labeling. + """ + + DOF: int = 1 + + LABEL_MAP: Dict[str, str] = { + 'confidence_level': '{level:.0%} CL', + 'sigma_level': r'{level:.0g} $\sigma$', + } + + CONFIG: Dict[str, Dict[str, str]] = { + 'level_key': { + 'confidence_level': '{level_str}_CL', + 'sigma_level': '{level_str}_sigma', + } + } + + def get_chi2_value(self, level: Union[int, float], use_sigma: bool = False) -> float: + """Convert confidence/sigma level to chi-square value.""" + if use_sigma: + return sigma_to_chi2(level, k=self.DOF) + return confidence_level_to_chi2(level, self.DOF) + + def get_level_label(self, level: Union[int, float], use_sigma: bool = False) -> str: + """Get formatted label for confidence/sigma level.""" + key = 'sigma_level' if use_sigma else 'confidence_level' + return self.label_map.get(key, '').format(level=level) + + def get_level_key(self, level: Union[int, float], use_sigma: bool = False) -> str: + """Get dictionary key for confidence/sigma level.""" + key = 'sigma_level' if use_sigma else 'confidence_level' + level_str = str_encode_value(level) + return self.config.get('level_key', {}).get(key, '').format(level_str=level_str) + + def get_level_specs( + self, + sigma_levels: Optional[Tuple[Union[int, float], ...]] = None, + confidence_levels: Optional[Tuple[float, ...]] = None + ) -> Dict[str, Dict[str, Union[float, str]]]: + """ + Get specifications for all confidence/sigma levels. + + Parameters + ---------- + sigma_levels : Optional[Tuple[Union[int, float], ...]] + Sigma levels to include + confidence_levels : Optional[Tuple[float, ...]] + Confidence levels to include + + Returns + ------- + Dict[str, Dict[str, Union[float, str]]] + Dictionary of level specifications, sorted by chi-square value + """ + specs = {} + for levels, use_sigma in [(sigma_levels, True), (confidence_levels, False)]: + if not levels: + continue + for level in levels: + chi2 = self.get_chi2_value(level, use_sigma) + label = self.get_level_label(level, use_sigma) + key = self.get_level_key(level, use_sigma) + specs[key] = {'chi2': chi2, 'label': label} + + # make sure the levels are ordered in increasing chi2 + return dict(sorted(specs.items(), key=lambda x: x[1]['chi2'])) \ No newline at end of file diff --git a/quickstats/plots/multi_data_plot.py b/quickstats/plots/multi_data_plot.py new file mode 100644 index 0000000000000000000000000000000000000000..ce4e8583f776b887a5939d34b5f602c0308133f3 --- /dev/null +++ b/quickstats/plots/multi_data_plot.py @@ -0,0 +1,150 @@ +from __future__ import annotations + +from typing import Dict, List, Optional, Union, Any + +import pandas as pd + +from .abstract_plot import AbstractPlot +from .colors import ColorType, ColormapType + +DataFrameMap = Dict[Optional[str], pd.DataFrame] + +class MultiDataPlot(AbstractPlot): + """ + Plot class supporting multiple datasets with customizable styles. + + This class extends AbstractPlot to handle multiple pandas DataFrames, + allowing separate styling and labeling for each dataset. + """ + + def __init__( + self, + data_map: Union[pd.DataFrame, Dict[str, pd.DataFrame]], + color_map: Optional[Dict[str, ColorType]] = None, + color_cycle: Optional[ColormapType] = None, + label_map: Optional[Dict[str, str]] = None, + styles: Optional[Dict[str, Any]] = None, + styles_map: Optional[Dict[str, Union[Dict[str, Any], str]]] = None, + analysis_label_options: Optional[Union[str, Dict[str, Any]]] = None, + config: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Initialize MultiDataPlot. + + Parameters + ---------- + data_map : Union[pd.DataFrame, Dict[str, pd.DataFrame]] + Input data. Either a single DataFrame or a dictionary mapping + target names to DataFrames + color_map : Optional[Dict[str, ColorType]], default None + Mapping of targets to colors + color_cycle : Optional[ColormapType], default None + Color cycle for automatic color assignment + label_map : Optional[Dict[str, str]], default None + Mapping of targets to display labels + styles : Optional[Dict[str, Any]], default None + Global plot styles + styles_map : Optional[Dict[str, Union[Dict[str, Any], str]]], default None + Target-specific style overrides + analysis_label_options : Optional[Union[str, Dict[str, Any]]], default None + Options for analysis labels + config : Optional[Dict[str, Any]], default None + Additional configuration parameters + + Raises + ------ + ValueError + If data_map is not a DataFrame or valid dictionary + """ + if not isinstance(data_map, (pd.DataFrame, dict)): + raise ValueError( + "data_map must be a pandas DataFrame or dictionary of DataFrames" + ) + + self.load_data(data_map) + + super().__init__( + color_map=color_map, + color_cycle=color_cycle, + label_map=label_map, + styles=styles, + styles_map=styles_map, + analysis_label_options=analysis_label_options, + config=config, + ) + + def load_data(self, data_map: Union[pd.DataFrame, Dict[str, pd.DataFrame]]) -> None: + """ + Load data into plot instance. + + Parameters + ---------- + data_map : Union[pd.DataFrame, Dict[str, pd.DataFrame]] + Input data. Single DataFrame will be stored with key None + + Raises + ------ + ValueError + If any value in data_map dictionary is not a DataFrame + """ + if isinstance(data_map, pd.DataFrame): + self.data_map: DataFrameMap = {None: data_map} + else: + for key, value in data_map.items(): + if not isinstance(value, pd.DataFrame): + raise ValueError( + f"Value for key '{key}' is not a pandas DataFrame" + ) + self.data_map = data_map + + def is_single_data(self) -> bool: + """ + Check if plot contains only a single dataset. + + Returns + ------- + bool + True if only one dataset is present with key None + """ + return len(self.data_map) == 1 and None in self.data_map + + def resolve_targets( + self, + targets: Optional[List[str]] = None + ) -> List[Optional[str]]: + """ + Resolve target names for plotting. + + Parameters + ---------- + targets : Optional[List[str]], default None + Requested target names. If None, uses all available targets + + Returns + ------- + List[Optional[str]] + Resolved list of target names + + Raises + ------ + ValueError + If targets specified for single dataset or invalid targets provided + """ + if self.is_single_data(): + if targets and set(targets) != {None}: + raise ValueError( + "Cannot specify targets when only one dataset is present" + ) + return [None] + + if targets is None: + return list(self.data_map.keys()) + + # Validate all requested targets exist + invalid_targets = set(targets) - set(self.data_map.keys()) + if invalid_targets: + raise ValueError( + f"Invalid targets specified: {sorted(invalid_targets)}" + ) + + return targets \ No newline at end of file diff --git a/quickstats/plots/np_ranking_plot.py b/quickstats/plots/np_ranking_plot.py index d8e113490db6aa2c2374cb15bc91e9153fdb902f..fbc79b184a5be28f0ca51c94dee32c721ffde47a 100644 --- a/quickstats/plots/np_ranking_plot.py +++ b/quickstats/plots/np_ranking_plot.py @@ -9,14 +9,16 @@ import click import numpy as np import pandas as pd import matplotlib.pyplot as plt +from matplotlib.axes import Axes import matplotlib.transforms as transforms from matplotlib.patches import Rectangle +from matplotlib.lines import Line2D from matplotlib.backends.backend_pdf import PdfPages from quickstats import AbstractObject, semistaticmethod -from quickstats.plots.template import draw_analysis_label, format_axis_ticks, parse_transform, draw_hatches, \ - draw_sigma_bands, draw_sigma_lines, get_artist_dimension, draw_text, centralize_axis from quickstats.maths.numerics import ceildiv from quickstats.utils.common_utils import json_load +from quickstats.plots.template import draw_analysis_label, format_axis_ticks, parse_transform, draw_hatches, \ + get_artist_dimension, draw_text, centralize_axis BASE_STYLE = { 'pull': { @@ -147,6 +149,63 @@ TREX_STYLE = { }, **BASE_STYLE } +def draw_sigma_bands(axis: Axes, ymax: float, height: float = 1.0) -> None: + """ + Draw sigma bands on the axis. + + Parameters + ---------- + axis : matplotlib.axes.Axes + The axis to draw on. + ymax : float + Maximum y-value for the bands. + height : float, default 1.0 + Height of the bands. + + Returns + ------- + None + """ + # +- 2 sigma band + axis.add_patch( + Rectangle( + (-2, -height / 2), 4, ymax + height / 2, fill=True, color="yellow" + ) + ) + # +- 1 sigma band + axis.add_patch( + Rectangle( + (-1, -height / 2), 2, ymax + height / 2, fill=True, color="lime" + ) + ) + + +def draw_sigma_lines( + axis: Axes, ymax: float, height: float = 1.0, **styles +) -> None: + """ + Draw sigma lines on the axis. + + Parameters + ---------- + axis : matplotlib.axes.Axes + The axis to draw on. + ymax : float + Maximum y-value for the lines. + height : float, default 1.0 + Height of the lines. + **styles + Additional style arguments for the lines. + + Returns + ------- + None + """ + y_values = [-height / 2, ymax * height - height / 2] + axis.add_line(Line2D([-1, -1], y_values, **styles)) + axis.add_line(Line2D([1, 1], y_values, **styles)) + axis.add_line(Line2D([0, 0], y_values, **styles)) + class NPRankingPlot(AbstractObject): def __init__(self, input_dir:str=None, poi:Optional[str]=None, version:int=1, verbosity:Optional[Union[int, str]]=None): diff --git a/quickstats/plots/pdf_distribution_plot.py b/quickstats/plots/pdf_distribution_plot.py index 60037d66e36e58aa98d2e6b4ed6de7e37dfcd5ad..b8fa7cc263d7f8fb765e8b1b348f24623fd50373 100644 --- a/quickstats/plots/pdf_distribution_plot.py +++ b/quickstats/plots/pdf_distribution_plot.py @@ -59,6 +59,7 @@ class PdfDistributionPlot(AbstractPlot): figure_index:Optional[int]=None, config:Optional[Dict]=None): super().__init__(color_cycle=color_cycle, + label_map=label_map, styles=styles, figure_index=figure_index, analysis_label_options=analysis_label_options, config=config) @@ -68,10 +69,6 @@ class PdfDistributionPlot(AbstractPlot): self.plot_options = {} else: self.plot_options = plot_options - if label_map is None: - self.label_map = {} - else: - self.label_map = label_map self.colors = {} self.annotation = None diff --git a/quickstats/plots/registry.py b/quickstats/plots/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..efe2331420cbb1d4bb948404ba344fea2d566679 --- /dev/null +++ b/quickstats/plots/registry.py @@ -0,0 +1,35 @@ +from typing import Optional, Dict, Union +import copy + +from quickstats import NamedTreeNode, NestedDict + +class Registry(NamedTreeNode): + + def _validate_data(self, data: Optional[Dict] = None) -> NestedDict: + if data is None: + return NestedDict() + return NestedDict(copy.deepcopy(data)) + + @property + def data(self) -> NestedDict: + return self._data.copy(deep=True) + + def use(self, name: str) -> None: + data = self.get(name, strict=True) + self._data = data + + def parse(self, source: Optional[Union[str, Dict]] = None) -> NestedDict: + if source is None: + return {} + if isinstance(source, str): + try: + return self.get(source, strict=True) + except KeyError: + raise KeyError(f'template does not exist: {source}') + return NestedDict(copy.deepcopy(source)) + + def chain(self, *sources) -> NestedDict: + result = self.data + for source in sources: + result &= self.parse(source) + return result \ No newline at end of file diff --git a/quickstats/plots/score_distribution_plot.py b/quickstats/plots/score_distribution_plot.py index 7dcf3bb0124f13fbcbe2de3c9df2337fb1b64c34..d982f2f0e4c7e2934bbe33fd770507e5b4be76aa 100644 --- a/quickstats/plots/score_distribution_plot.py +++ b/quickstats/plots/score_distribution_plot.py @@ -11,7 +11,7 @@ from matplotlib.patches import Polygon from quickstats.maths.statistics import poisson_interval from quickstats.utils.common_utils import combine_dict from quickstats.plots import AbstractPlot -from quickstats.plots.template import single_frame, parse_styles, create_transform, format_axis_ticks +from quickstats.plots.template import single_frame, create_transform, format_axis_ticks class ScoreDistributionPlot(AbstractPlot): diff --git a/quickstats/plots/stat_plot_config.py b/quickstats/plots/stat_plot_config.py index abd9779671b26c197a3854c12e9eb840e973369d..c5ed6cfff644b885856e21a1b0c80dba3ac4c4bf 100644 --- a/quickstats/plots/stat_plot_config.py +++ b/quickstats/plots/stat_plot_config.py @@ -1,120 +1,184 @@ -from typing import List, Union, Dict, Optional, Callable +from __future__ import annotations + +from typing import Dict, List, Optional, Callable, Any, TypeVar, Union import numpy as np -import matplotlib.pyplot as plt +from numpy.typing import NDArray from quickstats import GeneralEnum +T = TypeVar('T', bound=np.number) + class StatMeasure(GeneralEnum): - MEAN = (0, np.mean) - STD = (1, np.std) - MIN = (2, np.min) - MAX = (3, np.max) + """Statistical measure enumeration with associated numpy operations.""" + + MEAN = (0, np.mean) + STD = (1, np.std) + MIN = (2, np.min) + MAX = (3, np.max) MEDIAN = (4, np.median) - def __new__(cls, value:int, operator:Callable): + def __new__(cls, value: int, operator: Callable[[NDArray[T]], T]) -> StatMeasure: obj = object.__new__(cls) obj._value_ = value obj.operator = operator - return obj + return obj class StatPlotConfig: + """Configuration for statistical plot elements.""" @property - def stat_measures(self): + def stat_measures(self) -> List[StatMeasure]: + """Get statistical measures.""" return self._stat_measures @stat_measures.setter - def stat_measures(self, values): - parsed = [] - for value in values: - parsed.append(StatMeasure.parse(value)) - self._stat_measures = parsed - - def __init__(self, stat_measures:List[Union[StatMeasure, str]], - axis_method:str, options:Dict, - handle_options:Optional[Dict]=None, - handle_return_method:Optional[Callable]=None): - self.stat_measures = stat_measures - self.axis_method = axis_method - self.options = options - if handle_options is None: - self.handle_options = self.get_default_handle_options() - else: - self.handle_options = handle_options - self.quantities = {} + def stat_measures(self, values: List[Union[StatMeasure, str]]) -> None: + """Set statistical measures with validation.""" + self._stat_measures = [StatMeasure.parse(value) for value in values] + + def __init__( + self, + stat_measures: List[Union[StatMeasure, str]], + axis_method: str, + options: Dict[str, Any], + handle_options: Optional[Dict[str, Callable]] = None, + handle_return_method: Optional[Callable] = None + ) -> None: + """ + Initialize statistical plot configuration. + + Parameters + ---------- + stat_measures : List[Union[StatMeasure, str]] + List of statistical measures to compute + axis_method : str + Matplotlib axes method to call + options : Dict[str, Any] + Options for the plot method + handle_options : Optional[Dict[str, Callable]], default None + Options derived from handle properties + handle_return_method : Optional[Callable], default None + Method to process the returned handle + """ + self.stat_measures = stat_measures + self.axis_method = axis_method + self.options = options + self.handle_options = handle_options or self.get_default_handle_options() + self.quantities: Dict[str, float] = {} self.handle_return_method = handle_return_method - - def set_data(self, x:np.ndarray): - quantities = {} - for stat_measure in self.stat_measures: - quantity = stat_measure.operator(x) - name = stat_measure.name.lower() - quantities[name] = quantity - self.quantities = quantities - - def get_default_handle_options(self): + + def set_data(self, data: NDArray[T]) -> None: + """ + Compute statistical measures from data. + + Parameters + ---------- + data : numpy.ndarray + Input data array + """ + self.quantities = { + measure.name.lower(): measure.operator(data) + for measure in self.stat_measures + } + + def get_default_handle_options(self) -> Dict[str, Callable]: + """Get default handle options.""" return {} - - def apply(self, ax, main_handle:Optional=None): + + def apply(self, ax: Any, main_handle: Optional[Any] = None) -> Any: + """ + Apply statistical plot to axes. + + Parameters + ---------- + ax : matplotlib.axes.Axes + Target axes + main_handle : Optional[Any], default None + Main plot handle for style matching + + Returns + ------- + Any + Plot handle or processed result + + Raises + ------ + RuntimeError + If method not found or data not initialized + """ if not hasattr(ax, self.axis_method): - raise RuntimeError(f"matplotlib.axes.Axes has no method named {self.axis_method}") + raise RuntimeError(f"Axes has no method: {self.axis_method}") if not self.quantities: - raise RuntimeError("stat data not initialized") + raise RuntimeError("Statistical data not initialized") + method = getattr(ax, self.axis_method) - resolved_options = {} - for name in self.options: - if callable(self.options[name]): - resolved_options[name] = self.options[name](self.quantities) - else: - resolved_options[name] = self.options[name] - if (main_handle is not None) and (self.handle_options is not None): - for name in self.handle_options: - if name in resolved_options: - continue - resolved_options[name] = self.handle_options[name](main_handle) + resolved_options = { + name: option(self.quantities) if callable(option) else option + for name, option in self.options.items() + } + + if main_handle is not None and self.handle_options: + for name, option in self.handle_options.items(): + if name not in resolved_options: + resolved_options[name] = option(main_handle) + result = method(**resolved_options) if self.handle_return_method is not None: return self.handle_return_method(result) return result class HandleMatchConfig(StatPlotConfig): - def get_default_handle_options(self): - handle_options = { - "color": lambda handle: handle.get_color() - } - return handle_options - + """Base class for configurations with handle property matching.""" + + def get_default_handle_options(self) -> Dict[str, Callable]: + """Get default color-matching options.""" + return {"color": lambda handle: handle.get_color()} + class AverageLineH(HandleMatchConfig): - def __init__(self, **styles): - options = { - "y": lambda x: x["mean"], - **styles - } - super().__init__(["mean"], "axhline", options=options) - + """Horizontal average line configuration.""" + + def __init__(self, **styles: Any) -> None: + super().__init__( + ["mean"], + "axhline", + options={"y": lambda x: x["mean"], **styles} + ) + class AverageLineV(HandleMatchConfig): - def __init__(self, **styles): - options = { - "x": lambda x: x["mean"], - **styles - } - super().__init__(["mean"], "axvline", options=options) - + """Vertical average line configuration.""" + + def __init__(self, **styles: Any) -> None: + super().__init__( + ["mean"], + "axvline", + options={"x": lambda x: x["mean"], **styles} + ) + class StdBandH(HandleMatchConfig): - def __init__(self, **styles): - options = { - "ymin": lambda x: x["mean"] - x["std"], - "ymax": lambda x: x["mean"] + x["std"], - **styles - } - super().__init__(["mean", "std"], "axhspan", options=options) - + """Horizontal standard deviation band configuration.""" + + def __init__(self, **styles: Any) -> None: + super().__init__( + ["mean", "std"], + "axhspan", + options={ + "ymin": lambda x: x["mean"] - x["std"], + "ymax": lambda x: x["mean"] + x["std"], + **styles + } + ) + class StdBandV(HandleMatchConfig): - def __init__(self, **styles): - options = { - "xmin": lambda x: x["mean"] - x["std"], - "xmax": lambda x: x["mean"] + x["std"], - **styles - } - super().__init__(["mean", "std"], "axvspan", options=options) \ No newline at end of file + """Vertical standard deviation band configuration.""" + + def __init__(self, **styles: Any) -> None: + super().__init__( + ["mean", "std"], + "axvspan", + options={ + "xmin": lambda x: x["mean"] - x["std"], + "xmax": lambda x: x["mean"] + x["std"], + **styles + } + ) \ No newline at end of file diff --git a/quickstats/plots/template.py b/quickstats/plots/template.py index a368efaed5566d5c46089b6a6a07a8568e9a76ba..9f1f902f4d88f5a235243b045781f5e709e0f272 100644 --- a/quickstats/plots/template.py +++ b/quickstats/plots/template.py @@ -1,1073 +1,1507 @@ -from typing import Optional, Union, Dict, List, Tuple, Sequence, Any +from __future__ import annotations + +from typing import ( + Optional, Union, Dict, List, Tuple, Any +) +from dataclasses import dataclass import re +from enum import Enum from itertools import repeat from contextlib import contextmanager +import warnings import numpy as np import matplotlib.pyplot as plt import matplotlib.transforms as transforms -from matplotlib.axes import Axes -from matplotlib.patches import Rectangle, Polygon -from matplotlib.collections import (PolyCollection, LineCollection, PathCollection) +import matplotlib.colors as mcolors +from matplotlib.axes import Axes +from matplotlib.axis import Axis +from matplotlib.artist import Artist +from matplotlib.patches import Patch, Rectangle, Polygon from matplotlib.lines import Line2D -from matplotlib.container import Container -from matplotlib.ticker import (MultipleLocator, FormatStrFormatter, - AutoMinorLocator, ScalarFormatter, - Locator, Formatter, AutoLocator, - LogFormatter, LogFormatterSciNotation, - MaxNLocator) -from matplotlib.legend_handler import (HandlerTuple, - HandlerLine2D, - HandlerLineCollection, - HandlerPathCollection) -from quickstats.utils.common_utils import combine_dict +from matplotlib.container import Container, ErrorbarContainer +from matplotlib.image import AxesImage +from matplotlib.text import Text +from matplotlib.collections import ( + Collection, + PolyCollection, + LineCollection, + PathCollection, +) +from matplotlib.ticker import ( + Locator, + MaxNLocator, + AutoLocator, + AutoMinorLocator, + ScalarFormatter, + Formatter, + LogFormatterSciNotation, +) +from matplotlib.legend_handler import ( + HandlerLineCollection, + HandlerPathCollection, +) + from quickstats import DescriptiveEnum +from quickstats.core import mappings as mp +from .colors import ColorType, ColormapType +from . import template_styles + +class TransformType(str, Enum): + """Valid transform types for matplotlib coordinates.""" + FIGURE = "figure" + AXIS = "axis" + DATA = "data" + +# Custom exceptions for better error handling +class PlottingError(Exception): + """Base exception for plotting-related errors.""" + pass + +class TransformError(PlottingError): + """Exception raised for transform-related errors.""" + pass + +class StyleError(PlottingError): + """Exception raised for style-related errors.""" + pass class ResultStatus(DescriptiveEnum): + """ + Enumeration for different result statuses with descriptions and display texts. + + Attributes + ---------- + value : int + The enumeration value + description : str + Detailed description of the status + display_text : str + Short text for display purposes + """ - FINAL = (0, "Finalised results", "") - INT = (1, "Internal results", "Internal") - WIP = (2, "Work in progress results", "Work in Progress") - PRELIM = (3, "Preliminary results", "Preliminary") - OPENDATA = (4, "Open data results", "Open Data") - SIM = (5, "Simulation results", "Simulation") - SIMINT = (6, "Simulation internal results", "Simulation Internal") + FINAL = (0, "Finalised results", "") + INT = (1, "Internal results", "Internal") + WIP = (2, "Work in progress results", "Work in Progress") + PRELIM = (3, "Preliminary results", "Preliminary") + OPENDATA = (4, "Open data results", "Open Data") + SIM = (5, "Simulation results", "Simulation") + SIMINT = (6, "Simulation internal results", "Simulation Internal") SIMPRELIM = (7, "Simulation preliminary results", "Simulation Preliminary") - def __new__(cls, value:int, description:str="", display_text:str=""): + def __new__(cls, value: int, description: str = "", display_text: str = "") -> ResultStatus: obj = object.__new__(cls) obj._value_ = value obj.description = description obj.display_text = display_text return obj - + class NumericFormatter(ScalarFormatter): """ - Custom numeric formatter for matplotlib axis ticks. - - It adjusts the formatting of tick labels for integer values with an absolute magnitude less than - 1000 to display as integers without decimal places (e.g., 5 instead of 5.0). This enhances the - readability of tick labels for small integer values. - - Examples - -------- - >>> import matplotlib.pyplot as plt - >>> fig, ax = plt.subplots() - >>> ax.plot([1, 2, 3], [100, 200, 300]) - >>> ax.yaxis.set_major_formatter(NumericFormatter()) - + Enhanced numeric formatter for matplotlib axis ticks. + + This formatter improves readability by displaying small integers without + decimal places while maintaining scientific notation for large numbers. """ - def __call__(self, x, pos=None): - tmp_format = self.format - if (x.is_integer() and abs(x) < 1e3): + def __call__(self, x: float, pos: Optional[int] = None) -> str: + original_format = self.format + if x.is_integer() and abs(x) < 1e3: self.format = re.sub(r"1\.\d+f", r"1.0f", self.format) result = super().__call__(x, pos) - self.format = tmp_format + self.format = original_format return result - + + class LogNumericFormatter(LogFormatterSciNotation): - def __call__(self, x, pos=None): + """Enhanced log formatter with improved handling of special cases.""" + + def __call__(self, x: float, pos: Optional[int] = None) -> str: + """Format the log-scale tick value.""" result = super().__call__(x, pos) - #result = result.replace('10^{1}', '10').replace('10^{0}', '1') + # result = result.replace('10^{1}', '10').replace('10^{0}', '1') return result - + + class CustomHandlerLineCollection(HandlerLineCollection): - def create_artists(self, legend, orig_handle, xdescent, ydescent, width, height, fontsize, trans): - artists = super().create_artists(legend, orig_handle, xdescent, ydescent, width, height, fontsize, trans) - # Adjust line height to center in legend + """Enhanced handler for line collections in legends.""" + + def create_artists( + self, + legend: Any, + orig_handle: Any, + xdescent: float, + ydescent: float, + width: float, + height: float, + fontsize: float, + trans: transforms.Transform + ) -> List[Line2D]: + """Create artists for legend entries with improved centering.""" + artists = super().create_artists( + legend, orig_handle, xdescent, ydescent, width, height, fontsize, trans + ) + + # Center lines in legend box for artist in artists: - artist.set_ydata([height / 2.0, height / 2.0]) + artist.set_ydata([height / 2.0] * 2) + return artists + class CustomHandlerPathCollection(HandlerPathCollection): - def create_artists(self, legend, orig_handle, xdescent, ydescent, width, height, fontsize, trans): - artists = super().create_artists(legend, orig_handle, xdescent, ydescent, width, height, fontsize, trans) - # Modify the path collection offsets to center the markers in the legend + """Enhanced handler for path collections in legends.""" + + def create_artists( + self, + legend: Any, + orig_handle: Any, + xdescent: float, + ydescent: float, + width: float, + height: float, + fontsize: float, + trans: transforms.Transform + ) -> List[Collection]: + """Create artists for legend entries with improved centering.""" + artists = super().create_artists( + legend, orig_handle, xdescent, ydescent, width, height, fontsize, trans + ) + + # Center markers in legend box for artist in artists: - offsets = np.array([[width / 2.0, height / 2.0]]) - artist.set_offsets(offsets) + artist.set_offsets([(width / 2.0, height / 2.0)]) + return artists -CUSTOM_HANDLER_MAP = {LineCollection: CustomHandlerLineCollection(), - PathCollection: CustomHandlerPathCollection()} - -TEMPLATE_STYLES = { - 'default': { - 'figure':{ - 'figsize': (11.111, 8.333), - 'dpi': 72, - 'facecolor': "#FFFFFF" - }, - 'legend_Line2D': { - 'linewidth': 3 - }, - 'legend_border': { - 'edgecolor' : 'black', - 'linewidth' : 1 - }, - 'annotation':{ - 'fontsize': 12 - }, - 'axis': { - 'major_length': 16, - 'minor_length': 8, - 'major_width': 2, - 'minor_width': 1, - 'spine_width': 2, - 'labelsize': 20, - 'offsetlabelsize': 20, - 'tick_bothsides': True, - 'x_axis_styles': {}, - 'y_axis_styles': {} - }, - 'xtick':{ - 'format': 'numeric', - 'locator': 'auto', - 'steps': None, - 'prune': None, - 'integer': False - }, - 'ytick':{ - 'format': 'numeric', - 'locator': 'auto', - 'steps': None, - 'prune': None, - 'integer': False - }, - 'xlabel': { - 'fontsize': 22, - 'loc' : 'right', - 'labelpad': 10 - }, - 'ylabel': { - 'fontsize': 22, - 'loc' : 'top', - 'labelpad': 10 - }, - 'title':{ - 'fontsize': 20, - 'loc': 'center', - 'pad': 10 - }, - 'text':{ - 'fontsize': 20, - 'verticalalignment': 'top', - 'horizontalalignment': 'left' - }, - 'plot':{ - 'linewidth': 2 - }, - 'hist': { - 'linewidth': 2 - }, - 'errorbar': { - "marker": 'x', - "linewidth": 0, - "markersize": 0, - "elinewidth": 1, - "capsize": 2, - "capthick": 1 - }, - 'fill_between': { - "alpha": 0.5 - }, - 'legend':{ - "fontsize": 20, - "columnspacing": 0.8 - }, - 'ratio_frame':{ - 'height_ratios': (3, 1), - 'hspace': 0.07 - }, - 'barh': { - 'height': 0.5 - }, - 'bar': { - }, - 'colorbar': { - 'fraction': 0.15, - 'shrink': 1. - }, - 'contour':{ - 'linestyles': 'solid', - 'linewidths': 3 - }, - 'contourf':{ - 'alpha': 0.5, - 'zorder': 0 - }, - 'colorbar_axis': { - 'labelsize': 20, - 'y_axis_styles': { - 'labelleft': False, - 'labelright': True, - 'left': False, - 'right': True, - 'direction': 'out' - } - }, - 'colorbar_label': { - 'fontsize': 22, - 'labelpad': 0 - }, - 'clabel': { - 'inline': True, - 'fontsize': 10 - }, - 'line': { - }, - 'line_collection': { - }, - 'comp.hist': { - } - } + +# Constants +CUSTOM_HANDLER_MAP = { + LineCollection: CustomHandlerLineCollection(), + PathCollection: CustomHandlerPathCollection(), } -TEMPLATE_ANALYSIS_LABEL_OPTIONS = { - 'default': { - 'status': 'int', - 'loc': (0.05, 0.95), - 'fontsize': 25 - }, - 'ATLAS_Run2': { - 'colab': 'ATLAS', - 'status': 'int', - 'energy' : '13 TeV', - 'lumi' : "140 fb$^{-1}$", - 'fontsize': 25 - } +AXIS_LOCATOR_MAP = { + "auto": AutoLocator, + "maxn": MaxNLocator } -AXIS_LOCATOR_MAPS = { - 'auto': AutoLocator, - 'maxn': MaxNLocator +# Special text formatting patterns with improved regex +SPECIAL_TEXT_PATTERNS = { + r"\\bolditalic\{(.*?)\}": {"weight": "bold", "style": "italic"}, + r"\\italic\{(.*?)\}": {"style": "italic"}, + r"\\bold\{(.*?)\}": {"weight": "bold"}, } -def handle_has_label(handle): +SPECIAL_TEXT_REGEX = re.compile( + "|".join(f"({pattern})" for pattern in SPECIAL_TEXT_PATTERNS.keys()) +) + +def parse_transform( + target: Optional[TransformType] = None, + ax: Optional[Axes] = None +) -> Optional[transforms.Transform]: + """Parse transform objects for coordinate system transformations.""" + try: + if target == TransformType.FIGURE: + fig = plt.gcf() + if fig is None: + raise TransformError("No current figure available") + return fig.transFigure + + elif target == TransformType.AXIS: + if ax is None: + ax = plt.gca() + if ax is None: + raise TransformError("No current axes available") + return ax.transAxes + + elif target == TransformType.DATA: + if ax is None: + ax = plt.gca() + if ax is None: + raise TransformError("No current axes available") + return ax.transData + + elif target is None: + return None + + raise TransformError(f"Invalid transform target: '{target}'") + + except Exception as e: + raise TransformError(f"Failed to create transform: {str(e)}") + + +def create_transform( + transform_x: Optional[TransformType] = TransformType.AXIS, + transform_y: Optional[TransformType] = TransformType.AXIS, + ax: Optional[Axes] = None +) -> transforms.Transform: + """ + Create a blended transform from x and y components. + + Parameters + ---------- + transform_x : TransformType + Transform for x-axis + transform_y : TransformType + Transform for y-axis + ax : Optional[Axes] + Axes instance to use for transforms + + Returns + ------- + transforms.Transform + Blended transform object + """ + return transforms.blended_transform_factory( + parse_transform(transform_x, ax), + parse_transform(transform_y, ax) + ) + +@contextmanager +def change_axis(axis: Axes) -> None: + """ + Context manager for temporarily changing the current axis. + + Parameters + ---------- + axis : matplotlib.axes.Axes + The axis to temporarily set as current + """ + prev_axis = plt.gca() + try: + plt.sca(axis) + yield + finally: + plt.sca(prev_axis) + + +def handle_has_label(handle: Artist) -> bool: + """ + Check if an artist handle has a valid label. + + Parameters + ---------- + handle : matplotlib.artist.Artist + Artist to check for label + + Returns + ------- + bool + True if handle has a valid label + """ try: label = handle.get_label() - has_label = (label and not label.startswith('_')) - except: - has_label = False - return has_label - -def parse_templated_options(template: Dict[str, Dict], - inst_value: Optional[Union[Dict, str]] = None, - cls_value: Optional[Union[Dict, str]] = None, - use_default: bool = True): - if use_default: - default_value = template.get('default', {}) - else: - default_value = {} + return bool(label and not label.startswith('_')) + except AttributeError: + return False + - mapping = { - 'instance': inst_value, - 'class': cls_value, - 'default': default_value +def suggest_markersize(nbins: int) -> float: + """ + Calculate suggested marker size based on number of bins. + + Parameters + ---------- + nbins : int + Number of bins + + Returns + ------- + float + Suggested marker size + """ + BIN_MAX = 200 + BIN_MIN = 40 + SIZE_MAX = 8 + SIZE_MIN = 2 + + if nbins <= BIN_MIN: + return SIZE_MAX + + if nbins <= BIN_MAX: + slope = (SIZE_MIN - SIZE_MAX) / (BIN_MAX - BIN_MIN) + return slope * (nbins - BIN_MIN) + SIZE_MAX + + return SIZE_MIN + + +def format_axis_ticks( + ax: Axes, + x_axis: bool = True, + y_axis: bool = True, + major_length: int = 16, + minor_length: int = 8, + spine_width: int = 2, + major_width: int = 2, + minor_width: int = 1, + direction: str = "in", + label_bothsides: bool = False, + tick_bothsides: bool = False, + labelsize: Optional[int] = None, + offsetlabelsize: Optional[int] = None, + x_axis_styles: Optional[Dict[str, Any]] = None, + y_axis_styles: Optional[Dict[str, Any]] = None, + xtick_styles: Optional[Dict[str, Any]] = None, + ytick_styles: Optional[Dict[str, Any]] = None, +) -> None: + """Format axis ticks with comprehensive styling options.""" + try: + if x_axis: + _format_x_axis( + ax, major_length, minor_length, major_width, minor_width, + direction, label_bothsides, tick_bothsides, labelsize, + x_axis_styles, xtick_styles + ) + + if y_axis: + _format_y_axis( + ax, major_length, minor_length, major_width, minor_width, + direction, label_bothsides, tick_bothsides, labelsize, + y_axis_styles, ytick_styles + ) + + # Format spines + for spine in ax.spines.values(): + spine.set_linewidth(spine_width) + + _handle_offset_labels(ax, offsetlabelsize or labelsize) + + except Exception as e: + warnings.warn(f"Error formatting axis ticks: {str(e)}") + + +def _format_x_axis( + ax: Axes, + major_length: int, + minor_length: int, + major_width: int, + minor_width: int, + direction: str, + label_bothsides: bool, + tick_bothsides: bool, + labelsize: Optional[int], + x_axis_styles: Optional[Dict[str, Any]], + xtick_styles: Optional[Dict[str, Any]] +) -> None: + """Helper function for formatting x-axis ticks.""" + if ax.get_xaxis().get_scale() != "log": + ax.xaxis.set_minor_locator(AutoMinorLocator()) + + x_styles = { + "labelsize": labelsize, + "labeltop": label_bothsides, + "top": tick_bothsides, + "bottom": True, + "direction": direction, } + + if x_axis_styles: + x_styles.update(x_axis_styles) + + ax.tick_params( + axis="x", + which="major", + length=major_length, + width=major_width, + **x_styles + ) + ax.tick_params( + axis="x", + which="minor", + length=minor_length, + width=minor_width, + **x_styles + ) + + set_axis_tick_styles(ax.xaxis, xtick_styles) + + +def _format_y_axis( + ax: Axes, + major_length: int, + minor_length: int, + major_width: int, + minor_width: int, + direction: str, + label_bothsides: bool, + tick_bothsides: bool, + labelsize: Optional[int], + y_axis_styles: Optional[Dict[str, Any]], + ytick_styles: Optional[Dict[str, Any]] +) -> None: + """Helper function for formatting y-axis ticks.""" + if ax.get_yaxis().get_scale() != "log": + ax.yaxis.set_minor_locator(AutoMinorLocator()) + + y_styles = { + "labelsize": labelsize, + "labelleft": True, + "left": True, + "right": tick_bothsides, + "direction": direction, + } + + if y_axis_styles: + y_styles.update(y_axis_styles) + + ax.tick_params( + axis="y", + which="major", + length=major_length, + width=major_width, + **y_styles + ) + ax.tick_params( + axis="y", + which="minor", + length=minor_length, + width=minor_width, + **y_styles + ) + + set_axis_tick_styles(ax.yaxis, ytick_styles) + + +def _handle_offset_labels(ax: Axes, offsetlabelsize: Optional[int]) -> None: + """Helper function for handling offset labels.""" + if offsetlabelsize is None: + return + + for axis in (ax.xaxis, ax.yaxis): + offset_text = axis.get_offset_text() + if offset_text.get_text(): + offset_text.set_fontsize(offsetlabelsize) + axis.labelpad += offset_text.get_fontsize() + + if (ax.xaxis.get_offset_text().get_text() or + ax.yaxis.get_offset_text().get_text()): + if not isinstance(plt.gca(), plt.Subplot): + plt.tight_layout() + - options = {} - copied = False - for key, value in mapping.items(): - if not value: - continue - if isinstance(value, str): - if value not in template: - raise ValueError(f'Undefined template: {value}') - value = template[value] - assert isinstance(value, dict) - if not options: - options = value +def set_axis_tick_styles(axis: Axis, styles: Optional[Dict[str, Any]] = None) -> None: + """ + Set advanced tick styles for an axis. + + Parameters + ---------- + axis : matplotlib.axis.Axis + The axis to style + styles : Optional[Dict[str, Any]] + Style specifications + """ + if not styles: + return + + try: + _set_axis_formatter(axis, styles.get("format")) + if axis.get_scale() != "log": + _set_axis_locator(axis, styles) + except Exception as e: + raise StyleError(f"Failed to apply axis styles: {str(e)}") + + +def _set_axis_formatter( + axis: Axis, + fmt: Optional[Union[str, Formatter]] +) -> None: + """Helper function to set axis formatter.""" + if fmt is None: + return + + if isinstance(fmt, str): + if fmt == "numeric": + formatter = ( + LogNumericFormatter() + if axis.get_scale() == "log" + else NumericFormatter() + ) else: - options = combine_dict(value, options) - copied = True - if not copied: - options = combine_dict(options) - return options - -def parse_styles(inst_value: Optional[Union[Dict, str]] = None, - cls_value: Optional[Union[Dict, str]] = None, - use_default: bool = True): - return parse_templated_options(template=TEMPLATE_STYLES, - inst_value=inst_value, - cls_value=cls_value, - use_default=use_default) - -def parse_analysis_label_options(inst_value: Optional[Union[Dict, str]] = None, - cls_value: Optional[Union[Dict, str]] = None, - use_default: bool = True): - # do not draw analysis label - if (inst_value is None) and (cls_value is None) and (not use_default): - return None - return parse_templated_options(template=TEMPLATE_ANALYSIS_LABEL_OPTIONS, - inst_value=inst_value, - cls_value=cls_value, - use_default=use_default) - -def ratio_frame(logx:bool=False, logy:bool=False, - logy_lower:Optional[bool]=False, - styles:Optional[Union[Dict, str]]=None, - analysis_label_options:Optional[Union[Dict, str]]=None, - prop_cycle:Optional[List[str]]=None, - prop_cycle_lower:Optional[List[str]]=None, - figure_index:Optional[int]=None): + raise ValueError(f"Unsupported format string: '{fmt}'") + elif isinstance(fmt, Formatter): + formatter = fmt + else: + raise ValueError(f"Invalid formatter type: {type(fmt)}") + + axis.set_major_formatter(formatter) + + +def _set_axis_locator(axis: Axis, styles: Dict[str, Any]) -> None: + """Helper function to set axis locator.""" + locator_type = styles.get("locator", "").lower() + if not locator_type: + return + + new_locator_class = AXIS_LOCATOR_MAP.get(locator_type) + if not new_locator_class: + raise ValueError(f"Unknown locator type: {locator_type}") + + new_locator = new_locator_class() + + locator_params = { + param: styles[param] + for param in getattr(new_locator, "default_params", []) + if param in styles + } + + if locator_params: + new_locator.set_params(**locator_params) + + axis.set_major_locator(new_locator) + +def ratio_frame( + logx: bool = False, + logy: bool = False, + logy_lower: Optional[bool] = None, + styles: Optional[Union[Dict[str, Any], str]] = None, + analysis_label_options: Optional[Union[Dict[str, Any], str]] = None, + prop_cycle: Optional[List[str]] = None, + prop_cycle_lower: Optional[List[str]] = None, + figure_index: Optional[int] = None, +) -> Tuple[Axes, Axes]: + """ + Create a ratio plot frame with shared x-axis. + + Parameters + ---------- + logx : bool + Use logarithmic x-axis + logy : bool + Use logarithmic y-axis for main plot + logy_lower : Optional[bool] + Use logarithmic y-axis for ratio plot + styles : Optional[Union[Dict[str, Any], str]] + Plot styles + analysis_label_options : Optional[Union[Dict[str, Any], str]] + Options for analysis label + prop_cycle : Optional[List[str]] + Color cycle for main plot + prop_cycle_lower : Optional[List[str]] + Color cycle for ratio plot + figure_index : Optional[int] + Figure number to use + + Returns + ------- + Tuple[Axes, Axes] + Main plot axes and ratio plot axes + """ + if figure_index is None: plt.clf() else: plt.figure(figure_index) - styles = parse_styles(styles) + + styles = template_styles.parse(styles) + gridspec_kw = { - "height_ratios": styles['ratio_frame']['height_ratios'], - "hspace": styles['ratio_frame']['hspace'] + "height_ratios": styles["ratio_frame"]["height_ratios"], + "hspace": styles["ratio_frame"]["hspace"], } - fig, (ax_main, ax_ratio) = plt.subplots(nrows=2, ncols=1, gridspec_kw=gridspec_kw, - sharex=True, **styles['figure']) + _, (ax_main, ax_ratio) = plt.subplots( + nrows=2, + ncols=1, + gridspec_kw=gridspec_kw, + sharex=True, + **styles["figure"] + ) + + # Configure scales if logx: - ax_main.set_xscale('log') - ax_ratio.set_xscale('log') - + ax_main.set_xscale("log") + ax_ratio.set_xscale("log") + + if logy: + ax_main.set_yscale("log") + if logy_lower is None: logy_lower = logy - - if logy: - ax_main.set_yscale('log') - + if logy_lower: - ax_ratio.set_yscale('log') + ax_ratio.set_yscale("log") + + # Format axes + ax_main_styles = mp.concat( + (styles["axis"], {"x_axis_styles": {"labelbottom": False}}) + ) + + format_axis_ticks( + ax_main, + x_axis=True, + y_axis=True, + xtick_styles=styles["xtick"], + ytick_styles=styles["ytick"], + **ax_main_styles + ) - ax_main_styles = combine_dict(styles['axis'], {"x_axis_styles": {"labelbottom": False}}) - format_axis_ticks(ax_main, x_axis=True, y_axis=True, xtick_styles=styles['xtick'], - ytick_styles=styles['ytick'], **ax_main_styles) - format_axis_ticks(ax_ratio, x_axis=True, y_axis=True, xtick_styles=styles['xtick'], - ytick_styles=styles['ytick'], **styles['axis']) + format_axis_ticks( + ax_ratio, + x_axis=True, + y_axis=True, + xtick_styles=styles["xtick"], + ytick_styles=styles["ytick"], + **styles["axis"] + ) + # Add analysis label if requested if analysis_label_options is not None: - draw_analysis_label(ax_main, text_options=styles['text'], **analysis_label_options) - + draw_analysis_label( + ax_main, + text_options=styles["text"], + **analysis_label_options + ) + + # Set property cycles if prop_cycle is not None: ax_main.set_prop_cycle(prop_cycle) - + if prop_cycle_lower is None: prop_cycle_lower = prop_cycle - + if prop_cycle_lower is not None: ax_ratio.set_prop_cycle(prop_cycle_lower) return ax_main, ax_ratio -def single_frame(logx:bool=False, logy:bool=False, - styles:Optional[Union[Dict, str]]=None, - analysis_label_options:Optional[Union[Dict, str]]=None, - prop_cycle:Optional[List[str]]=None, - figure_index:Optional[int]=None): + +def single_frame( + logx: bool = False, + logy: bool = False, + styles: Optional[Union[Dict[str, Any], str]] = None, + analysis_label_options: Optional[Union[Dict[str, Any], str]] = None, + prop_cycle: Optional[List[str]] = None, + figure_index: Optional[int] = None, +) -> Axes: + """ + Create a single plot frame with enhanced options. + + Parameters + ---------- + logx : bool + Use logarithmic x-axis + logy : bool + Use logarithmic y-axis + styles : Optional[Union[Dict[str, Any], str]] + Plot styles + analysis_label_options : Optional[Union[Dict[str, Any], str]] + Options for analysis label + prop_cycle : Optional[List[str]] + Color cycle + figure_index : Optional[int] + Figure number to use + + Returns + ------- + Axes + The created plot axes + """ + if figure_index is None: plt.clf() else: plt.figure(figure_index) - styles = parse_styles(styles) - fig, ax = plt.subplots(nrows=1, ncols=1, **styles['figure']) + + styles = template_styles.parse(styles) + _, ax = plt.subplots(nrows=1, ncols=1, **styles["figure"]) if logx: - ax.set_xscale('log') + ax.set_xscale("log") if logy: - ax.set_yscale('log') - - format_axis_ticks(ax, x_axis=True, y_axis=True, xtick_styles=styles['xtick'], - ytick_styles=styles['ytick'], **styles['axis']) + ax.set_yscale("log") + + format_axis_ticks( + ax, + x_axis=True, + y_axis=True, + xtick_styles=styles["xtick"], + ytick_styles=styles["ytick"], + **styles["axis"] + ) if analysis_label_options is not None: - draw_analysis_label(ax, text_options=styles['text'], **analysis_label_options) - + draw_analysis_label( + ax, + text_options=styles["text"], + **analysis_label_options + ) + if prop_cycle is not None: ax.set_prop_cycle(prop_cycle) return ax -def suggest_markersize(nbins:int): - bin_max = 200 - bin_min = 40 - size_max = 8 - size_min = 2 - if nbins <= bin_min: - return size_max - elif (nbins > bin_min) and (nbins <= bin_max): - return ((size_min - size_max) / (bin_max - bin_min))*(nbins - bin_min) + size_max - return size_min - -def format_axis_ticks(ax, x_axis=True, y_axis=True, major_length:int=16, minor_length:int=8, - spine_width:int=2, major_width:int=2, minor_width:int=1, direction:str='in', - label_bothsides:bool=False, tick_bothsides:bool=False, - labelsize:Optional[int]=None, - offsetlabelsize:Optional[int]=None, - x_axis_styles:Optional[Dict]=None, - y_axis_styles:Optional[Dict]=None, - xtick_styles:Optional[Dict]=None, - ytick_styles:Optional[Dict]=None): - if x_axis: - if (ax.get_xaxis().get_scale() != 'log'): - ax.xaxis.set_minor_locator(AutoMinorLocator()) - styles = {"labelsize":labelsize} - styles['labeltop'] = label_bothsides - #styles['labelbottom'] = True - styles['top'] = tick_bothsides - styles['bottom'] = True - styles['direction'] = direction - if x_axis_styles is not None: - styles.update(x_axis_styles) - ax.tick_params(axis="x", which="major", length=major_length, - width=major_width, **styles) - ax.tick_params(axis="x", which="minor", length=minor_length, - width=minor_width, **styles) - if y_axis: - if (ax.get_yaxis().get_scale() != 'log'): - ax.yaxis.set_minor_locator(AutoMinorLocator()) - styles = {"labelsize":labelsize} - #styles['labelleft'] = True - styles['labelright'] = label_bothsides - styles['left'] = True - styles['right'] = tick_bothsides - styles['direction'] = direction - if y_axis_styles is not None: - styles.update(y_axis_styles) - ax.tick_params(axis="y", which="major", length=major_length, - width=major_width, **styles) - ax.tick_params(axis="y", which="minor", length=minor_length, - width=minor_width, **styles) - - for axis in ['top','bottom','left','right']: - ax.spines[axis].set_linewidth(spine_width) - - set_axis_tick_styles(ax.xaxis, xtick_styles) - set_axis_tick_styles(ax.yaxis, ytick_styles) - # take care of offset labels - if offsetlabelsize is None: - offsetlabelsize = labelsize - - xaxis_offset_text = ax.xaxis.get_offset_text().get_text() - if xaxis_offset_text: - ax.xaxis.get_offset_text().set_fontsize(offsetlabelsize) - ax.xaxis.labelpad = ax.xaxis.labelpad + ax.xaxis.get_offset_text().get_fontsize() - yaxis_offset_text = ax.yaxis.get_offset_text().get_text() - if yaxis_offset_text: - ax.yaxis.get_offset_text().set_fontsize(offsetlabelsize) - ax.yaxis.labelpad = ax.yaxis.labelpad + ax.yaxis.get_offset_text().get_fontsize() - - if (xaxis_offset_text or yaxis_offset_text) and (plt.gca().__class__.__name__ != "AxesSubplot"): - plt.tight_layout() - -def set_axis_tick_styles(ax, styles=None): - if styles is None: - return None - - fmt = styles['format'] - if fmt is not None: - formatter = None - if isinstance(fmt, str): - if fmt == 'numeric': - if ax.get_scale() == "log": - formatter = LogNumericFormatter() - else: - formatter = NumericFormatter() - if isinstance(fmt, Formatter): - formatter = fmt - if formatter is None: - raise ValueError(f"unsupported axis tick format {fmt}") - ax.set_major_formatter(formatter) +@dataclass +class AnalysisLabelConfig: + """Configuration for analysis labels.""" + loc: Tuple[float, float] = (0.05, 0.95) + fontsize: float = 25 + status: Union[str, ResultStatus] = "int" + energy: Optional[str] = None + lumi: Optional[str] = None + colab: Optional[str] = "ATLAS" + main_text: Optional[str] = None + extra_text: Optional[str] = None + dy: float = 0.02 + dy_main: float = 0.01 + transform_x: TransformType = "axis" + transform_y: TransformType = "axis" + vertical_align: str = "top" + horizontal_align: str = "left" + text_options: Optional[Dict[str, Any]] = None + + +def draw_analysis_label( + axis: Axes, + **kwargs: Any +) -> None: + """ + Draw analysis label with comprehensive options. + + Parameters + ---------- + axis : matplotlib.axes.Axes + The axes to draw on + **kwargs : Any + Configuration options (see AnalysisLabelConfig) + """ + config = AnalysisLabelConfig(**kwargs) + + try: + status_text = ResultStatus.parse(config.status).display_text + except (ValueError, AttributeError): + status_text = str(config.status) + + with change_axis(axis): + x_pos, y_pos = config.loc - if ax.get_scale() == "log": - return None - - locator = ax.get_major_locator() - - if isinstance(locator, (AutoLocator, MaxNLocator)): - new_locator = AXIS_LOCATOR_MAPS.get(styles['locator'].lower(), type(locator))() - try: - available_params = list(new_locator.default_params) - except: - available_params = ['steps', 'prune', 'integer'] - locator_params = {} - for param in available_params: - value = styles.get(param, None) - if value is not None: - locator_params[param] = value - new_locator.set_params(**locator_params) - ax.set_major_locator(new_locator) - -def centralize_axis(ax: Axes, which: str = 'y', ref_value: float = 0, padding: float = 0.1) -> None: + # Draw main texts + y_pos = _draw_main_texts( + axis, + x_pos, + y_pos, + config.main_text, + config.colab, + status_text, + config + ) + + # Draw additional texts + _draw_additional_texts( + axis, + x_pos, + y_pos, + config.energy, + config.lumi, + config.extra_text, + config + ) + + +def _draw_main_texts( + axis: Axes, + x_pos: float, + y_pos: float, + main_text: Optional[str], + colab: Optional[str], + status_text: str, + config: AnalysisLabelConfig +) -> float: + """Helper function to draw main texts of analysis label.""" + main_texts = [] + + if main_text: + main_texts.extend(main_text.split("//")) + + if colab: + colab_text = r"\bolditalic{" + colab + "} " + status_text + main_texts.append(colab_text) + + current_y = y_pos + for text in main_texts: + _, _, current_y, _ = draw_text( + axis, + x_pos, + current_y, + text, + fontsize=config.fontsize, + transform_x=config.transform_x, + transform_y=config.transform_y, + horizontalalignment=config.horizontal_align, + verticalalignment=config.vertical_align + ) + current_y -= config.dy_main + + return current_y + +def _draw_additional_texts( + axis: Axes, + x_pos: float, + y_pos: float, + energy: Optional[str], + lumi: Optional[str], + extra_text: Optional[str], + config: AnalysisLabelConfig +) -> None: + """Helper function to draw additional texts of analysis label.""" + texts = [] + + # Combine energy and luminosity + elumi_parts = [] + if energy: + elumi_parts.append(r"$\sqrt{s} = $" + energy) + if lumi: + elumi_parts.append(lumi) + + if elumi_parts: + texts.append(", ".join(elumi_parts)) + + # Add extra text + if extra_text: + texts.extend(extra_text.split("//")) + + # Draw all texts + text_options = config.text_options or {} + current_y = y_pos + + for text in texts: + _, _, current_y, _ = draw_text( + axis, + x_pos, + current_y - config.dy, + text, + **text_options + ) + current_y -= config.dy + + +def draw_text( + axis: Axes, + x: float, + y: float, + text_str: str, + transform_x: TransformType = "axis", + transform_y: TransformType = "axis", + **styles: Any +) -> Tuple[float, float, float, float]: """ - Centralize the axis around a reference value. - + Draw formatted text with special styles. + Parameters ---------- - ax : matplotlib.axes.Axes - The axis to be centralized. - which : str, optional - The axis to centralize. 'x' for x-axis, 'y' for y-axis. Default is 'y'. - ref_value : float, optional - The reference value around which the axis will be centralized. Default is 0. - padding : float, optional - The padding applied around the data to create space. Default is 0.1. - - Example + axis : matplotlib.axes.Axes + The axes to draw on + x : float + X-coordinate + y : float + Y-coordinate + text_str : str + Text to draw + transform_x : TransformType + X-coordinate transform + transform_y : TransformType + Y-coordinate transform + **styles : Any + Additional text styles + + Returns ------- - >>> import matplotlib.pyplot as plt - >>> fig, ax = plt.subplots() - >>> ax.plot([1, 2, 3], [2, 4, 6]) - >>> centralize_axis(ax, which='y', ref_value=3) + Tuple[float, float, float, float] + Text dimensions (xmin, xmax, ymin, ymax) """ - if which not in {'x', 'y'}: - raise ValueError('axis to centralize must be either "x" or "y"') - - if which == 'x': - get_scale = ax.get_xscale - get_lim = ax.get_xlim - set_lim = ax.set_xlim - elif which == 'y': - get_scale = ax.get_yscale - get_lim = ax.get_ylim - set_lim = ax.set_ylim + with change_axis(axis): + transform = create_transform(transform_x, transform_y) + components = SPECIAL_TEXT_REGEX.split(text_str) + current_x = x + xmin = None + + for component in components: + if not component: + continue + + if SPECIAL_TEXT_REGEX.match(component): + for pattern, font_styles in SPECIAL_TEXT_PATTERNS.items(): + match = re.match(pattern, component) + if match: + text = axis.text( + current_x, + y, + match.group(1), + transform=transform, + **styles, + **font_styles + ) + break + else: + text = axis.text( + current_x, + y, + component, + transform=transform, + **styles + ) + + xmin_, current_x, ymin, ymax = get_artist_dimension(text) + if xmin is None: + xmin = xmin_ - if get_scale() == 'log': - raise ValueError('cannot centralize on a logarithmic axis') + return xmin, current_x, ymin, ymax + + +def draw_multiline_text( + axis: Axes, + x: float, + y: float, + text_str: str, + dy: float = 0.01, + transform_x: TransformType = "axis", + transform_y: TransformType = "axis", + **styles: Any +) -> None: + """Draw multi-line text with special formatting.""" + + current_y = y + lines = text_str.split("//") + + for line in lines: + _, _, current_y, _ = draw_text( + axis, + x, + current_y, + line.strip(), + transform_x=transform_x, + transform_y=transform_y, + **styles + ) + current_y -= dy + transform_x = transform_y = "axis" + + +def centralize_axis( + ax: Axes, + which: Literal["x", "y"] = "y", + ref_value: float = 0, + padding: float = 0.1 +) -> None: + """ + Centralize an axis around a reference value with padding. + + Parameters + ---------- + ax : matplotlib.axes.Axes + The axes to modify + which : Literal["x", "y"] + Which axis to centralize + ref_value : float + Reference value to center around + padding : float + Padding fraction + """ + if which not in {"x", "y"}: + raise ValueError('Axis must be either "x" or "y"') + + get_scale = ax.get_xscale if which == "x" else ax.get_yscale + get_lim = ax.get_xlim if which == "x" else ax.get_ylim + set_lim = ax.set_xlim if which == "x" else ax.set_ylim + + if get_scale() == "log": + raise ValueError("Cannot centralize logarithmic axis") + + if not (0 <= padding < 1): + raise ValueError("Padding must be between 0 and 1") lim = get_lim() delta = max(abs(ref_value - lim[0]), abs(lim[1] - ref_value)) - pad = (lim[1] - lim[0]) * padding if padding else 0. + pad = (lim[1] - lim[0]) * padding if padding else 0.0 new_lim = (ref_value - delta - pad, ref_value + delta + pad) set_lim(*new_lim) -def parse_transform(target: Optional[str] = None) -> Optional[transforms.Transform]: +def get_artist_dimension( + artist: Artist, + transform: TransformType = 'axis' +) -> Tuple[float, float, float, float]: """ - Parse a string into a Matplotlib transform. - + Get dimensions of an artist's bounding box. + Parameters ---------- - target : Optional[str], default: None - The string representation of the transformation target. - Possible values: 'figure', 'axis', 'data', or an empty string. + artist : matplotlib.artist.Artist + The artist to measure + transform : TransformType + Coordinate transform for dimensions - - 'figure': Transform relative to the figure. - - 'axis': Transform relative to the axes. - - 'data': Transform relative to the data coordinates. - - None or '': Returns None. - Returns ------- - transform : Optional[transforms.Transform] - The corresponding transformation object. Returns None if the input is None or an empty string. - - Examples - -------- - >>> transform_figure = parse_transform('figure') - >>> transform_data = parse_transform('data') + Tuple[float, float, float, float] + Dimensions (xmin, xmax, ymin, ymax) """ - if target == 'figure': - fig = plt.gcf() - if fig is None: - raise ValueError("No current figure available for 'figure' transform") - return fig.transFigure - elif target == 'axis': - ax = plt.gca() - if ax is None: - raise ValueError("No current axis available for 'axis' transform") - return ax.transAxes - elif target == 'data': - ax = plt.gca() - if ax is None: - raise ValueError("No current axis available for 'data' transform") - return ax.transData - elif not target: - return None - else: - raise ValueError(f"Invalid transform target: '{target}'") -def create_transform(transform_x: str = 'axis', transform_y: str = 'axis') -> transforms.Transform: - """ - Create a composite transformation from two string representations of transformations. + axis = artist.axes or plt.gca() + artist.figure.canvas.draw() + + bbox = artist.get_window_extent() + + if transform is not None: + transform_obj = parse_transform(transform, ax=axis) + if transform_obj is not None: + bbox = bbox.transformed(transform_obj.inverted()) + + return bbox.xmin, bbox.xmax, bbox.ymin, bbox.ymax + +def draw_hatches( + axis: Axes, + ymax: float, + height: float = 1.0, + **styles: Any +) -> None: + """ + Draw hatched pattern on axis. + Parameters ---------- - transform_x : str, optional - The string representation of the transformation for the x-direction. - transform_y : str, optional - The string representation of the transformation for the y-direction. + axis : matplotlib.axes.Axes + The axes to draw on + ymax : float + Maximum y-value + height : float + Height of hatch pattern + **styles : Any + Additional style options + """ + y_values = np.arange(0, height * ymax, 2 * height) - height / 2 + transform = create_transform(transform_x="axis", transform_y="data") + + for y in y_values: + axis.add_patch( + Rectangle( + (0, y), + 1, + 1, + transform=transform, + zorder=-1, + **styles + ) + ) + + +def is_transparent_color(color: Optional[ColorType]) -> bool: + """ + Check if a color is transparent. + + Parameters + ---------- + color : Optional[ColorType] + Color to check + Returns ------- - transform : matplotlib.transforms.Transform - The composite transformation object. - - Examples - -------- - >>> combined_transform = create_transform('axis', 'data') - - """ - transform = transforms.blended_transform_factory(parse_transform(transform_x), - parse_transform(transform_y)) - return transform - -def get_artist_dimension(artist): + bool + True if color is transparent """ - Get the dimensions of an artist's bounding box in axis coordinates. + if color is None: + raise ValueError("Color cannot be None") + + try: + rgba = mcolors.to_rgba(color) + return rgba[3] == 0 + except ValueError as e: + raise ValueError(f"Invalid color format: {color}") from e - This function calculates the dimensions (x-min, x-max, y-min, y-max) of an artist's - bounding box in axis coordinates based on the provided artist. +def get_artist_colors( + artist: Artist, + index: int = 0 +) -> Dict[str, Optional[ColorType]]: + """ + Get color properties of an artist. + Parameters ---------- artist : matplotlib.artist.Artist - The artist for which dimensions need to be calculated. - + The artist to analyze + index : int + Index for collections + Returns ------- - xmin, xmax, ymin, ymax : float - The calculated dimensions of the artist's bounding box in axis coordinates. - - Example - ------- - >>> from matplotlib.patches import Rectangle - >>> rectangle = Rectangle((0.2, 0.3), 0.4, 0.4) - >>> xmin, xmax, ymin, ymax = get_artist_dimension(rectangle) - + Dict[str, Optional[ColorType]] + Color properties """ - axis = plt.gca() - plt.gcf().canvas.draw() - - # Get the bounding box of the artist in display coordinates - box = artist.get_window_extent() - - # Transform the bounding box to axis coordinates - points = box.transformed(axis.transAxes.inverted()).get_points().transpose() - - xmin = np.min(points[0]) - xmax = np.max(points[0]) - ymin = np.min(points[1]) - ymax = np.max(points[1]) - - return xmin, xmax, ymin, ymax - -def draw_sigma_bands(axis, ymax:float, height:float=1.0): - # +- 2 sigma band - axis.add_patch(Rectangle((-2, -height/2), 2*2, ymax + height/2, fill=True, color='yellow')) - # +- 1 sigma band - axis.add_patch(Rectangle((-1, -height/2), 1*2, ymax + height/2, fill=True, color='lime')) + if isinstance(artist, Container): + children = artist.get_children() + if not children: + raise IndexError("Artist has no children") + if index >= len(children): + raise IndexError("Index out of bounds") + artist = children[index] -def draw_sigma_lines(axis, ymax:float, height:float=1.0, **styles): - y = [-height/2, ymax*height - height/2] - axis.add_line(Line2D([-1, -1], y, **styles)) - axis.add_line(Line2D([+1, +1], y, **styles)) - axis.add_line(Line2D([0, 0], y, **styles)) + if not isinstance(artist, Artist): + raise TypeError("Invalid artist type") -def draw_hatches(axis, ymax, height=1.0, **styles): - x_min = axis.get_xlim()[0] - x_max = axis.get_xlim()[1] - x_range = x_max - x_min - y_values = np.arange(0, height*ymax, 2*height) - height/2 - transform = create_transform(transform_x='axis', transform_y='data') - for y in y_values: - axis.add_patch(Rectangle((0, y), 1, 1, **styles, zorder=-1, transform=transform)) - -special_text_fontstyles = { - re.compile(r'\\bolditalic\{(.*?)\}'): { - "weight":"bold", "style":"italic" - }, - re.compile(r'\\italic\{(.*?)\}'): { - "style":"italic" - }, - re.compile(r'\\bold\{(.*?)\}'): { - "weight":"bold" - } -} -special_text_regex = re.compile("|".join([f"({regex.pattern.replace('(', '').replace(')', '')})" - for regex in special_text_fontstyles.keys()])) + colors: Dict[str, Optional[ColorType]] = {} + + if isinstance(artist, Collection): + facecolors = artist.get_facecolor() + edgecolors = artist.get_edgecolor() + colors["facecolor"] = ( + facecolors[index] if len(facecolors) > index else None + ) + colors["edgecolor"] = ( + edgecolors[index] if len(edgecolors) > index else None + ) + + elif isinstance(artist, Line2D): + colors.update({ + "color": artist.get_color(), + "markerfacecolor": artist.get_markerfacecolor(), + "markeredgecolor": artist.get_markeredgecolor() + }) + + elif isinstance(artist, Patch): + colors.update({ + "facecolor": artist.get_facecolor(), + "edgecolor": artist.get_edgecolor() + }) + + elif isinstance(artist, AxesImage): + colors["cmap"] = artist.get_cmap() + + elif isinstance(artist, Text): + colors["textcolor"] = artist.get_color() + + return colors -def draw_text(axis, x:float, y:float, s:str, - transform_x:str='axis', - transform_y:str='axis', - **styles): - with change_axis(axis): - transform = create_transform(transform_x, transform_y) - components = special_text_regex.split(s) - components = [component for component in components] - xmax = x - xmin = None - for component in components: - if component and special_text_regex.match(component): - for regex, fontstyles in special_text_fontstyles.items(): - match = regex.match(component) - if match: - text = axis.text(xmax, y, match.group(1), transform=transform, - **styles, **fontstyles) - break - else: - text = axis.text(xmax, y, component, transform=transform, **styles) - xmin_, xmax, ymin, ymax = get_artist_dimension(text) - if xmin is None: - xmin = xmin_ - return xmin, xmax, ymin, ymax - -def draw_multiline_text(axis, x:float, y:float, - s:str, dy:float=0.01, - transform_x:str='axis', - transform_y:str='axis', - **styles): - components = s.split("//") - for component in components: - _, _, y, _ = draw_text(axis, x, y, component, - transform_x=transform_x, - transform_y=transform_y, - **styles) - y -= dy - transform_x, transform_y = 'axis', 'axis' - -@contextmanager -def change_axis(axis): - """ - Temporarily change the current axis to the specified axis within a context. - Parameters - ---------- - axis : matplotlib.axes._base.Axes - The axis to which the current axis will be temporarily changed. - - Examples - -------- - >>> import matplotlib.pyplot as plt - >>> from contextlib import contextmanager - - >>> @contextmanager - ... def change_axis(axis): - ... current_axis = plt.gca() - ... plt.sca(axis) - ... yield - ... plt.sca(current_axis) - - >>> fig, axes = plt.subplots(1, 2) - >>> with change_axis(axes[0]): - ... plt.plot([1, 2, 3], [4, 5, 6]) - ... plt.title('First Axis') - +def convert_size(size_str: str) -> float: """ - current_axis = plt.gca() - plt.sca(axis) - yield - plt.sca(current_axis) - -def draw_analysis_label(axis, loc=(0.05, 0.95), fontsize:float=25, status:str='int', - energy:Optional[str]=None, lumi:Optional[str]=None, - colab:Optional[str]='ATLAS', main_text:Optional[str]=None, - extra_text:Optional[str]=None, dy:float=0.02, dy_main:float=0.01, - transform_x:str='axis', transform_y:str='axis', - vertical_align:str='top', horizontal_align:str='left', - text_options:Optional[Dict]=None): - r""" - Draw analysis label and additional texts on a given axis. + Convert size string to float value. Parameters - --------------------------------------------------------------- - axis: matplotlib.pyplot.axis - Axis to be drawn on. - loc: (float, float), default = (0.05, 0.95) - The location of the analysis label and additional texts. - fontsize: float, default = 25 - Font size of the analysis label and the status label. - status: str or ResultStatus, default = 'int' - Display text for the analysis status. Certain keywords can be used to convert - automatically to the corresponding built-in status texts - (see `ResultStatus`). - energy: (optional) str - Display text for the Center-of-mass energy. A prefix of "\sqrt{s} = " will be - automatically appended to the front of the text. - lumi: (optional) str - Display text for the luminosity. It will be displayed as is. - colab: (optional) str - Display text for the collaboration involved in the analysis. It will be - bolded and italised. - main_text: (optional) str - Main text to be displayed before the colab text. A new line - can be added by adding a double-slash, i.e. "//". Use the "\bolditalic{<text>}" - keyword for bold-italic styled text. - extra_text: (optional) str - Extra text to be displayed after energy and luminosity texts. A new line - can be added by adding a double-slash, i.e. "//". Use the "\bolditalic{<text>}" - keyword for bold-italic styled text. - dy: float, default = 0.05 - Vertical separation between each line of the sub-texts in the axis coordinates. - dy_main: float, default = 0.02 - Vertical separation between each line of the main-texts in the axis coordinates. - transform_x: str, default = 'axis' - Coordinate transform for the x location of the analysis label. - transform_y: str, default = 'axis' - Coordinate transform for the y location of the analysis label. - vertical_align: str, default = 'top' - Vertical alignment of the analysis label. - horizontal_align: str, default = 'top' - Horizontal alignment of the analysis label. - text_options: (optional), dict - A dictionary specifying the styles for drawing texts. + ---------- + size_str : str + Size string (e.g., "50%", "0.5") + + Returns + ------- + float + Converted size value """ try: - status_text = ResultStatus.parse(status).display_text - except: - status_text = status - - with change_axis(axis): - xmin, ymin = loc - main_texts = [] - if main_text is not None: - main_texts.extend(main_text.split("//")) - if colab is not None: - # add collaboration and status text - colab_text = r"\bolditalic{" + colab + "} " + status_text - main_texts.append(colab_text) - for text in main_texts: - _, _, ymin, _ = draw_text(axis, xmin, ymin, text, - fontsize=fontsize, - transform_x=transform_x, - transform_y=transform_y, - horizontalalignment=horizontal_align, - verticalalignment=vertical_align) - ymin -= dy_main - transform_x, transform_y = 'axis', 'axis' - - # draw energy and luminosity labels as well as additional texts - elumi_text = [] - if energy is not None: - elumi_text.append(r"$\sqrt{s} = $" + energy ) - if lumi is not None: - elumi_text.append(lumi) - elumi_text = ", ".join(elumi_text) - - all_texts = [] - if elumi_text: - all_texts.append(elumi_text) - - if extra_text is not None: - all_texts.extend(extra_text.split("//")) + if size_str.endswith('%'): + return float(size_str.strip('%')) / 100 + return float(size_str) + except ValueError as e: + raise ValueError(f"Invalid size format: {size_str}") from e - if text_options is None: - text_options = {} - - for text in all_texts: - _, _, ymin, _ = draw_text(axis, xmin, ymin - dy, text, **text_options) - -def is_edgy_polygon(handle): +def is_edgy_polygon(handle: Polygon) -> bool: """ Check if a legend handle represents a polygon with only edges and no fill. - + Parameters ---------- handle : matplotlib.patches.Polygon - The legend handle to be checked. - + The legend handle to be checked + Returns ------- bool - True if the provided legend handle represents an edgy polygon (only edges, no fill). - False if the provided legend handle does not meet the criteria of an edgy polygon. - - Examples - -------- - >>> from matplotlib.patches import Polygon - >>> polygon_handle = Polygon([(0, 0), (1, 1), (2, 0)], edgecolor='black', fill=False) - >>> is_edgy_polygon(polygon_handle) - True + True if the handle is an edgy polygon (only edges, no fill) """ if not isinstance(handle, Polygon): return False - - if np.sum(handle.get_edgecolor()) == 0: + + edgecolor = handle.get_edgecolor() + if np.all(edgecolor == 0): return False + + return not handle.get_fill() - if handle.get_fill(): - return False - return True +def resolve_handle_label( + handle: Any, + raw: bool = False +) -> Tuple[Any, str]: + """ + Resolve the artist handle and label for the legend. + + Parameters + ---------- + handle : Any + The artist handle + raw : bool + If True, return the raw handle and label + + Returns + ------- + Tuple[Any, str] + The resolved handle and its label + + Raises + ------ + RuntimeError + If unable to extract label from handle + """ -def resolve_handle_label(handle, raw: bool = False): if raw: - if hasattr(handle, 'get_label'): - return handle, handle.get_label() - return handle, '_nolegend_' + label = getattr(handle, "get_label", lambda: "_nolegend_")() + return handle, label + if isinstance(handle, Container): label = handle.get_label() - if label.startswith('_'): + if not label or label.startswith("_"): return resolve_handle_label(handle[0]) - elif isinstance(handle, list): + elif isinstance(handle, (list, tuple)): return resolve_handle_label(handle[0]) - elif isinstance(handle, tuple): - _, label = resolve_handle_label(handle[0]) - handle = tuple([h[0] if (isinstance(h, list) and len(h) == 1) \ - else h for h in handle]) - elif hasattr(handle, 'get_label'): + elif hasattr(handle, "get_label"): label = handle.get_label() else: - raise RuntimeError('unable to extract label from the handle') + raise RuntimeError("Unable to extract label from the handle") + return handle, label + + +def remake_handles( + handles: List[Any], + polygon_to_line: bool = True, + fill_border: bool = True, + line2d_styles: Optional[Dict[str, Any]] = None, + border_styles: Optional[Dict[str, Any]] = None, +) -> List[Any]: + """ + Remake legend handles for better representation. -def remake_handles(handles:List, polygon_to_line:bool=True, fill_border:bool=True, - line2d_styles:Optional[Dict]=None, border_styles:Optional[Dict]=None): + Parameters + ---------- + handles : List[Any] + List of artist handles + polygon_to_line : bool + Convert polygon edges to lines in the legend + fill_border : bool + Add a border to filled patches in the legend + line2d_styles : Optional[Dict[str, Any]] + Styles for Line2D objects + border_styles : Optional[Dict[str, Any]] + Styles for border rectangles + + Returns + ------- + List[Any] + List of remade artist handles + """ + new_handles = [] for handle in handles: + subhandles = handle if isinstance(handle, (list, tuple)) else [handle] new_subhandles = [] - if isinstance(handle, (list, tuple)): - subhandles = handle - else: - subhandles = [handle] + for subhandle in subhandles: - if ((polygon_to_line) and is_edgy_polygon(subhandle)): - line2d_styles = combine_dict(line2d_styles) - subhandle = Line2D([], [], color=subhandle.get_edgecolor(), - linestyle=subhandle.get_linestyle(), - label=subhandle.get_label(), - **line2d_styles) + if polygon_to_line and is_edgy_polygon(subhandle): + line_styles = line2d_styles or {} + subhandle = Line2D( + [], + [], + color=subhandle.get_edgecolor(), + linestyle=subhandle.get_linestyle(), + label=subhandle.get_label(), + **line_styles + ) new_subhandles.append(subhandle) + if fill_border and isinstance(subhandle, PolyCollection): - border_styles = combine_dict(border_styles) - border_handle = Rectangle((0, 0), 1, 1, facecolor='none', - **border_styles) + border_style = border_styles or {} + border_handle = Rectangle( + (0, 0), + 1, + 1, + facecolor="none", + **border_style + ) new_subhandles.append(border_handle) - if len(new_subhandles) == 1: - new_subhandles = new_subhandles[0] + + if isinstance(handle, Container): + kwargs = {"label": handle.get_label()} + if isinstance(handle, ErrorbarContainer): + kwargs.update({ + "has_xerr": handle.has_xerr, + "has_yerr": handle.has_yerr + }) + new_handle = type(handle)(tuple(new_subhandles), **kwargs) else: - new_subhandles = tuple(new_subhandles) - new_handles.append(new_subhandles) + new_handle = ( + new_subhandles[0] + if len(new_subhandles) == 1 + else tuple(new_subhandles) + ) + new_handles.append(new_handle) + return new_handles -def isolate_contour_styles(styles: Dict[str, Any]) -> List[Dict[str, Any]]: +def isolate_contour_styles( + styles: Dict[str, Any] +) -> Iterator[Dict[str, Any]]: """ - Converts keyword arguments for contour or contourf to a list of - styles for each contour level, ensuring that styles are consistently applied - across different levels. - + Convert contour or contourf keyword arguments to a list of styles for each level. + Parameters ---------- - styles : dict - Dictionary of keyword arguments passed to contour or contourf. Keys can include - 'linestyles', 'linewidths', 'colors', and 'alpha', each of which can be a single value - or a sequence of values. - + styles : Dict[str, Any] + Dictionary of keyword arguments passed to contour or contourf + Returns ------- - list of dict - A list of dictionaries, each corresponding to the styles for one contour level. If - all styles are single values, the same dictionary is repeated for each contour level. - If sequences are passed for any styles, the function ensures that the lengths of - those sequences are consistent. - + Iterator[Dict[str, Any]] + Iterator of style dictionaries, one per contour level + Raises ------ ValueError - If the lengths of the sequences for different styles are inconsistent. + If style sequences have inconsistent lengths """ + # Map input style names to matplotlib properties style_key_map = { - 'linestyles': 'linestyle', - 'linewidths': 'linewidth', - 'colors': 'color', - 'alpha': 'alpha' + "linestyles": "linestyle", + "linewidths": "linewidth", + "colors": "color", + "alpha": "alpha", } - - # Extract relevant styles from the input dictionary + + # Extract relevant styles relevant_styles = { - new_name: styles[old_name] for old_name, new_name in style_key_map.items() if old_name in styles + new_name: styles[old_name] + for old_name, new_name in style_key_map.items() + if old_name in styles } - - # Determine the size (length) of each style property + + # Determine sizes sizes = [] - for style in relevant_styles.values(): - if isinstance(style, Sequence) and not isinstance(style, str): - sizes.append(len(style)) + for style_value in relevant_styles.values(): + if isinstance(style_value, Sequence) and not isinstance(style_value, str): + sizes.append(len(style_value)) else: sizes.append(1) - - # Handle case where there are no styles provided + if not sizes: return repeat({}) - - # Check for consistent sizes (if multiple sequences are provided) + + # Check for consistent sizes unique_sizes = np.unique([size for size in sizes if size != 1]) if len(unique_sizes) > 1: - raise ValueError("Contour styles have inconsistent sizes.") + raise ValueError("Contour styles have inconsistent sizes") - # Get the maximum size (this will determine the number of contour levels) - max_size = np.max(sizes) - - # If all styles are scalar, repeat the same dictionary + # Get maximum size + max_size = max(sizes) + + # Handle scalar case if max_size == 1: return repeat(relevant_styles) - - # Create a list of dictionaries, each corresponding to a contour level + + # Create style dictionaries for each level list_styles = [] for i in range(max_size): level_styles = { - key: value if sizes[j] == 1 else value[i] - for j, (key, value) in enumerate(relevant_styles.items()) + key: value if sizes[idx] == 1 else value[i] + for idx, (key, value) in enumerate(relevant_styles.items()) } list_styles.append(level_styles) - + return list_styles -def get_axis_limits(ax, xmin: Optional[float] = None, xmax: Optional[float] = None, - ymin: Optional[float] = None, ymax: Optional[float] = None, - ypadlo: Optional[float] = None, ypadhi: Optional[float] = None, - ypad: Optional[float] = None): + +def get_axis_limits( + ax: Axes, + xmin: Optional[float] = None, + xmax: Optional[float] = None, + ymin: Optional[float] = None, + ymax: Optional[float] = None, + ypadlo: Optional[float] = None, + ypadhi: Optional[float] = None, + ypad: Optional[float] = None, +) -> Tuple[List[float], List[float]]: + """ + Calculate new axis limits with optional padding. + Parameters + ---------- + ax : matplotlib.axes.Axes + The axes to calculate limits for + xmin, xmax : Optional[float] + X-axis limits + ymin, ymax : Optional[float] + Y-axis limits + ypadlo : Optional[float] + Lower y-padding fraction + ypadhi : Optional[float] + Upper y-padding fraction + ypad : Optional[float] + Symmetric y-padding fraction + + Returns + ------- + Tuple[List[float], List[float]] + New x and y limits + + Raises + ------ + ValueError + If invalid padding values are provided + """ xlim = list(ax.get_xlim()) ylim = list(ax.get_ylim()) - - # Update x-limits if provided + + # Update x-limits if xmin is not None: xlim[0] = xmin if xmax is not None: xlim[1] = xmax - + # Check conflicting padding values - if (ypad is not None) and (ypadhi is not None): - raise ValueError(f'Cannot set both `ypad` and `ypadhi`.') - - # If ypad is set, use it for upper padding + if ypad is not None and ypadhi is not None: + raise ValueError("Cannot set both 'ypad' and 'ypadhi'") + + # Use ypad for upper padding if set if ypad is not None: ypadhi = ypad - - # Determine the lower and upper paddings - if (ypadhi is not None) or (ypadlo is not None): - - ypads = (ypadlo or 0, ypadhi or 0) - if not (0 <= ypads[0] <= 1): - raise ValueError('`ypadlo` must be between 0 and 1.') - if not (0 <= ypads[1] <= 1): - raise ValueError('`ypadhi` must be between 0 and 1.') - - # Logarithmic scale adjustment + # Calculate padding + if ypadhi is not None or ypadlo is not None: + ypad_lo = ypadlo or 0 + ypad_hi = ypadhi or 0 + + if not (0 <= ypad_lo <= 1): + raise ValueError("'ypadlo' must be between 0 and 1") + if not (0 <= ypad_hi <= 1): + raise ValueError("'ypadhi' must be between 0 and 1") + + # Handle logarithmic scale if ax.get_yaxis().get_scale() == "log": if ylim[0] <= 0: - raise ValueError("ymin must be positive in a logscale plot") - new_ymin = ylim[1] / (ylim[1] / ylim[0]) ** (1 + ypads[0]) - new_ymax = ylim[0] * (ylim[1] / ylim[0]) ** (1 + ypads[1]) + raise ValueError("Y minimum must be positive in log scale") + new_ymin = ylim[1] / (ylim[1] / ylim[0]) ** (1 + ypad_lo) + new_ymax = ylim[0] * (ylim[1] / ylim[0]) ** (1 + ypad_hi) else: - # Linear scale adjustment - range_y = ylim[1] - ylim[0] - new_ymin = ylim[0] - range_y * ypads[0] / (1 - ypads[0] - ypads[1]) - new_ymax = ylim[1] + range_y * ypads[1] / (1 - ypads[0] - ypads[1]) - - # Apply padding only if the corresponding value is set - if ypads[0]: + # Linear scale + y_range = ylim[1] - ylim[0] + new_ymin = ylim[0] - y_range * ypad_lo / (1 - ypad_lo - ypad_hi) + new_ymax = ylim[1] + y_range * ypad_hi / (1 - ypad_lo - ypad_hi) + + # Apply padding + if ypad_lo: ylim[0] = new_ymin - if ypads[1]: + if ypad_hi: ylim[1] = new_ymax - - # Override y-limits if explicitly provided + + # Override with explicit limits if ymin is not None: ylim[0] = ymin if ymax is not None: ylim[1] = ymax - - # Return the calculated xlim and ylim values - return xlim, ylim + + return xlim, ylim \ No newline at end of file diff --git a/quickstats/plots/template_analysis_label_options.py b/quickstats/plots/template_analysis_label_options.py new file mode 100644 index 0000000000000000000000000000000000000000..dee24802bf875a3a15b6c7decb5443d09dac12c9 --- /dev/null +++ b/quickstats/plots/template_analysis_label_options.py @@ -0,0 +1,37 @@ +from typing import Optional, Union, Dict + +from .registry import Registry + +REGISTRY = Registry() + +REGISTRY['default'] = { + 'status': 'int', + 'loc': (0.05, 0.95), + 'fontsize': 25 +} + +REGISTRY.use('default') + +REGISTRY['atlas'] = REGISTRY.data & { + 'colab': 'ATLAS' +} + +REGISTRY['atlas_run2'] = REGISTRY.data & { + 'colab': 'ATLAS', + 'energy' : '13 TeV', + 'lumi' : "140 fb$^{-1}$", +} + +REGISTRY['atlas_run3'] = REGISTRY.data & { + 'colab': 'ATLAS', + 'energy' : '13.6 TeV' +} + +REGISTRY['cms'] = REGISTRY.data & { + 'colab': 'CMS' +} + +get = REGISTRY.get +use = REGISTRY.use +parse = REGISTRY.parse +chain = REGISTRY.chain \ No newline at end of file diff --git a/quickstats/plots/template_styles.py b/quickstats/plots/template_styles.py new file mode 100644 index 0000000000000000000000000000000000000000..01c7ba036fb0f2c51e574108cc8eafd6da7d7b6c --- /dev/null +++ b/quickstats/plots/template_styles.py @@ -0,0 +1,144 @@ +from typing import Optional, Union, Dict + +from .registry import Registry + +REGISTRY = Registry() + +REGISTRY['default'] = { + 'figure':{ + 'figsize': (11.111, 8.333), + 'dpi': 72, + 'facecolor': "#FFFFFF" + }, + 'legend_Line2D': { + 'linewidth': 3 + }, + 'legend_border': { + 'edgecolor' : 'black', + 'linewidth' : 1 + }, + 'annotation':{ + 'fontsize': 12 + }, + 'axis': { + 'major_length': 16, + 'minor_length': 8, + 'major_width': 2, + 'minor_width': 1, + 'spine_width': 2, + 'labelsize': 20, + 'offsetlabelsize': 20, + 'tick_bothsides': True, + 'x_axis_styles': {}, + 'y_axis_styles': {} + }, + 'xtick':{ + 'format': 'numeric', + 'locator': 'auto', + 'steps': None, + 'prune': None, + 'integer': False + }, + 'ytick':{ + 'format': 'numeric', + 'locator': 'auto', + 'steps': None, + 'prune': None, + 'integer': False + }, + 'xlabel': { + 'fontsize': 22, + 'loc' : 'right', + 'labelpad': 10 + }, + 'ylabel': { + 'fontsize': 22, + 'loc' : 'top', + 'labelpad': 10 + }, + 'title':{ + 'fontsize': 20, + 'loc': 'center', + 'pad': 10 + }, + 'text':{ + 'fontsize': 20, + 'verticalalignment': 'top', + 'horizontalalignment': 'left' + }, + 'plot':{ + 'linewidth': 2 + }, + 'point': { + 'markersize': 10, + 'linewidth': 0 + }, + 'hist': { + 'linewidth': 2 + }, + 'errorbar': { + "marker": 'x', + "linewidth": 0, + "markersize": 0, + "elinewidth": 1, + "capsize": 2, + "capthick": 1 + }, + 'fill_between': { + "alpha": 0.5 + }, + 'legend':{ + "fontsize": 20, + "columnspacing": 0.8 + }, + 'ratio_frame':{ + 'height_ratios': (3, 1), + 'hspace': 0.07 + }, + 'barh': { + 'height': 0.5 + }, + 'bar': { + }, + 'colorbar': { + 'fraction': 0.15, + 'shrink': 1. + }, + 'contour':{ + 'linestyles': 'solid', + 'linewidths': 3 + }, + 'contourf':{ + 'alpha': 0.5, + 'zorder': 0 + }, + 'colorbar_axis': { + 'labelsize': 20, + 'y_axis_styles': { + 'labelleft': False, + 'labelright': True, + 'left': False, + 'right': True, + 'direction': 'out' + } + }, + 'colorbar_label': { + 'fontsize': 22, + 'labelpad': 0 + }, + 'clabel': { + 'inline': True, + 'fontsize': 10 + }, + 'line': { + }, + 'line_collection': { + } +} + +REGISTRY.use('default') + +get = REGISTRY.get +use = REGISTRY.use +parse = REGISTRY.parse +chain = REGISTRY.chain \ No newline at end of file diff --git a/quickstats/plots/test_statistic_distribution_plot.py b/quickstats/plots/test_statistic_distribution_plot.py index 958ba30d953c3c260da107c27e011aeed3e20460..fa1a0dd0476c24285c9d788e4df91dbf71a57e73 100644 --- a/quickstats/plots/test_statistic_distribution_plot.py +++ b/quickstats/plots/test_statistic_distribution_plot.py @@ -4,7 +4,7 @@ import math import numpy as np from quickstats.plots import AbstractPlot -from quickstats.plots.template import single_frame, parse_styles +from quickstats.plots.template import single_frame class TestStatisticDistributionPlot(AbstractPlot): diff --git a/quickstats/plots/two_panel_1D_plot.py b/quickstats/plots/two_panel_1D_plot.py index 6c0fa50a4e73b5cb68d1469362ac7544cf85c912..9a1c8037e519dce8dff64847ea9720bc7274c59b 100644 --- a/quickstats/plots/two_panel_1D_plot.py +++ b/quickstats/plots/two_panel_1D_plot.py @@ -37,10 +37,10 @@ class TwoPanel1DPlot(AbstractPlot): config:Optional[Dict]=None): self.data_map = data_map - self.label_map = label_map self.styles_map = styles_map super().__init__(color_cycle=color_cycle, + label_map=label_map, styles=styles, analysis_label_options=analysis_label_options, config=config) diff --git a/quickstats/plots/upper_limit_benchmark_plot.py b/quickstats/plots/upper_limit_benchmark_plot.py index 7766740e05922beb253c2bfcc9796b7254a0eac9..385635b09e5ecc8371a2599160806a0c3d239441 100644 --- a/quickstats/plots/upper_limit_benchmark_plot.py +++ b/quickstats/plots/upper_limit_benchmark_plot.py @@ -10,7 +10,7 @@ import pandas as pd from quickstats.plots.template import create_transform, remake_handles from quickstats.plots import AbstractPlot from quickstats.utils.common_utils import combine_dict -from quickstats.maths.statistics import HistComparisonMode +from quickstats.maths.histograms import HistComparisonMode class UpperLimitBenchmarkPlot(AbstractPlot): @@ -31,36 +31,30 @@ class UpperLimitBenchmarkPlot(AbstractPlot): } } - COLOR_PALLETE = { - '2sigma' : 'hh:darkyellow', - '1sigma' : 'hh:lightturquoise', - 'expected' : 'k', - 'observed' : 'k', - 'theory' : 'darkred', + COLOR_MAP = { + '2sigma' : 'hh:darkyellow', + '1sigma' : 'hh:lightturquoise', + 'expected' : 'k', + 'observed' : 'k', + 'alt.2sigma' : '#ffcc00', + 'alt.1sigma' : '#00cc00', + 'alt.expected': 'r', + 'alt.observed': 'r', + 'theory' : 'darkred', 'theory_unc' : 'hh:darkpink' } - COLOR_PALLETE_EXTRA = { - '2sigma': '#ffcc00', - '1sigma': '#00cc00', - 'expected': 'r', - 'observed': 'r', - } - LABELS = { '2sigma': r'Expected limit $\pm 2\sigma$', '1sigma': r'Expected limit $\pm 1\sigma$', 'expected': r'Expected limit (95% CL)', 'observed': r'Observed limit (95% CL)', + 'alt.2sigma': r'Alt. Expected limit $\pm 2\sigma$', + 'alt.1sigma': r'Alt. Expected limit $\pm 1\sigma$', + 'alt.expected': r'Alt. Expected limit (95% CL)', + 'alt.observed': r'Alt. Observed limit (95% CL)', 'theory' : r'Theory prediction' } - - LABELS_EXTRA = { - '2sigma': r'Alt. Expected limit $\pm 2\sigma$', - '1sigma': r'Alt. Expected limit $\pm 1\sigma$', - 'expected': r'Alt. Expected limit (95% CL)', - 'observed': r'Alt. Observed limit (95% CL)', - } CONFIG = { 'xmargin': 0.3, @@ -132,10 +126,10 @@ class UpperLimitBenchmarkPlot(AbstractPlot): def theory_func(self): return self._theory_func - def __init__(self, data:pd.DataFrame, + def __init__(self, data: Union[pd.DataFrame, Dict[str, pd.DataFrame]], theory_func:Callable=None, - color_pallete:Optional[Dict]=None, - labels:Optional[Dict]=None, + color_map:Optional[Dict]=None, + label_map:Optional[Dict]=None, styles:Optional[Union[Dict, str]]=None, config:Optional[Dict]=None, custom_styles:Optional[Union[Dict, str]]=None, @@ -146,29 +140,23 @@ class UpperLimitBenchmarkPlot(AbstractPlot): dataframe with columns ("-2", "-1", "0", "1", "2", "obs") representing the corresponding limit level and rows indexed by the benchmark names """ - super().__init__(color_pallete=color_pallete, + super().__init__(color_map=color_map, + label_map=label_map, styles=styles, config=config, analysis_label_options=analysis_label_options) self.data = data.copy() self.set_theory_function(theory_func) - self.labels = combine_dict(self.LABELS, labels) self.custom_styles = combine_dict(self.CUSTOM_STYLES, custom_styles) self.alt_data = {} - self.alt_color_pallete = {} self.alt_custom_styles = {} - self.alt_labels = {} self.hline_data = [] def add_alternative_data(self, data:Optional[pd.DataFrame], key:str="alt", - color_pallete:Optional[Dict]=None, - labels:Optional[Dict]=None, custom_styles:Optional[Dict]=None): if key is None: raise RuntimeError('key can not be None') self.alt_data[key] = data - self.alt_color_pallete[key] = combine_dict(self.COLOR_PALLETE_EXTRA, color_pallete) - self.alt_labels[key] = combine_dict(self.LABELS_EXTRA, labels) self.alt_custom_styles[key] = combine_dict(self.CUSTOM_STYLES_EXTRA, custom_styles) def get_default_legend_order(self): diff --git a/quickstats/plots/variable_distribution_plot.py b/quickstats/plots/variable_distribution_plot.py index 79540abd6702c7f8abab196749e36a95934b76ae..bedfa0eb01b95d007b1767e1921fa590e887d999 100644 --- a/quickstats/plots/variable_distribution_plot.py +++ b/quickstats/plots/variable_distribution_plot.py @@ -1,714 +1,1073 @@ -from typing import Optional, Union, Dict, List, Sequence, Tuple, Callable +from typing import Optional, Union, Dict, List, Sequence, Tuple, Any, Callable +from collections import defaultdict +from copy import deepcopy import pandas as pd import numpy as np -from matplotlib.ticker import MaxNLocator -from matplotlib.lines import Line2D -from matplotlib.patches import Polygon - -from quickstats.plots import AbstractPlot, get_color_cycle -from quickstats.plots.template import centralize_axis, remake_handles -from quickstats.utils.common_utils import combine_dict, remove_duplicates -from quickstats.maths.numerics import safe_div -from quickstats.maths.statistics import (HistComparisonMode, - min_max_to_range, get_hist_data, - get_stacked_hist_data, - get_hist_comparison_data, - get_clipped_data) +from matplotlib.axes import Axes +from quickstats.core import mappings as mp +from quickstats.utils.common_utils import remove_duplicates +from quickstats.maths.histograms import HistComparisonMode +from quickstats.concepts import Histogram1D, StackedHistogram from .core import PlotFormat, ErrorDisplayFormat +from .colors import ColorType, ColormapType +from .template import get_artist_colors +from .histogram_plot import HistogramPlot + +def _merge_styles( + styles_map: Dict[str, Dict[str, Any]], + primary_key: Optional[str] = None, + use_sequence_options: bool = True +) -> Dict[str, Any]: + """ + Merge style dictionaries from multiple targets into a single style dictionary. + + Parameters + ---------- + styles_map : Dict[str, Dict[str, Any]] + A mapping from target names to their style dictionaries + primary_key : Optional[str], optional + The key of the primary target whose styles should be prioritized + use_sequence_options : bool, optional + Whether to collect sequence options (like 'color', 'label') into lists + + Returns + ------- + Dict[str, Any] + A merged style dictionary -class VariableDistributionPlot(AbstractPlot): + Raises + ------ + ValueError + If inconsistent style values or missing required options are found + """ + sequence_options = ["color", "label"] - COLOR_CYCLE = "simple_contrast" + if primary_key is not None: + merged_styles = deepcopy(styles_map[primary_key]) + else: + merged_styles = {} + for target, styles in styles_map.items(): + styles = deepcopy(styles) + for key, value in styles.items(): + if key in sequence_options: + continue + if key in merged_styles and merged_styles[key] != value: + targets = list(styles_map) + raise ValueError( + f"Inconsistent values for option '{key}' among targets: " + f"{', '.join(targets)}" + ) + merged_styles[key] = value + + if use_sequence_options: + for option in sequence_options: + merged_styles[option] = [styles.get(option) for styles in styles_map.values()] + if None in merged_styles[option]: + missing_targets = [ + target for target, styles in styles_map.items() + if styles.get(option) is None + ] + raise ValueError( + f"Missing '{option}' for targets: {', '.join(missing_targets)}" + ) + return merged_styles + +class VariableDistributionPlot(HistogramPlot): + """ + Class for plotting variable distributions with advanced features. + """ + + COLOR_CYCLE = "atlas_hdbs" + STYLES = { - "legend": { - "handletextpad": 0.3 - }, "hist": { - 'histtype' : 'step', + 'histtype': 'step', 'linestyle': '-', 'linewidth': 2 }, - "bar": { - 'linewidth' : 0, - 'alpha' : 0.5 - }, - "fill_between": { - 'alpha' : 0.5 - }, 'errorbar': { - "marker": 'o', - "markersize": None, + 'marker': 'o', + 'markersize': 10, 'linestyle': 'none', - "linewidth": 0, - "elinewidth": 2, - "capsize": 0, - "capthick": 0 - }, - } - - CONFIG = { - 'ratio_line_styles':{ + 'linewidth': 0, + 'elinewidth': 2, + 'capsize': 0, + 'capthick': 0 + }, + 'fill_between': { + 'alpha': 0.5, + 'color': 'gray' + }, + "bar": { + 'linewidth': 0, + 'alpha': 0.5, + 'color': 'gray' + }, + 'ratio_line': { 'color': 'gray', 'linestyle': '--', 'zorder': 0 - }, + } + } + + CONFIG = { 'plot_format': 'hist', 'error_format': 'shade', - 'error_label_format': r'{label}', + 'comparison_mode': 'ratio', 'show_xerr': False, - 'stacked_label': ':stacked_{index}:', - 'box_legend_handle': False + 'error_on_top': True, + 'inherit_color': True, + 'combine_stacked_error': False, + 'box_legend_handle': False, + 'isolate_error_legend': False, + 'stacked_object_id': 'stacked_{index}', + 'comparison_object_id': 'comparison_{reference}_{target}' } - - def __init__(self, data_map:Union["pandas.DataFrame", Dict[str, "pandas.DataFrame"]], - plot_options:Optional[Dict]=None, - label_map:Optional[Dict]=None, - color_cycle:Optional[Union[List, str, "ListedColorMap"]]='simple_contrast', - styles:Optional[Union[Dict, str]]=None, - comparison_styles:Optional[Union[Dict, str]]=None, - analysis_label_options:Optional[Dict]=None, - config:Optional[Dict]=None, - verbosity:Optional[Union[int, str]]='INFO'): + + def __init__( + self, + data_map: Union[pd.DataFrame, Dict[str, pd.DataFrame]], + plot_options: Optional[Dict[str, Dict[str, Any]]] = None, + label_map: Optional[Dict[str, str]] = None, + color_cycle: Optional[ColormapType] = None, + styles: Optional[Union[Dict[str, Any], str]] = None, + analysis_label_options: Optional[Dict[str, Any]] = None, + config: Optional[Dict[str, Any]] = None, + verbosity: Union[int, str] = 'INFO' + ) -> None: """ - Parameters - ---------------------------------------------------------------------------- - data_map: pandas.DataFrame or dictionary of pandas.DataFrame - Input dataframe(s). If a dictionary is given, it should be of the form - {<sample_name>: <pandas.DataFrame>} - plot_options: dictionary - A dictionary containing the plot options for various group of samples. - It should be of the form - { <sample_group>: - { - "samples": <list of sample names>, - "weight_scale": <scale factor>, - "styles" : <options in mpl.hist or mpl.errorbar>, - "error_styles": <options in mpl.bar>, - "plot_format": "hist" or "errorbar", - "show_error": True or False, - "stack_index": <stack index>, - "hide": <list of callables / 2-tuples> - } - } - - "styles" should match the options available in mpl.hist if - `plot_format` = "hist" or mpl.errorbar if `plot_format` = "errorbar" - - "error_styles" should match the options available in mpl.errorbar if - `error_format` = "errorbar", mpl.bar if `error_format` = "shade" or - mpl.fill_between if `error_format` = "fill" + Initialize the VariableDistributionPlot. - (optional) "weight_scale" is used to scale the weights of the given - group of samples by the given factor - - (optional) "show_error" is used to specify whether to show the errorbar/ - errorbands for this particular target. - - (optional) "stack_index" is used when multiple stacked plots are made; - sample groups with the same stack index will be stacked; this option - is only used when `plot_format` = "hist" and the draw method is called - with the `stack` option set to True; by default a stack index of 0 will - be assigned - - (optional) "hide" defines the condition to hide portion of the data in - the plot; in case of a 2-tuple, it specifies the (start, end) range of - data that should be hidden; in case of a callable, it is a function - that takes as input the value of the variable, and outputs a boolean - value indicating whether the data should be hidden; - - Note: If "samples" is not given, it will default to [<sample_group>] - - Note: If both `plot_format` and `error_format` are errorbar, "styles" - will be used instead of "error_styles" for the error styles. - - + Parameters + ---------- + data_map : Union[pd.DataFrame, Dict[str, pd.DataFrame]] + Input dataframe(s). If dictionary, maps sample names to dataframes + plot_options : Optional[Dict[str, Dict[str, Any]]], optional + A dictionary containing the plot options for various group of samples. + It should be of the form + { <sample_group>: + { + "samples": <list of sample names>, + "weight_scale": <scale factor>, + "styles" : <matplotlib artist options>, + "error_styles": <matplotlib artist options>, + "plot_format": "hist" or "errorbar", + "error_format": "errorbar", "fill" or "shade" + "show_error": True or False, + "primary": True or False, + "stack_index": <stack index>, + "mask_condition": <callable or tuple of 2 floats> + } + } + + "styles" should match the options available in mpl.hist if + `plot_format` = "hist" or mpl.errorbar if `plot_format` = "errorbar" + + "error_styles" should match the options available in mpl.errorbar if + `error_format` = "errorbar", mpl.bar if `error_format` = "shade" or + mpl.fill_between if `error_format` = "fill" + + (optional) "weight_scale" is used to scale the weights of the given + group of samples by the given factor + + (optional) "plot_format" is used to indicate which matplotlib artist + is used to draw the variable distribution; by default the internal + value from config['plot_format'] is used; allowed formats are + "hist" or "errorbar" + + (optional) "error_format" is used to indicate which matplotlib artist + is used to draw the error information; by default the internal + value from config['error_format'] is used; allowed formats are + "errorbar", "fill" or "shade" + + (optional) "show_error" is used to specify whether to show the errorbar/ + errorbands for this particular target + + (optional) "stack_index" is used when multiple stacked plots are made; + sample groups with the same stack index will be stacked; this option + is only used when `plot_format` = "hist" and the draw method is called + with the `stack` option set to True; by default a stack index of 0 will + be assigned + + (optional) "mask_condition" defines the condition to mask portion(s) + of the data in the plot; in case of a 2-tuple, it specifies the + (start, end) bin range of data that should be hidden; in case of a + callable, it is a function that takes as input the bin_centers (x) + and bin_content (y) of the histogram, and outputs a boolean + array indicating the locations of the histogram that should be hidden + + Note: If "samples" is not given, it will default to [<sample_group>] + + Note: If both `plot_format` and `error_format` are errorbar, "styles" + will be used instead of "error_styles" for the error styles + label_map : Optional[Dict[str, str]], optional + Mapping from target names to display labels + color_cycle : Optional[ColormapType], optional + Color cycle for plotting + styles : Optional[Union[Dict[str, Any], str]], optional + Global styles for plot artists + analysis_label_options : Optional[Dict[str, Any]], optional + Options for analysis labels + config : Optional[Dict[str, Any]], optional + Configuration parameters + verbosity : Union[int, str], optional + Logging verbosity level, by default 'INFO' """ self.load_data(data_map) self.plot_options = plot_options - self.label_map = label_map - self.reset_hist_data() - super().__init__(color_cycle=color_cycle, - styles=styles, - analysis_label_options=analysis_label_options, - config=config, - verbosity=verbosity) - - def load_data(self, data_map:Dict[str, pd.DataFrame]): + self.reset_metadata() + super().__init__( + color_cycle=color_cycle, + styles=styles, + label_map=label_map, + analysis_label_options=analysis_label_options, + config=config, + verbosity=verbosity + ) + + def load_data(self, data_map: Union[pd.DataFrame, Dict[str, pd.DataFrame]]) -> None: + """ + Load data into the plot. + + Parameters + ---------- + data_map : Union[pd.DataFrame, Dict[str, pd.DataFrame]] + Input dataframe(s). If dictionary, maps sample names to dataframes + """ if not isinstance(data_map, dict): data_map = {None: data_map} self.data_map = data_map - - def reset_hist_data(self): - self.hist_data = {} - self.hist_bin_edges = {} - self.hist_comparison_data = [] - - def set_plot_format(self, plot_format:str): + + def reset_metadata(self) -> None: + """Reset metadata, including histograms.""" + super().reset_metadata() + self.histograms = {} + + def set_plot_format(self, plot_format: str) -> None: + """ + Set the plot format. + + Parameters + ---------- + plot_format : str + The plot format to use + """ self.config['plot_format'] = PlotFormat.parse(plot_format) - - def set_error_format(self, error_format:str): + + def set_error_format(self, error_format: str) -> None: + """ + Set the error format. + + Parameters + ---------- + error_format : str + The error format to use + """ self.config['error_format'] = ErrorDisplayFormat.parse(error_format) - - def is_single_data(self): + + def is_single_data(self) -> bool: + """ + Check if there is only a single data set. + + Returns + ------- + bool + True if only a single data set is present + """ return (None in self.data_map) and (len(self.data_map) == 1) - def resolve_targets(self, targets:Optional[List[str]]=None, - plot_options:Optional[Dict]=None): + def resolve_targets( + self, + targets: Optional[List[str]] = None, + plot_options: Optional[Dict[str, Dict[str, Any]]] = None + ) -> List[Optional[str]]: + """ + Resolve the targets to be plotted. + + Parameters + ---------- + targets : Optional[List[str]], optional + List of target names + plot_options : Optional[Dict[str, Dict[str, Any]]], optional + Plot options dictionary + + Returns + ------- + List[Optional[str]] + List of resolved target names + + Raises + ------ + ValueError + If targets are specified when only a single data set is present + """ if self.is_single_data(): if targets is not None: - raise ValueError('no targets should be specified if only one set of input data is given') - targets = [None] - elif targets is None: + raise ValueError( + 'No targets should be specified if only one set of input data is given' + ) + return [None] + + if targets is None: all_samples = list(self.data_map.keys()) + if plot_options is None: + return all_samples + targets = [] - if plot_options is not None: - grouped_samples = [] - for key in plot_options: - samples = plot_options[key].get("samples", [key]) - grouped_samples.extend([sample for sample in samples \ - if sample not in grouped_samples]) - targets.append(key) - targets.extend([sample for sample in all_samples \ - if sample not in grouped_samples]) - else: - targets = all_samples + grouped_samples = set() + for key, options in plot_options.items(): + targets.append(key) + samples = options.get("samples", [key]) + grouped_samples |= set(samples) + + targets.extend([sample for sample in all_samples if sample not in grouped_samples]) + return targets - - def resolve_plot_options(self, plot_options:Optional[Dict]=None, - targets:Optional[List[str]]=None): - plot_options = plot_options or self.plot_options or {} - targets = self.resolve_targets(targets, plot_options=plot_options) - resolved_plot_options = {} + + def resolve_plot_options( + self, + plot_options: Optional[Dict[str, Dict[str, Any]]] = None, + targets: Optional[List[str]] = None, + stacked: bool = False + ) -> Dict[str, Dict[str, Any]]: + """ + Resolve plot options for the given targets. + + Parameters + ---------- + plot_options : Optional[Dict[str, Dict[str, Any]]], optional + Plot options dictionary + targets : Optional[List[str]], optional + List of target names + stacked : bool, optional + Whether to stack the plots + + Returns + ------- + Dict[str, Dict[str, Any]] + Resolved plot options + + Raises + ------ + RuntimeError + If no targets to draw + ValueError + If no samples specified for a target or duplicate samples found + """ + targets = self.resolve_targets(targets, plot_options) + if not targets: + raise RuntimeError('No targets to draw') + + plot_options = plot_options or {} colors = self.get_colors() n_colors, color_i = len(colors), 0 - if self.label_map is not None: - label_map = self.label_map - else: - label_map = {} + + resolved_plot_options = {} for target in targets: - options = combine_dict(plot_options.get(target, {})) - # use global plot format if not specified - if 'plot_format' not in options: - options['plot_format'] = PlotFormat.parse(self.config['plot_format']) - else: - options['plot_format'] = PlotFormat.parse(options['plot_format']) - # use global error format if not specified - if 'error_format' not in options: - if options['plot_format'] == PlotFormat.ERRORBAR: - options['error_format'] = ErrorDisplayFormat.ERRORBAR - else: - options['error_format'] = ErrorDisplayFormat.parse(self.config['error_format']) - else: - options['error_format'] = ErrorDisplayFormat.parse(options['error_format']) - # use default styles if not specified - if 'styles' not in options: - options['styles'] = combine_dict(self.get_styles(options['plot_format'].mpl_method)) - else: - options['styles'] = combine_dict(self.get_styles(options['plot_format'].mpl_method), options['styles']) - if 'color' not in options['styles']: + options = deepcopy(plot_options.get(target, {})) + options.setdefault('samples', [target]) + options.setdefault('primary', False) + options.setdefault('weight_scale', None) + options.setdefault('stack_index', 0) + options.setdefault('mask_condition', None) + + if not options['samples']: + raise ValueError(f'No samples specified for target "{target}"') + if len(set(options['samples'])) != len(options['samples']): + raise ValueError(f'Found duplicated samples for target "{target}": {options["samples"]}') + + plot_format = PlotFormat.parse( + options.get('plot_format', self.config['plot_format']) + ) + styles = options.get('styles', {}) + + if 'color' not in styles: if color_i == n_colors: - self.stdout.warning("Number of targets is more than the number of colors " - "available in the color map. The colors will be recycled.") - options['styles']['color'] = colors[color_i % n_colors] + self.stdout.warning( + 'Number of targets exceeds available colors in the color map. Colors will be recycled.' + ) + styles['color'] = colors[color_i % n_colors] color_i += 1 - if 'label' not in options['styles']: - label = label_map.get(target, target) - # handle case of single data (no label needed) - if label is None: - label = 'None' - options['styles']['label'] = label - if 'samples' not in options: - options['samples'] = [target] - if 'error_styles' not in options: - options['error_styles'] = combine_dict(self.get_styles(options['error_format'].mpl_method)) - else: - options['error_styles'] = combine_dict(self.get_styles(options['error_format'].mpl_method), options['error_styles']) - # reuse color of the plot for the error by default - if 'color' not in options['error_styles']: - options['error_styles']['color'] = options['styles']['color'] - if 'label' not in options['error_styles']: - fmt = self.config['error_label_format'] - options['error_styles']['label'] = fmt.format(label=options['styles']['label']) - if 'stack_index' not in options: - options['stack_index'] = 0 - if 'weight_scale' not in options: - options['weight_scale'] = None - if 'hide' not in options: - options['hide'] = None + + if 'label' not in styles: + styles['label'] = self.label_map.get(target, target) or 'None' + + error_format = ErrorDisplayFormat.parse( + options.get( + 'error_format', + 'errorbar' if plot_format == 'errorbar' else self.config['error_format'] + ) + ) + + error_styles = options.get('error_styles', {}) + # Reuse color of the plot for the error by default + if 'color' not in error_styles: + error_styles['color'] = styles['color'] + if 'label' not in error_styles: + error_styles['label'] = self.label_map.get(f'{target}.error', styles['label']) or 'None' + + options.update({ + 'plot_format': plot_format, + 'error_format': error_format, + 'styles': styles, + 'error_styles': error_styles + }) resolved_plot_options[target] = options - return resolved_plot_options - def _merge_styles(self, styles_list:List[Dict]): - merged_styles = {} - sequence_args = ["color", "label"] - for styles in styles_list: - styles = styles.copy() - for key, value in styles.items(): - if key in sequence_args: - if key not in merged_styles: - merged_styles[key] = [] - merged_styles[key].append(value) - continue - if (key in merged_styles) and (value != merged_styles[key]): - raise ValueError('failed to merge style options for targets in a stacked plot: ' - f'found inconsistent values for the option "{key}"') - merged_styles[key] = value - return merged_styles + final_plot_options = {} + if not stacked: + for target in targets: + options = {'components': {}} + for key in ['plot_format', 'error_format', 'styles', 'error_styles']: + options[key] = resolved_plot_options[target].pop(key) + options['components'][target] = resolved_plot_options.pop(target) + final_plot_options[target] = options + return final_plot_options - def resolve_stacked_plot_options(self, plot_options:Dict): - stacked_plot_options = {} - targets = [target for target, options in plot_options.items() if \ - options['plot_format'] == PlotFormat.HIST] - if not targets: - raise RuntimeError('no histograms to be stacked') - target_map = {} + target_map = defaultdict(list) for target in targets: - stack_index = plot_options[target]['stack_index'] - if stack_index not in target_map: - target_map[stack_index] = [] + stack_index = resolved_plot_options[target]['stack_index'] target_map[stack_index].append(target) - stacked_plot_options = {} - for stack_index, targets in target_map.items(): + + for index, targets_group in target_map.items(): options = {} - options["components"] = {} - styles_list = [] - hide_list = [] - error_styles_list = [] - error_format_list = [] - for target in targets: - # modify the dictionary (intended) - target_options = plot_options.pop(target) - styles = target_options.pop("styles") - error_styles = target_options.pop("error_styles") - error_format = target_options.pop("error_format") - hide = target_options.pop("hide") - styles_list.append(styles) - error_styles_list.append(error_styles) - error_format_list.append(error_format) - hide_list.append(hide) - options["components"][target] = target_options - options['plot_format'] = PlotFormat.HIST - options['error_format_list'] = error_format_list - options['styles'] = self._merge_styles(styles_list) - options['error_styles_list'] = error_styles_list - options['hide_list'] = hide_list - label = self.config["stacked_label"].format(index=stack_index) - if label in stacked_plot_options: - raise RuntimeError(f"duplicated stack label: {label}") - stacked_plot_options[label] = options - return stacked_plot_options - - def resolve_comparison_options(self, comparison_options:Optional[Dict]=None, - plot_options:Optional[Dict]=None): + components = {target: resolved_plot_options.pop(target) for target in targets_group} + options['components'] = components + + if len(targets_group) == 1: + primary_target = targets_group[0] + else: + primary_target = next( + (t for t in targets_group if components[t]['primary']), + None + ) + + if len([t for t in targets_group if components[t]['primary']]) > 1: + raise RuntimeError( + f'Multiple primary targets found with stack index: {index}' + ) + + for format_type in ['plot', 'error']: + key = f'{format_type}_format' + format_map = {t: components[t].pop(key) for t in targets_group} + if primary_target is None: + if len(set(format_map.values())) > 1: + raise RuntimeError( + f'Inconsistent {format_type} format for targets with stack index: {index}' + ) + options[key] = next(iter(format_map.values())) + else: + options[key] = format_map[primary_target] + + styles_map = {t: components[t].pop('styles') for t in targets_group} + combine_stacked_error = self.config['combine_stacked_error'] + use_sequence_options = len(components) > 1 + styles = _merge_styles( + styles_map, primary_target, + use_sequence_options=use_sequence_options + ) + error_styles_map = {t: components[t].pop('error_styles') for t in targets_group} + use_sequence_options &= not combine_stacked_error + error_styles = _merge_styles( + error_styles_map, primary_target, + use_sequence_options=use_sequence_options + ) + # Only one target, no need to stack + if len(components) == 1: + target = targets_group[0] + else: + target = self.config["stacked_object_id"].format(index=index) + if combine_stacked_error: + error_styles['label'] = self.label_map.get(f'{target}.error', f'{target}.error') + options.update({ + 'styles': styles, + 'error_styles': error_styles + }) + final_plot_options[target] = options + return final_plot_options + + def resolve_comparison_options( + self, + comparison_options: Optional[Dict[str, Any]] = None, + plot_options: Optional[Dict[str, Dict[str, Any]]] = None + ) -> Optional[Dict[str, Any]]: + """ + Resolve comparison options for the plot. + + Parameters + ---------- + comparison_options : Dict[str, Any]], optional + Comparison options dictionary + plot_options : Dict[str, Dict[str, Any]], optional + Plot options dictionary + + Returns + ------- + Optional[Dict[str, Any]] + Resolved comparison options, or None if not provided + """ if comparison_options is None: return None - if plot_options is None: - plot_options = {} - comparison_options = combine_dict(comparison_options) - comparison_options['mode'] = HistComparisonMode.parse(comparison_options['mode']) - plot_colors = self.get_colors() - n_colors, color_i = len(plot_colors), 0 - if 'plot_format' in comparison_options: - plot_format = PlotFormat.parse(comparison_options.pop('plot_format')) - else: - plot_format = PlotFormat.parse(self.config['plot_format']) - # temporary fix because only error plot format is supported - plot_format = PlotFormat.ERRORBAR - if 'error_format' in comparison_options: - error_format = ErrorDisplayFormat.parse(comparison_options.pop('error_format')) - else: - error_format = ErrorDisplayFormat.parse(self.config['error_format']) + + plot_options = plot_options or {} + comparison_options = deepcopy(comparison_options) + comparison_options.setdefault('mode', self.config['comparison_mode']) + + if not callable(comparison_options['mode']): + comparison_options['mode'] = HistComparisonMode.parse(comparison_options['mode']) + + default_plot_format = comparison_options.get('plot_format', self.config['plot_format']) + default_error_format = comparison_options.get('error_format', self.config['error_format']) + components = comparison_options['components'] if not isinstance(components, list): components = [components] + comparison_options['components'] = components + + def get_target_color(target: str, style_type: str) -> Optional[str]: + if target in plot_options: + return plot_options[target][style_type].get('color') + for _, options in plot_options.items(): + names = list(options['components'].keys()) + if target not in names: + continue + color = options[style_type].get('color') + if isinstance(color, list): + return color[names.index(target)] + return color + return None + + inherit_color = self.config['inherit_color'] for component in components: - reference = component['reference'] - target = component['target'] - if 'plot_format' not in component: - component['plot_format'] = plot_format - if 'error_format' not in component: - component['error_format'] = error_format - com_plot_format = PlotFormat.parse(component['plot_format']) - com_error_format = ErrorDisplayFormat.parse(component['error_format']) - if 'styles' not in component: - component['styles'] = combine_dict(self.get_styles(com_plot_format.mpl_method)) - if 'error_styles' not in component: - component['error_styles'] = combine_dict(self.get_styles(com_error_format.mpl_method)) - if 'color' not in component['styles']: - if target in plot_options: - component['styles']['color'] = plot_options[target]['styles']['color'] - else: - component['styles']['color'] = colors[color_i % n_colors] - color_i += 1 - if 'color' not in component['error_styles']: - if target in plot_options: - component['error_styles']['color'] = plot_options[target]['error_styles']['color'] - else: - component['error_styles']['color'] = component['styles']['color'] component['mode'] = comparison_options['mode'] - comparison_options['components'] = components + plot_format = PlotFormat.parse(component.get('plot_format', default_plot_format)) + error_format = ErrorDisplayFormat.parse(component.get('error_format', default_error_format)) + + component.update({ + 'plot_format': plot_format, + 'error_format': error_format + }) + component.setdefault('styles', {}) + component.setdefault('error_styles', {}) + + if inherit_color: + component['styles'].setdefault( + 'color', + get_target_color(component['target'], 'styles') + ) + component['error_styles'].setdefault( + 'color', + get_target_color(component['target'], 'error_styles') + ) + return comparison_options - def draw_comparison_data(self, ax, reference_data, target_data, - bin_edges:Optional[np.ndarray]=None, - mode:Union[HistComparisonMode, str]="ratio", - draw_error:bool=True, - plot_format:Union[PlotFormat, str]='errorbar', - error_format:Union[ErrorDisplayFormat, str]='errorbar', - styles:Optional[Dict]=None, - error_styles:Optional[Dict]=None): - mode = HistComparisonMode.parse(mode) - comparison_data = get_hist_comparison_data(reference_data, - target_data, - mode=mode) - handle, error_handle = self.draw_binned_data(ax, comparison_data, - bin_edges=bin_edges, - draw_data=True, - draw_error=draw_error, - plot_format=plot_format, - error_format=error_format, - styles=styles, - error_styles=error_styles) - # expand ylim according to data range - y = comparison_data['y'] - ylim = list(ax.get_ylim()) - if ylim[0] > np.min(y): - ylim[0] = np.min(y) - if ylim[1] < np.max(y): - ylim[1] = np.max(y) - ax.set_ylim(ylim) - - self.hist_comparison_data.append(comparison_data) - - return handle, error_handle - - def deduce_bin_range(self, samples:List[str], column_name:str, - variable_scale:Optional[float]=None): - """Deduce bin range based on variable ranges from multiple samples + def get_relevant_samples( + self, + plot_options: Dict[str, Dict[str, Any]] + ) -> List[str]: """ - xmin = None - xmax = None - for sample in samples: - df = self.data_map[sample] - x = df[column_name].values - x = x[np.isfinite(x)] - if variable_scale is not None: - x = x * variable_scale - if xmin is None: - xmin = np.min(x) - else: - xmin = min(xmin, np.min(x)) - if xmax is None: - xmax = np.max(x) + Get all relevant samples from the plot options. + + Parameters + ---------- + plot_options : Dict[str, Dict[str, Any]] + Plot options dictionary + + Returns + ------- + List[str] + List of relevant sample names + """ + relevant_samples = [] + for options in plot_options.values(): + for component in options['components'].values(): + relevant_samples.extend(component['samples']) + return remove_duplicates(relevant_samples) + + def resolve_legend_order( + self, + plot_options: Dict[str, Dict[str, Any]] + ) -> List[str]: + """ + Resolve the order of legend entries. + + Parameters + ---------- + plot_options : Dict[str, Dict[str, Any]] + Plot options dictionary + + Returns + ------- + List[str] + List of legend keys in the desired order + """ + legend_order = [] + combine_stacked_error = self.config['combine_stacked_error'] + isolate_error_legend = self.config['isolate_error_legend'] + + for target, options in plot_options.items(): + if len(options['components']) == 1: + legend_order.append(target) + if isolate_error_legend: + legend_order.append(f'{target}.error') else: - xmax = max(xmax, np.max(x)) - return (xmin, xmax) - - def get_sample_data(self, samples:List[str], - column_name:str, - variable_scale:Optional[float]=None, - weight_scale:Optional[float]=None, - weight_name:Optional[str]=None, - selection:Optional[str]=None): + for subtarget in options['components'].keys(): + legend_order.append(f'{target}.{subtarget}') + if isolate_error_legend: + legend_order.append(f'{target}.{subtarget}.error') + if combine_stacked_error: + legend_order.append(f'{target}.error') + + return legend_order + + def get_sample_data( + self, + samples: List[str], + column_name: str, + variable_scale: Optional[float] = None, + weight_scale: Optional[float] = None, + weight_name: Optional[str] = None, + selection: Optional[str] = None + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Get sample data and weights for the given samples. + + Parameters + ---------- + samples : List[str] + List of sample names + column_name : str + Name of the variable column + variable_scale : Optional[float], optional + Factor to scale the variable values + weight_scale : Optional[float], optional + Factor to scale the weights + weight_name : Optional[str], optional + Name of the weight column + selection : Optional[str], optional + Selection query to filter the data + + Returns + ------- + Tuple[np.ndarray, np.ndarray] + Tuple containing the variable data and corresponding weights + """ df = pd.concat([self.data_map[sample] for sample in samples], ignore_index=True) + if selection is not None: df = df.query(selection) + x = df[column_name].values if variable_scale is not None: x = x * variable_scale - if weight_name is not None: - weights = df[weight_name] - else: - weights = np.ones(x.shape) + + weights = df[weight_name].values if weight_name is not None else np.ones_like(x) if weight_scale is not None: - weights = weights * weight_scale + weights = weights * weight_scale + return x, weights - def draw_stacked_target(self, ax, stack_target:str, - components:Dict, - column_name:str, - plot_format:Union[PlotFormat, str], - error_format_list:List[Union[ErrorDisplayFormat, str]], - hist_options:Dict, - styles:Dict, - error_styles_list:List[Dict], - variable_scale:Optional[float]=None, - weight_name:Optional[str]=None, - show_error:bool=False, - selection:Optional[str]=None, - hide_list:Optional[List[Union[Tuple[float, float], Callable]]]=None): - stacked_data = { - 'x' : [], - 'y' : [] - } + def deduce_bin_range( + self, + samples: List[str], + column_name: str, + variable_scale: Optional[float] = None + ) -> Tuple[float, float]: + """ + Deduce bin range based on variable ranges from multiple samples. + + Parameters + ---------- + samples : List[str] + List of sample names + column_name : str + Name of the variable column + variable_scale : Optional[float], optional + Factor to scale the variable values + + Returns + ------- + Tuple[float, float] + The minimum and maximum values across all samples + """ + xmin, xmax = np.inf, -np.inf - for target, options in components.items(): - samples = options['samples'] - weight_scale = options['weight_scale'] - x, y = self.get_sample_data(samples, column_name, selection=selection, - variable_scale=variable_scale, - weight_scale=weight_scale, - weight_name=weight_name) - x = get_clipped_data(x, - bin_range=hist_options["bin_range"], - clip_lower=hist_options["underflow"], - clip_upper=hist_options["overflow"]) - stacked_data['x'].append(x) - stacked_data['y'].append(y) - - bin_edges = np.histogram_bin_edges(np.concatenate(stacked_data['x']).flatten(), - bins=hist_options["bins"], - range=hist_options["bin_range"]) - show_xerr = show_error and self.config['show_xerr'] - hist_data = get_stacked_hist_data(stacked_data['x'], - stacked_data['y'], - xerr=show_xerr, - yerr=show_error, - error_mode='auto', - **hist_options) - stacked_data = get_stacked_hist_data(stacked_data['x'], stacked_data['y'], - xerr=show_xerr, - yerr=show_error, - merge=False, - **hist_options) - handles = self.draw_stacked_binned_data(ax, stacked_data, - bin_edges=bin_edges, - plot_format=plot_format, - error_format_list=error_format_list, - draw_error=show_error, - hide_list=hide_list, - styles=styles, - error_styles_list=error_styles_list) - for i, target in enumerate(components): - self.update_legend_handles({target: handles[i]}) - self.hist_data[stack_target] = hist_data - self.hist_bin_edges[stack_target] = bin_edges - - targets = list(components) - from quickstats.utils.common_utils import dict_of_list_to_list_of_dict - hist_data_list = dict_of_list_to_list_of_dict(stacked_data) - for target, hist_data_i in zip(targets, hist_data_list): - self.hist_data[target] = hist_data_i - #self.update_legend_handles({stack_target: handles}) - - def draw_single_target(self, ax, target:str, samples:List[str], - column_name:str, - styles:Dict, error_styles:Dict, - plot_format:Union[PlotFormat, str], - error_format:Union[ErrorDisplayFormat, str], - hist_options:Dict, - variable_scale:Optional[float]=None, - weight_name:Optional[str]=None, - weight_scale:Optional[float]=None, - show_error:bool=False, - selection:Optional[str]=None, - hide:Optional[Union[Tuple[float, float], Callable]]=None): - x, weights = self.get_sample_data(samples, column_name, selection=selection, - variable_scale=variable_scale, - weight_scale=weight_scale, - weight_name=weight_name) - bin_edges = np.histogram_bin_edges(x, - bins=hist_options["bins"], - range=hist_options["bin_range"]) - show_xerr = show_error and self.config['show_xerr'] - hist_data = get_hist_data(x, weights, - xerr=show_xerr, - yerr=show_error, - error_mode='auto', - **hist_options) - handles = self.draw_binned_data(ax, hist_data, - bin_edges=bin_edges, - styles=styles, - draw_error=show_error, - plot_format=plot_format, - error_format=error_format, - error_styles=error_styles, - hide=hide) - self.hist_data[target] = hist_data - self.hist_bin_edges[target] = bin_edges - self.update_legend_handles({target: handles}) + for sample in samples: + df = self.data_map[sample] + x = df[column_name].values + if variable_scale is not None: + x = x * variable_scale + xmin = min(xmin, np.nanmin(x)) + xmax = max(xmax, np.nanmax(x)) - def draw(self, column_name:str, weight_name:Optional[str]=None, - targets:Optional[List[str]]=None, - selection:Optional[str]=None, - xlabel:str="", ylabel:str="Fraction of Events / {bin_width:.2f}{unit}", - unit:Optional[str]=None, bins:Union[int, Sequence]=25, - bin_range:Optional[Sequence]=None, clip_weight:bool=True, - underflow:bool=False, overflow:bool=False, divide_bin_width:bool=False, - normalize:bool=True, show_error:bool=False, - stacked:bool=False, xmin:Optional[float]=None, xmax:Optional[float]=None, - ymin:Optional[float]=None, ymax:Optional[float]=None, ypad:float=0.3, - variable_scale:Optional[float]=None, logy:bool=False, - comparison_options:Optional[Union[Dict, List[Dict]]]=None, - legend_order:Optional[List[str]]=None): + return xmin, xmax + + def draw_single_target( + self, + ax: Axes, + target: str, + components: Dict[str, Dict[str, Any]], + column_name: str, + hist_options: Dict[str, Any], + plot_format: Union[str, PlotFormat] = 'hist', + error_format: Union[str, ErrorDisplayFormat] = 'shade', + styles: Optional[Dict[str, Any]] = None, + error_styles: Optional[Dict[str, Any]] = None, + variable_scale: Optional[float] = None, + weight_name: Optional[str] = None, + show_error: bool = False, + selection: Optional[str] = None, + ) -> None: """ - - Arguments: - column_name: string - Name of the variable in the dataframe(s). - weight_name: (optional) string - If specified, weight the histogram by the "weight_name" variable - in the dataframe. - targets: (optional) list of str - List of target inputs to be included in the plot. All inputs are - included by default. - selection: str, optional - Filter the data with the given selection (a boolean expression). - xlabel: string, default = "Score" - Label of x-axis. - ylabel: string, default = "Fraction of Events / {bin_width}" - Label of y-axis. - boundaries: (optional) list of float - If specified, draw score boundaries at given values. - bins: int or sequence of scalars, default = 25 - If integer, it defines the number of equal-width bins in the given range. - If sequence, it defines a monotonically increasing array of bin edges, - including the rightmost edge. - bin_range: (optional) (float, float) - Range of histogram bins. - clip_weight: bool, default = True - If True, ignore data outside given range when evaluating total weight - used in normalization. - underflow: bool, default = False - Include undeflow data in the first bin. - overflow: bool, default = False - Include overflow data in the last bin. - divide_bin_width: bool, default = False - Divide each bin by the bin width. - normalize: bool, default = True - Normalize the sum of weights to one. Weights outside the bin range will - not be counted if ``clip_weight`` is set to false, so the sum of bin - content could be less than one. - show_error: bool, default = False - Whether to display data error. - stacked: bool, default = False - Do a stacked plot. Only histograms will be stacked (i.e. plot format - is not errorbar). Samples with different stack_index will be stacked - independently. - xmin: (optional) float - Minimum range of x-axis. - xmax: (optional) float - Maximum range of x-axis. - ymin: (optional) float - Minimum range of y-axis. - ymax: (optional) float - Maximum range of y-axis. - ypad: float, default = 0.3 - Fraction of the y-axis that should be padded. This options will be - ignored if ymax is set. - variable_scale: (optional) float - Rescale variable values by a factor. - logy: bool, default = False - Use log scale for y-axis. - comparison_options: (optional) dict or list of dict - One or multiple dictionaries containing instructions on - making comparison plots. - legend_order: (optional) list of str - Order of legend labels. The same order as targets will be used by default. + Draw a single target on the plot. + + Parameters + ---------- + ax : matplotlib.axes.Axes + The axis on which to draw the plot. + target : str + The target name. + components : Dict + Components of the target. + column_name : str + Name of the variable column. + hist_options : Dict + Histogram options. + plot_format : Union[PlotFormat, str], optional + Format for plotting the histogram, by default 'hist'. + error_format : Union[ErrorDisplayFormat, str], optional + Format for plotting the error, by default 'shade'. + styles : Dict, optional + Styling options for the plot, by default None. + error_styles : Dict, optional + Styling options for the error representation, by default None. + variable_scale : float, optional + Factor to scale the variable values, by default None. + weight_name : str, optional + Name of the weight column, by default None. + show_error : bool, optional + Whether to display error bars, by default False. + selection : str, optional + Selection query to filter the data, by default None. """ - plot_options = self.resolve_plot_options(targets=targets) - relevant_samples = remove_duplicates([sample for options in plot_options.values() \ - for sample in options['samples']]) - if not relevant_samples: - raise RuntimeError('no targets to draw') - if stacked: - # this will remove targets that participates in stacking - stacked_plot_options = self.resolve_stacked_plot_options(plot_options) - else: - stacked_plot_options = {} - comparison_options = self.resolve_comparison_options(comparison_options, plot_options) - if legend_order is not None: - self.legend_order = list(legend_order) + def get_histogram(options: Dict[str, Any]) -> Histogram1D: + samples = options['samples'] + weight_scale = options.get('weight_scale') + mask_condition = options.get('mask_condition') + + x, weights = self.get_sample_data( + samples, + column_name, + selection=selection, + variable_scale=variable_scale, + weight_scale=weight_scale, + weight_name=weight_name + ) + + histogram = Histogram1D.create( + x, weights, + evaluate_error=show_error, + error_mode='auto', + **hist_options + ) + + if mask_condition is not None: + histogram.mask(mask_condition) + + return histogram + + # Handle stacked histograms + if len(components) > 1: + histograms = { + subtarget: get_histogram(options) + for subtarget, options in components.items() + } + histogram = StackedHistogram(histograms) + + if hist_options.get('normalize'): + density = hist_options.get('divide_bin_width', False) + histogram.normalize(density=density, inplace=True) + + self.histograms.update(histograms) + self.histograms[target] = histogram else: - self.legend_order = list(plot_options) - for options in stacked_plot_options.values(): - self.legend_order.extend([target for target in list(options['components']) \ - if target not in self.legend_order]) + options = next(iter(components.values())) + histogram = get_histogram(options) + self.histograms[target] = histogram + + handles = self.draw_histogram_data( + ax, + histogram, + plot_format=plot_format, + error_format=error_format, + styles=styles, + error_styles=error_styles, + domain=target + ) + self.update_legend_handles(handles) + + def draw_comparison_data( + self, + ax: Axes, + reference: str, + target: str, + mode: Union[str, HistComparisonMode] = "ratio", + plot_format: Union[str, PlotFormat] = 'errorbar', + error_format: Union[str, ErrorDisplayFormat] = 'errorbar', + styles: Optional[Dict[str, Any]] = None, + error_styles: Optional[Dict[str, Any]] = None + ) -> None: + """ + Draw comparison data on the plot. + + Parameters + ---------- + ax : matplotlib.axes.Axes + The axis on which to draw the comparison data. + reference : str + The reference histogram key. + target : str + The target histogram key. + mode : Union[HistComparisonMode, str], optional + Comparison mode, by default "ratio". + plot_format : Union[PlotFormat, str], optional + Format for plotting the histogram, by default 'errorbar'. + error_format : Union[ErrorDisplayFormat, str], optional + Format for plotting the error, by default 'errorbar'. + styles : Dict, optional + Styling options for the plot, by default None. + error_styles : Dict, optional + Styling options for the error representation, by default None. + + Raises + ------ + ValueError + If the histogram data for the reference or target is not set. + """ + for key in [reference, target]: + if key not in self.histograms: + raise ValueError(f'Histogram data not set: {key}') + target_histogram = self.histograms[target] + reference_histogram = self.histograms[reference] + comparison_histogram = target_histogram.compare(reference_histogram, mode=mode) + + domain = self.config['comparison_object_id'].format( + reference=reference, + target=target + ) + self.histograms[domain] = comparison_histogram + + handles = self.draw_histogram_data( + ax, + comparison_histogram, + plot_format=plot_format, + error_format=error_format, + styles=styles, + error_styles=error_styles, + domain=domain + ) + + if comparison_histogram.has_errors(): + ylim = ( + np.nanmin(comparison_histogram.rel_bin_errlo), + np.nanmax(comparison_histogram.rel_bin_errhi) + ) + else: + y = comparison_histogram.bin_content + ylim = (np.nanmin(y), np.nanmax(y)) + + self.stretch_axis(ax, ylim=ylim) + self.update_legend_handles(handles) + + def draw( + self, + column_name: str, + weight_name: Optional[str] = None, + targets: Optional[List[str]] = None, + selection: Optional[str] = None, + xlabel: str = "", + ylabel: str = "Fraction of Events / {bin_width:.2f}{unit}", + unit: Optional[str] = None, + bins: Union[int, Sequence[float]] = 25, + bin_range: Optional[Sequence[float]] = None, + clip_weight: bool = True, + underflow: bool = False, + overflow: bool = False, + divide_bin_width: bool = False, + normalize: bool = True, + show_error: bool = False, + stacked: bool = False, + xmin: Optional[float] = None, + xmax: Optional[float] = None, + ymin: Optional[float] = None, + ymax: Optional[float] = None, + ypad: float = 0.3, + variable_scale: Optional[float] = None, + logy: bool = False, + comparison_options: Optional[Dict[str, Any]] = None, + legend_order: Optional[List[str]] = None + ) -> Union[Axes, Tuple[Axes, Axes]]: + """ + Draw the plot with specified parameters. + + Parameters + ---------- + column_name : str + Name of the variable in the dataframe(s). + weight_name : str, optional + Name of the weight column, by default None. + targets : List[str], optional + List of target inputs to be included in the plot, by default None + (i.e. all inputs are included). + selection : str, optional + Filter the data with the given selection (a boolean expression), by default None. + The selection is applied before any variable scaling. + xlabel : str, optional + Label of x-axis, by default "". + ylabel : str, optional + Label of y-axis, by default "Fraction of Events / {bin_width:.2f}{unit}". + unit : str, optional + Unit of the variable, by default None. + bins : Union[int, Sequence], optional + Number of bins or bin edges, by default 25. + bin_range : Sequence, optional + Range of histogram bins, by default None. + clip_weight : bool + If True, ignore data outside given range when evaluating total weight used in normalization, by default True. + underflow : bool + Include underflow data in the first bin, by default False. + overflow : bool + Include overflow data in the last bin, by default False. + divide_bin_width : bool + Divide each bin by the bin width, by default False. + normalize : bool + Normalize the sum of weights to one, by default True. + show_error : bool + Whether to display data error, by default False. + stacked : bool + Do a stacked plot, by default False. + xmin : float, optional + Minimum range of x-axis, by default None. + xmax : float, optional + Maximum range of x-axis, by default None. + ymin : float, optional + Minimum range of y-axis, by default None. + ymax : float, optional + Maximum range of y-axis, by default None. + ypad : float, optional + Fraction of the y-axis that should be padded, by default 0.3. + This options will be ignored if ymax is set. + variable_scale : float, optional + Rescale variable values by a factor, by default None. + logy : bool, optional + Use log scale for y-axis, by default False. + comparison_options : Union[Dict, List[Dict]], optional + Instructions for making comparison plots, by default None. + legend_order : List[str], optional + Order of legend labels, by default None. + + Returns + ------- + Axes or Tuple[Axes, Axes] + Axes object(s) for the plot. If comparison is drawn, returns a tuple of axes. + + Raises + ------ + RuntimeError + If no targets to draw. + """ + plot_options = self.resolve_plot_options( + self.plot_options, + targets=targets, + stacked=stacked + ) + comparison_options = self.resolve_comparison_options( + comparison_options, + plot_options + ) + + relevant_samples = self.get_relevant_samples(plot_options) + if not relevant_samples: + raise RuntimeError('No targets to draw') + + self.legend_order = ( + list(legend_order) if legend_order is not None + else self.resolve_legend_order(plot_options) + ) + if comparison_options is not None: ax, ax_ratio = self.draw_frame(ratio=True, logy=logy) else: ax = self.draw_frame(ratio=False, logy=logy) - - if (bin_range is None) and isinstance(bins, (int, float)): - bin_range = self.deduce_bin_range(relevant_samples, column_name, variable_scale=variable_scale) - self.stdout.info(f"Using deduced bin range ({bin_range[0]:.3f}, {bin_range[1]:.3f})") - self.reset_hist_data() + if bin_range is None and isinstance(bins, int): + bin_range = self.deduce_bin_range( + relevant_samples, + column_name, + variable_scale=variable_scale + ) + self.stdout.info( + f"Using deduced bin range ({bin_range[0]:.3f}, {bin_range[1]:.3f})" + ) + + self.reset_metadata() hist_options = { - "bins" : bins, - "bin_range" : bin_range, - "underflow" : underflow, - "overflow" : overflow, - "normalize" : normalize, - "clip_weight" : clip_weight, - "divide_bin_width" : divide_bin_width + "bins": bins, + "bin_range": bin_range, + "underflow": underflow, + "overflow": overflow, + "normalize": normalize, + "clip_weight": clip_weight, + "divide_bin_width": divide_bin_width } + data_options = { 'column_name': column_name, 'weight_name': weight_name, 'variable_scale': variable_scale, 'selection': selection } - - for stack_target, options in stacked_plot_options.items(): - options['show_error'] = options.get('show_error', show_error) - self.draw_stacked_target(ax, stack_target=stack_target, - hist_options=hist_options, - **options, - **data_options) - + for target, options in plot_options.items(): - options['show_error'] = options.get('show_error', show_error) - options.pop('stack_index', None) - self.draw_single_target(ax, target=target, - hist_options=hist_options, - **options, - **data_options) - - # propagate bin width to ylabel if needed - if isinstance(bins, int): + options.setdefault('show_error', show_error) + self.draw_single_target( + ax, + target=target, + hist_options=hist_options, + **options, + **data_options + ) + + # Propagate bin width to ylabel if needed + if isinstance(bins, int) and bin_range is not None: bin_width = (bin_range[1] - bin_range[0]) / bins - if unit is None: - unit_str = "" - else: - unit_str = f" {unit}" + unit_str = "" if unit is None else f" {unit}" ylabel = ylabel.format(bin_width=bin_width, unit=unit_str) - + if unit is not None: xlabel = f"{xlabel} [{unit}]" + self.draw_axis_components(ax, xlabel=xlabel, ylabel=ylabel) - self.set_axis_range(ax, xmin=xmin, xmax=xmax, - ymin=ymin, ymax=ymax, ypad=ypad) - + self.set_axis_range(ax, xmin=xmin, xmax=xmax, ymin=ymin, ymax=ymax, ypad=ypad) + if not self.is_single_data(): handles, labels = self.get_legend_handles_labels() - if not self.config['box_legend_handle']: - handles = remake_handles(handles, polygon_to_line=True, - line2d_styles=self.styles['legend_Line2D']) self.draw_legend(ax, handles=handles, labels=labels) - + if comparison_options is not None: components = comparison_options.pop('components') - for component in components: - reference = component.pop('reference') - target = component.pop('target') - bin_edges = self.hist_bin_edges[target] - self.draw_comparison_data(ax_ratio, - self.hist_data[reference], - self.hist_data[target], - bin_edges=bin_edges, - **component) + for options in components: + self.draw_comparison_data(ax_ratio, **options) comparison_options['xlabel'] = ax.get_xlabel() self.decorate_comparison_axis(ax_ratio, **comparison_options) ax.set(xlabel=None) ax.tick_params(axis="x", labelbottom=False) - - if comparison_options is not None: return ax, ax_ratio - + return ax \ No newline at end of file diff --git a/quickstats/tests/test_NamedTreeNode.py b/quickstats/tests/test_NamedTreeNode.py new file mode 100644 index 0000000000000000000000000000000000000000..47a1154c64723fadd8b217a5005b165d1e9655d9 --- /dev/null +++ b/quickstats/tests/test_NamedTreeNode.py @@ -0,0 +1,602 @@ +""" +Unit tests for the NamedTreeNode class. +""" + +import unittest +from typing import Optional, Dict, Any, List +import copy + +from quickstats.tree import NamedTreeNode + + +class TestNamedTreeNodeInit(unittest.TestCase): + """Test NamedTreeNode initialization and basic properties.""" + + def test_init_defaults(self) -> None: + """Test default initialization.""" + node = NamedTreeNode[str]() + self.assertEqual(node.name, "root") + self.assertIsNone(node.data) + self.assertEqual(len(node._children), 0) + self.assertEqual(node._separator, ".") + + def test_init_custom(self) -> None: + """Test initialization with custom values.""" + node = NamedTreeNode[str]("custom", "data", separator="|") + self.assertEqual(node.name, "custom") + self.assertEqual(node.data, "data") + self.assertEqual(node._separator, "|") + + def test_init_validation(self) -> None: + """Test initialization validation.""" + # Invalid name types + with self.assertRaises(ValueError): + NamedTreeNode[str](123) # type: ignore + + # Empty name + with self.assertRaises(ValueError): + NamedTreeNode[str]("") + + # Invalid name characters + with self.assertRaises(ValueError): + NamedTreeNode[str]("invalid name") + + # Invalid separator + with self.assertRaises(ValueError): + NamedTreeNode[str]("valid", separator="") + + +class TestNamedTreeNodeProperties(unittest.TestCase): + """Test NamedTreeNode properties and attribute access.""" + + def setUp(self) -> None: + """Set up test fixtures.""" + self.node = NamedTreeNode[str]("root", "root_data") + self.node["child1"] = "data1" + self.node["child2.grandchild"] = "data2" + + def test_name_property(self) -> None: + """Test name property access.""" + self.assertEqual(self.node.name, "root") + + # Name should be read-only + with self.assertRaises(AttributeError): + self.node.name = "new_name" # type: ignore + + def test_data_property(self) -> None: + """Test data property access and modification.""" + self.assertEqual(self.node.data, "root_data") + + # Data should be read-only through property + with self.assertRaises(AttributeError): + self.node.data = "new_data" # type: ignore + + def test_namespaces_property(self) -> None: + """Test namespaces property.""" + self.assertEqual(set(self.node.namespaces), {"child1", "child2"}) + + # Empty node + empty_node = NamedTreeNode[str]() + self.assertEqual(empty_node.namespaces, []) + + def test_domains_property(self) -> None: + """Test domains property.""" + expected_domains = {"child1", "child2.grandchild"} + self.assertEqual(set(self.node.domains), expected_domains) + + # Test nested domains + self.node["a.b.c.d"] = "nested" + self.assertIn("a.b.c.d", self.node.domains) + + +class TestNamedTreeNodeOperations(unittest.TestCase): + """Test NamedTreeNode operations (add, remove, get, set).""" + + def setUp(self) -> None: + """Set up test fixtures.""" + self.node = NamedTreeNode[str]("root") + + def test_add_child(self) -> None: + """Test adding child nodes.""" + # Add simple child + child = NamedTreeNode[str]("child", "data") + self.node.add_child(child) + self.assertEqual(self.node["child"], "data") + + # Add with existing name + with self.assertRaises(ValueError): + self.node.add_child(child) + + # Add invalid type + with self.assertRaises(TypeError): + self.node.add_child("not_a_node") # type: ignore + + # Add child with invalid name + invalid_child = NamedTreeNode[str]("invalid.name", "data") + with self.assertRaises(ValueError): + self.node.add_child(invalid_child) + + def test_get_child(self) -> None: + """Test getting child nodes.""" + child = NamedTreeNode[str]("child", "data") + self.node.add_child(child) + + # Get existing child + self.assertEqual(self.node.get_child("child"), child) + + # Get non-existent child + self.assertIsNone(self.node.get_child("missing")) + + # Get with custom default + default = NamedTreeNode[str]("default") + self.assertEqual( + self.node.get_child("missing", default), + default + ) + + def test_remove_child(self) -> None: + """Test removing child nodes.""" + self.node["child"] = "data" + + # Remove existing child + removed = self.node.remove_child("child") + self.assertEqual(removed.data, "data") + self.assertNotIn("child", self.node) + + # Remove non-existent child + self.assertIsNone(self.node.remove_child("missing")) + + +class TestNamedTreeNodeTraversal(unittest.TestCase): + """Test NamedTreeNode traversal methods.""" + + def setUp(self) -> None: + """Set up test fixtures.""" + self.node = NamedTreeNode[str]("root", "root_data") + self.node["a.b.c"] = "abc_data" + self.node["a.b.d"] = "abd_data" + self.node["x.y"] = "xy_data" + + def test_traverse(self) -> None: + """Test traverse method.""" + # Traverse existing path + node = self.node.traverse("a", "b", "c") + self.assertEqual(node.data, "abc_data") + + # Traverse non-existent path without creation + self.assertIsNone(self.node.traverse("a", "b", "missing")) + + # Traverse with path creation + node = self.node.traverse("new", "path", create=True) + self.assertIsNotNone(node) + self.assertEqual(node.name, "path") + + # Traverse with empty components + self.assertEqual( + self.node.traverse("a", "", "b", "c"), + self.node.traverse("a", "b", "c") + ) + + def test_traverse_domain(self) -> None: + """Test domain-based traversal.""" + # Traverse existing domain + node = self.node.traverse_domain("a.b.c") + self.assertEqual(node.data, "abc_data") + + # Traverse with different separator + node_pipe = NamedTreeNode[str]("root", separator="|") + node_pipe["a|b|c"] = "data" + self.assertEqual(node_pipe.traverse_domain("a|b|c").data, "data") + + # Invalid domain type + with self.assertRaises(TypeError): + self.node.traverse_domain(123) # type: ignore + + +class TestNamedTreeNodeDataAccess(unittest.TestCase): + """Test NamedTreeNode data access methods.""" + + def setUp(self) -> None: + """Set up test fixtures.""" + self.node = NamedTreeNode[str]("root", "root_data") + + def test_get(self) -> None: + """Test get method.""" + # Get root data + self.assertEqual(self.node.get(), "root_data") + + # Get with domain + self.node["a.b"] = "ab_data" + self.assertEqual(self.node.get("a.b"), "ab_data") + + # Get non-existent with default + self.assertEqual(self.node.get("missing", default="default"), "default") + + # Get with strict mode + with self.assertRaises(KeyError): + self.node.get("missing", strict=True) + + def test_set(self) -> None: + """Test set method.""" + # Set root data + self.node.set("new_data") + self.assertEqual(self.node.data, "new_data") + + # Set with domain + self.node.set("domain_data", "a.b.c") + self.assertEqual(self.node.get("a.b.c"), "domain_data") + + # Set with validation + with self.assertRaises(ValueError): + self.node.set({"invalid": "data"}) # type: ignore + + def test_item_access(self) -> None: + """Test dictionary-style item access.""" + # Set and get + self.node["a.b"] = "ab_data" + self.assertEqual(self.node["a.b"], "ab_data") + + # Delete + del self.node["a.b"] + self.assertNotIn("a.b", self.node) + + # Get non-existent + with self.assertRaises(KeyError): + _ = self.node["missing"] + + +class TestNamedTreeNodeUpdate(unittest.TestCase): + """Test NamedTreeNode update methods.""" + + def setUp(self) -> None: + """Set up test fixtures.""" + self.node = NamedTreeNode[str]("root", "root_data") + self.node["a"] = "a_data" + self.node["b.c"] = "bc_data" + + def test_update_from_dict(self) -> None: + """Test updating from dictionary.""" + update_dict = { + "name": "root", + "data": "new_data", + "children": { + "a": {"name": "a", "data": "new_a_data"}, + "d": {"name": "d", "data": "d_data"} + } + } + self.node.update(update_dict) + + self.assertEqual(self.node.data, "new_data") + self.assertEqual(self.node["a"], "new_a_data") + self.assertEqual(self.node["d"], "d_data") + self.assertEqual(self.node["b.c"], "bc_data") # Unchanged + + def test_update_from_node(self) -> None: + """Test updating from another node.""" + other = NamedTreeNode[str]("root", "other_data") + other["a"] = "other_a_data" + other["x"] = "x_data" + + self.node.update(other) + + self.assertEqual(self.node.data, "other_data") + self.assertEqual(self.node["a"], "other_a_data") + self.assertEqual(self.node["x"], "x_data") + self.assertEqual(self.node["b.c"], "bc_data") # Unchanged + + def test_merge(self) -> None: + """Test merge method.""" + other = NamedTreeNode[str]("root", "other_data") + other["a"] = "other_a_data" + other["b.d"] = "bd_data" + + # Test replace strategy + self.node.merge(other, strategy='replace') + self.assertEqual(self.node.data, "other_data") + self.assertEqual(self.node["a"], "other_a_data") + self.assertEqual(self.node["b.c"], "bc_data") + self.assertEqual(self.node["b.d"], "bd_data") + + # Test keep strategy + node2 = NamedTreeNode[str]("root", "keep_data") + node2.merge(other, strategy='keep') + self.assertEqual(node2.data, "keep_data") + + # Test invalid strategy + with self.assertRaises(ValueError): + self.node.merge(other, strategy='invalid') + + +class TestNamedTreeNodeSerialization(unittest.TestCase): + """Test NamedTreeNode serialization methods.""" + + def setUp(self) -> None: + """Set up test fixtures.""" + self.node = NamedTreeNode[str]("root", "root_data") + self.node["a"] = "a_data" + self.node["b.c"] = "bc_data" + + def test_to_dict(self) -> None: + """Test conversion to dictionary.""" + data = self.node.to_dict() + + self.assertEqual(data["name"], "root") + self.assertEqual(data["data"], "root_data") + self.assertIn("a", data["children"]) + self.assertIn("b", data["children"]) + self.assertIn("c", data["children"]["b"]["children"]) + + def test_from_dict(self) -> None: + """Test creation from dictionary.""" + data = { + "name": "root", + "data": "root_data", + "children": { + "a": { + "name": "a", + "data": "a_data" + }, + "b": { + "name": "b", + "data": None, + "children": { + "c": { + "name": "c", + "data": "bc_data" + } + } + } + } + } + + node = NamedTreeNode[str].from_dict(data) + self.assertEqual(node.name, "root") + self.assertEqual(node.data, "root_data") + self.assertEqual(node["a"], "a_data") + self.assertEqual(node["b.c"], "bc_data") + + def test_from_mapping(self) -> None: + """Test creation from mapping.""" + data = { + None: "root_data", + "a": "a_data", + "b.c": "bc_data" + } + + node = NamedTreeNode[str].from_mapping(data) + self.assertEqual(node.data, "root_data") + self.assertEqual(node["a"], "a_data") + self.assertEqual(node["b.c"], "bc_data") + + +class TestNamedTreeNodeCopy(unittest.TestCase): + """Test NamedTreeNode copying.""" + + def setUp(self) -> None: + """Set up test fixtures.""" + self.node = NamedTreeNode[str]("root", "root_data") + self.node["a"] = "a_data" + self.node["b.c"] = "bc_data" + + def test_shallow_copy(self) -> None: + """Test shallow copy.""" + copied = self.node.copy(deep=False) + + # Test independence + copied["new"] = "new_data" + self.assertNotIn("new", self.node) + + # Test shared references + copied["a"] = "modified" + self.assertEqual(self.node["a"], "a_data") # Original unchanged + + def test_deep_copy(self) -> None: + """Test deep copy.""" + copied = self.node.copy(deep=True) + + # Test independence of structure + copied["new"] = "new_data" + self.assertNotIn("new", self.node) + + # Test independence of data + copied.set("modified", "a") + self.assertEqual(self.node["a"], "a_data") # Original unchanged + + # Test nested structures + copied.set("modified_nested", "b.c") + self.assertEqual(self.node["b.c"], "bc_data") # Original unchanged + + def test_copy_special_methods(self) -> None: + """Test copy special methods.""" + # Test __copy__ + copied = copy.copy(self.node) + self.assertEqual(copied.data, self.node.data) + self.assertEqual(copied["a"], self.node["a"]) + + # Test __deepcopy__ + deep_copied = copy.deepcopy(self.node) + deep_copied["a"] = "modified" + self.assertNotEqual(deep_copied["a"], self.node["a"]) + + +class TestNamedTreeNodeOperators(unittest.TestCase): + """Test NamedTreeNode operators and special methods.""" + + def setUp(self) -> None: + """Set up test fixtures.""" + self.node = NamedTreeNode[str]("root", "root_data") + self.node["a"] = "a_data" + self.node["b.c"] = "bc_data" + + def test_or_operator(self) -> None: + """Test | operator.""" + # Test with dictionary + result = self.node | {"name": "new", "data": "new_data"} + self.assertEqual(result.name, "new") + self.assertEqual(result.data, "new_data") + self.assertEqual(result["a"], "a_data") # Preserved original data + + # Test with another node + other = NamedTreeNode[str]("other", "other_data") + other["x"] = "x_data" + result = self.node | other + self.assertEqual(result.name, "other") + self.assertEqual(result.data, "other_data") + self.assertEqual(result["a"], "a_data") + self.assertEqual(result["x"], "x_data") + + # Test with invalid type + with self.assertRaises(TypeError): + _ = self.node | "invalid" # type: ignore + + def test_ior_operator(self) -> None: + """Test |= operator.""" + # Test with dictionary + original_node = self.node.copy(deep=True) + self.node |= {"name": "new", "data": "new_data"} + self.assertEqual(self.node.name, "new") + self.assertEqual(self.node.data, "new_data") + self.assertEqual(self.node["a"], "a_data") + + # Test with another node + self.node = original_node.copy(deep=True) + other = NamedTreeNode[str]("other", "other_data") + other["x"] = "x_data" + self.node |= other + self.assertEqual(self.node.name, "other") + self.assertEqual(self.node.data, "other_data") + self.assertEqual(self.node["x"], "x_data") + + def test_ror_operator(self) -> None: + """Test reverse or operator.""" + result = {"name": "dict", "data": "dict_data"} | self.node + self.assertEqual(result.name, "root") + self.assertEqual(result.data, "root_data") + self.assertEqual(result["a"], "a_data") + + def test_contains(self) -> None: + """Test membership testing.""" + self.assertTrue("a" in self.node) + self.assertTrue("b.c" in self.node) + self.assertFalse("missing" in self.node) + self.assertFalse("b.missing" in self.node) + + # Test with invalid types + self.assertFalse(123 in self.node) # type: ignore + self.assertFalse(None in self.node) + + def test_iteration(self) -> None: + """Test iteration over nodes.""" + children = list(self.node) + self.assertEqual(len(children), 2) # 'a' and 'b' nodes + + child_names = [child.name for child in children] + self.assertIn("a", child_names) + self.assertIn("b", child_names) + + +class TestNamedTreeNodeEdgeCases(unittest.TestCase): + """Test edge cases and error conditions.""" + + def setUp(self) -> None: + """Set up test fixtures.""" + self.node = NamedTreeNode[str]("root", "root_data") + + def test_empty_domain(self) -> None: + """Test operations with empty domains.""" + # Set with empty domain + self.node.set("data", "") + self.assertEqual(self.node.data, "data") + + # Get with empty domain + self.assertEqual(self.node.get(""), "data") + + # Empty domain components + self.node.set("nested", "a..b") + self.assertEqual(self.node.get("a.b"), "nested") + + def test_invalid_domains(self) -> None: + """Test operations with invalid domains.""" + # None domain + self.node.set("data", None) + self.assertEqual(self.node.data, "data") + + # Invalid domain types + with self.assertRaises(TypeError): + self.node.set("data", 123) # type: ignore + + with self.assertRaises(TypeError): + self.node.get(["invalid"]) # type: ignore + + def test_namespace_conflicts(self) -> None: + """Test namespace conflict handling.""" + # Try to create domain that conflicts with existing node + self.node["a"] = "data" + with self.assertRaises(ValueError): + self.node["a.b"] = "conflict" + + # Try to create node that conflicts with existing domain + self.node["x.y"] = "data" + with self.assertRaises(ValueError): + self.node["x"] = "conflict" + + def test_circular_references(self) -> None: + """Test handling of circular references.""" + child = NamedTreeNode[str]("child", "child_data") + self.node.add_child(child) + + # Attempt to add parent as child of child + with self.assertRaises(ValueError): + child.add_child(self.node) + + def test_data_validation(self) -> None: + """Test data validation.""" + # Test with None data when not allowed + strict_node = NamedTreeNode[str]("strict") + strict_node.config.allow_none_data = False + + with self.assertRaises(ValueError): + strict_node.set(None) + + # Test with invalid data types + with self.assertRaises(TypeError): + self.node.set(123) # type: ignore + + +class TestNamedTreeNodePerformance(unittest.TestCase): + """Test performance characteristics.""" + + def setUp(self) -> None: + """Set up test fixtures.""" + self.node = NamedTreeNode[str]("root") + + # Create a deep tree + current = self.node + for i in range(100): + child = NamedTreeNode[str](f"child{i}", f"data{i}") + current.add_child(child) + current = child + + def test_deep_traversal(self) -> None: + """Test traversal of deep trees.""" + # This should complete quickly even with a deep tree + path = ".".join(f"child{i}" for i in range(99)) + result = self.node.get(path) + self.assertEqual(result, "data99") + + def test_large_breadth(self) -> None: + """Test operations on trees with many siblings.""" + # Create many siblings + for i in range(1000): + self.node[f"child{i}"] = f"data{i}" + + # Test access time + self.assertEqual(self.node["child999"], "data999") + + # Test iteration + count = sum(1 for _ in self.node) + self.assertEqual(count, 1100) # 1000 new + 100 from setUp + + +if __name__ == '__main__': + unittest.main(verbosity=2) \ No newline at end of file diff --git a/quickstats/tests/test_abstract_object.py b/quickstats/tests/test_abstract_object.py new file mode 100644 index 0000000000000000000000000000000000000000..ea779e44b4d37aff6f1138e5087fc881c3e933c3 --- /dev/null +++ b/quickstats/tests/test_abstract_object.py @@ -0,0 +1,205 @@ +"""Unit tests for AbstractObject class.""" + +from __future__ import annotations + +import pickle +import unittest +from typing import Any + +from quickstats.core.abstract_object import AbstractObject +from quickstats.core.io import VerbosePrint +from quickstats.core.enums import Verbosity +from quickstats.core.decorators import semistaticmethod + + +class TestObject(AbstractObject): + """Test class extending AbstractObject.""" + + def __init__(self, value: Any = None, **kwargs: Any) -> None: + self.value = value + super().__init__(**kwargs) + + @semistaticmethod + def log_message(self_or_cls, message: str) -> None: + """Log message using either instance or class stdout.""" + self_or_cls.stdout.info(message) + + @semistaticmethod + def get_verbosity(self_or_cls) -> Verbosity: + """Get current verbosity level.""" + return self_or_cls.stdout.verbosity + + +class TestAbstractObject(unittest.TestCase): + """Test suite for AbstractObject functionality.""" + + def setUp(self) -> None: + """Initialize test fixtures.""" + # Reset class-level default for each test + TestObject._class_stdout = VerbosePrint(Verbosity.INFO) + + def test_initialization(self) -> None: + """Test object initialization with different verbosity settings.""" + # Default initialization + obj = TestObject() + self.assertEqual(obj.stdout.verbosity, Verbosity.INFO) + self.assertIsNone(obj._stdout) + + # Custom verbosity + obj = TestObject(verbosity="DEBUG") + self.assertEqual(obj.stdout.verbosity, Verbosity.DEBUG) + self.assertIsNotNone(obj._stdout) + + # None verbosity + obj = TestObject(verbosity=None) + self.assertEqual(obj.stdout.verbosity, Verbosity.INFO) + self.assertIsNone(obj._stdout) + + def test_stdout_property(self) -> None: + """Test stdout property behavior at class and instance level.""" + # Class-level access + self.assertEqual(TestObject.stdout.verbosity, Verbosity.INFO) + + # Instance with default + obj1 = TestObject() + self.assertEqual(obj1.stdout.verbosity, Verbosity.INFO) + self.assertIs(obj1.stdout, TestObject._class_stdout) + + # Instance with custom + obj2 = TestObject(verbosity="DEBUG") + self.assertEqual(obj2.stdout.verbosity, Verbosity.DEBUG) + self.assertIsNot(obj2.stdout, TestObject._class_stdout) + + def test_verbosity_inheritance(self) -> None: + """Test verbosity inheritance from class to instance.""" + # Set class default + TestObject.set_default_verbosity("DEBUG") + + # New instance should inherit + obj = TestObject() + self.assertEqual(obj.stdout.verbosity, Verbosity.DEBUG) + self.assertTrue(obj.debug_mode) + + # Custom verbosity should override + obj.set_verbosity("WARNING") + self.assertEqual(obj.stdout.verbosity, Verbosity.WARNING) + self.assertFalse(obj.debug_mode) + + def test_set_verbosity(self) -> None: + """Test verbosity changes.""" + obj = TestObject() + + # Set custom verbosity + obj.set_verbosity("DEBUG") + self.assertEqual(obj.stdout.verbosity, Verbosity.DEBUG) + self.assertIsNotNone(obj._stdout) + + # Reset to default + obj.set_verbosity(None) + self.assertEqual(obj.stdout.verbosity, Verbosity.INFO) + self.assertIsNone(obj._stdout) + + # Test invalid verbosity + with self.assertRaises(ValueError): + obj.set_verbosity("INVALID") + + def test_copy_verbosity(self) -> None: + """Test copying verbosity between objects.""" + # Source with custom verbosity + source = TestObject(verbosity="DEBUG") + + # Copy to target + target = TestObject() + target.copy_verbosity_from(source) + self.assertEqual(target.stdout.verbosity, Verbosity.DEBUG) + self.assertIsNotNone(target._stdout) + + # Copy default verbosity + source = TestObject() + target.copy_verbosity_from(source) + self.assertIsNone(target._stdout) + self.assertEqual(target.stdout.verbosity, Verbosity.INFO) + + def test_semistaticmethod(self) -> None: + """Test stdout access with semistaticmethod.""" + # Class-level access + TestObject.log_message("class message") + self.assertEqual(TestObject.get_verbosity(), Verbosity.INFO) + + # Instance with default + obj1 = TestObject() + obj1.log_message("default message") + self.assertEqual(obj1.get_verbosity(), Verbosity.INFO) + + # Instance with custom + obj2 = TestObject(verbosity="DEBUG") + obj2.log_message("debug message") + self.assertEqual(obj2.get_verbosity(), Verbosity.DEBUG) + + def test_debug_mode(self) -> None: + """Test debug_mode property.""" + obj = TestObject() + self.assertFalse(obj.debug_mode) + + obj.set_verbosity("DEBUG") + self.assertTrue(obj.debug_mode) + + obj.set_verbosity("INFO") + self.assertFalse(obj.debug_mode) + + def test_pickling(self) -> None: + """Test serialization and deserialization.""" + # Test with custom verbosity + original = TestObject(value="test", verbosity="DEBUG") + pickled = pickle.dumps(original) + unpickled = pickle.loads(pickled) + + self.assertEqual(unpickled.value, "test") + self.assertEqual(unpickled.stdout.verbosity, Verbosity.DEBUG) + self.assertIsNotNone(unpickled._stdout) + + # Test with default verbosity + original = TestObject(value="test") + pickled = pickle.dumps(original) + unpickled = pickle.loads(pickled) + + self.assertEqual(unpickled.value, "test") + self.assertEqual(unpickled.stdout.verbosity, Verbosity.INFO) + self.assertIsNone(unpickled._stdout) + + def test_class_default_changes(self) -> None: + """Test effect of changing class-level default verbosity.""" + # Create instances + default_obj = TestObject() + custom_obj = TestObject(verbosity="WARNING") + + # Change class default + TestObject.set_default_verbosity("DEBUG") + + # Should affect instance using class default + self.assertEqual(default_obj.stdout.verbosity, Verbosity.DEBUG) + self.assertEqual(TestObject.stdout.verbosity, Verbosity.DEBUG) + + # Should not affect instance with custom verbosity + self.assertEqual(custom_obj.stdout.verbosity, Verbosity.WARNING) + + def test_edge_cases(self) -> None: + """Test edge cases and error conditions.""" + obj = TestObject() + + # Invalid verbosity values + for invalid_value in ["INVALID", -1, None]: + with self.subTest(value=invalid_value): + if invalid_value is None: + obj.set_verbosity(invalid_value) # Should work + else: + with self.assertRaises(ValueError): + obj.set_verbosity(invalid_value) + + # Copy from invalid source + with self.assertRaises(AttributeError): + obj.copy_verbosity_from("not an object") + + +if __name__ == '__main__': + unittest.main(verbosity=2) \ No newline at end of file diff --git a/quickstats/tests/test_abstract_plot.py b/quickstats/tests/test_abstract_plot.py new file mode 100644 index 0000000000000000000000000000000000000000..956c16e39afe95bb4989cc621baa325688f2018d --- /dev/null +++ b/quickstats/tests/test_abstract_plot.py @@ -0,0 +1,159 @@ +import pytest +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.lines import Line2D +from matplotlib.patches import Rectangle + +from quickstats import AbstractObject +from quickstats.plots import AbstractPlot, LegendEntry + +class TestPlot(AbstractPlot): + """Test implementation of AbstractPlot.""" + COLOR_MAP = {'test': 'red'} + COLOR_CYCLE = 'viridis' + LABEL_MAP = {'test': 'Test Label'} + + def get_default_legend_order(self) -> List[str]: + return ['test'] + +class TestAbstractPlot: + """Test suite for AbstractPlot class.""" + + @pytest.fixture + def plot(self): + """Create a test plot instance.""" + return TestPlot() + + @pytest.fixture + def ax(self): + """Create a test axes.""" + fig, ax = plt.subplots() + yield ax + plt.close(fig) + + def test_initialization(self, plot): + """Test plot initialization.""" + assert isinstance(plot, AbstractObject) + assert plot.color_map.get('test') == 'red' + assert plot.label_map.get('test') == 'Test Label' + assert hasattr(plot, 'cmap') + assert hasattr(plot, 'color_cycle') + + def test_color_cycle(self, plot): + """Test color cycle functionality.""" + # Test default cycle + colors = plot.get_colors() + assert len(colors) > 0 + assert all(isinstance(c, str) for c in colors) + + # Test custom cycle + plot.set_color_cycle(['red', 'blue']) + colors = plot.get_colors() + assert colors == ['red', 'blue'] + + def test_legend_handling(self, plot, ax): + """Test legend related functionality.""" + # Create test handle + line = Line2D([0], [0], label='test') + + # Update legend handles + plot.update_legend_handles({'test': line}) + + # Check legend data + entry = plot.legend_data.get('test') + assert isinstance(entry, LegendEntry) + assert entry.handle == line + assert entry.label == 'test' + assert entry.has_valid_label() + + # Test decoration + rect = Rectangle((0, 0), 1, 1) + plot.add_legend_decoration(rect, ['test']) + + entry = plot.legend_data.get('test') + assert isinstance(entry.handle, tuple) + assert len(entry.handle) == 2 + assert entry.handle[1] == rect + + def test_point_handling(self, plot): + """Test point addition functionality.""" + plot.add_point(1.0, 2.0, label='test_point') + assert len(plot._points) == 1 + point = plot._points[0] + assert point.x == 1.0 + assert point.y == 2.0 + assert point.label == 'test_point' + + # Test duplicate name handling + with pytest.raises(PlottingError): + plot.add_point(3.0, 4.0, name='test_point') + + def test_annotation_handling(self, plot): + """Test annotation functionality.""" + plot.add_annotation('test', x=0.5, y=0.5) + assert len(plot._annotations) == 1 + ann = plot._annotations[0] + assert ann.text == 'test' + assert ann.options == {'x': 0.5, 'y': 0.5} + + # Test invalid input + with pytest.raises(ValueError): + plot.add_annotation('') + + def test_drawing(self, plot): + """Test drawing functionality.""" + # Test single frame + ax = plot.draw_frame() + assert isinstance(ax, plt.Axes) + + # Test ratio frame + ax_main, ax_ratio = plot.draw_frame(ratio=True) + assert isinstance(ax_main, plt.Axes) + assert isinstance(ax_ratio, plt.Axes) + + def test_cleanup(self, plot): + """Test cleanup functionality.""" + ax = plot.draw_frame() + plot.add_point(1.0, 2.0) + plot.add_annotation('test') + + plot.reset() + assert len(plot._points) == 0 + assert len(plot._annotations) == 0 + assert len(plot.legend_order) == 0 + + AbstractPlot.close_all_figures() + + def test_axis_components(self, plot, ax): + """Test axis component drawing.""" + plot.draw_axis_components( + ax, + xlabel='X', + ylabel='Y', + title='Test', + xlim=(0, 1), + ylim=(0, 1), + xticks=[0, 0.5, 1], + yticks=[0, 0.5, 1] + ) + + assert ax.get_xlabel() == 'X' + assert ax.get_ylabel() == 'Y' + assert ax.get_title() == 'Test' + + def test_axis_range(self, plot, ax): + """Test axis range manipulation.""" + plot.set_axis_range( + ax, + xmin=0, + xmax=1, + ymin=0, + ymax=1, + ypad=0.1 + ) + + xlim = ax.get_xlim() + ylim = ax.get_ylim() + assert xlim == (0, 1) + assert ylim[0] < 0 # Check padding + assert ylim[1] > 1 # Check padding \ No newline at end of file diff --git a/quickstats/tests/test_constraints.py b/quickstats/tests/test_constraints.py new file mode 100644 index 0000000000000000000000000000000000000000..4761ca1a93012e6d53485dd8da1ab30301e2a6a2 --- /dev/null +++ b/quickstats/tests/test_constraints.py @@ -0,0 +1,224 @@ +"""Unit tests for constraints module.""" + +from __future__ import annotations + +import unittest +from typing import Any, Set + +from quickstats.core.constraints import ( + BaseConstraint, + RangeConstraint, + MinConstraint, + MaxConstraint, + ChoiceConstraint +) + + +class TestConstraints(unittest.TestCase): + """Test suite for constraint classes.""" + + def test_base_constraint(self) -> None: + """Test BaseConstraint functionality.""" + constraint = BaseConstraint() + + # Test basic functionality + self.assertTrue(constraint(42)) + self.assertTrue(constraint("anything")) + + # Test representation + self.assertEqual(repr(constraint), "BaseConstraint()") + + # Test equality + other_constraint = BaseConstraint() + self.assertEqual(constraint, other_constraint) + + # Test hash + self.assertEqual(hash(constraint), hash(other_constraint)) + + def test_range_constraint(self) -> None: + """Test RangeConstraint functionality.""" + # Test inclusive bounds + constraint = RangeConstraint(1, 10) + + for value in [1, 5, 10]: + with self.subTest(value=value): + self.assertTrue(constraint(value)) + + for value in [0, 11]: + with self.subTest(value=value): + self.assertFalse(constraint(value)) + + # Test exclusive bounds + constraint = RangeConstraint(1, 10, lbound=False, rbound=False) + + for value in [2, 5, 9]: + with self.subTest(value=value): + self.assertTrue(constraint(value)) + + for value in [1, 10]: + with self.subTest(value=value): + self.assertFalse(constraint(value)) + + # Test error conditions + with self.assertRaises(ValueError): + RangeConstraint(10, 1) # Invalid range + + with self.assertRaises(ValueError): + RangeConstraint(1, 10, lbound="invalid") + + with self.assertRaises(ValueError): + constraint("invalid") # Invalid value type + + # Test equality and hash + c1 = RangeConstraint(1, 10) + c2 = RangeConstraint(1, 10) + c3 = RangeConstraint(1, 10, lbound=False) + + self.assertEqual(c1, c2) + self.assertNotEqual(c1, c3) + self.assertEqual(hash(c1), hash(c2)) + self.assertNotEqual(hash(c1), hash(c3)) + + def test_min_constraint(self) -> None: + """Test MinConstraint functionality.""" + # Test inclusive minimum + constraint = MinConstraint(5) + + for value in [5, 6, 10]: + with self.subTest(value=value): + self.assertTrue(constraint(value)) + + for value in [0, 4]: + with self.subTest(value=value): + self.assertFalse(constraint(value)) + + # Test exclusive minimum + constraint = MinConstraint(5, inclusive=False) + + for value in [6, 7, 10]: + with self.subTest(value=value): + self.assertTrue(constraint(value)) + + for value in [4, 5]: + with self.subTest(value=value): + self.assertFalse(constraint(value)) + + # Test error conditions + with self.assertRaises(ValueError): + MinConstraint(5, inclusive="invalid") + + with self.assertRaises(ValueError): + constraint("invalid") + + # Test equality and hash + c1 = MinConstraint(5) + c2 = MinConstraint(5) + c3 = MinConstraint(5, inclusive=False) + + self.assertEqual(c1, c2) + self.assertNotEqual(c1, c3) + self.assertEqual(hash(c1), hash(c2)) + self.assertNotEqual(hash(c1), hash(c3)) + + def test_max_constraint(self) -> None: + """Test MaxConstraint functionality.""" + # Test inclusive maximum + constraint = MaxConstraint(5) + + for value in [0, 4, 5]: + with self.subTest(value=value): + self.assertTrue(constraint(value)) + + for value in [6, 10]: + with self.subTest(value=value): + self.assertFalse(constraint(value)) + + # Test exclusive maximum + constraint = MaxConstraint(5, inclusive=False) + + for value in [0, 3, 4]: + with self.subTest(value=value): + self.assertTrue(constraint(value)) + + for value in [5, 6]: + with self.subTest(value=value): + self.assertFalse(constraint(value)) + + # Test error conditions + with self.assertRaises(ValueError): + MaxConstraint(5, inclusive="invalid") + + with self.assertRaises(ValueError): + constraint("invalid") + + # Test equality and hash + c1 = MaxConstraint(5) + c2 = MaxConstraint(5) + c3 = MaxConstraint(5, inclusive=False) + + self.assertEqual(c1, c2) + self.assertNotEqual(c1, c3) + self.assertEqual(hash(c1), hash(c2)) + self.assertNotEqual(hash(c1), hash(c3)) + + def test_choice_constraint(self) -> None: + """Test ChoiceConstraint functionality.""" + choices = {1, "test", 3.14} + constraint = ChoiceConstraint(*choices) + + # Test valid choices + for value in choices: + with self.subTest(value=value): + self.assertTrue(constraint(value)) + + # Test invalid choices + invalid_values = [0, "invalid", 2.71] + for value in invalid_values: + with self.subTest(value=value): + with self.assertRaises(ValueError): + constraint(value) + + # Test with different types of choices + constraint = ChoiceConstraint(None, True, 42) + self.assertTrue(constraint(None)) + self.assertTrue(constraint(True)) + self.assertTrue(constraint(42)) + + with self.assertRaises(ValueError): + constraint(False) + + # Test equality and hash + c1 = ChoiceConstraint(1, 2, 3) + c2 = ChoiceConstraint(1, 2, 3) + c3 = ChoiceConstraint(1, 2, 4) + + self.assertEqual(c1, c2) + self.assertNotEqual(c1, c3) + self.assertEqual(hash(c1), hash(c2)) + self.assertNotEqual(hash(c1), hash(c3)) + + def test_constraint_combinations(self) -> None: + """Test using multiple constraints together.""" + range_constraint = RangeConstraint(1, 10) + choice_constraint = ChoiceConstraint(2, 4, 6, 8) + + # Value should satisfy both constraints + value = 4 + self.assertTrue(range_constraint(value)) + self.assertTrue(choice_constraint(value)) + + # Value in range but not in choices + value = 3 + self.assertTrue(range_constraint(value)) + with self.assertRaises(ValueError): + choice_constraint(value) + + # Value in choices but not in range + value = 12 + with self.assertRaises(ValueError): + choice_constraint(value) + self.assertFalse(range_constraint(value)) + + +if __name__ == '__main__': + unittest.main(verbosity=2) \ No newline at end of file diff --git a/quickstats/tests/test_enums.py b/quickstats/tests/test_enums.py new file mode 100644 index 0000000000000000000000000000000000000000..6fbb8eb3db9aac72db6f2dc03d9e7286db7b11c8 --- /dev/null +++ b/quickstats/tests/test_enums.py @@ -0,0 +1,124 @@ +"""Unit tests for core.enums module.""" + +import unittest +from quickstats.core.enums import ( + CaseInsensitiveStrEnum, + GeneralEnum, + DescriptiveEnum +) + +class TestCaseInsensitiveStrEnum(unittest.TestCase): + """Test CaseInsensitiveStrEnum functionality.""" + + def setUp(self): + class Format(CaseInsensitiveStrEnum): + JSON = "json" + XML = "xml" + self.Format = Format + + def test_case_insensitive_match(self): + """Test case-insensitive string matching.""" + self.assertEqual(self.Format("JSON"), self.Format.JSON) + self.assertEqual(self.Format("json"), self.Format.JSON) + self.assertEqual(self.Format("JsOn"), self.Format.JSON) + + def test_invalid_value(self): + """Test handling of invalid values.""" + with self.assertRaises(ValueError): + self.Format("INVALID") + self.assertIsNone(self.Format._missing_(123)) + +class TestGeneralEnum(unittest.TestCase): + """Test GeneralEnum functionality.""" + + def setUp(self): + class Status(GeneralEnum): + ACTIVE = 1 + INACTIVE = 2 + __aliases__ = { + "enabled": "active", + "disabled": "inactive" + } + self.Status = Status + + def test_parse_methods(self): + """Test various parsing methods.""" + self.assertEqual(self.Status.parse("active"), self.Status.ACTIVE) + self.assertEqual(self.Status.parse("ACTIVE"), self.Status.ACTIVE) + self.assertEqual(self.Status.parse(1), self.Status.ACTIVE) + self.assertEqual(self.Status.parse("enabled"), self.Status.ACTIVE) + self.assertIsNone(self.Status.parse(None)) + + def test_equality(self): + """Test equality comparisons.""" + self.assertEqual(self.Status.ACTIVE, "active") + self.assertEqual(self.Status.ACTIVE, 1) + self.assertNotEqual(self.Status.ACTIVE, "inactive") + self.assertNotEqual(self.Status.ACTIVE, 2) + + def test_aliases(self): + """Test alias functionality.""" + self.assertEqual(self.Status.parse("enabled"), self.Status.ACTIVE) + self.assertEqual(self.Status.parse("disabled"), self.Status.INACTIVE) + + def test_member_lookup(self): + """Test member lookup methods.""" + self.assertTrue(self.Status.has_member("active")) + self.assertTrue(self.Status.has_member("ACTIVE")) + self.assertFalse(self.Status.has_member("invalid")) + + def test_invalid_values(self): + """Test handling of invalid values.""" + with self.assertRaises(ValueError): + self.Status.parse("INVALID") + with self.assertRaises(ValueError): + self.Status.parse(999) + + def test_mapping_methods(self): + """Test mapping utility methods.""" + self.assertEqual( + set(self.Status.get_members()), + {"active", "inactive"} + ) + self.assertEqual( + self.Status.get_values_map(), + {1: self.Status.ACTIVE, 2: self.Status.INACTIVE} + ) + self.assertEqual( + self.Status.get_aliases_map(), + {"enabled": "active", "disabled": "inactive"} + ) + +class TestDescriptiveEnum(unittest.TestCase): + """Test DescriptiveEnum functionality.""" + + def setUp(self): + class Level(DescriptiveEnum): + HIGH = (1, "High priority") + MEDIUM = (2, "Medium priority") + LOW = (3, "Low priority") + self.Level = Level + + def test_descriptions(self): + """Test description attribute access.""" + self.assertEqual(self.Level.HIGH.description, "High priority") + self.assertEqual(self.Level.MEDIUM.description, "Medium priority") + self.assertEqual(self.Level.LOW.description, "Low priority") + + def test_parse_with_description(self): + """Test parsing with description preservation.""" + level = self.Level.parse("HIGH") + self.assertEqual(level.description, "High priority") + self.assertEqual(level.value, 1) + + def test_invalid_with_descriptions(self): + """Test error messages include descriptions.""" + with self.assertRaises(ValueError) as cm: + self.Level.parse("INVALID") + error_msg = str(cm.exception) + self.assertIn("High priority", error_msg) + self.assertIn("Medium priority", error_msg) + self.assertIn("Low priority", error_msg) + +if __name__ == '__main__': + unittest.main() diff --git a/quickstats/tests/test_flexible_dumper.py b/quickstats/tests/test_flexible_dumper.py index 47a1df04aa7e9022e1852b5b7ce419e35cbcb8cc..a77fc988618147d72ede7c39593589ce196a862e 100644 --- a/quickstats/tests/test_flexible_dumper.py +++ b/quickstats/tests/test_flexible_dumper.py @@ -1,72 +1,227 @@ +"""Unit tests for FlexibleDumper class.""" + +from __future__ import annotations + import unittest +from typing import Any import numpy as np from quickstats import FlexibleDumper + class TestFlexibleDumper(unittest.TestCase): - - def test_default_settings(self): - data = [1, 2, [[], [{}, {1: 2, "a": 'b'}]], (5, 3), np.arange(100), 'I\nlove\nyou', [dict, any, {1, 2, 3}, 'ä½ ']] - expected_output = "- 1\n- 2\n- - []\n - - {}\n - 1: 2\n a: b\n- - 5\n - 3\n- [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23\n 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47\n 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71\n 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95\n 96 97 98 99]\n- I\n love\n you\n- - <class 'dict'>\n - <built-in function any>\n - {1, 2, 3}\n - ä½ " - dumper = FlexibleDumper() - result = dumper.dump(data) - self.assertEqual(result, expected_output) + """Test suite for FlexibleDumper class.""" - def test_indent_and_separator_settings(self): - data = {"key": "value", "list": [1, 2, 3, 5]} - dumper = FlexibleDumper(item_indent='--> ', list_indent='*-> ', separator=' => ', skip_str='[...]', max_iteration=3) - expected_output = "key => value\nlist => \n--> *-> 1\n--> *-> 2\n--> *-> 3\n--> [...]" - result = dumper.dump(data) - self.assertEqual(result, expected_output) + def setUp(self) -> None: + """Initialize test fixtures.""" + self.dumper = FlexibleDumper() + self.basic_data = { + 'str': 'hello', + 'int': 42, + 'list': [1, 2, 3], + 'dict': {'a': 1, 'b': 2}, + 'set': {1, 2, 3}, + 'tuple': (4, 5, 6) + } + self.complex_data = { + 'nested': { + 'deep': { + 'deeper': { + 'deepest': 'value' + } + } + }, + 'mixed': [ + 1, + {'a': [2, 3]}, + (4, [5, {'b': 6}]) + ], + 'array': np.arange(10), + 'multiline': 'line1\nline2\nline3' + } - def test_indent_sequence_on_key(self): - data = {"key": [1, 2, {"nested": "value"}]} + def test_initialization(self) -> None: + """Test initialization with different parameters.""" + # Default initialization + self.assertEqual(self.dumper.item_indent, ' ') + self.assertEqual(self.dumper.list_indent, '- ') + + # Custom initialization + custom_dumper = FlexibleDumper( + item_indent=' ', + list_indent='* ', + separator=' = ', + skip_str='...' + ) + self.assertEqual(custom_dumper.item_indent, ' ') + self.assertEqual(custom_dumper.list_indent, '* ') + self.assertEqual(custom_dumper.separator, ' = ') + self.assertEqual(custom_dumper.skip_str, '...') + + # Invalid initialization + with self.assertRaises(ValueError): + FlexibleDumper(item_indent=' ', list_indent='---') + + def test_basic_types(self) -> None: + """Test dumping of basic Python types.""" + cases = [ + (42, "42"), + ("hello", "hello"), + (True, "True"), + (None, "None"), + ([1, 2], "- 1\n- 2"), + ((1, 2), "- 1\n- 2"), + ({1, 2}, "{1, 2}"), + ({'a': 1}, "a: 1") + ] + + for data, expected in cases: + with self.subTest(data=data): + result = self.dumper.dump(data) + self.assertEqual(result, expected) + + def test_nested_structures(self) -> None: + """Test dumping of nested data structures.""" + data = { + 'list': [1, [2, 3], {'a': 4}], + 'dict': {'x': {'y': 'z'}}, + 'mixed': [1, {'a': [2, {'b': 3}]}] + } dumper = FlexibleDumper(indent_sequence_on_key=True) - expected_output = "key: \n - 1\n - 2\n - nested: value" result = dumper.dump(data) - self.assertEqual(result, expected_output) + + expected = ( + "list: \n" + " - 1\n" + " - - 2\n" + " - 3\n" + " - a: 4\n" + "dict: \n" + " x: \n" + " y: z\n" + "mixed: \n" + " - 1\n" + " - a: \n" + " - 2\n" + " - b: 3" + ) + self.assertEqual(result, expected) - def test_max_depth(self): - data = {"a": {"b": {"c": {"d": "e"}}}} + def test_limits(self) -> None: + """Test various limiting parameters.""" + # Test max_depth + deep_data = {'a': {'b': {'c': {'d': 'e'}}}} dumper = FlexibleDumper(max_depth=2) - expected_output = "a: \n b: \n c: \n ..." - result = dumper.dump(data) - self.assertEqual(result, expected_output) + self.assertEqual( + dumper.dump(deep_data), + "a: \n b: \n ..." + ) - def test_max_iteration(self): - data = [[1, 2, 3, 4, 5], {1:2, 3:4, 5:6, 7:8}] + # Test max_iteration + long_list = list(range(10)) dumper = FlexibleDumper(max_iteration=3) - expected_output = "- - 1\n - 2\n - 3\n ...\n- 1: 2\n 3: 4\n 5: 6\n 7: 8" - result = dumper.dump(data) - self.assertEqual(result, expected_output) + self.assertEqual( + dumper.dump(long_list), + "- 0\n- 1\n- 2\n..." + ) - def test_max_item(self): - data = [[1, 2, 3, 4, 5], {1:2, 3:4, 5:6, 7:8}] + # Test max_item + big_dict = {str(i): i for i in range(10)} dumper = FlexibleDumper(max_item=3) - expected_output = "- - 1\n - 2\n - 3\n - 4\n - 5\n- 1: 2\n 3: 4\n 5: 6\n ..." - result = dumper.dump(data) - self.assertEqual(result, expected_output) + result = dumper.dump(big_dict) + self.assertIn("...", result) + self.assertTrue(len(result.split('\n')) <= 4) - def test_max_line(self): - data = [1, 2, 3, 4, 5] - dumper = FlexibleDumper(max_line=3) - expected_output = "- 1\n- 2\n- 3\n..." - result = dumper.dump(data) - self.assertEqual(result, expected_output) + # Test max_len + long_text = "This is a very long text string" + dumper = FlexibleDumper(max_len=15) + self.assertEqual( + dumper.dump(long_text), + "This is a ve..." + ) - def test_max_len(self): - data = {"a": "abcdefghijklmnopqrstuvwxyz"} - dumper = FlexibleDumper(max_len=10) - expected_output = "a: abcdefg..." - result = dumper.dump(data) - self.assertEqual(result, expected_output) + def test_multiline_handling(self) -> None: + """Test handling of multiline strings.""" + data = { + 'text': 'line1\nline2\nline3', + 'nested': { + 'text': 'a\nb\nc' + } + } + result = self.dumper.dump(data) + expected = ( + "text: line1\n" + " line2\n" + " line3\n" + "nested: \n" + " text: a\n" + " b\n" + " c" + ) + self.assertEqual(result, expected) + + def test_numpy_array_handling(self) -> None: + """Test handling of NumPy arrays.""" + arrays = { + 'int': np.array([1, 2, 3]), + '2d': np.array([[1, 2], [3, 4]]), + 'float': np.array([1.5, 2.5, 3.5]) + } + result = self.dumper.dump(arrays) + self.assertIn("[1 2 3]", result) + self.assertIn("[[1 2]", result) + self.assertIn(" [3 4]]", result) + self.assertIn("[1.5 2.5 3.5]", result) + + def test_edge_cases(self) -> None: + """Test edge cases and potential error conditions.""" + cases = [ + ({}, ""), # Empty dict + ([], ""), # Empty list + (" ", " "), # Whitespace + ("", ""), # Empty string + (set(), "set()"), # Empty set + ({"": ""}, ": "), # Empty keys/values + ([None, None], "- None\n- None"), # None values + ] + + for data, expected in cases: + with self.subTest(data=data): + result = self.dumper.dump(data) + self.assertEqual(result, expected) + + def test_configuration_changes(self) -> None: + """Test dynamic configuration changes.""" + dumper = FlexibleDumper() + data = {'a': [1, 2, 3]} + + # Default output + default_output = dumper.dump(data) + + # Change configuration + dumper.configure( + item_indent='>>', + list_indent='*>', + separator=' = ' + ) + + modified_output = dumper.dump(data) + self.assertNotEqual(default_output, modified_output) + self.assertIn('*>', modified_output) + self.assertIn(' = ', modified_output) + + def test_invalid_configurations(self) -> None: + """Test handling of invalid configurations.""" + with self.assertRaises(ValueError): + FlexibleDumper(item_indent='', list_indent='') + + with self.assertRaises(ValueError): + FlexibleDumper(item_indent=' ', list_indent='-') + + dumper = FlexibleDumper() + with self.assertRaises(KeyError): + dumper.configure(invalid_param='value') - def test_all_constraints(self): - data = [[1, 2, 3, 4, 5, 6], [2, "abcdefghijklmnopq"], {"a": 1, "b": 2, 5: [[], [1, 5, []]], 2: 3}, [1, [1, [[1, 2]]]], 6] - dumper = FlexibleDumper(max_depth=2, max_iteration=5, max_item=3, max_line=14, max_len=20) - expected_output = "- - 1\n - 2\n - 3\n - 4\n - 5\n ...\n- - 2\n - abcdefghijklmnop...\n- a: 1\n b: 2\n 5: \n - ...\n ...\n- - 1\n..." - result = dumper.dump(data) - self.assertEqual(result, expected_output) if __name__ == '__main__': - unittest.main() + unittest.main(verbosity=2) \ No newline at end of file diff --git a/quickstats/tests/test_io.py b/quickstats/tests/test_io.py new file mode 100644 index 0000000000000000000000000000000000000000..5a85772d51a29543ca52cbb8f004c5cda3348347 --- /dev/null +++ b/quickstats/tests/test_io.py @@ -0,0 +1,209 @@ +"""Unit tests for core.io module.""" + +import unittest +import sys +import io +from unittest.mock import patch, MagicMock +from contextlib import contextmanager +from quickstats.core.io import ( + TextColors, + Verbosity, + VerbosePrint, + switch_verbosity, + set_default_log_format +) + +@contextmanager +def captured_output(): + """Context manager to capture stdout.""" + new_out = io.StringIO() + old_out = sys.stdout + try: + sys.stdout = new_out + yield new_out + finally: + sys.stdout = old_out + +class TestTextColors(unittest.TestCase): + """Test TextColors class functionality.""" + + def test_colorize(self): + """Test text colorization.""" + colored = TextColors.colorize("test", "red") + self.assertIn("test", colored) + self.assertIn(TextColors.CODES["red"], colored) + self.assertIn(TextColors.CODES["reset"], colored) + + def test_invalid_color(self): + """Test handling of invalid colors.""" + self.assertEqual(TextColors.colorize("test", None), "test") + self.assertEqual(TextColors.colorize("test", "invalid"), "test") + + def test_format_comparison(self): + """Test text comparison formatting.""" + left, right = TextColors.format_comparison( + "abc", "abd", + equal_color="blue", + delete_color="red", + insert_color="green" + ) + self.assertIn(TextColors.CODES["blue"], left) + self.assertIn(TextColors.CODES["red"], left) + self.assertIn(TextColors.CODES["blue"], right) + self.assertIn(TextColors.CODES["green"], right) + +class TestVerbosity(unittest.TestCase): + """Test Verbosity enum functionality.""" + + def test_ordering(self): + """Test verbosity level ordering.""" + self.assertLess(Verbosity.DEBUG, Verbosity.INFO) + self.assertLess(Verbosity.INFO, Verbosity.WARNING) + self.assertLess(Verbosity.WARNING, Verbosity.ERROR) + self.assertLess(Verbosity.ERROR, Verbosity.CRITICAL) + + def test_comparison(self): + """Test verbosity comparison with different types.""" + self.assertLess(Verbosity.DEBUG, 20) # INFO level + self.assertLess(Verbosity.DEBUG, "INFO") + self.assertEqual(Verbosity.INFO, "INFO") + self.assertEqual(Verbosity.INFO, 20) + +class TestVerbosePrint(unittest.TestCase): + """Test VerbosePrint class functionality.""" + + def setUp(self): + self.printer = VerbosePrint( + verbosity=Verbosity.INFO, + fmt="basic", + name="test" + ) + + def test_verbosity_setting(self): + """Test verbosity level setting and parsing.""" + self.printer.verbosity = "DEBUG" + self.assertEqual(self.printer.verbosity, Verbosity.DEBUG) + self.printer.verbosity = Verbosity.INFO + self.assertEqual(self.printer.verbosity, Verbosity.INFO) + self.printer.verbosity = 30 # WARNING + self.assertEqual(self.printer.verbosity, Verbosity.WARNING) + + def test_basic_output(self): + """Test basic message output.""" + with captured_output() as out: + self.printer.info("test message") + output = out.getvalue() + self.assertIn("[INFO]", output) + self.assertIn("test message", output) + + def test_format_switching(self): + """Test format switching.""" + self.printer.set_format("detailed") + with captured_output() as out: + self.printer.info("test") + detailed = out.getvalue() + self.assertIn("PID:", detailed) + self.assertIn("TID:", detailed) + + self.printer.set_format("basic") + with captured_output() as out: + self.printer.info("test") + basic = out.getvalue() + self.assertNotIn("PID:", basic) + self.assertNotIn("TID:", basic) + + def test_time_formatting(self): + """Test time format customization.""" + self.printer.set_format("detailed") + self.printer.set_timefmt(datefmt="%H:%M:%S") + + with captured_output() as out: + self.printer.info("test") + output = out.getvalue() + # Check time format (HH:MM:SS.mmm) + self.assertRegex(output, r"\d{2}:\d{2}:\d{2}\.\d{3}") + + def test_verbosity_filtering(self): + """Test message filtering by verbosity.""" + self.printer.verbosity = Verbosity.WARNING + with captured_output() as out: + self.printer.debug("debug") # Should not print + self.printer.info("info") # Should not print + self.printer.warning("warn") # Should print + self.printer.error("error") # Should print + + output = out.getvalue() + self.assertNotIn("debug", output) + self.assertNotIn("info", output) + self.assertIn("warn", output) + self.assertIn("error", output) + + def test_bare_output(self): + """Test bare (unformatted) output.""" + with captured_output() as out: + self.printer.info("test", bare=True) + output = out.getvalue() + self.assertEqual(output.strip(), "test") + + def test_color_output(self): + """Test colored output.""" + with captured_output() as out: + self.printer.info("test", color="red") + output = out.getvalue() + self.assertIn(TextColors.CODES["red"], output) + self.assertIn(TextColors.CODES["reset"], output) + + def test_copy(self): + """Test printer copying.""" + copy = self.printer.copy() + self.assertEqual(copy.verbosity, self.printer.verbosity) + self.assertEqual(copy._formatter._fmt, self.printer._formatter._fmt) + self.assertEqual(copy._name, self.printer._name) + +class TestVerbositySwitch(unittest.TestCase): + """Test verbosity switching context manager.""" + + def test_switch_verbosity(self): + """Test temporary verbosity switching.""" + printer = VerbosePrint(verbosity=Verbosity.INFO) + + self.assertEqual(printer.verbosity, Verbosity.INFO) + + with switch_verbosity(printer, Verbosity.DEBUG): + self.assertEqual(printer.verbosity, Verbosity.DEBUG) + + self.assertEqual(printer.verbosity, Verbosity.INFO) + + def test_switch_verbosity_with_exception(self): + """Test verbosity restoration after exception.""" + printer = VerbosePrint(verbosity=Verbosity.INFO) + + try: + with switch_verbosity(printer, Verbosity.DEBUG): + raise Exception("test error") + except Exception: + pass + + self.assertEqual(printer.verbosity, Verbosity.INFO) + +class TestDefaultFormat(unittest.TestCase): + """Test default format setting.""" + + def test_set_default_format(self): + """Test setting default format.""" + set_default_log_format("detailed") + printer = VerbosePrint() # Should use new default + self.assertEqual( + printer._formatter._fmt, + VerbosePrint.FORMATS["detailed"] + ) + + set_default_log_format("basic") # Reset to default + printer = VerbosePrint() + self.assertEqual( + printer._formatter._fmt, + VerbosePrint.FORMATS["basic"] + ) + +if __name__ == '__main__': + unittest.main() diff --git a/quickstats/tests/test_mappings.py b/quickstats/tests/test_mappings.py new file mode 100644 index 0000000000000000000000000000000000000000..ced5fed44074279893b7613bc8bb26924b3023c2 --- /dev/null +++ b/quickstats/tests/test_mappings.py @@ -0,0 +1,227 @@ +"""Unit tests for core.mappings module.""" + +import unittest +from copy import deepcopy +from typing import Dict, Any + +from quickstats.core.mappings import ( + recursive_update, + concatenate, + merge_classattr, + NestedDict +) + +class TestRecursiveUpdate(unittest.TestCase): + """Test recursive_update functionality.""" + + def setUp(self): + self.base_dict = { + 'a': 1, + 'b': { + 'c': 2, + 'd': {'e': 3} + } + } + + def test_simple_update(self): + """Test basic dictionary update.""" + update_dict = {'a': 10, 'x': 20} + result = recursive_update(self.base_dict.copy(), update_dict) + + self.assertEqual(result['a'], 10) + self.assertEqual(result['x'], 20) + self.assertEqual(result['b']['c'], 2) + + def test_nested_update(self): + """Test updating nested dictionaries.""" + update_dict = { + 'b': { + 'c': 20, + 'd': {'f': 30} + } + } + result = recursive_update(self.base_dict.copy(), update_dict) + + self.assertEqual(result['b']['c'], 20) + self.assertEqual(result['b']['d']['e'], 3) + self.assertEqual(result['b']['d']['f'], 30) + + def test_empty_update(self): + """Test update with empty dictionary.""" + original = self.base_dict.copy() + result = recursive_update(original, {}) + + self.assertEqual(result, original) + + def test_none_values(self): + """Test handling of None values.""" + update_dict = {'a': None, 'b': {'c': None}} + result = recursive_update(self.base_dict.copy(), update_dict) + + self.assertIsNone(result['a']) + self.assertIsNone(result['b']['c']) + + def test_new_nested_structure(self): + """Test creating new nested structures.""" + update_dict = {'new': {'nested': {'value': 42}}} + result = recursive_update(self.base_dict.copy(), update_dict) + + self.assertEqual(result['new']['nested']['value'], 42) + +class TestConcatenate(unittest.TestCase): + """Test concatenate functionality.""" + + def setUp(self): + self.dict1 = {'a': 1, 'b': {'c': 2}} + self.dict2 = {'b': {'d': 3}, 'e': 4} + + def test_basic_concatenation(self): + """Test basic dictionary concatenation.""" + result = concatenate([self.dict1, self.dict2]) + + self.assertEqual(result['a'], 1) + self.assertEqual(result['b']['c'], 2) + self.assertEqual(result['b']['d'], 3) + self.assertEqual(result['e'], 4) + + def test_copy_option(self): + """Test copy parameter behavior.""" + # Without copy + result1 = concatenate([self.dict1, self.dict2], copy=False) + self.dict1['a'] = 100 + self.assertEqual(result1['a'], 1) # Should not be affected + + # With copy + result2 = concatenate([self.dict1, self.dict2], copy=True) + self.dict1['a'] = 200 + self.assertEqual(result2['a'], 1) # Should not be affected + + def test_none_handling(self): + """Test handling of None values in input sequence.""" + result = concatenate([self.dict1, None, self.dict2]) + + self.assertEqual(result['a'], 1) + self.assertEqual(result['e'], 4) + + def test_empty_input(self): + """Test concatenation with empty input.""" + result = concatenate([]) + self.assertEqual(result, {}) + +class TestNestedDict(unittest.TestCase): + """Test NestedDict class functionality.""" + + def setUp(self): + self.nested = NestedDict({'a': 1, 'b': {'c': 2, 'd': 3}}) + + def test_merge(self): + """Test merge method.""" + self.nested.merge({'b': {'e': 4}, 'f': 5}) + + self.assertEqual(self.nested['b']['c'], 2) + self.assertEqual(self.nested['b']['e'], 4) + self.assertEqual(self.nested['f'], 5) + + def test_merge_none(self): + """Test merge with None.""" + original = deepcopy(self.nested) + self.nested.merge(None) + + self.assertEqual(self.nested, original) + + def test_and_operator(self): + """Test & operator.""" + result = self.nested & {'b': {'e': 4}} + + self.assertEqual(result['b']['c'], 2) + self.assertEqual(result['b']['e'], 4) + self.assertIsInstance(result, NestedDict) + + def test_iand_operator(self): + """Test &= operator.""" + self.nested &= {'b': {'e': 4}} + + self.assertEqual(self.nested['b']['c'], 2) + self.assertEqual(self.nested['b']['e'], 4) + + def test_ror_operator(self): + """Test reverse or operator.""" + result = {'b': {'e': 4}} | self.nested + + self.assertEqual(result['b']['c'], 2) + self.assertEqual(result['b']['e'], 4) + self.assertIsInstance(result, NestedDict) + + def test_copy(self): + """Test copy method.""" + # Shallow copy + shallow = self.nested.copy() + shallow['a'] = 100 + self.assertEqual(self.nested['a'], 1) + shallow['b']['c'] = 200 + self.assertEqual(self.nested['b']['c'], 200) # Nested dict is shared + + # Deep copy + deep = self.nested.copy(deep=True) + deep['b']['c'] = 300 + self.assertEqual(self.nested['b']['c'], 200) # Nested dict is separate + +class TestMergeClassAttr(unittest.TestCase): + """Test merge_classattr functionality.""" + + def setUp(self): + class Base: + data = {'a': 1} + + class Child(Base): + data = {'b': 2} + + class GrandChild(Child): + data = {'c': 3} + + self.Base = Base + self.Child = Child + self.GrandChild = GrandChild + + def test_basic_merge(self): + """Test basic attribute merging.""" + result = merge_classattr(self.Child, 'data') + + self.assertEqual(result['a'], 1) + self.assertEqual(result['b'], 2) + + def test_multi_level_merge(self): + """Test multi-level inheritance merging.""" + result = merge_classattr(self.GrandChild, 'data') + + self.assertEqual(result['a'], 1) + self.assertEqual(result['b'], 2) + self.assertEqual(result['c'], 3) + + def test_copy_option(self): + """Test copy parameter behavior.""" + result = merge_classattr(self.Child, 'data', copy=True) + self.Base.data['a'] = 100 + + self.assertEqual(result['a'], 1) # Should not be affected + + def test_missing_attribute(self): + """Test handling of missing attributes.""" + class Empty: + pass + + result = merge_classattr(Empty, 'data') + self.assertEqual(result, {}) + + def test_custom_parser(self): + """Test custom parser function.""" + def parser(data: Dict[str, Any]) -> Dict[str, Any]: + return {k.upper(): v * 2 for k, v in data.items()} + + result = merge_classattr(self.Child, 'data', parse=parser) + + self.assertEqual(result['A'], 2) + self.assertEqual(result['B'], 4) + +if __name__ == '__main__': + unittest.main() diff --git a/quickstats/tests/test_trees.py b/quickstats/tests/test_trees.py new file mode 100644 index 0000000000000000000000000000000000000000..6bff19f40271b16fe12e8a21d4ef05be64e6fe7b --- /dev/null +++ b/quickstats/tests/test_trees.py @@ -0,0 +1,199 @@ +"""Unit tests for core.trees module.""" + +import unittest +from typing import Optional, Any + +from quickstats.core.trees import ( + NodeConfig, + NamedTreeNode, + TreeError, + InvalidNodeError, + DomainError, + ValidationError +) + +class TestNodeConfig(unittest.TestCase): + """Test NodeConfig functionality.""" + + def test_default_config(self): + """Test default configuration values.""" + config = NodeConfig() + + self.assertEqual(config.separator, '.') + self.assertTrue(config.allow_none_data) + self.assertFalse(config.validate_names) + self.assertFalse(config.validate_data_type) + self.assertEqual(config.name_pattern, r'^[a-zA-Z][a-zA-Z0-9_]*$') + +class TestNamedTreeNode(unittest.TestCase): + """Test NamedTreeNode functionality.""" + + def setUp(self): + self.root = NamedTreeNode[str]("root", "root_data") + + def test_basic_operations(self): + """Test basic node operations.""" + # Add child + child = NamedTreeNode("child1", "child1_data") + self.root.add_child(child) + + self.assertEqual(self.root.get("child1"), "child1_data") + self.assertEqual(len(list(self.root)), 1) + + # Remove child + removed = self.root.remove_child("child1") + self.assertEqual(removed.data, "child1_data") + self.assertEqual(len(list(self.root)), 0) + + def test_domain_operations(self): + """Test domain-style operations.""" + # Set with domain + self.root.set("data1", "child1.grandchild1") + self.root.set("data2", "child1.grandchild2") + + # Get with domain + self.assertEqual(self.root.get("child1.grandchild1"), "data1") + self.assertEqual(self.root.get("child1.grandchild2"), "data2") + + # Check domain exists + self.assertTrue("child1.grandchild1" in self.root) + self.assertFalse("invalid.path" in self.root) + + def test_traversal(self): + """Test tree traversal.""" + # Setup tree + self.root["a.b.c"] = "data1" + self.root["a.b.d"] = "data2" + + # Test traverse method + node = self.root.traverse("a", "b") + self.assertIsNotNone(node) + self.assertEqual(node.get("c"), "data1") + self.assertEqual(node.get("d"), "data2") + + # Test traverse_domain method + node = self.root.traverse_domain("a.b") + self.assertIsNotNone(node) + self.assertEqual(node.get("c"), "data1") + + def test_dict_operations(self): + """Test dictionary-style operations.""" + # Dictionary update + self.root |= { + "name": "child1", + "data": "child1_data", + "children": { + "grandchild1": { + "name": "grandchild1", + "data": "gc1_data" + } + } + } + + self.assertEqual(self.root.get("child1.grandchild1"), "gc1_data") + + def test_validation(self): + """Test node validation.""" + # Enable validation + self.root.config.validate_names = True + + # Valid name + valid_node = NamedTreeNode("valid_name", "data") + self.root.add_child(valid_node) + + # Invalid name + with self.assertRaises(ValidationError): + invalid_node = NamedTreeNode("123invalid", "data") + self.root.add_child(invalid_node) + + def test_type_validation(self): + """Test data type validation.""" + typed_node = NamedTreeNode[int]("numbers") + typed_node.config.validate_data_type = True + + # Valid type + typed_node.set(42, "valid") + + # Invalid type + with self.assertRaises(ValidationError): + typed_node.set("not a number", "invalid") + + def test_copy_operations(self): + """Test copy operations.""" + # Setup original + self.root["a.b.c"] = "data1" + + # Shallow copy + shallow = self.root.copy() + shallow["a.b.c"] = "changed" + self.assertEqual(self.root["a.b.c"], "data1") + + # Deep copy + deep = self.root.copy(deep=True) + deep["a.b.c"] = "changed" + self.assertEqual(self.root["a.b.c"], "data1") + + def test_merge_operations(self): + """Test merge operations.""" + # Setup trees + tree1 = NamedTreeNode[str]("root") + tree1["a.b"] = "data1" + + tree2 = NamedTreeNode[str]("root") + tree2["a.c"] = "data2" + + # Test merge + tree1.merge(tree2) + self.assertEqual(tree1["a.b"], "data1") + self.assertEqual(tree1["a.c"], "data2") + + def test_error_handling(self): + """Test error handling.""" + # Invalid node type + with self.assertRaises(TypeError): + self.root.add_child("not a node") + + # Domain not found + with self.assertRaises(KeyError): + _ = self.root["invalid.path"] + + # Duplicate child name + child = NamedTreeNode("child", "data") + self.root.add_child(child) + with self.assertRaises(ValidationError): + self.root.add_child(child) + +class TestTreeDataTypes(unittest.TestCase): + """Test tree with different data types.""" + + def test_int_tree(self): + """Test tree with integer data.""" + tree = NamedTreeNode[int]("numbers") + tree["a"] = 1 + tree["b"] = 2 + + self.assertEqual(tree["a"], 1) + self.assertIsInstance(tree["a"], int) + + def test_optional_data(self): + """Test tree with optional data.""" + tree = NamedTreeNode[Optional[str]]("optional") + tree["a"] = "data" + tree["b"] = None + + self.assertEqual(tree["a"], "data") + self.assertIsNone(tree["b"]) + + def test_any_data(self): + """Test tree with Any data type.""" + tree = NamedTreeNode[Any]("any") + tree["a"] = 1 + tree["b"] = "string" + tree["c"] = [1, 2, 3] + + self.assertEqual(tree["a"], 1) + self.assertEqual(tree["b"], "string") + self.assertEqual(tree["c"], [1, 2, 3]) + +if __name__ == '__main__': + unittest.main() diff --git a/quickstats/tests/test_type_validation.py b/quickstats/tests/test_type_validation.py new file mode 100644 index 0000000000000000000000000000000000000000..bed02311666806ee2aabd1f8e44b40e3dd90978c --- /dev/null +++ b/quickstats/tests/test_type_validation.py @@ -0,0 +1,188 @@ +"""Unit tests for type validation module.""" + +from __future__ import annotations + +import unittest +from typing import ( + List, Dict, Union, Optional, Tuple, Set, + TypeVar, Any, Generic +) +from dataclasses import dataclass + +from quickstats.core.type_validation import ( + ValidatorFactory, + check_type, + get_type_hint_str, + ValidationError +) + +T = TypeVar('T') + +@dataclass +class MockGeneric(Generic[T]): + value: T + + +class TestTypeValidation(unittest.TestCase): + """Test suite for type validation functionality.""" + + def test_basic_types(self) -> None: + """Test validation of basic Python types.""" + cases = [ + (42, int, True), + ("hello", str, True), + (3.14, float, True), + (True, bool, True), + (42, str, False), + ("42", int, False), + (3.14, int, False), + ] + + for value, type_hint, expected in cases: + with self.subTest(value=value, type=type_hint): + self.assertEqual( + check_type(value, type_hint), + expected, + f"Failed for value {value} of type {type(value)}" + ) + + def test_container_types(self) -> None: + """Test validation of container types.""" + cases = [ + ([1, 2, 3], List[int], True), + ([1, "2", 3], List[int], False), + ({"a": 1, "b": 2}, Dict[str, int], True), + ({"a": "1"}, Dict[str, int], False), + ({1, 2, 3}, Set[int], True), + ((1, "2", 3.0), Tuple[int, str, float], True), + ((1, 2), Tuple[int, ...], True), + ] + + for value, type_hint, expected in cases: + with self.subTest(value=value, type=type_hint): + self.assertEqual(check_type(value, type_hint), expected) + + def test_optional_types(self) -> None: + """Test validation of Optional types.""" + cases = [ + (None, Optional[int], True), + (42, Optional[int], True), + ("hello", Optional[int], False), + (None, Optional[List[int]], True), + ([1, 2, 3], Optional[List[int]], True), + ([1, "2"], Optional[List[int]], False), + ] + + for value, type_hint, expected in cases: + with self.subTest(value=value, type=type_hint): + self.assertEqual(check_type(value, type_hint), expected) + + def test_union_types(self) -> None: + """Test validation of Union types.""" + cases = [ + (42, Union[int, str], True), + ("hello", Union[int, str], True), + (3.14, Union[int, str], False), + ([1, 2], Union[List[int], Dict[str, int]], True), + ({"a": 1}, Union[List[int], Dict[str, int]], True), + ({1: "a"}, Union[List[int], Dict[str, int]], False), + ] + + for value, type_hint, expected in cases: + with self.subTest(value=value, type=type_hint): + self.assertEqual(check_type(value, type_hint), expected) + + def test_nested_types(self) -> None: + """Test validation of nested type structures.""" + cases = [ + ( + {"a": [1, 2], "b": [3, 4]}, + Dict[str, List[int]], + True + ), + ( + {"a": [1, "2"]}, + Dict[str, List[int]], + False + ), + ( + [[1, 2], [3, 4]], + List[List[int]], + True + ), + ( + {"a": {"b": [1, 2]}}, + Dict[str, Dict[str, List[int]]], + True + ), + ] + + for value, type_hint, expected in cases: + with self.subTest(value=value, type=type_hint): + self.assertEqual(check_type(value, type_hint), expected) + + def test_generic_types(self) -> None: + """Test validation of generic types.""" + int_generic = MockGeneric[int](42) + str_generic = MockGeneric[str]("hello") + + self.assertTrue(check_type(int_generic.value, int)) + self.assertTrue(check_type(str_generic.value, str)) + self.assertFalse(check_type(int_generic.value, str)) + self.assertFalse(check_type(str_generic.value, int)) + + def test_type_hint_str(self) -> None: + """Test string representation of type hints.""" + cases = [ + (int, "int"), + (List[int], "List[int]"), + (Dict[str, int], "Dict[str, int]"), + (Optional[int], "Optional[int]"), + (Union[int, str], "int | str"), + (List[Dict[str, int]], "List[Dict[str, int]]"), + (Tuple[int, ...], "Tuple[int, ...]"), + ] + + for type_hint, expected in cases: + with self.subTest(type=type_hint): + self.assertEqual(get_type_hint_str(type_hint), expected) + + def test_validation_errors(self) -> None: + """Test error handling during validation.""" + with self.assertRaises(ValidationError): + check_type("hello", int, raise_error=True) + + with self.assertRaises(ValidationError): + check_type([1, "2"], List[int], raise_error=True) + + # Test error message content + try: + check_type(42, str, raise_error=True) + except ValidationError as e: + self.assertIn("str", str(e)) + self.assertIn("int", str(e)) + + def test_validator_factory_cache(self) -> None: + """Test caching behavior of ValidatorFactory.""" + # Get validators for same type multiple times + validator1 = ValidatorFactory.get_validator(List[int]) + validator2 = ValidatorFactory.get_validator(List[int]) + + # Should return same cached validator + self.assertIs(validator1, validator2) + + # Test with different types + validator3 = ValidatorFactory.get_validator(Dict[str, int]) + self.assertIsNot(validator1, validator3) + + def test_any_type(self) -> None: + """Test validation with Any type.""" + cases = [42, "hello", [1, 2], {"a": 1}, {1, 2}, (1, 2), None] + + for value in cases: + with self.subTest(value=value): + self.assertTrue(check_type(value, Any)) + + +if __name__ == '__main__': + unittest.main(verbosity=2) \ No newline at end of file diff --git a/quickstats/utils/common_utils.py b/quickstats/utils/common_utils.py index 0b5c499e70d78c04c42b8a5241081b5b2d79116b..815eeffefec422b692048fae23a6583c6ae64bb2 100644 --- a/quickstats/utils/common_utils.py +++ b/quickstats/utils/common_utils.py @@ -20,8 +20,6 @@ except ImportError: import numpy as np -from quickstats import cached_import - class disable_cout: def __enter__(self): import cppyy @@ -103,11 +101,13 @@ def stdout_print(msg): sys.__stdout__.write(msg + '\n') def redirect_stdout(logfile_path): + from quickstats import cached_import ROOT = cached_import("ROOT") sys.stdout = open(logfile_path, 'w') ROOT.gSystem.RedirectOutput(logfile_path) def restore_stdout(): + from quickstats import cached_import ROOT = cached_import("ROOT") if sys.stdout != sys.__stdout__: sys.stdout.close() @@ -116,6 +116,7 @@ def restore_stdout(): def redirect_stdout_test(func): """Redirect stdout to a log file""" + from quickstats import cached_import ROOT = cached_import("ROOT") @functools.wraps(func) def wrapper_timer(*args, **kwargs): diff --git a/quickstats/utils/string_utils.py b/quickstats/utils/string_utils.py index a8c5a135bf40a604187aef492bf3a08864447281..426277d3fecbf392560bf96d9baf27fc14a9cecf 100644 --- a/quickstats/utils/string_utils.py +++ b/quickstats/utils/string_utils.py @@ -366,8 +366,35 @@ def parse_format_str_with_regex(str_list, format_str, regex_map, mode: str = "se return results -def format_dict_to_string(dictionary: Dict[str, str], separator: str = " : ", - left_margin: int = 0, line_break: int = 100) -> str: +def format_delimited_dict(dictionary: dict, separator: str = '=', delimiter: str = ',') -> str: + """ + Formats a dictionary into a string, where each key-value pair is separated by the specified + separator, and different items are separated by the specified delimiter. + + Parameters + ---------- + dictionary : dict + The dictionary to format. + separator : str, optional + The string used to separate keys from values. Defaults to '='. + delimiter : str, optional + The string used to separate different key-value pairs. Defaults to ','. + + Returns + ------- + str + The formatted string where keys and values are joined by the separator, and items are separated + by the delimiter. + + Example + ------- + >>> format_delimited_dict({'key1': 'value1', 'key2': 'value2'}, '=', ',') + 'key1=value1,key2=value2' + """ + return delimiter.join([f"{key}{separator}{value}" for key, value in dictionary.items()]) + +def format_aligned_dict(dictionary: Dict[str, str], separator: str = " : ", + left_margin: int = 0, linebreak: int = 100) -> str: """ Formats a dictionary into a neatly aligned string representation, with each key-value pair on a new line. @@ -376,30 +403,30 @@ def format_dict_to_string(dictionary: Dict[str, str], separator: str = " : ", can contain multiple words. separator: The string used to separate keys from their values. Defaults to ": ". left_margin: The number of spaces to prepend to each line for indentation. Defaults to 0. - line_break: The maximum allowed width of each line, in characters, before wrapping the text to a new line. + linebreak: The maximum allowed width of each line, in characters, before wrapping the text to a new line. Defaults to 100. Returns: A string representation of the dictionary. Each key-value pair is on its own line, with lines broken such - that words are not split across lines, respecting the specified `line_break` width. + that words are not split across lines, respecting the specified `linebreak` width. Example: >>> example_dict = {"Key1": "This is a short value.", "Key2": "This is a much longer value that will be wrapped according to the specified line break width."} - >>> print(format_dict_to_string(example_dict, left_margin=4, line_break=80)) + >>> print(format_aligned_dict(example_dict, left_margin=4, linebreak=80)) Key1: This is a short value. Key2: This is a much longer value that will be wrapped according to the specified line break width. Note: The function removes existing newlines in values to prevent unexpected line breaks and treats the entire - value as a single paragraph that needs to be wrapped according to `line_break`. + value as a single paragraph that needs to be wrapped according to `linebreak`. """ if not dictionary: return "" max_key_length = max(len(key) for key in dictionary) indent_size = left_margin + max_key_length + len(separator) - effective_text_width = line_break - indent_size + effective_text_width = linebreak - indent_size if effective_text_width <= 0: raise ValueError("Line break width must be greater than the size of indentation and separator.") @@ -408,7 +435,7 @@ def format_dict_to_string(dictionary: Dict[str, str], separator: str = " : ", indent_string = " " * indent_size for key, value in dictionary.items(): cleaned_value = str(value).replace("\n", " ") - wrapped_value = make_multiline_text(cleaned_value, line_break, False, indent_string) + wrapped_value = make_multiline_text(cleaned_value, linebreak, False, indent_string) line = f"{' ' * left_margin}{str(key):{max_key_length}}{separator}{wrapped_value}" formatted_lines.append(line) @@ -497,4 +524,34 @@ def replace_with_mapping(s: str, mapping: Dict[str, str]) -> str: """ for old, new in mapping.items(): s = s.replace(old, new) - return s \ No newline at end of file + return s + +def indent_str(s: str, indent: int = 4, indent_char: str = ' ') -> str: + """ + Indents each line of a given string. + + Parameters + ---------- + s : str + The input string to be indented. + indent : int, optional + The number of characters to indent each line. Default is 4. + indent_char : str, optional + The character used for indentation. Default is a space (' '). + + Returns + ------- + str + The indented multi-line string. + + Examples + -------- + >>> s = "Line 1\nLine 2\nLine 3" + >>> indent_str(s, indent=4, indent_char=' ') + ' Line 1\n Line 2\n Line 3' + + >>> indent_str(s, indent=2, indent_char='-') + '--Line 1\n--Line 2\n--Line 3' + """ + indentation = indent_char * indent + return '\n'.join([f'{indentation}{line}' for line in s.splitlines()]) \ No newline at end of file diff --git a/quickstats/workspace/elements/workspace.py b/quickstats/workspace/elements/workspace.py index 2b77e34299bbbb9e5e24441e366573ca615d64d5..25a8a8f66bc68e1dba48d09b850a063e674335ef 100644 --- a/quickstats/workspace/elements/workspace.py +++ b/quickstats/workspace/elements/workspace.py @@ -5,7 +5,7 @@ from pydantic import Field, ValidationInfo, field_validator, model_validator from quickstats.core.typing import Scalar from quickstats.utils.common_utils import remove_list_duplicates -from quickstats.utils.string_utils import format_dict_to_string +from quickstats.utils.string_utils import format_aligned_dict from quickstats.interface.pydantic.helpers import resolve_field_import from quickstats.workspace.settings import ( DataStorageType, @@ -139,7 +139,7 @@ class Workspace(BaseElement): 'Categories': ', '.join(self.category_names), 'Asimov datasets': ', '.join(self.asimov_names) } - text += format_dict_to_string(items, left_margin=2, line_break=70) + text += format_aligned_dict(items, left_margin=2, linebreak=70) text += '=' * 74 + '\n' return text