diff --git a/quickstats/_version.py b/quickstats/_version.py index feb91c702fec2a519fdf12875e677be6b793cf75..971db2c712b307faf0119f8f871c4a65c34f6df7 100644 --- a/quickstats/_version.py +++ b/quickstats/_version.py @@ -1 +1 @@ -__version__ = "0.8.3.5.11" +__version__ = "0.8.3.6" diff --git a/quickstats/components/modelling/data_modelling.py b/quickstats/components/modelling/data_modelling.py index 391eeb8cc8347f347a14e23fbf654a5f572f804b..c9e639e5d390f5bbd615d42617d01915cc68a06f 100644 --- a/quickstats/components/modelling/data_modelling.py +++ b/quickstats/components/modelling/data_modelling.py @@ -122,7 +122,6 @@ class DataModelling(ROOTObject): 'print_level': -1, 'min_fit': 2, 'max_fit': 3, - 'binned': False, 'minos': False, 'hesse': True, 'sumw2': True, @@ -403,6 +402,7 @@ class DataModelling(ROOTObject): data: Union[np.ndarray, "ROOT.RooDataSet", "ROOT.TTree", DataSource], weights: Optional[np.ndarray]=None, reset_parameters: bool = False, + data_error: str = 'poisson', ): with timer() as t: data_source = self.create_data_source(data, weights=weights) @@ -427,6 +427,7 @@ class DataModelling(ROOTObject): fit_kwargs['eval_bin_range'] = self.eval_options.get('bin_range', None) fit_kwargs['eval_nbins'] = self.eval_options.get('nbins', None) fit_kwargs['use_asym_error'] = self.fit_options.get('use_asym_error', True) + fit_kwargs['data_error'] = data_error fit_result = fit_tool.mle_fit(**fit_kwargs) if fit_result is not None: self.parameters.copy_data(fit_result.parameters) diff --git a/quickstats/components/modelling/model_parameters.py b/quickstats/components/modelling/model_parameters.py index 117c03061a740f52b2be774172a26d61c35a61b0..64169bbf3d7c9a6b09c64c5816b56dba7ac67984 100644 --- a/quickstats/components/modelling/model_parameters.py +++ b/quickstats/components/modelling/model_parameters.py @@ -47,13 +47,13 @@ class ModelParameters(RealVariableSet, metaclass=ModelParametersRegistryMeta): verbosity : Optional[str], optional Verbosity level for logging or diagnostics. Defaults to None. """ + self._cache_param_data : Dict[str, Any] = None name = name or self.__registry_key__ super().__init__( name=name, verbosity=verbosity, **kwargs ) - self._init_param_data = None self.set_parameters(components) def get_default_parameters(self) -> RealVariableSet: @@ -72,7 +72,8 @@ class ModelParameters(RealVariableSet, metaclass=ModelParametersRegistryMeta): def set_parameters( self, - components: Optional[ParametersType] = None + components: Optional[ParametersType] = None, + overwrite_cache: bool = True ) -> None: self._is_locked = False self.clear() @@ -122,12 +123,16 @@ class ModelParameters(RealVariableSet, metaclass=ModelParametersRegistryMeta): parameters = RealVariableSet(components=cloned_components) else: parameters = default_parameters - self._init_param_data = parameters.data + if overwrite_cache: + self._cache_param_data = parameters.data self.append(parameters) self._is_locked = True def reset(self) -> None: - for name, data in self._init_param_data.items(): + if self._cache_param_data is None: + self.stdout.warning('No cache parameter data available. Skipped.') + return + for name, data in self._cache_param_data.items(): self[name].set_data(**data) def prefit(self, data: DataSource) -> None: diff --git a/quickstats/components/modelling/pdf_fit_tool.py b/quickstats/components/modelling/pdf_fit_tool.py index 85d65becf0ac707ef1780cd77fbffd580b5f23f9..2ce0f3f55b3f690ca26dc995295bde44d1b75b6a 100644 --- a/quickstats/components/modelling/pdf_fit_tool.py +++ b/quickstats/components/modelling/pdf_fit_tool.py @@ -1,10 +1,16 @@ from typing import List, Optional, Dict, Any, Union -from quickstats import AbstractObject, semistaticmethod, cached_import, timer +from quickstats import AbstractObject, semistaticmethod, cached_import, timer, DescriptiveEnum from quickstats.core.typing import ArrayLike -from quickstats.utils.string_utils import split_str +from quickstats.utils.string_utils import split_str, unique_string from quickstats.interface.root import RooFitResult +class DataErrorType(DescriptiveEnum): + POISSON = (0, 'Poisson error') + SUMW2 = (1, 'SumW2 error') + NONE = (2, 'No error') + AUTO = (4, 'Error determined automatically') + class PdfFitTool(AbstractObject): def __init__(self, pdf:"ROOT.RooAbsPdf", data:"ROOT.RooAbsData", @@ -36,26 +42,61 @@ class PdfFitTool(AbstractObject): model:"ROOT.RooAbsPdf", data:"ROOT.RooAbsData", bin_range:Optional[Union[ArrayLike, str]]=None, + fit_range:Optional[str]=None, nbins:Optional[int]=None, n_float_params:int=0, + data_error:str='poisson' ) -> Dict[str, Any]: ROOT = cached_import("ROOT") observable = self.get_observable(data) if nbins is None: nbins = observable.numBins() if bin_range is None: - bin_range = observable.getRange('fitRange') - bin_low, bin_high = bin_range.first, bin_range.second + range_name = '' + bin_low, bin_high = observable.getMin(), observable.getMax() elif isinstance(bin_range, str): - bin_range = observable.getRange(bin_range) - bin_low, bin_high = bin_range.first, bin_range.second + range_name = bin_range + if not observable.hasRange(range_name): + raise RuntimeError( + f'Observable "{observable.GetName()}" does not have a range named "{range_name}"' + ) + obs_bin_range = observable.getRange(bin_range) + bin_low, bin_high = obs_bin_range.first, obs_bin_range.second else: bin_low, bin_high = bin_range + range_name = unique_string() + observable.setRange(range_name, bin_low, bin_high) # +1 is there to account for the normalization that is done internally in RootFit ndf = nbins - (n_float_params + 1) frame = observable.frame(bin_low, bin_high, nbins) - data.plotOn(frame) - model.plotOn(frame) + data_args = [] + model_args = [ROOT.RooFit.Range(range_name)] + if fit_range is not None: + data_range = fit_range + data_args.append(ROOT.RooFit.CutRange(fit_range)) + model_args.append(ROOT.RooFit.NormRange(fit_range)) + else: + data_range = range_name + data_error = DataErrorType.parse(data_error) + if data_error == DataErrorType.POISSON: + data_error_code = ROOT.RooAbsData.ErrorType.Poisson + elif data_error == DataErrorType.SUMW2: + data_error_code = ROOT.RooAbsData.ErrorType.SumW2 + elif data_error == DataErrorType.AUTO: + data_error_code = ROOT.RooAbsData.ErrorType.Auto + elif data_error == DataErrorType.NONE: + data_error_code = getattr(ROOT.RooAbsData.ErrorType, 'None') + else: + raise ValueError( + f'Unsupported data error value: {data_error}' + ) + data_args.append(ROOT.RooFit.DataError(data_error_code)) + n_data = data.sumEntries('', data_range) + model_args.append(ROOT.RooFit.Normalization(n_data, ROOT.RooAbsReal.NumEvent)) + data.plotOn(frame, *data_args) + model.plotOn(frame, *model_args) + curve = frame.findObject('', ROOT.RooCurve.Class()) + hist = frame.findObject('', ROOT.RooHist.Class()) chi2_reduced = frame.chiSquare(n_float_params) chi2 = chi2_reduced * ndf pvalue = ROOT.TMath.Prob(chi2, ndf) @@ -73,8 +114,10 @@ class PdfFitTool(AbstractObject): def get_fit_stats( self, bin_range:Optional[Union[ArrayLike, str]]=None, + fit_range:Optional[str]=None, nbins:Optional[int]=None, - n_float_params:int=0 + n_float_params:int=0, + data_error:str='poisson' ) -> Dict[str, Any]: """ Parameters @@ -86,7 +129,14 @@ class PdfFitTool(AbstractObject): Number of floating parameters in the fit. This decreases the number of degrees of freedom used in chi2 calculation. """ - return self._get_fit_stats(self.pdf, self.data, nbins=nbins, n_float_params=n_float_params, bin_range=bin_range) + return self._get_fit_stats( + self.pdf, self.data, + bin_range=bin_range, + fit_range=fit_range, + nbins=nbins, + n_float_params=n_float_params, + data_error=data_error + ) @semistaticmethod def print_fit_stats(self, fit_stats: Dict) -> None: @@ -119,7 +169,8 @@ class PdfFitTool(AbstractObject): print_level:int=-1, use_asym_error: bool = True, eval_bin_range:Optional[Union[ArrayLike, str]]=None, - eval_nbins:Optional[int]=None + eval_nbins:Optional[int]=None, + data_error:str='poisson' ): ROOT = cached_import("ROOT") @@ -200,8 +251,10 @@ class PdfFitTool(AbstractObject): n_float_params = final_fit_result.parameters.size fit_stats = self.get_fit_stats( bin_range=eval_bin_range, + fit_range=range_name, nbins=eval_nbins, - n_float_params=n_float_params + n_float_params=n_float_params, + data_error=data_error ) self.print_fit_stats(fit_stats) final_fit_result.set_stats(fit_stats) diff --git a/quickstats/plots/variable_distribution_plot.py b/quickstats/plots/variable_distribution_plot.py index 3592802427e89e6106bf066dbf869c98ae82a199..d86cee3158f22ccf6c0e608af1d6aa9f837afc21 100644 --- a/quickstats/plots/variable_distribution_plot.py +++ b/quickstats/plots/variable_distribution_plot.py @@ -912,7 +912,6 @@ class VariableDistributionPlot(HistogramPlot): ax, target=target, hist_options=hist_options, - show_error=show_error, **options, **data_options )