#include <stdio.h>
#include <stdlib.h>

#include "problema.h"
#include "method.h"

#define STEP0    100

#define TAU      0.01
#define OMEGA    0.0001

#define MAX_ITER 100000

#define MAX_DIMENSION       100

#define printMatrix(a,b)    if(getLogLevelDisplayed() <= DEBUG)         \
                            {                                           \
                                printf("Matrix :\n");                   \
                                for(i=0;i<b;i++)                        \
                                {                                       \
                                    printf("\t[");                      \
                                    for(j=0;j<b;j++)                    \
                                    {                                   \
                                        printf("%4.4lf",a[i][j]);       \
                                        if(j != b - 1)                  \
                                            printf(", ");               \
                                    }                                   \
                                    printf("]\n");                      \
                                }                                       \
                            }

static
void buildSymetricMatrix(double** matrix, size_t dimension)
{
    int i,j;
    for(i=0;i<dimension;i++)
    {
        for(j=0;j<dimension;j++)
        {
            matrix[i][j] =(j == i)? 1 : 0;
        }
    }
}

//============ CALCUL MATRICIEL ==============

// comprendre ici :
//      v^t * m
static
double* TxM(double v[], double** m, size_t dimension)
{
    double* result;
    int i,j;

    result = (double *)malloc(dimension * sizeof(double));

    for(i=0;i<dimension;i++)
    {
        result[i] = 0;
        for(j=0;j<dimension;j++)
        {
            result[i] += v[j] * m[j][i];
        }
    }

    return result;
}
// comprendre ici :
//      m * v
static
double* MxV(double** m, double* v, size_t dimension)
{
    double* result;
    int i,j;

    result = (double *)malloc(dimension * sizeof(double));

    for(i=0;i<dimension;i++)
    {
        result[i] = 0;
        for(j=0;j<dimension;j++)
        {
            result[i] += v[j] * m[i][j];
        }
    }

    return result;
}

// comprendre ici :
//      v1^t * v2
static
double** VxT(double* v1, double* v2, size_t dimension)
{
    double** result = NULL;
    int i,j;

    //create rows
    result = (double **)malloc(dimension * sizeof(double));
    //create columns
    for(i=0;i<dimension;i++)
        result[i] = (double *)malloc(dimension * sizeof(double));

    for(i=0;i<dimension;i++)
    {
        for(j=0;j<dimension;j++)
        {
            result[i][j] = v1[i] * v2[j];
        }
    }

    return result;
}

// comprendre ici :
//      v1 * v2^t
static
double TxV(double* v1, double* v2, size_t dimension)
{
    double result = 0.0;
    int i;

    for(i=0;i<dimension;i++)
    {
        result += v1[i] * v2[i];
    }

    return result;
}

// comprendre ici :
//      v1 * v2^t
static
double** MxS(double** m, double s, size_t dimension)
{
    double** result = NULL;
    int i,j;

    //create rows
    result = (double **)malloc(dimension * sizeof(double));
    //create columns
    for(i=0;i<dimension;i++)
        result[i] = (double *)malloc(dimension * sizeof(double));

    for(i=0;i<dimension;i++)
    {
        for(j=0;j<dimension;j++)
        {
            result[i][j] = m[i][j] * s;
        }
    }

    return result;
}

// comprendre ici :
//      v1 * v2^t
static
double** MxM(double** m1, double** m2, size_t dimension)
{
    double** result = NULL;
    double   buffer;
    int i,j,k;

    //create rows
    result = (double **)malloc(dimension * sizeof(double));
    //create columns
    for(i=0;i<dimension;i++)
        result[i] = (double *)malloc(dimension * sizeof(double));

    for(i=0;i<dimension;i++)
    {
        for(j=0;j<dimension;j++)
        {
            buffer = 0.0;
            for(k = 0; k < dimension;k++)
                    buffer += m1[i][k] * m2[k][j];
            
            result[i][j] = buffer;
        }
    }

    return result;
}

