#!/usr/bin/env python3

import argparse
import glob
import logging
import os
import sys
import xml.etree.ElementTree as ET


ATTR_ID = "{http://schemas.android.com/apk/res/android}id"
ATTR_NAME = "{http://schemas.android.com/apk/res/android}name"
ATTR_START_DESTINATION = "{http://schemas.android.com/apk/res-auto}startDestination"
ATTR_DESTINATION = "{http://schemas.android.com/apk/res-auto}destination"
ATTR_ARG_TYPE = "{http://schemas.android.com/apk/res-auto}argType"


class Data:
    navigations = {}
    start_destinations = {}
    destination_ids = []
    destination_names = []


def get_id_from_attr(id):
    return id.split("/")[-1]


def get_start_destination_from_navigation(data, elem):
    if ATTR_ID not in elem.attrib:
        # Graphs may not have id
        return
    navigation_id = get_id_from_attr(elem.attrib[ATTR_ID])
    if ATTR_START_DESTINATION not in elem.attrib:
        raise AssertionError(f"Graph '{navigation_id}' has no start destination")
    navigation_start = get_id_from_attr(elem.attrib[ATTR_START_DESTINATION])
    data.navigations[navigation_id] = navigation_start
    data.start_destinations[navigation_start] = []


def collect_destinations(data, xml_files):
    for xml_file in xml_files:
        tree = ET.parse(xml_file)
        get_start_destination_from_navigation(data, tree.getroot())
        for elem in tree.findall(".//navigation"):
            get_start_destination_from_navigation(data, elem)


def get_node_args(root, node_id):
    arguments = []
    for elem in root.findall("argument"):
        if ATTR_NAME not in elem.attrib:
            raise AssertionError(f"Argument of '{node_id}' has no name")
        argument_name = elem.attrib[ATTR_NAME]
        if ATTR_ARG_TYPE not in elem.attrib:
            raise AssertionError(f"Argument '{argument_name}' of '{node_id}' has no type")
        arguments.append({
            "name": argument_name,
            "type": elem.attrib[ATTR_ARG_TYPE],
        })
    return arguments


def collect_start_destination_args(data, xml_files):
    for xml_file in xml_files:
        tree = ET.parse(xml_file)
        for elem in tree.findall(".//fragment"):
            if ATTR_ID not in elem.attrib:
                raise AssertionError(f"Found destination in '{xml_name}' with no id")
            dest_id = get_id_from_attr(elem.attrib[ATTR_ID])

            # Each destination should have a unique ID.
            # Things may break in unexpected ways if this is not true.
            if dest_id in data.destination_ids:
                logging.warning(f"Found duplicate destination id '{dest_id}'")
            data.destination_ids.append(dest_id)

            if ATTR_NAME not in elem.attrib:
                raise AssertionError(f"Destination '{dest_id}' has no name")
            dest_name = get_id_from_attr(elem.attrib[ATTR_NAME])

            # Reusing Fragments for different destinations with a different set
            # of arguments will silently break the safe-args plugin.
            if dest_name in data.destination_names:
                logging.warning(f"Duplicate destination name '{dest_name}'")
            data.destination_names.append(dest_name)

            if dest_id not in data.start_destinations:
                continue
            arguments = get_node_args(elem, dest_id)
            data.start_destinations[dest_id] = arguments


def are_actions_valid(data, xml_files):
    def output(is_error, msg):
        if is_error:
            logging.error(msg)
        else:
            logging.info(msg)

    mismatch_found = False
    for xml_file in xml_files:
        tree = ET.parse(xml_file)
        for elem in tree.findall(".//action"):
            xml_name = os.path.basename(xml_file)
            if ATTR_ID not in elem.attrib:
                raise AssertionError(f"Found action in '{xml_name}' with no id")
            action_id = get_id_from_attr(elem.attrib[ATTR_ID])

            if ATTR_DESTINATION not in elem.attrib:
                raise AssertionError(f"Action '{action_id}' in '{xml_name}' has no destination")
            dest_id = get_id_from_attr(elem.attrib[ATTR_DESTINATION])

            if dest_id not in data.navigations:
                if dest_id not in data.destination_ids:
                    # This can happen if you forget to pass some xml file
                    raise AssertionError(f"Could not find {dest_id} for action {action_id}")
                continue

            action_args = get_node_args(elem, action_id)
            start_dest = data.navigations[dest_id]
            dest_args = data.start_destinations[start_dest]

            dest_minus_action = [i for i in dest_args if i not in action_args]
            action_minus_dest = [i for i in action_args if i not in dest_args]
            is_error = dest_minus_action or action_minus_dest
            mismatch_found = mismatch_found or is_error
            output(is_error, "")
            output(is_error, f"{xml_name}: {action_id} -> {dest_id} ({start_dest})")
            output(is_error, f"  Dest  : {dest_args}")
            output(is_error, f"  Action: {action_args}")

    return mismatch_found


def are_graphs_valid(data, xml_files):
    collect_destinations(data, xml_files)
    collect_start_destination_args(data, xml_files)
    return are_actions_valid(data, xml_files)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("xml_files", metavar="PATH", type=str, nargs="+",
                        help="The paths to the navigation graphs")
    parser.add_argument("--verbose", "-v", action='store_true', default=False)
    args = parser.parse_args()

    logging_level = logging.INFO if args.verbose else logging.WARNING
    logging.basicConfig(format='%(message)s', level=logging_level)

    xml_files = set()
    for path in args.xml_files:
        xml_files.add(os.path.abspath(path))

    data = Data()
    are_graphs_valid(data, sorted(xml_files))
