# Copyright (C) 2015 Matthias C. M. Troffaes & Lewis Paton
# matthias.troffaes@durham.ac.uk, l.w.paton@durham.ac.uk
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

from __future__ import division

import collections
import functools
import itertools
import matplotlib
matplotlib.use('pgf') # must come before importing pyplot
matplotlib.rcParams.update({'font.size': 9, 'text.usetex': True})
ext = '.pgf'
import matplotlib.gridspec
import matplotlib.pyplot as plt
import numpy as np
import operator
import os.path
from pprint import pprint

import crop_data
import dirichlet_prior
import expectation
import logistic_prior
import markov_chain
import smooth


def get_crop_data(region):
    """Get the crop data as a list of (x, i, j) tuples for a specific
    region.
    """
    data_folder = "../data/"
    return crop_data.get_xijs(
        rain_filename=os.path.join(data_folder, "rain_%s.csv" % region),
        profit_filename=os.path.join(data_folder, "profit.csv"),
        nitrogen_filename=os.path.join(data_folder, "nitrogen.csv"),
        crop_filename=os.path.join(data_folder, "crop_%s.csv" % region),
        soil_filename=os.path.join(data_folder, "soil_%s.csv" % region))


def get_xijss(region, parts=1):
    """Get the (x, i, j) tuples for a specific region, where the data is
    randomly split into different parts. The splitting is useful for
    cross validation, where we want to test on each part for models
    trained on the remainder.
    """
    if region == "all":
        xijs_anglia = [(x, i, j) for x, i, j, year in get_crop_data("anglia")]
        xijs_mease = [(x, i, j) for x, i, j, year in get_crop_data("mease")]
        all_xijs = xijs_anglia + xijs_mease
    elif region == "mease":
        # mease region has very low frequency of soil type 1
        # insufficient data for training, so simply drop the data
        # remember i = (soil type, previous crop)
        all_xijs = [(x, i, j) for x, i, j, year in get_crop_data(region)
                    if i[0] != 1]
    else:
        all_xijs = [(x, i, j) for x, i, j, year in get_crop_data(region)]
    all_indices = list(xrange(len(all_xijs)))
    np.random.shuffle(all_indices)
    indexss = [all_indices[part::parts] for part in xrange(parts)]
    result = [[all_xijs[index] for index in indexs] for indexs in indexss]
    assert sorted(all_xijs) == sorted(itertools.chain.from_iterable(result))
    return result


def get_i_xjs(xijs):
    """Group a sequence of (x, i, j) tuples by i. The result is returned
    as a sequence of (i, ((x1, j1), (x2, j2), ...)).
    """
    keyfunc = operator.itemgetter(1)  # index 1 of xij is i
    for ii, xijs_group in itertools.groupby(
        sorted(xijs, key=keyfunc), key=keyfunc):

        yield ii, ((x, j) for (x, i, j) in xijs_group)


def get_x_bounds(xijs):
    """Calculate the bounds of the regressors x in the data.
    This is useful to set sensible values for xs_prior.
    """
    xs = np.array([x for x, i, j in xijs])
    return np.vstack([np.min(xs, axis=0), np.max(xs, axis=0)])


def get_xs_prior(xijs):
    """Suggest xs_prior values by peeking at the data."""
    bounds = [sorted(set(x)) for x in get_x_bounds(xijs).T]
    return [tuple(x) for x in itertools.product(*bounds)]


def get_initial_distribution(region, i0, year0, js):
    """Calculate distribution of states across the data.
    This is useful to get a reasonable initial state.
    """
    counter = collections.Counter(
        i[1] for x, i, j, year in get_crop_data(region)
        if i[0] == i0 and year == year0)
    dist = np.array([counter[j] for j in js])
    assert sum(counter.itervalues()) == sum(dist)  # everything counted
    return dist / sum(dist)


def get_initial_lower_expectation(func, dist):
    return dist.dot(func)


def get_i_posterior_results(xijs, func_of_t):
    """This is the main function: find posterior hyperparameters for
    every state, and apply some function on the resulting
    hyperparameters.
    """
    xs_prior = get_xs_prior(xijs)
    # assert prior defined in correct space
    assert len(xs_prior[0]) == len(xijs[0][0]) 
    for i, xjs in get_i_xjs(xijs):
        js, nk = logistic_prior.nk_dict(xjs)
        results = itertools.imap(
            func_of_t,
            logistic_prior.get_t_posterior_extremes(
                xs_prior=xs_prior, sval_prior=2, nk=nk))
        yield i, js, results


