###############################################################################
# (c) Copyright 2023 CERN for the benefit of the LHCb Collaboration           #
#                                                                             #
# This software is distributed under the terms of the GNU General Public      #
# Licence version 3 (GPL Version 3), copied verbatim in the file "COPYING".   #
#                                                                             #
# In applying this licence, CERN does not waive the privileges and immunities #
# granted to it by virtue of its status as an Intergovernmental Organization  #
# or submit itself to any jurisdiction.                                       #
###############################################################################
import argparse

from Moore import run_reconstruction, Options
from Moore.config import Reconstruction
from PyConf.Algorithms import (
    fromPrVeloTracksV1TracksMerger,  
    VPHitEfficiencyMonitor,
    VeloRetinaClusterTrackingSIMDFull,
    TrackEventFitter,
    TrackListRefiner,
    TrackSelectionToContainer,
    DeterministicPrescaler, 
)
from PyConf.Tools import (
    TrackMasterFitter,
    TrackInterpolator,
    TrackLinearExtrapolator,
)
from PyConf.application import( 
    default_raw_banks, 
    make_odin, 
    default_raw_event, 
)
from RecoConf.legacy_rec_hlt1_tracking import (
    make_reco_pvs,
    make_PatPV3DFuture_pvs,
    make_RetinaClusters,
    get_default_ut_clusters, 
    get_global_clusters_on_track_tool_only_velo, 
)
from RecoConf.hlt2_tracking import (
    TrackBestTrackCreator,
    make_PrStoreUTHit_empty_hits,
    get_global_measurement_provider,
    get_track_master_fitter, 
)
from RecoConf.decoders import default_VeloCluster_source
from PyConf import configurable
import Functors as F 

make_reco_pvs.global_bind(make_pvs_from_velo_tracks=make_PatPV3DFuture_pvs)

velo_clusters_algorithm_light = make_RetinaClusters
velo_tracking_algorithm = VeloRetinaClusterTrackingSIMDFull
bankType = "VPRetinaCluster"
default_VeloCluster_source.global_bind(bank_type="VPRetinaCluster")

# get TMF
def get_my_track_master_fitter():
    with TrackMasterFitter.bind(
            MaxNumberOutliers=2,
            NumberFitIterations=10,
            FastMaterialApproximation=True):
        return get_track_master_fitter(
            clusters_on_track_tool=get_global_clusters_on_track_tool_only_velo)

@configurable
def make_my_sequence(
    *, 
    beginSensor = 0, 
    endSensor = 1, 
    ):  
    data = []

    filterList = []
    
    prescaler = DeterministicPrescaler(
            name="MyEventPrescaler",
            AcceptFraction=float(
                0.1, 
            ),  # make sure prescale is not interpreted as 'int' because it changes the hash computation...
            SeedName="KSPrescaler",
            ODINLocation=make_odin()) 
    
    filterList += [prescaler]
   
    with get_default_ut_clusters.bind(disable_ut=True),\
         get_global_measurement_provider.bind(
            ignoreUT=True,
            velo_hits=velo_clusters_algorithm_light,
            ignoreFT=True,
            ignoreMuon=True,
            ut_hits=make_PrStoreUTHit_empty_hits):
            
        for sensor_under_study in range(beginSensor, endSensor + 1):
            
            my_mask = [sensor in [sensor_under_study]
                    for sensor in range(208)] 
            
            vpClustering = velo_tracking_algorithm(
                RawBanks=default_raw_banks(bankType),
                SensorMasks=tuple(my_mask),
                MaxScatterSeeding=0.1,
                MaxScatterForwarding=0.1,
                MaxScatter3hits=0.02,
                SeedingWindow=10, 
                SkipForward=8)
            
            clusters = vpClustering.HitsLocation
            vpTracks = vpClustering.TracksLocation
            vpTracks_backwards = vpClustering.TracksBackwardLocation

            vpTracks_v1 = fromPrVeloTracksV1TracksMerger(  # converts Pr -> v1 tracks and merges forward/backward
                InputTracksLocation1=vpTracks,
                InputTracksLocation2=vpTracks_backwards, 
                TrackAddClusterTool=get_global_clusters_on_track_tool_only_velo(),
                ).OutputTracksLocation

            my_TrackMasterFitter = get_my_track_master_fitter()

            fittedTracks = TrackEventFitter(
                TracksInContainer=vpTracks_v1,
                Fitter=(my_TrackMasterFitter),
                MaxChi2DoF=2.8,
                name="TrackEventFitter_{hash}").TracksOutContainer

            bestTracks = TrackBestTrackCreator(
                name="TrackBestTrackCreator_{hash}",
                TracksInContainers=[fittedTracks],
                DoNotRefit=True,
                AddGhostProb=False,
                FitTracks=False,
                MaxChi2DoF=2.8,
            ).TracksOutContainer

            tracks_selection = TrackListRefiner(
            inputLocation=bestTracks,
            Code=F.require_all(
                F.NVPHITS >= 3,
                F.ETA >= 1.3, 
                F.P >= 0,  # Here to make it explicit
                F.PT >= 0)).outputLocation
            
            filtered_tracks = TrackSelectionToContainer(
                name="TrackSelectionToContainer_{hash}",
                InputLocation=tracks_selection).OutputLocation
            
            trackExtrapolator = TrackLinearExtrapolator()
            trackInterpolator = TrackInterpolator(Extrapolator=trackExtrapolator)
            
            my_vp_efficiency_alg = VPHitEfficiencyMonitor(
                name="VPHitEfficiencyMonitorSensor_{0}".format(
                    sensor_under_study),
                TrackLocation=filtered_tracks,
                PrVPHitsLocation=clusters,
                UseAandCSide=False, 
                MaxTrackCov=100.0,
                SensorUnderStudy=sensor_under_study,
                Interpolator=trackInterpolator,
                Extrapolator=trackExtrapolator,
                ExpertMode=False, 
                ResidualTolerance=0.4,  
            )
            data += [my_vp_efficiency_alg]

    return Reconstruction('hlt2_hit_eff_reco', data, filters = filterList)

from PyConf.application import default_raw_event


def main(options: Options, *args):
    parser = argparse.ArgumentParser(description="Moore HLT2 Hit Efficiency")
    parser.add_argument("stream", choices=["turcal_rawbanks", "turcal_persistrecorawbanks", "nobias"])
    parser.add_argument("begin_sensor", type=int)
    parser.add_argument("end_sensor", type=int)
    args = parser.parse_args(args)

    # DO NOT COPY THIS THIS WITHOUT CONSULTING DPA-WP2 FIRST
    # This is a horrible hack to make it so the histograms are written to the
    # file that is expected by Analysis Productions. Ideally we should have a
    # better way of managing output/ntuples/histograms but it's currently an
    # extremely niche use case.
    
    raw_options = options.dict()
    raw_options["histo_file"] = raw_options["ntuple_file"]
    raw_options["ntuple_file"] = None
    options = Options(**raw_options)
    # DO NOT COPY THIS THIS WITHOUT CONSULTING DPA-WP2 FIRST

    default_raw_event.global_bind(stream=args.stream)
    make_my_sequence.global_bind(beginSensor=args.begin_sensor, endSensor=args.end_sensor)
    config = run_reconstruction(options, make_my_sequence)
    return config