// comprendre ici :
//      v1 * v2^t
static
double** MplusM(double** m1, double** m2, size_t dimension)
{
    double** result = NULL;
    int i,j;

    //create rows
    result = (double **)malloc(dimension * sizeof(double));
    //create columns
    for(i=0;i<dimension;i++)
        result[i] = (double *)malloc(dimension * sizeof(double));

    for(i=0;i<dimension;i++)
    {
        for(j=0;j<dimension;j++)
        {
            result[i][j] = m1[i][j] + m2[i][j];
        }
    }

    return result;
}

static
void freeMatrix(double** matrix, size_t dimension)
{
    int i;
    for(i=0;i<dimension;i++)
    {
        free(matrix[i]);
    }
    free(matrix);
}

static
double ** copyMatrix(double m[MAX_DIMENSION][MAX_DIMENSION], size_t dimension)
{
    double** result = NULL;
    int i,j;

    //create rows
    result = (double **)malloc(dimension * sizeof(double));
    //create columns
    for(i=0;i<dimension;i++)
        result[i] = (double *)malloc(dimension * sizeof(double));

    for(i=0;i<dimension;i++)
    {
        for(j=0;j<dimension;j++)
        {
            result[i][j] = m[i][j];
        }
    }

    return result;
}

//============================================

// Algo de B.F.G.C
// mise à jour de la hessienne après avoir calculer une itération
static
void updateHessien(
    double**  old_hessien,
    double    old_iteration[],
    double    current_iteration[],
    Problema* pb
)
{
    int i,j;

    double first_term;
    double** second_term;
    double** third_term;

    double y_k[pb->_dimension];
    double delta_k[pb->_dimension];

    double*  vector_buffer_1 = NULL;
    double*  vector_buffer_2 = NULL;
    double** matrix_buffer_1 = NULL;
    double** matrix_buffer_2 = NULL;
    double** matrix_buffer_3 = NULL;

    // calcul de y_k
        // gradient J( x_{k+1} )
        if( pb->_derivees[0] )
            vector_buffer_1 = pb->_derivees[0](current_iteration);
        else
            vector_buffer_1 = approximateGradient(
                                pb->_function,
                                current_iteration,
                                pb->_dimension,
                                0.001//FIXME
                            );
    
        // gradient J( x_{k} )
        if( pb->_derivees[0] )
            vector_buffer_2 = pb->_derivees[0](old_iteration);
        else
            vector_buffer_2 = approximateGradient(
                                pb->_function,
                                old_iteration,
                                pb->_dimension,
                                0.001//FIXME
                            );
        //ahah !
        for(i=0;i<pb->_dimension;i++)
            y_k[i] = vector_buffer_1[i] - vector_buffer_2[i];

        free(vector_buffer_1);
        free(vector_buffer_2);

    // calcul de delta_k
        for(i=0;i<pb->_dimension;i++)
            delta_k[i] = current_iteration[i] - old_iteration[i];

    // premier terme
        // y_{k}^{t} * S_{k} * y_{k}
        vector_buffer_1 = TxM(y_k,old_hessien,pb->_dimension);

    // ( 1 + ( ( y_{k}^{t} * S_{k} * y_{k} ) / ( delta_{k}^{t} * y_{k} ) ) )
    first_term =
            1
            + (
                TxV(vector_buffer_1,y_k,pb->_dimension)
                / TxV(delta_k,y_k,pb->_dimension)
            );

        free(vector_buffer_1);

        // ( ( delta_{k} * delta_{k}^{t} )
        matrix_buffer_1 = VxT(delta_k,delta_k,pb->_dimension);
        
    // ( ( delta_{k} * delta_{k}^{t} ) / ( delta_{k}^{t} * y_{k} ) )
    second_term =
            MxS(
                matrix_buffer_1,
                (double)(1.0 / TxV(delta_k,y_k,pb->_dimension)),
                pb->_dimension
            );

        freeMatrix(matrix_buffer_1,pb->_dimension);

    //third term
        // delta_{k} * y_{k}^{t}
        matrix_buffer_1 =
                VxT(delta_k,y_k,pb->_dimension);
        // delta_{k} * y_{k}^{t} * S_{k}
        matrix_buffer_2 =
                MxM(matrix_buffer_1,old_hessien,pb->_dimension);

            freeMatrix(matrix_buffer_1,pb->_dimension);

        // y_{k} * delta_{k}^{t}
        matrix_buffer_1 = VxT(y_k,delta_k,pb->_dimension);
        // S_{k} * y_{k} * delta_{k}^{t}
        matrix_buffer_3 = MxM(old_hessien,matrix_buffer_1,pb->_dimension);

            freeMatrix(matrix_buffer_1,pb->_dimension);

        // delta_{k} * y_{k}^{t} * S_{k} + S_{k} * y_{k} * delta_{k}^{t}
        matrix_buffer_1 = MplusM(matrix_buffer_2,matrix_buffer_3,pb->_dimension);

            freeMatrix(matrix_buffer_2,pb->_dimension);
            freeMatrix(matrix_buffer_3,pb->_dimension);

    third_term = MxS(
            matrix_buffer_1,
            (double)(1.0 / TxV(delta_k,y_k,pb->_dimension)),
            pb->_dimension);
    
            freeMatrix(matrix_buffer_1,pb->_dimension);

    // compute it all
    matrix_buffer_1 = MxS(second_term,first_term,pb->_dimension);

    matrix_buffer_2 = MplusM(old_hessien,matrix_buffer_1,pb->_dimension);
            freeMatrix(matrix_buffer_1,pb->_dimension);

    matrix_buffer_1 = MxS(third_term,-1.0,pb->_dimension);

    matrix_buffer_3 = MplusM(matrix_buffer_2,matrix_buffer_1,pb->_dimension);

    for(i=0;i<pb->_dimension;i++)
    {
        for(j=0;j<pb->_dimension;j++)
        {
            old_hessien[i][j] = matrix_buffer_3[i][j];
        }
    }

    freeMatrix(matrix_buffer_1,pb->_dimension);
    freeMatrix(matrix_buffer_2,pb->_dimension);
    freeMatrix(matrix_buffer_3,pb->_dimension);
    freeMatrix(second_term,pb->_dimension);
    freeMatrix(third_term,pb->_dimension);

}

