package savilerow.expression;
/*

    Savile Row http://savilerow.cs.st-andrews.ac.uk/
    Copyright (C) 2014-2018 Peter Nightingale
    
    This file is part of Savile Row.
    
    Savile Row is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.
    
    Savile Row is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.
    
    You should have received a copy of the GNU General Public License
    along with Savile Row.  If not, see <http://www.gnu.org/licenses/>.

*/

import savilerow.*;
import java.util.*;
import java.math.BigInteger;
import java.io.*;
import savilerow.model.SymbolTable;
import savilerow.model.Sat;
import savilerow.mdd.*;

//  Pseudo-boolean with at-most-one groups constraint. 

public class AMOPB extends ASTNodeC {
    public static final long serialVersionUID = 1L;
    
    public AMOPB(ASTNode mat, ASTNode k) {
        super(mat, k);
    }
    
    public ASTNode copy() {
        return new AMOPB(getChild(0), getChild(1));
    }
    
    public boolean isRelation() {
        return true;
    }
    
    public ASTNode simplify() {
        return null;  // Do sometihng.
    }
    
    //////////////////////////////////////////////////////////////////////////// 
    // 
    // Output methods.
    
    public void toMinion(BufferedWriter b, boolean bool_context) throws IOException
	{
	    ArrayList<ASTNode> ch=new ArrayList<ASTNode>();
	    ArrayList<Long> wts=new ArrayList<Long>();
	    
	    
	    
	    for(int i=1; i<getChild(0).numChildren(); i++) {
	        ASTNode amoproduct=getChild(0).getChild(i);
	        ch.addAll(amoproduct.getChild(2).getChildren(1));
	        
	        for(int j=1; j<amoproduct.getChild(1).numChildren(); j++) {
	            wts.add(amoproduct.getChild(1).getChild(j).getValue());
	        }
	    }
	    
        b.append("weightedsumleq([");
        for (int i =0; i < wts.size(); i++) {
            b.append(String.valueOf(wts.get(i)));
            if (i < wts.size() - 1) {
                b.append(",");
            }
        }
        b.append("],[");
        for (int i =0; i < ch.size(); i++) {
            ch.get(i).toMinion(b, false);
            if (i < ch.size() - 1) {
                b.append(",");
            }
        }
        b.append("],");
        getChild(1).toMinion(b, false);
        b.append(")");
	}
    
    ////////////////////////////////////////////////////////////////////////////
    //
    //  SAT encoding
    
    public void toSAT(Sat satModel) throws IOException {
        //  Construct the groups for the bools.
        ArrayList<ArrayList<Long>> X = new ArrayList<ArrayList<Long>>();
        ArrayList<ArrayList<Integer>> coeffs = new ArrayList<ArrayList<Integer>>();
        
        for(int i=1; i<getChild(0).numChildren(); i++) {
            ArrayList<Long> boolgroup=new ArrayList<Long>();
            
            ASTNode bools_ast=getChild(0).getChild(i).getChild(2);
            
            for(int j=1; j<bools_ast.numChildren(); j++) {
                long lit=bools_ast.getChild(j).directEncode(satModel, 1);
                boolgroup.add(lit);
            }
            X.add(boolgroup);
            
            ASTNode coeffs_ast=getChild(0).getChild(i).getChild(1);
            ArrayList<Integer> coeffgroup=new ArrayList<Integer>();
            for(int j=1; j<coeffs_ast.numChildren(); j++) {
                coeffgroup.add((int) coeffs_ast.getChild(j).getValue());
            }
            coeffs.add(coeffgroup);
        }
        
        //  Reorder coeffs and X to be sorted in descending order of max (or should it be avg?) coefficient.
        for(int i=0; i<coeffs.size(); i++) {
            for(int j=i+1; j<coeffs.size(); j++) {
                if(Collections.max(coeffs.get(j)) > Collections.max(coeffs.get(i))) {
                    ArrayList<Integer> coeffgroup=coeffs.get(j);
                    coeffs.set(j, coeffs.get(i));
                    coeffs.set(i, coeffgroup);
                    
                    ArrayList<Long> boolgroup=X.get(j);
                    X.set(j, X.get(i));
                    X.set(i, boolgroup);
                }
            }
        }
        
        AMOPBMDDBuilder a=new AMOPBMDDBuilder(coeffs, X, (int) getChild(1).getValue(), true);
        
        MDD mdd=a.getMDD();
        
        AbioMDDEncoding me=new AbioMDDEncoding(satModel);
        
        long rootlit=me.assertMDD(mdd);
        
        satModel.addClause(rootlit);
    }
}
