#include "petri_net_sog.hpp"

#include <iostream>
#include <map>
#include <set>
#include <stack>
#include <string>
#include <vector>

#include "Net.hpp"
#include "bvec.h"

using namespace std;

// BDD initial values
constexpr int BDD_INITIAL_NUM_NODES = 1000000;
constexpr int BDD_SIZE_CACHES = 1000000;

// vector of model's places for the print handler
const vector<Place> *v_places = nullptr;

/**
 * Handler for errors in the BDD package
 * @param err_code
 */
void BDDErrorHandler(int err_code) {
  bdd_default_errhandler(err_code);
}

/**
 * Handler to convert the FDD integer identifier into something readable by
 the
 * end user
 * @param o
 * @param var
 */
void PrintHandler(ostream &o, int var) {
  o << (*v_places)[var / 2].name;
  if (var % 2) {
    o << "_p";
  }
}

/**
 * Select the first firable transition that is not covered
 * @param covered_trans Set of covered transitions
 * @param firable_trans Set of firable transitions
 * @return the first firable transition that is not covered
 */
int SelectFirableTrans(Set covered_trans, const Set &firable_trans) {
  for (int firable_tran : firable_trans) {
    if (covered_trans.find(firable_tran) == covered_trans.end()) {
      return firable_tran;
    }
  }

  return -1;
}

void ReinitCycle(const Path &trace, map<int, int> &trans_obs) {
  for (auto t : trace) {
    if (trans_obs[t] > 0) {
      trans_obs[t] = 2;
    }
  }
}

/*****************************************************************************/
/*                                Class Trans                                */
/*****************************************************************************/
Trans::Trans(const bdd &var, bddPair *pairs_table, const bdd &postrel,
             const bdd &prerel, const bdd &precond, const bdd &postcond)
    : var(var),
      pairs_table(pairs_table),
      precond(precond),
      postcond(postcond),
      postrel(postrel),
      prerel(prerel) {}

bdd Trans::operator()(const bdd &op) const {
  const bdd res = bdd_relprod(op, postrel, var);
  return bdd_replace(res, pairs_table);
}

bdd Trans::operator[](const bdd &op) const {
  const bdd res = bdd_relprod(op, prerel, var);
  return bdd_replace(res, pairs_table);
}

