#include "ellib.h"
#include <fstream>
#include <iterator>

using namespace ellib;

// Structure parameters
double totalSize = 1;
double strain = 0.2;
double modulus = 1.4;
double meshSize = totalSize / 20;

// Minimisation
double convergenceFactor = 1e-2;
double maxStep = 0.1;
int iterMax = 20000;
int iterRamp = 500;
double forceFactor = 1e-3;

// Genetic algorithm
int popSize = 48;
int maxGen = 20;
double pertubation = 0.3;
double mutationRate = 0.25;
double selectionRate = 0.33;


template <typename T = double>
vector<T> loadData(std::string filename, int nCol=3) {
  std::ifstream file(filename);
  int nLine = std::count(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), '\n');
  vector<T> data(nCol*nLine);
  file.clear();
  file.seekg(0);
  for (int iL=0; iL<nLine; iL++) {
    for (int iC=0; iC<nCol; iC++) {
      file >> data[nCol*iL+iC];
    }
  }
  file.close();
  return data;
}


template <typename T = double>
vector2d<T> loadData2d(std::string filename, int nCol=3) {
  auto data1d = loadData<T>(filename, nCol);
  int nRow = data1d.size() / nCol;
  vector2d<T> data2d(nRow, vector<T>(nCol));
  for(int i=0; i<nRow*nCol; i++) {
    data2d[i/nCol][i%nCol] = data1d[i];
  }
  return data2d;
}


template <typename T>
void outputData(vector<T> coords, std::string filename, int ncol=3, int rank=0) {
  if (mpi.rank != rank) return;
  std::ofstream file(filename);
  for (int i=0; i<(int)coords.size()/ncol; i++) {
    for (int j=0; j<ncol; j++) {
      file << coords[ncol*i+j] << " ";
    }
    file << std::endl;
  }
  file.close();
}

template <typename T>
void outputData(vector2d<T> coordList, std::string filename, int ncol=3, int rank=0) {
  vector<double> allCoords;
  for (auto& coords: coordList) allCoords.insert(allCoords.end(), coords.begin(), coords.end());
  outputData(allCoords, filename, ncol, rank);
}


vector<bool> getFixedNodes(vector<int> regions) {
  vector<bool> fixed;
  for (int region: regions) {
    if (region < 0) {
      fixed.push_back(true);
      fixed.push_back(true);
      fixed.push_back(true);
    } else {
      fixed.push_back(false);
      fixed.push_back(false);
      fixed.push_back(false);
    }
  }
  return fixed;
}


vector<double> getParameters(double legWidthFrac, double thicknessFrac, double outerRadFrac, double innerRadFrac) {
  double minWidth = sqrt(3.0) * meshSize;
  double legWidth = totalSize/4 * legWidthFrac;
  if (legWidth < minWidth) legWidth = minWidth;
  double thickness = totalSize/20 * thicknessFrac;
  double outerRad = totalSize/3 * outerRadFrac;
  double innerRad = (outerRad-minWidth) * innerRadFrac;
  vector<double> params{legWidth, thickness, outerRad, innerRad};
  for (auto &p: params) {
    p = std::round(p*1e6) / 1e6;
  }
  return params;
}
vector<double> getParameters(vector<double> params) {
  return getParameters(params[0], params[1], params[2], params[3]);
}


State initialise(vector<double> params0, vector<int>& regions) {
  vector<double> params = getParameters(params0[0], params0[1], params0[2], params0[3]);
  double thickness = params[1];
  
  // Generate the mesh
  char command[999];
  std::sprintf(command, "mkdir -p data/width-%.6f-outRad-%.6f-inRad-%.6f", params[0], params[2], params[3]);
  system(command);
  std::sprintf(command, "matlab -nodisplay -nosplash -nodesktop -r \"genmesh %.6f %.6f %.6f; exit\" > /dev/null", params[0], params[2], params[3]);
  system(command);

  // Read data
  char filename[999];
  std::sprintf(filename, "data/width-%.6f-outRad-%.6f-inRad-%.6f/tlist.txt", params[0], params[2], params[3]);
  auto tri = loadData2d<int>(filename, 3);
  std::sprintf(filename, "data/width-%.6f-outRad-%.6f-inRad-%.6f/coords.txt", params[0], params[2], params[3]);
  auto coords = loadData(filename, 3);
  std::sprintf(filename, "data/width-%.6f-outRad-%.6f-inRad-%.6f/regions.txt", params[0], params[2], params[3]);
  regions = loadData<int>(filename, 1);

  // Initialise the flat structure
  BarAndHinge pot;
  pot.setModulus(modulus);
  pot.setThickness(thickness);
  pot.setRigidity({1e3*modulus*pow(thickness,3)}, vector<double>{}); // Scale the stretching rigidity so it does not become huge for small thicknesses
  pot.setTriangulation(tri);
  pot.setFixed(getFixedNodes(regions));
  State state(pot, coords, {mpi.rank});
  state.convergence = convergenceFactor * modulus*pow(thickness,3) * pow(meshSize,2)/totalSize;
  return state;
}


