Skip to content
Snippets Groups Projects
ChainMerging.py 19.5 KiB
Newer Older
# Copyright (C) 2002-2020 CERN for the benefit of the ATLAS collaboration
Catrin Bernius's avatar
Catrin Bernius committed
from AthenaCommon.Logging import logging
log = logging.getLogger( __name__ )

from TriggerMenuMT.HLTMenuConfig.Menu.MenuComponents import Chain, ChainStep, EmptyMenuSequence, RecoFragmentsPool
from TriggerMenuMT.HLTMenuConfig.Menu.MenuAlignmentTools import get_alignment_group_ordering as getAlignmentGroupOrdering
from collections import OrderedDict
Catrin Bernius's avatar
Catrin Bernius committed
from copy import deepcopy
def mergeChainDefs(listOfChainDefs, chainDict):
    #chainDefList is a list of Chain() objects
    #one for each part in the chain

    strategy = chainDict["mergingStrategy"]
    offset = chainDict["mergingOffset"]
    log.info("[mergeChainDefs] %s: Combine by using %s merging", chainDict['chainName'], strategy)
Catrin Bernius's avatar
Catrin Bernius committed

    if strategy=="parallel":
        return mergeParallel(listOfChainDefs,  offset)
Catrin Bernius's avatar
Catrin Bernius committed
    elif strategy=="serial":
        return mergeSerial(listOfChainDefs)

    elif strategy=="auto":
        ordering = getAlignmentGroupOrdering()
        merging_dict = OrderedDict()
        for ich,cConfig in enumerate(listOfChainDefs):
            chain_ag = cConfig.alignmentGroups[0]
            if chain_ag not in ordering:
                log.error("[mergeChainDefs] Alignment group %s can't be auto-merged because it's not in the grouping list!",chain_ag)
            if chain_ag in merging_dict:
                merging_dict[chain_ag] += [ich]
            else:
                merging_dict[chain_ag] = [ich]
                
        tmp_merged = []
        for ag in ordering:
            if ag not in merging_dict:
                continue
            if len(merging_dict[ag]) > 1:
                tmp_merged += [mergeParallel(list( listOfChainDefs[i] for i in merging_dict[ag] ),offset)]
            else:
                tmp_merged += [listOfChainDefs[merging_dict[ag][0]]]

        # only serial merge if necessary
        if len(tmp_merged) == 1:
            return tmp_merged[0]

        return mergeSerial(tmp_merged)
            
        
Catrin Bernius's avatar
Catrin Bernius committed
    else:
        log.error("[mergeChainDefs] Merging failed for %s. Merging strategy '%s' not known.", (listOfChainDefs, strategy))
def mergeParallel(chainDefList, offset):
Catrin Bernius's avatar
Catrin Bernius committed

    if offset != -1:
        log.error("[mergeParallel] Offset for parallel merging not implemented.")
        raise Exception("[mergeParallel] Cannot merge this chain, exiting.")

Catrin Bernius's avatar
Catrin Bernius committed
    allSteps = []
    nSteps = []
    chainName = ''
    alignmentGroups = []
Catrin Bernius's avatar
Catrin Bernius committed
    for cConfig in chainDefList:
        if chainName == '':
            chainName = cConfig.name
        elif chainName != cConfig.name:
            log.error("[mergeParallel] Something is wrong with the combined chain name: cConfig.name = %s while chainName = %s", cConfig.name, chainName)
            raise Exception("[mergeParallel] Cannot merge this chain, exiting.")
