#!/usr/bin/env python3
#
# Copyright (c) "2022" Advanced Micro Devices, Inc. All rights reserved.

import argparse
import os
import re
import csv
import math
from collections import namedtuple

instance_column = namedtuple(
    "instance_column",
    ["instance_id", "mean", "std_deviation", "latency_data"]
)

RENDER_LATENCY_CSV_FILE_NAME = "render.csv"
ENCODE_LATENCY_CSV_FILE_NAME = "encode.csv"

OUTLIER_Z_SCORE_THRESHOLD = 15.0

def read_instance_data(instance_id, instance_latency_csv_path, frames):
    latency_data = []
    try:
        with open(instance_latency_csv_path, "r") as csv_file:
            reader = csv.DictReader(csv_file)
            latency_data_key = None
            frame_num_key = None
            for field in reader.fieldnames:
                if re.search(r'latency', field) is not None:
                    latency_data_key = field
                elif re.search(r'frame', field) is not None:
                    frame_num_key = field
            assert(latency_data_key is not None and frame_num_key is not None)

            rows = []
            for i, row in enumerate(reader):
                if i > frames:
                    break
                latency_data.append(float(row[latency_data_key]))
                rows.append(row)
            
            mean = sum(latency_data) / len(latency_data)
            std_deviation = math.sqrt(sum(map(lambda x: (x - mean) ** 2, latency_data)) / len(latency_data))

            for i, row in enumerate(rows):
                z_score = (latency_data[i] - mean) / std_deviation
                if abs(z_score) > OUTLIER_Z_SCORE_THRESHOLD:
                    print("Found outliner: %s, frame '%s', latency '%s'!"
                            % (instance_latency_csv_path, row[frame_num_key], row[latency_data_key]))
            
            return instance_column(instance_id, mean, std_deviation, latency_data)

    except IOError:
        print("Fail to get instance latency file: ", instance_latency_csv_path)
        return None

def write_all_instance_columns_to_csv(file_name, all_instance_columns):
    all_instance_columns.sort(key=lambda col: col.mean, reverse=True)

    if len(all_instance_columns) == 0:
        print("No data found for writing ", file_name)
        return

    latency_data_samples = len(all_instance_columns[0].latency_data)
    columns = []
    columns.append(["instance_id", "mean", "std deviation"] + list(range(latency_data_samples)))
    for col in all_instance_columns:
        columns.append([col.instance_id, col.mean, col.std_deviation] + col.latency_data)

    with open(file_name, "w") as csv_file:
        csv_writer = csv.writer(csv_file)
        for i in range(len(columns[0])):
            row = []
            for col in columns:
                row.append(col[i])
            csv_writer.writerow(row)

def main(log_dir: str, frames: int):
    max_std_deviation = {RENDER_LATENCY_CSV_FILE_NAME: 0, ENCODE_LATENCY_CSV_FILE_NAME: 0}
    max_std_deviation_instance_id = {RENDER_LATENCY_CSV_FILE_NAME: "", ENCODE_LATENCY_CSV_FILE_NAME: ""}
    avg_latency = {RENDER_LATENCY_CSV_FILE_NAME: 0, ENCODE_LATENCY_CSV_FILE_NAME: 0}
    merged_latencies = {RENDER_LATENCY_CSV_FILE_NAME: [], ENCODE_LATENCY_CSV_FILE_NAME: []}
    for subdir_name in os.listdir(log_dir): 
        instance_id = os.path.basename(subdir_name)
        subdir_path = os.path.join(log_dir, subdir_name)

        for file_name in (RENDER_LATENCY_CSV_FILE_NAME, ENCODE_LATENCY_CSV_FILE_NAME):
            csv_file_path = os.path.join(subdir_path, file_name)
            latency_column = read_instance_data(instance_id, csv_file_path, frames)
            if latency_column is not None:
                if latency_column.std_deviation > max_std_deviation[file_name]:
                    max_std_deviation[file_name] = latency_column.std_deviation
                    max_std_deviation_instance_id[file_name] = latency_column.instance_id 
                instance_num = len(merged_latencies[file_name])
                avg_latency[file_name] = (avg_latency[file_name] * instance_num + latency_column.mean) / (instance_num + 1)
                merged_latencies[file_name].append(latency_column)

    for file_name, latency_columns in merged_latencies.items():
        write_all_instance_columns_to_csv(file_name, latency_columns)

    for key, value in max_std_deviation_instance_id.items():
        print("%s: instance with largest latency variance is '%s'" % (key, value))

    for key, value in avg_latency.items():
        print("%s: average latency is '%f'" % (key, value))

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Run multiple encoding instances')
    parser.add_argument('-l', '--log', dest='log_dir', type=str, required=True,
        help='Given the log directory, parse latency data individual instances and merge to a single data sheet')
    parser.add_argument('-z', '--outlier_threshold', dest='z_score_threshold', type=float, required=False,
        default=OUTLIER_Z_SCORE_THRESHOLD,
        help='Set z score thresold for detection outliers')
    parser.add_argument('-f', '--frames', dest='frames', type=int, required=False,
        default=1000,
        help='Merge only the first "frames" data into final table')
    args = parser.parse_args()

    OUTLIER_Z_SCORE_THRESHOLD = args.z_score_threshold

    if os.path.isabs(args.log_dir):
        log_dir_path = args.log_dir
    else:
        log_dir_path = os.path.abspath(os.path.join(os.getcwd(), args.log_dir))
    
    main(log_dir_path, args.frames)

