""" Copyright start
  Copyright (C) 2008 - 2025 Fortinet Inc.
  All rights reserved.
  FORTINET CONFIDENTIAL & FORTINET PROPRIETARY SOURCE CODE
  Copyright end """

import os
import sys
import json
from framework.base.tasks import Tasks
from helpers.cmd_utils import CmdUtils
from helpers.logger import Logger
from constants import LOG_FILE

TASK_STATUS = {"DONE": "DONE", "FAILED": "FAILED"}
TASK_LOG_STATUS = {"STARTED": "STARTED", "COMPLETED": "COMPLETED"}
TEXT_COLOR = {
    "GREEN": "\033[92m",
    "RED": "\033[91m",
    "YELLOW": "\033[93m",
    "RESET": "\033[0m",
}
TEXT_DECORATION = {"BLINK": "\033[5m", "BOLD": "\033[1m", "RESET": "\033[0m"}


class ChangeNodeMode(Tasks):
    MSG_WIDTH = 80
    SPACE_BEFORE_MSG = 2
    PHASE_START_MSG = "Started executing {0} phase."
    PHASE_END_MSG = "Completed execution of {0} phase."
    TASK_STATUS_MSG = "Updating mode of the current node"

    def __init__(self) -> None:
        super().__init__()
        self.cmd_line_utilities = CmdUtils()
        self.logger = Logger.get_logger(__name__)

    @property
    def tags(self) -> str:
        return "pre-upgrade"

    def get_description(self) -> str:
        return ""

    def is_supported(self) -> bool:
        step_result = self.get_step_results("pre-upgrade", "initialize")
        flag_is_enterprise = step_result["flag_is_enterprise"]
        current_version = int(self.current_version.replace(".", ""))
        is_ha = self.is_ha()
        return current_version > 760 and flag_is_enterprise and is_ha

    def execute(self):
        self.add_banner_in_log_file(self.TASK_STATUS_MSG, TASK_LOG_STATUS["STARTED"])
        upgrade_mode = "upgrade"
        current_mode = self.__get_node_mode()
        print(
            f"Current node mode is '{current_mode}'\nChanging node mode to '{upgrade_mode}'..."
        )
        # operational -> upgrade
        mode_change_cmd = f"csadm system env --mode {upgrade_mode}"
        self.cmd_line_utilities.execute_cmd(mode_change_cmd, True)
        mode = self.__get_node_mode()

        if mode != upgrade_mode:
            msg = f"Failed to change the mode from '{mode}' to '{upgrade_mode}'"
            color = TEXT_COLOR["RED"]
            reset = TEXT_COLOR["RESET"]
            colored_msg = f"{color}ERROR:{reset} {msg}"
            print(colored_msg)
            self.logger.error(msg)
            self.add_banner_in_log_file(
                self.TASK_STATUS_MSG, TASK_LOG_STATUS["COMPLETED"]
            )
            self._print_status_msg(self.TASK_STATUS_MSG, TASK_STATUS["FAILED"])
            self._show_start_or_end_msg("preupgrade", "end")
            sys.exit()

        self.add_banner_in_log_file(self.TASK_STATUS_MSG, TASK_LOG_STATUS["COMPLETED"])
        self._print_status_msg(self.TASK_STATUS_MSG, TASK_STATUS["DONE"])

    def validate(self) -> bool:
        return True

    def _show_start_or_end_msg(self, phase, position=None):
        if position == "start":
            msg = self.PHASE_START_MSG.format(phase)
        if position == "end":
            msg = self.PHASE_END_MSG.format(phase)
        border = "=" * (self.MSG_WIDTH + self.SPACE_BEFORE_MSG)
        final_msg = (" " * self.SPACE_BEFORE_MSG) + msg
        print("\n{}\n{}\n{}".format(border, final_msg, border))

    def add_banner_in_log_file(self, msg: str, status: str) -> None:
        status_msg = " [{:^11}] {} {} ".format(status, ":", msg)
        border_length = len(status_msg)
        border = "=" * border_length
        new_line_char = "\n" if status == TASK_LOG_STATUS["STARTED"] else "\n\n"
        final_msg = f"{status_msg}{new_line_char}"
        if os.path.exists(LOG_FILE):
            with open(LOG_FILE, "a") as log_file:
                log_file.write(final_msg)

    def _print_status_msg(self, msg, status):
        reset = TEXT_COLOR["RESET"]
        if status == TASK_STATUS["DONE"]:
            color = TEXT_COLOR["GREEN"]
        else:
            color = TEXT_COLOR["RED"]
        truncated_message = msg[:65] + "..." if len(msg) > 65 else msg
        width = 8
        status = f"{status:^{width}}"
        colored_status = f"{color}{status}{reset}"
        final_msg = "{:<70}{}[{}]".format(truncated_message, " ", colored_status)
        print(final_msg)

    def __get_device_uuid(self):
        get_device_uuid_cmd = "csadm license --get-device-uuid"
        device_uuid = self.cmd_line_utilities.execute_cmd(get_device_uuid_cmd, True)
        return device_uuid["std_out"].strip()

    def __get_node_mode(self):
        node_details = self.all_nodes_details()
        mode = "operational"
        if len(node_details) > 0:
            device_uuid = self.__get_device_uuid()
            for node in node_details:
                if node["nodeId"] == device_uuid and "mode" in node:
                    mode = node["mode"]
        return mode

    def is_ha(self):
        node_details = self.all_nodes_details()
        return True if len(node_details) > 1 else False

    def all_nodes_details(self) -> list:
        output_list = []
        output_list_temp = []
        data_keys = ["nodeId", "nodeName", "status", "role"]
        all_nodes_details = "csadm ha list-nodes"
        cmd_output = self.cmd_line_utilities.execute_cmd(all_nodes_details, True)

        if cmd_output["std_out"]:
            output_list = cmd_output["std_out"].split("\n")

        for item in output_list:
            if "+---" in item or "+===" in item or "---" in item:
                continue
            else:
                node_detail = [
                    element
                    for element in item.split()
                    if element and element not in ["*", "|"]
                ]

                if node_detail:
                    output_list_temp.append(node_detail)

        output_list = output_list_temp
        node_details_keys = output_list.pop(0)

        node_details = []
        for node_detail_list in output_list:
            node_detail = {}
            for data_key in data_keys:
                if data_key in node_details_keys:
                    data_key_index = node_details_keys.index(data_key)
                    if data_key_index < len(node_detail_list):
                        # node_details_keys[data_key_index] => This will provide key
                        # node_detail_list[data_key_index] => This will provide value for key
                        node_detail[node_details_keys[data_key_index]] = (
                            node_detail_list[data_key_index]
                        )

            if (
                "comment" in node_details_keys
                and "mode" in node_details_keys
                and node_details_keys.index("comment") < node_details_keys.index("mode")
            ):
                mode_index = node_details_keys.index("mode")
                mode_value_index = mode_index - len(node_details_keys)
                node_detail["mode"] = node_detail_list[mode_value_index]
            node_details.append(node_detail)
        return node_details
