# Copyright (C) 2015 Matthias C. M. Troffaes & Lewis Paton
# matthias.troffaes@durham.ac.uk, l.w.paton@durham.ac.uk
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

"""Some functions to handle crop data."""

import csv
import itertools

def iterate_csv(filename, fields, types):
    with open(filename, "rb") as stream:
        for row in csv.DictReader(stream):
            yield tuple(
                typ(row[field]) for typ, field in itertools.izip(types, fields))

def crop_to_j(crop):
    #wheat = ["WH1", "WH2", "WH3", "SP1", "BU1"]
    #barley = ["BA1"]
    #oilseedrape = ["RA1", "RA6", "RA7", "RA8"]
    #sugarbeet = ["SU1"]
    #legumes = ["BE3", "BE7", "BE9", "BE12", "PE1", "PE5", "PE9"]
    #potato = ["PO1"]
    #oats = ["OA1"]

    if crop[:2] in {"WH", "SP", "BU"}:
        return "WH"
        # wheat2 is handled later
    elif crop[:2] in {"BE", "PE"}:
        return "LE"
    elif crop[:2] in {"RA"}:
        return "RA"
    else:
        return "OT"

def get_xijs(rain_filename, profit_filename, nitrogen_filename, crop_filename, soil_filename):
    """Return the data as a list of (x, i, j) tuples, where:

    * x is a tuple (1, x1, x2, ...) containing the regressor variables
      of the model such as rainfall, profit margin, and so on.

    * i is a tuple containing non-regressor variables such as soil
      type and previous crop grown. A distinct model is fitted for
      each value of these variables.

    * j is the actual crop grown.
    """
    # variables for book-keeping
    total = 0
    missing_year = 0
    missing_soil = 0
    # tables to be joined to full crop data
    rain = {
        year: rain for year, rain in iterate_csv(
            rain_filename, ("year", "rain"), (int, float))
        }
    profit = {
        (year, crop): profit
        for year, crop, profit in iterate_csv(
            profit_filename, ("year", "crop", "profit"), (int, str, float))
        }
    nitrogen = {
        year: price for year, price in iterate_csv(
            nitrogen_filename, ("year", "price"), (int, float))
        }
    soil = {
        field: soil
        for field, soil in iterate_csv(
            soil_filename, ("field", "soil"), (str, int))
        }
    last_j = None
    last_field = None
    last_year = None
    
    for year, field, crop in iterate_csv(
        crop_filename, ("year", "field", "crop"), (int, str, str)):
    
        j = crop_to_j(crop)
        
        if field == last_field:
            if year != last_year + 1:
                # missing year for a field, this happens occasionally
                missing_year += 1
            elif field not in soil:
                # missing soil for field
                # at the moment this happens at boundaries
                missing_soil += 1
            elif (last_j == "LE" and j == "RA"):
                # DEBUG
                #print("omitting LE -> RA")
                pass
            elif (last_j == "RA" and j == "LE"):
                # DEBUG
                #print("omitting RA -> LE")
                pass
            else:
                # second wheat is not stored in the database
                # we create a new category "WH2" for this
                assert j != "WH2"  # assure that WH2 is not actually used!!
                ## for the time being we do not use WH2
                ##if (last_j == "WH" or last_j == "WH2") and j == "WH":
                ##    j = "WH2"
                yield (
                    #(1, rain[year], profit[year, "RA"], profit[year, "LE"]),
                    (1, rain[year], nitrogen[year]),
                    (soil[field], last_j), j, year
                    )
        last_j = j
        last_field = field
        last_year = year
        total += 1

    #print("total = %i" % total)
    #print("missing year = %i" % missing_year)
    #print("missing soil = %i" % missing_soil)

# regression tests

import nose.tools

def test_get_xijs():
    nose.tools.assert_equal(
        list(get_xijs(
            rain_filename="test_data_rain.csv",
            profit_filename="test_data_profit.csv",
            nitrogen_filename="test_data_nitrogen.csv",
            crop_filename="test_data_crop.csv",
            soil_filename="test_data_soil.csv",
            )),
        [
            # field 1
            ((1, 45, 22), (3, 'WH'), 'RA', 1994),
            ((1, 48, 13), (3, 'RA'), 'WH', 1995),
            ((1, 39, 8), (3, 'WH'), 'WH', 1996),
            ((1, 63, 99), (3, 'WH'), 'RA', 1997),
            ((1, 51, 0), (3, 'RA'), 'OT', 1998),
            # field 2
            ((1, 45, 22), (9, 'RA'), 'WH', 1994),
            ((1, 48, 13), (9, 'WH'), 'WH', 1995),
            ((1, 39, 8), (9, 'WH'), 'RA', 1996),
            ((1, 63, 99), (9, 'RA'), 'OT', 1997),
            ((1, 51, 0), (9, 'OT'), 'WH', 1998),
            # field 3
            ((1, 45, 22), (5, 'RA'), 'RA', 1994),
            ((1, 48, 13), (5, 'RA'), 'WH', 1995),
            ((1, 39, 8), (5, 'WH'), 'WH', 1996),
            ((1, 63, 99), (5, 'WH'), 'OT', 1997),
            ((1, 51, 0), (5, 'OT'), 'RA', 1998),
        ]
        )
