# Copyright (C) 2015 Matthias C. M. Troffaes & Lewis Paton
# matthias.troffaes@durham.ac.uk, l.w.paton@durham.ac.uk
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

"""Logistic model."""

from __future__ import division

import itertools
import numbers
import numpy as np
import nose.tools
import operator
#import pymc
import scipy.optimize

# using _dot and _exp to save namespace lookups
# this speeds up the code slightly
# (and this function is called *MANY* times)
def logistic_pmf(beta, x, _dot=np.dot, _exp=np.exp, _sum=sum):
    """Return logistic function for all categories: normalised exp(dot(beta, x).

    The result is indexed by j - 1 (i.e. the first category starts at
    index 0).
    """
    assert x[0] == 1  # first element one
    assert (beta[-1] == 0).all()  # last row zero
    # XXX dot product is second slowest part of this function
    # XXX np.dot is slightly faster than beta.dot(x)
    xx = _exp(_dot(beta, x))
    assert xx[-1] == 1
    return xx / _sum(xx)

def sample_logistic_pmf(betasample, x):
    """Return empirical expected logistic function of all categories
    from an MCMC sample of beta values.
    """
    # XXX todo: should also return an error!
    return np.mean(
        [logistic_pmf(beta, x) for beta in betasample],
        axis=0)

# multivariate conjugate prior with parameter beta (J x (M+1)) matrix
# distribution of the logistic coefficients (beta)
# t = list of (x, tval, sval)
# where tval[j-1] = prob of category j for regressor value x
def log_conjugate_prior(beta, t):
    betax_tval_sval = (
        (beta.dot(x), tval, sval)
        for x, (tval, sval) in t.iteritems())
    return sum(
        sval * (tval.dot(betax) - np.log(sum(np.exp(betax))))
        for betax, tval, sval in betax_tval_sval)

def get_num_regressors_and_categories_from_t(t):
    for x, (tval, sval) in t.iteritems():
        assert x[0] == 1
        np.testing.assert_almost_equal(sum(tval), 1)
        return len(x) - 1, len(tval)
    raise ValueError("t must not be empty")

def get_num_regressors_and_categories_from_beta(beta):
    return beta.shape[1] - 1, beta.shape[0]

def conjugate_prior_t_logistic(beta, t):
    """Calculate all intermediate values we need to calculate the
    final value of the conjugate prior mode function.
    """
    for x, (tval, sval) in t.iteritems():
        yield x, tval, sval, logistic_pmf(beta, x)

def conjugate_prior_mode_func_from_t_logistic(t_logistic):
    """Calculate mode function from the *t_logistic* values.
    Here, *t_logistic* must be a list, not just a generator,
    because we need to iterate over it many times.
    """
    if not isinstance(t_logistic, list):
        raise TypeError("t_logistic must be a list")
    xx, ttval, ssval, ppival = t_logistic[0]
    for m in xrange(len(xx)):
        for j in xrange(len(ttval) - 1):
            yield sum(sval * x[m] * (pival[j] - tval[j])
                      for x, tval, sval, pival in t_logistic)

def conjugate_prior_mode_func(beta, t):
    """Calculate mode function from *beta* and *t*."""
    return conjugate_prior_mode_func_from_t_logistic(
        list(conjugate_prior_t_logistic(beta, t)))

def rfbeta_to_beta(rfbeta, num_regressors, num_categories):
    return np.vstack(
        (rfbeta.reshape((num_categories - 1, num_regressors + 1)),
         np.zeros((1, num_regressors + 1))))

def beta_to_rfbeta(beta):
    M, J = get_num_regressors_and_categories_from_beta(beta)
    assert (beta[-1] == 0).all()  # last row zero
    return beta[:-1].flatten()

