import sys
import argparse
import re

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from SWILevel01B import L01B_OBJ


def RJT(T: float, f: float):
    """Convert temperatures to Rayleigh Jeans scale.


    Args:
        T (float): the physical temperature in K to be converted.
        f (float): the frequency of the observation in Hz

    Returns:
        float: the temperature on a Rayleight Jeans scale.
    """
    h = 6.62607015e-34
    k = 1.380649e-23
    x = h * f / k
    return x / (np.exp(x / T) - 1.0)


def calibrate(hot, sky, Th: float, Tc: float):
    """Calculate noise specturm according to Y-factor method.


    Args:
        hot (np.array): data vector for hot meaurement
        sky (np.array): data vector for cold (sky) measurement
        Th (float): hot load temperature (on RJ scale)
        Tc (float): cold (sky) load temperatur (on RJ scale)

    Returns:
        np.array: resulting Tsys spectrum.
    """
    Y = hot / sky
    cal = (Th - Y * Tc) / (Y - 1.0)
    return cal


def hot_sky_plot(hot, sky, title, ax=None):
    """Plot one pair of hot and sky spectra."""

    if ax is None:
        ax = plt.gca()
    x = np.arange(len(hot))
    f = 6496.291429 - 0.099271 * x
    ax.set_title(title)  # , fontdict={'fontsize': 10, 'color': 'g'})
    ax.set_xlabel("frequency [MHz]")
    ax.set_ylabel("counts/cycle")
    # ax.set_ylim(1500, 7500)
    ax.plot(f, sky, label="sky")
    ax.plot(f, hot, label="hot")
    # return foo


def tsys_plot(tsys, title, ax=None):
    """Plot a Tsys spectrum."""

    if ax is None:
        ax = plt.gca()
    x = np.arange(len(tsys))
    f = 6496.291429 - 0.099271 * x
    ax.set_xlabel("frequency [MHz]")
    ax.set_ylabel("Tn RJ[K]")
    ax.set_ylim(0, 5000)
    ax.set_title(title)  # , fontdict={'fontsize': 10, 'color': 'g'})
    ax.plot(f, tsys, label="tsys")
    ax.hlines(y=np.mean(tsys), xmin=np.min(f), xmax=np.max(f), linewidth=1, color="r")


def plot_combs(mat, axs):
    """Plot the two comb spectra from a CTS observation sequence.

    Args:
        mat: the numpy matrix with the CTS spectral data.
        axs: a 2 element matplotlib axis object
    """

    _, nchan = mat.shape
    x = np.arange(nchan)
    f = 6496.291429 - 0.099271 * x

    axs[0].plot(f, mat[0, :])  # comb spectrum for CTS1
    axs[1].plot(f, mat[1, :])  # comb spectrum for CTS2
    axs[0].set_title("CTS1")
    axs[1].set_title("CTS2")
    axs[0].set_xlabel("frequency [MHz]")
    axs[1].set_xlabel("frequency [MHz]")
    axs[0].set_ylabel("counts/cycle")
    axs[1].set_ylabel("counts/cycle")


def plot_all(tbl, mat, axs):
    """Plot all pairs of hot/cold spectra."""

    axs = axs.flatten()
    extrema = []
    for index, row in tbl.iterrows():
        ghz = float(row["GHz"])
        ihot = int(row["hot"])
        isky = int(row["sky"])
        bias = int(row["bias"])
        title = f"{isky}/{ihot}, {bias} GHz:{ghz:.3f}"
        # print(ghz, hot, sky, bias)
        hot = mat[ihot, :]
        sky = mat[isky, :]
        extrema.append([np.min(hot), np.max(hot)])
        extrema.append([np.min(sky), np.max(sky)])
        hot_sky_plot(hot, sky, title, axs[index])
    plt.tight_layout()