Catrin Bernius's avatar
Catrin Bernius committed
        allSteps.append(cConfig.steps)
        nSteps.append(len(cConfig.steps))
        l1Thresholds.extend(cConfig.vseeds)
        if len(cConfig.alignmentGroups) > 1:
            log.error("[mergeParallel] Parallel merging an already merged chain? This is odd! %s",cConfig.alignmentGroups)
            raise Exception("[mergeParallel] Complicated situation currently unimplemented. exiting.")
        elif len(cConfig.alignmentGroups) == 1:
            alignmentGroups.append(cConfig.alignmentGroups[0])
        else: 
            log.info("[mergeParallel] Alignment groups are empty for this combined chain - if this is not _newJO, this is not ok!")
    import itertools
    if 'zip_longest' in dir(itertools):
        from itertools import zip_longest
    else:
        from itertools import izip_longest as zip_longest
    # Use zip_longest so that we get None in case one chain has more steps than the other
    orderedSteps = list(zip_longest(*allSteps))
    combChainSteps =[]
    log.debug("[mergeParallel] len(orderedSteps): %d", len(orderedSteps))
    for chain_index in range(len(chainDefList)):
        log.debug('[mergeParallel] Chain object to merge (i.e. chainDef) %s', chainDefList[chain_index])

    for step_index, steps in enumerate(orderedSteps):
Catrin Bernius's avatar
Catrin Bernius committed
        mySteps = list(steps)
        log.debug("[mergeParallel] Merging step counter %d", step_index+1)
        combStep = makeCombinedStep(mySteps, step_index+1, chainDefList, orderedSteps, combChainSteps)
Catrin Bernius's avatar
Catrin Bernius committed
        combChainSteps.append(combStep)
                                  
    combinedChainDef = Chain(chainName, ChainSteps=combChainSteps, L1Thresholds=l1Thresholds, 
                                nSteps = nSteps, alignmentGroups = alignmentGroups)
    log.debug("[mergeParallel] Parallel merged chain %s with these steps:", chainName)
Catrin Bernius's avatar
Catrin Bernius committed
    for step in combinedChainDef.steps:
Catrin Bernius's avatar
Catrin Bernius committed

    return combinedChainDef

def getEmptySeqName(stepName, chain_index, step_number, alignGroup):
    #remove redundant instances of StepN
    if re.search('^Step[0-9]_',stepName):
        stepName = stepName[6:]

    seqName = 'Empty'+ alignGroup +'Seq'+str(step_number)+ '_'+ stepName
def getEmptyMenuSequence(flags, name, mergeUsingFeature = False):
    return EmptyMenuSequence(name, mergeUsingFeature = mergeUsingFeature)
def getMultiplicityPerLeg(multiplicities):
    mult_per_leg = []
    for mult in multiplicities:
        if mult == 1: 
            mult_per_leg += ['1']
        elif mult > 1: 
            mult_per_leg += ['N']
        else: 
            raise Exception("[serial_zip] multiplicity not an expected value: %s",mult) 
    return mult_per_leg

def isFullScanRoI(inputL1Nav):
    fsRoIList = ['HLTNav_L1FSNOSEED','HLTNav_L1MET','HLTNav_L1J']
    if inputL1Nav in fsRoIList:
        return True
    else:
        return False

def noPrecedingStepsPreMerge(newsteps,chain_index,ileg):
    for step in newsteps:
        seq = step[chain_index].sequences[ileg]
        if type(seq).__name__ == 'EmptyMenuSequence':
            continue
        else:
            #if there's a non-empty sequence in a step before, there is clearly a
            #preceding step in this chain.
            return False
    return True

def noPrecedingStepsPostMerge(newsteps, ileg):
    for step in newsteps:
        seq = step.sequences[ileg]
        if type(seq).__name__ == 'EmptyMenuSequence':
            continue
        else:
            #if there's a non-empty sequence in a step before, there is clearly a
            #preceding step in this chain.
            return False
    return True
        
