from matplotlib import use; use('Agg')
import numpy as np
import h5py as h5
import pylab as plt
import glob, sys

fname = 'GridDensity'
#fname = 'ParticleVelocities_x'
coord = 0.5

def save_image(image, fname, vrange=None):
    plt.clf()
    DPI = 72.0
    fig_size = np.shape(image)[0]/DPI, np.shape(image)[1]/DPI
    plt.figure(figsize=fig_size)
    ax = plt.axes([0,0,1,1], frameon=False)
    ax.set_axis_off()
    if vrange is None:
        plt.imshow(image, interpolation='nearest', origin='lower')
    else:
        plt.imshow(image, interpolation='nearest', origin='lower',
                   vmin=vrange[0], vmax=vrange[1])
    plt.savefig(fname)
    return

files = glob.glob(fname+".?")
files.sort()
slices = []
top_dim = None
for f in files:
    fptr = h5.File(f)
    this_dims = fptr[f].shape
    this_coord = int(this_dims[1]/2 + (coord-0.5)*this_dims[1])
    if top_dim is None:
        top_dim = this_dims[1]
    this_slice = fptr[f][0, this_coord, :, :]
    slices.append(this_slice)
    fptr.close()

for i,s in enumerate(slices):
    nx = len(s[0,:])
    ny = len(s[:,0])
    magnify = 2**(len(files)-i-1)
    image_size = top_dim * 2**(len(files)-1)
    j0 = image_size/2 - nx*magnify/2
    k0 = image_size/2 - ny*magnify/2
    print j0, k0
    if i == 0:
        image = np.zeros((image_size, image_size))
    for j in range(nx):
        jstart = j0+j*magnify
        jend = j0+(j+1)*magnify
        for k in range(ny):
            kstart = k0+k*magnify
            kend = k0+(k+1)*magnify
            image[jstart:jend, kstart:kend] = s[j,k]
    if i == 0:
        immin = image.min()
        immax = image.max()
    save_image(image, "%s_sub%d.png" % (fname, i), vrange=[immin,immax])