def get_i_betas(xijs):
    """MAP estimation."""
    def _get_i_betas():
        return get_i_posterior_results(xijs, logistic_prior.get_map_estimate)
    return _get_i_betas


def get_i_betasamples(xijs, num=10000):
    """MCMC estimation."""
    def _get_i_betasamples():
        return get_i_posterior_results(
            xijs,
            functools.partial(logistic_prior.get_mcmc_sample, num=num))
    return _get_i_betasamples()


def get_i_dirichlet_posterior_results(xijs, func_of_t):
    for i, xjs in get_i_xjs(xijs):
        js, nk = dirichlet_prior.nk_list(xjs)
        results = itertools.imap(
            func_of_t,
            dirichlet_prior.get_t_posterior_extremes(
                sval_prior=2, nk=nk))
        yield i, js, results


def get_i_thetas(xijs):
    """Least squares estimation of theta parameters of multinomial model
    with a Dirichlet prior.
    """
    def _get_i_thetas():
        # for dirichlet, least squares estimate
        # is exactly the t values themselves
        return get_i_dirichlet_posterior_results(xijs, lambda ts: ts[0])
    return _get_i_thetas


def main(tag, region, get_i_params, pmfunc, plots=None):
    """
    get_i_params gives list of (i, params) for each state i.
    The state of the Markov chain, i.e. crop type, is i[1].
    Any other states that are not used as regression parameters are in i[0].
    The function pmfunc is called as pmfunc(params, x).
    """
    if plots is None:
        plots = {'bar', 'decision', 'probability', 'time', 'regressor'}
    # store in a tuple so we can iterate multiple times
    i_params = tuple((i, js, tuple(params))
                     for i, js, params in get_i_params())
    color = {'WH': 'red', 'LE': 'green', 'WH2': 'purple', 'RA': 'orange', 'OT': 'blue'}
    hatch = {'WH': '////', 'WH2': r'\\\\', 'OT': 'xxxx', 'LE': '-', 'RA': '|'}
    label = {'WH': "wheat", 'LE': "legumes", 'OT': "other", 'WH2': "wheat2", 'RA': "rapeseed"}
    linestyle = {'WH': '-', 'WH2': '-', 'LE': '--', 'RA': '-.', 'OT': ':'}
    soilname = {1: 'light', 2: 'medium', 3: 'heavy'}
    soillinestyle = {1: ':', 2: '--', 3: '-'}
    if 'bar' in plots:
        main_barplots(tag, i_params, pmfunc, color, hatch, label, linestyle, soilname)
    if 'decision' in plots:
        main_decision_plots(tag, region, i_params, pmfunc, color, hatch, label, linestyle, soilname)
    if 'decision_rb' in plots:
        main_decision_rb(tag, region, i_params, pmfunc, label, soilname)
    if 'probability' in plots:
        main_probability_plots(tag, region, i_params, pmfunc, color, hatch, label, linestyle, soilname)
    i0s = tuple(sorted(set(i[0] for i, js, params in i_params)))
    for i0 in i0s:
        if 'time' in plots:
            assert region != "all"  # timeplot broken for region="all"
            main_time_plots(tag, region, i0, i_params, pmfunc, color, hatch, label, linestyle, soilname)
        if 'regressor' in plots:
            main_regressor_plots(tag, region, i0, i_params, pmfunc, color, hatch, label, linestyle, soilname)
    i1s = tuple(sorted(set(i[1] for i, js, params in i_params)))
    for i1 in i1s:
        if 'density' in plots:
            main_empirical_density_plots(tag, region, i1, i_params, pmfunc, label, soillinestyle, soilname)


