# This program reads a map series and applies kriging to produce a
# reconstruction with global coverage.
# If a distance is specified, the covariance function is constructed
# using the specified distance rather than calculated from the data.
# Usage:
#  python krig1vb.py map.dat > new.map
#  python krig1vb.py map.dat 830.0 > new.map


import sys, math, numpy, scipy.linalg, scipy.optimize


# read a month of map data
def read_map( lines ):
  w = lines[0].split()
  month,year = sorted( [int(w[0]),int(w[1])] )
  date = year+month/12.0-1.0/24.0
  smap = [[numpy.nan for i in range(72)] for j in range(36)]
  for j in range(len(smap)):
    w = lines[j+1].split()
    for i in range(len(smap[j])):
      if not '.' in w[i]:
        t = 0.01*float(w[i])
      else:
        t = float( w[i] )
      if t > -99.0:
        smap[j][i] = t
  smap.reverse()
  return year, month, smap


# write a month of map data
def write_map( year, month, smap ):
  tmap = reversed( smap )
  lines = ["%4d %2d\n"%(year,month)]
  for row in tmap:
    s = ""
    for val in row:
      if not numpy.isnan(val):
        s += "%7.3f "%(val)
      else:
        s += "-99.9 "
    lines.append( s[:-1] + "\n" )
  return lines


# calculate great circle distance
def distance( la1, ln1, la2, ln2 ):
  a = 0.99999999*( math.sin(la1)*math.sin(la2) +
                   math.cos(la1)*math.cos(la2)*math.cos(ln1-ln2) )
  return 6371.0*math.acos( a )


# functional form for correlations
def fitfunc(p,x):
  e = -p[0]-p[1]*x
  return numpy.exp(e)


# prepare a list of intercell distances for flattened maps
def prepare_dists():
  nla = 36
  nln = 72
  la = [(math.radians(180*(i+0.5)/nla- 90.0)) for i in range(nla)]
  ln = [(math.radians(360*(i+0.5)/nln-180.0)) for i in range(nln)]
  las = numpy.zeros([nla,nln])
  lns = numpy.zeros([nla,nln])
  for ila in range(nla):
    for iln in range(nln):
      las[ila,iln] = la[ila]
      lns[ila,iln] = ln[iln]
  npt = nla*nln
  las = numpy.array(las).flatten()
  lns = numpy.array(lns).flatten()
  dists = numpy.zeros([npt,npt])
  for i in range(npt):
    for j in range(i,npt):
      dists[i][j] = dists[j][i] = distance( las[i], lns[i], las[j], lns[j] )
  return dists


# fit data for a combination of land/ocean maps
def fit_correlation( dists, nmap1, nmap2 ):
  npt, nmonth = nmap1.shape

  # make a list of station pair correlations by distance
  dhistn = [ 0.0 for i in range(100) ]
  dhistx = [ 0.0 for i in range(100) ]
  dhisty = [ 0.0 for i in range(100) ]
  for i1 in range(npt):
    for i2 in range(npt):
      dist = dists[i1,i2]
      if dist < 10000.0:
        d = int(dist/300.0+0.5)
        diff = numpy.square(nmap1[i1,:]-nmap2[i2,:])/2.0
        n = numpy.sum(numpy.logical_not(numpy.isnan(diff)))
        if n > 0:
          dhistn[d] += n
          dhistx[d] += n*dist
          dhisty[d] += numpy.nansum(diff)

  # histogram means
  X,Y = [],[]
  for i in range(len(dhistn)):
   if dhistn[i] > 0.5:
      X.append( dhistx[i]/dhistn[i] )
      Y.append( dhisty[i]/dhistn[i] )
  X = numpy.array(X)
  Y = numpy.array(Y)
  # coverage semivariogram to correlation
  Yinf = numpy.mean( Y[len(Y)/2:] )
  Y = Yinf - Y

  # for more terms...
  #  e = -p[0]-p[1]*x-p[2]*numpy.power(x,2) #-p[3]*numpy.power(x,3)-p[4]*numpy.power(x,4)
  #  p0 = [0.0, 0.0, 0.0, 0.0, 0.0]

  def errfunc(p,y,x):
    return y-fitfunc(p,x)

  # fit analytical form
  p0 = [0.0, 0.0]
  p1,success = scipy.optimize.leastsq( errfunc, p0[:], args=(Y,X) )
  print >> sys.stderr, p1
  if success != 1: raise( Exception( "Leastsq fail" ) )

  f = open( "variogram.dat", "w" )
  for i in range(len(X)):
    f.write( "%12.2f %12.4f %12.4f %12.4f %12.4f\n"%(X[i],Y[i],fitfunc(p1,X[i]),Yinf-Y[i],Yinf-fitfunc(p1,X[i])))
  f.close()

  # return function coeffs
  return p1