/*****************************************************************************/
/*                         Class PetriNetSOG                                 */
/*****************************************************************************/
PetriNetSOG::PetriNetSOG(const net &petri_net, const map<int, int> &obs_trans,
                         Set non_obs_trans, const int bound, const bool init)
    : nb_places(petri_net.places.size()) {
  auto bvec_vars = new bvec[nb_places];
  auto vp = new bvec[nb_places];

  // pre and post conditions of the transitions
  auto pre_arc = new bvec[nb_places];
  auto post_arc = new bvec[nb_places];

  auto id_var = new int[nb_places];
  auto idvp = new int[nb_places];

  // number of bdd variables used for each place
  auto nb_bdd_vars = new int[nb_places];

  // initialize the BDD
  if (init) {
    bdd_init(BDD_INITIAL_NUM_NODES, BDD_SIZE_CACHES);
  }

  // Suppress GC messages
  bdd_gbc_hook(nullptr);

  // the error handler
  bdd_error_hook(BDDErrorHandler);

  // petri net's places
  v_places = &petri_net.places;
  fdd_strm_hook(PrintHandler);

  // transitions of the petri net model and a mapping from transition names to
  // their identifiers
  transitions = petri_net.transitions;
  transitions_names = petri_net.transitionName;

  // add the set of observable transitions
  for (auto t : obs_trans) {
    observables.insert(t.first);
  };

  // since non_observables is passed by value, we can move it
  non_observables = std::move(non_obs_trans);

  // create bdd variables for each model's place
  int domain = 0;
  for (const auto &place : petri_net.places) {
    // the default domain
    domain = (place.hasCapacity() ? place.capacity : bound) + 1;

    // variables are created one by one (implying contiguous binary variables)
    fdd_extdomain(&domain, 1);
    fdd_extdomain(&domain, 1);
  }

  // initialize the bdd variables
  for (int p = 0; p < nb_places; p++) {
    int var = 2 * p;
    nb_bdd_vars[p] = fdd_varnum(var);
    bvec_vars[p] = bvec_varfdd(var);
    vp[p] = bvec_varfdd(var + 1);
  }

  // initial marking
  m0 = bdd_true();
  int offset = 0;
  for (const auto &place : petri_net.places) {
    m0 = m0 & fdd_ithvar(2 * offset, place.marking);
    offset++;
  }

  /* Transition relation */
  for (auto t = petri_net.transitions.begin(); t != petri_net.transitions.end();
       ++t) {
    bdd postrel = bdd_true();
    bdd var = bdd_true();
    bdd prerel = bdd_true();
    bdd precond = bdd_true();
    bdd postcond = bdd_true();

    // initialize the pre adn post arcs vectors with 0
    for (int p = 0; p < nb_places; p++) {
      pre_arc[p] = bvec_con(nb_bdd_vars[p], 0);
      post_arc[p] = bvec_con(nb_bdd_vars[p], 0);
    }

    // pre arcs
    Set adjacent_places;
    for (auto it = t->pre.begin(); it != t->pre.end(); ++it) {
      int place = it->first;
      adjacent_places.insert(place);
      pre_arc[place] =
          pre_arc[place] + bvec_con(nb_bdd_vars[place], it->second);
    }

    // post arcs
    for (auto it = t->post.begin(); it != t->post.end(); ++it) {
      int place = it->first;
      adjacent_places.insert(place);
      post_arc[place] =
          post_arc[place] + bvec_con(nb_bdd_vars[place], it->second);
    }

    int nb_pairs = 0;
    for (const auto place : adjacent_places) {
      id_var[nb_pairs] = 2 * place;
      idvp[nb_pairs] = 2 * place + 1;
      var = var & fdd_ithset(2 * place);

      // image
      // precondition
      postrel = postrel & (bvec_vars[place] >= pre_arc[place]);
      precond = precond & (bvec_vars[place] >= pre_arc[place]);
      // postcondition
      postrel =
          postrel &
          (vp[place] == (bvec_vars[place] - pre_arc[place] + post_arc[place]));

      // pre-image
      // precondition
      prerel = prerel & (bvec_vars[place] >= post_arc[place]);
      // postcondition
      postcond = postcond & (bvec_vars[place] >= post_arc[place]);
      prerel = prerel & (vp[place] ==
                         (bvec_vars[place] - post_arc[place] + pre_arc[place]));

      // capacity
      if (petri_net.places[place].hasCapacity()) {
        postrel =
            postrel & (vp[place] <= bvec_con(nb_bdd_vars[place],
                                             petri_net.places[place].capacity));
      }
      nb_pairs++;
    }

    // variable pairs are used in bdd_replace to define which variables to
    // replace with other variables.
    bddPair *vars_pair_table = bdd_newpair();
    fdd_setpairs(vars_pair_table, idvp, id_var, nb_pairs);
    relation.emplace_back(var, vars_pair_table, postrel, prerel, precond,
                          postcond);
  }

  // remove vectors
  delete[] bvec_vars;
  delete[] vp;
  delete[] pre_arc;
  delete[] post_arc;
  delete[] id_var;
  delete[] idvp;
  delete[] nb_bdd_vars;
}

bdd PetriNetSOG::AccessibleEpsilon(const bdd &from) const {
  bdd m1;
  bdd m2 = from;

  do {
    m1 = m2;
    for (const int i : non_observables) {
      m2 = relation[i](m2) | m2;
    }
  } while (m1 != m2);

  return m2;
}