def main_barplots(tag, i_params, pmfunc, color, hatch, label, linestyle, soilname):
    """Transition probability bounds under different scenarios."""
    x1s = (10, 100)
    x2s = (50, 150)
    for i, js, params in i_params:
        # at the moment, i[0] is just the soil type
        # i[1] is always the Markov chain state (crop type)
        print "barplot; soil = %s; crop = %s" % (soilname[i[0]], label[i[1]])
        gs = matplotlib.gridspec.GridSpec(len(x1s), len(x2s))
        fig = plt.figure(figsize=(3.5, 3.5))
        for k1, x1 in enumerate(x1s):
            for k2, x2 in enumerate(x2s):
                x = (1, x1, x2)
                ax = fig.add_subplot(gs[k1, k2], aspect=3)
                ax.axis([0, 3, 0, 1])
                ax.set_xticks([.5 + tick for tick, j in enumerate(js)])
                ax.set_xticklabels([label[j] for j in js])
                ax.set_title("$x=(1,%s,%s)$" % (x1, x2))
                for cat, (low, upp) in enumerate(
                    itertools.izip(
                        expectation.get_lower_pmf(
                            pmfunc, params, x),
                        expectation.get_upper_pmf(
                            pmfunc, params, x),
                        )):
                    j = js[cat]
                    kwargs = dict(
                        color=color[j],
                        hatch=hatch[j],
                        align="center",
                        linewidth=2,
                        )
                    ax.bar(.5 + cat, low, **kwargs)
                    ax.bar(.5 + cat, upp - low, bottom=low, alpha=0.5, **kwargs)
        plt.tight_layout()  # improve spacing between subplots
        plt.savefig("%s-barplot-%i-%s%s" % (tag, i[0], i[1], ext))
        plt.savefig("%s-barplot-%i-%s%s" % (tag, i[0], i[1], ".png"))
        plt.close(fig)


def get_x_linspace(xijs, axis, num=20):
    """Get a linspace for x values, for plotting. We take the central
    value for all regressors except along axis where we step from the
    minimum to the maximum.
    """
    x_bounds = get_x_bounds(xijs)
    x_mean = np.mean(x_bounds, axis=0)
    zs = np.linspace(x_bounds[0, axis], x_bounds[1, axis], num=num)
    xs = [np.array([(x if ax != axis else z) for ax, x in enumerate(x_mean)])
          for z in zs]
    return xs, zs


def main_empirical_density_plots(
        tag, region, i1, i_params, pmfunc,
        label, soillinestyle, soilname):
    i0_params = tuple((i[0], js, params)
                      for i, js, params in i_params if i[1] == i1)
    i0s = tuple(i0 for i0, js, params in i0_params)
    xijs = get_xijss(region, parts=1)[0]
    xs, zs = get_x_linspace(xijs, axis=2)
    kern = smooth.get_gaussian_kernel(np.array([1, 20, 20]))
    i0_xjs = {i[0]: list(xjs) for i, xjs in get_i_xjs(xijs) if i[1] == i1}
    fig = plt.figure(figsize=(2.2, 2.2))
    for i0, js, params in i0_params:
        print "densityplot; soil = %s; crop = %s" % (soilname[i0], label[i1])
        # calculate "data density" of the smoother
        xjs = i0_xjs[i0]
        density = smooth.get_density(xjs, kern)
        empirical_density = np.array([density(x) for x in xs])
        plt.plot(zs, empirical_density, color="black", linewidth=2, linestyle=soillinestyle[i0], label=soilname[i0])
    plt.legend(prop={'size': 6}, loc=0)
    fig.suptitle(
        "from %s" % label[i1],
        x=0.25, y=1.0, horizontalalignment="left")
    plt.xlabel("nitrogen")
    plt.ylabel("observations")
    plt.xlim((min(zs), max(zs)))
    plt.tight_layout()
    fig.savefig("%s-densityplot-%s%s" % (tag, i1, ext))
    fig.savefig("%s-densityplot-%s%s" % (tag, i1, ".png"))
    plt.close(fig)


