
# -*- coding: utf-8 -*-
"""
Created on Mon Apr 15 00:32:46 2024

@author: shara
"""
import numpy as np
#Defines a m-length Trellis Quantizer with N states

def conv_encoder(h0, h1, state_seq, u):

    state_seq = str(u) + state_seq
    z_0 = 0
    z_1 = 0
    for i in range(len(state_seq)):
        z_1 = z_1 + np.int64(state_seq[i]) * np.int64(h0[i])
        z_0 = z_0 + np.int64(state_seq[i]) * np.int64(h1[i])
    z_1 = z_1 % 2
    z_0 = z_0 % 2
    return 2*z_1 + z_0 + 1

def update_list_in_dictionary( dict_name, key, value ):
    if key in dict_name:
        dict_name[key].append( value )
    else:
        dict_name[key] = [ value ]
    return dict_name

h0_map = { 4: '101', 8: '1101', 16:'11001', 32:'101001', 64:'1010001', 128:'10111001', 256:'101100101' }
h1_map = { 4: '010', 8: '0010', 16:'00100', 32:'000100', 64:'0010100', 128:'01101010', 256:'010011110' }

# h0_map = { 4: '101', 8: '1011', 16:'10101', 32:'100101', 64:'1000101', 128:'10011101', 256:'101001101' }
# h1_map = { 4: '010', 8: '0100', 16:'00100', 32:'100101', 64:'0010100', 128:'01010110', 256:'011110010' }

class Trellis:
    def __init__(self, N=8, m=100, num_partitions = 4, rate_pmf = None):
        self.N = N
        self.m = m
        self.num_partitions = num_partitions
        self.subset_codebook = {}
        self.rate_pmf = rate_pmf
        self.transition_matrix = self.generate_state_transition()
        return

    def randomize_rate( self ):
        return np.random.choice( a=len(self.subset_codebook), p=self.rate_pmf ) + 1

    def add_new_codebook( self, rate, codebook ):
        self.subset_codebook[rate] = self.set_partition(codebook)
        return

    def get_codebook_from_file( self, rate, filename ):
        with open( filename, 'rb' ) as f:
            self.subset_codebook[rate] = np.load( f )
        return

    def set_partition(self, codebook):
        codebook = np.sort(codebook)
        state_codebook = [np.array([])] * self.num_partitions
        for i in range(len(codebook)):
            state_codebook[i%self.num_partitions] = np.append( state_codebook[i%self.num_partitions], codebook[i] )
        return np.array(state_codebook)

    def generate_state_transition(self):
        #transition_matrix[j,i] is the index of the subset ( plus one) for the j->i connection. 0 if no connection
        transition_matrix = np.zeros((self.N, self.N))
        for i in range( self.N ):
            k = i % (self.N/2)
            for j in range( self.N ):
                if np.int64(j/2) == k:
                    transition_matrix[j,i] = conv_encoder( h0_map[self.N],
                                                          h1_map[self.N],
                                                          np.binary_repr(j, width = np.int64(np.log2(self.N))),
                                                          np.int64(i>=self.N/2))


        return transition_matrix

    def subset_distances(self, rate,  x):
        argmin_within_subset = np.zeros( self.num_partitions ) - 1
        min_within_subset = np.zeros( self.num_partitions ) - 1
        for k in range(self.num_partitions):
            argmin_within_subset[k] =  np.argmin( ( self.subset_codebook[rate][k] - x )**2 )
            min_within_subset[k] = (self.subset_codebook[rate][k, np.int64( argmin_within_subset[k] )] - x)**2
        return min_within_subset, argmin_within_subset

    def decode(self, x_source):
        self.m = len(x_source)
        #distance_matrix[i,j] holds the lowest cost for reaching state i at iteration j
        distance_matrix = np.ones((self.N, self.m + 1)) * np.inf
        #Path[i,j] holds the best predecessor state for reaching state i at iteration j
        path_matrix = np.ones((self.N, self.m + 1)) * -1
        distance_matrix[0,0] = 0
        path_matrix[0,0] = 0
        rate_sequence = []
        for j in range( 1, self.m + 1 ):
            active_rate = self.randomize_rate()
            rate_sequence.append( active_rate )
            dist_within_subset, arg_dist_within_subset = self.subset_distances( active_rate, x_source[j-1] )
            for i in range(self.N):
                possible_predecessors = np.where( self.transition_matrix[:, i] > 0 )[0] #Has Exactly two elements
                best_distance = np.inf
                best_predecessor = -1
                for pred in possible_predecessors:
                    subset = int( self.transition_matrix[ pred, i ] - 1 )
                    candidate_dist = distance_matrix[pred, j-1] + dist_within_subset[ subset ]
                    if candidate_dist < best_distance:
                        best_distance = candidate_dist
                        best_predecessor = pred
                distance_matrix[i,j] = best_distance
                path_matrix[i,j] = best_predecessor
        return distance_matrix, path_matrix, rate_sequence

    def get_codebook_element( self, key ):
        key_split = key.split('-')
        subset_index = int( float( key_split[0] ) )
        index_within_subset = int( float( key_split[1] ) )
        return self.subset_codebook[ subset_index, index_within_subset ]

    def find_best_path( self, distance_matrix, path_matrix, rate_sequence, train_seq ):
        cluster_dictionary = {}
        min_distance = np.min( distance_matrix[:, -1] )
        best_path = []
        code = []
        best_path.append(int( np.argmin( distance_matrix[:, -1] )) )
        for j in range( 1, self.m + 1 ):
            sample_rate = rate_sequence[-j]
            _, arg_dist_within_subset = self.subset_distances( sample_rate, train_seq[-j] )
            prev_state = best_path[-1]
            current_state = int( path_matrix[ prev_state, -j ] )
            best_path.append( current_state )
            subset = int( self.transition_matrix[ current_state, prev_state ] ) - 1
            best_codeword = self.subset_codebook[sample_rate][ subset, int( arg_dist_within_subset[ subset ] ) ]
            code.append( best_codeword )
            cluster_dictionary = update_list_in_dictionary( cluster_dictionary, str(subset) + '-' + str( int( arg_dist_within_subset[ subset ]) ), train_seq[-j]  )
        return cluster_dictionary, min_distance, np.flip(code)

    def update_codebook(self, update_dictionary, rate):
        for key in update_dictionary:
            key_split = key.split('-')
            target_subset = np.int64( float( key_split[0]  ) )
            index_within_subset = np.int64( float( key_split[1] ) )
            self.subset_codebook[rate][ target_subset, index_within_subset ] = np.mean( update_dictionary[key] )
            #self.subset_codebook = self.set_partition( self.subset_codebook.ravel() )
        return

    def train(self, training_sequence, rate, num_iter = 25,):
        for count_iter in range( num_iter ):
            distance_matrix, path_matrix, rate_sequence = self.decode( training_sequence )
            update_dictionary, min_distance, _ = self.find_best_path( distance_matrix, path_matrix, rate_sequence, training_sequence )
            print( 'Iteration ', count_iter + 1, ': ', 'MSE = ', min_distance/len(training_sequence) )
            self.update_codebook( update_dictionary, rate )
        return

    def evaluate( self, test_sequence ):
        distance_matrix, path_matrix, rate_sequence = self.decode( test_sequence )
        return self.find_best_path(distance_matrix, path_matrix, rate_sequence, test_sequence)
    def print_state_codebook(self):
        print(self.subset_codebook)
        print(self.transition_matrix)
        return