package savilerow;
/*

    Savile Row http://savilerow.cs.st-andrews.ac.uk/
    Copyright (C) 2014-2020 Peter Nightingale
    
    This file is part of Savile Row.
    
    Savile Row is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.
    
    Savile Row is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.
    
    You should have received a copy of the GNU General Public License
    along with Savile Row.  If not, see <http://www.gnu.org/licenses/>.

*/

import java.util.*;

//  When the optimisation function is a sum, take the terms and tabulate each one with dominance.

public class TransformTabulateObj extends TreeTransformerBottomUpNoWrapper
{
    public static boolean verbose=false;
    
    public TransformTabulateObj(Model _m) {
        super(_m);
        min=(m.objective instanceof Minimising);
        optVar=m.objective.getChild(0);
        findOptCt(m.constraints);
        System.out.println(optCt);
        
        if( (optCt instanceof ToVariable && optCt.getChild(1).equals(optVar) && optCt.getChild(0) instanceof WeightedSum)
            || (optCt instanceof LessEqual && optCt.getChild(1).equals(optVar) && optCt.getChild(0) instanceof WeightedSum && min) 
            || (optCt instanceof LessEqual && optCt.getChild(0).equals(optVar) && optCt.getChild(1) instanceof WeightedSum && !min) ) {
            optTerms=optCt.getChild(0).getChildren();    /// BUG! Should be getChild(1) in some cases. 
            optTermsPolarity=new boolean[optTerms.size()];
            for(int i=0; i<optTerms.size(); i++) {
                optTermsPolarity[i]=((WeightedSum)optCt.getChild(0)).getWeight(i)>0;
            }
        }
        
        optTermCt=new ArrayList<>();
        for(int i=0; i<optTerms.size(); i++) {
            optTermCt.add(new ArrayList<ASTNode>());
        }
    }
    
    ASTNode optVar;
    
    ASTNode optCt=null;
    
    ArrayList<ASTNode> optTerms;
    boolean[] optTermsPolarity;
    
    ArrayList<ArrayList<ASTNode>> optTermCt;  // Collect constraints containing each term of the optimisation. 
    
    HashMap<ASTNode, Integer> idCount=new HashMap<>();
    
    private boolean min;
    
    void findOptCt(ASTNode a) {
        if(a.equals(optVar)) {
            ASTNode p=a.getParent();
            while(! p.isRelation()) {
                p=p.getParent();
            }
            optCt=p;
            return;
        }
        
        for(int i=0; i<a.numChildren() && optCt==null; i++) {
            findOptCt(a.getChild(i));
        }
    }
    
    // Use this pass to collect constraints mentioning terms of the opt variable. 
    protected NodeReplacement processNode(ASTNode curnode)
	{
	    if(curnode instanceof Identifier) {
	        for(int i=0; i<optTerms.size(); i++) {
	            if(curnode.equals(optTerms.get(i))) {
	                ASTNode p=curnode.getParent();
                    while(! p.isRelation()) {
                        p=p.getParent();
                    }
                    if(p!=optCt) {
                        optTermCt.get(i).add(p);
                    }
	            }
	        }
	        
	        if(! idCount.containsKey(curnode)) {
	            idCount.put(curnode, 1);
	        }
	        else {
	            idCount.put(curnode, idCount.get(curnode)+1);
	        }
	    }
        return null;
    }
    
