###############################################################################
# (c) Copyright 2021-2024 CERN for the benefit of the LHCb Collaboration      #
#                                                                             #
# This software is distributed under the terms of the GNU General Public      #
# Licence version 3 (GPL Version 3), copied verbatim in the file "COPYING".   #
#                                                                             #
# In applying this licence, CERN does not waive the privileges and immunities #
# granted to it by virtue of its status as an Intergovernmental Organization  #
# or submit itself to any jurisdiction.                                       #
###############################################################################
"""
Make a tuple with DaVinci from the output of the sprucing

    lb-run davinci/v64r13 lbexec funtuple:main funtuple.yml
"""
from PyConf.reading import get_particles, get_pvs
import Functors as F
from DaVinciMCTools import MCTruthAndBkgCat
from FunTuple import FunctorCollection
from FunTuple import FunTuple_Particles as Funtuple
import FunTuple.functorcollections as FC
from DaVinci.algorithms import create_lines_filter
from PyConf.Algorithms import VoidFilter
from DaVinci import Options, make_config
from Configurables import LHCb__ParticlePropertySvc

# The list of known parent particles
PARENT_PARTICLES = ['KS0']

# Definition of the branches for each line
LINES = {
    # control
    'Hlt2RD_KS0ToPiPi': dict(
        type='reco_vertex',
        fields={
            'KS0': 'KS0 -> pi+ pi-',
            'pip': 'KS0 -> ^pi+ pi-',
            'pim': 'KS0 -> pi+ ^pi-',
        }
    )
}


def configure_tuple(line, fields, type):

    parent_variables = {
        'reco_vertex': FunctorCollection(
        {
            "ID": F.PARTICLE_ID,
            "KEY": F.OBJECT_KEY,
            "PT": F.PT,
            "PX": F.PX,
            "PY": F.PY,
            "PZ": F.PZ,
            "ENERGY": F.ENERGY,
            "P": F.P,
            "FOURMOMENTUM": F.FOURMOMENTUM,
            "OWNPVDIRA": F.OWNPVDIRA,
            "OWNPVFDCHI2": F.OWNPVFDCHI2,
            "OWNPVIPCHI2": F.OWNPVIPCHI2,
        }),
        # if a parent particle does not have vertex information (presumably because
        # it was created from neutral objects) then we can not ask for quantities
        # like the DIRA
        'neutral': FunctorCollection(
        {
            "ID": F.PARTICLE_ID,
            "KEY": F.OBJECT_KEY,
            "PT": F.PT,
            "PX": F.PX,
            "PY": F.PY,
            "PZ": F.PZ,
            "ENERGY": F.ENERGY,
            "P": F.P,
            "FOURMOMENTUM": F.FOURMOMENTUM,
    })}

    decay_product_variables = FunctorCollection(
        {
            "ID": F.PARTICLE_ID,
            "PT": F.PT,
            "PX": F.PX,
            "PY": F.PY,
            "PZ": F.PZ,
            "ENERGY": F.ENERGY,
            "P": F.P,
            "FOURMOMENTUM": F.FOURMOMENTUM,
        }
    )

    variables = {name: (parent_variables[type] if name in PARENT_PARTICLES else decay_product_variables) for name in fields.keys()}

    particles = get_particles(f'/Event/HLT2/{line}/Particles')

    # get configured "MCTruthAndBkgCatAlg" algorithm for HLT2 output
    MCTRUTH = MCTruthAndBkgCat(particles, name=f'MCTruthAndBkgCat_hlt2_{line}')
    trueid_bkgcat_info = {
        # Important note: specify an invalid value for integer functors if there exists no truth info.
        #                 The invalid value for floating point functors is set to nan.
        "TRUEID": F.VALUE_OR(0) @ MCTRUTH(F.PARTICLE_ID),
        "TRUEKEY": F.VALUE_OR(-1) @ MCTRUTH(F.OBJECT_KEY),
        "TRUEPT": MCTRUTH(F.PT),
        "TRUEPX": MCTRUTH(F.PX),
        "TRUEPY": MCTRUTH(F.PY),
        "TRUEPZ": MCTRUTH(F.PZ),
        "TRUEENERGY": MCTRUTH(F.ENERGY),
        "TRUEP": MCTRUTH(F.P),
        "TRUEFOURMOMENTUM": MCTRUTH(F.FOURMOMENTUM),
        "BKGCAT": MCTRUTH.BkgCat,
    }
    for field in variables.keys():
        variables[field] += FunctorCollection(trueid_bkgcat_info)

    return Funtuple(
        name=line,
        tuple_name="DecayTree",
        fields=fields,
        variables=variables,
        inputs=particles)


def main(options: Options):

    pvs = get_pvs()

    hlt2rd_filter = create_lines_filter(name="HDRFilter_Hlt2RD", lines=list(LINES.keys()))

    tuples = [configure_tuple(line, **config) for line, config in LINES.items()]

    return make_config(options, [hlt2rd_filter] + tuples)