#!/usr/bin/env python3

import sys
import argparse
import cv2
import os
import numpy as np
import netint.network

def build_nbg_model(model_file, device_id):
    interpreter = netint.network.Interpreter(model_path=model_file,
        dev_id=device_id)
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    if len(input_details) == 0 or len(output_details) == 0:
        print('invalid network layer')
        del interpreter
        sys.exit(1)
    return interpreter

def normalization(args, image):
    mean =np.array(args['mean'])
    scale =np.array(args['scale'])
    image = np.float32(image)
    image = image - mean[np.newaxis,np.newaxis,:]
    image = image * scale[np.newaxis,np.newaxis,:]
    return np.float32(image)

def main(**args):
    # define output path
    file_name =  os.path.splitext(os.path.basename(args['input_file']))[0] 
    image_path = os.path.join('./', '{}_res.png'.format(file_name)) 
    preprocess_path = os.path.join('./', '{}_preprocess.png'.format(file_name)) 
    
    # build NBG model
    model = build_nbg_model(args['model_file'], args['device_id'])
    output_details = model.get_output_details()
    input_details = model.get_input_details()

    # input pre-processing
    im = cv2.imread(args['input_file']) # img0: BGR format
    
    ori_img_size = [im.shape[0], im.shape[1], im.shape[2]]
    
    if args['channel_order'] == 'rgb':
        im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
    
    im_ori = im.copy()

    if args['add_prec_node'] == False:
        im = normalization(args, im)
        im_ori = normalization(args, im_ori)

    im = cv2.resize(im, (160, 90),interpolation=cv2.INTER_LINEAR) 

    im = np.transpose(im,(2,0,1)) # im: NHWC -> NCHW
    im_ori = np.transpose(im_ori,(2,0,1))

    # model inference
    im = np.expand_dims(im, axis=0)  # expand for batch dim
    im_ori = np.expand_dims(im_ori,axis=0)

    im_reverse = im.copy()
    im_reverse = im_reverse * 255
    im_reverse = np.clip(im_reverse, 0, 255)
    im_reverse = np.squeeze(im_reverse,axis=0)
    im_reverse = im_reverse.transpose(1,2,0).astype(np.uint8)
    im_reverse = cv2.cvtColor(im_reverse, cv2.COLOR_RGB2BGR)
    #cv2.imwrite(preprocess_path, im_reverse)
    
    model.set_tensor(input_details[0]['index'], im)
    model.invoke()
    inference_res = []
    for output_index in range(len(output_details)):
        output = model.get_tensor(output_index)
        inference_res.append(output)

    print(inference_res[0].shape)
    inference_res[0] = np.squeeze(inference_res[0],axis=0)
    inference_res[0] = inference_res[0].transpose(1,2,0)
    inference_res[0] = np.tanh(inference_res[0]) 
    inference_res[0] = cv2.resize(inference_res[0], (ori_img_size[1], ori_img_size[0]), interpolation=cv2.INTER_LINEAR)
    inference_res[0] = inference_res[0].transpose(2,0,1)
    
    x_r = np.expand_dims(inference_res[0],axis=0)
    x = im_ori

    x = x + x_r*(np.power(x,2)-x)
    x = x + x_r*(np.power(x,2)-x)
    x = x + x_r*(np.power(x,2)-x)
    x = x + x_r*(np.power(x,2)-x)		
    x = x + x_r*(np.power(x,2)-x)				
    x = x + x_r*(np.power(x,2)-x)	
    x = x + x_r*(np.power(x,2)-x)
    enhance_image = x + x_r*(np.power(x,2)-x)

    opt = enhance_image * 255
    opt = np.clip(opt, 0, 255)
    opt = np.squeeze(opt,axis=0)
    opt = opt.transpose(1,2,0).astype(np.uint8)
    opt = cv2.cvtColor(opt, cv2.COLOR_RGB2BGR)
    
    cv2.imwrite(image_path, opt) 


if  __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Test zero_dce')
    
    parser.add_argument("--model_file", type=str,default="./zerodce_rgb.nb", \
                        help='path to NBG model')
    parser.add_argument("--input_file", type=str, default="./data_ori/435_0_.png", \
                        help='path to input image')
    parser.add_argument('--file_format', '-f', type=str, default='nchw',
                        help='specify the model input format')
    parser.add_argument('--channel_order', '-c', type=str, default='rgb',
                        help='specify the order of channels')
    parser.add_argument('-d', '--device_id', type=int, default=0,
                        help='specify which device to run inferences')
    parser.add_argument("--mean", nargs='+', type=int,default=[0.0, 0.0, 0.0], 
                        help='value of mean for model')
    parser.add_argument("--scale", nargs='+', type=int,default=[0.00392157], 
                        help='value of scale for model')
    parser.add_argument('-a', '--add_prec_node',  action='store_true', default=False,\
                        help='the add preprocessing node open. The preprocessing steps is put into NBG model')
    argspar = parser.parse_args()    

    print("\n### Test zero_dce ###")
    print("> Parameters:")
    for p, v in zip(argspar.__dict__.keys(), argspar.__dict__.values()):
        print('\t{}: {}'.format(p, v))
    print('\n')
    main(**vars(argspar))