static double homeRandom(double a, double b)
{
    return ( rand()/(double)RAND_MAX ) * (b-a) + a;
}

static double ESwithRo(
    double    x[],
    Problema* pb,
    double    ro)
{
    int i;
    double  copy[pb->_dimension];
    double* calculGradient;

    // On copie le tableau car on ne veut pas le modifier
    memcpy(copy,x,pb->_dimension * sizeof(double));

    // On calcule le gradient car utile au calcul
    if( pb->_derivees[0] )
    {
        calculGradient = pb->_derivees[0](x);
    }
    // sinon
    else
    {
        calculGradient = approximateGradient(
                            pb->_function,
                            x,
                            pb->_dimension,
                            0.001//FIXME
                        );
    }

    // Calcul du vecteur interm�diaire
    for(i = 0; i < 3; i++)
        copy[i] = x[i] - ro * calculGradient[i];

    free(calculGradient);

    return pb->_function(copy,pb->_dimension)[0];
}

/**
 * ici dk = - grad( ES( x_{k} ) )
 * le calcul devient donc :
 * OMEGA_{1} * RO_{k}^{i} * - norme( grad( ES( w_{k} ) ) )
 */
static double EScomplement(
    double x[],
    Problema* pb,
    double step)
{
    double* calculGradient;
    double result;

    // On calcule le gradient car utile au calcul
    if( pb->_derivees[0] )
    {
        calculGradient = pb->_derivees[0](x);
    }
    // sinon
    else
    {
        calculGradient = approximateGradient(
                            pb->_function,
                            x,
                            pb->_dimension,
                            0.001//FIXME
                        );
    }

    result = OMEGA * step * - normeEuclidienne( calculGradient,pb->_dimension );

    free(calculGradient);

    return result;
}

