from hcs.data.LinkedList import LinkedList
from hcs.data.Face import FaceData
from hcs.algo.SplitHull import split_hull
from hcs.algo.common import is_face_visible
from hcs.algo.common import get_visible_faces_recursive
from hcs.algo.common import triangulate_edges_with_node
from hcs.algo.common import get_neighbour_points
from hcs.algo.common import count_types
from hcs.algo.SplitHull import SHNodeData


__author__="Gernot WALZL"
__date__ ="2010-02-25"


def subset_conflict_walk(mesh_conv_q, mesh_conv_p, animator=None):
    scw = SubsetConflictWalk(mesh_conv_q, mesh_conv_p, animator)
    scw.run()


def rand_multi_split(mesh, animator=None):
    rms = RandMultiSplit(mesh, animator)
    return rms.run(True)



class RMSNodeData(SHNodeData):

    def __init__(self):
        SHNodeData.__init__(self)
        self.type_before = 0
        self.node_q = None
        self.node_p = None
        self.face_conflict = None
        self.face_nodes_conflict_list_element = None
        self.visited = False



class RMSFaceData(FaceData):

    def __init__(self):
        FaceData.__init__(self)
        self.nodes_conflict = LinkedList()



class SubsetConflictWalk:
    """
    1. Let queue be a queue with the elements in Q.
    2. While queue != 0:
       (a) Let p be the next point in queue.
       (b) If p not in Q, insert p into conv(Q), using a previously
           computed conflict facet f_p for p as a starting point.
       (c) For each neighbor q in Gamma_P(p), find a conflict facet f~_q
           in conv(Q u p), using Claim 3.3.
       (d) Using the f~_q's, find conflict facets f_q in F[Q] for
           Gamma_P(p). If q in Gamma_P(p) has not been encountered yet,
           insert it into queue.
    """

    def __init__(self, mesh_conv_q, mesh_conv_p, animator=None):
        self.mesh_conv_q = mesh_conv_q
        self.mesh_conv_p = mesh_conv_p
        self.animator = animator


    def remove_visible_faces(self, face_hint, node_visible):
        faces_visible = LinkedList()
        get_visible_faces_recursive(faces_visible, face_hint, node_visible)
        for face in faces_visible:
            self.mesh_conv_q.remove_face(face)
            face.highlight = False
        return faces_visible


    def insert_point(self, node_p):
        """
        O(1)
        """
        node_q = node_p.clone()
        node_q.data = node_p.data
        self.mesh_conv_q.add_node(node_q)
        node_q.data.node_q = node_q


    def remove_point(self, node_p):
        node_q = node_p.data.node_q
        all_nodes_conflict = LinkedList()
        for face in node_q.faces:
            if (face.data):
                for node in face.data.nodes_conflict:
                    all_nodes_conflict.add(node)
        self.mesh_conv_q.remove_node(node_q)
        node_q.data.node_q = None
        return all_nodes_conflict


    def is_left_face(self, face, edge):
        """
        O(1)
        """
        if (edge.node_a is face.node_a and
                edge.node_b is face.node_b):
            return True
        if (edge.node_a is face.node_b and
                edge.node_b is face.node_c):
            return True
        if (edge.node_a is face.node_c and
                edge.node_b is face.node_a):
            return True
        return False


    def insert_faces(self, faces_removed):
        """
        O(len(faces_removed))
        """
        for face in faces_removed:
            if (self.is_left_face(face, face.edge_a)):
                face.edge_a.face_l = face
            else:
                face.edge_a.face_r = face
            if (self.is_left_face(face, face.edge_b)):
                face.edge_b.face_l = face
            else:
                face.edge_b.face_r = face
            if (self.is_left_face(face, face.edge_c)):
                face.edge_c.face_l = face
            else:
                face.edge_c.face_r = face
            self.mesh_conv_q.add_face(face)
        self.mesh_conv_q.edges_no_faces.remove_all()
        self.mesh_conv_q.edges_open.remove_all()


    def update_conflicts(self, faces_before, nodes_conflict):
        """
        O(len(faces_before) * len(nodes_conflict))
        """
        for node in nodes_conflict:
            conflict_found = False
            for face in faces_before:
                if (is_face_visible(face, node)):
                    node.data.face_conflict = face
                    node.data.face_nodes_conflict_list_element = \
                        face.data.nodes_conflict.add(node)
                    conflict_found = True
                    break
            if (not conflict_found):
                for edge in self.mesh_conv_q.edges_open:
                    face = edge.face_l
                    if (is_face_visible(face, node)):
                        node.data.face_conflict = face
                        node.data.face_nodes_conflict_list_element = \
                            face.data.nodes_conflict.add(node)
                        break


    def run(self):
        """
        O(len(self.mesh_conv_p.nodes))
        """
        if (self.animator):
            self.animator.set_text('SubsetConflictWalk')
            self.animator.wait()

        # 1. Let queue be a queue with the elements in Q.
        queue = LinkedList()
        for node in self.mesh_conv_q.nodes:
            queue.add(node)
            node.data.visited = True

        # 2. While queue != 0:
        element = queue.first_element
        while(element):

        #    (a) Let p be the next point in queue.
            node_p = element.data

        #    (b) If p not in Q, insert p into conv(Q), using a previously
        #        computed conflict facet f_p for p as a starting point.
            node_inserted = False
            if (node_p.data.node_q is None):
                self.insert_point(node_p)

                if (self.animator):
                    node_p.data.node_q.highlight = True
                    self.animator.wait()

                faces_removed = self.remove_visible_faces(
                    node_p.data.face_conflict, node_p)

                if (self.animator):
                    self.animator.wait()

                edges_open = self.mesh_conv_q.handle_edges_open()
                triangulate_edges_with_node(self.mesh_conv_q,
                    edges_open, node_p.data.node_q)
                edges_open.destroy()
                node_inserted = True

                if (self.animator):
                    self.animator.wait()
            elif (self.animator):
                node_p.data.node_q.highlight = True
                self.animator.wait()

        #    (c) For each neighbor q in Gamma_P(p), find a conflict facet f~_q
        #        in conv(Q u p), using Claim 3.3.
            nodes_neighbour = get_neighbour_points(node_p.data.node_p)

            if (self.animator):
                nodes_ani_added = LinkedList()
                for node in nodes_neighbour:
                    if (not node.data.node_q):
                        node_q = node.clone()
                        node_q.data = node.data
                        self.mesh_conv_q.add_node(node_q)
                        node.data.node_q = node_q
                        nodes_ani_added.add(node_q)
                    node.data.node_q.highlight = True
                self.animator.wait()
                for node in nodes_neighbour:
                    node.data.node_q.highlight = False

            for node in nodes_neighbour:
                if (not node.data.face_conflict):
                    for face in node_p.data.node_q.faces:
                        node_check = node
                        if (node.data.node_q):
                            node_check = node.data.node_q
                        if (is_face_visible(face, node_check)):
                            node.data.face_conflict = face
                            if (not face.data):
                                face.data = RMSFaceData()
                            node.data.face_nodes_conflict_list_element = \
                                face.data.nodes_conflict.add(node)

                            if (self.animator):
                                node.data.node_q.highlight = True
                                node.data.face_conflict.highlight = True
                                self.animator.wait()
                                node.data.node_q.highlight = False
                                node.data.face_conflict.highlight = False

                            break

            if (self.animator):
                node_p.data.node_q.highlight = False
                for node in nodes_ani_added:
                    self.mesh_conv_q.remove_node(node)
                    node.data.node_q = None
                nodes_ani_added.remove_all()

        #    (d) Using the f~_q's, find conflict facets f_q in F[Q] for
        #        Gamma_P(p). If q in Gamma_P(p) has not been encountered yet,
        #        insert it into queue.
            if (node_inserted):

                if (self.animator):
                    node_p.data.node_q.highlight = True
                    self.animator.wait()

                nodes_conflict = self.remove_point(node_p)

                if (self.animator):
                    self.animator.wait()

                self.update_conflicts(faces_removed, nodes_conflict)
                self.insert_faces(faces_removed)
                faces_removed.destroy()
                nodes_conflict.destroy()

                if (self.animator):
                    self.animator.wait()

            for node in nodes_neighbour:
                if (not node.data.visited):
                    queue.add(node)
                    node.data.visited = True
            nodes_neighbour.destroy()

            queue.remove_element(element)
            element = queue.first_element



