import stokeslet_new as st
import matplotlib.pyplot as plt
import numpy as np
import pickle
import datetime
from streamplot import streamplot

#numCircles = range(20,250,5)
#runsPerDensity = 10

def make_nums(numCircles, runsPerDensity):
    nums = np.zeros((len(numCircles),runsPerDensity))
    for num in numCircles:
        for run in range(0,runsPerDensity):
            num_ind = numCircles.index(num)
            nums[num_ind,run] = num
    return nums

def load_run(num, run, dir = 'runs', with_times = False):
    filename =  dir + '/numCircle.' + str(num) + '.run.' + str(run) + '.data'
    file = open(filename,'rb')
    
    ss = pickle.load(file)
    pf = pickle.load(file)
    fg = pickle.load(file)
    start = pickle.load(file)
    finish = pickle.load(file)
    
    file.close()
    
    if with_times:
        return ss, pf, fg, start, finish
    else:
        return ss, pf, fg

def get_edge_velocities(fg):    
    U_rear   = fg.U[:,0]
    U_lead   = fg.U[:,-1]
    V_top    = fg.V[-1,:]
    V_bottom = fg.V[0,:]
    
    return U_rear, U_lead, V_top, V_bottom

def get_edge_pressures(fg):
    P_rear = fg.P[:,0]
    P_lead = fg.P[:,-1]
    
    return P_rear, P_lead

def get_pressure_stats(dir = 'runs', numCircles = range(20,250,5), runsPerDensity = 10):
    nums = np.zeros((len(numCircles),runsPerDensity))
    P_rear_avg = np.zeros((len(numCircles),runsPerDensity))
    P_lead_avg = np.zeros((len(numCircles),runsPerDensity))
    
    for num in numCircles:
        for run in range(0,runsPerDensity):
            ss, pf, fg = load_run(num, run, dir)
            P_rear, P_lead = get_edge_pressures(fg)
            
            num_ind = numCircles.index(num)
            
            nums[num_ind,run] = num
            
            P_rear_avg[num_ind,run] = P_rear.mean()
            P_lead_avg[num_ind,run] = P_lead.mean()
    
    P_rear_avg_mean = P_rear_avg.mean(1)
    P_lead_avg_mean = P_lead_avg.mean(1)
    
    return nums, P_rear_avg, P_rear_avg_mean, P_lead_avg, P_lead_avg_mean

def get_stats(dir = 'runs', numCircles = range(20,250,5), runsPerDensity = 10, with_V = False):
    nums = np.zeros((len(numCircles),runsPerDensity))
    U_rear_avg = np.zeros((len(numCircles),runsPerDensity))
    U_lead_avg = np.zeros((len(numCircles),runsPerDensity))
    V_top_avg = np.zeros((len(numCircles),runsPerDensity))
    V_bottom_avg = np.zeros((len(numCircles),runsPerDensity))

    for num in numCircles:
        for run in range(0,runsPerDensity):
            
            ss, pf, fg = load_run(num, run, dir)
            
            U_rear, U_lead, V_top, V_bottom = get_edge_velocities(fg)
            
            num_ind = numCircles.index(num)
            
            nums[num_ind,run] = num
            
            U_rear_avg[num_ind,run] = U_rear.mean()
            U_lead_avg[num_ind,run] = U_lead.mean()
            V_top_avg[num_ind,run] = V_top.mean()
            V_bottom_avg[num_ind,run] = V_bottom.mean()
    
    U_rear_avg_mean = U_rear_avg.mean(1)
    U_lead_avg_mean = U_lead_avg.mean(1)
    V_top_avg_mean = V_top_avg.mean(1)
    V_bottom_avg_mean = V_bottom_avg.mean(1)
    
    if with_V:
        return nums, U_rear_avg, U_rear_avg_mean, U_lead_avg, U_lead_avg_mean, V_top_avg, V_top_avg_mean, V_bottom_avg, V_bottom_avg_mean
    else:
        return nums, U_rear_avg, U_rear_avg_mean, U_lead_avg, U_lead_avg_mean

def make_avg_fig(nums, U_rear_avg, U_rear_avg_mean, U_lead_avg, U_lead_avg_mean):
    fig = plt.figure((16,12))
    for i in range(0,nums.shape[0]):
        plt.plot(nums[i,:]/100, U_lead_avg[i,:], 'bo')
    plt.plot(nums[:,0]/100, U_lead_avg_mean, 'ro-')
    plt.xlabel(r'Cylinders per $\mu$m$^2$')
    plt.ylabel(r'Average outward flow at leading edge')
    return fig

def make_pressure_fig(num, run, dir='runs'):
    ss, pf, fg = load_run(num, run, dir)
    U_rear, U_lead, V_top, V_bottom = get_edge_velocities(fg)
    P_rear, P_lead = get_edge_pressures(fg)
    
    circle_indices = range(pf.xF.size-1,pf.xF.size-1-ss.nF,-1)
    
    fig = plt.figure(figsize=(8,12))
    
    plt.subplot(211)
    plt.contourf(10*fg.X,10*fg.Y,fg.P,200)
    plt.colorbar()
    plt.plot(10*pf.xF[circle_indices],10*pf.yF[circle_indices],'w.')
    
    plt.subplot(212)
    plt.plot(fg.Y[:,-1],P_rear,'b')
    plt.plot(fg.Y[:,-1],P_lead,'r')

def make_fig(num, run, dir='runs'):
    ss, pf, fg = load_run(num, run, dir)
    U_rear, U_lead, V_top, V_bottom = get_edge_velocities(fg)
    
    circle_indices = range(pf.xF.size-1,pf.xF.size-1-ss.nF,-1)
    
    fig = plt.figure(figsize=(8,12))
    
    plt.subplot(211)
    plt.contourf(10*fg.X,10*fg.Y,fg.speed,200)
    streamplot(10*fg.x,10*fg.y,fg.U,fg.V,density=2)
    #plt.plot(10*pf.xF[circle_indices],10*pf.yF[circle_indices],'w.')
    plt.plot(10*pf.xF[circle_indices],10*pf.yF[circle_indices],'w.')
    plt.xlabel(r'$x$ ($\mu$m)')
    plt.ylabel(r'$y$ ($\mu$m)')
    plt.title(r'Membrane speed over whole cell')
    
    plt.subplot(212)
    plt.plot(fg.Y[:,-1],U_lead)
    plt.xlabel(r'$y$ ($\mu$m)')
    plt.ylabel(r'$u$')
    plt.title(r'Membrane speed normal to the leading edge')
    
    titleString = r'Density over whole cell: ' + str(num/100.0) + r' cylinders/$\mu$m$^2$'
    plt.suptitle(titleString)
    
    return fig