from __future__ import print_function

# Copyright (C) 2002-2019 CERN for the benefit of the ATLAS collaboration

from future import standard_library
from builtins import map
from builtins import str
from builtins import zip
from builtins import object
import sys

# useFrontierClient 
# True: one uses the python bindings of frontier_client from the TrigConfDBConnection package
# False: one uses a purely python-based implementation
useFrontierClient = False

def getServerUrls(frontier_servers):
    from re import findall
    return findall(r'\(serverurl=(.*?)\)',frontier_servers)

def testUrl(url):
    import urllib.request, urllib.error, urllib.parse
    except urllib.error.URLError:
        return False
    return True

def resolveUrl(url):
    Expects input string to be a URL or $FRONTIER_SERVER
    Returns an accessible URL or None"""
    import re
    if re.match("http://",url): # simple URL specification http://...
        return url if testUrl(url) else None

    if re.match(r'\(serverurl=(.*?)\)',url): # syntax of FRONTIER_SERVER
        for url in getServerUrls(url):
            if testUrl(url):
                return url     
        return None

def getFrontierCursor(url, schema, loglevel = logging.INFO):
    log = logging.getLogger( "" )
    url = resolveUrl(url)
    if url is None:
        log.warning("Cannot find a valid frontier connection, will not return a Frontier cursor")
        return None
    else:"Will use Frontier server at {url}")
    if useFrontierClient:"Using frontier_client from TrigConfDBConnection")
        return FrontierCursor2( url = url, schema = schema)
    else:"Using a pure python implementation of frontier")
# used by FrontierCursor2
def resolvebindvars(query, bindvars):
    """Replaces the bound variables :xyz with a ? in the query and
    adding the value behind a : at the end"""
    log = logging.getLogger( "" )"Query: %s", query)"bound variables: %r", bindvars)
    import re
    varsextract = re.findall(':([A-z0-9]*)',query)
    values = list(map(bindvars.get, varsextract))
    log.debug("Resolving bound variable %r with %r", varsextract,values)
    appendix = ":".join([str(v) for v in values])
    queryWithQuestionMarks = re.sub(':[A-z0-9]*','?', query)
    query = queryWithQuestionMarks + ':' + appendix"Resolved query new style: %s", query)
    return query

# used by FrontierCursor
def replacebindvars(query, bindvars):
    """Replaces the bound variables with the specified values,
    disables variable binding
    log = logging.getLogger( "" )
    from builtins import int
    for var,val in list(bindvars.items()):
        if query.find(":%s" % var)<0:
            raise NameError("variable '%s' is not a bound variable in this query: %s" % (var, query) )
        if isinstance (val, int):
            query = query.replace(":%s" % var,"%s" % val)
            query = query.replace(":%s" % var,"%r" % val)
        log.debug("Resolving bound variable '%s' with %r", var,val)
    log.debug("Resolved query: %s", query)
    return query


    def __init__(self, url, schema, refreshFlag=False):
        log = logging.getLogger( "" )
        self.url = url
        self.schema = schema
        self.refreshFlag = refreshFlag
        from TrigConfDBConnection import frontier_client as fc
        log.debug("Frontier URL      : %s", self.url)
        log.debug("Schema            : %s", self.schema)
        log.debug("Refresh cache     : %s", self.refreshFlag)
            query = resolvebindvars(query,bindvars)
        from TrigConfDBConnection import frontier_client as fc
        log = logging.getLogger( "" )
        log.debug("Executing query : %s", query)

        conn = fc.Connection(self.url)
        session = fc.Session(conn)

        doReload = self.refreshFlag
        queryStart = time.localtime()
        log.debug("Query started: %s", time.strftime("%m/%d/%y %H:%M:%S %Z", queryStart))

        t1 = time.time()
        req = fc.Request("frontier_request:1:DEFAULT", fc.encoding_t.BLOB)
        param = fc.Request.encodeParam(query)
        t2 = time.time()

        #nfield = session.getNumberOfFields()
        #print ("\nNumber of fields:", nfield, "\n")
        #print ("\nResult contains", nrec, "objects.\n")
        queryEnd = time.localtime()
        self.result = [r for r in session.getRecords2()]
        log.debug("Query ended: %s", time.strftime("%m/%d/%y %H:%M:%S %Z", queryEnd))
        log.debug("Query time: %s seconds", (t2-t1))
        log.debug("Result size: %i entries", len(self.result))
        s =  "FrontierCursor2:\n"
        s += "Using Frontier URL: %s\n" % self.url
        s += "Schema: %s\n" % self.schema
        s += "Refresh cache:  %s" % self.refreshFlag
        return s
    def __init__(self, url, schema, refreshFlag=False, doDecode=True, retrieveZiplevel="zip"):
        self.schema = schema
        self.refreshFlag = refreshFlag
        self.retrieveZiplevel = retrieveZiplevel
        self.doDecode = doDecode

    def __str__(self):
        s = "Using Frontier URL: %s\n" % self.url
        s += "Schema: %s\n" % self.schema
        s += "Refresh cache:  %s" % self.refreshFlag
        return s

    def execute(self, query, bindvars={}):
        if len(bindvars)>0:
            query = replacebindvars(query,bindvars)
        log.debug("Frontier URL  : %s", self.url)
        log.debug("Refresh cache : %s", self.refreshFlag)
        log.debug("Query         : %s", query)
        import base64, zlib, urllib.request, urllib.error, urllib.parse, time
        compQuery = zlib.compress(query.encode("utf-8"),9)
        base64Query = base64.binascii.b2a_base64(compQuery).decode("utf-8")
        encQuery = base64Query.replace("+", ".").replace("\n","").replace("/","-").replace("=","_")

        frontierRequest="%s/type=frontier_request:1:DEFAULT&encoding=BLOB%s&p1=%s" % (self.url, self.retrieveZiplevel, encQuery)
        request = urllib.request.Request(frontierRequest)

        if self.refreshFlag:
            request.add_header("pragma", "no-cache")

        frontierId = "TrigConfFrontier 1.0"
        request.add_header("X-Frontier-Id", frontierId)

        queryStart = time.localtime()
        log.debug("Query started: %s", time.strftime("%m/%d/%y %H:%M:%S %Z", queryStart))
        result = urllib.request.urlopen(request,None,10).read().decode()
        log.debug("Query ended: %s", time.strftime("%m/%d/%y %H:%M:%S %Z", queryEnd))
        log.debug("Query time: %s [seconds]", (t2-t1))
        log.debug("Result size: %i [seconds]", len(result))
        self.result = result

    def fetchall(self):
        if self.doDecode: self.decodeResult()
        return self.result

    def decodeResult(self):
        log = logging.getLogger( "" )
        from xml.dom.minidom import parseString
        import base64, zlib, curses.ascii
        #print ("Query result:\n", self.result)
        dom = parseString(self.result)
        dataList = dom.getElementsByTagName("data")
        keepalives = 0
        # Control characters represent records, but I won't bother with that now,
        # and will simply replace those by space.
        for data in dataList:
            for node in data.childNodes:
                # <keepalive /> elements may be present, combined with whitespace text
                if node.nodeName == "keepalive":
                    # this is of type Element
                    keepalives += 1
                # else assume of type Text
                if == "":
                if keepalives > 0:
                    print (keepalives, "keepalives received\n")
                    row = zlib.decompress(row).decode("utf-8")

                #Hack to get these lines to work in python 2
                if sys.version_info[0] < 3: 
                    row = row.encode('ascii', 'xmlcharrefreplace')
                endFirstRow = row.find('\x07')
                firstRow = row[:endFirstRow]
                for c in firstRow:
                    if curses.ascii.isctrl(c):
                        firstRow = firstRow.replace(c, ' ')
                fields = [x for i,x in enumerate(firstRow.split()) if i%2==0]
                types = [x for i,x in enumerate(firstRow.split()) if i%2==1]
                ptypes = []
                for t in types:
                    if t.startswith("NUMBER"):
                        if ",0" in t:

                log.debug("Fields      : %r", fields)
                log.debug("DB Types    : %r", types)
                log.debug("Python Types: %r", ptypes)
                #                pattern = re.compile("\x06\x00\x00\x00.",flags=re.S)
                #replace pattern above  more restrictive version, as longerstrings in the results
                #have a size variable in the column separate that becomes visible if the string
                #is large enough - this then broke the prevous  decoding
                pattern = re.compile("\x06\x00\x00..",flags=re.S)
                row_h = pattern.sub('.xXx.',row_h)
                row_h = row_h.replace("\x86", '.xXx.')

                row_h = row_h.split('.nNn.')
                row_h = [r.split('.xXx.') for r in row_h]

                result = []
                for r in row_h:
                    if r[0]=='': r[0:1]=[]
                    r = tuple([t(v) for t,v in zip(ptypes,r)])
                    result.append( r )

        self.result = result

def testQuery(query, bindvars):
    log = logging.getLogger( "" )
    from TriggerJobOpts.TriggerFlags import TriggerFlags as tf
    tf.triggerUseFrontier = True

    from TrigConfigSvc.TrigConfigSvcUtils import interpretConnection
    connectionParameters = interpretConnection("TRIGGERDBMC")
    cursor = getFrontierCursor( url = connectionParameters['url'], schema = connectionParameters['schema'])
    cursor.execute(query, bindvars)"Raw response:")
    cursor.decodeResult()"Decoded response:")[0][0])        
    if cursor.result[0][0] != 'MC_pp_v7':
        return 1
def testBindVarResolution(query, bindvars):
    resolvebindvars(query, bindvars)
    log = logging.getLogger( "" )

    dbalias = "TRIGGERDBMC"
    query = "select distinct HPS.HPS_NAME from ATLAS_CONF_TRIGGER_RUN2_MC.HLT_PRESCALE_SET HPS where HPS.HPS_ID = :psk"
    bindvars = { "psk": 260 }

    res = testBindVarResolution(query, bindvars) # query resolution for c++ frontier client
    res = max(res, testQuery(query, bindvars)) # pure python frontier query