Tool_ModeDiscriminator.cxx 10.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
/*
  Copyright (C) 2002-2017 CERN for the benefit of the ATLAS collaboration
*/

///////////////////////////////////////////////////////////////////
//   Implementation file for class Tool_ModeDiscriminator
///////////////////////////////////////////////////////////////////
// (c) ATLAS Detector software
///////////////////////////////////////////////////////////////////
// Tool for PID of TauSeeds
///////////////////////////////////////////////////////////////////
// sebastian.fleischmann@cern.ch
///////////////////////////////////////////////////////////////////

//! C++
#include <string>

//! PanTau includes
#include "PanTauAlgs/Tool_ModeDiscriminator.h"
#include "PanTauAlgs/Tool_InformationStore.h"
21
22
#include "PanTauAlgs/TauFeature.h"
#include "PanTauAlgs/PanTauSeed.h"
23
#include "PanTauAlgs/HelperFunctions.h"
24
25
26
27

//! Root
#include "TString.h"
#include "TFile.h"
28
#include "TTree.h"
29
30
31
32
33
34
35
#include "TH1F.h"

//!Other
#include "PathResolver/PathResolver.h"


PanTau::Tool_ModeDiscriminator::Tool_ModeDiscriminator(
36
37
    const std::string& name ) :
        asg::AsgTool(name),
38
39
40
        m_Name_InputAlg("InvalidInputAlg"),
        m_Name_ModeCase("InvalidModeCase"),
        m_Tool_InformationStore("PanTau::Tool_InformationStore/Tool_InformationStore"),
41
        m_MVABDT_List()
42
43
44
45
{
    declareProperty("Name_InputAlg",            m_Name_InputAlg,            "Name of the input algorithm for this instance");
    declareProperty("Name_ModeCase",            m_Name_ModeCase,            "Name of the two modes to be distinguished for this instance");
    declareProperty("Tool_InformationStore",    m_Tool_InformationStore,    "Handle to the information store tool");
46
    declareProperty("Tool_InformationStoreName",m_Tool_InformationStoreName,"Handle to the information store tool");
47
48
49
50
51
52
53
54
55
56
57
58
}



PanTau::Tool_ModeDiscriminator::~Tool_ModeDiscriminator() {
}



