package savilerow.treetransformer;
/*

    Savile Row http://savilerow.cs.st-andrews.ac.uk/
    Copyright (C) 2014-2018 Peter Nightingale
    
    This file is part of Savile Row.
    
    Savile Row is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.
    
    Savile Row is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.
    
    You should have received a copy of the GNU General Public License
    along with Savile Row.  If not, see <http://www.gnu.org/licenses/>.

*/


import savilerow.expression.*;
import savilerow.model.*;

import java.util.ArrayList;
import java.util.HashMap;

//   Catch cases where AMO-PB encoding can be used. At the moment this is very
//   basic. 

public class TransformSumToAMOPB extends TreeTransformerBottomUpNoWrapper
{
    public TransformSumToAMOPB() { super(null); }
    
    protected NodeReplacement processNode(ASTNode curnode)
    {
        ArrayList<ASTNode> ch=null;
        ArrayList<Long> wts=null;
        long cmp=0;
        
        //  Catch all cases of a sum in a binop (except equality, decomposed earlier).
        if(curnode instanceof Less && curnode.getChild(0) instanceof WeightedSum && curnode.getChild(1).isConstant()) {
            ch=curnode.getChild(0).getChildren();
            wts=((WeightedSum)curnode.getChild(0)).getWeights();
            
            cmp=curnode.getChild(1).getValue()-1;  // convert to <=
        }
        
        if(curnode instanceof LessEqual && curnode.getChild(0) instanceof WeightedSum && curnode.getChild(1).isConstant()) {
            ch=curnode.getChild(0).getChildren();
            wts=((WeightedSum)curnode.getChild(0)).getWeights();
            
            cmp=curnode.getChild(1).getValue();
        }
        
        if(curnode instanceof Less && curnode.getChild(1) instanceof WeightedSum && curnode.getChild(0).isConstant()) {
            // k < sum  becomes  -sum < -k  becomes  -sum <= -k-1
            ch=curnode.getChild(1).getChildren();
            wts=((WeightedSum)curnode.getChild(1)).getWeights();
            for(int i=0; i<wts.size(); i++) {
                wts.set(i, -wts.get(i));  // negate the weights
            }
            
            cmp=-curnode.getChild(0).getValue()-1;
        }
        
        if(curnode instanceof LessEqual && curnode.getChild(1) instanceof WeightedSum && curnode.getChild(0).isConstant()) {
            // k <= sum  becomes  -sum <= -k 
            ch=curnode.getChild(1).getChildren();
            wts=((WeightedSum)curnode.getChild(1)).getWeights();
            for(int i=0; i<wts.size(); i++) {
                wts.set(i, -wts.get(i));  // negate the weights
            }
            
            cmp=-curnode.getChild(0).getValue();
        }
        
        if(ch!=null) {
            //  Flip any negative weights.
            for(int i=0; i<ch.size(); i++) {
                if(wts.get(i)<0) {
                    wts.set(i, -wts.get(i));
                    ch.set(i, new MultiplyMapper(ch.get(i), NumberConstant.make(-1)));
                }
            }
            
            // Shift any terms that go below 0. 
            for(int i=0; i<ch.size(); i++) {
                Intpair bnds=ch.get(i).getBounds();
                if(bnds.lower<0) {
                    ch.set(i, new ShiftMapper(ch.get(i), NumberConstant.make(-bnds.lower)));
                    cmp=cmp+(-bnds.lower)*wts.get(i);
                }
            }
            
            //  Parameters of the AMOPB constraint.
            ArrayList<ASTNode> amoproducts=new ArrayList<ASTNode>();
            
            
            ArrayList<ArrayList<Integer>> coeffs=new ArrayList<ArrayList<Integer>>();
            
            for(int i=0; i<ch.size(); i++) {
                if(ch.get(i)==null) {
                    continue;   //  This element has been deleted
                }
                
                //  Case one. The sum directly contains a decision variable.
                //  Get the domain, remove the smallest value 
                if(ch.get(i) instanceof Identifier || ch.get(i) instanceof ShiftMapper || ch.get(i) instanceof MultiplyMapper) {
                    ArrayList<Intpair> dom=ch.get(i).getIntervalSetExp();
                    long coeff=wts.get(i);
                    // chop the smallest value. 
                    long smallestval=dom.get(0).lower;
                    
                    ArrayList<ASTNode> coeffs_onevar=new ArrayList<ASTNode>();
                    ArrayList<ASTNode> bools_onevar=new ArrayList<ASTNode>();
                    
                    for(int j=0; j<dom.size(); j++) {
                        for(long k=dom.get(j).lower; k<=dom.get(j).upper; k++) {
                            if(k!=smallestval) {
                                coeffs_onevar.add( NumberConstant.make((coeff*k)-(coeff*smallestval)));
                                bools_onevar.add(new Equals(ch.get(i), NumberConstant.make(k)));
                            }
                        }
                    }
                    
                    //  Adjust the other side of the binop to subtract the smallest val. 
                    cmp -= smallestval*coeff;
                    
                    amoproducts.add(CompoundMatrix.make(CompoundMatrix.make(coeffs_onevar), CompoundMatrix.make(bools_onevar)));
                }
                else if(ch.get(i).isConstant()) {
                    long val=ch.get(i).getValue()*wts.get(i);
                    cmp -= val;  //  Move to the other side of the <=.
                }
                else {
                    return null;
                }
            }
            
            return new NodeReplacement(new AMOPB(CompoundMatrix.make(amoproducts), NumberConstant.make(cmp)));
        }
        
	    return null;
    }
}

