import logging
import math
import random
import sqlite3
import time

_default_log = logging.getLogger(__name__)

def _solve(f, eps, x0, x1):
    u0 = f(x0)
    u1 = f(x1)
    while abs(u1) > eps:
        if u0 * u1 > 0.0:
            raise RuntimeError('Cannot solve, same sign.')
        x = 0.5 * (x0 + x1)
        u = f(x)
        if u * u0 < 0.0:
            x1 = x
            u1 = u
        else:
            x0 = x
            u0 = u
    return x1

def _func(lambda_):
    def f(pS):
        return math.exp(1.0 - lambda_ / pS)
    return f

def _norm(success_dist):
    def f(lambda_):
        return sum(_func(lambda_)(pS) for pS in success_dist.values()) - 1.0
    return f

def choice_dist_of_success_dist(success_dist):
    success_dist = dict((k, p) for (k, p) in success_dist.items() if p > 0.0)
    lambda0 = _solve(_norm(success_dist), 1e-6, 0.0, 20.0)
    f0 = _func(lambda0)
    return dict((term, f0(pS)) for (term, pS) in success_dist.items())

_create_sql = """\
CREATE TABLE arcnagios_reputation (
  dist_name text NOT NULL,
  choice_name text NOT NULL,
  update_time double precision NOT NULL,
  recent_count double precision NOT NULL,
  recent_success double precision NOT NULL,
  PRIMARY KEY (dist_name, choice_name)
)"""

_past_dist_names_sql = """\
SELECT DISTINCT dist_name FROM arcnagios_reputation
"""

_past_choice_names_sql = """\
SELECT choice_name FROM arcnagios_reputation WHERE dist_name = ?
"""

_fetch_sql = """\
SELECT choice_name, recent_count, recent_success
FROM arcnagios_reputation WHERE dist_name = ?
"""

_submit_select_sql = """\
SELECT update_time, recent_count, recent_success
FROM arcnagios_reputation WHERE dist_name = ? AND choice_name = ?
"""
_submit_insert_sql = """\
INSERT INTO arcnagios_reputation \
    (update_time, recent_count, recent_success, dist_name, choice_name)
VALUES (?, ?, ?, ?, ?)
"""
_submit_update_sql = """\
UPDATE arcnagios_reputation
SET update_time = ?, recent_count = ?, recent_success = ?
WHERE dist_name = ? AND choice_name = ?
"""

class ReputationTracker:

    def __init__(self, config, db_path, log=_default_log):
        self._log = log
        self._config = config
        self._db_path = db_path
        self._db = None
        self._choices = {}

    def _config_float(self, var, default=None):
        if self._config.has_section('reputation') \
                and self._config.has_option('reputation', var):
            return self._config.getfloat('reputation', var)
        else:
            return default

    def _config_dist_float(self, dist_name, var, default=None):
        section_name = 'reputation_dist:' + dist_name
        if self._config.has_section(section_name) \
                and self._config.has_option(section_name, var):
            return self._config.getfloat(section_name, var)
        else:
            return default

    @property
    def _busy_timeout(self):
        return self._config_float('busy_timeout', 10.0)

    @property
    def _default_sample_lifetime(self):
        return self._config_float('sample_lifetime', 172800.0)

    def connect(self):
        if self._db is None:
            self._db = sqlite3.connect(self._db_path, self._busy_timeout)
            try:
                self._db.execute(_create_sql)
            except sqlite3.OperationalError:
                pass

    def past_dist_names(self):
        return [name for (name,) in self._db.execute(_past_dist_names_sql)]

    def past_choice_names(self, dist_name):
        return [name for (name,)
                in self._db.execute(_past_choice_names_sql, (dist_name,))]

    def disconnect(self):
        if not self._db is None:
            self._db.close()
            self._db = None

    def success_dist(self, dist_name):
        self.connect()
        cur = self._db.execute(_fetch_sql, (dist_name,))
        return dict((k, (nS + 0.25) / (n + 0.5)) for (k, n, nS) in cur)

    def choice_dist(self, dist_name, choice_names, success_dist = None):
        if success_dist is None:
            success_dist = self.success_dist(dist_name)
        if success_dist == {}:
            avg_success = 0.5
        else:
            avg_success = sum(success_dist.values()) / len(success_dist)
        restricted_success_dist = \
            dict((k, success_dist.get(k, avg_success)) for k in choice_names)
        return choice_dist_of_success_dist(restricted_success_dist)

    def submit(self, dist_name, choice_name, is_success):
        self.connect()
        rows = self._db.execute(_submit_select_sql, (dist_name, choice_name)) \
                       .fetchall()
        t_now = time.time()
        if rows == []:
            t_past, recent_count, recent_success = (t_now, 0.0, 0.0)
        else:
            assert(len(rows) == 1)
            t_past, recent_count, recent_success = rows[0]

        sample_lifetime = self._config_dist_float(dist_name, 'sample_lifetime')
        scale = math.exp((t_past - t_now) \
              / (sample_lifetime or self._default_sample_lifetime))
        recent_count = scale * recent_count + 1.0
        recent_success *= scale
        if is_success:
            recent_success += 1.0
        self._db.execute(
                rows == [] and _submit_insert_sql or _submit_update_sql,
                (t_now, recent_count, recent_success, dist_name, choice_name))
        self._db.commit()

    def choose_off_the_record(self, dist_name, choice_names):
        assert choice_names
        choice_dist = self.choice_dist(dist_name, choice_names)
        p = random.uniform(0.0, 1.0)
        for (choice_name, pS) in choice_dist.items():
            p -= pS
            if p < 0.0:
                return choice_name
        return next(iter(choice_names))

    def choose(self, dist_name, choice_names):
        choice_name = self.choose_off_the_record(dist_name, choice_names)
        self._choices[dist_name] = choice_name
        return choice_name

    def choices(self):
        return self._choices
