#!/usr/bin/env python3


import cv2
import netint.network
import re
import numpy as np
import math
#import matplotlib.pyplot as plt
#import tensorflow.compat.v1 as tf
import os
import argparse
#tf.disable_eager_execution()



def AtmLight(g1):
    bins = 2000
    ht = np.histogram(g1, bins)
    d = np.cumsum(ht[0]) / float(g1.size)
    for lmax in range(bins - 1, 0, -1):
#        print(d[lmax], ht[1][lmax])
        if d[lmax] <= 0.999:
            break
    A = ht[1][lmax + 1]
    return A


def TransmissionEstimate(dark,A):
    omega = 0.75/A

    transmission = 1 - omega*dark
    return transmission



def build_nbg_model(model_file, device_id):
    interpreter = netint.network.Interpreter(model_path=model_file,dev_id=device_id)
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()

    if len(input_details) == 0 or len(output_details) == 0:
        print('invalid network layer')
        del interpreter
        sys.exit(1)
    return interpreter

def Recovery(y,u,v,t,A):
#    t = cv2.max(t,tx);
    t_raw = cv2.resize(t.reshape(1080,1920) ,(960,540)).reshape(1,540,960,1)

    y2 = 16  + (y-16.0)/t -0.859*A*255/t +0.859*A*255

    u2 = 128 + (u-128.0)/t_raw
    v2 = 128 + (v-128.0)/t_raw
    return y2,u2,v2

def postprocess(output0, output1, output2):
    frame = np.zeros([1620*1920])
    frame[:int(1080*1920)] =  output0.reshape(1080*1920)
    frame[int(1080*1920):int(1080*1920+540*960)] = output1.reshape(540*960)
    frame[int(1080*1920+540*960):] = output2.reshape(540*960)

    frame = np.array(frame)
    return frame.astype(np.uint8)


def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('-i', '--input_yuv', action='store', required=True)
    parser.add_argument('-o', '--output', action='store', default='res.yuv')
    parser.add_argument('-s', '--save_img', action='store_true')

    args = parser.parse_args()

    yuv_file = os.path.realpath(args.input_yuv)

    width, height = 1920, 1080
    uv_height = int(1080/4)

    NBfile_1 = "AI_model_1.nb"
    NBfile_2 = "AI_model_2.nb"

    file_size = os.path.getsize(yuv_file)
    n_frames = file_size // (width*height*3 // 2)
    f = open(yuv_file, 'rb')
    yuv_image = np.frombuffer(f.read(width*height*3//2), dtype=np.uint8).reshape((height*3//2, width))

    y = yuv_image[:height,:]
    tf_y = y.reshape(1,1080,1920,1).astype(np.float32) / 255

    u = yuv_image[height:height+uv_height,:]
    tf_u = u.reshape(1,540,960,1).astype(np.float32)

    v = yuv_image[height+uv_height:,:]
    tf_v = v.reshape(1,540,960,1).astype(np.float32)



    # ---------------- model 1 inference ----------------
    device_id = 0
    model = build_nbg_model(NBfile_1, device_id)
    output_details = model.get_output_details()
    input_details = model.get_input_details()

    model.set_tensor(input_details[0]['index'], tf_y)
    model.invoke()

    ds_y = model.get_tensor(0)
    ds_dark = model.get_tensor(1)

    A = AtmLight(ds_dark)
    A = np.clip(A, 0, 0.86)
    A = A - 16/255.0

    tf_A = (A).reshape(1,).astype(np.float32)

    te = TransmissionEstimate(ds_dark.reshape(270,480,1),A);

    tf_te = te.reshape(1,270,480,1).astype(np.float32)
    tf_ds_y = ds_y.reshape(1,270,480,1).astype(np.float32)


    # ----------------- model 2 inference -------------------

    device_id = 0
    model_second = build_nbg_model(NBfile_2, device_id)
    output_details = model_second.get_output_details()
    input_details = model_second.get_input_details()
    model_second.set_tensor(input_details[0]['index'], tf_ds_y)
    model_second.set_tensor(input_details[1]['index'], tf_te)
    model_second.invoke()

    t = model_second.get_tensor(0)

    tf_t = t.reshape(1,1080,1920,1).astype(np.float32)

    y2,u2,v2 = Recovery(tf_y*255,tf_u,tf_v,tf_t,tf_A)

    img_output = postprocess(np.clip(y2, 0, 255),np.clip(u2, 0, 255),np.clip(v2, 0, 255))

    output = np.clip(img_output, 0, 255)
    output.astype(np.uint8).tofile(f"{args.output}")
    if args.save_img:
        png_output = args.output.replace(".yuv", ".png")
        os.system(f"ffmpeg -y -c:v rawvideo -s:v {width}x{height} -pix_fmt yuv420p -i {args.output} ./{png_output}")

if __name__ == '__main__':
    main()

