// A constraint which implements a sum of weighted variables = 0.

package queso.constraints;

import java.util.*;
import gnu.math.*;

import queso.core.*;

// note on variables: numerical constraints can use both hash_var
// and bound_var types, but the other constraints can only use
// hash_var because they prune from the middle of the domains.

public class sum_constraint extends constraint implements make_ac
{
    // implements the rule-based sum constraint with three rules per
    // variable : prune upper (lower) bound and 'sufficient inner variance'
    // Uses variables whose bounds are represented by integers
    mid_domain[] variables;
    
    public sum_constraint(mid_domain [] variables, int [] weights, qcsp problem)
    {
        super(problem);
        this.variables=variables;
        check_variables(variables);
        
        assert weights.length== variables.length;
        for(int weight : weights) assert weight!=0;
        
        this.weights=weights;
    }
    
    final int [] weights;
    
    public boolean establish()
    {
        for(mid_domain temp : variables)
        {
            temp.add_wakeup(this);
        }
        
        return make_ac();
    }
    
    public variable_iface [] variables()
    {
        return (variable_iface []) variables;
    }
    
    public String toString()
    {
        String st="";
        for(int i=0; i<variables.length; i++)
        {
            st+=(variables[i].quant()?"A":"E")+variables[i]+" ";
        }
        st+=": ";
        for(int i=0; i<variables.length; i++)
        {
            st+=weights[i]+variables[i].toString();
            if(i!=variables.length-1)
                st+=" + ";
        }
        st+=" = 0";
        return st;
    }
    
    public boolean make_ac()
    {
        // variables are in quantification order
        
        // For the upper and lower pruning rules:
        
        int upper_total=0;   // add up all the upper bounds of the terms.
        
        int lower_total=0;
        
        // calculate the totals, not including the innermost variable.
        
        for(int var=0; var<variables.length-1; var++)
        {
            if(weights[var]>0)
            {
                upper_total+=variables[var].upperbound()*weights[var];
                lower_total+=variables[var].lowerbound()*weights[var];
            }
            else
            {
                upper_total+=variables[var].lowerbound()*weights[var];
                lower_total+=variables[var].upperbound()*weights[var];
            }
        }
        
        // loop through from the inner to outer vars, adjusting the upper_total and lower_total
        // as we go and imposing the new bounds.
        
        // deal with the inner var first
        
        if(!revise_bounds(upper_total, lower_total, variables[variables.length-1], weights[variables.length-1]))
        {
            return false;
        }
        
        // for the rest of the variables
        for(int var=variables.length-2; var>=0; var--)
        {
            // add on the effect of the last variable:
            
            // if the weight is positive and it is an existential variable
            // or if the weight is negative and it is a universal variable,
            // do not invert the bounds, otherwise do.
            if((weights[var+1]>0 && !variables[var+1].quant()) || (weights[var+1]<0 && variables[var+1].quant()))
            {
                upper_total+=variables[var+1].upperbound()*weights[var+1];
                lower_total+=variables[var+1].lowerbound()*weights[var+1];
            }
            else
            {
                upper_total+=variables[var+1].lowerbound()*weights[var+1];
                lower_total+=variables[var+1].upperbound()*weights[var+1];
            }
            
            // now subtract the current variable:
            if(weights[var]>0)
            {
                upper_total-=variables[var].upperbound()*weights[var];
                lower_total-=variables[var].lowerbound()*weights[var];
            }
            else
            {
                upper_total-=variables[var].lowerbound()*weights[var];
                lower_total-=variables[var].upperbound()*weights[var];
            }
            
            if(!revise_bounds(upper_total, lower_total, variables[var], weights[var]))
            {
                return false;
            }
        }
        
        // now check that each variable has sufficient inner variance to cover its interval.
        
        int inner_variance=0;
        
        for(int var=variables.length-1; var>=0; var--)
        {
            int this_variance=Math.abs((variables[var].upperbound()-variables[var].lowerbound())*weights[var]);
            if(variables[var].quant())
            {
                inner_variance-=this_variance;
            }
            else
            {
                inner_variance+=this_variance;
            }
            //System.out.println("Var:"+variables[var]+" inner variance:"+inner_variance);
            if(inner_variance<0)
            {
                assert variables[var].quant();
                return false;
            }
        }
        
        // do we restart here if the domains have changed?
        // Probably better not to: see HC1. Goualard and Granvilliers.
        // But make sure that the constraint is requeued again if we do not restart.
        return true;
    }
    
    public boolean entailed()
    {
        // if all vars unit
        for(int i=0; i<variables.length; i++)
        {
            if(!variables[i].unit())
            {
                return false;
            }
        }
        return true;
    }
    