static double computeArmijoStep(
    double x[],
    Problema* pb)
{
    // Initialize - ARBITRARY VALUE
    double step = STEP0;            // Pas initial - volontairement "grand"
    int    counter = 0;           // Compteur de boucle
    double eswro, escomp, es;     // Calcul interm�diaire pour la condition de boucle

    //Debug
    logMessage(DEBUG,"newline","");
    logMessage(DEBUG,"méthode armijo : pas initial",doubleToString(step));

    // Calcul de la premi�re it�ration
    eswro  = ESwithRo(x, pb, step);
    es     = pb->_function(x,pb->_dimension)[0];
    escomp = EScomplement(x,pb,step);

    logMessage(DEBUG,"condition : J( x_{k} + ro_{k}^{i} * d_{k})",doubleToString(eswro));
    logMessage(DEBUG,"condition : J( x_{k} )",doubleToString(es));
    logMessage(DEBUG,"condition : - omega_{1} * ro_{k}^{i} * d_{k} * d_{k}",doubleToString(escomp));
    logMessage(DEBUG,"condition : eswro > es + escomp",printBoolean( eswro > es + escomp ));

    while( ( eswro > es + escomp ) && ( ++counter < MAX_ITER ) )
    {
        //new step - RANDOMLY
        do
        {
            step = ((double)homeRandom(TAU,1 - TAU)) * step;
        }
        while(step == 0);

        logMessage(DEBUG,"actualisation pas",doubleToString(step));

        eswro  = ESwithRo(x, pb, step);
        es     = pb->_function(x,pb->_dimension)[0];
        escomp = EScomplement(x,pb,step);

        logMessage(DEBUG,"condition : J( x_{k} + ro_{k}^{i} * d_{k})",doubleToString(eswro));
        logMessage(DEBUG,"condition : J( x_{k} )",doubleToString(es));
        logMessage(DEBUG,"condition : - omega_{1} * ro_{k}^{i} * d_{k} * d_{k}",doubleToString(escomp));
        logMessage(DEBUG,"condition : eswro > es + escomp",printBoolean( eswro > es + escomp ));

    }

    logMessage(DEBUG,"counter value",integerToString(counter,10));
    logMessage(DEBUG,"counter limit",integerToString(MAX_ITER,10));
    logMessage(DEBUG,"resultat optimisation pas",doubleToString(step));

    return step;
}

