package queso.constraints;

import java.util.*;
import gnu.trove.*;
import queso.core.*;

public class reginlhomme extends gac_schema implements stateful
{
    private final class reginlhomme_backtrack extends backtrack
    {
        reginlhomme_backtrack(int [][] last_tuple_pointer)
        {
            this.last_tuple_pointer=last_tuple_pointer;
            vars=new TIntArrayList();
            vals=new TIntArrayList();
            numlist=new TIntArrayList();
        }
        
        int [][] last_tuple_pointer;
        
        final TIntArrayList vars;
        final TIntArrayList vals;
        final TIntArrayList numlist;
        
        public void backtrack_item()
        {
            int index=vars.size()-1;
            int var=vars.remove(index);
            int val=vals.remove(index);
            int old_value=numlist.remove(index);
            
            last_tuple_pointer[var][val]=old_value;
        }
        
        void add_backtrack_pair(int var, int val, int oldwatch)
        {
            vars.add(var);
            vals.add(val);
            numlist.add(oldwatch);
            backtrack_increment();
        }
    }
    
    public reginlhomme(mid_domain [] vars, qcsp prob, predicate_wrapper pred)
    {
        super(vars, prob);
        int [][] tuples=queso.constraints.tuples.pred2table(pred, variables);
        
        arity = tuples[0].length ;
        tuplelist=new TupleH[tuples.length];
        
        TupleComparator2 tc = new TupleComparator2();
        Arrays.sort( (Object []) tuples, tc);
        
        last_tuple_pointer=new int[arity][];
        for(int i=0; i<arity; i++)
        {   last_tuple_pointer[i]=new int[variables[i].domsize()];
            Arrays.fill(last_tuple_pointer[i], -1);
        }
        
        int [] redvalues=new int[arity];
        for(int i=0; i<arity; i++) redvalues[i]=variables[i].lowerbound();
        
        for(int i=0; i<tuples.length; i++)
        {
            TupleH t=new TupleH(tuples[i], (int [])redvalues.clone(), i);
            tuplelist[i]=t;
            
            // cross-link with previous tuples.
            int valcountlocal=arity;  // for each value in this tuple, there is one forward reference in nextValue.
        loopback:
            for(int j=i-1; j>=0; j--)
            {
                TupleH prev=tuplelist[j];
                for(int var=0; var<arity; var++)
                {
                    if(prev.values[var]==t.values[var] && prev.nextValue[var]==-1)
                    {
                        prev.nextValue[var]=i;
                        valcountlocal--;
                    }
                    if(prev.redundantValues[var]==t.values[var] && prev.redundantNextValue[var]==-1)
                    {
                        prev.redundantNextValue[var]=i;
                    }
                    if(valcountlocal==0) break loopback;
                }
            }
            
            // increment redvalues
            for(int var=0; var<arity; var++)
            {
                redvalues[var]++;
                while(!variables[var].is_present(redvalues[var])  &&  !(redvalues[var]>variables[var].upperbound()))
                {
                    redvalues[var]++;
                }
                if(redvalues[var]>variables[var].upperbound())
                {
                    redvalues[var]=variables[var].lowerbound();
                }
            }
        }
        
        //System.out.println(tuplist());
        
        gsbt= new reginlhomme_backtrack(last_tuple_pointer);
    }
    
    String tuplist()
    {
        String st="";
        for(int i=0; i<tuplelist.length; i++)
        {
            TupleH t=tuplelist[i];
            st+="index: "+i+" "+t+"\n";
        }
        return st;
    }
    