    public void doTabulation() {
        System.out.println(optTerms);
        System.out.println(optTermCt);
        
        
        for(int i=0; i<optTerms.size(); i++) {
            //  Make a conjunction of the constraints containing the variable
            //  This will copy the constraints.
            ASTNode con=new And(optTermCt.get(i));
            
            ArrayList<ASTNode> vars=getVariablesOrdered(con);
            
            for(int j=0; j<vars.size()-1; j++) {
                if(vars.get(j).equals(optTerms.get(i))) {
                    //  Swap to end
                    ASTNode tmp=vars.get(j);
                    vars.set(j, vars.get(vars.size()-1));
                    vars.set(vars.size()-1, tmp);
                    break;
                }
            }
            
            ArrayList<ASTNode> domains = getDomains(vars);
            
            shortsups=new ArrayList<ArrayList<Long>>();
            failcount=0;
            
            // Remove one-occurrence variables. Don't remove the objective variable.
            boolean unroll=false;
            for(int j=0; j<vars.size()-1; j++) {
                if(idCount.get(vars.get(j))==1) {   //  Should be idCount == number of occs in con. 
                    //  Shadowing the decision variable -- does this work correctly?
                    con=new ExistsExpression(vars.get(j), domains.get(j), con);
                    
                    vars.remove(j);
                    domains.remove(j);
                    j--;
                    
                    unroll=true;
                }
            }
            if(unroll) {
                TransformQuantifiedExpression tqe=new TransformQuantifiedExpression(m);
                con=tqe.transform(con);
            }
            
            System.out.println("About to tabulate:"+con);
            
            boolean flag=DFSfull(vars, domains, con, new ArrayList<Long>(), 100000, 100000, (optTermsPolarity[i] && min) || (!optTermsPolarity[i] && !min));
            
            if(!flag) continue;
            
            // rest copy-pasted from TransformMakeTable
            ArrayList<ASTNode> shortsups2=new ArrayList<ASTNode>();
            
            for(int j=0; j<shortsups.size(); j++) {
                
                ArrayList<Long> shortsupold=shortsups.get(j);
                
                ArrayList<ASTNode> shortsupnew=new ArrayList<ASTNode>(vars.size());
                
                for(int k=0; k<shortsupold.size(); k++) {
                    shortsupnew.add(NumberConstant.make(shortsupold.get(k)));
                }
                
                shortsups2.add(CompoundMatrix.make(shortsupnew));
            }
            
            ASTNode tab=CompoundMatrix.make(shortsups2);
            
            tab=m.cmstore.newConstantMatrixDedup(tab);
            
            ASTNode replcons=new Table(m, CompoundMatrix.make(vars), tab);
            
            ArrayList<ASTNode> toreplace=optTermCt.get(i);
            
            toreplace.get(0).getParent().setChild(toreplace.get(0).getChildNo(), replcons);
            for(int j=1; j<toreplace.size(); j++) {
                toreplace.get(j).getParent().setChild(toreplace.get(j).getChildNo(), new BooleanConstant(true));
            }
        }
    }
    
    
    
    ////////////////////////////////////////////////////////////////////////////
    //
    //    Convert to full-length table -- dominance on last variable.
    
    ArrayList<ArrayList<Long>> shortsups;
    long failcount;
    
