from netCDF4 import Dataset

import pandas as pd
import numpy as np
from datetime import datetime

# import copy
from pathlib import Path
# import time


def get_header_v2():
    header = {}
    header["METAKERNEL"] = ""
    header["START_DATE"] = ""
    header["NAME"] = ""
    header["OBSID"] = ""
    header["SCRIPTID"] = ""
    header["RUN_TIME"] = ""
    header["PWR"] = ""
    header["ENE"] = ""
    header["NUM_MEAS_CTS"] = ""
    header["NUM_MEAS_ACS"] = ""
    header["NUM_MEAS_CCH"] = ""
    header["DV"] = ""
    header["DR"] = ""
    header["FLM_TOT_STEPS"] = ""
    header["AT_TOT_STEPS"] = ""
    header["CT_TOT_STEPS"] = ""
    header["ASW_COMMAND"] = "210/33"
    header["ASW_NAME"] = "SWI_RunScript"
    header["TARGET"] = "NONE"
    return header


def get_geom_header():
    geom = {}
    # geom['TARGET']  = ['Name of the target','','NONE']
    # geom['DIST']    = ['SC-Target distance (to center)','km',-1]
    # geom['ANGSIZE'] = ['Angular diameter','mrad',-1]
    # geom['PHASE']   = ['Phase angle of sub-SC point','deg',-1]
    # geom['LT']      = ['Local Time of sub-SC point','hh:mm:ss',-1]
    # geom['VSEPX']   = ['Offset from nadir (SC-frame) in X','deg',-1]
    # geom['VSEPY']   = ['Offset from nadir (SC-frame) in Y','deg',-1]
    # geom['VSEPZ']   = ['Offset from nadir (SC-frame) in Z','deg',-1]
    # geom['ATANG']   = ['AT angle to target','deg',-1]
    # geom['CTANG']   = ['CT angle to target','deg',-1]
    # geom['ATSTEP']  = ['AT number of steps to target','steps',-1]
    # geom['CTSTEP']  = ['CT number of steps to target','steps',-1]

    geom["SUN_DISTANCE"] = ["SC-Sun distance", "km", -1]
    geom["VELOCITY_Y-AXIS_ANGLE"] = ["Angle between SC-vel and SC-y axis", "deg", -1]
    geom["NAIF_ID"] = ["Name of a target", "", "NONE"]
    geom["DIAMETER"] = ["Angular diameter", "mrad", -1]
    geom["SUB-SC_LATITUDE"] = ["Sub-SC latitude", "deg", -1]
    geom["SUB-SC_LONGITUDE"] = ["Sub-SC longitude", "deg", -1]
    geom["NORTH_POLE_ANGLE"] = ["SC-y axi and target spin axis angle", "deg", -1]
    geom["SUB-SOLAR_LATITUDE"] = ["Sub-Sun latitude", "deg", -1]
    geom["SUB-SOLAR_LONGITUDE"] = ["Sub-Sun longitude", "deg", -1]
    geom["PHASE_ANGLE"] = ["Phase angle", "deg", -1]
    geom["LOCAL_TIME"] = ["Local time at sub-SC point", "hrs", -1]
    geom["BETA_ANGLE"] = ["Beta angle", "deg", -1]
    geom["RADIAL_RELATIVE_VELOCITY"] = ["SC-Target relative radial velocity", "m/s", -1]
    geom["CT_OFFSET"] = ["CT angle to target", "deg", -1]
    geom["AT_OFFSET"] = ["AT angle to target", "deg", -1]
    geom["AT_STEPS"] = ["AT number of steps to target", "steps", -1]
    geom["CT_STEPS"] = ["CT number of steps to target", "steps", -1]
    geom["CT_DRIFT_RATE"] = ["CT drift rate", "mrad/s", -1]
    geom["AT_DRIFT_RATE"] = ["AT drift rate", "mrad/s", -1]
    geom["TOTAL_DRIFT_RATE"] = ["Combined drift rate", "mrad/s", -1]
    geom["LEADING_SIDE_STATUS"] = ["Leading/Trailing", "", -1]
    geom["MOON_EVENT_ID"] = ["Non relevant = 0, Occultation by Jupiter = 1, Transit = 2, Eclipse = 3, Solar occultation = 4, Out-of-range = 5", "", -1]

    return geom