<<<<<<< HEAD
bdd PetriNetSOG::OneStepBackUnobs(const bdd &from,
                                  const Aggregate *aggr) const {
  bdd res = bdd_false();
  for (const auto t : non_observables) {
    bdd pred = relation[t][from];
    if ((pred & aggr->bdd_state) != bdd_false()) {
      res = res | pred;
    }
  }

  return res;
}
=======

>>>>>>> 5cde081 (calcul poids chemins observés)

pair<int, bdd> PetriNetSOG::StepBackward(const bdd &from,
                                         const Aggregate *aggr) const {
  pair<int, bdd> res;

  for (const auto t : non_observables) {
    bdd succ = relation[t][from];

    // function that returns the preceding bdd with the transition t
    if ((succ != bdd_false()) & ((succ &= aggr->bdd_state) != bdd_false())) {
      res.first = t;
      res.second = succ;
      break;
    }
  }

  return res;
}

bdd PetriNetSOG::GetSuccessor(const bdd &from, const int t) const {
  return relation[t](from);
}

Set PetriNetSOG::FirableObservableTrans(const bdd &from) const {
  Set res;

  for (int t : observables) {
    if (relation[t](from) != bddfalse) {
      res.insert(t);
    }
  }

  return res;
}
std::pair<bdd,set<int>> PetriNetSOG::OneStepBackUnobs(const bdd &from,
                                  const Aggregate *aggr) const {
    bdd res = bdd_false();
    std::pair<bdd,set<int>> C;
    //std::cout<<"One Step Back"<<std::endl;
    set<int> fireBack;
    for (const auto t : non_observables) {
        bdd pred = relation[t][from]-from;
        if((pred & aggr->bdd_state) != bdd_false()){
            res = res | pred;
            fireBack.insert(t);
            //std::cout<<transitions[t].name<<std::endl;
        }

    }
    C.first=res;
    C.second=fireBack;
    return C;
}
std::vector<int> PetriNetSOG::ComputeWeightObsPaths(std::vector<std::vector<string>> obsPaths, const SOG &g)
{
    //std::cout<<"ICI compute weight obs paths"<<endl;
    //std::cout<<obsPaths.size()<<" paths received"<<std::endl;
    std::vector<int> w;
    for(int i=0;i<obsPaths.size();i++)
        w.push_back(0);
    int cur=0;

    //std::cout<<"ICI compute weight obs paths"<<endl;

    for(std::vector<std::string> op: obsPaths)
    {
        //std::cout<<"Obs Ath"<<std::endl;
        Aggregate *a=g.initial_state;
        Aggregate *sa;
        w[cur]+=ComputeSingleInitAggregate(g.initial_state, transitions_names[op[0]]);
        for(int i=0;i<op.size()-1;i++)
        {
            //std::cout<<op[i]<<" ---> "<<op[i+1]<<std::endl;

            sa=a->GetsuccessorsOfTrans(transitions_names[op[i]]);
            w[cur]+=ComputeSingleWeight(sa,transitions_names[op[i]],transitions_names[op[i+1]]);
            a=sa;
        }
        w[cur]+=op.size();
        cur++;
    }
    return w;

}
int PetriNetSOG::ComputeSingleInitAggregate(const Aggregate *aggr, const int out)
{
    bdd source = SearchExitPoints(aggr->bdd_state, out);
    const bdd dest =m0;
    int counter = 0;
    while ((source & dest) == bdd_false()) {
        std::pair<bdd,std::set<int>> step=OneStepBackUnobs(source, aggr);
        source = step.first;
        counter++;
    }
    return counter;
}
int PetriNetSOG::ComputeSingleWeight(const Aggregate *aggr, const int in,
                                     const int out) const {
  bdd source = SearchExitPoints(aggr->bdd_state, out);
  const bdd dest = relation[in]((aggr->GetPredecessorsOfTrans(in))->bdd_state);

  // backtracking from t_target until t_source
  int counter = 0;
  while ((source & dest) == bdd_false()) {
    std::pair<bdd,std::set<int>> step=OneStepBackUnobs(source, aggr);
    source = step.first;
    counter++;
  }

  return counter;
}

