import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt


# Few command to make the plot nicer
# (use latex for typesetting and bigger fonts)
mpl.rc('text', usetex = True)
mpl.rc('font', family = 'serif')
mpl.rc('font', size = '14')
mpl.rc('xtick', labelsize=14)
mpl.rc('ytick', labelsize=14)

#set this directory to where the files you want to analyze are
dir = "./"

def load_spectrum(filename, comment = '#'):
    # The next few lines open the file and count
    # how many bins in one spectrum
    file = open(filename)
    nBins = 0
    # read first line
    line = file.readline()
    # while no blank we continue
    while line != "\n":
        # if not a comment line (header for instance)
        # we count one line
        if(line[0] != comment):
            nBins +=1
        # read next line
        line = file.readline()
    file.close()

    # now we know nBins. We still need to actually read the data.
    fileSp = np.loadtxt(filename, comments=comment)

    # find how many spectra are in the files
    nSpectra = len(fileSp) / nBins

    return fileSp, int(nBins), int(nSpectra)


def plot_spectrum(spectra, column, nBins, nSpectra, colormap = 'rainbow'):
    colors = plt.get_cmap(colormap)
    for i in range(nSpectra):
        plt.plot(spectra[i*nBins:(i+1)*nBins,0], spectra[i*nBins:(i+1)*nBins,column], color=colors(i/nSpectra))
    plt.xscale('log')
    plt.yscale('log')

"""## Inflaton"""

infSp, nBins, nSpectra = load_spectrum(dir + "spectra_gws.txt")

plt.figure(0)
plot_spectrum(infSp, 1,  nBins, nSpectra)
plt.ylim([1e-10,1e-4])
plt.xlim([0.2,50])
plt.xlabel("$\\tilde{k}$")
plt.ylabel("$\\Omega_{\\mathrm{GW}}(\\tilde{k},\\eta)$")

plt.savefig("spectraProjectorComparison.pdf", format='pdf', dpi=600, bbox_inches='tight')