def get_map_estimate(t):
    # t is a dictionary of the form
    # (1, x1, x2, ..., xM) -> ((tval[0], ..., tval[J-1]), sval)
    M, J = get_num_regressors_and_categories_from_t(t)
    solution = scipy.optimize.root(
        lambda rfbeta: list(conjugate_prior_mode_func(
            rfbeta_to_beta(rfbeta, M, J), t)),
        [0] * ((J - 1) * (M + 1)),
        )
    if not solution.success:
        raise ValueError("cannot find map estimate (%s)" % solution.message)
    return rfbeta_to_beta(solution.x, M, J)

def get_mcmc_sample(t, num=10000):
    M, J = get_num_regressors_and_categories_from_t(t)
    ConjugatePrior = pymc.stochastic_from_dist(
        'ConjugatePrior',
        logp=lambda rfbeta: log_conjugate_prior(
            rfbeta_to_beta(rfbeta, M, J), t),
        mv=True)
    beta_map = get_map_estimate(t)
    initial_value = beta_to_rfbeta(beta_map)
    rfbetavar = ConjugatePrior('rfbeta', value=initial_value)
    model = pymc.MCMC([rfbetavar])
    model.sample(iter=num)
    return [rfbeta_to_beta(rfbeta, M, J) for rfbeta in model.trace('rfbeta')]

def nk_dict(xjs):
    """Return a dictionary of the form
    x -> (n(x), k_1(x), k_2(x), ...)
    where k_j(x) is number of observations of category Y=j for X=x
    and n(x) is the sum of these over all j.
    """
    # TODO use groupby on sorted xjs
    xjs = list(xjs)
    nk = {}
    M = len(xjs[0][0]) - 1  # number of regressors
    js = list(sorted(set(j for (x, j) in xjs)))
    # DEBUG
    #import collections
    #print(collections.Counter(j for (x, j) in xjs))
    J = len(js)  # number of categories
    for x, j in xjs:
        cat = js.index(j)
        assert 0 <= cat < J
        assert x[0] == 1
        assert len(x) == M + 1
        x = tuple(x)  # make immutable
        if x not in nk:
            nk[x] = np.zeros(J + 1)
        nk[x][0] += 1
        nk[x][cat + 1] += 1
    return js, nk

def apply_dicts(func, d1, d2):
    return {
        key: func(d1.get(key, 0), d2.get(key, 0))
        for key in set(itertools.chain(d1, d2))
        }

def add_dicts(d1, d2):
    return apply_dicts(operator.add, d1, d2)

def div_dicts(d1, d2):
    return apply_dicts(operator.truediv, d1, d2)

def get_t_posterior(t_prior, nk):
    """
    Update prior parameters given the data.

    For nk: see nk_dict.

    For t_prior: dictionary from x to (tval, sval) where tval is a array of
    probabilities.

    This function updates the prior parameters to the posterior parameters.
    s(x) -> s(x) + n(x)
    t_j(x) -> (s(x)*t_j(x) + k_j(x)) / (s(x) + n(x))
    output t_posterior in same form as t_prior
    """
    s = {x: sval for x, (tval, sval) in t_prior.iteritems()}
    sts = {x: sval * tval for x, (tval, sval) in t_prior.iteritems()}
    n = {x: k[0] for x, k in nk.iteritems()}
    ks = {x: k[1:] for x, k in nk.iteritems()}

    numer = add_dicts(sts, ks)
    sigma = add_dicts(s, n)
    assert set(numer) == set(sigma)
    tau = div_dicts(numer, sigma)
    return dict([(key, (tau[key], sigma[key])) for key in tau])

def get_t_prior_extremes(xs_prior, sval_prior, num_categories, epsilon=0.01):
    # tvals is the list of extreme distributions on categories
    tvals = [[epsilon if j != jstar else (1 - epsilon * (num_categories - 1))
              for j in xrange(num_categories)]
             for jstar in xrange(num_categories)]
    # tvals_product picks a member of tvals for each x in xs_prior
    for tvals_product in itertools.product(tvals, repeat=len(xs_prior)):
        yield {
            x: (np.array(tval), sval_prior)
            for x, tval in itertools.izip(xs_prior, tvals_product)
            }