# determine list of points
def prepare_correls( dists, p ):
  npt = dists.shape[0]
  correl = numpy.zeros([npt,npt])
  for i in range(npt):
    for j in range(i, npt):
      correl[i,j] = correl[j,i] = fitfunc( p, dists[i][j] )
  return correl


# precondition the matrix to ensure a[i,i]=1
def precondition( correl ):
  s = numpy.empty( [correl.shape[0]] )
  for i in range(correl.shape[0]):
    s[i] = 1.0/math.sqrt( correl[i,i] )
  return s*correl*s


# interpolate a map
def interpolate( data, correls ):
  # set up matrices
  unobsflag = numpy.isnan(data)
  obsflag = numpy.logical_not( unobsflag )
  tmp = correls[obsflag,:]
  a = tmp[:,obsflag]
  b = tmp[:,unobsflag]
  c = data[obsflag]
  a = numpy.vstack( [
        numpy.hstack( [ a, numpy.ones([a.shape[0],1]) ] ),
        numpy.hstack( [ numpy.ones([1,a.shape[1]]), numpy.zeros([1,1]) ] )
    ] )
  b = numpy.vstack( [ b,numpy.ones([1,b.shape[1]]) ] )
  c = numpy.hstack( [ c,numpy.zeros([1]) ] )
  # solve for basis function weigths
  print >> sys.stderr, a.shape, b.shape, c.shape
  x = scipy.linalg.solve( a, b )
  print >> sys.stderr, x.shape
  # calculate temperatures
  t = numpy.dot( c, x )
  print >> sys.stderr, t.shape
  # and store
  result = data.copy()
  j = 0
  for i in range(obsflag.shape[0]):
    if not obsflag[i]:
      result[i] = t[j]
      j += 1
  print >> sys.stderr, result.shape, result[0]
  return result


# MAIN PROGRAM
# default values
datafile1 = None
dist      = None
if len(sys.argv) > 1:
  datafile1 = sys.argv[1]
if len(sys.argv) > 2:
  dist = float(sys.argv[2])

# read data
nmonths = 9999
f = open( datafile1 )
lines1 = f.readlines()
f.close()
nmonths = min(nmonths,len(lines1)/37)

# calculate maps
maps = []
for m in range(nmonths):
  # read land data
  year1,month1,tmap1 = read_map( lines1[37*m:37*m+37] )
  maps.append( ( year1, month1, tmap1 ) )

# distance matrix
dists = prepare_dists()

# prepare norms
base0,base1 = 1981,2010
nmap1 = []
for m in range(nmonths):
  year1, month1, tmap1 = maps[m]
  if base0 <= year1 <= base1:
    nmap1.append( numpy.array(tmap1).flatten() )
nmap1 = numpy.array(nmap1).T

# prepare correlation functions
if dist == None:
  p00 = fit_correlation( dists, nmap1, nmap1 )
else:
  p00 = [0.0,1.0/dist]
print >> sys.stderr, 'Distance ', 1.0/p00[1]

# prepare Kriging correlation matrices
correl = prepare_correls( dists, p00 )
#correl = precondition( correl )
print >> sys.stderr, correl

# produce interpolated maps
for m in range(0,nmonths):
  # read land data
  year1, month1, tmap1 = maps[m]
  print >> sys.stderr, "Processing:",year1,month1
  # flatten data
  data1 = numpy.array(tmap1).flatten()
  # interpolate
  data2 = interpolate( data1, correl )
  # expand data
  tmap1 = data2.reshape(36,72).tolist()
  # write
  for l in write_map( year1, month1, tmap1 ):
    print l,