def plot_noise(tbl, mat, axs):
    """Plot all Tsys spectra."""
    axs = axs.flatten()
    f = []
    Tn = []
    for index, row in tbl.iterrows():
        ghz = float(row.GHz)
        ihot = int(row.hot)
        isky = int(row.sky)
        bias = int(row.bias)
        title = f"{isky}/{ihot}, {bias} GHz:{ghz:.3f}"
        hot = mat[ihot, :]
        sky = mat[isky, :]
        Th = float(row.Th)
        tsys = calibrate(hot, sky, Th=RJT(Th, ghz * 1.0e9), Tc=RJT(2.73, ghz * 1.0e9))
        tsys_plot(tsys, title, axs[index])
        f.append(ghz)
        Tn.append(np.mean(tsys))
    plt.tight_layout()
    noise = pd.DataFrame(data={"GHz": f, "Tsys": Tn})
    return noise


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("ncfile", help="name of science data file")
    args = parser.parse_args()
    print(args.ncfile)

    m = re.search("TSYS_(CTS|ACS)", args.ncfile)
    if not m:
        print("file name not recognized")
        sys.exit(1)
    else:
        backend = m.group(1)

    obj = L01B_OBJ(args.ncfile)
    obj.read_l01b_netcdf()

    if backend == "CTS":
        mat = obj.get_cts_mat()
        df = obj.get_cts_obs_table()
        print(df)
        if (df["FLM"] == "SKY").all():
            df["FLM"] = np.array(["HOT"] * 16 + ["SKY"] * 15)
            print("corrected for missing HOT labels")

        print("Comb spectra")
        fig_title = obj.basename.replace("_", " ")
        fig, axs = plt.subplots(1, 2, figsize=(12, 6), sharey=True)
        plot_combs(mat, axs)
        plt.show()

        print("Hot/cold data for 600 GHz receiver")
        wide1 = L01B_OBJ.get_cts_tbl(df, 1)
        fig, axs = plt.subplots(3, 5, figsize=(16, 10), sharex=True, sharey=True)
        plot_all(wide1, mat, axs)
        plt.show()

        print("Tsys spectra for 600 GHz receiver")
        fig, axs = plt.subplots(3, 5, figsize=(16, 10), sharex=True, sharey=False)
        noise = plot_noise(wide1, mat, axs)
        print(noise)
        # performance1.append(tbl)
        plt.show()

        print("Hot/cold data for 1200 GHz receiver")
        wide2 = L01B_OBJ.get_cts_tbl(df, 2)
        fig, axs = plt.subplots(3, 5, figsize=(16, 10), sharex=True, sharey=True)
        plot_all(wide2, mat, axs)
        plt.show()

        print("Tsys spectra for 1200 GHz receiver")
        fig, axs = plt.subplots(3, 5, figsize=(16, 10), sharex=True, sharey=False)
        noise = plot_noise(wide2, mat, axs)
        print(noise)
        plt.show()

    elif backend == "ACS":
        mat = obj.get_acs_mat()
        df = obj.get_acs_obs_table()
        print(df)
        print("Hot/cold data for 600 GHz receiver")
        wide1 = L01B_OBJ.get_acs_tbl(df, 1)
        fig, axs = plt.subplots(3, 5, figsize=(16, 10), sharex=True, sharey=True)
        plot_all(wide1, mat, axs)
        plt.show()

        print("Tsys spectra for 600 GHz receiver")
        fig, axs = plt.subplots(3, 5, figsize=(16, 10), sharex=True, sharey=False)
        noise = plot_noise(wide1, mat, axs)
        print(noise)
        # performance1.append(tbl)
        plt.show()

        print("Hot/cold data for 1200 GHz receiver")
        wide2 = L01B_OBJ.get_acs_tbl(df, 2)
        fig, axs = plt.subplots(3, 5, figsize=(16, 10), sharex=True, sharey=True)
        plot_all(wide2, mat, axs)
        plt.show()

        print("Tsys spectra for 1200 GHz receiver")
        fig, axs = plt.subplots(3, 5, figsize=(16, 10), sharex=True, sharey=False)
        noise = plot_noise(wide2, mat, axs)
        print(noise)
        plt.show()