def get_t_posterior_extremes(xs_prior, sval_prior, nk):
    # take the first value (n(x), k_1(x), k_2(x), ...) from the nk dictionary
    # length of this value is number of categories + 1
    num_categories = len(next(nk.itervalues())) - 1
    return (get_t_posterior(t_prior, nk)
            for t_prior in get_t_prior_extremes(
                xs_prior, sval_prior, num_categories))

# tests

def a(*args):
    return np.array(list(args))

test_log_conjugate_prior_data = [
    # (beta, s, [(x, tval, sval)], log_conjugate_prior)
    (a([1.2, 3.4], [0, 0]), # beta
     {(1, -2): (a(0.2, 0.8), 0.5), # x1 tval1 sval1
      (1, 3): (a(0.6, 0.4), 0.5), # x2 tval2 sval2
      },
     -2.8418511194245597), # result
    (a([-2.1, 1.9], [0, 0]),
     {(1, 5): (a(0.01, 0.99), 0.83),
      (1, 6): (a(0.93, 0.07), 0.83),
      },
     -6.6214930634420766),
    ]

def test_log_conjugate_prior():
    for beta, t, result in test_log_conjugate_prior_data:
        yield lambda: nose.tools.assert_almost_equal(
            log_conjugate_prior(beta, t), result)

def test_t_prior_vacuous_log_conjugate_prior():
    beta = a([1.2, 3.4], [0, 0])
    xs = [(1, -2), (1, 3)]
    sval = 0.5
    # just running the code to see that it does not raise any exceptions
    for t_prior in get_t_prior_extremes(xs, sval, 2):
        log_conjugate_prior(beta, t_prior)

test_logistic_data = [
    # beta, x, logistic
    ([[0.1, -2], [0, 0]], (1, 17), 1.8941617547848785e-15),
    #  R: plogis(-2, 1.033/0.97, 1/0.97)
    ([[-1.033, 0.97], [0, 0]], (1, -2), 0.04866065641364303),
    #  R: plogis(1.2, -0.64/1.3, 1/1.3)
    ([[0.64, 1.3], [0, 0]], (1, 1.2), 0.9002495),
    ]

def test_logistic():
    for beta, x, result in test_logistic_data:
        beta = np.array(beta)
        x = np.array(x)
        yield lambda: np.testing.assert_allclose(
            logistic_pmf(beta, x), (result, 1 - result), atol=1e-5)

def test_add_dicts():
    nose.tools.assert_equal(
        add_dicts({1: 5, 2: 3, 4: 6}, {1: -2, 3: 1}),
        {1: 3, 2: 3, 3: 1, 4: 6})

def test_nk_dict():
    np.testing.assert_equal(
        nk_dict([
            [[1, 2], 2],
            [[1, 4], 3],
            [[1, 2], 4],
            [[1, 1], 9],
            [[1, 2], 2],
            ]),
        ([2, 3, 4, 9],
         {(1, 2): [3, 2, 0, 1, 0],
          (1, 4): [1, 0, 1, 0, 0],
          (1, 1): [1, 0, 0, 0, 1],
          }
         )
        )

def test_get_t_posterior():  
    js, nk = nk_dict([
        [[1, 0], 1],
        [[1, 0], 1],
        [[1, 1], 2],
        [[1, 1], 1],
        ])
    t = {
        (1, 0): (a(0.8, 0.2), 2.0),
        (1, 2): (a(0.125, 0.875), 0.75),
        }
    np.testing.assert_equal(
        get_t_posterior(t, nk),
        {(1, 2): (a(0.125, 0.875), 0.75),
         (1, 0): (a(0.9, 0.1), 4.0),
         (1, 1): (a(0.5, 0.5), 2),
         }
        )