void PetriNetSOG::ComputeRowWeight(const int in, Aggregate *a) const {
  std::vector<int> row;
  row.reserve(a->columns.size());  // Preallocates memory for n elements

  for (const int out : a->columns) {
    row.push_back(ComputeSingleWeight(a, in, out));
  }

  a->AddWeightRow(in, row);
}

void PetriNetSOG::GenerateSOG(SOG &sog) const {
  Stack st;

  // construction of the first aggregate
  auto *new_aggregate = new Aggregate;
  const bdd complete_aggr = AccessibleEpsilon(m0);
  new_aggregate->bdd_state = complete_aggr;

  sog.set_initial_state(new_aggregate);
  sog.AddState(new_aggregate);

  // Generate the successor states
  Set firable_trans = FirableObservableTrans(complete_aggr);

  // initialize the rows of weights
  new_aggregate->InitWeightColumns(firable_trans);

  st.emplace(new_aggregate, firable_trans);
  while (!st.empty()) {
    StackElt curr_aggr = st.top();
    st.pop();

    if (!curr_aggr.firable_trans.empty()) {
      int t = *curr_aggr.firable_trans.begin();

      // remove the handled transition and put the same aggregate in the stack
      curr_aggr.firable_trans.erase(t);
      st.push(curr_aggr);

      auto *succ_aggr = new Aggregate;
      const bdd complete_succ_aggr =
      AccessibleEpsilon(GetSuccessor(curr_aggr.aggregate->bdd_state, t));
      succ_aggr->bdd_state = complete_succ_aggr;

      Aggregate *pos = sog.FindState(succ_aggr);

      // if aggregate does not exist in the sog
      if (!pos) {
        sog.AddState(succ_aggr);
        sog.AddArc(curr_aggr.aggregate, succ_aggr, t);

        // if the successor has firable transitions, it is added to the
        // stack
        firable_trans = FirableObservableTrans(complete_succ_aggr);
        if (!firable_trans.empty()) {
          st.emplace(succ_aggr, firable_trans);

          // initialize the rows of weights
          succ_aggr->InitWeightColumns(firable_trans);

          // compute a row weight for the input transition t and add it to the
          // successor aggregate
          ComputeRowWeight(t, succ_aggr);
        }
      } else {
        sog.AddArc(curr_aggr.aggregate, pos, t);
        ComputeRowWeight(t, pos);

        delete succ_aggr;
      }
    }
  }
}