std::function<void(int, State&)> initialRamp(vector2d<double> regionForces, vector<int> regions, Minimiser& min) {
  return [regionForces, regions, &min] (int iter, State& state) {
    static double convergence;
    static vector2d<double> regionStrainStep;

    if (iter == 0) {
      convergence = state.convergence;
      state.convergence = 0;
      min.setLinesearch("none");

      // Get the strain step size to apply to each fixed region
      vector<double> coords = state.coords();
      int nFixed = - *(std::min_element(regions.begin(), regions.end()));
      regionStrainStep = vector2d<double>(nFixed, {0,0,0});
      for (int iR=0; iR<nFixed; iR++) {
        vector<double> regionPos = {0,0,0};
        int nNodes = 0;
        for (int iN=0; iN<(int)state.ndof/3; iN++) {
          if (regions[iN]==-iR-1) {
            regionPos[0] += coords[3*iN];
            regionPos[1] += coords[3*iN+1];
            regionPos[2] += coords[3*iN+2];
            nNodes++;
          }
        }
        if (nNodes==0) continue;
        regionPos /= nNodes;
        regionStrainStep[iR] = - regionPos * strain / iterRamp;
      }

      // Set the forces
      int nRegions = *(std::max_element(regions.begin(), regions.end()));
      vector<int> regionNodes(nRegions, 0);
      for (int iN=0; iN<(int)state.ndof/3; iN++) {
        if (regions[iN] > 0) regionNodes[regions[iN]-1] ++;
      }
      vector2d<double> forces(state.ndof/3, {0, 0, 0});
      for (int iN=0; iN<(int)state.ndof/3; iN++) {
        if (regions[iN] <= 0) continue;
        forces[iN] = regionForces[regions[iN]-1] / regionNodes[regions[iN]-1];
      }
      dynamic_cast<BarAndHinge&>(*state.pot).setForce(forces);
    }

    if (iter < iterRamp) {
      // Apply the strain
      vector<double> coords = state.coords();
      for (size_t i=0; i<state.ndof/3; i++) {
        if (regions[i] < 0) {
          auto step = regionStrainStep[-regions[i]-1];
          coords[3*i] += step[0];
          coords[3*i+1] += step[1];
        }
        else {
          coords[3*i] *= 1 - strain/iterRamp;
          coords[3*i+1] *= 1 - strain/iterRamp;
        }
      }
      state.coords(coords);

    } else if (iter == iterRamp) {
      // Turn off the force and allow convergence
      dynamic_cast<BarAndHinge&>(*state.pot).setForce({0,0,0});
      state.convergence = convergence;
    }
  };
}


bool checkState1(vector<double> coords) {
  // Ensure that state 1 is correct. It should be down in the top-right and up in the bottom-left
  double tr = 0;
  double bl = 0;
  for (int iNode=0; iNode<(int)coords.size()/3; iNode++) {
    double x = coords[3*iNode];
    double y = coords[3*iNode+1];
    double z = coords[3*iNode+2];
    if (x+y > 0) {
      tr += z;
    } else {
      bl += z;
    }
  }
  return (tr<0 && bl>0);
}

bool checkState2(vector<double> coords) {
  // Ensure that state 2 is correct. It should be down symmetric in the top-right and bottom-left
  double tr = 0;
  double bl = 0;
  for (int iNode=0; iNode<(int)coords.size()/3; iNode++) {
    double x = coords[3*iNode];
    double y = coords[3*iNode+1];
    double z = coords[3*iNode+2];
    if (x+y > 0) {
      tr += z;
    } else {
      bl += z;
    }
  }
  return (tr>0 && bl>0) && ((tr-bl)/bl<0.02);
}


class InverseBarrier: public NewPotential<InverseBarrier> {
  public:
    InverseBarrier() { _energyDef = true; };
    
