""" Copyright start
  Copyright (C) 2008 - 2025 Fortinet Inc.
  All rights reserved.
  FORTINET CONFIDENTIAL & FORTINET PROPRIETARY SOURCE CODE
  Copyright end """

import json
from framework.base.tasks import Tasks
from helpers.cmd_utils import CmdUtils
from helpers.logger import Logger
from constants import LOG_FILE

TEXT_COLOR = {
    "GREEN": "\033[92m",
    "RED": "\033[91m",
    "YELLOW": "\033[93m",
    "RESET": "\033[0m",
}
TEXT_DECORATION = {"BLINK": "\033[5m", "BOLD": "\033[1m", "RESET": "\033[0m"}


# This file should be after os hardening task (fix_sshd) in post-upgrade phase
class ChangeNodeMode(Tasks):

    def __init__(self) -> None:
        super().__init__()
        self.cmd_line_utilities = CmdUtils()
        self.logger = Logger.get_logger(__name__)

    @property
    def tags(self) -> str:
        return "post-upgrade"

    def get_description(self) -> str:
        return ""

    def is_supported(self) -> bool:
        step_result = self.get_step_results("pre-upgrade", "initialize")
        flag_is_enterprise = step_result["flag_is_enterprise"]
        current_version = int(self.current_version.replace(".", ""))
        is_ha = self.is_ha()
        return current_version > 760 and flag_is_enterprise and is_ha

    def execute(self):
        upgrade_mode = "upgrade"
        operational_cmd = "csadm system env --mode operational"
        current_mode = self.__get_node_mode()
        # upgrade -> operational
        if current_mode == upgrade_mode:
            msg = f"Current node mode is '{current_mode}'\nDo the basic sanity and then change mode to 'operational' using following command\n'{operational_cmd}'"
            color = TEXT_COLOR["YELLOW"]
            decoration = TEXT_DECORATION["BLINK"]
            reset = TEXT_COLOR["RESET"]
            colored_msg = f"{color}{decoration}Note:{reset}"
            print(f"{colored_msg}\n{msg}")

    def validate(self) -> bool:
        return True

    def __get_device_uuid(self):
        get_device_uuid_cmd = "csadm license --get-device-uuid"
        device_uuid = self.cmd_line_utilities.execute_cmd(get_device_uuid_cmd, True)
        return device_uuid["std_out"].strip()

    def __get_node_mode(self):
        node_details = self.all_nodes_details()
        mode = "operational"
        if len(node_details) > 0:
            device_uuid = self.__get_device_uuid()
            for node in node_details:
                if node["nodeId"] == device_uuid and "mode" in node:
                    mode = node["mode"]
        return mode

    def is_ha(self):
        node_details = self.all_nodes_details()
        return True if len(node_details) > 1 else False

    def all_nodes_details(self) -> list:
        output_list = []
        output_list_temp = []
        data_keys = ["nodeId", "nodeName", "status", "role"]
        all_nodes_details = "csadm ha list-nodes"
        cmd_output = self.cmd_line_utilities.execute_cmd(all_nodes_details, True)

        if cmd_output["std_out"]:
            output_list = cmd_output["std_out"].split("\n")

        for item in output_list:
            if "+---" in item or "+===" in item or "---" in item:
                continue
            else:
                node_detail = [
                    element
                    for element in item.split()
                    if element and element not in ["*", "|"]
                ]

                if node_detail:
                    output_list_temp.append(node_detail)

        output_list = output_list_temp
        node_details_keys = output_list.pop(0)

        node_details = []
        for node_detail_list in output_list:
            node_detail = {}
            for data_key in data_keys:
                if data_key in node_details_keys:
                    data_key_index = node_details_keys.index(data_key)
                    if data_key_index < len(node_detail_list):
                        # node_details_keys[data_key_index] => This will provide key
                        # node_detail_list[data_key_index] => This will provide value for key
                        node_detail[node_details_keys[data_key_index]] = (
                            node_detail_list[data_key_index]
                        )

            if (
                "comment" in node_details_keys
                and "mode" in node_details_keys
                and node_details_keys.index("comment") < node_details_keys.index("mode")
            ):
                mode_index = node_details_keys.index("mode")
                mode_value_index = mode_index - len(node_details_keys)
                node_detail["mode"] = node_detail_list[mode_value_index]
            node_details.append(node_detail)
        return node_details
