#!/usr/bin/env python3
#
# Copyright (c) 2022-2023 Advanced Micro Devices, Inc. All rights reserved.

import enum
import os
import datetime
import subprocess
import argparse
import re
from collections import namedtuple
from multiprocessing import Pool 

InstanceResult = namedtuple(
    "InstanceResult",
    ["ret_code", "output", "error"]
)

InstanceStatistic = namedtuple(
    "InstanceStatistic",
    ["latency_mean_DMA", "latency_std_DMA", "latency_mean_encode", "latency_std_encode"]
)

ENABLE_PERF=False

def run_instance(output_dir: str, gpu_num: int, enc_num: int, frame_num: int, timeout_secs: float):
    command = ["./vkcubepp", "--c", str(frame_num), "--gpu_number", str(gpu_num), "--enc_number", str(enc_num), "--encoding_mode", "0", "--output_directory", output_dir, "--use_staging"]
    if ENABLE_PERF:
        os.mkdir(output_dir)
        command = ["perf", "record", "-e", "block:block_rq_issue", "-e", "block:block_rq_complete", "-a", "-o", os.path.join(output_dir, "perf.data")] + command

    print(command)

    proc= subprocess.Popen(
        command,
        env=os.environ.copy(),
        universal_newlines=True,
        stdout=subprocess.PIPE,
        stderr=subprocess.PIPE)

    # Change schedule to Round-Robin and set priority to max(99)
    subprocess.check_output(["chrt", "-r", "-p", "99", str(proc.pid)])

    try:
        stdouts, stderrs = proc.communicate(timeout=timeout_secs)
    except subprocess.TimeoutExpired:
        print("instance timeout detected!")
        proc.kill()
        stdouts, stderrs = proc.communicate()

    return InstanceResult(proc.returncode, stdouts, stderrs)


def run_instance_func(args):
    return run_instance(*args)


def process_results(output_dirs, results):
    for output_dir, res in zip(output_dirs, results):
        with open(os.path.join(output_dir, 'stdout.log'), 'w') as f:
            f.write(res.output)
        with open(os.path.join(output_dir, 'stderr.log'), 'w') as f:
            f.write(res.error)
    
    #latency_DMA_pattern = re.compile(
    #    r'^render_and_color_conversion_and_DMA.*latency_average: ([-+]?\d*\.\d+).*latency_std_deviation: ([-+]?\d*\.\d+).*$',
    #    re.M)        
    #latency_encode_pattern = re.compile(
    #    r'^encoding.*latency_average: ([-+]?\d*\.\d+).*latency_std_deviation: ([-+]?\d*\.\d+).*$',
    #    re.M)        
    #
    #instance_statistic = []
    #for res in results:
    #    match = re.search(latency_DMA_pattern, res.output)
    #    if match is not None:
    #        latency_mean_DMA = match.group(1)
    #        latency_std_DMA  = match.group(2)
    #    match = re.search(latency_encode_pattern, res.output)
    #    if match is not None:
    #        latency_mean_encode = match.group(1)
    #        latency_std_encode  = match.group(2)
    #    instance_statistic.append(
    #        InstanceStatistic(latency_mean_DMA, latency_std_DMA, latency_mean_encode, latency_std_encode)) 

    #instance_DMA_latency_tuple = [
    #    (instance, x.latency_mean_DMA, x.latency_std_DMA)
    #    for instance, x in enumerate(instance_statistic)
    #    if x.latency_mean_DMA is not None and x.latency_std_DMA is not None ]
    #
    #instances_no_DMA_latency = [
    #    instance for instance, x in enumerate(instance_statistic)
    #    if x.latency_mean_DMA is None or x.latency_std_DMA is None ]

    #count = min(10, len(results))

    #print("%d instances with highest DMA latency mean: " % count,
    #    [ instance for instance, _, _ in sorted(instance_DMA_latency_tuple, key=lambda x: x[1], reverse=True)[0:count] ])
    #print("%d instances with highest DMA latency variance: " % count,
    #    [ instance for instance, _, _ in sorted(instance_DMA_latency_tuple, key=lambda x: x[2], reverse=True)[0:count] ])
    #print("Instances no DMA latency number: ", instances_no_DMA_latency)

    #instance_encode_latency_tuple = [
    #    (i, x.latency_mean_encode, x.latency_std_encode)
    #    for i, x in enumerate(instance_statistic)
    #    if x.latency_mean_encode is not None and x.latency_std_encode is not None ]

    #instances_no_encode_latency = [
    #    i for i, x in enumerate(instance_statistic)
    #    if x.latency_mean_encode is None or x.latency_std_encode is None ]
    #
    #print("%d instances with highest encode latency mean: " % count,
    #    [ instance for instance, _, _ in sorted(instance_encode_latency_tuple, key=lambda x: x[1], reverse=True)[0:count] ])
    #print("%d instances with highest encode latency variance: " % count,
    #    [ instance for instance, _, _ in sorted(instance_encode_latency_tuple, key=lambda x: x[2], reverse=True)[0:count] ])
    #print("Instances no encode latency number: ", instances_no_encode_latency)

    results_group_by_ret_code = { res.ret_code: [] for res in results }
    for instance_index, res in enumerate(results):
        results_group_by_ret_code[res.ret_code].append((instance_index, res))
    
    ret_code_to_str = {
        0  : "Success",
        -2 : "Sig interuppt",
        -6 : "Sig abort",
        -9 : "Sig killed",
        -11: "Sig Segfault",
    }
    
    for ret_code, ret_code_keyed_results in sorted(results_group_by_ret_code.items(), key=lambda x:x[0], reverse=True):
        if ret_code in ret_code_to_str:
            print("%s instances count: %d" % (ret_code_to_str[ret_code], len(ret_code_keyed_results)), end=", ")
        else:
            print("Exit with unknown signal '%d' instances count: %d" % (ret_code, len(ret_code_keyed_results)), end=", ")
        print("list: ", [instance_index for instance_index, _ in ret_code_keyed_results])


