package savilerow.treetransformer;
/*

    Savile Row http://savilerow.cs.st-andrews.ac.uk/
    Copyright (C) 2014-2024 Peter Nightingale and Luke Ryan
    
    This file is part of Savile Row.
    
    Savile Row is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.
    
    Savile Row is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.
    
    You should have received a copy of the GNU General Public License
    along with Savile Row.  If not, see <http://www.gnu.org/licenses/>.

*/

import java.util.ArrayList;

import savilerow.CmdFlags;
import savilerow.expression.*;
import savilerow.model.Model;

public class TransformDecomposeCircuit extends TreeTransformerBottomUpNoWrapper
{
    private boolean propagate;
    
    public TransformDecomposeCircuit(Model _m, boolean _propagate) { 
        super(_m);
        propagate = _propagate;
    }
    
    protected NodeReplacement processNode(ASTNode curnode)
    {
        if(curnode instanceof Circuit) {
            //  Does the backend solve have a circuit constraint
            boolean backend=(CmdFlags.getOrtoolstrans() || CmdFlags.getChuffedtrans() || CmdFlags.getMinizinctrans() || CmdFlags.getGecodetrans() || CmdFlags.getChocotrans()) && !propagate;
            
            if(!backend && !curnode.getParent().inTopAnd()) {
                CmdFlags.errorExit("Reified circuit not supported for the chosen backend solver.");
            }
            
            if(!backend) {
                return new NodeReplacement(decomp(curnode));
            }
            else {
                long lb=curnode.getChild(0).getChild(0).getBounds().lower;
                if(lb<0) {
                    ArrayList<ASTNode> x=curnode.getChild(0).getChildren(1);
                    for(int i=0; i<x.size(); i++) {
                        x.set(i, new WeightedSum(x.get(i), NumberConstant.make(-lb)));
                    }
                    ASTNode idx=new IntegerDomainConcrete(0, x.size()-1);
                    ASTNode cm=CompoundMatrix.make(idx, x, false);
                    
                    curnode.setChild(0, cm);
                }
            }
        }
        return null;
    }
    
    private ASTNode decomp(ASTNode curnode) {
        ASTNode x = curnode.getChildConst(0);
        
        assert x instanceof CompoundMatrix;
        
        ArrayList<Intpair> idx=x.getChild(0).getIntervalSet();
        
        if(idx.size()>1) {
            CmdFlags.errorExit("Index domain of matrix in circuit constraint must be a single interval.");
        }
        
        long lb=idx.get(0).lower;
        
        ArrayList<ASTNode> ct=new ArrayList<>();
        
        for(int i=1; i<x.numChildren(); i++) {
            ct.add(new AllDifferent(x.getChild(i), NumberConstant.make(i-1+lb)));
        }
        
        ct.add(new AllDifferent(x));
        
        boolean explSeq=false;
        if(explSeq) {
            // Represent the sequence explicitly. 
            // Make the sequence variables.
            ArrayList<ASTNode> seq=new ArrayList<>();
            for(int i=1; i<x.numChildren(); i++) {
                ASTNode aux=m.global_symbols.newAuxiliaryVariable(x.getChild(0));  // Take domain from the matrix index. 
                seq.add(aux);
            }
            
            for(int i=0; i<seq.size()-1; i++) {
                // x[seq[i]]=seq[i+1]
                ct.add(new ToVariable(new SafeElementOne(x, new WeightedSum(seq.get(i), NumberConstant.make(-lb+1))), seq.get(i+1)));
            }
            // Close the circuit. 
            ct.add(new ToVariable(new SafeElementOne(x, new WeightedSum(seq.get(seq.size()-1), NumberConstant.make(-lb+1))), seq.get(0)));
            
            ct.add(new AllDifferent(CompoundMatrix.make(seq)));
            
            //  Fix the first element. 
            ct.add(new Equals(seq.get(0), NumberConstant.make(lb)));
        }
        else {
            //  Order variables to represent the order number of each element from 1 to n. 
            //  Credited to Gert Smolka (named 'jump' variables) here: https://www.gecode.org/doc-latest/reference/classKnightsReified.html
            
            ArrayList<ASTNode> ord=new ArrayList<>();
            for(int i=1; i<x.numChildren(); i++) {
                ASTNode aux=m.global_symbols.newAuxiliaryVariable(1, x.numChildren()-1);
                ord.add(aux);
            }
            
            ct.add(new AllDifferent(CompoundMatrix.make(ord)));
            
            // Set the smallest index from x to have order number 1. 
            ct.add(new Equals(ord.get(0), NumberConstant.make(1)));
            
            for(int i=0; i<ord.size(); i++) {
                ct.add(new Implies(new AllDifferent(ord.get(i), NumberConstant.make(x.numChildren()-1)),
                    new Equals(new SafeElementOne(CompoundMatrix.make(ord), new WeightedSum(x.getChild(i+1), NumberConstant.make(-lb+1))), 
                        new WeightedSum(ord.get(i), NumberConstant.make(1)))
                    ));
                
                ct.add(new Implies(new Equals(ord.get(i), NumberConstant.make(x.numChildren()-1)),
                    new Equals(x.getChild(i+1), NumberConstant.make(lb))
                    ));
            }
        }
        
        return new And(ct);
    }
}
