import RF_Track
import numpy

def init_bunch(rf_fname, setup):
    
    mass = setup.mass # MeV/c^2
    Q = setup.Q # single-particle charge, in units of e
    population = setup.population # 50 * RF_Track.pC; % number of real particles per bunch
    P_i = setup.P_i # MeV/c
    P_f = setup.P_f # MeV/c
    
    RF = init_rf_structure(rf_fname)
    L_RF = RF.get_length()
    
    # FODO cell parameters
    mu = numpy.deg2rad(setup.mu) # deg
    Lquad = 0.1 # m
    Lcell = 8*L_RF + 2*Lquad # m, eight structures and two quadrupoles
    
    # Define Twiss parameters
    Twiss = RF_Track.Bunch6d_twiss();
    Twiss.emitt_x = 5 # mm.mrad, normalized emittances
    Twiss.emitt_y = 5 # mm.mrad
    Twiss.beta_x = Lcell * (1 - numpy.sin(mu/2)) / numpy.sin(mu) # m
    Twiss.beta_y = Lcell * (1 + numpy.sin(mu/2)) / numpy.sin(mu) # m
    Twiss.sigma_t = setup.sigma_t
    Twiss.sigma_pt = setup.sigma_pt
    
    # Create the bunch
    bunch =  RF_Track.Bunch6d_QR (mass, population, Q, P_i, Twiss, 1000)
    return bunch


def init_rf_structure(rf_fname):

    T = numpy.loadtxt(rf_fname)
    
    Ez = T[:,1] + 1j*T[:,2] # MV/m
    hz = T[1,0]-T[0,0] # m
    
    freq = 11.9942e9 # Hz
    phid = numpy.deg2rad(0) # degrees

    E_map = 100e6 # V/m, the gradient of the field map
    E_actual = 80e6 # V/m, our target gradient

    P_map = 37.5e6 # W, the field map was generated assuming 37.5 MW input power, to provide 100 MV/m gradient
    P_actual = P_map * E_actual**2/E_map**2 # W, we want to operate at 80 MV/m

    RF = RF_Track.RF_FieldMap_1d (Ez, hz, -1, freq, +1, P_map, P_actual)
    RF.set_phid (phid)
    RF.set_nsteps (100)

    # Uncomment to be a little faster, using a constant Ez field.
    #RF = RF_Track.Drift( RF.get_length() )
    #RF.set_static_Efield (0, 0, - 0.8 * E_actual)
    
    # add SRWF to structure
    SRWF = init_SRWF(freq)

    RF.add_collective_effect(SRWF)
    RF.set_cfx_nsteps(10)
    return RF


def init_SRWF(freq):
    w = RF_Track.clight / freq # m, RF wavelength
    t = 0.5 * (1.67 + 1) / 1e3 # m, average iris thickness
    a = 0.5 * (6.3 + 4.7) / 1e3 * 0.5 # m, average iris aperture radius
    l = w / 3 # m, cell length, as this is a 2pi/3 TW structure
    g = l - t # m, gap length
    SRWF = RF_Track.ShortRangeWakefield(a, g, l)# % a,g,l
    return SRWF 

def init_linac_lattice(rf_fname, setup):

    mass = setup.mass # MeV/c^2
    Q = setup.Q # single-particle charge, in units of e
    population = setup.population # 50 * RF_Track.pC; % number of real particles per bunch
    P_i = setup.P_i # MeV/c
    P_f = setup.P_f # MeV/c
    
    RF = init_rf_structure(rf_fname)
    RF.set_phid(setup.phid)
    
    # FODO cell parameters
    mu = numpy.deg2rad(setup.mu) # deg
    
    L_RF = RF.get_length()
    L_quad = 0.1 # m
    L_cell = 8*L_RF + 2*L_quad # m, eight structures and two quadrupoles
    k1L = numpy.sin(mu/2) / (L_cell/4) # 1/m, integrateg focusing strength

    C = RF_Track.Corrector()
    B = RF_Track.Bpm()

    FODO = RF_Track.Lattice()
    FODO.append(RF)
    FODO.append(RF)
    FODO.append(RF)
    FODO.append(RF)
    FODO.append(C)
    FODO.append(RF_Track.Quadrupole(L_quad, 0.0)) # initial strength to zero -> set automatically using Ref part
    FODO.append(B)
    FODO.append(RF)
    FODO.append(RF)
    FODO.append(RF)
    FODO.append(RF)
    FODO.append(C)
    FODO.append(RF_Track.Quadrupole(L_quad, 0.0))
    FODO.append(B)
    
    # Define the reference particle
    P0 = init_reference_particle(setup)
    
    # We use autophase to set the phases, and to determine P_max, the maximum final momentum
    P_max = FODO.autophase(P0) # MeV/c
    FODO.unset_t0()

    # The momentum gain is
    P_gain = P_max - P_i   # MeV/c
    
    n_FODO = (P_f - P_i) / P_gain
    n_FODO = int(numpy.round(n_FODO)) # let's round it to the nearest integer
    
    # Start a new lattice
    LINAC = RF_Track.Lattice()
    
    # 1/2 quad, let's start with half a quad
    LINAC.append(C)
    LINAC.append(RF_Track.Quadrupole(L_quad/2, 0.0))
    LINAC.append(B)
    
    # let's put out n_FODO cells
    for i in range(n_FODO):
        LINAC.append(FODO)
    
    # Track the reference particle
    P1 = LINAC.track(P0)
    
    Quads = LINAC.get_quadrupoles()

    k1 = k1L / L_quad # 1/m^2

    half_P_gain = 0.5 * P_gain

    P = P_i # initial momentum
    for q in Quads:
        # set quadrupole strength
        q.set_K1 (P/Q, k1)
        # update the momentum variable
        P += half_P_gain
        # changes the sign of k1, anticipating the next quadrupole
        k1 = -k1
        
    return LINAC

def init_reference_particle(setup):
    return RF_Track.Bunch6d(setup.mass, setup.population, setup.Q, numpy.array([0,0,0,0,0,setup.P_i]).T)

def init_gun(setup, filename):
    # max_field in MV/m
    # phid in degrees
    freq = 2.99855e9 # Hz
    T = numpy.loadtxt(filename)
    S = T[:,0] # m
    S0 = min(S) # m
    S1 = max(S) # m
    L = S1-S0 # m
    Ez = T[:,1] * setup.Ez * 1e6 # V/m
    dS = L / (len(S)-1) # m
    Gun = RF_Track.RF_FieldMap_1d(Ez, dS, L, freq, +1)
    Gun.set_t0(0.0)
    Gun.set_phid(setup.PHID)
    return Gun, S0

def init_solenoid(setup, filename):
    # max_field in T
    T = numpy.loadtxt(filename)
    S = T[:,0] # m
    S0 = min(S) # m
    S1 = max(S) # m
    L = S1-S0 # m
    Bz = T[:,1] # T
    Bz = Bz * setup.Bz # T
    dS = L / (len(S)-1) # m
    Solenoid = RF_Track.Static_Magnetic_FieldMap_1d(Bz, dS)
    return Solenoid, S0