def test_conjugate_prior_mode_func():
    beta = a([-60 * 0.05, 0.05], [0, 0])
    t = {
        (1, 30): (a(0.9, 0.1), 2.1),
        (1, 100): (a(0.1, 0.9), 1.2),
        }
    # mode function is sum_x x[m] * s(x) * (logistic(j, beta, x) - t(x)[j - 1])
    #mf1 = (
    #    1 * 2.1 * (logistic(1, beta, (1, 30)) - 0.9) +
    #    1 * 1.2 * (logistic(1, beta, (1, 100)) - 0.1))
    #mf2 = (
    #    30 * 2.1 * (logistic(1, beta, (1, 30)) - 0.9) +
    #    100 * 1.2 * (logistic(1, beta, (1, 100)) - 0.1))
    np.testing.assert_allclose(
        list(conjugate_prior_mode_func(beta, t)),
        (-0.569949906433, 48.4884573571))

def test_get_t_prior_extremes():
    xs = [(1, 10, 20), (1, 15, 5), (1, 25, 0)]
    sval = 3
    num_categories = 3
    np.testing.assert_equal(
        list(get_t_prior_extremes(xs, sval, num_categories, epsilon=0)),
        [
            {(1, 25, 0): ((1, 0, 0), 3), (1, 15, 5): ((1, 0, 0), 3), (1, 10, 20): ((1, 0, 0), 3)},
            {(1, 25, 0): ((0, 1, 0), 3), (1, 15, 5): ((1, 0, 0), 3), (1, 10, 20): ((1, 0, 0), 3)},
            {(1, 25, 0): ((0, 0, 1), 3), (1, 15, 5): ((1, 0, 0), 3), (1, 10, 20): ((1, 0, 0), 3)},
            {(1, 25, 0): ((1, 0, 0), 3), (1, 15, 5): ((0, 1, 0), 3), (1, 10, 20): ((1, 0, 0), 3)},
            {(1, 25, 0): ((0, 1, 0), 3), (1, 15, 5): ((0, 1, 0), 3), (1, 10, 20): ((1, 0, 0), 3)},
            {(1, 25, 0): ((0, 0, 1), 3), (1, 15, 5): ((0, 1, 0), 3), (1, 10, 20): ((1, 0, 0), 3)},
            {(1, 25, 0): ((1, 0, 0), 3), (1, 15, 5): ((0, 0, 1), 3), (1, 10, 20): ((1, 0, 0), 3)},
            {(1, 25, 0): ((0, 1, 0), 3), (1, 15, 5): ((0, 0, 1), 3), (1, 10, 20): ((1, 0, 0), 3)},
            {(1, 25, 0): ((0, 0, 1), 3), (1, 15, 5): ((0, 0, 1), 3), (1, 10, 20): ((1, 0, 0), 3)},
            {(1, 25, 0): ((1, 0, 0), 3), (1, 15, 5): ((1, 0, 0), 3), (1, 10, 20): ((0, 1, 0), 3)},
            {(1, 25, 0): ((0, 1, 0), 3), (1, 15, 5): ((1, 0, 0), 3), (1, 10, 20): ((0, 1, 0), 3)},
            {(1, 25, 0): ((0, 0, 1), 3), (1, 15, 5): ((1, 0, 0), 3), (1, 10, 20): ((0, 1, 0), 3)},
            {(1, 25, 0): ((1, 0, 0), 3), (1, 15, 5): ((0, 1, 0), 3), (1, 10, 20): ((0, 1, 0), 3)},
            {(1, 25, 0): ((0, 1, 0), 3), (1, 15, 5): ((0, 1, 0), 3), (1, 10, 20): ((0, 1, 0), 3)},
            {(1, 25, 0): ((0, 0, 1), 3), (1, 15, 5): ((0, 1, 0), 3), (1, 10, 20): ((0, 1, 0), 3)},
            {(1, 25, 0): ((1, 0, 0), 3), (1, 15, 5): ((0, 0, 1), 3), (1, 10, 20): ((0, 1, 0), 3)},
            {(1, 25, 0): ((0, 1, 0), 3), (1, 15, 5): ((0, 0, 1), 3), (1, 10, 20): ((0, 1, 0), 3)},
            {(1, 25, 0): ((0, 0, 1), 3), (1, 15, 5): ((0, 0, 1), 3), (1, 10, 20): ((0, 1, 0), 3)},
            {(1, 25, 0): ((1, 0, 0), 3), (1, 15, 5): ((1, 0, 0), 3), (1, 10, 20): ((0, 0, 1), 3)},
            {(1, 25, 0): ((0, 1, 0), 3), (1, 15, 5): ((1, 0, 0), 3), (1, 10, 20): ((0, 0, 1), 3)},
            {(1, 25, 0): ((0, 0, 1), 3), (1, 15, 5): ((1, 0, 0), 3), (1, 10, 20): ((0, 0, 1), 3)},
            {(1, 25, 0): ((1, 0, 0), 3), (1, 15, 5): ((0, 1, 0), 3), (1, 10, 20): ((0, 0, 1), 3)},
            {(1, 25, 0): ((0, 1, 0), 3), (1, 15, 5): ((0, 1, 0), 3), (1, 10, 20): ((0, 0, 1), 3)},
            {(1, 25, 0): ((0, 0, 1), 3), (1, 15, 5): ((0, 1, 0), 3), (1, 10, 20): ((0, 0, 1), 3)},
            {(1, 25, 0): ((1, 0, 0), 3), (1, 15, 5): ((0, 0, 1), 3), (1, 10, 20): ((0, 0, 1), 3)},
            {(1, 25, 0): ((0, 1, 0), 3), (1, 15, 5): ((0, 0, 1), 3), (1, 10, 20): ((0, 0, 1), 3)},
            {(1, 25, 0): ((0, 0, 1), 3), (1, 15, 5): ((0, 0, 1), 3), (1, 10, 20): ((0, 0, 1), 3)},
        ])

