# Copyright (C) 2015 Matthias C. M. Troffaes & Lewis Paton
# matthias.troffaes@durham.ac.uk, l.w.paton@durham.ac.uk
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

"""Simple multinomial model."""

import numpy as np
import logistic_prior


def nk_list(xjs):
    js, nk = logistic_prior.nk_dict(xjs)
    # values of nk are vectors (n(x), k_1(x), k_2(x), ...)
    return js, sum(nk.itervalues())


def multinom_pmf(theta, x, _sum=sum):
    """Return multinomial probability mass function for all categories.
    This explicitly ignores x.

    The result is indexed by j - 1 (i.e. the first category starts at
    index 0).
    """
    return theta


def get_t_posterior(t_prior, nk):
    """
    Update prior parameters given the data.

    For nk: see nk_list.

    For t_prior: a tuple (tval, sval) where tval is a array of
    probabilities.

    This function updates the prior parameters to the posterior parameters.
    s -> s + n
    t_j -> (s*t_j + k_j) / (s + n)
    output t_posterior in same form as t_prior
    """
    tval, sval = t_prior
    n = nk[0]
    ks = np.array(nk[1:])
    return ((sval * tval + ks) / (sval + n), sval + n)


def get_t_prior_extremes(sval_prior, num_categories):
    # tvals is the list of extreme distributions on categories
    tvals = [[0.01 if j != jstar else (1 - 0.01 * (num_categories - 1))
              for j in xrange(num_categories)]
             for jstar in xrange(num_categories)]
    # tvals_product picks a member of tvals for each x in xs_prior
    for tval in tvals:
        yield (np.array(tval), sval_prior)


def get_t_posterior_extremes(sval_prior, nk):
    # nk is (n(x), k_1(x), k_2(x), ...)
    # length of nk is therefore number of categories + 1
    num_categories = len(nk) - 1
    return (get_t_posterior(t_prior, nk)
            for t_prior in get_t_prior_extremes(
                sval_prior, num_categories))

# tests

def test_get_t_prior_extremes():
    np.testing.assert_allclose(
        list(t for t, s in get_t_prior_extremes(2, 3)),
        [[0.98, 0.01, 0.01],
         [0.01, 0.98, 0.01],
         [0.01, 0.01, 0.98]], atol=1e-5)
    np.testing.assert_allclose(
        list(s for t, s in get_t_prior_extremes(2, 3)),
        [2, 2, 2], atol=1e-5)


def test_get_t_posterior_extremes():
    np.testing.assert_allclose(
        list(t for t, s in get_t_posterior_extremes(2, [9, 4, 5])),
        [[5.98 / (9 + 2), 5.02 / (9 + 2)], [4.02 / (9 + 2), 6.98 / (9 + 2)]]
        )