def serial_zip(allSteps, chainName, chainDefList):

    legs_per_part = [len(cd.steps[0].multiplicity) for cd in chainDefList]
    n_parts = len(allSteps)
    log.debug('[serial_zip] configuring chain with %d parts with multiplicities %s', n_parts, legs_per_part)

    doBonusDebug = False

    for chain_index, chainSteps in enumerate(allSteps): #per-part (horizontal) iteration
        for step_index, step in enumerate(chainSteps):  #serial step iteration
            log.debug('[serial_zip] chain_index: %s step_index: %s', chain_index, step_index)
            # create list of correct length (chainSteps in parallel)
            stepList = [None]*n_parts

            # put the step from the current sub-chain into the right place
            stepList[chain_index] = step
            log.debug('[serial_zip] Put step: %s', step.name)
            # all other chain parts' steps should contain an empty sequence
            for chain_index2, (emptyStep, nLegs) in enumerate(zip(stepList,legs_per_part)): #more per-leg iteration
                    mult_per_leg = getMultiplicityPerLeg(chainDefList[chain_index2].steps[0].multiplicity)

                    #this WILL NOT work for jets!
                    step_mult = []
                    emptyChainDicts = []
                    if chain_index2 < chain_index:
                        emptyChainDicts = allSteps[chain_index2][-1].stepDicts
                    else:
                        emptyChainDicts = allSteps[chain_index2][0].stepDicts

                    sigNames = []
                    for emptyChainDict in emptyChainDicts:
                        if isFullScanRoI(chainDefList[chain_index2].L1decisions[0]):
                            sigNames +=[emptyChainDict['chainParts'][0]['signature']+'FS']
                        else:
                            sigNames +=[emptyChainDict['chainParts'][0]['signature']]

                    seqMultName = '_'.join([mult+sigName for mult, sigName in zip(mult_per_leg,sigNames)])
                    seqStepName = 'Empty' + chainDefList[chain_index].alignmentGroups[0]+'Align'+str(step_index+1)+'_'+seqMultName

                    seqNames = [getEmptySeqName(emptyChainDicts[iSeq]['signature'], chain_index, step_index+1, chainDefList[chain_index].alignmentGroups[0]) for iSeq in range(nLegs)]
                    if doBonusDebug:                        
                        log.debug("[serial_zip] step name for this leg: %s", seqStepName)
                        log.debug("[serial_zip] created empty sequence(s): %s", seqNames)
                        log.debug("[serial_zip] L1decisions %s ", chainDefList[chain_index2].L1decisions)

                    emptySequences = []
                    for ileg in range(nLegs):
                        if isFullScanRoI(chainDefList[chain_index2].L1decisions[0]) and noPrecedingStepsPreMerge(newsteps,chain_index2, ileg):
                            log.debug("[serial_zip] adding FS empty sequence with mergeUsingFeature = False ")
                            emptySequences += [RecoFragmentsPool.retrieve(getEmptyMenuSequence, flags=None, name=seqNames[ileg]+"FS", mergeUsingFeature = False)]
                        elif isFullScanRoI(chainDefList[chain_index2].L1decisions[0]):
                            log.debug("[serial_zip] adding FS empty sequence with mergeUsingFeature = True ")
                            emptySequences += [RecoFragmentsPool.retrieve(getEmptyMenuSequence, flags=None, name=seqNames[ileg]+"FS", mergeUsingFeature = True)]
                        else:
                            log.debug("[serial_zip] adding non-FS empty sequence")
                            emptySequences += [RecoFragmentsPool.retrieve(getEmptyMenuSequence, flags=None, name=seqNames[ileg])]

                    #this WILL NOT work for jets!
                    step_mult = []
                    emptyChainDicts = []
                    if chain_index2 < chain_index:
                        emptyChainDicts = allSteps[chain_index2][-1].stepDicts
                    else:
                        emptyChainDicts = allSteps[chain_index2][0].stepDicts

                    if doBonusDebug:
                        log.debug("[serial_zip] emptyChainDicts %s",emptyChainDicts)

                    if len(emptySequences) != len(emptyChainDicts):
                        log.error("[serial_zip] different number of empty sequences/legs %d to stepDicts %d", len(emptySequences), len(emptyChainDicts))
                        raise Exception("[serial_zip] Cannot create this chain step, exiting.")

                    for sd in emptyChainDicts:
                        if len(sd['chainParts']) != 1:
                            log.error("[serial_zip] stepDict chainParts has length != 1 within a leg! %s",sd)
                            raise Exception("[serial_zip] Cannot create this chain step, exiting.")
                        step_mult += [int(sd['chainParts'][0]['multiplicity'])] 

                    if len(emptySequences) != len(step_mult):
                        log.error("[serial_zip] different number of empty sequences/legs %d to multiplicities %d", len(emptySequences), len(step_mult))
                        raise Exception("[serial_zip] Cannot create this chain step, exiting.")

                    if doBonusDebug:
                        log.debug('[serial_zip] step multiplicity %s',step_mult)

                    stepList[chain_index2] = ChainStep( seqStepName, Sequences=emptySequences,
                                                  multiplicity = step_mult, chainDicts=emptyChainDicts,
                                                  isEmpty = True)

    log.debug('After serial_zip')
        log.debug( ', '.join(map(str, [step.name for step in s]) ) )
