#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <locale.h>
#include "nierrno.h"
#include "lpr_network.h"
#include "ni_yolo_utils.h"
#include "ni_log.h"

static int g_masks[3][3] = {{0, 1, 2}, {3, 4, 5}, {6, 7, 8}};
static int sequence[3] = {0, 1, 2};
static float g_biases[] = {4, 5, 8, 10, 13, 16, 23, 29, 43, 55, 73, 105, 146, 217, 231, 300, 335, 433};

static void get_landmarks(double landmarks[4][2], float *x, float *biases, int oidx,
                        int n, int index, int col,
                        int row, int lw, int lh, int nw, int nh, int stride)
{
    landmarks[0][0] = x[index + 0 * stride] * biases[2 * n]     + (float)col * anchor_stride[oidx];
    landmarks[0][1] = x[index + 1 * stride] * biases[2 * n + 1] + (float)row * anchor_stride[oidx];
    landmarks[1][0] = x[index + 2 * stride] * biases[2 * n]     + (float)col * anchor_stride[oidx];
    landmarks[1][1] = x[index + 3 * stride] * biases[2 * n + 1] + (float)row * anchor_stride[oidx];
    landmarks[2][0] = x[index + 4 * stride] * biases[2 * n]     + (float)col * anchor_stride[oidx];
    landmarks[2][1] = x[index + 5 * stride] * biases[2 * n + 1] + (float)row * anchor_stride[oidx];
    landmarks[3][0] = x[index + 6 * stride] * biases[2 * n]     + (float)col * anchor_stride[oidx];
    landmarks[3][1] = x[index + 7 * stride] * biases[2 * n + 1] + (float)row * anchor_stride[oidx];
}

static void get_box_landmarks(LicDetModelCtx *ctx, detection *det, double landmarks[4][2],
        float gain_x, float gain_y)
{
    YoloModelCtx *yolo_model_ctx = &ctx->yolo_model_ctx;
    int col, row;
    ni_roi_network_layer_t *l = &yolo_model_ctx->layers[det->layer_idx];

    row = det->sub_idx / l->width;
    col = det->sub_idx % l->width;

    float *predictions = l->output;

    get_landmarks(landmarks, predictions, l->biases, l->index, l->mask[det->color],
            entry_index(l, 0, det->color, det->sub_idx, 5), col, row,
            l->width, l->height, yolo_model_ctx->input_width, yolo_model_ctx->input_height,
            l->width * l->height);

    landmarks[0][0] = landmarks[0][0] / gain_x;
    landmarks[0][1] = landmarks[0][1] / gain_y;
    landmarks[1][0] = landmarks[1][0] / gain_x;
    landmarks[1][1] = landmarks[1][1] / gain_y;
    landmarks[2][0] = landmarks[2][0] / gain_x;
    landmarks[2][1] = landmarks[2][1] / gain_y;
    landmarks[3][0] = landmarks[3][0] / gain_x;
    landmarks[3][1] = landmarks[3][1] / gain_y;
}