    boolean revise_bounds(int upper_total, int lower_total, mid_domain var, int weight)
    {
        // lower_total and upper_total refer to the other terms in the constraint, and 
        // include opposite bounds of universal terms where appropriate.
        
        // The lower bound of this variable var is pruned using the upper_total
        // and vice versa.
        
        int var_upper_bound, var_lower_bound;
        
        if(weight>0)
        {
            var_upper_bound=div(-lower_total, weight, -1);   // round down.
            var_lower_bound=div(-upper_total, weight, 1);   // round up.
        }
        else
        {
            // same as above but upper and lower are switched.
            var_lower_bound=div(-lower_total, weight, 1);    // round up.
            var_upper_bound=div(-upper_total, weight, -1);   // round down.
        }
        
        // prune if necessary
        //System.out.println("New bounds for "+var+" ["+var_lower_bound+","+var_upper_bound+"]");
        
        if(var.upperbound() > var_upper_bound)
        {   // if the upper bound of this variable exceeds the new upper bound
            // then prune it.
            boolean flag=var.exclude_upper(var_upper_bound, null);
            if(!flag)
                return false;
        }

        if(var.lowerbound() < var_lower_bound)
        {   // if the upper bound of this variable exceeds the new upper bound
            // then prune it.
            boolean flag=var.exclude_lower(var_lower_bound, null);
            if(!flag)
                return false;
        }
        
        return true;
    }
    
    final int div(int top, int bottom, int round)
    {
        // if round is 1, we round up, if it is -1 we round down.
        // The machine arithmetic appears to round towards 0. 
        
        assert 3/4 == 0;
        assert -3/4 == 0;
        assert 3/-4 == 0;
        assert -3/-4 == 0;
        
        assert (round==1 || round==-1) && bottom!=0;
        
        if(top==0)
        {
            return 0;
        }
        
        int quotient=top/bottom;
        int remainder=top%bottom;
        
        boolean neg=((top<0 || bottom<0) && !(top<0 && bottom<0));   // if the unrounded answer is negative
        int retval;
        if(round==-1 && neg && remainder!=0)
        {
            retval= quotient-1;  // round it down.
        }
        else if(round==1 && !neg && remainder!=0)
        {
            retval= quotient+1;  // round it up
        }
        else
        {
            retval= quotient;  // machine arithmetic rounded it in the right direction.
        }
        //System.out.println(top+"/"+bottom+" round:"+round+" = "+retval);
        assert new IntNum(retval).equals(IntNum.quotient(new IntNum(top), new IntNum(bottom), (round==1?IntNum.CEILING:IntNum.FLOOR)));
        return retval;
    }
}

class sum_constraint_test
{
    public static void main(String[] args)
    {
        mid_domain [] variables= new mid_domain[4];
        qcsp prob = new qcsp();
        
        variables[0]=new existential(-10, 10, prob, "a");
        variables[1]=new existential(-10, 10, prob, "b");
        variables[2]=new universal(0, 3, prob, "c");
        variables[3]=new existential(0, 3, prob, "d");
        
        int weights[]= {2,-2,3,3};
        
        sum_constraint c1= new sum_constraint(variables, weights, prob);
        
        prob.printdomains();
        
        System.out.println(c1.make_ac());
        
        prob.printdomains();
        
        System.out.println(c1.make_ac());
        
        prob.printdomains();
        
        System.out.println(c1.make_ac());
        
        prob.printdomains();
        
        System.out.println(c1.make_ac());
        
        prob.printdomains();
        
        System.out.println(c1.make_ac());
        
        prob.printdomains();
        
        System.out.println(c1.make_ac());
        
        prob.printdomains();
        
        System.out.println(c1.make_ac());
        
        prob.printdomains();
        
        System.out.println(c1.make_ac());
        
        prob.printdomains();
        
        System.out.println(c1.make_ac());
        
        prob.printdomains();
    }
}

class sum_constraint_test3
{
    // checks if solving a QCSP involving a single sum constraint
    // produces the same result as propagating the SQGAC version.
    public static void main(String[] args)
    {
        sum_constraint_test3 o1=new sum_constraint_test3();
        o1.sw_sum=new stopwatch2();
        o1.sw_sqgac=new stopwatch2();
        
        // warm-up.
        for(int i=0; i<1000; i++)
        {
            o1.test_random_prob();
            System.out.println("Done test "+i);
        }
        
        o1.sw_sum=new stopwatch2();
        o1.sw_sqgac=new stopwatch2();
        o1.sum_solved_root=0;
        o1.sqgac_solved_root=0;
        o1.sum_nodes=0;
        
        for(int i=0; i<10000; i++)
        {
            o1.test_random_prob();
            System.out.println("Done test "+i);
        }
        
        System.out.println("Time spent solving sum:"+o1.sw_sum.elapsedMicros());
        System.out.println("Time spent solving sqgac:"+o1.sw_sqgac.elapsedMicros());
        System.out.println("Sum solved at root:"+o1.sum_solved_root);
        System.out.println("SQGAC solved at root:"+o1.sqgac_solved_root);
        
        System.out.println("Sum search nodes:"+o1.sum_nodes);
    }
    