    public reginlhomme(mid_domain [] vars, qcsp prob, int [][] tuples)
    {
        super(vars, prob);
        //int [][] tuples=tuples.pred2table(pred, variables);
        
        arity = tuples[0].length ;
        tuplelist=new TupleH[tuples.length];
        
        TupleComparator2 tc = new TupleComparator2() ;
        Arrays.sort( (Object []) tuples, tc);
        
        last_tuple_pointer=new int[arity][];
        for(int i=0; i<arity; i++)
        {   last_tuple_pointer[i]=new int[variables[i].domsize()];
            Arrays.fill(last_tuple_pointer[i], -1);
        }
        
        int [] redvalues=new int[arity];
        for(int i=0; i<arity; i++) redvalues[i]=variables[i].lowerbound();
        
        // count total number of values
        /*int valcount=0;
        for(int i=0; i<arity; i++)
            for(int j=variables[i].lowerbound(); j<=variables[i].upperbound(); j++)
                if(variables[i].is_present(j)) valcount++;
        
        System.out.println("valcount: "+valcount);*/
        
        for(int i=0; i<tuples.length; i++)
        {
            TupleH t=new TupleH(tuples[i], (int [])redvalues.clone(), i);
            tuplelist[i]=t;
            
            // cross-link with previous tuples.
            int valcountlocal=arity;  // for each value in this tuple, there is one forward reference in nextValue.
        loopback:
            for(int j=i-1; j>=0; j--)
            {
                TupleH prev=tuplelist[j];
                for(int var=0; var<arity; var++)
                {
                    if(prev.values[var]==t.values[var] && prev.nextValue[var]==-1)
                    {
                        prev.nextValue[var]=i;
                        valcountlocal--;
                    }
                    if(prev.redundantValues[var]==t.values[var] && prev.redundantNextValue[var]==-1)
                    {
                        prev.redundantNextValue[var]=i;
                    }
                    if(valcountlocal==0) break loopback;
                }
            }
            
            // increment redvalues
            for(int var=0; var<arity; var++)
            {
                redvalues[var]++;
                while(!variables[var].is_present(redvalues[var])  &&  !(redvalues[var]>variables[var].upperbound()))
                {
                    redvalues[var]++;
                }
                if(redvalues[var]>variables[var].upperbound())
                {
                    redvalues[var]=variables[var].lowerbound();
                }
            }
        }
        
        //System.out.println(tuplist());
        
        gsbt= new reginlhomme_backtrack(last_tuple_pointer);
    }
    
    
    
  private class TupleH
  {
    // tuple class for Regin/Lhomme's bounding and jumping algorithm
    int id;  // global array index (can also be used for lex comparison of two tuples.)
    // no need for nextPointer
    int [] values;
    int [] nextValue;    // int index into global array, pointing to the next 
    
    int [] redundantValues;
    int [] redundantNextValue;
    
    TupleH(int [] values, int [] redundantValues, int id)
    {
        this.values=values;
        nextValue=new int[values.length];
        Arrays.fill(nextValue, -1);
        //nextPointer=new int[values.length];
        //Arrays.fill(nextPointer, -1);
        this.redundantValues=redundantValues;
        redundantNextValue=new int[values.length];
        Arrays.fill(redundantNextValue, -1);
    }
    
    public String toString()
    {
        String st="values: ";
        for(int i=0; i<values.length; i++) st+=values[i]+", ";
        st+=" nextValue: ";
        for(int i=0; i<values.length; i++) st+=nextValue[i]+", ";
        st+=" redValues: ";
        for(int i=0; i<values.length; i++) st+=redundantValues[i]+", ";
        st+=" redNextValue: ";
        for(int i=0; i<values.length; i++) st+=redundantNextValue[i]+", ";
        return st;
    }
  }
  
  final TupleH [] tuplelist;
  final int [][] last_tuple_pointer;  // backtracking pointer to the last tuple.
  final int arity;
  final reginlhomme_backtrack gsbt;
  