class RandMultiSplit:
    """
    1. Pick a random sample S of P of size n/chi and compute
       conv(S).
    2. For each p in P, determine a facet f_p in F[S] in conflict
       with p.
    3. For each color i:
       (a) Insert all points of C_i into conv(S).
       (b) Extract conv(C_i) from conv(C_i u S).
    """

    def __init__(self, mesh, animator=None):
        self.mesh = mesh
        self.animator = animator
        self.num_types = 0
        self.results = None


    def compute_sample(self, num_nodes):
        """
        O(len(self.mesh.nodes))
        """
        result = self.mesh.clone()
        i = 0
        for node in result.nodes:
            node_p = node.data.origin
            node.data = RMSNodeData()
            node.data.node_p = node_p
            node_p.data = node.data
            if (i < num_nodes):
                node.data.type_before = node.type
                node.type = 0
            i += 1
        split_hull(result, 0)
        for face in result.faces:
            face.data = RMSFaceData()
        for node in result.nodes:
            node.type = node.data.type_before
            node.data.node_q = node
        return result


    def clone_mesh(self, mesh):
        """
        O(len(mesh.nodes))
        """
        result = mesh.clone()
        for node in result.nodes:
            node_origin = node.data.origin
            node.data = SHNodeData()
            node.data.origin = node_origin
        for face in result.faces:
            face_origin = face.data.origin
            face.data = RMSFaceData()
            face.data.origin = face_origin
        return result


    def insert_point(self, mesh, node):
        """
        O(1)
        """
        node_origin = node.data.origin
        node_insert = node.clone()
        node_insert.data = SHNodeData()
        node_insert.data.origin = node_origin
        mesh.add_node(node_insert)
        return node_insert


    def init_conflicts(self, meshes):
        """
        O(len(self.mesh.nodes))
        """
        for i in range(0, self.num_types):
            mesh = meshes[i]
            nodes_type = i+1
            for face in mesh.faces:
                original_nodes_conflict = face.data.origin.data.nodes_conflict
                for node in original_nodes_conflict:
                    if (node.type == nodes_type):
                        node.data.face_conflict = face
                        face.data.nodes_conflict.add(node)
                        original_nodes_conflict.remove_element(
                            node.data.face_nodes_conflict_list_element)
                        node.data.face_nodes_conflict_list_element = None


    def remove_visible_faces(self, mesh, face_hint, node_visible):
        nodes_conflict = LinkedList()
        faces_visible = LinkedList()
        get_visible_faces_recursive(faces_visible, face_hint, node_visible)
        for face in faces_visible:
            mesh.remove_face(face)
            face.highlight = False
            if (face.data):
                for node in face.data.nodes_conflict:
                    nodes_conflict.add(node)
        faces_visible.destroy()
        return nodes_conflict


    def update_conflicts(self, faces, nodes_conflict):
        """
        O(len(faces) * len(nodes_conflict))
        """
        for node in nodes_conflict:
            for face in faces:
                if (is_face_visible(face, node)):
                    node.data.face_conflict = face
                    if (not face.data):
                        face.data = RMSFaceData()
                    face.data.nodes_conflict.add(node)
                    break


    def run(self, check_result_consistency=False):
        """
        O(len(self.mesh.nodes))
        """
        if (self.animator):
            self.animator.set_text('RandMultiSplit')
            self.animator.wait()

        # 1. Pick a random sample S of P of size n/chi and compute
        #    conv(S).
        self.num_types = count_types(self.mesh)
        size_sample = int(len(self.mesh.nodes) / self.num_types)
        if (size_sample < 4):
            size_sample = 4
        mesh_s = self.compute_sample(size_sample)
        results = [None]*self.num_types
        for i in range(0, self.num_types):
            results[i] = self.clone_mesh(mesh_s)

        if (self.animator):
            for node in mesh_s.nodes:
                node.data.node_p.highlight = True
            self.animator.wait()
            for node in mesh_s.nodes:
                node.data.node_p.highlight = False
            self.animator.set_mesh(mesh_s)
            self.animator.wait()

        # 2. For each p in P, determine a facet f_p in F[S] in conflict
        #    with p.
        subset_conflict_walk(mesh_s, self.mesh, self.animator)

        if (self.animator):
            self.animator.set_text('RandMultiSplit')
            self.animator.halt_skip()

        # 3. For each color i:
        #    (a) Insert all points of C_i into conv(S).
        self.init_conflicts(results)
        i = 0
        for node in self.mesh.nodes:
            if (i >= len(mesh_s.nodes)):
                mesh_current = results[node.type-1]

                if (self.animator):
                    self.animator.set_mesh(mesh_current)
                    self.animator.wait()

                node_inserted = self.insert_point(mesh_current, node)
                face_hint = node.data.face_conflict
                nodes_conflict = self.remove_visible_faces(mesh_current,
                    face_hint, node)
                mesh_current.remove_edges_no_faces()
                edges_open = mesh_current.handle_edges_open()
                triangulate_edges_with_node(mesh_current,
                    edges_open, node_inserted)
                edges_open.destroy()
                self.update_conflicts(node_inserted.faces, nodes_conflict)
                nodes_conflict.destroy()

                if (self.animator):
                    node_inserted.highlight = True
                    self.animator.wait()
                    node_inserted.highlight = False

            i += 1

        if (self.animator):
            self.animator.halt_skip()

        #    (b) Extract conv(C_i) from conv(C_i u S).
        for i in range(0, self.num_types):

            if (self.animator):
                self.animator.set_mesh(results[i])
                self.animator.wait()

            split_hull(results[i], i+1)
            if (check_result_consistency):
                if (not results[i].is_consistent()):
                    raise Exception('Result C_'+str(i+1)+' is inconsistent.')

            if (self.animator):
                self.animator.wait()

        return results