data_beta_to_rfbeta = [
    dict(
        num_categories=2,
        num_regressors=1,
        beta=a([1.2, 3.4], [0, 0]),
        rfbeta=a(1.2, 3.4),
        ),
    dict(
        num_categories=3,
        num_regressors=3,
        beta=a([1.2, 3.4, 5.6, 7.8], [9.8, 7.6, 5.4, 3.2], [0, 0, 0, 0]),
        rfbeta=a(1.2, 3.4, 5.6, 7.8, 9.8, 7.6, 5.4, 3.2),
        ),
    ]

def test_beta_to_rfbeta():
    for data in data_beta_to_rfbeta:
        num_categories = data["num_categories"]
        num_regressors = data["num_regressors"]
        beta = data["beta"]
        rfbeta = data["rfbeta"]
        nose.tools.assert_equal(
            get_num_regressors_and_categories_from_beta(beta),
            (num_regressors, num_categories))
        np.testing.assert_equal(beta_to_rfbeta(beta), rfbeta)
        np.testing.assert_equal(rfbeta_to_beta(rfbeta, num_regressors, num_categories), beta)

def test_get_lower_expectation():
    from expectation import (
        get_lower_expectation, get_upper_expectation,
        get_lower_pmf, get_upper_pmf,
        )
    betas = [
        a([1.2, 3.4], [0, 0]),
        a([1.4, 2.6], [0, 0])
        ]
    v1 = np.exp(1.2) / (np.exp(1.2) + 1)  # 0.768524783499
    v2 = np.exp(1.4) / (np.exp(1.4) + 1)  # 0.802183888559

    np.testing.assert_approx_equal(
        get_lower_expectation(a(1, 0), logistic_pmf, betas, (1, 0)), v1)
    np.testing.assert_approx_equal(
        get_lower_expectation(a(-1, 0), logistic_pmf, betas, (1, 0)), -v2)
    np.testing.assert_approx_equal(
        get_lower_expectation(a(1, 2), logistic_pmf, betas, (1, 0)), 2 - v2)
    np.testing.assert_approx_equal(
        get_lower_expectation(a(0, -1), logistic_pmf, betas, (1, 0)), v1 - 1)

    np.testing.assert_approx_equal(
        get_upper_expectation(a(1, 0), logistic_pmf, betas, (1, 0)), v2)
    np.testing.assert_approx_equal(
        get_upper_expectation(a(-1, 0), logistic_pmf, betas, (1, 0)), -v1)
    np.testing.assert_approx_equal(
        get_upper_expectation(a(1, 2), logistic_pmf, betas, (1, 0)), 2 - v1)
    np.testing.assert_approx_equal(
        get_upper_expectation(a(0, -1), logistic_pmf, betas, (1, 0)), v2 - 1)

    np.testing.assert_allclose(
        get_lower_pmf(logistic_pmf, betas, (1, 0)), (v1, 1 - v2))
    np.testing.assert_allclose(
        get_upper_pmf(logistic_pmf, betas, (1, 0)), (v2, 1 - v1))