    tuple seekNextSupport(pa partial, tuple previous_support)
    {
        // don't actually use previous_support
        
        assert partial.vals.length==1;  // only works for all-existential constraints.
        int var=partial.vars[0];
        int val=partial.vals[0];
        
        int last_pointer=last_tuple_pointer[var][val+variables[var].offset()];
        //System.out.println("Finding support for var:"+var+" val:"+val);
        
        int nxvalid=0;
        boolean testmode=false;
        assert testmode=true;  // Ha!
        if(testmode)
        {
            if(nxvalid==-1) nxvalid++;
            while(nxvalid<tuplelist.length && (!tuples.valid(tuplelist[nxvalid].values, variables) || tuplelist[nxvalid].values[var]!=val))
                nxvalid++;
        }
        
        // for other variables, compute the max (over vars) of the min (over vals)
        // which gives the min lower bound of tuples that have been checked already.
        int lowerbound=last_pointer;
        for(int i=0; i<arity; i++)
        {
            if(i!=var)
            {
                int minlb=last_tuple_pointer[i][variables[i].lowerbound()+variables[i].offset()];
                
                for(int valIndex=variables[i].lowerbound()+1; valIndex<=variables[i].upperbound(); valIndex++)
                {
                    if(variables[i].is_present(valIndex))
                    {
                        int thisbound=last_tuple_pointer[i][valIndex+variables[i].offset()];
                        if(thisbound<minlb) minlb=thisbound;  // even if thisbound==-1.
                    }
                }
                if(minlb>lowerbound)
                {
                    lowerbound=minlb;
                }
            }
        }
        
        // now lowerbound contains the index of the tuple which is the greatest checked so far.
        //System.out.println("lowerbound:"+lowerbound);
        // compute the upperbound -- there has to be a better upper bound than this.
        int [] upperboundtuple=new int[arity];
        for(int i=0; i<arity; i++)
        {
            if(i!=var) 
            {
                upperboundtuple[i]=variables[i].upperbound();
            }
            else
            {
                upperboundtuple[i]=val;
            }
        }
        
        // now find the next one from lowerbound which contains (var, val)
        int curtupleIndex;
        
        if(lowerbound==-1)
        {
            curtupleIndex=nextin(var, val, 0);
        }
        else
        {
            curtupleIndex=nextin(var, val, lowerbound);
        }
        if(curtupleIndex==-1)
        {   // off the end of the list
            assert nxvalid==tuplelist.length;
            return null;
        }
        
        TupleH curtuple=tuplelist[curtupleIndex];
        
        if(tuples.comparetuples(curtuple.values, upperboundtuple)>0)
        {
            //System.out.println("Bigger than upperboundtuple");
            assert nxvalid==tuplelist.length;
            return null;
        }
        
        while(!tuples.valid(curtuple.values, variables))
        {
            curtupleIndex=curtuple.nextValue[var];  //  curtuple=NEXT((x,a), curtuple);
            if(curtupleIndex==-1)
            {
                assert nxvalid==tuplelist.length;
                return null;
            }
            
            int maxjump=curtupleIndex;
            for(int y=0; y<arity; y++)
            {
                if(y!=var)
                {
                    //int temp=nextinMinAllvals(y, curtupleIndex); // wrong; should look in last_tuple_pointer
                    // iterate for values and find the nextin for each. Take the min.
                    int b=variables[y].lowerbound();
                    int offset=variables[y].offset();
                    int ltp=last_tuple_pointer[y][b+offset];
                    int nextinminallvals=nextin(y, b, (ltp>curtupleIndex)?ltp:curtupleIndex);
                    for(b=variables[y].lowerbound()+1; b<=variables[y].upperbound(); b++)
                    {
                        if(variables[y].is_present(b))
                        {
                            ltp=last_tuple_pointer[y][b+offset];
                            int temp=nextin(y, b, (ltp>curtupleIndex)?ltp:curtupleIndex);
                            if(temp<nextinminallvals) nextinminallvals=temp;
                        }
                    }
                    
                    if(nextinminallvals>maxjump) maxjump=nextinminallvals;
                }
            }
            
            maxjump=nextin(var, val, maxjump);
            
            if(maxjump>curtupleIndex)
            {
                assert !tuples.valid(curtuple.values, variables); // for some reason the pseudocode assumes this.
                //System.out.println("Jumping forward to: "+maxjump);
                curtupleIndex=maxjump;
            }
            
            curtuple=tuplelist[curtupleIndex];
            
            if(tuples.comparetuples(curtuple.values, upperboundtuple)>0)  // this wouldn't be necessary if I got an index for upperboundtuple.
            {
                assert nxvalid==tuplelist.length;
                return null;
            }
        }
        gsbt.add_backtrack_pair(var, val+variables[var].offset(), last_tuple_pointer[var][val+variables[var].offset()]);
        last_tuple_pointer[var][val+variables[var].offset()]=curtupleIndex;
        
        //System.out.println("nxvalid:"+nxvalid+" "+tuples.atos(tuplelist[nxvalid].values)+" curtupleIndex:"+curtupleIndex+" "+tuples.atos(curtuple.values));
        //problem.printdomains();
        assert nxvalid==curtupleIndex;
        assert tuples.comparetuples(tuplelist[nxvalid].values, curtuple.values)==0;
        return new tuple(curtuple.values);
    }
  
  private int nextin(int var, int val, int curtuple)
  { // returns curtuple if curtuple contains var,val. So not strictly 'next'. If there is not one, returns -1.
    TupleH temp=tuplelist[curtuple];
    
    while(temp.values[var]!=val)
    {
        if(temp.redundantValues[var]==val)
        {
            assert temp.redundantNextValue[var]==-1 || tuplelist[temp.redundantNextValue[var]].values[var]==val;
            return temp.redundantNextValue[var];
        }
        curtuple++;
        if(curtuple>=tuplelist.length) return -1;
        temp=tuplelist[curtuple];
    }
    assert curtuple==-1 || tuplelist[curtuple].values[var]==val;
    return curtuple;
  }
  