StatusCode PanTau::Tool_ModeDiscriminator::initialize() {

    ATH_MSG_DEBUG( name() << " initialize()" );
59
    m_init=true;
60
61
62
63
    
    ATH_MSG_DEBUG("InputAlg   : "  << m_Name_InputAlg);
    ATH_MSG_DEBUG("Mode Case  : "  << m_Name_ModeCase);
    
64
65
    ATH_CHECK( HelperFunctions::bindToolHandle( m_Tool_InformationStore, m_Tool_InformationStoreName ) );

66
    ATH_CHECK(m_Tool_InformationStore.retrieve());
67
68
69
    
    //get the required information from the informationstore tool
    ATH_MSG_DEBUG("Get infos from information store & configure...");
70
71
72
    ATH_CHECK( m_Tool_InformationStore->getInfo_VecDouble("ModeDiscriminator_BinEdges_Pt", m_BinEdges_Pt));
    ATH_CHECK( m_Tool_InformationStore->getInfo_String("ModeDiscriminator_ReaderOption", m_ReaderOption) );
    ATH_CHECK( m_Tool_InformationStore->getInfo_String("ModeDiscriminator_TMVAMethod", m_MethodName) );
73
74
75
76
    
    //build the name of the variable that contains the variable list for this discri tool
    std::string varNameList_Prefix  = "ModeDiscriminator_BDTVariableNames_";
    std::string varNameList_Full    = varNameList_Prefix + m_Name_InputAlg + "_" + m_Name_ModeCase;
77
    ATH_CHECK( m_Tool_InformationStore->getInfo_VecString(varNameList_Full, m_List_BDTVariableNames) );
78
79
80
    
    std::string varDefaultValueList_Prefix  = "ModeDiscriminator_BDTVariableDefaults_";
    std::string varDefaultValueList_Full    = varDefaultValueList_Prefix + m_Name_InputAlg + "_" + m_Name_ModeCase;
81
    ATH_CHECK( m_Tool_InformationStore->getInfo_VecDouble(varDefaultValueList_Full, m_List_BDTVariableDefaultValues) );
82
83
84
85
86
87
88
89
90
91
92
93
    
    
    //consistency check:
    // Number of feature names and feature default values has to match
    if( m_List_BDTVariableDefaultValues.size() != m_List_BDTVariableNames.size() ) {
        ATH_MSG_ERROR("Number of variable names does not match number of default values! Check jobOptions!");
        return StatusCode::FAILURE;
    }
    
    //! ////////////////////////
    //! Create list of BDT variables to link to the reader
    
94
    m_List_BDTVariableValues = std::vector<float*>(0);
95
    for(unsigned int iVar=0; iVar<m_List_BDTVariableNames.size(); iVar++) {
96
        m_List_BDTVariableValues.push_back(new float(0));
97
98
99
100
101
102
103
104
105
106
107
108
109
    }
    
    
    //! ////////////////////////
    //! Create reader for each pT Bin
    
    unsigned int nPtBins = m_BinEdges_Pt.size() - 1; // nBins =  Edges-1
    for(unsigned int iPtBin=0; iPtBin<nPtBins; iPtBin++) {
        ATH_MSG_DEBUG("PtBin " << iPtBin+1 << " / " << nPtBins);
        
        double bin_lowerVal         = m_BinEdges_Pt[iPtBin];
        double bin_upperVal         = m_BinEdges_Pt[iPtBin+1];
        
110
111
        std::string bin_lowerStr    = m_HelperFunctions.convertNumberToString(bin_lowerVal/1000.);
        std::string bin_upperStr    = m_HelperFunctions.convertNumberToString(bin_upperVal/1000.);
112
113
114
115
116
117
118
119
120
121
122
123
124
        
        std::string curPtBin        = "ET_" + bin_lowerStr + "_" + bin_upperStr;
        std::string curModeCase     = m_Name_ModeCase;
        
        
        //! ////////////////////////
        //! weight files
        ATH_MSG_DEBUG("\tGet the weight file");
        std::string curWeightFile = "";
        curWeightFile += "TrainModes_";
        curWeightFile += m_Name_InputAlg + "_";
        curWeightFile += curPtBin + "_";
        curWeightFile += curModeCase + "_";
125
126
127
128
129
130
131
        curWeightFile += m_MethodName + ".weights.root";

	#ifndef XAOD_ANALYSIS
	std::string resolvedWeightFileName = PathResolver::find_file(curWeightFile, "DATAPATH");
	#else
	std::string resolvedWeightFileName = PathResolverFindCalibFile("PanTauAlgs/weights/"+curWeightFile);
	#endif
132
133
134
135
136
137
138
139
140
141
        if(resolvedWeightFileName == "") {
            ATH_MSG_ERROR("Weight file " << curWeightFile << " not found!");
            return StatusCode::FAILURE;
        }
        
        ATH_MSG_DEBUG("\t\tAdded weight file: " << resolvedWeightFileName);
        
        
        //! ////////////////////////
        //! TMVA Readers
142
        ATH_MSG_DEBUG("\tCreate MVAUtils::BDT");
143
144
145
146
        
        //setup variables for reader
        for(unsigned int iVar=0; iVar<m_List_BDTVariableNames.size(); iVar++) {
            TString variableNameForReader = "tau_pantauFeature_" + m_Name_InputAlg + "_" + m_List_BDTVariableNames[iVar];
147
148
            ATH_MSG_DEBUG("\t\tAdding variable to reader: " << variableNameForReader << " var stored at: " << (m_List_BDTVariableValues[iVar]));
            //curReader->AddVariable(variableNameForReader, &(m_List_BDTVariableValues[iVar]));
149
        }//end loop over variables
150
151
152
153
154

	TFile* fBDT = TFile::Open(resolvedWeightFileName.c_str());
	TTree* tBDT = dynamic_cast<TTree*> (fBDT->Get("BDT"));
	MVAUtils::BDT* curBDT = new MVAUtils::BDT(tBDT);
	curBDT->SetPointers(m_List_BDTVariableValues);
155
        
156
157
        ATH_MSG_DEBUG("\t\tStoring new MVAUtils::BDT at " << curBDT);
        m_MVABDT_List.push_back(curBDT);
158
        
159
    }//end loop over pt bins to get weight files, reference hists and MVAUtils::BDT objects
160
161
162
163
164
165
166
167
168
    
    return StatusCode::SUCCESS;
}