    stopwatch2 sw_sum;
    stopwatch2 sw_sqgac;
    
    int sum_solved_root=0;
    int sqgac_solved_root=0;
    
    int sum_nodes=0;
    int sqgac_nodes=0;
    
    void test_random_prob()
    {
        final int numvars=6;
        mid_domain [] vars_sum= new mid_domain[numvars];
        mid_domain [] vars_sqgac= new mid_domain[numvars];
        
        qcsp prob_sum = new qcsp();
        qcsp prob_sqgac = new qcsp();
        
        int [] coeff=new int[numvars];
        for(int i=0; i<numvars; i++)
        {
            coeff[i]=(int)(Math.random()*21.0)-10;
            while(coeff[i]==0)
            {
                coeff[i]=(int)(Math.random()*21.0)-10;
            }
        }
        
        // 5 types of universal to choose from
        // random test.
        
        for(int i=0; i<numvars; i++)
        {
            if(Math.random()>0.8)
            {
                // universal
                double temp=Math.random();
                int ub, lb;
                if(temp<0.2)
                {   // positive
                    ub=12; lb=3;
                }
                else if(temp<0.4)
                {   // includes 0.
                    ub=9; lb=0;
                }
                else if(temp<0.6)
                {   // spans 0
                    ub=6; lb=-3;
                }
                else if(temp<0.8)
                {   // includes 0 negative
                    ub=0; lb=-9;
                }
                else
                {
                    ub=-3; lb=-12;
                }
                vars_sum[i]=new universal(lb, ub, prob_sum, "x"+(i+1));
                vars_sqgac[i]=new universal(lb, ub, prob_sqgac, "x"+(i+1));
            }
            else
            {
                // existential
                double temp=Math.random();
                int ub, lb;
                if(temp<0.2)
                {   // positive
                    ub=12; lb=3;
                }
                else if(temp<0.4)
                {   // includes 0.
                    ub=9; lb=0;
                }
                else if(temp<0.6)
                {   // spans 0
                    ub=6; lb=-3;
                }
                else if(temp<0.8)
                {   // includes 0 negative
                    ub=0; lb=-9;
                }
                else
                {
                    ub=-3; lb=-12;
                }
                vars_sum[i]=new existential(lb, ub, prob_sum, "x"+(i+1));
                vars_sqgac[i]=new existential(lb, ub, prob_sqgac, "x"+(i+1));
            }
        }
        
        //int xi = (int) (numvars*Math.random());
        //assert xi>=0 && xi<=(numvars-1);
        
        sum_constraint c1= new sum_constraint(vars_sum, coeff, prob_sum);
        
        predicate_wrapper pred=new sum_predicate(coeff);
        
        sw_sqgac.start();
        sqgac c2= new sqgac(vars_sqgac, prob_sqgac, pred);
        sw_sqgac.end();
        
        prob_sum.printdomains();
        
        sw_sum.start();
        boolean flag1=prob_sum.establish();
        sw_sum.end();
        
        sw_sqgac.start();
        boolean flag2=prob_sqgac.establish();
        sw_sqgac.end();
        
        // remember sqgac is strictly stronger
        
        if(!flag1) sum_solved_root++;
        if(!flag2) sqgac_solved_root++;
        
        System.out.println(c1);
        
        assert !flag2 || flag1 : "SQGAC found true and sum found false";
        
        sw_sum.start();
        if(flag1) flag1=prob_sum.search();
        sw_sum.end();
        sum_nodes+=prob_sum.numnodes;
        // should perh. also search for sqgac.
        
        //System.out.println(c1);
        
        System.out.println("flag1="+flag1+", flag2="+flag2);
        assert flag1==flag2 : "Truth values do not match";
    }
}

class sum_predicate implements predicate_wrapper
{
    sum_predicate(int [] coeff)
    {
        this.coeff=coeff;
    }
    
    final int [] coeff;
    
    public boolean predicate(tuple tau)
    {
        int sum=0;
        for(int i=0; i<tau.vals.length; i++)
        {
            sum=sum+(coeff[i]*tau.vals[i]);
        }
        
        return sum==0;
    }
}