int get_licence_plate_boxes(LicDetModelCtx *ctx, uint32_t img_width,
        uint32_t img_height, struct plate_box **plate_box, int *plate_num)
{
    int i;
    int ret;
    int dets_num;
    int real_num = 0;
    detection *dets;
    YoloModelCtx *yolo_model_ctx = &ctx->yolo_model_ctx;
    struct plate_box *pplate_box;

    *plate_box = NULL;
    *plate_num = 0;

    ret = ni_get_yolov5_detections(yolo_model_ctx, sequence, 0);
    if (ret < 0) {
        pr_err("cannot get detection\n");
        return ret;
    }
    if (ret == 0) {
        return ret;
    }
    dets_num = ret;
    dets = yolo_model_ctx->det_cache.dets;

    for (i = 0; i < dets_num; i++) {
        if (dets[i].max_prob == 0) {
            continue;
        }
        real_num++;
    }
    if (real_num == 0) {
        return 0;
    }

    pplate_box = malloc(sizeof(struct plate_box) * real_num);
    if (!pplate_box) {
        ni_err("cannot allocate plate box\n");
        return NIERROR(ENOMEM);
    }

    if (1) { //tiling mode
        float gain_x = (yolo_model_ctx->input_width  / (float)img_width);
        float gain_y = (yolo_model_ctx->input_height / (float)img_height);

        for (i = 0; i < dets_num; i++) {
            struct plate_box *pbox;
            if (dets[i].max_prob == 0) {
                continue;
            }
            pbox = &pplate_box[*plate_num];
            ni_resize_coords_tiling_mode(&dets[i], &pbox->roi_box, img_width, img_height,
                    gain_x, gain_y);
            get_box_landmarks(ctx, &dets[i], pbox->landmark, gain_x, gain_y);
            (*plate_num)++;
        }
    } else {
        float gain_x = (yolo_model_ctx->input_width  / (float)img_width);
        float gain_y = (yolo_model_ctx->input_height / (float)img_height);
        float gain = (float) ((yolo_model_ctx->input_width > yolo_model_ctx->input_height) ?
                yolo_model_ctx->input_width : yolo_model_ctx->input_height) /
                ((img_width > img_height) ? img_width : img_height);
        float pad0 = (yolo_model_ctx->input_width  - img_width  * gain) / 2.0;
        float pad1 = (yolo_model_ctx->input_height - img_height * gain) / 2.0;

        for (i = 0; i < dets_num; i++) {
            struct plate_box *pbox;
            if (dets[i].max_prob == 0) {
                continue;
            }
            pbox = &pplate_box[*plate_num];
            ni_resize_coords_padding_mode(&dets[i], &pbox->roi_box, img_width, img_height,
                    gain, pad0, pad1);
            get_box_landmarks(ctx, &dets[i], pbox->landmark, gain_x, gain_y);
            (*plate_num)++;
        }
    }

    *plate_box = pplate_box;
    return 0;
}

static void destroy_lic_det_yolov5_model(YoloModelCtx *ctx)
{
    if (ctx == NULL) {
        return;
    }

    if (ctx->out_tensor) {
        int i;
        for (i = 0; i < ctx->output_number; i++) {
            free(ctx->out_tensor[i]);
        }
        free(ctx->out_tensor);
        ctx->out_tensor = NULL;
    }

    if (ctx->layers) {
        int i;
        for (i = 0; i < ctx->output_number; i++) {
            free(ctx->layers[i].biases);
        }
        free(ctx->layers);
        ctx->layers = NULL;
    }

    if (ctx->det_cache.dets) {
        free(ctx->det_cache.dets);
        ctx->det_cache.dets = NULL;
    }

}

void destroy_lic_detect_model(LicDetModelCtx *ctx)
{
    destroy_lic_det_yolov5_model(&ctx->yolo_model_ctx);

    if (ctx->scale_ctx) {
        sws_freeContext(ctx->scale_ctx);
        ctx->scale_ctx = NULL;
    }
}