class L01B_OBJ(object):

    def __init__(self, fin):
        """
        Read in the L01B file. From the filename, fin, we will get exactly
        the format we expect. Eg. CTS, ACS, or CTS+CCH, or ACS+CCH, or CCH only.

        Therefore the reading of these files is slightly different than the LO time
        ordered files.

        """

        ncfile = Path(fin)
        if ncfile.is_file():
            print(f"L01B file: {fin} exists")
        else:
            print(f"ERROR0 in swi_netcdf_readers L01B object, file: {fin} not found")
            exit(1)

        self.cmd_scriptid = ""  # commanded script and obsid and script name
        self.cmd_obsid = 0
        self.cmd_name = ""
        self.cmd_dims = {
            "nprof_acs": -1,
            "nprof_cts": -1,
            "nprof_cch": -1,
            "nprof_headers": -1,
            "num_cmd_fields": -1,
        }

        self.fname = fin
        self.basename = ncfile.stem        # get rid of the .nc in the filename
        # We are left only with OBSID_SCRIPTID_SCRIPTNAME

        self.allvars = {}  # all variables
        # will hold string variable from GMT (generated at run-time) !!!
        self.allvars["DATE_ST"] = np.array([], dtype='datetime64[s]')
        # will hold datetime object for further processing if needed (convenient variable)
        # self.allvars["DATE_DT"] = np.array([], dtype=object)

        rs0 = self.basename.split("_")
        self.cmd_scriptid = rs0[1]
        self.cmd_obsid = int(rs0[0])
        self.cmd_name = "_".join(rs0[2:])
        self.cmd_keys = []  # this will be extracted from the l01b file by bruteforce...

        #
        # Get keys for commanded parameters (unpacked)
        #
        # script_obj, script_name_tmp = swi_scripts_fm.get_script_name(self.cmd_scriptid)
        # self_cmd_fields   = script_obj.pars_unpacked.keys()
        #
        # But I do not want swi_netcdf_readers to depend on swi_scripts_fm.....
        # So, I will do with brute force after I read in the file !!!
        f = Dataset(self.fname, "r", format="NETCDF4")  # 'a' would allow read-write

        for k, val in enumerate(f.dimensions.values()):
            if val.name in self.cmd_dims:
                self.cmd_dims[val.name] = int(val.size)
                # print('L01B: Found dimension {} with size {} '.format(val.name,val.size))

        allvars = f.variables.keys()
        geomhdr = get_geom_header().keys()
        hdr = get_header_v2()

        def cmd_keys(k):
            if k in geomhdr:
                return False
            if k in hdr:
                return False
            for sub in ["HK", "RAW", "HDR", "ACS_CONVERTED", "GMT", "CTSID", "ACSID", "CCHID", "TMLN_SCRIPT"]:
                if sub in k:
                    return False
            return True

        self.cmd_keys = list(filter(cmd_keys, allvars))
        print(f"Number of script parameters from fields {len(self.cmd_keys)},"
              f"from dimensions = {self.cmd_dims['num_cmd_fields']}")

        if len(self.cmd_keys) != self.cmd_dims["num_cmd_fields"]:
            print("Number of script_parameters not the same ... this should not be")
            exit(1)

        f.close()

    def print_L01B_template(self):
        hdr_line = "================== Contents of the L01B template ====================="
        print(hdr_line)
        print("{} ".format(self.fname))
        print("CMD_SCRIPTID: {}".format(self.cmd_scriptid))
        print("CMD_OBSID: {}".format(self.cmd_obsid))
        print("CMD_NAME: {}".format(self.cmd_name))
        print("DIMENSIONS: ")
        for k in self.cmd_dims.keys():
            if self.cmd_dims[k] > 0:
                print(k, self.cmd_dims[k])

        print("=" * len(hdr_line))

    def print_tmln_info(self):
        """
        Print explanations to the tmln_script (21 columns)
        """

        labels = {
            "CMD": "Name of command executed inside the script",
            "IDX1": "Index into CTS, ACS or CCH matrix to subset taken for LO1.",
            "IDX2": "Index into CTS, ACS or CCH matrix to subset taken for LO2.",
            "INTCYCLES": "Integration time in cycles (cts, cch, acs).)",
            "COMB": "Flag 0/1 indicating whether this spectra is on comb",
            "FLM": "Flip mirror position HOT/SKY",
            "AT": "Current steps of AT",
            "CT": "Current steps of CT",
            "BIAS1": "Index into the bias table for REC1",
            "SHIFT1": "Shift -2 to 2 allowed to fine tune BIAS1",
            "K1": "K-value = BIAS1+SHIFT1",
            "LO1": "LO1 frequency setting (GHz)",
            "MODEC": "MODE-compression (40, 16 bits or A, B)",
            "MODES": "MODE-shift (bit shift applied)",
            "MODEA": "MODE-averaging (0,2,4,8,16 channels)",
            "BIAS2": "Index into the bias table for REC2",
            "SHIFT2": "Shift -2 to 2 allowed to fine tune BIAS2",
            "K2": "K-value = BIAS2 + SHIFT2",
            "LO2": "LO2 frequency setting (GHz)",
            "CCH_INTCYCLES": "Integriton cycles for CCH (if used)",
        }

        for i, k in enumerate(labels.keys()):
            print(f"{i:2d} {k:<14}: {labels[k]}")
        # print("#"*51 + "\n")

    def print_available_fields(self, flag="all"):
        """
        flag = all, hk, geom, sci
        """

        if flag not in ["all", "hk", "geom", "cmd", "sci", "other"]:
            print(f"Printing available L01B fields flag {flag} not recognized")
            exit(1)

        # print("#"*51)

        def is_housekeeping(key):
            # if "HK" in key or key in ["DATE_ST", "DATE_DT", "FLAG_HDR"]:
            if "HK" in key or key in ["DATE_ST", "FLAG_HDR"]:
                return True
            return False

        def print_section(keys, title, with_data=False):
            print()
            print("---------------------------------------------")
            print(f"Available fields with {title}")
            print("---------------------------------------------")
            for k in keys:
                if with_data:
                    print(f"{k:<15} =  {self.allvars[k][0]}")
                else:
                    print(k)

        scilist = ["CTS_RAW", "ACS_RAW", "CCH_RAW", "CTSID", "ACSID", "CCHID", "ACS_CONVERTED"]
        titles = {"hk": "housekeeping",
                  "geom": "geometry",
                  "cmd": "commanded parameters",
                  "sci": "science data",
                  "other": "other data"}
        sections = {"hk": list(filter(is_housekeeping, self.allvars.keys())),
                    "geom": get_geom_header().keys(),
                    "cmd": self.cmd_keys,
                    "sci": list(filter(lambda k: k in scilist, self.allvars.keys())),
                    "other": ["TMLN_SCRIPT"]}

        if flag == "all":
            for grp in ["hk", "geom", "cmd", "sci", "other"]:
                print_section(sections[grp], titles[grp], with_data=(grp == "cmd"))
        else:
            print_section(sections[flag], titles[flag], with_data=(flag == "cmd"))

    def read_l01b_netcdf(self):
        """
        Read in already filled file
        """

        f = Dataset(self.fname, "r", format="NETCDF4")  # 'a' would allow read-write
        allvars = f.variables.keys()
        d = dict((v, f.variables[v][:]) for v in allvars)
        missing = 0

        def ST2DT(st):
            dt = []
            missing = 0

            format = "%Y-%m-%dT%H:%M:%S" if len(st[0]) < 20 else "%Y-%m-%dT%H:%M:%S.%f"
            for t in st:
                if len(t) > 10:
                    dt.append(datetime.strptime(t, format))
                else:
                    missing += 1
            return np.array(dt), missing

        for k, val in enumerate(d.keys()):
            if "GMT" in val:
                time0 = d[val][:]
                # time1, missing = ST2DT(time0)
                _, missing = ST2DT(time0)

                self.allvars["DATE_ST"] = time0.astype("datetime64[ms]")
                # self.allvars["DATE_DT"] = time1

            self.allvars[val] = d[val][:]

        if missing > 0:
            print(f"Found {missing} missing dates."
                  "Either we are loading an empty L1B file to fill with defaults,"
                  "or data are incomplete")
            self.allvars["INCOMPLETE_SCRIPT"] = True
            self.allvars["MISSED_EVENTS"] = missing
        else:
            self.allvars["INCOMPLETE_SCRIPT"] = False
            self.allvars["MISSED_EVENTS"] = 0
        f.close()

    def get_tmln_table(self):
        """ Get the table of script parameters."""
        tmln = self.allvars["TMLN_SCRIPT"]
        columns = {"CMD": "object",
                   "IDX1": "int16",
                   "IDX2": "int16",
                   "INTCYCLES": "int32",
                   "COMB": "int16",
                   "FLM": "object",
                   "AT": "int16",
                   "CT": "int16",
                   "BIAS1": "int16",
                   "SHIFT1": "int16",
                   "K1": "int16",
                   "LO1": "float",
                   "MODEC": "int16",
                   "MODES": "int16",
                   "MODEA": "int16",
                   "BIAS2": "int16",
                   "SHIFT2": "int16",
                   "K2": "int16",
                   "LO2": "float",
                   "CCH_INTCYCLES": "int32"}

        data = {}
        n = len(tmln)
        for i, col in enumerate(columns.keys()):
            data[col] = [tmln[j][i] for j in range(n)]
        df = pd.DataFrame(data=data).astype(columns)
        return df

    def get_acs_obs_table(self):
        """ Get the table of script parameters for the ACS."""
        df = self.get_tmln_table()
        T_hot = np.mean([self.allvars["HK_CHL_1_T"],
                         self.allvars["HK_CHL_2_T"],
                         self.allvars["HK_CHL_3_T"],
                         self.allvars["HK_CHL_4_T"]],
                        axis=0).reshape(-1, 4)
        df["THOT1"] = T_hot[:, 0] + 273.15
        df["THOT2"] = T_hot[:, 1] + 273.15
        df["DT1"] = self.allvars["DATE_ST"].reshape(-1, 4)[:, 0]
        df["DT2"] = self.allvars["DATE_ST"].reshape(-1, 4)[:, 1]
        cols = ["CMD", "FLM",
                "DT1", "IDX1", "LO1", "BIAS1", "THOT1",
                "DT2", "IDX2", "LO2", "BIAS2", "THOT2"]
        return df.loc[:, cols]

    def get_cts_obs_table(self):
        """ Get the table of script parameters for the CTS."""
        df = self.get_tmln_table()
        T_hot = np.mean([self.allvars["HK_CHL_1_T"],
                         self.allvars["HK_CHL_2_T"],
                         self.allvars["HK_CHL_3_T"],
                         self.allvars["HK_CHL_4_T"]],
                        axis=0).reshape(-1, 2)
        df["THOT1"] = T_hot[:, 0] + 273.15
        df["THOT2"] = T_hot[:, 1] + 273.15
        df["DT1"] = self.allvars["DATE_ST"][0::2]
        df["DT2"] = self.allvars["DATE_ST"][1::2]
        cols = ["CMD", "COMB", "FLM",
                "DT1", "IDX1", "LO1", "BIAS1", "THOT1",
                "DT2", "IDX2", "LO2", "BIAS2", "THOT2"]
        return df.loc[:, cols]

    def get_acs_mat(self):
        """ Get the ACS raw data matrix. """
        mat = self.allvars["ACS_CONVERTED"]
        return mat

    def get_cts_mat(self):
        """ Get a scaled (i.e. counts per cylce) copy of the raw data. """
        ctsmat = self.allvars["CTS_RAW"]
        bitshift = self.allvars["MODEON_SHFT"][0]
        tint = self.allvars["T_ON"][0]
        tintcomb = self.allvars["INT_COMB"][0]
        scale = 2**bitshift / tint
        scale_comb = 2**bitshift / tintcomb
        mat = ctsmat.copy()    # ctsmat is a masked array, take a copy
        n = mat.shape[0]       # get number of rows
        factor = np.concatenate((scale_comb*np.ones(2), scale*np.ones(n-2)), axis=0).reshape(n, 1)
        return mat*factor

    @staticmethod
    def get_cts_tbl(df, rx):
        """ Get table columns for a specific receiver."""
        LO = f"LO{rx}"
        IDX = f"IDX{rx}"
        BIAS = f"BIAS{rx}"
        DT = f"DT{rx}"
        THOT = f"THOT{rx}"

        no_comb = df[df["COMB"] == 0]    # filter out comb data
        hot = no_comb[no_comb["FLM"] == "HOT"]
        sky = no_comb[no_comb["FLM"] == "SKY"]
        wide = pd.DataFrame(data={'GHz': np.array(hot[LO]),
                                  'bias': np.array(hot[BIAS]),
                                  'hot': np.array(hot[IDX]),
                                  'sky': np.array(sky[IDX]),
                                  'Th': np.array(hot[THOT])})

        # wide = (no_comb.pivot(index=[LO, BIAS], columns="FLM")[[IDX, THOT]]
        #         .droplevel(0, axis=1)
        #         .reset_index(level=BIAS)
        #         .rename_axis("GHz"))
        # wide.columns = ["bias", "hot", "sky", "T1", "T2"]
        return wide

    @staticmethod
    def get_acs_tbl(df, rx):
        """ Get table columns for a specific receiver."""
        LO = f"LO{rx}"
        IDX = f"IDX{rx}"
        BIAS = f"BIAS{rx}"
        DT = f"DT{rx}"
        THOT = f"THOT{rx}"

        hot = df[df["FLM"] == "HOT"]
        sky = df[df["FLM"] == "SKY"]
        wide = pd.DataFrame(data={'GHz': np.array(hot[LO]),
                                  'bias': np.array(hot[BIAS]),
                                  'hot': np.array(hot[IDX]),
                                  'sky': np.array(sky[IDX]),
                                  'Th': np.array(hot[THOT])})

        return wide