StatusCode PanTau::Tool_ModeDiscriminator::finalize() {
    
    //delete the readers
169
170
171
    for(unsigned int iReader=0; iReader<m_MVABDT_List.size(); iReader++) {
        MVAUtils::BDT* curBDT = m_MVABDT_List[iReader];
        if(curBDT != 0) delete curBDT;
172
    }
173
174
175
176
    m_MVABDT_List.clear();
    for( float* f : m_List_BDTVariableValues ) delete f;
    m_List_BDTVariableValues.clear();

177
178
179
180
181
    return StatusCode::SUCCESS;
}



182
void    PanTau::Tool_ModeDiscriminator::updateReaderVariables(PanTau::PanTauSeed2* inSeed) {
183
184
185
186
187
    
    //update features used in MVA with values from current seed
    // use default value for feature if it is not present in current seed
    //NOTE! This has to be done (even if the seed pt is bad) otherwise problems with details storage
    //      [If this for loop is skipped, it is not guaranteed that all details are set to their proper default value]
188
    PanTau::TauFeature2* seedFeatures = inSeed->getFeatures();
189
190
191
192
193
194
195
196
197
198
199
200
201
202
    ATH_MSG_DEBUG( "Update the variables that are used in the readers...");
    for(unsigned int iVar=0; iVar<m_List_BDTVariableNames.size(); iVar++) {
        std::string curVar = m_Name_InputAlg + "_" + m_List_BDTVariableNames[iVar];
        
        bool    isValid;
        double  newValue = seedFeatures->value(curVar, isValid);
        if(isValid == false) {
            ATH_MSG_DEBUG("\tUse default value as the feature (the one below this line) was not calculated");
            newValue = m_List_BDTVariableDefaultValues[iVar];
            //add this feature with its default value for the details later
            seedFeatures->addFeature(curVar, newValue);
        }
        
        ATH_MSG_DEBUG("\tUpdate variable " << curVar << " from " << m_List_BDTVariableValues[iVar] << " to " << newValue);
203
        *(m_List_BDTVariableValues[iVar]) = (float)newValue;
204
205
206
207
208
209
210
    }//end loop over BDT vars for update
    
    return;
}



211
double PanTau::Tool_ModeDiscriminator::getResponse(PanTau::PanTauSeed2* inSeed, bool& isOK) {
212
213
214
215
216
    
    ATH_MSG_DEBUG("get bdt response now");
    
    updateReaderVariables(inSeed);
    
217
    if(inSeed->isOfTechnicalQuality(PanTau::PanTauSeed2::t_BadPtValue) == true) {
218
219
220
221
222
223
224
        ATH_MSG_DEBUG("WARNING Seed has bad pt value! " << inSeed->getTauJet()->pt() << " MeV");
        isOK = false;
        return -2;
    }
    
    //get the pt bin of input Seed
    //NOTE: could be moved to decay mode determinator tool...
225
    double          seedPt  = inSeed->p4().Pt();
226
227
228
229
230
231
232
233
234
235
236
237
238
239
    int             ptBin   = -1;
    for(unsigned int iPtBin=0; iPtBin<m_BinEdges_Pt.size()-1; iPtBin++) {
        if(seedPt > m_BinEdges_Pt[iPtBin] && seedPt < m_BinEdges_Pt[iPtBin+1]) {
            ptBin = iPtBin;
            break;
        }
    }
    if(ptBin == -1) {
        ATH_MSG_WARNING("Could not find ptBin for tau seed with pt " << seedPt);
        isOK = false;
        return -2.;
    }
    
    //get mva response
240
241
242
    MVAUtils::BDT*   curBDT   = m_MVABDT_List[ptBin];
    if(curBDT == 0) {
        ATH_MSG_ERROR("MVAUtils::BDT object for current tau seed points to 0");
243
244
245
246
247
248
249
250
        isOK = false;
        return -2.;
    }
    
    
//     ATH_MSG_DEBUG("Values of BDT Variables: ");
//     for(unsigned int iVar=0; iVar<m_List_BDTVariableNames.size(); iVar++) ATH_MSG_WARNING(m_List_BDTVariableNames.at(iVar) << ": " << m_List_BDTVariableValues.at(iVar) << " (stored at " << &(m_List_BDTVariableValues.at(iVar)) << ")");
    
251
    double  mvaResponse     = curBDT->GetGradBoostMVA(m_List_BDTVariableValues);
252
253
254
255
256
257
258
259
260
    ATH_MSG_DEBUG("MVA response from " << m_MethodName << " in " << m_Name_ModeCase << " is " << mvaResponse);
    
    isOK = true;
    return mvaResponse;
}