import pandas as pd
import numpy as np 
from datetime import datetime, timedelta
import matplotlib as mpl
import matplotlib.pyplot as plt

class MP:
    """
    The Multi-PIP class.
    First Authored: 2/28/2021 
    
    The methods contained in MP are grouped by functionality and outlined with brief descriptions below.
    
    Data I-O Related Methods
    ------------------------
        load : Imports data from CSV.
    
    Calculation Related Methods
    ---------------------------
        calc_NDZI : Calculates NDZI for all probe pairs.
    
    Plotting Related Methods
    ------------------------
        plot_NDZI : Plots NDZI for one probe pair.
        plot_z_vs_t : Plots impedance versus time for one probe pair.
        plot_z_h_vs_z_l : Plot high impedance versus low impedance for one probe pair.
        
        
    Formatting and Checking Related Methods
    ---------------------------------------
        format_datetime_plot : Sets title and axis format for datetime plot.
        prb_pair_name_crosscheck : Crosschecks probe pair with probe pair name.
    
        
    Please see method descriptions for more detailed documentation on each of the methods.

    """
    version = None
    name = None
    metadata = pd.DataFrame()
    data = pd.DataFrame()
    NDZI = pd.DataFrame()
    
    def __init__(self):
        print("You have initialized the Multi-PIP data handler. \n Please upload a CSV data file!\n\n")

    def load(self,path):
        """
            load imports data from a CSV file and saves it to the data and metadata attributes of MP.
            If something goes wrong loading the file it throws an exception. 
            
            Inputs
            ------
            path : string. 
                The path to the CSV data file. 
            
        """
        try:
            self.metadata = pd.read_csv(path,header=None,usecols=[1,2], nrows=3).transpose()
            self.metadata.columns = ['Device Info','Start Time','End Time']
            self.name = self.metadata.iloc[0]['Device Info']
            self.version = self.metadata.iloc[1]['Device Info']
            self.data = pd.read_csv(path,skiprows=4)
        
            print("Loaded data from device "+self.name+". \nThe data file loaded has "+str(self.data.shape)+ " rows and columns.\n\n")
            
        except:
            print("Something went wrong loading your data file. Check your path and data file format and try again.\n\n")
        

    def calc_NDZI(self,f_l,f_h):
        """
            calc_NDZI computes the NDZI associated with specific user entered frequencies. 
            It then saves the current NDZI under the self.NDZI attribute.
            
            Inputs
            ------
            f_l : integer. 
                Low frequency to use for NDZI computation.
            f_h : integer. 
                High frequency to use for NDZI computation.
            
            Outputs
            -------
            NDZI : pandas dataframe. 
                Contains NDZI values associated with f_l and f_h as well as the associated probe pair name and timestamp for each of the values.
        
        """
        print("Calculating NDZI...")
        # Find probe pair names and frequency vector
        freq = pd.unique(self.data['frequency'])
        prb_pair_names = pd.unique(self.data['probe pair name'])
        prb_pairs = pd.unique(self.data['probe pair'])
        # Check if frequencies inputted are in vector
        if not (f_l in freq) & (f_h in freq):
            raise Exception("One of the frequencies you entered: ("+str(f_l)+","+str(f_h)+") is not contained in the data.")
            
        
        # Initialize NDZI dataframe
        NDZI = pd.DataFrame(columns=["Time","probe pair","probe pair name","NDZI","High freq","Low freq"])
        for cnt,i in enumerate(prb_pair_names):

            # Find indices associated with high and low frequency and probe pair
            ind_f_l = self.data.loc[(self.data['frequency'] == f_l) & (self.data['probe pair name'] == i)]
            ind_f_h = self.data.loc[(self.data['frequency'] == f_h) & (self.data['probe pair name'] == i)]
                    
            # Get impedance and time
            ind_z_l = ind_f_l['impedance']/1000
            ind_z_h = ind_f_h['impedance']/1000
            t = np.array(ind_f_l['Time'])
                    
            # Calculate NDZI
            ind_NDZI = np.divide(np.array(ind_z_l) - np.array(ind_z_h) , np.array(ind_z_l) + np.array(ind_z_h))*100
                    
            # Save all values to the NDZI dataframe
            for l in range(len(t)):
                NDZI = NDZI.append({"Time" : t[l], "probe pair" : prb_pairs[cnt], "probe pair name" : i, "NDZI" : ind_NDZI[l], "High freq" : f_h, "Low freq" : f_l},  ignore_index = True)
        
        # Save NDZI to attribute
        self.NDZI = NDZI
        print(" Done!\n\n")
    
        return NDZI
    
    def plot_NDZI(self, df = pd.DataFrame(), prb_pair = None, prb_pair_name = None):
        """
            plot_NDZI takes care of all of the formatting for creating a datetime/NDZI plot. 
            
            Inputs
            ------
            df : (optional) pandas dataframe. 
                Contains time, probe pair, probe pair name, and NDZI to be shown in desired plot.
            (ONE OF THE FOLLOWING IS REQUIRED)
            prb_pair: string. 
                The probe pair of the data to plot.  
            prb_pair_name : string. 
                The probe pair name of the data to plot. 
            
            Outputs
            -------
            fig : matplotlib figure.
                Figure associated with the NDZI plot.
            ax : matplotlib axes.
                Axis associated with the NDZI plot.
            
        """
        
        
        if df.empty:
            # If no NDZI is inputted, calculate NDZI with highest and lowest frequency found in the data.
            freq = pd.unique(self.data['frequency'])
            f_l = freq[0]
            f_h = freq[-1]
            print("NDZI not computed yet.\n Using default frequency values "+str(f_l)+" and "+str(f_h)+" Hz to compute NDZI. \n")
            df = self.calc_NDZI(f_l,f_h)
            
            
        # Find frequencies.
        f_h = pd.unique(df['High freq'])
        f_l = pd.unique(df['Low freq'])
        
        #Check wether inputs correspond to good data
        prb_pair, prb_pair_name = self.prb_pair_name_crosscheck(df, prb_pair, prb_pair_name)
        
        print("Plotting NDZI for probe pair name "+prb_pair_name+" and associated probe pair "+prb_pair+"...")
        
        # Get NDZI and time data for particular probe pair
        ind_prb_pr = df.loc[df['probe pair'] == prb_pair]
            
        NDZI = ind_prb_pr['NDZI']
        t = ind_prb_pr['Time']
        t_i = pd.to_datetime(t.iloc[0],unit='s')
        t_f = pd.to_datetime(t.iloc[-1],unit='s')
            
        # Get datetime title and axis format
        title_form, ax_form = self.format_datetime_plot(t_i,t_f)
            
        
        # Plot NDZI 
        fig, ax = plt.subplots(1, figsize = (12,8))
        fig.suptitle("NDZI ("+str(f_l[0])+","+str(f_h[0])+" Hz), Probe: "+prb_pair+", Name: "+prb_pair_name+", "+title_form)
        ax.scatter(pd.to_datetime(t,unit='s'),NDZI)
        ax.set_ylim([0,100])
        ax.set_ylabel('NDZI')   
        ax.xaxis.set_major_formatter(mpl.dates.DateFormatter(ax_form))
        ax.xaxis.set_minor_formatter(mpl.dates.DateFormatter(ax_form))
    
        ax.legend([prb_pair_name])
        
        print("Done!\n")
        return fig, ax
    
    def format_datetime_plot(self,t_i,t_f):
        """
            format_datetime_plot sets axis datetime format options and title date information based on elapsed time for windows. 
    
            Input
            -----
            t_i : pandas datetime
                Initial time
            t_f : pandas datetime
                Final Time

            Output
            ------
            ax_form: string
                Axis datetime format
            title_form: string
                Date data to include in title e.g. 2021/05/07 -----> 2021/08/19
                
        """
        # Calculate elapsed time between initial and final
        delta_t = t_f - t_i 
        # Determine title and axis format based off elapsed time
        if delta_t <= timedelta(seconds = 300):
            ax_form = "%M:%S"
            title_form = t_f.strftime("%Y/%m/%d")+", hour "+t_f.strftime("%H")
        elif (delta_t > timedelta(seconds = 300))&(delta_t <= timedelta(minutes = 80)):
            ax_form = "%H:%M"
            title_form = t_i.strftime("%Y/%m/%d")+" -----> "+t_f.strftime("%Y/%m/%d")
        elif (delta_t > timedelta(minutes = 80))&(delta_t <= timedelta(hours = 72)):
            ax_form = "%H:%M"
            title_form = t_i.strftime("%Y/%m/%d")+" -----> "+t_f.strftime("%Y/%m/%d")
        elif (delta_t > timedelta(hours = 72))&(delta_t <= timedelta(days = 14)):
            ax_form = "%m/%d"
            title_form = t_i.strftime("%Y/%m/%d")+" -----> "+t_f.strftime("%Y/%m/%d")
        else: 
            ax_form = "%m/%d"
            title_form = t_i.strftime("%Y/%m/%d")+" -----> "+t_f.strftime("%Y/%m/%d")
        
        
        return title_form,ax_form
    
    def prb_pair_name_crosscheck(self,df,prb_pair,prb_pair_name):
        """
            prb_pair_name_crosscheck validates that the inputted probe pair and probe pair name correspond to the same data.
            prb_pair_name_crosscheck checks wether or not the probe pair and name are contained in the dataframe.
            Then it does the following based on the cases:
                1. Neither are contained:
                    Throws exception saying it cannot find the inputted probe pair and name.
                2. One of the two is contained in the dataframe
                    Finds the corresponding probe pair or name and returns them. 
                3. Both are contained
                    Performs crosscheck to make sure there is a one to one correspondence between probe pair name and probe pair.
                    Otherwise it throws an exception. 
                    
    
            Input
            -----
            df : pandas dataframe
                Data to search for probe pair and probe pair name
            prb_pair : string
                Probe pair to crosscheck with probe pair name, e.g., 'A1C1'.
            prb_pair_name : string
                Probe pair name to crosscheck with probe pair, e.g., 'TYPE1_1_LEAF_1'.

            Output
            ------
            prb_pair: string
                Probe pair corresponding to probe pair name. 
            prb_pair_name: string
                Probe pair name corresponding to probe pair.
                
        """
        prb_pairs = pd.unique(df['probe pair'])
        prb_pair_names = pd.unique(df['probe pair name'])
        
        if (not prb_pair in prb_pairs) and (not prb_pair_name in prb_pair_names):
            raise Exception("Neither the probe pair nor the probe pair name was found in the data set.\n")
            
        elif (prb_pair in prb_pairs) and (not prb_pair_name in prb_pair_names):
            ind_prb = df.loc[df['probe pair'] == prb_pair]
            prb_pair_name = pd.unique(ind_prb['probe pair name'])[0]
            
            
        elif (not prb_pair in prb_pairs) and (prb_pair_name in prb_pair_names):
            
            ind_prb_pair_names = df.loc[df['probe pair name'] == prb_pair_name]
            prb_pair = pd.unique(ind_prb_pair_names['probe pair'])[0]
            
        else:
            ind_prb = df.loc[df['probe pair'] == prb_pair]
            ind_prb_name = df.loc[df['probe pair name'] == prb_pair_name]
            
            pair_to_name = pd.unique(ind_prb['probe pair name'])
            name_to_pair = pd.unique(ind_prb_name['probe pair'])
            
            if len(pair_to_name) > 1:
                raise Exception("The entered probe pair corresponds to more than one probe pair name.\n")
                
            elif len(name_to_pair) > 1:
                raise Exception("The entered probe pair name corresponds to more than one probe pair.\n")
                
            if pair_to_name[0] != prb_pair_name or name_to_pair[0] != prb_pair:
                raise Exception("Your input probe pair and probe pair name do not correspond to the same data set.\n")
                
        return prb_pair, prb_pair_name
    
    def plot_z_vs_t(self, freq, prb_pair = None, prb_pair_name = None, z_upp_bnd = None, z_low_bnd = None):
        """
            plot_z_vs_t takes care of all of the formatting for creating a datetime/Impedance (Z) plot. 
            
            Inputs
            ------
            freq : (required) integer.
                The frequency for which impedance is to be plotted. 
                
            (ONE OF THE FOLLOWING IS REQUIRED)
            prb_pair: string. 
                The probe pair of the data to plot.  
            prb_pair_name : string. 
                The probe pair name of the data to plot. 
                
            z_upp_bnd : (optional) float.
                The upper bound for the impedance vs time plot in OHMS. Default is 'z_upp_bnd = Z.max() + 10 (kOhms)'.
            z_low_bnd : (optional) float.
                The lower bound for the impedance vs time plot in OHMS. Default is 'z_low_bnd = 0'.
                
            Outputs
            -------
            fig : matplotlib figure.
                Figure associated with the NDZI plot.
            ax : matplotlib axes.
                Axis associated with the NDZI plot.
            
        """
        #Check that given frequency is part of the data set
        frequencies = pd.unique(self.data['frequency'])
        if not freq in frequencies:
            raise TypeError("The frequency you inputted does not match any of the frequencies in the data.")
        
        #Check wether inputs correspond to good data
        prb_pair, prb_pair_name = self.prb_pair_name_crosscheck(self.data, prb_pair, prb_pair_name)
        
        print("Plotting Impedance versus time at "+str(freq)+"for probe pair name "+prb_pair_name+" and associated probe pair "+prb_pair+"...")
            
    
        # Find indices associated with frequency and probe pair
        ind = self.data.loc[(self.data['frequency'] == freq) & (self.data['probe pair name'] == prb_pair_name)]
        
                    
        # Get impedance and time
        Z = ind['impedance']/1000
        t = ind['Time']
        
        # Set upper and lower axis bounds for impedance
        if z_upp_bnd == None:
            z_upp_bnd = Z.max()+10
        else:
            z_upp_bnd = z_upp_bnd/1000
        
        if z_low_bnd == None:
            z_low_bnd = 0
        else:
            z_low_bnd = z_low_bnd/1000
            
        # Get initial and final time
        t_i = pd.to_datetime(t.iloc[0],unit='s')
        t_f = pd.to_datetime(t.iloc[-1],unit='s')
        
        # Get datetime title and axis format
        title_form, ax_form = self.format_datetime_plot(t_i,t_f)
        
        # Plot impedance
        fig, ax = plt.subplots(1, figsize = (12,8))
        fig.suptitle("Impedance at "+str(freq)+" Hz, Probe: "+prb_pair+", Name: "+prb_pair_name+", "+title_form)
        ax.scatter(pd.to_datetime(t,unit='s'),Z)
        ax.set_ylim([z_low_bnd,z_upp_bnd])
        ax.set_ylabel('Z (kOhms)')   
        ax.xaxis.set_major_formatter(mpl.dates.DateFormatter(ax_form))
        ax.xaxis.set_minor_formatter(mpl.dates.DateFormatter(ax_form))
    
        ax.legend([prb_pair_name])
        
        print("Done!\n")
        return fig, ax
    
    def plot_z_h_vs_z_l(self, prb_pair = None, prb_pair_name = None, f_l = None, f_h = None, z_l_bnd = None, z_h_bnd = None):
        """
            plot_z_h_vs_z_l plots impedance at a high frequency versus impedance at a low frequency for a particular probe pair or name. 
            
            Inputs
            ------
            (ONE OF THE FOLLOWING IS REQUIRED)
            prb_pair: string. 
                The probe pair of the data to plot.  
            prb_pair_name : string. 
                The probe pair name of the data to plot.
                
            f_l : (optional) integer . 
                Low frequency to plot impedance for. Default is minimum frequency found in the data.
            f_h : (optional) integer. 
                High frequency to plot impedance for. Default is maxmimum frequency found in the data.
            z_l_bnd : (optional) float.
                The upper bound for low frequency impdance in OHMS. Default is 'z_l_bnd = Z_l.max() + 10 (kOhms)'.
            z_h_bnd : (optional) float.
                The upper bound for high frequency impedance in OHMS. Default is 'z_l_bnd = Z_h.max() + 10 (kOhms)'.
                
            Outputs
            -------
            fig : matplotlib figure.
                Figure associated with high versus low impedance.
            ax : matplotlib axes.
                Axis associated with high versus low impedance.
                
                
        """
        
        # Find frequency vector in data
        freq = pd.unique(self.data['frequency'])
        # Check if frequencies inputted are in vector
        if not (f_l in freq) & (f_h in freq):
            f_l = freq.min()
            f_h = freq.max()
            print("Could not find frequencies inputted in data, defaulting to :\nf_l = "+str(f_l)+", f_h = "+str(f_h))
            
        
        #Check wether inputs correspond to good data
        prb_pair, prb_pair_name = self.prb_pair_name_crosscheck(self.data, prb_pair, prb_pair_name)
        
        print("Plotting high vs low impedance at "+str(f_l)+", "+str(f_h)+" Hz for probe pair name "+prb_pair_name+" and associated probe pair "+prb_pair+"...")
        

        # Find indices associated with high and low frequency and probe pair
        ind_f_l = self.data.loc[(self.data['frequency'] == f_l) & (self.data['probe pair name'] == prb_pair_name)]
        ind_f_h = self.data.loc[(self.data['frequency'] == f_h) & (self.data['probe pair name'] == prb_pair_name)]
                    
        # Get impedance and time
        z_l = ind_f_l['impedance']/1000
        z_h = ind_f_h['impedance']/1000
        t = ind_f_l['Time']
        
        # Set axis bounds for high and low impedance
        if z_h_bnd == None:
            z_h_bnd = z_h.max()+50
        else:
            z_h_bnd = z_h_bnd/1000
        
        if z_l_bnd == None:
            z_l_bnd = z_l.max()+50
        else:
            z_l_bnd = z_l_bnd/1000
        
        # Get initial and final time
        t_i = pd.to_datetime(t.iloc[0],unit='s')
        t_f = pd.to_datetime(t.iloc[-1],unit='s')
            
        
        # Plot high versus low impedance
        fig, ax = plt.subplots(1, figsize = (12,8))
        fig.suptitle("Low vs high impedance ("+str(f_l)+","+str(f_h)+" Hz), Probe: "+prb_pair+", Name: "+prb_pair_name)
        ax.scatter(z_h,z_l)
        ax.set_xlim([0,z_h_bnd])
        ax.set_ylim([0,z_l_bnd])
        ax.set_xlabel('Z_h (kOhms)')
        ax.set_ylabel('Z_l (kOhms)')
                    
                    
        print(" Done!\n\n")
        return fig, ax

