package savilerow.model;
/*

    Savile Row http://savilerow.cs.st-andrews.ac.uk/
    Copyright (C) 2014-2024 Peter Nightingale
    
    This file is part of Savile Row.
    
    Savile Row is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.
    
    Savile Row is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.
    
    You should have received a copy of the GNU General Public License
    along with Savile Row.  If not, see <http://www.gnu.org/licenses/>.

*/

import java.io.BufferedWriter;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.io.RandomAccessFile;
import java.io.CharArrayWriter;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

import savilerow.CmdFlags;
import savilerow.expression.*;

public class PB extends Sat
{
    public PB(SymbolTable _global_symbols) {
        super(_global_symbols);
    }
    
    //  Override the methods that write SAT clauses to the output file to use the OPB format. 
    
    protected int rhsval=1;
    
    @Override
    protected void writeLiteral(long lit) throws IOException {
        if(lit<0) {
            // (1-xn)
            outstream.write("-1 x"+(-lit)+" ");
            rhsval--;  //  Adjust right-hand side of inequality 
        }
        else {
            outstream.write("+1 x"+lit+" ");
        }
    }
    
    @Override
    protected void clauseStart() throws IOException {
        rhsval=1;
        super.clauseStart();
    }
    
    @Override
    protected void clauseEnd() throws IOException {
        outstream.write(">= "+rhsval+";");
        outstream.newLine();
    }
    
    @Override
    public void addComment(String comment) throws IOException
    {
        outstream.write("* ");
        outstream.write(comment);
        outstream.newLine();
    }
    
    @Override
    public void finaliseOutput() throws IOException
    {
        outstream.flush();
        fw.getFD().sync();
        outstream.close();
        
        RandomAccessFile f=new RandomAccessFile(CmdFlags.pbfile, "rws");  //  rws to make sure everything is sync'd.
        f.seek(0);
        byte[] headline;
        headline=("* #variable= "+(variableNumber-1)+" #constraint= "+numClauses).getBytes();
        f.write(headline);
        f.write(("          ").getBytes());  //  Write some spaces in case there was a header line already that was longer.
        f.close();
    }
    
    ///   Output constraints, objective function
    
    public long normaliseConstraint(ArrayList<Long> coeffs, ArrayList<Long> bools, long rhsval) {
        for(int i=0; i<coeffs.size(); i++) {
            if(bools.get(i)<0) {
                //  Rewrite k(not x) op rhs   ~~~>   k(1-x) op rhs ~~~>  k-kx op rhs ~~~> -kx op rhs-k 
                
                rhsval = rhsval - coeffs.get(i);
                coeffs.set(i, -coeffs.get(i));
                bools.set(i, -bools.get(i));
            }
        }
        return rhsval;
    }
    
    public void addConstraint(ArrayList<Long> coeffs, ArrayList<Long> bools, int op, long rhsval) throws IOException {
        rhsval=addExpression(coeffs, bools, rhsval);  // Write out the sum expression. Adjusts rhsval if there are negated lits. 
        
        if(op == -1) {
            outstream.write("<= ");
        }
        else if(op == 0) {
            outstream.write("= ");
        }
        else {
            outstream.write(">= ");
        }
        outstream.write(String.valueOf(rhsval));
        outstream.write(";\n");
    }
    
    public long addExpression(ArrayList<Long> coeffs, ArrayList<Long> bools, long rhsval) throws IOException {
        rhsval=normaliseConstraint(coeffs, bools, rhsval);
        for(int i=0; i<coeffs.size(); i++) {
            if(coeffs.get(i)>0) {
                outstream.write("+"+coeffs.get(i)+" x"+bools.get(i)+" ");
            }
            if(coeffs.get(i)<0) {
                outstream.write(coeffs.get(i)+" x"+bools.get(i)+" ");
            }
        }
        return rhsval;
    }
    
    public void write(String s) throws IOException {
        outstream.write(s);
    }
    
    //  Encodes a weighted integer/bool as part of a sum, using the order encoding.
    //  The encoding is added to coeffs & bools, and rhsval is adjusted and returned. 
    public long encodeIntegerTerm(ASTNode term, long wt, ArrayList<Long> coeffs, ArrayList<Long> bools, long rhsval) {
        if(wt!=1) {
            term=new MultiplyMapper(term, NumberConstant.make(wt));
        }
        
        // break down any integers using order encoding
        ArrayList<Intpair> domain = term.getIntervalSetExp();
        
        // Adjust the RHS for smallest value. 
        rhsval=rhsval-domain.get(0).lower;
        
        long prevval=Long.MIN_VALUE;
        for(int j=0; j<domain.size(); j++) {
            for(long val=domain.get(j).lower; val<=domain.get(j).upper; val++) {
                if(prevval>Long.MIN_VALUE) {
                    long lit=-term.orderEncode(this, prevval);   // x > prevval
                    coeffs.add(val-prevval);  // x is greater than prevval, so add on the difference between val and prevval. 
                    bools.add(lit);
                }
                
                // Update previous value
                prevval=val;
            }
        }
        return rhsval;
    }
    
    //  Switch to storing constraints instead of writing them to the file. 
    BufferedWriter realfile=null;
    CharArrayWriter buffer=null;
    
    public void bufferConstraints() throws IOException {
        realfile=outstream;
        buffer=new CharArrayWriter();
        outstream=new BufferedWriter(buffer); 
    }
    
    public void deactivateBufferConstraints() throws IOException {
        outstream.flush();
        outstream=realfile;  // Switch back to writing into the real file. 
        realfile=null;
    }
    
    public void unbufferConstraints() throws IOException {
        outstream.write(buffer.toString());
        buffer=null;
    }
}
