""" Copyright start
  Copyright (C) 2008 - 2024 Fortinet Inc.
  All rights reserved.
  FORTINET CONFIDENTIAL & FORTINET PROPRIETARY SOURCE CODE
  Copyright end """

import os
from helpers.logger import Logger
from helpers.cmd_utils import CmdUtils
from framework.base.tasks import Tasks
from constants import REPORT_FILE, STEP_RESULT_FILE, LOG_FILE, FRAMEWORK_PATH
from helpers.utils import Utilities

TASK_LOG_STATUS = {"STARTED":"STARTED","COMPLETED":"COMPLETED"}

"""This task file should get execute at last of execution"""

class FixSshd(Tasks):
    TASK_STATUS_MSG = "Fix SSHD"
    
    def __init__(self) -> None:
        super().__init__()
        self.logger = Logger.get_logger(__name__)
        self.utilities = Utilities()
        self.cmd_line_utilities = CmdUtils()
        self.store_result_path = STEP_RESULT_FILE.format(self.current_version)
        self.report_path = REPORT_FILE.format(self.current_version)

    @property
    def tags(self) -> str:
        return 'post-upgrade'

    def get_description(self) -> str:
        return ""

    def is_supported(self) -> bool:
        current_version = int(self.current_version.replace('.', ''))
        target_upgrade_version = int(
            self.target_upgrade_version.replace('.', ''))
        return target_upgrade_version >= current_version
    
    def add_banner_in_log_file(self, msg:str, status: str) -> None:
        status_msg = " [{:^11}] {} {} ".format(status,":",msg)
        border_length = len(status_msg)
        border = '='*border_length
        new_line_char = "\n" if status==TASK_LOG_STATUS["STARTED"] else "\n\n"
        final_msg = f"{status_msg}{new_line_char}"
        if os.path.exists(LOG_FILE):
            with open(LOG_FILE,'a') as log_file:
                log_file.write(final_msg)

    def execute(self):
        self.add_banner_in_log_file(self.TASK_STATUS_MSG,TASK_LOG_STATUS["STARTED"])
        self._fix_sshd()
        self._show_message()
        self.add_banner_in_log_file(self.TASK_STATUS_MSG,TASK_LOG_STATUS["COMPLETED"])

    def validate(self) -> bool:
        return True

    def _fix_sshd(self):
        try:
            # 1. Cipher Block Chaining (CBC) encryption may allow an attacker to recover the plaintext message from the ciphertext
            #   So need to remove CBC config setting
            # 2. The remote SSH server is configured to allow key exchange algorithms which are considered weak.
            #   This is based on the IETF draft document Key Exchange (KEX) Method Updates and Recommendations for Secure Shell (SSH) draft-ietf-curdle-ssh-kex-sha2-20.
            # Section 4 lists guidance on key exchange algorithms that SHOULD NOT and MUST NOT be enabled. This includes:
            # diffie-hellman-group-exchange-sha1
            # diffie-hellman-group1-sha1
            # gss-gex-sha1-*
            # gss-group1-sha1-*
            # gss-group14-sha1-*
            # rsa1024-sha1

            step_result = self.get_step_results('pre-upgrade', 'initialize')
            d_crypto_policies_back_ends = step_result['d_crypto_policies_back_ends']
            l_ssh_config_files = step_result['l_ssh_config_files']
            l_vulnerable_ciphertext_and_algorithms = ["aes128-cbc", "aes256-cbc", "diffie-hellman-group-exchange-sha1", "diffie-hellman-group1-sha1", "gss-gex-sha1-", "gss-group1-sha1-", "gss-group14-sha1-", "rsa1024-sha1"]

            for f_ssh_config in l_ssh_config_files:
                for s_vulnerable_ciphertext_and_algorithms in l_vulnerable_ciphertext_and_algorithms:
                    cmd = f"sed -i --follow-symlinks \"s/\(,|:\)\?{s_vulnerable_ciphertext_and_algorithms}//g\" \"{d_crypto_policies_back_ends}/{f_ssh_config}\""
                    self.cmd_line_utilities.execute_cmd(cmd,True)
        except Exception as ex:
            err_msg = "ERROR: {}".format(ex)
            self.logger.exception(err_msg)
            print(
                f"Exception occurred at fix sshd task. Refer logs at '{LOG_FILE}'"
            )
        
    def _show_message(self):
        step_result = self.get_step_results('pre-upgrade', 'initialize')
        var_post_upgrade_msg = step_result['var_post_upgrade_msg']
        self.print_txt("\nUpgrade process is now complete.\n")
        # Display the post upgrade failed messages
        if var_post_upgrade_msg:
            for msg in var_post_upgrade_msg:
                self.print_txt(msg)
        self.print_txt("Note:")
        self.print_txt(
            "You may check the RPM upgrade logs in the directory '/var/log/cyops/install'.")
        self.print_txt(
            "In case there are any failures in the logs, contact support.")

        self._system_reboot()

    # The system_reboot() function checks if instance required to reboot or not.
    def _system_reboot(self):
        # The needs-restarting command only report whether a
        # full reboot is required (returns 1) or not (returns 0).
        cmd = "needs-restarting -r"
        result = self.cmd_line_utilities.execute_cmd(cmd)
        if result['return_code'] != 0:
            self.print_txt("It is recommended to reboot the instance to ensure that your system benefits from updates.\nWould you like to reboot the system [y/n] ?")
            input = self._yes_no_user_input()
            if input:
                workflow_path = os.path.join(FRAMEWORK_PATH, "workflow")
                self.utilities.remove_contents(workflow_path)
                os.system("reboot")
        # If there is no update in the system package, then above exit 0 will exit the script.
        
    def _yes_no_user_input(self):
        yes_list = ["y","Y","Yes","yes","YES"]
        no_list = ["n","N","No","no","NO"]
        i_count=0
        while i_count <=2:
            user_input = input()
            if user_input in yes_list:
                return True
            elif user_input in no_list:
                return False
            else:
                print("You have provided an invalid input. Enter Yes or No")   
                i_count+=1 
        if  i_count > 2:
            print("Max retries reached, exiting the upgrade.")   