def main_probability_plots(
        tag, region, i_params, pmfunc,
        color, hatch, label, linestyle, soilname):
    """Plot posterior transition probabilities. This is like barplots
    but instead plots posterior bounds against changes in a single
    variable.
    """
    xijs = get_xijss(region, parts=1)[0]
    xs, zs = get_x_linspace(xijs, axis=2)
    kern = smooth.get_gaussian_kernel(np.array([1, 20, 20]))
    i_xjs = {i: list(xjs) for i, xjs in get_i_xjs(xijs)}
    for i, js, params in i_params:
        print "probplot; soil = %s; crop = %s" % (soilname[i[0]], label[i[1]])
        # calculate non-parametric estimate
        empirical_prob = {}
        xjs = i_xjs[i]
        for j in js:
            smoother = smooth.get_smoother(xjs, kern, j)
            empirical_prob[j] = [smoother(x) for x in xs]
        # calculate robust Bayesian estimate
        lows = {j: [] for j in js}
        upps = {j: [] for j in js}
        for x in xs:
            for cat, (low, upp) in enumerate(
                itertools.izip(
                    expectation.get_lower_pmf(
                        pmfunc, params, x),
                    expectation.get_upper_pmf(
                        pmfunc, params, x),
                    )):
                j = js[cat]
                lows[j].append(low)
                upps[j].append(upp)
        assert all(len(lows[j]) == len(xs) for j in js)
        assert all(len(upps[j]) == len(xs) for j in js)
        # do the plot
        fig = plt.figure(figsize=(2.2, 2.2))
        for j in js:
            kwargs = dict(
                color=color[j],
                hatch=hatch[j],
                alpha=0.2,
                label=label[j],
                linewidth=2,
                )
            plt.fill_between(zs, lows[j], upps[j], **kwargs)
            del kwargs["alpha"]
            del kwargs["hatch"]
            plt.plot(zs, empirical_prob[j], linestyle[j], **kwargs)
        fig.suptitle(
            "%s soil, from %s" % (soilname[i[0]], label[i[1]]),
            x=0.25, y=1.0, horizontalalignment="left")
        ps = [plt.Rectangle(
            (0, 0), 1, 1, alpha=0.2, hatch=hatch[j], color=color[j]) for j in js]
        labels = [label[j] for j in js]
        plt.legend(ps, labels, prop={'size': 6}, loc=0)
        plt.xlabel("nitrogen")
        #plt.ylabel("probability")
        plt.ylim((0,1))
        plt.xlim((min(zs), max(zs)))
        plt.tight_layout()
        fig.savefig("%s-probplot-%i-%s%s" % (tag, i[0], i[1], ext))
        fig.savefig("%s-probplot-%i-%s%s" % (tag, i[0], i[1], ".png"))
        plt.close(fig)


def main_decision_plots(
        tag, region, i_params, pmfunc,
        color, hatch, label, linestyle, soilname):
    """Plot decision plot for posterior expected utility."""
    xijs = get_xijss(region, parts=1)[0]
    xs, zs = get_x_linspace(xijs, axis=2, num=50)
    i_xjs = {i: list(xjs) for i, xjs in get_i_xjs(xijs)}
    # only from wheat for now
    i_params = [i_param for i_param in i_params if i_param[0][1] == "WH"]
    for i, js, params in i_params:
        for kappa_index, kappa in enumerate(np.linspace(0, 0.25, 26)):
            print(
                "decplot; soil = %s, crop = %s, kappa = %g"
                % (soilname[i[0]], label[i[1]], kappa))
            # calculate utility bounds
            lows = []
            upps = []
            for x, z in itertools.izip(xs, zs):
                ### TODO use get_func
                func = np.array([(100 if j == "LE" else 0) for j in js]) - kappa * z
                lows.append(
                    expectation.get_lower_expectation(func, pmfunc, params, x))
                upps.append(
                    expectation.get_upper_expectation(func, pmfunc, params, x))
            # do the plot
            fig = plt.figure(figsize=(2.2, 2.2))
            kwargs = dict(
                color=color["LE"],
                hatch="||",
                alpha=0.2,
                linewidth=2,
                )
            plt.fill_between(zs, lows, upps, interpolate=True, **kwargs)
            plt.axhline(np.max(lows), color='black', linewidth=2)
            fig.suptitle(
                #"%s soil, from %s, $\kappa=%g$"
                "%s soil, $\kappa=%g$"
                #% (soilname[i[0]], label[i[1]], kappa),
                % (soilname[i[0]], kappa),
                x=0.25, y=1.0, horizontalalignment="left")
            plt.xlabel("nitrogen")
            #plt.ylabel("utility")
            plt.xlim((min(zs), max(zs)))
            plt.tight_layout()
            fig.savefig("%s-decplot-%i-%s-%i%s" % (tag, i[0], i[1], kappa_index, ext))
            fig.savefig("%s-decplot-%i-%s-%i%s" % (tag, i[0], i[1], kappa_index, ".png"))
            plt.close(fig)


def get_func(js, kappa, z):
    """Return decision gamble as a function of js, kappa, and z."""
    return np.array(
        [(100 if j == "LE" else 0) for j in js]) - kappa * z


