# Copyright (C) 2015 Matthias C. M. Troffaes & Lewis Paton
# matthias.troffaes@durham.ac.uk, l.w.paton@durham.ac.uk
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

"""Code for calculating lower and upper expectations and probability
mass functions.
"""

# param = parameters over which to take an envelope
# extra = any additional parameters, not to take envelope over

# pmfunc(param, extra) must return a numpy array containing the
# probability mass of every singleton

# func is just a random variable, can be any sequence (tuple, list, array, ...)

def get_expectation(func, pmfunc, param, extra):
    return pmfunc(param, extra).dot(func)

def get_lower_expectation(func, pmfunc, params, extra):
    return min(get_expectation(func, pmfunc, param, extra) for param in params)

def get_upper_expectation(func, pmfunc, params, extra):
    return max(get_expectation(func, pmfunc, param, extra) for param in params)

def get_lower_pmf(pmfunc, params, extra):
    return np.min(np.vstack((pmfunc(param, extra) for param in params)), axis=0)

def get_upper_pmf(pmfunc, params, extra):
    return np.max(np.vstack((pmfunc(param, extra) for param in params)), axis=0)

# test code

import numpy as np

def test_get_expectation():
    pmfunc = lambda param, extra: np.array([param, 1 - param])
    func = (-1, 2)
    np.testing.assert_almost_equal(
        get_expectation(func, pmfunc, 0.5, None), 0.5)
    np.testing.assert_almost_equal(
        get_expectation(func, pmfunc, 0.1, None), 1.7)
    np.testing.assert_almost_equal(
        get_lower_expectation(func, pmfunc, (0.1, 0.5), None), 0.5)
    np.testing.assert_almost_equal(
        get_upper_expectation(func, pmfunc, (0.1, 0.5), None), 1.7)
    np.testing.assert_almost_equal(
        get_lower_pmf(pmfunc, (0.1, 0.5), None), (0.1, 0.5))
    np.testing.assert_almost_equal(
        get_upper_pmf(pmfunc, (0.1, 0.5), None), (0.5, 0.9))