  private int nextin_test(int var, int val, int curtuple)
  {
    int next1=nextin(var, val, curtuple);
    
    while(tuplelist[curtuple].values[var]!=val)
    {
        curtuple++;
        if(curtuple>=tuplelist.length) {curtuple=-1; break;}
    }
    
    assert next1==curtuple;
    return next1;
  }
  
  public void add_backtrack_level()
    {
        gsbt.add_backtrack_level();
        super.add_backtrack_level();
    }
    
    public void backtrack()
    {
        gsbt.backtrack();
        super.backtrack();
    }
    
    public boolean entailed()
    { // this is basically a stub.
        // if all vars unit
        boolean flag=true;
        for(int i=0; i<variables.length; i++)
        {
            if(!variables[i].unit())
            {
                flag=false;
                break;
            }
        }
        if(flag) return true;
        return false;
    }
}


/*private int nextinMinAllvals(int var, int curtupleIndex)
  {
    // find the minimum next tuple across all values of variable var.
    int min_nexttuple=-1;
    
    
    int offset=variables[var].offset();
    // count domain
    int count=0;
    for(int i=variables[var].lowerbound(); i<=variables[var].upperbound(); i++)
    {
        if(variables[var].is_present(i+offset))
            count++;
    }
    
    boolean [] valdone=new boolean[variables[var].domsize()];
    Arrays.fill(valdone, false);
    
    TupleH curtuple=tuplelist[curtupleIndex];
    
    while(count>0 && min_nexttuple>-1)  // min_nexttuple gets set to -1 if redundantNextValue has -1, indicating no next tuple.
    {
        int thisval=curtuple.values[var];
        if(!valdone[thisval+offset])
        {
            count--;
            if(min_nexttuple>curtupleIndex) min_nexttuple=curtupleIndex;
            valdone[thisval+offset]=true;
        }
        if(count==0) return min_nexttuple;
        
        thisval=curtuple.redundantValues[var];
        if(!valdone[thisval+offset])
        {
            count--;
            if(min_nexttuple>curtuple.redundantNextValue[var]) min_nexttuple=curtuple.redundantNextValue[var];
            valdone[thisval+offset]=true;
        }
        
        curtupleIndex++;
        curtuple=tuplelist[curtupleIndex];
    }
    return min_nexttuple;
  }*/

class reginlhomme_test
{
    public static void main(String[] args)
    {
        reginlhomme_test o1=new reginlhomme_test();
        o1.do_tests();
    }
    
    void do_tests()
    {
        r=new Random(12345);
        sw=new stopwatch("GMT");
        
        for(int i=0; i<100000; i++)
        {
            test_random_prob();
            System.out.println("Done test "+(i+1));
        }
    }
    
    Random r; // random number generator.
    stopwatch sw;
    