def mergeSerial(chainDefList):
    allSteps = []
    nSteps = []
    chainName = ''
    l1Thresholds = []
    alignmentGroups = []
    log.debug('[mergeSerial] Merge chainDefList:')
    for cConfig in chainDefList:
        if chainName == '':
            chainName = cConfig.name
        elif chainName != cConfig.name:
            log.error("[mergeSerial] Something is wrong with the combined chain name: cConfig.name = %s while chainName = %s", cConfig.name, chainName)
            raise Exception("[mergeSerial] Cannot merge this chain, exiting.")

        allSteps.append(cConfig.steps)
        nSteps.append(len(cConfig.steps))
        l1Thresholds.extend(cConfig.vseeds)
        alignmentGroups.extend(cConfig.alignmentGroups)
    serialSteps = serial_zip(allSteps, chainName, chainDefList)
    mySerialSteps = deepcopy(serialSteps)
    combChainSteps =[]
    for chain_index in range(len(chainDefList)):
        log.debug('[mergeSerial] Chain object to merge (i.e. chainDef) %s', chainDefList[chain_index])
    for step_index, steps in enumerate(mySerialSteps):
        mySteps = list(steps)
        combStep = makeCombinedStep(mySteps, step_index+1, chainDefList)
        combChainSteps.append(combStep)

    # check if all chain parts have the same number of steps
    sameNSteps = all(x==nSteps[0] for x in nSteps) 
    if sameNSteps is True:
        log.info("[mergeSerial] All chain parts have the same number of steps")
        log.info("[mergeSerial] Have to deal with uneven number of chain steps, there might be none's appearing in sequence list => to be fixed")
    combinedChainDef = Chain(chainName, ChainSteps=combChainSteps, L1Thresholds=l1Thresholds,
                               nSteps = nSteps, alignmentGroups = alignmentGroups)
    log.debug("[mergeSerial] Serial merged chain %s with these steps:", chainName)
    for step in combinedChainDef.steps:

    return combinedChainDef
def makeCombinedStep(parallel_steps, stepNumber, chainDefList, allSteps = [], currentChainSteps = []):
    from TrigCompositeUtils.TrigCompositeUtils import legName
    stepName = 'merged' #we will renumber all steps after chains are aligned #Step' + str(stepNumber)
Catrin Bernius's avatar
Catrin Bernius committed
    stepSeq = []
    log.verbose("[makeCombinedStep] steps %s ", parallel_steps)
    # this function only makes sense if we are merging steps corresponding to the chains in the chainDefList
    assert len(chainDefList)==len(parallel_steps), "[makeCombinedStep] makeCombinedStep: Length of chain defs %d does not match length of steps to merge %d" % (len(chainDefList), len(allSteps))
    leg_counter = 0
  
    for chain_index, step in enumerate(parallel_steps): #this is a horizontal merge!