def main_decision_rb(
        tag, region, i_params, pmfunc, label, soilname):
    """Robust bayesian decision making"""
    xijs = get_xijss(region, parts=1)[0]
    xs, zs = get_x_linspace(xijs, axis=2, num=50)
    print min(zs)
    print max(zs)
    i_xjs = {i: list(xjs) for i, xjs in get_i_xjs(xijs)}
    # only from wheat for now
    i_params = [i_param for i_param in i_params if i_param[0][1] == "WH"]
    for i, js, params in i_params:
        for kappa_index, kappa in enumerate(np.linspace(0, 0.25, 26)):
            opt_zs = [
                max([
                    (expectation.get_expectation(
                        get_func(js, kappa, z), pmfunc, param, x),
                     z)
                    for x, z in itertools.izip(xs, zs)])[1]
                for param in params]
            print(
                "robust bayes decision; soil = %s, crop = %s, kappa = %g"
                % (soilname[i[0]], label[i[1]], kappa))
            print(
                "optimal nitrogen range = (%g, %g)"
                % (min(opt_zs), max(opt_zs)))


def func_project(func, js, alljs):
    # project function on alljs to a function on js
    return np.array([func[alljs.index(j)] for j in js])


def get_years_trace(future_xs):
    timesteps = len(future_xs)
    past_years = range(1994, 2005)
    all_years = past_years + range(2005, 2005 + timesteps)
    return past_years, all_years


def get_crop_bounds_trace(region, i0, i_params, pmfunc, future_xs):
    """Calculate historic and future crop bounds from scenario."""
    i1_params = tuple((i[1], js, params)
                      for i, js, params in i_params if i[0] == i0)
    i1s = tuple(i1 for i1, js, params in i1_params)
    assert i1s == tuple(sorted(set(
        itertools.chain.from_iterable(js for i, js, params in i_params))))
    # get historic distributions
    past_years, all_years = get_years_trace(future_xs)
    past_crop_dists = [
        get_initial_distribution(region, i0=i0, year0=year, js=i1s)
        for year in past_years]

    def lift_lower_expectation(js, params, x):
        def lowexp(func):
            func2 = func_project(func, js, i1s)
            return expectation.get_lower_expectation(
                func=func2, pmfunc=pmfunc, params=params, extra=x)
        return lowexp

    lower_transition_operators = [
        [
            # we need lower transition operators that work on
            # functions on i1s, but params is for js, so need
            # func_project to lift the operator to the larger
            # space of gambles on i1s
            lift_lower_expectation(js, params, x)
            for i1, js, params in i1_params
            ]
        for x in future_xs
        ]
    for cat, i in enumerate(i1s):
        # historic crop
        low = [dist[cat] for dist in past_crop_dists]
        upp = [dist[cat] for dist in past_crop_dists]
        # future crop
        indicator = tuple(1 if ii == i else 0 for ii in i1s)
        conjugate = tuple(0 if ii == i else 1 for ii in i1s)
        it1 = markov_chain.iterate_lower_transition_operators(
            lower_transition_operators, indicator)
        it2 = markov_chain.iterate_lower_transition_operators(
            lower_transition_operators, conjugate)
        # first iteration will just not apply any transitions
        # we have this already covered
        it1.next()
        it2.next()
        low += [
            get_initial_lower_expectation(func, past_crop_dists[-1])
            for func in it1]
        upp += [
            1 - get_initial_lower_expectation(func, past_crop_dists[-1])
            for func in it2]
        yield (i, low, upp)


def get_regressor_trace(region, i0, future_xs):
    """Get a trace of regressor values."""
    past_years, all_years = get_years_trace(future_xs)
    past_xs = [
        # next(...) gets the first element of ... provided ... is an iterator
        next(groups)[1] for key, groups
        in itertools.groupby(
            sorted((year, x) for x, i, j, year
                   in get_crop_data(region) if i[0] == i0),
            operator.itemgetter(0)
            )
        ]
    for m in [1, 2]:
        yield (m, [xx[m] for xx in past_xs] + [xx[m] for xx in future_xs])


