""" Copyright start
  Copyright (C) 2008 - 2025 Fortinet Inc.
  All rights reserved.
  FORTINET CONFIDENTIAL & FORTINET PROPRIETARY SOURCE CODE
  Copyright end """

import os
from helpers.logger import Logger
from helpers.cmd_utils import CmdUtils
from framework.base.tasks import Tasks
from constants import REPORT_FILE, STEP_RESULT_FILE, LOG_FILE

TASK_STATUS = {"DONE":"DONE", "FAILED":"FAILED"}
TASK_LOG_STATUS = {"STARTED":"STARTED","COMPLETED":"COMPLETED"}
TEXT_COLOR = {'GREEN':'\033[92m', 'RED':'\033[91m', 'YELLOW':'\033[93m', 'RESET':'\033[0m'}
TEXT_DECORATION = {'BLINK':'\033[5m', 'BOLD':'\033[1m','RESET':'\033[0m'}

class HandleCachePostUpgradeE(Tasks):
    TASK_STATUS_MSG = "Refresh cache post upgrade"
    
    def __init__(self) -> None:
        super().__init__()
        self.logger = Logger.get_logger(__name__)
        self.cmd_line_utilities = CmdUtils()
        self.store_result_path = STEP_RESULT_FILE.format(self.current_version)
        self.report_path = REPORT_FILE.format(self.current_version)

    @property
    def tags(self) -> str:
        return 'post-upgrade'

    def get_description(self) -> str:
        return ""

    def is_enterprise(self):
        step_result = self.get_step_results("pre-upgrade", "initialize")
        return step_result["flag_is_enterprise"]

    def is_supported(self) -> bool:
        current_version = int(self.current_version.replace('.', ''))
        target_upgrade_version = int(
            self.target_upgrade_version.replace('.', ''))
        return target_upgrade_version >= current_version and self.is_enterprise()

    def execute(self):
        self.add_banner_in_log_file(self.TASK_STATUS_MSG,TASK_LOG_STATUS["STARTED"]) 
        self._handle_cache_post_upgrade()   
        self.add_banner_in_log_file(self.TASK_STATUS_MSG,TASK_LOG_STATUS["COMPLETED"])

    def validate(self) -> bool:
        return True

    def _print_status_msg(self, msg, status):
        reset = TEXT_COLOR["RESET"]
        color = TEXT_COLOR["GREEN"] if status == TASK_STATUS["DONE"] else TEXT_COLOR["RED"]

        truncated_msg = msg[:65] + "..." if len(msg) > 65 else msg
        status_text = f"{status:^8}"
        colored_status = f"{color}{status_text}{reset}"
        print(f"{truncated_msg:<70} [{colored_status}]")
        
    def add_banner_in_log_file(self, msg:str, status: str) -> None:
        status_line = f"[{status:^11}] : {msg}"
        newline = "\n" if status == TASK_LOG_STATUS["STARTED"] else "\n\n"
        final_msg = f"{status_line}{newline}"
        if os.path.exists(LOG_FILE):
            with open(LOG_FILE, 'a') as log_file:
                log_file.write(final_msg)

    def _run_command(self, cmd, success_msg, error_msg, shell=False):
        result = self.cmd_line_utilities.execute_cmd(cmd, shell)
        if result['return_code'] != 0:
            if result['std_err']:
                self.logger.error(f"{error_msg} Following error occurred =>\n{result['std_err']}")
            return False, error_msg
        return True, success_msg

    # Handle cache post upgrade tasks
    def _handle_cache_post_upgrade(self):
        try:
            # Commands and messages
            tasks = [
                {
                    "cmd": "sudo -u nginx php /opt/cyops-api/bin/console app:type:class:map",
                    "success_msg": "Refreshing of the cache is complete",
                    "error_msg": "Fail to refresh cache",
                    "shell": False,
                    "key": "cache"
                }
            ]
            # Track results
            results = {}

            for task in tasks:
                success, message = self._run_command(
                    cmd=task["cmd"],
                    success_msg=task["success_msg"],
                    error_msg=task["error_msg"],
                    shell=task.get("shell", False)
                )
                results[task["key"]] = {"success": success, "message": message}

            # Determine overall status
            if all(res["success"] for res in results.values()):
                self._print_status_msg(self.TASK_STATUS_MSG, TASK_STATUS["DONE"])
            else:
                self._print_status_msg(self.TASK_STATUS_MSG, TASK_STATUS["FAILED"])
                for key, res in results.items():
                    if not res["success"]:
                        print(res["message"])

        except Exception as ex:
            self.logger.exception(f"ERROR: {ex}")
            self._print_status_msg(self.TASK_STATUS_MSG, TASK_STATUS["FAILED"])
            print("Exception occurred at cache post upgrade task. "
                f"Refer logs at '{LOG_FILE}'")