Catrin Bernius's avatar
Catrin Bernius committed
        if step is None:
            # this happens for merging chains with different numbers of steps, we need to "pad" out with empty sequences to propogate the decisions
            # all other chain parts' steps should contain an empty sequence
            new_stepDict = deepcopy(chainDefList[chain_index].steps[-1].stepDicts[-1])
            
            seqName = getEmptySeqName(new_stepDict['signature'], chain_index, stepNumber, chainDefList[0].alignmentGroups[0])

            currentStepName = ''
            if isFullScanRoI(chainDefList[chain_index].L1decisions[0]):
                if noPrecedingStepsPostMerge(currentChainSteps, chain_index):
                    stepSeq.append(RecoFragmentsPool.retrieve(getEmptyMenuSequence, flags=None, name=seqName+"FS", mergeUsingFeature = False))
                else:
                    stepSeq.append(RecoFragmentsPool.retrieve(getEmptyMenuSequence, flags=None, name=seqName+'FS', mergeUsingFeature = True))
                currentStepName = 'Empty' + chainDefList[chain_index].alignmentGroups[0]+'Align'+str(stepNumber)+'_'+new_stepDict['chainParts'][0]['multiplicity']+new_stepDict['signature']+'FS'
            else:
                stepSeq.append(RecoFragmentsPool.retrieve(getEmptyMenuSequence, flags=None, name=seqName))
                currentStepName = 'Empty' + chainDefList[chain_index].alignmentGroups[0]+'Align'+str(stepNumber)+'_'+new_stepDict['chainParts'][0]['multiplicity']+new_stepDict['signature']

            log.debug("[makeCombinedStep]  step %s,  empty sequence %s", currentStepName, seqName)

            #stepNumber is indexed from 1, need the previous step indexed from 0, so do - 2
            prev_step_mult = int(currentChainSteps[stepNumber-2].multiplicity[chain_index])
            stepMult.append(prev_step_mult)
            # we need a chain dict here, use the one corresponding to this leg of the chain
            oldLegName = new_stepDict['chainName']
            if re.search('^leg[0-9]{3}_',oldLegName):
                oldLegName = oldLegName[7:]
            new_stepDict['chainName'] = legName(oldLegName,leg_counter)
            stepDicts.append(new_stepDict)
            leg_counter += 1
            log.debug("[makeCombinedStep]  step %s, multiplicity  = %s", step.name, str(step.multiplicity))
                log.debug("[makeCombinedStep]    with sequences = %s", ' '.join(map(str, [seq.name for seq in step.sequences])))

            # this function only works if the input chains are single-object chains (one menu seuqnce)
            if len(step.sequences) > 1:
                log.debug("[makeCombinedStep] combining in an already combined chain")

            currentStepName = step.name
            #remove redundant instances of StepN_ and merged_ (happens when merging already merged chains)
            if re.search('^Step[0-9]_',currentStepName):
                currentStepName = currentStepName[6:]
            if re.search('^merged_',currentStepName):
                currentStepName = currentStepName[7:]
            stepSeq.extend(step.sequences)
            if len(step.multiplicity) == 0:
                stepMult.append(0)
            else:
                stepMult.extend(step.multiplicity)
            comboHypoTools.extend(step.comboToolConfs)
            # update the chain dict list for the combined step with the chain dict from this step
            log.debug('[makeCombinedStep] adding step dictionaries %s',step.stepDicts)
            
            for new_stepDict in deepcopy(step.stepDicts):
                oldLegName = new_stepDict['chainName']
                if re.search('^leg[0-9]{3}_',oldLegName):
                    oldLegName = oldLegName[7:]
                new_stepDict['chainName'] = legName(oldLegName,leg_counter)
                log.debug("[makeCombinedStep] stepDict naming old: %s, new: %s", oldLegName, new_stepDict['chainName'])
                stepDicts.append(new_stepDict)
                leg_counter += 1

        # the step naming for combined chains needs to be revisted!!
        stepName += '_' + currentStepName
        log.debug('[makeCombinedStep] current step name %s',stepName)
        # for merged steps, we need to update the name to add the leg name
Francesca Pastore's avatar
Francesca Pastore committed
        
    theChainStep = ChainStep(stepName, Sequences=stepSeq, multiplicity=stepMult, chainDicts=stepDicts, comboHypoCfg=comboHypo, comboToolConfs=comboHypoTools) 
    log.info("[makeCombinedStep] Merged step: \n %s", theChainStep)
Catrin Bernius's avatar
Catrin Bernius committed
    
    return theChainStep