from math import log10
import math
import numpy as na
#nullfmt = NullFormatter()
from mpi4py import MPI
import re
import glob
import os
import time

comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size=comm.size

redshift = 15.0
redshift_end = 11.0
redshift_previous = 15.4
tfinal =   150439000      # in years
# 24: 4345107; 23.5:4615649; 23: 9980088 22: 11126760 21: 12466632 20: 6807916
# 19.5: 7235949; 19.00: 7702654; 18.50: 9740892 17.91: 8176002 17.45: 8666562  16.99:7954699 
# 16.60: 12848160 15.99: 13803675 15.40: 10208909; 15.00: 28263000 
# to z=10 209660000 ; z=11 150439000 z=12 101439000 z=13 61439000
#to z=6  681803000  ;  z=7   505439000   ;z=8 378439000   ;  z=9  282439000   
nx = 512
Ex = 1e3   # photon energy (eV)
cfl = 0.5

erg_eV = 8.61423e-5
G = 6.673e-8
H0 = 71.0
Omega_b = 0.0449
Omega_m = 0.266
yH = 0.76
IH = 13.6
h = 6.626e-27
kb = 1.38e-16
eV = 1.602e-12
yr = 3.1557e7
mH = 1.661e-24

#rhoH = 3 * (H0/3.086e19)**2 / (8 * na.pi * G) * yH * (1.0+redshift)**3
nu = Ex*eV / h

tc = 1.459

#read Xray flux

if rank == 0:
 filelist = glob.glob("jXRAY_*_%6.3f_dat"%redshift)
 xraybins = len(filelist)
 xrayzstart = na.zeros([xraybins])
 xrayzend = na.zeros([xraybins])
 xrayenergy = na.zeros([xraybins])
 Xray_r = []
 for i,name in enumerate(filelist):
  xrayzstart[i] = float(re.findall(r"[-+]?\d*\.\d+|\d+",name)[0])
  xrayzend[i] = float(re.findall(r"[-+]?\d*\.\d+|\d+",name)[1])
  xrayenergy[i] = Ex*(1+redshift)/(1+(xrayzstart[i]+xrayzend[i])/2)
  f = open(name,'r')
  #f.seek(12,os.SEEK_SET)
  #Xray_r.append(na.fromfile(f,dtype='float64').reshape(nx,nx,nx))
  Xray_r.append(na.fromfile(f,dtype='float64')*(1+redshift)/(1+(xrayzstart[i]+xrayzend[i])/2))
  print "Reading %s"%name, xrayzstart[i], xrayzend[i], xrayenergy[i] 
  f.close()
else: 
 xraybins = None

print "rank %s: beginning communication"%rank

xraybins=comm.bcast(xraybins, root=0) 

if rank !=0:
  Xray_r = na.zeros(xraybins)

# Read cooling table (primordial abundances -- 2nd column)
cool_table = na.loadtxt("zcool_sd93.dat")[:,0:2]

if rank == 0:
  rho = (na.fromfile('Density_512_%5.2f.dat'%redshift))*yH
  Tem = na.fromfile('Temperature_512_%5.2f.dat'%redshift)
  #if redshift < 18.5: Tem = Tem*tc
  DTem = na.fromfile('Delta_Temperature_%5.2f_%5.2f.dat'%(redshift_previous,redshift))
  DTem[DTem < 0] = 0.0
  Tem = Tem+DTem
  ef = (na.fromfile('HII_Fraction_512_%5.2f.dat'%redshift))/yH 
  Def = na.fromfile('Delta_Electron_fraction_%5.2f_%5.2f.dat'%(redshift_previous,redshift))
  Def[Def < 0] = 0.0
  ef = ef+Def
else:
  rho = None
  Tem = None
  DTem = None
  ef = None
  xrayenergy = None
#Xflux = na.fromfile('XrayFlux.dat')

rho_local = na.zeros(nx**3/size)
Tem_local = na.zeros(nx**3/size)
ef_local  = na.zeros(nx**3/size)
#Xflux = na.zeros(nx**3/size)
Xflux_local = []
#for i in xrange(xraybins):
#  Xflux_local.append(na.zeros(nx**3/size))

comm.Scatter(rho,rho_local,root=0)
comm.Scatter(Tem,Tem_local,root=0)
comm.Scatter(ef,ef_local,root=0)
for i in xrange(xraybins):
   xray = na.zeros(nx**3/size)
   comm.Scatter(Xray_r[i],xray,root=0)  
   Xflux_local.append(xray)
