/*
 * Copyright (c) 2010 Nicolas George
 * Copyright (c) 2011 Stefano Sabatini
 * Copyright (c) 2014 Andrey Utkin
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in
 * all copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
 * THE SOFTWARE.
 */

/**
 * @file
 * API for ai to create ocr model, get output from ocr model
 * @example ocr_postprocess.h
 *
 * @added by cube.sun@netint.ca
 * use network api to create ocr model, process ocr model
 */

#include <stdio.h>
#include <stdlib.h>
#include <unistd.h>
#include <math.h>
#include "nierrno.h"
#include "ocr_postprocess.h"
#include "ni_log.h"

static char characters_str[6625][8] = { 0 };

static int get_characters(const char *character_file)
{
    char buf[8] = { 0 };
    char *tmp = NULL;
    char *empty = "e";
    char *end = " ";

    strcpy(characters_str[0], empty);
    FILE *fp = fopen(character_file, "rb");
    if (!fp) {
        pr_log("can not assess to character_file %s\n", character_file);
        char character_str[] = "0123456789abcdefghijklmnopqrstuvwxyz";
        int i;
        for (i = 0; i < strlen(character_str); i++) {
            strcpy(characters_str[i + 1], &character_str[i]);
        }
    }
    else {
        int index = 1;
        while(fgets(buf, 8, fp)!= NULL) {
            strcpy(characters_str[index], buf);
            if ((tmp = strstr(characters_str[index], "\n")))
            {
                *tmp = '\0';
            }
            index++;
        }
        strcpy(characters_str[index], end);
    }
    fclose(fp);
    return 0;
}

static int get_ocr_detections(ni_ocr_network_layer_t *l, bool remove_duplicate, ocr_detection* det_res)
{
    float *predictions = l->output;
    char result[140] = {0};
    float prob_sum = 0;
    int char_count = 0;
    int i, j;

    int32_t index[65] = {0};
    float prob[65] = {0};
    // ni_ocr_get_txt_index_prob(l, predictions, index, prob);
    for (i = 0; i < l->char_num; i++) {
        float max_prob = 0;
        for (j = 0; j < l->char_lib; j++) {
            if (predictions[i * l->char_lib + j] > max_prob) {
                max_prob = predictions[i * l->char_lib + j];
                index[i] = j;
                prob[i] = max_prob;
            }
            if (max_prob > 0.51) {
                break;
            }
        }
    }

    if (remove_duplicate) {
        for (i = 0; i < l->char_num - 1; i++) {
            if (index[i] == index[i+1]) {
                index[i+1] = 0;
            }
        }
    }

    for (i = 0; i < l->char_num; i++) {
        if (index[i]) {
            strcat(result, characters_str[index[i]]);
            // process specify situation,
            // if detect "ol", "oi", need change to "01"
            int len = strlen(result);
            if (len >= 2 && (result[len-1] == 'l' || result[len-1] == 'i' || result[len-1] == '1')) {
                if (result[len-2] == 'o' || result[len-2] == '0') {
                    result[len-2] = '0';
                    result[len-1] = '1';
                }
            }
            prob_sum += prob[i];
            char_count++;
        }
    }
    if (char_count) {
        strcpy(det_res->txt, result);
        det_res->prob = prob_sum / char_count;
        pr_log("det_res txt %s prob %f\n", det_res->txt, det_res->prob);
    }
    else {
        det_res->txt[0] = '\0';
        det_res->prob = 0;
    }
    return 0;
}

static int ni_ocr_get_detections(OCRModelCtx *ctx)
{
    int i, ret;

    for (i = 0; i < ctx->output_number; i++) {
        ret = get_ocr_detections(&ctx->layers[i], true, ctx->det_res);
        if (ret != 0) {
            pr_err("failed to get ocr detection at layer %d\n", i);
            return ret;
        }
        pr_log("layer %d, ocr detections txt: %s\n", i, ctx->det_res->txt);
    }

    return 0;
}

static void destroy_ocr_model(OCRModelCtx *ctx)
{
    if (ctx->out_tensor) {
        int i;
        for (i = 0; i < ctx->output_number; i++) {
            if (ctx->out_tensor[i]) {
                free(ctx->out_tensor[i]);
            }
        }
        free(ctx->out_tensor);
        ctx->out_tensor = NULL;
    }
    if (ctx->det_res) {
        free(ctx->det_res);
        ctx->det_res = NULL;
    }
    if (ctx->layers) {
        free(ctx->layers);
        ctx->layers = NULL;
    }
}

static int create_ocr_model(OCRModelCtx *ctx, ni_network_data_t *network_data,
        int model_width, int model_height, const char *character_file)
{
    int i, ret = 0;

    ctx->input_width  = network_data->linfo.in_param[0].sizes[0];
    ctx->input_height = network_data->linfo.in_param[0].sizes[1];

    ctx->output_number = network_data->output_num;

    ctx->det_res = malloc(sizeof(ocr_detection));
    if (ctx->det_res == NULL) {
        pr_err("failed to allocate detection memory\n");
        ret = NIERROR(ENOMEM);
        goto fail;
    }
    ctx->det_res->prob = 0;

    ctx->out_tensor = (uint8_t **)calloc(network_data->output_num,
            sizeof(uint8_t **));
    if (ctx->out_tensor == NULL) {
        pr_err("failed to allocate output tensor bufptr\n");
        ret = NIERROR(ENOMEM);
        goto fail;
    }

    for (i = 0; i < network_data->output_num; i++) {
        ni_network_layer_params_t *p_param = &network_data->linfo.out_param[i];
        ctx->out_tensor[i] =
                (uint8_t *)malloc(ni_ai_network_layer_dims(p_param) * sizeof(float));
        if (ctx->out_tensor[i] == NULL) {
            pr_err("failed to allocate output tensor buffer\n");
            ret = NIERROR(ENOMEM);
            goto fail;
        }
    }

    ctx->layers =
        malloc(sizeof(ni_ocr_network_layer_t) * network_data->output_num);
    if (!ctx->layers) {
        pr_err("cannot allocate network layer memory\n");
        ret = NIERROR(ENOMEM);
        goto fail;
    }

    for (i = 0; i < network_data->output_num; i++) {
        ctx->layers[i].char_lib      = network_data->linfo.out_param[i].sizes[0];
        ctx->layers[i].char_num      = network_data->linfo.out_param[i].sizes[1];
        ctx->layers[i].output_number =
            ni_ai_network_layer_dims(&network_data->linfo.out_param[i]);

        ctx->layers[i].output = (float *)ctx->out_tensor[i];

        pr_log("network layer %d: char_lib %d, char_num %d output_number %d\n", i,
                ctx->layers[i].char_lib, ctx->layers[i].char_num, ctx->layers[i].output_number);
    }

    get_characters(character_file);
    return ret;
fail:
    destroy_ocr_model(ctx);
    return ret;
}

OCRModel ocr = {
    .create_model       = create_ocr_model,
    .destroy_model      = destroy_ocr_model,
    .ni_get_detections  = ni_ocr_get_detections,
};