static int create_lic_det_yolov5_model(YoloModelCtx *ctx, ni_network_data_t *network_data,
        int pic_width, int pic_height, float obj_thresh, float nms_thresh)
{
    int i;
    int ret = 0;

    ctx->obj_thresh = obj_thresh;
    ctx->nms_thresh = nms_thresh;

    ctx->input_width  = network_data->linfo.in_param[0].sizes[0];
    ctx->input_height = network_data->linfo.in_param[0].sizes[1];

    ctx->output_number = network_data->output_num;
    ctx->out_tensor = (uint8_t **)calloc(network_data->output_num,
            sizeof(uint8_t **));
    if (ctx->out_tensor == NULL) {
        ni_err("failed to allocate output tensor bufptr\n");
        return NIERROR(ENOMEM);
    }

    for (i = 0; i < network_data->output_num; i++) {
        ni_network_layer_params_t *p_param = &network_data->linfo.out_param[i];
        ctx->out_tensor[i] =
                (uint8_t *)malloc(ni_ai_network_layer_dims(p_param) * sizeof(float));
        if (ctx->out_tensor[i] == NULL) {
            pr_err("failed to allocate output tensor buffer\n");
            ret = NIERROR(ENOMEM);
            goto fail_out;
        }
    }

    ctx->layers = malloc(sizeof(ni_roi_network_layer_t) * network_data->output_num);
    if (!ctx->layers) {
        pr_err("cannot allocate network layer memory\n");
        ret = NIERROR(ENOMEM);
        goto fail_out;
    }

    for (i = 0; i < network_data->output_num; i++) {
        ctx->layers[i].index     = i;
        ctx->layers[i].width     = network_data->linfo.out_param[i].sizes[0];
        ctx->layers[i].height    = network_data->linfo.out_param[i].sizes[1];
        ctx->layers[i].channel   = network_data->linfo.out_param[i].sizes[2];
        ctx->layers[i].component = 3;
        //ctx->layers[i].classes = (ctx->layers[i].channel - (4 + 1));
        ctx->layers[i].classes = 2;
        //ctx->layers[i].output_number =
        //    ni_ai_network_layer_dims(&network_data->linfo.out_param[i]);
        ctx->layers[i].padding = 8; //number of landmarks
        ctx->layers[i].output = (float *)ctx->out_tensor[i];

        ///TODO rm [0]
        memcpy(ctx->layers[i].mask, &g_masks[i][0], sizeof(ctx->layers[i].mask));

        ctx->layers[i].biases = (float *)malloc(sizeof(g_biases));
        if (! ctx->layers[i].biases) {
            pr_err("cannot allocate network layer memory\n");
            ret = NIERROR(ENOMEM);
            goto fail_out;
        }
        memcpy(ctx->layers[i].biases, &g_biases[0], sizeof(g_biases));

        pr_log("network layer %d: w %d, h %d, ch %d, co %d, cl %d\n", i,
                ctx->layers[i].width, ctx->layers[i].height,
                ctx->layers[i].channel, ctx->layers[i].component,
                ctx->layers[i].classes);
    }

    ctx->entry_set.obj_entry = 4;
    ctx->entry_set.class_entry = 13;
    ctx->entry_set.coods_entry = 0;

    ctx->det_cache.dets_num = 0;
    ctx->det_cache.capacity = 20;
    ctx->det_cache.dets = malloc(sizeof(detection) * ctx->det_cache.capacity);
    if (!ctx->det_cache.dets) {
        pr_err("failed to allocate detection cache\n");
        ret = NIERROR(ENOMEM);
        goto fail_out;
    }

    return 0;

fail_out:
    return ret;
}

int create_lic_detect_model(LicDetModelCtx *ctx, ni_network_data_t *network_data,
        int pic_width, int pic_height, float obj_thresh, float nms_thresh)
{
    int ret;

    ret = create_lic_det_yolov5_model(&ctx->yolo_model_ctx, network_data,
            pic_width, pic_height, obj_thresh, nms_thresh);
    if (ret != 0) {
        return ret;
    }

    ctx->scale_ctx = sws_getContext(pic_width, pic_height, AV_PIX_FMT_YUV420P,
            ctx->yolo_model_ctx.input_width, ctx->yolo_model_ctx.input_height,
            AV_PIX_FMT_GBR24P, SWS_BICUBIC, NULL, NULL, NULL);
    if (!ctx->scale_ctx) {
        ni_err("cannot get scale ctx\n");
        ret = NIERROR(ENOMEM);
        goto fail_out;
    }

    return 0;

fail_out:
    destroy_lic_detect_model(ctx);
    return ret;
}

void destroy_lic_rec_model(LicRecModelCtx *ctx)
{
    if (ctx == NULL) {
        return;
    }

    if (ctx->out_tensor[0]) {
        free(ctx->out_tensor[0]);
        ctx->out_tensor[0] = NULL;
    }
    if (ctx->characters) {
        free(ctx->characters);
        ctx->characters = NULL;
    }

    if (ctx->rgb_planar_data) {
        free(ctx->rgb_planar_data);
        ctx->rgb_planar_data = NULL;
    }

    if (ctx->plate_result) {
        free(ctx->plate_result);
        ctx->plate_result = NULL;
    }
    if (ctx->plate_string) {
        free(ctx->plate_string);
        ctx->plate_string = NULL;
    }

    av_frame_unref(&ctx->persp_frame);
    av_frame_unref(&ctx->rgb_frame);

    if (ctx->rgb_sws_ctx) {
        sws_freeContext(ctx->rgb_sws_ctx);
        ctx->rgb_sws_ctx = NULL;
    }
}

