import plotly.graph_objects as go
import plotly.io as pio 
import numpy as np
from scipy.interpolate import RegularGridInterpolator
import matplotlib.pyplot as plt 
import sys
from ase.io.vasp import read_vasp

jcar = str(sys.argv[1]) # Take the NMRCURBX, NMRCURBY, or NMRCURBZ file for plotting the current
num = int(sys.argv[2]) # Take every point nth point (e.g. n=1 would be every point, n=2 would be every other point)
height = float(sys.argv[3]) # What plane to take the current from in direct coordinates, e.g. 0.5 would be half way through the cell
S = float(sys.argv[4]) # Scale the length of the current arrows in the plot by (arrow_lrngth)/10^(your number) * cell_width

# read NMRCUR
with open(jcar, 'r') as f:
    #PROCESS STRUCTURAL INFORMATION
    system = f.readline()
    scale = float(f.readline())

    lattice = []
    for i in range(0,3):
        lattice.append( [float(x) for x in f.readline().split()] )
    lattice_matrix = np.array(lattice)

    atoms = [ x for x in f.readline().split()]
    atoms_count = [ int(x) for x in f.readline().split()]

    fractional = True if "Direct" in f.readline() else False

    position = []
    for i in range(0,len(atoms)):
        for j in range(0,atoms_count[i]):
            position.append( [float(x) for x in f.readline().split()] )

    f.readline()

    #READ CURRENT INFO
    gridc = [int(x) for x in f.readline().split()]
    tmp = []
    n_lines = int(np.ceil(gridc[0]*gridc[1]*gridc[2]/5))
    for _ in range(0,4):
        for __ in range(0,n_lines):
            line = f.readline()
            for x in line.split():
                tmp.append(float(x))
        f.readline()
        #f.readline()
    jden = np.array(tmp)

# components of current density
print(np.size(jden), gridc[2],gridc[1],gridc[0])
j = np.reshape(jden,(3,gridc[2],gridc[1],gridc[0]))
jx, jy, jz = np.array(j[0]), np.array(j[1]), np.array(j[2])
# j[0] all zeros
jx, jy, jz = np.transpose(jx, (2, 1, 0)), np.transpose(jy, (2, 1, 0)), np.transpose(jz, (2, 1, 0))
jdotj = jx*jx + jy*jy + jz*jz

x_vals = np.linspace(0,1,gridc[0])
y_vals = np.linspace(0,1,gridc[1])
z_vals = np.linspace(0,1,gridc[2])


print("x=", np.shape(x_vals))
print("y=", np.shape(y_vals))
print("z=", np.shape(z_vals))
print("jx=", np.shape(jx))

# uncomment to plot unit cell in fractional coordinates
#lattice_matrix = np.array([[1,0,0],[0,1,0],[0,0,1]])
a, b, c = scale*lattice_matrix
x, y, z = [], [], []
n=0
xx, yy, zz=[], [], []

for k in range(0,gridc[0]):
    xx.append([])
    yy.append([])
    zz.append([])
    for j in range(0,gridc[1]):
        xx[k].append([])
        yy[k].append([])
        zz[k].append([])
        for i in range(0,gridc[2]):
            #pos = (i/(gridc[2]-1)*c+j/(gridc[1]-1)*b+k/(gridc[0]-1)*a)
            pos = (i/(gridc[0])*c+j/(gridc[1])*b+k/(gridc[2])*a)
            x.append(pos[0])
            y.append(pos[1])
            z.append(pos[2])
            xx[k][j].append(pos[0])
            yy[k][j].append(pos[1])
            zz[k][j].append(pos[2])
            n = n + 1 
xx, yy, zz = np.array(xx), np.array(yy), np.array(zz)

"""
xx, yy, zz = np.meshgrid(x_vals, y_vals, z_vals, indexing='ij')
print("Fractional")
print(np.shape(xx))
"""

x_slice = int(height*gridc[0])

z_slice = zz[x_slice, ::num, ::num]
y_slice = yy[x_slice, ::num, ::num]
Bz_slice = jz[x_slice, ::num, ::num]
By_slice = jy[x_slice, ::num, ::num]
v0_slice =  jdotj[x_slice, ::num, ::num]
plt.figure(figsize=(24,20))
plt.contourf(y_slice, z_slice, v0_slice, alpha=0.5)
plt.colorbar()
plt.quiver(y_slice, z_slice, By_slice, Bz_slice, scale=10.0**S, scale_units='width')
plt.xlabel("y")
plt.ylabel("z")
plt.title("Current density (jy, jz) and intensity (j dot j) in the yz-plane at x = " + str(height))

plt.savefig("fig_nmrcurbx-slice_c6h6.png")