void resolveByNearlyNewton(
    Problema*    pb,
    Method*      m,
    double       init_vector[],
    void*        args[])
{
    // iteration
    double current_iteration[pb->_dimension];
    double old_iteration    [pb->_dimension];

    //
    double norme;
    double step;
    double* derivee_buffer = NULL;
    double* vector_buffer = NULL;
    double diff[pb->_dimension];

    double old_error = -1, current_error;
    int counter = 0,i,j;
    boolean toContinue = 1;

    double** hessien = NULL;

    logMessage(DEBUG,"newline","");
    logMessage(DEBUG,"function name","resolveByNearlyNewton");
    logMessage(DEBUG,"parameters : problem name",pb->_name);
    logMessage(DEBUG,"parameters : method name",m->_name);
    logMessage(DEBUG,"parameters : vector",printVector(init_vector,pb->_dimension));
    logMessage(DEBUG,"parameters : epsilon in",doubleToString(m->_epsilon_in));
    logMessage(DEBUG,"parameters : epsilon out",doubleToString(m->_epsilon_out));
    logMessage(DEBUG,"parameters : patience",integerToString(m->_max_iteration,10));

    //create rows
    hessien = (double **)malloc(pb->_dimension * sizeof(double));
    //create columns
    for(i=0;i<pb->_dimension;i++)
        hessien[i] = (double *)malloc(pb->_dimension * sizeof(double));

    // Initialize old_iteration <- init_vector
    memcpy(
            old_iteration,
            init_vector,
            sizeof(double) * pb->_dimension);

    buildSymetricMatrix(hessien,pb->_dimension);

    if(pb->_solution)
    {
        logMessage(DEBUG,"solution","bundled with the problem");
        logMessage(DEBUG,"solution value",printVector(pb->_solution,pb->_dimension));
    }
    if( pb->_derivees[0] )
        logMessage(DEBUG,"derivee acquisition","we got the first derivee");
    else
        logMessage(DEBUG,"derivee acquisition","dynamically calculated");

    do
    {

/*
        LogLevel old = getLogLevelDisplayed();
*/
        
        printMatrix(hessien,pb->_dimension);
        
        // compute step
        step = computeArmijoStep(old_iteration,pb);
        
        logMessage(DEBUG,"armijo step",doubleToString(step));

        free(derivee_buffer); //FIXME due to approximate gradient
        // Si on a la dérivée première autant l'utiliser...
        if( pb->_derivees[0] )
        {
            derivee_buffer = pb->_derivees[0](old_iteration);
        }
        // sinon
        else
        {
            derivee_buffer = approximateGradient(
                                pb->_function,
                                old_iteration,
                                pb->_dimension,
                                0.001//FIXME
                            );
        }
        
        logMessage(DEBUG,"derivee value",printVector(derivee_buffer,pb->_dimension));

        free(vector_buffer);

        vector_buffer = MxV(hessien,derivee_buffer,pb->_dimension);

        logMessage(DEBUG,"Sk * gradient value",printVector(vector_buffer,pb->_dimension));

        // Calcul it�ration
        for(i=0; i<pb->_dimension; i++)
            current_iteration[i] =
                    old_iteration[i]
                    - (
                        step
                        * vector_buffer[i]
                    );

        logMessage(DEBUG,"current iteration",printVector(current_iteration,pb->_dimension));
        logMessage(DEBUG,"old iteration",printVector(old_iteration,pb->_dimension));

/*
        setLogLevelDisplayed(old);
        setLogLevelDisplayed(NONE);
*/

        // Mise à jour de la hessienne
        updateHessien(hessien,old_iteration,current_iteration,pb);


        // Calcul de la diff�rence entre les deux vecteurs
        for(i=0; i<pb->_dimension; i++)
            diff[i] = current_iteration[i] - old_iteration[i];

        // Calcul de la norme pour la condition d'arr�t
        norme = normeEuclidienne(diff,pb->_dimension);

        logMessage(DEBUG,"difference's norme value",doubleToString(norme));

        //Calcul de l'erreur, si on la possède
        if(pb->_solution){
            current_error = error(pb->_dimension,current_iteration,pb->_solution);

            logMessage(DEBUG,"error",doubleToString(current_error));
            if(old_error != -1)
                logMessage(DEBUG,"error difference between iteration",doubleToString(fabs(current_error - old_error)));

            // Actualisation statistique sur l'erreur
            old_error = current_error;
        }

        // Pr�paration prochaine it�ration
        memcpy(old_iteration,current_iteration,pb->_dimension * sizeof(double));


        toContinue = toContinue && ( norme > m->_epsilon_in );
        // TODO
        //toContinue = toContinue && ( fabs(pb->_function(old_iteration,pb->_dimension) - pb->_function(current_iteration,pb->_dimension)) > m->_epsilon_out );
        toContinue = toContinue && ( ++counter < m->_max_iteration );
        // TODO
        //toContinue = toContinue && (pb->_solution && old_error != -1) ? fabs(current_error - old_error) < m->_epsilon_in : 1)

        logMessage(DEBUG,"counter",integerToString(counter,10));

        logMessage(DEBUG,"stop condition 1 : norme > m->_epsilon_in",printBoolean(norme > m->_epsilon_in));
        logMessage(DEBUG,"stop condition 2 : counter < m->_max_iteration",printBoolean((counter+1) < m->_max_iteration));
        logMessage(DEBUG,"should we continue",printBoolean(toContinue));
    }
    while( toContinue );

    // TODO : fill, FIXME : make it a define
    m->_last_complexity            = counter;
    m->_last_result_function_value = pb->_function(current_iteration,pb->_dimension)[0];
    if(pb->_solution)
    {// FIXME
        m->_last_absolute_error        = normeEuclidienne(diff,pb->_dimension);
        m->_last_relative_error        = normeEuclidienne(diff,pb->_dimension)/normeEuclidienne(pb->_solution,pb->_dimension);
    }
    // S'il y avait une précédente valeur, on libère le précédent espace utilisé
    if(m->_last_result) free(m->_last_result);
    m->_last_result = (double*) malloc(sizeof(double) * pb->_dimension);
    memcpy(m->_last_result, current_iteration, sizeof(double) * pb->_dimension);
}