Paths PetriNetSOG::ObservablePaths(SOG &sog, map<int, int> trans_obs) const {
  Paths observable_paths;
  Path current_trace;
  Set covered_trans;
  Set firable_trans;
  Stack st;

  // construction of the first aggregate
  auto *c = new Aggregate;
  {
    const bdd complete_aggr = AccessibleEpsilon(m0);
    c->bdd_state = complete_aggr;

    firable_trans = FirableObservableTrans(complete_aggr);
    st.emplace(c, firable_trans);
  }

  sog.set_initial_state(c);
  sog.AddState(c);

  // TODO: What is the purpose of the old variable?
  bool old = true;
  while (!st.empty()) {
    StackElt elt = st.top();
    st.pop();

    // if there are firable transitions
    if (!elt.firable_trans.empty()) {
      // choose a transition from the firable transitions
      int t = SelectFirableTrans(covered_trans, elt.firable_trans);
      if (t != -1) {
        old = false;
        trans_obs[t]--;
        current_trace.push_back(t);
        if (trans_obs[t] == 0) {
          covered_trans.insert(t);
          // check if all observable transitions are covered and return the
          // set of generated paths
          if (covered_trans.size() == observables.size()) {
            observable_paths.insert(current_trace);
            return observable_paths;
          }
        }
      } else {
        // case: when there is no firable transition that is not covered,
        // take the first firable transition
        t = *elt.firable_trans.begin();
        current_trace.push_back(t);
      }

      // remove the handled transition and put the same aggregate in the
      // stack again
      elt.firable_trans.erase(t);
      st.push(elt);

      // computes the successor
      {
        auto *reached_aggr = new Aggregate;
        bdd complete_aggr =
            AccessibleEpsilon(GetSuccessor(elt.aggregate->bdd_state, t));
        reached_aggr->bdd_state = complete_aggr;
        Aggregate *pos = sog.FindState(reached_aggr);

        // if aggregate does not exist in the sog
        if (!pos) {
          firable_trans = FirableObservableTrans(complete_aggr);

          // add the aggregate and its predecessors to the graph
          sog.AddState(reached_aggr);
          sog.AddArc(elt.aggregate, reached_aggr, t);

          // if the aggregate has no firable transitions
          if (firable_trans.empty()) {
            if (!old) {
              observable_paths.insert(current_trace);
            }
            ReinitCycle(current_trace, trans_obs);

            current_trace.pop_back();
          } else {
            st.emplace(reached_aggr, firable_trans);
          }
        } else {  // aggregate already exists
          if (!old) {
            observable_paths.insert(current_trace);
          }
          ReinitCycle(current_trace, trans_obs);

          current_trace.pop_back();
          sog.AddArc(elt.aggregate, pos, t);
          delete reached_aggr;
          old = true;
        }
      }
    } else {
      current_trace.pop_back();
      old = true;
    }
  }

  return observable_paths;
}

stack<AggrPair> PetriNetSOG::SearchEntryPoints(Path path,
                                               const SOG &sog) const {
  AggrPair p;
  stack<AggrPair> pt_entr;

  bdd entree = m0;
  Aggregate *agr = sog.initial_state;

  p.first = agr;
  p.second = entree;
  pt_entr.push(p);

  for (auto k = path.begin(); k != path.end() - 1; ++k) {
    const int t = *k;
    entree = relation[t](p.first->bdd_state);
    p.second = entree;

    for (const auto &succ : agr->successors) {
      if (succ.transition == t) {
        agr = succ.state;
        break;
      }
    }

    p.first = agr;
    pt_entr.push(p);
  }

  return pt_entr;
}

bdd PetriNetSOG::SearchExitPoints(const bdd &from, const int t) const {
  return bddfalse | (from & relation[t].precond);
}

Path PetriNetSOG::AbstractPath(Path path, const SOG &sog) const {
  bdd source;
  Path abstract_paths;
  stack<AggrPair> entry_points = SearchEntryPoints(path, sog);

  bool flag = false;
  while (!entry_points.empty()) {
    int trans = *(path.end() - 1);
    const AggrPair entry_aggr = entry_points.top();
    entry_points.pop();

    const Aggregate *aggr = entry_aggr.first;
    const bdd target = entry_aggr.second;

    source = flag ? relation[trans][source]
                  : SearchExitPoints(aggr->bdd_state, trans);

    Path path_aggregate = SubPathAggregate(&source, target, aggr);
    abstract_paths.insert(abstract_paths.begin(), trans);
    abstract_paths.insert(abstract_paths.begin(), path_aggregate.begin(),
                          path_aggregate.end());
    path.pop_back();
    flag = true;
  }

  return abstract_paths;
}

Path PetriNetSOG::SubPathAggregate(bdd *source, const bdd &target,
                                   const Aggregate *aggr) const {
  Path path;

  bdd current_state = *source;
  while ((target & current_state) == bdd_false()) {
    pair<int, bdd> couple = StepBackward(current_state, aggr);
    path.insert(path.begin(), couple.first);
    current_state = couple.second;
  }

  *source = current_state;
  return path;
}