Commit 60cf7ece authored by Marcel Rieger's avatar Marcel Rieger
Browse files

Simplify usage of snapshots.

parent 705f0c38
......@@ -11,12 +11,13 @@ import luigi
from dhi.tasks.base import BoxPlotTask, view_output_plots
from dhi.tasks.combine import MultiDatacardTask, POIScanTask, POIPlotTask
from dhi.tasks.snapshot import SnapshotUser
from dhi.tasks.limits import MergeUpperLimits, PlotUpperLimits
from dhi.tasks.likelihoods import MergeLikelihoodScan, PlotLikelihoodScan
from dhi.config import br_hh
class PlotExclusionAndBestFit(POIScanTask, MultiDatacardTask, POIPlotTask, BoxPlotTask):
class PlotExclusionAndBestFit(POIScanTask, MultiDatacardTask, POIPlotTask, SnapshotUser, BoxPlotTask):
show_best_fit = PlotLikelihoodScan.show_best_fit
show_best_fit_error = PlotLikelihoodScan.show_best_fit_error
......@@ -55,6 +56,14 @@ class PlotExclusionAndBestFit(POIScanTask, MultiDatacardTask, POIPlotTask, BoxPl
return reqs
def get_output_postfix(self, join=True):
parts = super(PlotExclusionAndBestFit, self).get_output_postfix(join=False)
if self.use_snapshot:
parts.append("fromsnapshot")
return self.join_postfix(parts) if join else parts
def output(self):
names = self.create_plot_names(["exclusionbestfit", self.get_output_postfix()])
return [self.local_target(name) for name in names]
......@@ -129,7 +138,7 @@ class PlotExclusionAndBestFit(POIScanTask, MultiDatacardTask, POIPlotTask, BoxPl
)
class PlotExclusionAndBestFit2D(POIScanTask, POIPlotTask):
class PlotExclusionAndBestFit2D(POIScanTask, POIPlotTask, SnapshotUser):
show_best_fit = PlotLikelihoodScan.show_best_fit
show_best_fit_error = copy.copy(PlotLikelihoodScan.show_best_fit_error)
......@@ -190,6 +199,14 @@ class PlotExclusionAndBestFit2D(POIScanTask, POIPlotTask):
return reqs
def get_output_postfix(self, join=True):
parts = super(PlotExclusionAndBestFit2D, self).get_output_postfix(join=False)
if self.use_snapshot:
parts.append("fromsnapshot")
return self.join_postfix(parts) if join else parts
def output(self):
names = self.create_plot_names(["exclusionbestfit2d", self.get_output_postfix()])
return [self.local_target(name) for name in names]
......
......@@ -15,9 +15,10 @@ from dhi.tasks.combine import (
POIPlotTask,
CreateWorkspace,
)
from dhi.tasks.snapshot import Snapshot, SnapshotUser
class GoodnessOfFitBase(POITask):
class GoodnessOfFitBase(POITask, SnapshotUser):
toys = luigi.IntParameter(
default=1,
......@@ -49,11 +50,19 @@ class GoodnessOfFitBase(POITask):
parts["gof"] = self.algorithm
return parts
def get_output_postfix(self, join=True):
parts = super(GoodnessOfFitBase, self).get_output_postfix(join=False)
if self.use_snapshot:
parts.append("fromsnapshot")
if self.frequentist_toys:
parts.append("freqtoys")
return self.join_postfix(parts) if join else parts
@property
def toys_postfix(self):
postfix = "t{}_pt{}".format(self.toys, self.toys_per_task)
if self.frequentist_toys:
postfix += "_freq"
return "t{}_pt{}".format(self.toys, self.toys_per_task)
return postfix
......@@ -82,23 +91,36 @@ class GoodnessOfFit(GoodnessOfFitBase, CombineCommandTask, law.LocalWorkflow, HT
def workflow_requires(self):
reqs = super(GoodnessOfFit, self).workflow_requires()
reqs["workspace"] = self.requires_from_branch()
reqs["workspace"] = CreateWorkspace.req(self)
if self.use_snapshot:
reqs["snapshot"] = Snapshot.req(self)
return reqs
def requires(self):
return CreateWorkspace.req(self)
reqs = {"workspace": CreateWorkspace.req(self)}
if self.use_snapshot:
reqs["snapshot"] = Snapshot.req(self, branch=0)
return reqs
def output(self):
parts = []
if self.branch == 0:
postfix = "b0_data"
parts.append("b0_data")
else:
postfix = "b{}_toy{}To{}".format(self.branch, self.branch_data[0], self.branch_data[-1])
if self.frequentist_toys:
postfix += "_freq"
name = self.join_postfix(["gof", self.get_output_postfix(), postfix])
parts.append("b{}_toy{}To{}".format(self.branch, self.branch_data[0], self.branch_data[-1]))
name = self.join_postfix(["gof", self.get_output_postfix(), parts])
return self.local_target(name + ".root")
def build_command(self):
# get the workspace to use and define snapshot args
if self.use_snapshot:
workspace = self.input()["snapshot"].path
snapshot_args = " --snapshotName MultiDimFit"
else:
workspace = self.input()["workspace"].path
snapshot_args = ""
# toy options
toy_opts = ""
if self.branch > 0:
......@@ -106,7 +128,8 @@ class GoodnessOfFit(GoodnessOfFitBase, CombineCommandTask, law.LocalWorkflow, HT
if self.frequentist_toys:
toy_opts += " --toysFrequentist"
return (
# build the command
cmd = (
"combine -M GoodnessOfFit {workspace}"
" {self.custom_args}"
" --verbose 1"
......@@ -119,16 +142,20 @@ class GoodnessOfFit(GoodnessOfFitBase, CombineCommandTask, law.LocalWorkflow, HT
" --setParameters {self.joined_parameter_values}"
" --freezeParameters {self.joined_frozen_parameters}"
" --freezeNuisanceGroups {self.joined_frozen_groups}"
" {snapshot_args}"
" {self.combine_optimization_args}"
" && "
"mv higgsCombineTest.GoodnessOfFit.mH{self.mass_int}.{self.branch}.root {output}"
).format(
self=self,
workspace=self.input().path,
workspace=workspace,
output=self.output().path,
toy_opts=toy_opts,
snapshot_args=snapshot_args,
)
return cmd
def htcondor_output_postfix(self):
postfix = super(GoodnessOfFit, self).htcondor_output_postfix()
return "{}__{}".format(postfix, self.toys_postfix)
......@@ -257,10 +284,7 @@ class PlotMultipleGoodnessOfFits(PlotGoodnessOfFit, MultiDatacardTask, BoxPlotTa
@property
def toys_postfix(self):
tpl_to_str = lambda tpl: "_".join(map(str, tpl))
postfix = "t{}_pt{}".format(tpl_to_str(self.toys), tpl_to_str(self.toys_per_task))
if self.frequentist_toys:
postfix += "_freq"
return postfix
return "t{}_pt{}".format(tpl_to_str(self.toys), tpl_to_str(self.toys_per_task))
def requires(self):
return [
......
......@@ -32,6 +32,14 @@ class LikelihoodBase(POIScanTask, SnapshotUser):
force_scan_parameters_equal_pois = True
allow_parameter_values_in_pois = True
def get_output_postfix(self, join=True):
parts = super(LikelihoodBase, self).get_output_postfix(join=False)
if self.use_snapshot:
parts.append("fromsnapshot")
return self.join_postfix(parts) if join else parts
class LikelihoodScan(LikelihoodBase, CombineCommandTask, law.LocalWorkflow, HTCondorWorkflow):
......@@ -54,11 +62,7 @@ class LikelihoodScan(LikelihoodBase, CombineCommandTask, law.LocalWorkflow, HTCo
return reqs
def output(self):
parts = []
if self.use_snapshot:
parts.append("fromsnapshot")
name = self.join_postfix(["likelihood", self.get_output_postfix(), parts]) + ".root"
name = self.join_postfix(["likelihood", self.get_output_postfix()]) + ".root"
return self.local_target(name)
def build_command(self):
......@@ -117,11 +121,7 @@ class MergeLikelihoodScan(LikelihoodBase):
return LikelihoodScan.req(self)
def output(self):
parts = []
if self.use_snapshot:
parts.append("fromsnapshot")
name = self.join_postfix(["likelihoods", self.get_output_postfix(), parts]) + ".npz"
name = self.join_postfix(["likelihoods", self.get_output_postfix()]) + ".npz"
return self.local_target(name)
@law.decorator.log
......@@ -245,8 +245,6 @@ class PlotLikelihoodScan(LikelihoodBase, POIPlotTask):
def output(self):
# additional postfix
parts = []
if self.use_snapshot:
parts.append("fromsnapshot")
if self.n_pois == 1 and self.y_log:
parts.append("log")
......@@ -481,8 +479,6 @@ class PlotMultipleLikelihoodScans(PlotLikelihoodScan, MultiDatacardTask):
def output(self):
# additional postfix
parts = []
if self.use_snapshot:
parts.append("fromsnapshot")
if self.n_pois == 1 and self.y_log:
parts.append("log")
......@@ -596,8 +592,6 @@ class PlotMultipleLikelihoodScansByModel(PlotLikelihoodScan, MultiHHModelTask):
def output(self):
# additional postfix
parts = []
if self.use_snapshot:
parts.append("fromsnapshot")
if self.n_pois == 1 and self.y_log:
parts.append("log")
......
......@@ -27,6 +27,14 @@ class UpperLimitsBase(POIScanTask, SnapshotUser):
force_scan_parameters_unequal_pois = True
def get_output_postfix(self, join=True):
parts = super(UpperLimitsBase, self).get_output_postfix(join=False)
if self.use_snapshot:
parts.append("fromsnapshot")
return self.join_postfix(parts) if join else parts
class UpperLimits(UpperLimitsBase, CombineCommandTask, law.LocalWorkflow, HTCondorWorkflow):
......@@ -49,11 +57,7 @@ class UpperLimits(UpperLimitsBase, CombineCommandTask, law.LocalWorkflow, HTCond
return reqs
def output(self):
parts = []
if self.use_snapshot:
parts.append("fromsnapshot")
name = self.join_postfix(["limit", self.get_output_postfix(), parts]) + ".root"
name = self.join_postfix(["limit", self.get_output_postfix()]) + ".root"
return self.local_target(name)
def build_command(self):
......@@ -132,11 +136,7 @@ class MergeUpperLimits(UpperLimitsBase):
return UpperLimits.req(self)
def output(self):
parts = []
if self.use_snapshot:
parts.append("fromsnapshot")
name = self.join_postfix(["limits", self.get_output_postfix(), parts]) + ".npz"
name = self.join_postfix(["limits", self.get_output_postfix()]) + ".npz"
return self.local_target(name)
@law.decorator.log
......@@ -231,8 +231,6 @@ class PlotUpperLimits(UpperLimitsBase, POIPlotTask):
def output(self):
# additional postfix
parts = []
if self.use_snapshot:
parts.append("fromsnapshot")
if self.xsec in ["pb", "fb"]:
parts.append(self.xsec)
if self.br != law.NO_STR:
......@@ -376,8 +374,6 @@ class PlotMultipleUpperLimits(PlotUpperLimits, MultiDatacardTask):
def output(self):
# additional postfix
parts = []
if self.use_snapshot:
parts.append("fromsnapshot")
if self.xsec in ["pb", "fb"]:
parts.append(self.xsec)
if self.br != law.NO_STR:
......@@ -494,8 +490,6 @@ class PlotMultipleUpperLimitsByModel(PlotUpperLimits, MultiHHModelTask):
def output(self):
# additional postfix
parts = []
if self.use_snapshot:
parts.append("fromsnapshot")
if self.xsec in ["pb", "fb"]:
parts.append(self.xsec)
if self.br != law.NO_STR:
......@@ -726,11 +720,17 @@ class PlotUpperLimitsAtPoint(POIPlotTask, SnapshotUser, MultiDatacardTask, BoxPl
for datacards in self.multi_datacards
]
def get_output_postfix(self, join=True):
parts = super(PlotUpperLimitsAtPoint, self).get_output_postfix(join=False)
if self.use_snapshot:
parts.append("fromsnapshot")
return self.join_postfix(parts) if join else parts
def output(self):
# additional postfix
parts = []
if self.use_snapshot:
parts.append("fromsnapshot")
if self.xsec in ["pb", "fb"]:
parts.append(self.xsec)
if self.br != law.NO_STR:
......@@ -888,8 +888,6 @@ class PlotUpperLimits2D(UpperLimitsBase, POIPlotTask):
def output(self):
# additional postfix
parts = []
if self.use_snapshot:
parts.append("fromsnapshot")
if self.z_log:
parts.append("log")
......
......@@ -73,10 +73,16 @@ class FitDiagnostics(POITask, CombineCommandTask, SnapshotUser, law.LocalWorkflo
reqs["snapshot"] = Snapshot.req(self, branch=0)
return reqs
def output(self):
parts = []
def get_output_postfix(self, join=True):
parts = super(FitDiagnostics, self).get_output_postfix(join=False)
if self.use_snapshot:
parts.append("fromsnapshot")
return self.join_postfix(parts) if join else parts
def output(self):
parts = []
if not self.skip_b_only:
parts.append("withBOnly")
if self.skip_save:
......@@ -141,7 +147,18 @@ class FitDiagnostics(POITask, CombineCommandTask, SnapshotUser, law.LocalWorkflo
)
class PlotPostfitSOverB(POIPlotTask):
class PostfitPlotBase(POIPlotTask, SnapshotUser):
def get_output_postfix(self, join=True):
parts = super(PostfitPlotBase, self).get_output_postfix(join=False)
if self.use_snapshot:
parts.append("fromsnapshot")
return self.join_postfix(parts) if join else parts
class PlotPostfitSOverB(PostfitPlotBase):
pois = FitDiagnostics.pois
bins = law.CSVParameter(
......@@ -262,7 +279,7 @@ class PlotPostfitSOverB(POIPlotTask):
)
class PlotNuisanceLikelihoodScans(POIPlotTask):
class PlotNuisanceLikelihoodScans(PostfitPlotBase):
x_min = copy.copy(POIPlotTask.x_min)
x_max = copy.copy(POIPlotTask.x_max)
......
......@@ -38,6 +38,14 @@ class PullsAndImpactsBase(POITask, SnapshotUser):
force_n_pois = 1
allow_parameter_values_in_pois = True
def get_output_postfix(self, join=True):
parts = super(PullsAndImpactsBase, self).get_output_postfix(join=False)
if self.use_snapshot:
parts.append("fromsnapshot")
return self.join_postfix(parts) if join else parts
class PullsAndImpacts(PullsAndImpactsBase, CombineCommandTask, law.LocalWorkflow, HTCondorWorkflow):
......@@ -101,9 +109,6 @@ class PullsAndImpacts(PullsAndImpactsBase, CombineCommandTask, law.LocalWorkflow
def output(self):
parts = [self.branch_data]
if self.use_snapshot:
parts.append("fromsnapshot")
name = self.join_postfix(["fit", self.get_output_postfix(), parts]) + ".root"
return self.local_target(name)
......@@ -196,8 +201,6 @@ class MergePullsAndImpacts(PullsAndImpactsBase):
parts = []
if self.mc_stats:
parts.append("mcstats")
if self.use_snapshot:
parts.append("fromsnapshot")
if self.only_parameters:
parts.append("only_" + law.util.create_hash(sorted(self.only_parameters)))
if self.skip_parameters:
......@@ -394,8 +397,6 @@ class PlotPullsAndImpacts(PullsAndImpactsBase, POIPlotTask, BoxPlotTask):
parts = []
if self.mc_stats:
parts.append("mcstats")
if self.use_snapshot:
parts.append("fromsnapshot")
if self.only_parameters:
parts.append("only_" + law.util.create_hash(sorted(self.only_parameters)))
if self.skip_parameters:
......
......@@ -44,6 +44,8 @@ class SignificanceBase(POIScanTask, SnapshotUser):
if not self.unblinded and self.frequentist_toys:
parts.insert(0, ["postfit"])
if self.use_snapshot:
parts.append("fromsnapshot")
return self.join_postfix(parts) if join else parts
......@@ -69,11 +71,7 @@ class SignificanceScan(SignificanceBase, CombineCommandTask, law.LocalWorkflow,
return reqs
def output(self):
parts = []
if self.use_snapshot:
parts.append("fromsnapshot")
name = self.join_postfix(["significance", self.get_output_postfix(), parts]) + ".root"
name = self.join_postfix(["significance", self.get_output_postfix()]) + ".root"
return self.local_target(name)
def build_command(self):
......@@ -126,11 +124,7 @@ class MergeSignificanceScan(SignificanceBase):
return SignificanceScan.req(self)
def output(self):
parts = []
if self.use_snapshot:
parts.append("fromsnapshot")
name = self.join_postfix(["significance", self.get_output_postfix(), parts]) + ".npz"
name = self.join_postfix(["significance", self.get_output_postfix()]) + ".npz"
return self.local_target(name)
@law.decorator.log
......@@ -207,8 +201,6 @@ class PlotSignificanceScan(SignificanceBase, POIPlotTask):
def output(self):
# additional postfix
parts = []
if self.use_snapshot:
parts.append("fromsnapshot")
if self.y_log:
parts.append("log")
......@@ -299,8 +291,6 @@ class PlotMultipleSignificanceScans(PlotSignificanceScan, MultiDatacardTask):
def output(self):
# additional postfix
parts = []
if self.use_snapshot:
parts.append("fromsnapshot")
if self.y_log:
parts.append("log")
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment