library IEEE;
use IEEE.STD_LOGIC_1164.ALL;
use IEEE.NUMERIC_STD.ALL;
use IEEE.fixed_pkg.all;
use work.array_type.ALL;

entity Neuron_top_level_tb is

end Neuron_top_level_tb;

architecture Behavioral of Neuron_top_level_tb is

constant clk_period : time := 10ns;
constant Dendrite_NO : NATURAL := 4;

-- neuron parameters start
constant Resting_potential : NATURAL := 40;
constant Refract_potential_start : NATURAL := 37; -- the membrane potential(MP) will drop to this value and get into the refract period. MP will increase by 1 in each clk peirod.
constant Refract_potential_end : NATURAL := 40; -- when the MP reach Refract_potential_end, refract period end and the neuron restart to be sentive to incoming spikes.
constant thredhold_potential : SFIXED := to_SFIXED(45, 7, -12);
constant Leaky_rate : SFIXED := to_SFIXED(0.00390625, 3, -12);
constant Recharge_rate : SFIXED := to_SFIXED(0.00390625, 3, -12);
-- neuron parameters end 

-- Spike pattern loop for each dendrite 
signal Spike_in_pattern : array_2d_STD(0 to Dendrite_NO - 1)(31 downto 0) := ("10001000100010001000100010001000", 
                                                                              "11101110111011101110111011101110",
                                                                              "10101010101010101010101010101010", 
                                                                              "10000000000000001000000000000000");

signal clk : STD_LOGIC;
signal rst : STD_LOGIC;
signal en : STD_LOGIC;
signal Spike_in : STD_LOGIC_VECTOR(0 to Dendrite_NO - 1);
signal Spike_out : STD_LOGIC;
signal weights_rst : STD_LOGIC;

signal custom_weight_en : STD_LOGIC;
signal custom_weight : array_2d_SFIXED (0 to Dendrite_NO - 1)(1 downto -12);


begin

    clk_process : process
    begin
        clk <= '0';
        wait for clk_period/2;
        clk <= '1';
        wait for clk_period/2;
    end process;
    
    UUT : entity work.Neuron_top_level
        Generic map(Dendrite_NO => Dendrite_NO,
            Resting_potential => Resting_potential,
            Refract_potential_start => Refract_potential_start,
            Refract_potential_end => Refract_potential_end,
            thredhold_potential => thredhold_potential,
            Leaky_rate => Leaky_rate, 
            Recharge_rate => Recharge_rate)

        Port map (clk => clk, 
            rst => rst, 
            en => en, 
            Spike_in => Spike_in,
            Spike_out => Spike_out,
            weight_rst => weights_rst,
            custom_weight_en => custom_weight_en,
            custom_weight => custom_weight);
            
    test_process : process
    begin
        
        wait for 100ns;
        wait until falling_edge(clk);
        
        -- initialize all value of inputs to 0
        rst <= '0';
        en <= '0';
        Spike_in <= (others => '0'); 
        weights_rst <= '0';
        custom_weight_en <= '0';
        custom_weight <= (others => (others => '0'));
        wait for clk_period*10;
        
        -- initial reset
        rst <= '1';
        wait for clk_period;
        rst <= '0';
        wait for clk_period *10 ;
        
        en <= '1';
        
        -- if you want to randomly initialise weights, uncomment below
--        wait for clk_period * 84; -- Let RNG to initialize. change this value to try different initial weights.
--        weights_rst <= '1';
--        wait for clk_period;
--        weights_rst <= '0';
--        wait for clk_period;
        -- random initialisation end
        
        -- if you want to use customise wieghts, uncomment below
        wait for clk_period *10;
        custom_weight_en <= '1';
        custom_weight <= (to_SFIXED(1.0, 1, -12),
                          to_SFIXED(0.5, 1, -12),
                          to_SFIXED(-0.5, 1, -12),
                          to_SFIXED(-1, 1, -12));
        -- Customising weights code end
        
        
        wait for clk_period * 20;
        
        -- feed input spike
        while true loop
            for i in 0 to 31 loop
                for j in 0 to 3 loop
                    Spike_in(j) <= Spike_in_pattern(j)(i);
                end loop;
                wait for clk_period;
            end loop;
        end loop;
        
        wait;
    
   end process;
            
end Behavioral;