def main(instance_num: int, frame_num: int, enc_num: int, gpu_num: int):
    fps = 30
    timeout_secs = max((frame_num / fps) * 2.1, 5)
    print("Set frames=", frame_num)
    print("Set timeout=", timeout_secs)

    log_dir_name = os.path.join(os.getcwd(), datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
    os.makedirs(log_dir_name)
    print("Instances outputs saved to directory ", log_dir_name)

    pool = Pool(instance_num)
    output_dir_args = [ '%s%sinstance_%d' % (log_dir_name, os.sep, i) for i in range(instance_num) ]
    gpu_num_args = [ i % gpu_num for i in range(instance_num) ]
    enc_num_args = [ i % enc_num for i in range(instance_num) ]
    frames_args = [frame_num] * instance_num
    timeouts_args = [timeout_secs] * instance_num

    print("Running %d instances.." % instance_num)
    instance_results = pool.map(run_instance_func, zip(output_dir_args, gpu_num_args, enc_num_args, frames_args, timeouts_args))

    pool.close()
    pool.join()

    process_results(output_dir_args, instance_results)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Run multiple encoding instances')
    parser.add_argument('--instance', dest='instance', type=int, required=True,
                        help='number of instances to run')
    parser.add_argument('--frames', dest='frames', type=int, default=1000,
                        help='number of frames for each instance to run')
    parser.add_argument('--perf', action='store_true',
                        help='Enable perf record on each instance')
    parser.add_argument('--no-perf', action='store_false',
                        help='Disable perf record on each instance')
    parser.add_argument('--enc_number', dest='enc_num', type=int, default=0,
                        help='encoder card  to run')
    parser.add_argument('--gpu_number', dest='gpu_num', type=int, default=0,
                        help='gpu #  to run')
    parser.set_defaults(perf=False)

    args = parser.parse_args()
    ENABLE_PERF=args.perf
    main(args.instance, args.frames, args.enc_num, args.gpu_num)