def main_time_plots(tag, region, i0, i_params, pmfunc,
                    color, hatch, label, linestyle, soilname):
    """Crop traces under different scenarios."""
    i1_params = tuple((i[1], js, params)
                      for i, js, params in i_params if i[0] == i0)
    i1s = tuple(i1 for i1, js, params in i1_params)
    assert i1s == tuple(sorted(set(
        itertools.chain.from_iterable(js for i, js, params in i_params))))
    xss_scenario = (
        # first scenario
        [(1, 10, 100),
         (1, 50, 120),
         (1, 20, 150),
         (1, 55, 130),
         (1, 60, 110)
         ],
##        # second scenario
##        [(1, 25, 100),
##         (1, 20, 95),
##         (1, 30, 85),
##         (1, 40, 90),
##         (1, 35, 95)
##         ],
        )
    # range for each of the components of x
    xs_range = (None, (0, 120), (30, 170))
    gs = matplotlib.gridspec.GridSpec(3, 2)
    fig = plt.figure(figsize=(5.333, 4))
    regressor_names = [None, "rain", "nitrogen"]

    for scenario, future_xs in enumerate(xss_scenario):
        print "soil = %s, scenario = %s" % (soilname[i0], str(future_xs))
        past_years, all_years = get_years_trace(future_xs)
        # crop bounds
        ax = fig.add_subplot(gs[0, scenario])
        ax.axis([min(all_years), max(all_years), 0, 1])
        #ax.set_title("$x = %s$" % str(future_xs))
        ax.set_xlabel("time")
        ax.set_ylabel("probability")

        for i, low, upp in get_crop_bounds_trace(
            region, i0, i_params, pmfunc, future_xs):
            kwargs = dict(
                color=color[i],
                hatch=hatch[i],
                alpha=0.5,
                label=label[i],
                linewidth=2,
                )
            ax.fill_between(all_years, low, upp, **kwargs)

        kwargs = dict(
            alpha=0.5,
            linewidth=2,
            )
        ps = [
            plt.Rectangle(
                (0, 0), 1, 1, hatch=hatch[i], color=color[i], **kwargs)
            for i in i1s]
        labels = [label[i] for i in i1s]
        ax.legend(ps, labels, prop={'size': 6}, loc=2)
        plt.axvline(x=2004, color='black')
        # regressor values
        for m, xs in get_regressor_trace(region, i0, future_xs):
            ax = fig.add_subplot(gs[m, scenario])
            ax.axis([min(all_years), max(all_years),
                     xs_range[m][0], xs_range[m][1]])
            ax.set_xlabel("time")
            ax.set_ylabel(regressor_names[m])
            plt.plot(all_years, xs)
            plt.axvline(x=2004,color='black')
    plt.tight_layout()
    plt.savefig("%s-timeplot-%i%s" % (tag, i0, ext))
    plt.savefig("%s-timeplot-%i%s" % (tag, i0, ".png"))
    plt.close(fig)


def get_crop_bounds_final(region, i0, i_params, pmfunc,
                          xs, future_length):
    result = {}
    for x in xs:
        future_xs = [x] * future_length
        for i, low, upp in get_crop_bounds_trace(
            region, i0, i_params, pmfunc, future_xs):
            if i not in result:
                result[i] = ([], [])
            result[i][0].append(low[-1])
            result[i][1].append(upp[-1])
    for (ll, uu) in result.itervalues():
        assert len(ll) == len(xs)
        assert len(uu) == len(xs)
    for i, (ll, uu) in result.iteritems():
        yield (i, ll, uu)


def main_regressor_plots(tag, region, i0, i_params, pmfunc,
                         color, hatch, label, linestyle, soilname):
    """Long-run crop distribution under different scenarios."""
    i1_params = tuple((i[1], js, params)
                      for i, js, params in i_params if i[0] == i0)
    i1s = tuple(i1 for i1, js, params in i1_params)
    assert i1s == tuple(sorted(set(
        itertools.chain.from_iterable(js for i, js, params in i_params))))
    nitros = range(30, 171, 5)
    rains = range(30, 101, 5)
    vlabel_vs_xs_s = (
        ("nitrogen price", nitros, [(1, 60, nitro) for nitro in nitros]),
        ("rain",           rains,  [(1, rain, 110) for rain in rains]),
        )
    # range for each of the components of x
    xs_range = (None, (0, 120), (60, 150))
    gs = matplotlib.gridspec.GridSpec(3, 2)
    fig = plt.figure(figsize=(8, 6))
    regressor_names = [None, "rain", "nitrogen price"]

    for scenario, (vlabel, vs, xs) in enumerate(vlabel_vs_xs_s):
        print "soil = %s, scenario = %s" % (soilname[i0], str(xs))
        # crop bounds
        ax = fig.add_subplot(gs[0, scenario])
        ax.axis([min(vs), max(vs), 0, 1])
        #ax.set_title(...)
        ax.set_xlabel(vlabel)
        ax.set_ylabel("probability")

        for i, low, upp in get_crop_bounds_final(
            region, i0, i_params, pmfunc, xs, 10):
            kwargs = dict(
                color=color[i],
                hatch=hatch[i],
                alpha=0.5,
                label=label[i],
                linewidth=2,
                )
            ax.fill_between(vs, low, upp, **kwargs)

        kwargs = dict(
            alpha=0.5,
            linewidth=2,
            )
        ps = [
            plt.Rectangle(
                (0, 0), 1, 1, hatch=hatch[i], color=color[i], **kwargs)
            for i in i1s]
        labels = [label[i] for i in i1s]
        ax.legend(ps, labels, prop={'size': 6}, loc=2)
        plt.axvline(x=2004,color='black')
        # regressor values
        for m in [1, 2]:
            ax = fig.add_subplot(gs[m, scenario])
            ax.axis([min(vs), max(vs),
                     xs_range[m][0], xs_range[m][1]])
            ax.set_xlabel(vlabel)
            ax.set_ylabel(regressor_names[m])
            plt.plot(vs, [x[m] for x in xs])
    plt.tight_layout()
    plt.savefig("%s-regressorplot-%i%s" % (tag, i0, ext))
    plt.savefig("%s-regressorplot-%i%s" % (tag, i0, ".png"))
    plt.close(fig)