def test_get_sample_expectation():
    from expectation import get_expectation
    betasample = [
        a([1.2, 3.4], [0, 0]),
        a([1.4, 2.6], [0, 0])
        ]
    v1 = np.exp(1.2) / (np.exp(1.2) + 1)  # 0.768524783499
    v2 = np.exp(1.4) / (np.exp(1.4) + 1)  # 0.802183888559

    np.testing.assert_approx_equal(
        get_expectation(a(1, 0), sample_logistic_pmf, betasample, (1, 0)),
        0.5 * (v1 + v2))
    np.testing.assert_approx_equal(
        get_expectation(a(0, 1), sample_logistic_pmf, betasample, (1, 0)),
        1 - 0.5 * (v1 + v2))

def test_get_sample_lower_expectation():
    from expectation import (
        get_lower_expectation, get_upper_expectation,
        get_lower_pmf, get_upper_pmf,
        )
    betasamples = [[
        a([1.2, 3.4], [0, 0]),
        a([1.4, 2.6], [0, 0]),
        ], [
        a([0.2, 0.4], [0, 0]),
        a([0.4, 0.6], [0, 0]),
        ]]

    v1 = np.exp(1.2) / (np.exp(1.2) + 1)
    v2 = np.exp(1.4) / (np.exp(1.4) + 1)
    w1 = np.exp(0.2) / (np.exp(0.2) + 1)
    w2 = np.exp(0.4) / (np.exp(0.4) + 1)

    np.testing.assert_approx_equal(
        get_lower_expectation((1, 0), sample_logistic_pmf, betasamples, (1, 0)),
        min(0.5 * (v1 + v2), 0.5 * (w1 + w2)))
    np.testing.assert_approx_equal(
        get_lower_expectation((0, 1), sample_logistic_pmf, betasamples, (1, 0)),
        min(1 - 0.5 * (v1 + v2), 1 - 0.5 * (w1 + w2)))

    np.testing.assert_approx_equal(
        get_upper_expectation((1, 0), sample_logistic_pmf, betasamples, (1, 0)),
        max(0.5 * (v1 + v2), 0.5 * (w1 + w2)))
    np.testing.assert_approx_equal(
        get_upper_expectation((0, 1), sample_logistic_pmf, betasamples, (1, 0)),
        max(1 - 0.5 * (v1 + v2), 1 - 0.5 * (w1 + w2)))

    np.testing.assert_allclose(
        get_lower_pmf(sample_logistic_pmf, betasamples, (1, 0)),
        (min(0.5 * (v1 + v2), 0.5 * (w1 + w2)),
         min(1 - 0.5 * (v1 + v2), 1 - 0.5 * (w1 + w2))))
    np.testing.assert_allclose(
        get_upper_pmf(sample_logistic_pmf, betasamples, (1, 0)),
        (max(0.5 * (v1 + v2), 0.5 * (w1 + w2)),
         max(1 - 0.5 * (v1 + v2), 1 - 0.5 * (w1 + w2))))
