package savilerow.treetransformer;
/*

    Savile Row http://savilerow.cs.st-andrews.ac.uk/
    Copyright (C) 2014-2024 Peter Nightingale and Luke Ryan
    
    This file is part of Savile Row.
    
    Savile Row is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.
    
    Savile Row is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.
    
    You should have received a copy of the GNU General Public License
    along with Savile Row.  If not, see <http://www.gnu.org/licenses/>.

*/

import java.util.ArrayList;

import savilerow.CmdFlags;
import savilerow.expression.*;
import savilerow.model.Model;

public class TransformDecomposeInverse extends TreeTransformerBottomUpNoWrapper
{
    private boolean propagate;
    
    public TransformDecomposeInverse(Model _m, boolean _propagate) { 
        super(_m);
        propagate = _propagate;
    }
    
    protected NodeReplacement processNode(ASTNode curnode)
    {
        if(curnode instanceof Inverse) {
            boolean backend=(CmdFlags.getOrtoolstrans() || CmdFlags.getChuffedtrans() || CmdFlags.getMinizinctrans() || CmdFlags.getGecodetrans() || CmdFlags.getChocotrans()) && !propagate;
            
            if(!backend) {
                return new NodeReplacement(decomp(curnode));
            }
        }
        return null;
    }
    
    private ASTNode decomp(ASTNode curnode) {
        ASTNode mat1=curnode.getChildConst(0);
        ASTNode mat2=curnode.getChildConst(1);
        
        ArrayList<Intpair> idxdom1 = mat1.getChild(0).getIntervalSet();
        ArrayList<Intpair> idxdom2 = mat2.getChild(0).getIntervalSet();
        
        assert idxdom1.size()==1;
        assert idxdom2.size()==1;
        
        int lb1=(int)idxdom1.get(0).lower;
        int lb2=(int)idxdom2.get(0).lower;
        int ub1=(int)idxdom1.get(0).upper;
        int ub2=(int)idxdom2.get(0).upper;
        
        ArrayList<ASTNode> newCts=new ArrayList<>();
        
        if(CmdFlags.getSattrans() && !propagate) {
            // A decomposition for SAT/SMT/etc, should allow the solver to unify pairs of SAT variables. 
            // x[i]=j <-> y[j]=i
            for(int i=lb1; i<=ub1; i++) {
                for(int j=lb2; j<=ub2; j++) {
                    newCts.add(new Iff(new Equals(mat1.getChild(i-lb1+1), NumberConstant.make(j)), new Equals(mat2.getChild(j-lb2+1), NumberConstant.make(i))));
                }
            }
        }
        else {
            //  A CP decomposition
            for(int i=lb1; i<=ub1; i++) {
                newCts.add(new Equals(new SafeElementOne(mat2, new WeightedSum(mat1.getChild(i-lb1+1), NumberConstant.make(-lb2+1))), NumberConstant.make(i)));
            }
            
            for(int i=lb2; i<=ub2; i++) {
                newCts.add(new Equals(new SafeElementOne(mat1, new WeightedSum(mat2.getChild(i-lb2+1), NumberConstant.make(-lb1+1))), NumberConstant.make(i)));
            }
            
            newCts.add(new AllDifferent(mat1));
            //newCts.add(new AllDifferent(mat2));
        }
        
        return new And(newCts);
    }
}