    private boolean DFSfull(ArrayList<ASTNode> varlist, ArrayList<ASTNode> vardoms, ASTNode exp, ArrayList<Long> assignment, long suplimit, long faillimit, boolean minimising) {
        int depth=assignment.size();
        
        if(exp instanceof BooleanConstant) {
            if(depth==varlist.size() && exp.getValue()==1) {
                // Copy the current assignment into shortsups.
                shortsups.add(new ArrayList<Long>(assignment));
                if(verbose && shortsups.size()>suplimit) {
                    System.out.println("DFS hit supports limit");
                }
                return shortsups.size()<=suplimit; //  Continue search iff within sup limit.
            }
            if(exp.getValue()==0) {
                failcount++;
                if(verbose && failcount>faillimit) {
                    System.out.println("DFS hit fail limit");
                }
                return failcount<=faillimit;   // continue search iff failcount<=faillimit.
            }
            // When the expression evaluates to true but assignment is not long enough, continue forwardtracking.
        }
        
        //  Iterate through the domain of the current variable assigning each value in turn.
        ASTNode curvar=varlist.get(depth);
        ArrayList<Intpair> vals=vardoms.get(depth).getIntervalSet();
        TransformSimplify ts=new TransformSimplify();
        
        if(depth<varlist.size()-1) {
            //  Not the last variable. 
            for(int i=0; i<vals.size(); i++) {
                for(long val=vals.get(i).lower; val<=vals.get(i).upper; val++) {
                    ASTNode local_exp=exp.copy();
                    
                    local_exp=TabulationUtils.assignValue(local_exp, curvar, val);
                    
                    local_exp=ts.transform(local_exp);  // make the assignment and simplify.
                    
                    assignment.add(val);
                    
                    boolean flag=DFSfull(varlist, vardoms, local_exp, assignment, suplimit, faillimit, minimising);
                    if(!flag) return false;
                    
                    assignment.remove(assignment.size()-1);   //  delete this assignment.
                }
            }
        }
        else if(minimising) {
            ///  Last variable, take least value. 
            for(int i=0; i<vals.size(); i++) {
                int suplistlength=shortsups.size();  // Store length of support list
                for(long val=vals.get(i).lower; val<=vals.get(i).upper; val++) {
                    ASTNode local_exp=exp.copy();
                    
                    local_exp=TabulationUtils.assignValue(local_exp, curvar, val);
                    
                    local_exp=ts.transform(local_exp);  // make the assignment and simplify.
                    
                    assignment.add(val);
                    
                    boolean flag=DFSfull(varlist, vardoms, local_exp, assignment, suplimit, faillimit, minimising);
                    if(!flag) return false;
                    
                    assignment.remove(assignment.size()-1);   //  delete this assignment.
                    if(depth==varlist.size()-1 && shortsups.size()>suplistlength) {
                        return true;  //   The rest of the values must be dominated by val. 
                    }
                }
            }
        }
        else {
            // Last variable, take greatest value. 
            for(int i=vals.size()-1; i>=0; i--) {
                int suplistlength=shortsups.size();  // Store length of support list
                for(long val=vals.get(i).upper; val>=vals.get(i).lower; val--) {
                    ASTNode local_exp=exp.copy();
                    
                    local_exp=TabulationUtils.assignValue(local_exp, curvar, val);
                    
                    local_exp=ts.transform(local_exp);  // make the assignment and simplify.
                    
                    assignment.add(val);
                    
                    boolean flag=DFSfull(varlist, vardoms, local_exp, assignment, suplimit, faillimit, minimising);
                    if(!flag) return false;
                    
                    assignment.remove(assignment.size()-1);   //  delete this assignment.
                    if(depth==varlist.size()-1 && shortsups.size()>suplistlength) {
                        return true;  //   The rest of the values must be dominated by val. 
                    }
                }
            }
        }
        return true;
    }
    
    
    ///  C&P  from TransformMakeTable.
    
    //   Assumes only decision variables and references to the constant matrices remain.
    public ArrayList<ASTNode> getVariablesOrdered(ASTNode exp) {
        HashSet<ASTNode> tmp=new HashSet<ASTNode>();
        ArrayList<ASTNode> vars_ordered=new ArrayList<ASTNode>();
        getVariablesOrderedInner(exp, tmp, vars_ordered);
        return vars_ordered;
    }
    
    private void getVariablesOrderedInner(ASTNode exp, HashSet<ASTNode> varset, ArrayList<ASTNode> varlist) {
        if(exp instanceof Identifier && exp.getCategory()>ASTNode.Constant) {
            // Collect all identifiers except those that refer to a constant matrix.
            if(! varset.contains(exp)) {
                varset.add(exp);
                varlist.add(exp);
            }
        }
        else {
            for(int i=0; i<exp.numChildren(); i++) {
                getVariablesOrderedInner(exp.getChild(i), varset, varlist);
            }
        }
    }
    
    public ArrayList<ASTNode> getDomains(ArrayList<ASTNode> varlist) {
        ArrayList<ASTNode> vardoms=new ArrayList<ASTNode>();
        TransformSimplify ts=new TransformSimplify();
        for(int i=0; i<varlist.size(); i++) {
            vardoms.add(ts.transform(m.global_symbols.getDomain(varlist.get(i).toString())));
        }
        return vardoms;
    }
}

