/*
 * Copyright (c) 2010 Nicolas George
 * Copyright (c) 2011 Stefano Sabatini
 * Copyright (c) 2014 Andrey Utkin
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */

/**
 * @file
 * API for ai to create vqe model, get output from vqe model
 * @example vqe_postprocess.c
 *
 * @added by cube.sun@netint.ca
 * use network api to create vqe model, process vqe model
 */

#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <math.h>
#include "nierrno.h"
#include "vqe_postprocess.h"
#include "ni_log.h"

#define MAX_WIN_SIZE 120
#define CLASSES_NUM 8

static int frame_num = 0;

int status[MAX_WIN_SIZE] = {0};
float confs[MAX_WIN_SIZE] = {0};

char *classes_name[] = {"gaussian_blur", "gaussian_noise", "ghosting_blur", "normal_data",
                        "occlusion_data", "snow_noise", "sp_noise", "stripe_noise"};
char *chinese_class_name[] = {"Shaking_or_Blurry_抖动或模糊", "Noise_噪声",
                            "Shaking_or_Blurry_抖动或模糊", "Normal_正常",
                            "Blocking_遮挡", "Noise_噪声", "Noise_噪声", "Noise_噪声"};

static int get_vqe_detections(ni_vqe_network_layer_t *l, int window_size)
{
    if (window_size > MAX_WIN_SIZE || window_size < 1) {
        pr_err("slide window size out of scope\n");
        return -1;
    }
    if (l->classes != CLASSES_NUM) {
        pr_err("classes number not equal to 8\n");
        return -1;
    }
    float *predictions = l->output;
    float max_args = 0;
    int max_index = 0;
    int i;
    for (i = 0; i < CLASSES_NUM; i++) {
        if (predictions[i] > max_args) {
            max_args = predictions[i];
            max_index = i;
        }
    }

    float sum = 0;
    for (i = 0; i < CLASSES_NUM; i++) {
        sum += (exp(predictions[i] - max_args));
    }
    float conf = 1 / sum;
    status[frame_num % window_size] = max_index;
    confs[frame_num % window_size] = conf;
    return 0;
}

static int ni_vqe_get_detections(VQEModelCtx *ctx)
{
    int i, ret;
    //vqe output_number = 1
    for (i = 0; i < ctx->output_number; i++) {
        ret = get_vqe_detections(&ctx->layers[i], ctx->window_size);
        if (ret != 0) {
            pr_err("failed to get vqe detection at layer %d\n", i);
            return ret;
        }
    }
    int window_size = 0;
    if (frame_num < ctx->window_size - 1) {
        frame_num++;
        return 0;
    } else {
        window_size = ctx->window_size;
    }
    if (window_size > MAX_WIN_SIZE || window_size < 1) {
        pr_err("slide window size out of scope\n");
        return -1;
    }

    int window_status = 3;
    int status_count[CLASSES_NUM] = {0};
    float status_conf_sum[CLASSES_NUM] = {0};
    //get window_status
    for (i = 0; i < window_size; i++) {
        status_count[status[i]]++;
    }
    float normal_ratio = (float)status_count[3] / window_size;
    // printf("status %d, ratio %f \n", window_status, normal_ratio);
    if (normal_ratio < ctx->alert_threshold) {
        int max_abnormal_status_count = 0;
        int abnormal_status = 0;
        int i, j;

        for (i = 0; i < window_size; i++) {
            status_conf_sum[status[i]] += confs[i];
        }
        for (j = 0; j < CLASSES_NUM; j++) {
            if (j == 3) continue;
            if (status_count[j] > max_abnormal_status_count) {
                max_abnormal_status_count = status_count[j];
                abnormal_status = j;
            } else if (status_count[j] == max_abnormal_status_count) {
                if (status_conf_sum[j] > status_conf_sum[abnormal_status]) {
                    max_abnormal_status_count = status_count[j];
                    abnormal_status = j;
                }
            }
        }
        window_status = abnormal_status;
    }
    //calculate window_status conf
    float conf_sum = 0;
    int conf_count = 0;
    for (i = 0; i < window_size; i++) {
        if (status[i] == window_status) {
            conf_count++;
            conf_sum += confs[i];
        }
    }
    float window_conf = conf_count > 0 ? conf_sum/conf_count : 1;

    strcpy(ctx->det_res->txt, chinese_class_name[window_status]);
    ctx->det_res->conf = window_conf;
    pr_log("vqe detections classes: %d, conf %f\n", window_status, ctx->det_res->conf);

    frame_num++;
    return 0;
}

static void destroy_vqe_model(VQEModelCtx *ctx)
{
    if (ctx->out_tensor) {
        int i;
        for (i = 0; i < ctx->output_number; i++) {
            if (ctx->out_tensor[i]) {
                free(ctx->out_tensor[i]);
            }
        }
        free(ctx->out_tensor);
        ctx->out_tensor = NULL;
    }
    if (ctx->det_res) {
        free(ctx->det_res);
        ctx->det_res = NULL;
    }
    if (ctx->layers) {
        free(ctx->layers);
        ctx->layers = NULL;
    }
}

static int create_vqe_model(VQEModelCtx *ctx, ni_network_data_t *network_data,
        float alert_threshold, int window_size, int model_width, int model_height)
{
    int i, ret = 0;

    ctx->input_width  = network_data->linfo.in_param[0].sizes[0];
    ctx->input_height = network_data->linfo.in_param[0].sizes[1];

    ctx->alert_threshold = alert_threshold;
    ctx->window_size = window_size;
    ctx->output_number = network_data->output_num;

    ctx->det_res = malloc(sizeof(vqe_detection));
    if (ctx->det_res == NULL) {
        pr_err("failed to allocate detection memory\n");
        ret = NIERROR(ENOMEM);
        goto fail;
    }
    ctx->det_res->conf = 0;

    ctx->out_tensor = (uint8_t **)calloc(network_data->output_num,
            sizeof(uint8_t **));
    if (ctx->out_tensor == NULL) {
        pr_err("failed to allocate output tensor bufptr\n");
        ret = NIERROR(ENOMEM);
        goto fail;
    }

    for (i = 0; i < network_data->output_num; i++) {
        ni_network_layer_params_t *p_param = &network_data->linfo.out_param[i];
        ctx->out_tensor[i] =
                (uint8_t *)malloc(ni_ai_network_layer_dims(p_param) * sizeof(float));
        if (ctx->out_tensor[i] == NULL) {
            pr_err("failed to allocate output tensor buffer\n");
            ret = NIERROR(ENOMEM);
            goto fail;
        }
    }

    ctx->layers =
        malloc(sizeof(ni_vqe_network_layer_t) * network_data->output_num);
    if (!ctx->layers) {
        pr_err("cannot allocate network layer memory\n");
        ret = NIERROR(ENOMEM);
        goto fail;
    }

    for (i = 0; i < network_data->output_num; i++) {
        ctx->layers[i].classes = network_data->linfo.out_param[i].sizes[0];
        ctx->layers[i].output_number =
            ni_ai_network_layer_dims(&network_data->linfo.out_param[i]);

        ctx->layers[i].output = (float *)ctx->out_tensor[i];

        pr_log("network layer %d: classes %d output_number %d\n", i,
                ctx->layers[i].classes, ctx->layers[i].output_number);
    }
    return ret;
fail:
    destroy_vqe_model(ctx);
    return ret;
}

VQEModel vqe = {
    .create_model       = create_vqe_model,
    .destroy_model      = destroy_vqe_model,
    .ni_get_detections  = ni_vqe_get_detections,
};