xrayenergy = comm.bcast(xrayenergy, root=0)
Xflux_local=na.array(Xflux_local)
for i in xrange(xraybins):
    Xflux_local[i] = Xflux_local[i]/(Ex*eV)
Xflux_local=Xflux_local.T

if rank == 0:
   Tem = Tem-DTem
   ef = ef - Def

comm.Barrier()

print "rank %s: finishing communication"%rank

del rho,Xray_r,DTem
#del Tem, ef

def CH(T):
    lnT = na.log(T*erg_eV)
    coeff = [-32.71396786, 13.536556, -5.73932875, 1.56315498, -0.2877056,
              3.48255977e-2, -2.63197617e-3, 1.11954395e-4, -2.03914985e-6]
    a = 0.0
    for i in range(len(coeff)):
        a += coeff[i] * lnT**i
    return na.exp(a)

def alphaB(T):
    return 2.59e-13 * (T/1e4)**(-0.7)

def sigmaH(E):
    return 5.475e-14 * (E / 0.4298 - 1)**2 * (E / 0.4298)**(-4.0185) * \
        (1 + na.sqrt(E / 14.13))**(-2.963)

def secondary_ion(x):
    return 0.3908 * (1.0 - x**0.4092)**1.7592

def secondary_heat(x):
    return 0.9971 * (1.0 - (1.0 - x**0.2663)**1.3163)



T_final_local = na.zeros(nx**3/size)
x_final_local = na.zeros(nx**3/size)

#t_start=time.time()
ii=0 
sigmalist = map(sigmaH,xrayenergy)
for rhoH,T,x in zip(rho_local,Tem_local,ef_local):
  #F =  1e-6 #XFlux[ii]
  #tfinal = 1e7 * yr
  dt0 = 1e6 * yr
  t = 0.0
  Eth = kb * T
  #results = {"time": [], "x": [], "T": [], "dx_dt": [], "dE_dt": []}
  while t < tfinal*yr:
        ne = rhoH*x/mH
        T = Eth / kb
        logT = log10(T)
        if logT > 8.45: logT = 8.45
        if x >1: x=1
        kph = 0.0;heat = 0.0;
        for F,sigma,Exray in zip(Xflux_local[ii],sigmalist, xrayenergy):
          kph += secondary_ion(x) * F * sigma * (Exray/IH)
          heat += secondary_heat(x) * F * sigma * (Exray*eV)
        if T < 1e4:
            cool = 0.0
        else:
            cool_idx = na.searchsorted(cool_table[:,0], logT)
            interp_factor = (logT - cool_table[cool_idx,0]) / \
                (cool_table[cool_idx+1,0] - cool_table[cool_idx,0])
            cool = cool_table[cool_idx,1] + \
                interp_factor * (cool_table[cool_idx+1,1] - cool_table[cool_idx,1])
            cool = 10.0**cool
        dx_dt = (1.0 - x) * (kph - ne*CH(T)) - x*ne*alphaB(T)
        dt = min(dt0, 0.01 * abs(x/(dx_dt)),
                 0.01 * abs(Eth / (heat - cool)))
        t += dt
        x += dx_dt * dt
        Eth += (heat - cool) * dt
  T_final_local[ii] = Eth/kb
  x_final_local[ii] = x
  ii = ii + 1
  #if ii > 10000: break

if rank == 0:
        T_final = na.zeros(nx**3)
        x_final = na.zeros(nx**3)
else:
        T_final = None
        x_final = None

comm.Gather(T_final_local,T_final,root=0)
comm.Gather(x_final_local,x_final,root=0)

if rank == 0:
  T_final = (T_final-Tem)*((1+redshift_end)/(1+redshift))**2+Tem
  T_final.tofile('Final_Temperature_%5.2f_%5.2f.dat'%(redshift,redshift_end))
  x_final.tofile('Final_Electron_fraction_%5.2f_%5.2f.dat'%(redshift,redshift_end))
  (T_final-Tem).tofile('Delta_Temperature_%5.2f_%5.2f.dat'%(redshift,redshift_end))
  (x_final-ef).tofile('Delta_Electron_fraction_%5.2f_%5.2f.dat'%(redshift,redshift_end))
#print time.time()-t_start