    double energy(const vector<double>& coords) const override {
      static int call_num = -1;
      call_num++;
      int procStates = popSize / mpi.size;
      int ga_iter = call_num / procStates;
      int state_num = procStates*mpi.rank + (call_num % procStates);

      vector<int> regions;
      State initState = initialise(coords, regions);

      double thickness = coords[1];
      double initForce = forceFactor * modulus*pow(thickness,3) * pow(totalSize,2)/meshSize;
      vector2d<double> force1 = {{0,0,-initForce}, {0,0,-initForce}, {0,0,initForce}, {0,0,initForce},
                                 {0,0,-initForce}, {0,0,0},          {0,0,initForce}, {0,0,0}};
      vector2d<double> force2 = {{0,0,initForce}, {0,0,initForce}, {0,0,initForce}, {0,0,initForce},
                                 {0,0,initForce}, {0,0,initForce}, {0,0,initForce}, {0,0,initForce}};

      Lbfgs min;
      min.setM(10);
      min.setMaxIter(iterMax);

      // Get the minima under an applied strain
      State state1 = initState;
      State state2 = initState;
      min.minimise(state1, initialRamp(force1, regions, min));
      bool failed = (min.iter >= min.maxIter) || (!checkState1(state1.coords()));
      min.minimise(state2, initialRamp(force2, regions, min));
      failed = failed || (min.iter >= min.maxIter) || (!checkState2(state2.coords()));
      // If they did not properly converge return a barrier of zero
      if (failed) {
        printAll(state_num, "Minimisation did not converge");
        return 1/0.0;
      }

      // Find the transition state with BITSS
      min.setLinesearch("none");
      min.setMaxStep(maxStep*sqrt(initState.ndof)*meshSize);
      Bitss bitss(state1, state2, min);
      // bitss.setLog();
      bitss.setConvergenceMethod("energy");
      bitss.setConvergenceEnergy(0.1);
      bitss.setMaxIter(30);
      bitss.setDistStep(0.15 + 1e-3*(std::rand()%1000)/1000.0); // Small differences to prevent a state always converging early
      bitss.setMaxBarrier(0.5*(state1.energy()+state2.energy()));
      bitss.run();
      State ts = bitss.getTS();

      // Save the states
      char filename[999];
      std::sprintf(filename, "outputs/pop-states-%04d-%04d.txt", ga_iter, state_num);
      vector2d<double> stateCoords{state1.coords(), state2.coords(), ts.coords()};
      outputData(stateCoords, filename, state1.ndof, mpi.rank);

      // Get the barrier
      double e1 = state1.energy();
      double e2 = state2.energy();
      double ets = ts.energy();
      double barrier = 2*ets - e1 - e2;
      if (bitss.checkFailed()) barrier = 0;

      auto tsPair = bitss.getPair();
      double ets1 = tsPair[0].energy();
      double ets2 = tsPair[1].energy();
      printAll(state_num, "\tCoords:", coords, "\tBarrier:", barrier, "\tE:", e1, e2, "\t", ets, "\t", ets1, ets2);
      return 1 / barrier;
    }

    vector<double> gradient(const vector<double>& coords) const override {
      return vector<double>(coords.size());
    }
};


void outputPop(int iter, GenAlg& ga) {
  mpi.barrier();
  print("Generation:", iter, "Best:", *std::min_element(ga.popEnergies.begin(), ga.popEnergies.end()));

  // Output energies and parameters
  int nPop = ga.pop.size();
  vector2d<double> popParams(nPop);
  for (int iPop=0; iPop<nPop; iPop++) {
    popParams[iPop] = getParameters(ga.pop[iPop].allCoords());
  }
  char filename[999];
  std::sprintf(filename, "outputs/pop-%04d.txt", iter);
  outputData(popParams, filename, ga.pop[0].ndof);
  std::sprintf(filename, "outputs/pop-energy-%04d.txt", iter);
  outputData(ga.popEnergies, filename, 1);
}


int main(int argc, char** argv) {
  mpi.init(&argc, &argv);


  InverseBarrier pot;

  GenAlg ga(pot);
  ga.setPopSize(popSize);
  ga.setMaxIter(maxGen);
  ga.setBounds({0,0,0,0}, {1,1,1,1});
  ga.setPertubation({pertubation, pertubation, pertubation, pertubation});
  ga.setMutationRate(mutationRate);
  ga.setSelectionRate(selectionRate);
  ga.setIterFn(outputPop);
  ga.run();

  return 0;
}