int create_lic_rec_model(LicRecModelCtx *ctx, ni_network_data_t *network_data)
{
    int ret;

    ni_network_layer_params_t *param = &network_data->linfo.in_param[0];
    ctx->input_width  = param->sizes[0];
    ctx->input_height = param->sizes[1];

    param = &network_data->linfo.out_param[0];

    ctx->chars_num = param->sizes[0];
    ctx->max_chars = param->sizes[1];

    ctx->out_tensor[0] = malloc(sizeof(float) * ctx->chars_num * ctx->max_chars);
    if (ctx->out_tensor[0] == NULL) {
        pr_err("cannot allocate lic_rec model tensor buffer\n");
        return -1;
    }

    ctx->characters = malloc(sizeof(int) * ctx->max_chars);
    if (ctx->characters == NULL) {
        pr_err("cannot allocate lic_rec model characters buffer\n");
        ret = NIERROR(ENOMEM);
        goto fail_out;
    }

    memset(&ctx->persp_frame, 0, sizeof(ctx->persp_frame));
    ctx->persp_frame.width = ctx->input_width;
    ctx->persp_frame.height = ctx->input_height;
    ctx->persp_frame.format = AV_PIX_FMT_YUV420P;
    if (av_frame_get_buffer(&ctx->persp_frame, 32)) {
        pr_err("cannot get frame buffer\n");
        ret = NIERROR(ENOMEM);
        goto fail_out;
    }

    memset(&ctx->rgb_frame, 0, sizeof(ctx->rgb_frame));
    ctx->rgb_frame.width = ctx->input_width;
    ctx->rgb_frame.height = ctx->input_height;
    ctx->rgb_frame.format = AV_PIX_FMT_GBR24P;
    if (av_frame_get_buffer(&ctx->rgb_frame, 32)) {
        pr_err("cannot get frame buffer\n");
        ret = NIERROR(ENOMEM);
        goto fail_out;
    }

    ctx->rgb_sws_ctx = sws_getContext(ctx->input_width, ctx->input_height, AV_PIX_FMT_YUV420P,
            ctx->rgb_frame.width, ctx->rgb_frame.height,
            ctx->rgb_frame.format, SWS_BICUBIC, NULL, NULL, NULL);
    if (!ctx->rgb_sws_ctx) {
        pr_err("cannot get scale ctx\n");
        ret = NIERROR(ENOMEM);
        goto fail_out;
    }

    ctx->rgb_planar_data = malloc(sizeof(float) * ctx->input_width * ctx->input_height * 3);
    if (ctx->rgb_planar_data == NULL) {
        pr_err("cannot allocate data\n");
        ret = NIERROR(ENOMEM);
        goto fail_out;
    }

    ctx->plate_str_len = sizeof(wchar_t) * ctx->max_chars;
    ctx->plate_result = malloc(ctx->plate_str_len);
    if (ctx->plate_result == NULL) {
        pr_err("cannot allocate plate result\n");
        ret = NIERROR(ENOMEM);
        goto fail_out;
    }
    ctx->plate_string = malloc(ctx->plate_str_len);
    if (ctx->plate_string == NULL) {
        pr_err("cannot allocate plate string\n");
        ret = NIERROR(ENOMEM);
        goto fail_out;
    }

    pr_log("lic_rec model dim: %dx%d\n", ctx->input_width, ctx->input_height);
    return 0;

fail_out:
    destroy_lic_rec_model(ctx);
    return ret;
}

static wchar_t plate_chars[] = L"#京沪津渝冀晋蒙辽吉黑苏浙皖闽赣鲁豫鄂湘粤桂琼川贵云藏陕甘青宁新学警港澳挂使领民航危0123456789ABCDEFGHJKLMNPQRSTUVWXYZ险品";

int get_plate(LicRecModelCtx *ctx, wchar_t *plate_result)
{
    int k, n;
    for (n = 0; n < ctx->max_chars; n++) {
        int max_index = 0;
        float max = ctx->out_tensor[0][ctx->chars_num * n];
        for (k = 1; k < ctx->chars_num; k++) {
            if (max < ctx->out_tensor[0][ctx->chars_num * n + k]) {
                max = ctx->out_tensor[0][ctx->chars_num * n + k];
                max_index = k;
            }
        }
        ctx->characters[n] = max_index;
    }

    k = 0;
    int last_char_index = -1;
    for (n = 0; n < ctx->max_chars; n++) {
        if (ctx->characters[n] != 0 && last_char_index != ctx->characters[n]) {
            plate_result[k++] = plate_chars[ctx->characters[n]];
            last_char_index = ctx->characters[n];
        }
    }

    return k;
}