def mean_with_default(xs, x):
    return np.mean(xs) if xs else x


def main2(tag, get_i_params, pmfunc, xijs_test):
    # store in a tuple so we can iterate multiple times
    i_params = {i : (js, tuple(params)) for i, js, params in get_i_params()}
    j_pred = []
    num_classes = []
    for x, i, j in xijs_test:
        js, params = i_params[i]
        j_predicted = set(js[np.argmax(pmfunc(param, x))] for param in params)
        j_pred.append(j in j_predicted)
        num_classes.append(len(j_predicted))
    # percentage of precise classifications
    determinacy = np.mean([n == 1 for n in num_classes])
    single_accuracy = np.mean([ok for ok, n in zip(j_pred, num_classes) if n == 1])
    indeterminate_output_size = mean_with_default(
        [n for n in num_classes if n >= 2], 2.0)
    set_accuracy = mean_with_default(
        [ok for ok, n in zip(j_pred, num_classes) if n >= 2], 1.0)
    result = np.array([determinacy, single_accuracy, indeterminate_output_size, set_accuracy])
    print "***", result
    return result


# TODO this can be generalised to take any list of get_i_params functions
# along with a correspondling list of pmfunc
def main3(tag, xijs_test, get_i_params_m, get_i_params_l):
    """
    Compare cases where logistic model
    outputs multiple predicted j with cases where the
    multinomial model outputs single j 
    """
    # store in a tuple so we can iterate multiple times
    i_params_m = {i : (js, tuple(params)) for i, js, params in
                  get_i_params_m()}
    i_params_l = {i : (js, tuple(params)) for i, js, params in
                  get_i_params_l()}
    num_classes_m = []
    num_classes_l = []
    j_pred_m = []
    j_pred_l = []
    for x, i, j in xijs_test:
        js_m, params_m = i_params_m[i]
        j_predicted_m = set(js_m[np.argmax(dirichlet_prior.multinom_pmf(param, x))]
                            for param in params_m)
        js_l, params_l = i_params_l[i]
        j_predicted_l = set(js_l[np.argmax(logistic_prior.logistic_pmf(param, x))]
                            for param in params_l)
        j_pred_m.append(j in j_predicted_m)
        j_pred_l.append(j in j_predicted_l)
        num_classes_m.append(len(j_predicted_m))
        num_classes_l.append(len(j_predicted_l))
    # we expect num_classes_m to be 1 everywhere
    accuracy_if_single = [
        [mean_with_default(
            [ok for ok, n in zip(j_pred, num_classes) if n == 1], 1.0)
         for j_pred in [j_pred_m, j_pred_l]]
        for num_classes in [num_classes_m, num_classes_l]]
    accuracy_if_set = [
        [mean_with_default(
            [ok for ok, n in zip(j_pred, num_classes) if n >= 2], 1.0)
         for j_pred in [j_pred_m, j_pred_l]]
        for num_classes in [num_classes_m, num_classes_l]]
    for row, cls in zip(accuracy_if_single, ["m", "l"]):
        for acc, which in zip(row, ["m", "l"]):
            print "accuracy of %s if %s has single class" % (which, cls), acc
    for row, cls in zip(accuracy_if_set, ["m", "l"]):
        for acc, which in zip(row, ["m", "l"]):
            print "accuracy of %s if %s has set classes " % (which, cls), acc
    flatten = lambda x: [item for sublist in x for item in sublist]
    return flatten(accuracy_if_single) + flatten(accuracy_if_set)