    void test_random_prob()
    {
        // comparison
        int comp=(int) (r.nextFloat()*6);
        // 0..5
        comp=comp-2;
        
        // twice as likely to test 0 as any other because of rounding
        final int numvars=3;
        int indexforbool=(int) (r.nextFloat()*numvars);  // 0,1,2
        
        mid_domain [] vars_sqgac= new mid_domain[numvars];
        mid_domain [] vars_other= new mid_domain[numvars];
        
        qcsp prob_sqgac = new qcsp();
        qcsp prob_other = new qcsp();
        
        for(int i=0; i<numvars; i++)
        {
            if(indexforbool==i)
            {
                vars_other[i]=new existential(2, prob_other, "bx"+(i+1));
                vars_sqgac[i]=new existential(2, prob_sqgac, "bx"+(i+1));
            }
            else
            {
                double temp=r.nextFloat();
                int ub, lb;
                if(temp<0.2)
                {   // positive
                    ub=12; lb=3;
                }
                else if(temp<0.4)
                {   // includes 0.
                    ub=9; lb=0;
                }
                else if(temp<0.6)
                {   // spans 0
                    ub=6; lb=-3;
                }
                else if(temp<0.8)
                {   // includes 0 negative
                    ub=0; lb=-9;
                }
                else
                {
                    ub=-3; lb=-12;
                }
                vars_other[i]=new existential(lb, ub, prob_other, "x"+(i+1));
                vars_sqgac[i]=new existential(lb, ub, prob_sqgac, "x"+(i+1));
            }
        }
        
        predicate_wrapper pred=new comparison_predicate(comp, indexforbool);
        
        nightingaletuples c1= new nightingaletuples(vars_other, prob_other, pred);
        reginlhomme c2= new reginlhomme(vars_sqgac, prob_sqgac, pred);
        
        // start the testing
        // first check equivalence of the first pass
        
        assert same_domains(prob_other, prob_sqgac) : "something wrong in constructing the problems";
        
        boolean flag1=prob_other.establish();
        boolean flag2=prob_sqgac.establish();
        
        assert flag2==flag1 : "Found "+flag2+" for sqgac and "+flag1+" for other";
        
        if(!flag2)
        {
            return;
        }
        
        flag1=prob_other.propagate();
        flag2=prob_sqgac.propagate();
        
        assert flag2==flag1 : "Found "+flag2+" for sqgac and "+flag1+" for other";
        
        if(flag2 && !same_domains(prob_other, prob_sqgac))   // if not false, compare doms.
        {
            System.out.println("different domains");
            System.out.println("other:");
            prob_other.printdomains();
            System.out.println("SQGAC:");
            prob_sqgac.printdomains();
            assert false;
        }
        
        if(!flag2)
        {
            return;
        }
        
        prob_other.printdomains();
        for(int descents=0; descents<5; descents++)
        {
            System.out.println("New descent.");
            // now make some random removals until each variable is unit.
            int numassigned=0;
            int forward=0;
            while(numassigned<vars_other.length)
            {
                // first check that an assignment is possible
                boolean assignmentposs=false;
                for(int i=0; i<vars_other.length; i++)
                {
                    if(!vars_other[i].unit())
                    {
                        assignmentposs=true;
                        break;
                    }
                }
                
                if(!assignmentposs)
                    break;
                
                System.out.println("Assignment possible.");
                
                // pick 3 removals to make
                for(int i=0; i<3 && numassigned<numvars; i++)
                {
                    int var=(int)(numvars*r.nextFloat());
                    
                    while(vars_other[var].unit())
                    {
                        var=(int)(numvars*r.nextFloat());
                    }
                    
                    int val=(int) (r.nextFloat()*40-20);
                    while(!vars_sqgac[var].is_present(val))
                    {
                        val=(int) (r.nextFloat()*40-20);
                    }
                    System.out.println("Test harness: Removing variable "+prob_sqgac.variables.get(var)+" value "+val);
                    
                    ((mid_domain)prob_other.variables.get(var)).exclude(val, null);
                    ((mid_domain)prob_sqgac.variables.get(var)).exclude(val, null);
                    
                    numassigned=0;
                    for(int j=0; j<numvars; j++)
                    {
                        if(vars_sqgac[j].unit()) numassigned++;
                    }
                }
                prob_other.add_backtrack_level();
                prob_sqgac.add_backtrack_level();
                forward++;
                
                for(int i=0; i<numvars; i++)
                {
                    assert !vars_other[i].empty();
                    assert !vars_sqgac[i].empty();
                }
                
                flag1=prob_other.propagate();
                flag2=prob_sqgac.propagate();
                
                if(flag2!=flag1)
                {
                    System.out.println("other:");
                    prob_other.printdomains();
                    System.out.println("sqgac:");
                    prob_sqgac.printdomains();
                    
                    assert false : "found "+flag1+" for other_constraint and "+flag2+" for sqgac";
                }
                
                if(flag2 && !same_domains(prob_other, prob_sqgac))   // if not false, compare doms.
                {
                    System.out.println("different domains");
                    System.out.println("other:");
                    prob_other.printdomains();
                    System.out.println("sqgac:");
                    prob_sqgac.printdomains();
                    
                    assert false;
                }
                
                if(!flag1 || !flag2)
                    break;  // no point making more assignments.
            }
            
            // backtrack before starting the next descent.
            for(int i=0; i<forward; i++)
            {
                prob_other.backtrack();
                prob_sqgac.backtrack();
            }
        }
    }
    
    boolean same_domains(qcsp prob1, qcsp prob2)
    {
        if(prob1.variables.size()!=prob2.variables.size())
            return false;
        
        for(int i=0; i<prob1.variables.size(); i++)
        {
            mid_domain var1 = (mid_domain)prob1.variables.get(i);
            mid_domain var2 = (mid_domain)prob2.variables.get(i);
            
            if(var1.domsize()!=var2.domsize())
                return false;
            
            if(var1.lowerbound()!=var2.lowerbound())
                return false;
            
            if(var1.upperbound()!=var2.upperbound())
                return false;
            
            for(int j=var1.lowerbound(); j<=var1.upperbound(); j++)
            {
                if(var1.is_present(j)!=var2.is_present(j))
                    return false;
            }
        }
        return true;
    }
}