def main_all_plots_from_all_data_map(tag, region, plots):
    xijss = get_xijss(region, parts=1)
    main(
        tag, region,
        get_i_params=get_i_betas(xijss[0]),
        pmfunc=logistic_prior.logistic_pmf,
        plots=plots)


def main_all_plots_from_all_data_mcmc(tag, region):
    xijss = get_xijss(region, parts=1)
    main(
        tag, region,
        get_i_params=get_i_betasamples(xijss[0]),
        pmfunc=logistic_prior.sample_logistic_pmf)


def main2_validate_map(tag, region, parts, get_i_params_from_xijs, pmfunc):
    """n-fold cross validation for the specified region."""
    xijss = get_xijss(region, parts=parts)
    percent_correct = []
    average_class_size = []
    result = []
    for k_test, xijs_test in enumerate(xijss):
        xijs_train = list(itertools.chain.from_iterable(
            [xijs for k_train, xijs in enumerate(xijss) if k_train != k_test]))
        result.append(main2(
            tag=tag,
            get_i_params=get_i_params_from_xijs(xijs_train),
            pmfunc=pmfunc,
            xijs_test=xijs_test,
            ))
    print "%s: %i-fold cross validation" % (tag, parts)
    print np.mean(np.array(result), axis=0)


def main3_validate(tag, region, parts):
    """n-fold cross validation for the specified region."""
    xijss = get_xijss(region, parts=parts)
    result = []
    for k_test, xijs_test in enumerate(xijss):
        xijs_train = list(itertools.chain.from_iterable(
            [xijs for k_train, xijs in enumerate(xijss) if k_train != k_test]))
        result.append(main3(
            tag=tag,
            get_i_params_m=get_i_thetas(xijs_train),
            get_i_params_l=get_i_betas(xijs_train),
            xijs_test=xijs_test,
            ))
    print "%s: %i-fold cross validation" % (tag, parts)
    print np.mean(np.array(result), axis=0)


# this function is probably broken due to different soil types and
# possibly different categories in each of the regions!
def main3_validate_map():
    """region to region cross validation""" 
    xijss = [get_xijss(region="anglia", parts=1)[0],
             get_xijss(region="mease", parts=1)[0]
             ]
    percent_correct = []
    average_class_size = []
    for k_test, xijs_test in enumerate(xijss):
        xijs_train = itertools.chain.from_iterable(
            [xijs for k_train, xijs in enumerate(xijss) if k_train != k_test])
        correct, size = main2(
            tag=tag,
            get_i_params=get_i_betas(xijs_train),
            pmfunc=logistic_prior.logistic_pmf,
            xijs_test=xijs_test,
            )
        percent_correct.append(correct)
        average_class_size.append(size)
    print "region to region validation"
    print "average percentage correct = %f" % np.mean(percent_correct)
    print "average class size = %f" % np.mean(average_class_size)


# isipta15 plots start

main_all_plots_from_all_data_map(
    tag="anglia", region="anglia", plots={'time'})
main_all_plots_from_all_data_map(
    tag="all", region="all", plots={'probability', 'density', 'decision'})

# isipta15 plots end

# isipta15 cross validation start

main2_validate_map(
    tag="mease_cv10_logistic", region="mease", parts=10,
    get_i_params_from_xijs=get_i_betas,
    pmfunc=logistic_prior.logistic_pmf)
main2_validate_map(
    tag="anglia_cv10_logistic", region="anglia", parts=10,
    get_i_params_from_xijs=get_i_betas,
    pmfunc=logistic_prior.logistic_pmf)
main2_validate_map(
    tag="all_cv10_logistic", region="all", parts=10,
    get_i_params_from_xijs=get_i_betas,
    pmfunc=logistic_prior.logistic_pmf)
main2_validate_map(
    tag="all_cv10_multinom", region="all", parts=10,
    get_i_params_from_xijs=get_i_thetas,
    pmfunc=dirichlet_prior.multinom_pmf)

# isipta15 cross validation end

# extra stuff

#main_all_plots_from_all_data_map(
#    tag="testrb", region="anglia", plots={'decision_rb'})

#main3_validate(tag="all_cv10_log_mul_combined", region="all", parts=10)
#main3_validate_map()

