import numpy as np
from math import sqrt
 
 
class A_star:
    def __init__(self, start_point=[0, 0], end_point=[2, 2], matrix=[[True for i in range(3)] for j in range(3)], weights=1.1, Corner_amend=1, step=10000):
        self.matrix = matrix
        self.weights = weights
        self.corner_amend = Corner_amend
        self.matrix_length = len(self.matrix[0])
        self.matrix_width = len(self.matrix)
        self.start_point = self.sub2index(start_point) if type(start_point) != int else start_point
        self.end_point = self.sub2index(end_point) if type(end_point) != int else end_point
        self.startx, self.starty = self.index2sub(self.start_point)
        self.endx, self.endy = self.index2sub(self.end_point)
        self.step = step
        if not (self.matrix[self.startx][self.starty] and self.matrix[self.endx][self.endy]):
            exit("start or end is wall")
 
    def a_star(self):
        field = self.chansform_field()
        fieldpointers = self.chansform_fieldpointers()
 
        setopen = [self.start_point]
        setopencosts = [0]
        setopenheuristics = [float("inf")]
        setclosed = []
        setclosedcosts = []
        movementdirections = ["R", "L", "D", "U"]
 
        while self.end_point not in setopen and self.step:
            self.step -= 1
            total_costs = setopencosts + list(self.weights * np.array(setopenheuristics))
            temp = np.min(setopencosts)
            ii = total_costs.index(temp)
            if setopen[ii] != self.start_point and self.corner_amend == 1:
                new_ii = self.Path_optimization(temp, ii, fieldpointers, setopen, setopencosts, setopenheuristics)
                ii = new_ii
 
            [costs, heuristics, posinds] = self.findFValue(setopen[ii], setopencosts[ii], field, self.end_point)
 
            setclosed = setclosed + [setopen[ii]]
            setclosedcosts = setclosedcosts + [setopencosts[ii]]
 
            setopen.pop(ii)
            setopencosts.pop(ii)
            setopenheuristics.pop(ii)
 
            for jj in range(len(posinds)):
                if float("Inf") != costs[jj]:
                    if not posinds[jj] in setclosed + setopen:
                        [row, col] = self.index2sub(posinds[jj])
                        fieldpointers[row][col] = movementdirections[jj]
                        setopen = setopen + [posinds[jj]]
                        setopencosts = setopencosts + [costs[jj]]
                        setopenheuristics = setopenheuristics + [heuristics[jj]]
                    elif posinds[jj] in setopen:
                        I = setopen.index(posinds[jj])
                        if setopencosts[I] > costs[jj]:
                            [row, col] = self.index2sub(setopen[I])
                            setopencosts[I] = costs[jj]
                            setopenheuristics[I] = heuristics[jj]
                            fieldpointers[row][col] = movementdirections[jj]
                    else:
                        I = setclosed.index(posinds[jj])
                        if setclosedcosts[I] > costs[jj]:
                            [row, col] = self.index2sub(setclosed[I])
                            setclosedcosts[I] = costs[jj]
                            fieldpointers[row][col] = movementdirections[jj]
 
            if not setopen:
                return None
        if self.end_point in setopen:
            rod = self.findWayBack(self.end_point, fieldpointers)
            return rod
        else:
            exit("Can't")
 
    def sub2index(self, array):
        return int(array[1] * self.matrix_width + array[0] + 1)
 
    def Path_optimization(self, temp, ii, fieldpointers, setOpen, setOpenCosts, setOpenHeuristics):
        [row, col] = self.index2sub(setOpen[ii])
        _temp = fieldpointers[row][col]
        if _temp == "L":
            Parent_node = setOpen[ii] - self.matrix_width
        elif _temp == "R":
            Parent_node = setOpen[ii] + self.matrix_width
        elif _temp == "U":
            Parent_node = setOpen[ii] - 1
        elif _temp == "D":
            Parent_node = setOpen[ii] + 1
 
        if Parent_node == self.start_point:
            new_ii = ii
        else:
            [row, col] = self.index2sub(Parent_node)
            _temp = fieldpointers[row][col]
            if _temp == "L":
                Expected_note = Parent_node + self.matrix_width
            elif _temp == "R":
                Expected_note = Parent_node - self.matrix_width
            elif _temp == "U":
                Expected_note = Parent_node + 1
            elif _temp == "D":
                Expected_note = Parent_node - 1
 
            if Expected_note < 0 or Expected_note > self.matrix_width * self.matrix_length - 1:
                new_ii = ii
            else:
                [row, col] = self.index2sub(setOpen[ii])
                [row2, col2] = self.index2sub(Parent_node)
                if fieldpointers[row][col] == fieldpointers[row2][col2]:
                    new_ii = ii
                elif Expected_note in setOpen:
                    untext_ii = setOpen.index(Expected_note)
                    now_cost = setOpenCosts[untext_ii] + self.weights * setOpenHeuristics[untext_ii]
                    if temp == now_cost:
                        new_ii = untext_ii
                    else:
                        new_ii = ii
                else:
                    new_ii = ii
        return new_ii
 
    def findFValue(self, posind, costsofar, field, goalind):
        currentpos = self.index2sub(posind)
        goalpos = self.index2sub(goalind)
        cost = [float("inf") for i in range(4)]
        heuristic = [float("inf") for i in range(4)]
        pos = np.ones(8).reshape(4, 2)
 
        x = currentpos[0]
        y = currentpos[1] - 1
        if y >= 0:
            pos[0, :] = [x, y]
            heuristic[0] = sqrt((goalpos[1] - y) ** 2 + (goalpos[0] - x) ** 2)
            cost[0] = costsofar + field[x][y]
 
        x = currentpos[0]
        y = currentpos[1] + 1
        if y <= self.matrix_length - 1:
            pos[1, :] = [x, y]
            heuristic[1] = sqrt((goalpos[1] - y) ** 2 + (goalpos[0] - x) ** 2)
            cost[1] = costsofar + field[x][y]
 
        x = currentpos[0] - 1
        y = currentpos[1]
        if x >= 0:
            pos[2, :] = [x, y]
            heuristic[2] = sqrt((goalpos[1] - y) ** 2 + (goalpos[0] - x) ** 2)
            cost[2] = costsofar + field[x][y]
 
        x = currentpos[0] + 1
        y = currentpos[1]
        if x <= self.matrix_width - 1:
            pos[3, :] = [x, y]
            heuristic[3] = sqrt((goalpos[1] - y) ** 2 + (goalpos[0] - x) ** 2)
            cost[3] = costsofar + field[x][y]
 
        temp = [[pos[i, 0], pos[i, 1]] for i in range(4)]
        posinds = [self.sub2index(i) for i in temp]
        return [cost, heuristic, posinds]
 
    def chansform_field(self):
        field = np.ones((self.matrix_width, self.matrix_length))
        for i in range(self.matrix_width):
            for j in range(self.matrix_length):
                if not self.matrix[i][j]:
                    field[i][j] = float("inf")
        field[self.startx][self.starty] = 0
        field[self.endx][self.endy] = 0
        return field
 
    def chansform_fieldpointers(self):
        fieldpointers = []
        for i in range(self.matrix_width):
            temp = []
            for j in range(self.matrix_length):
                if self.matrix[i][j]:
                    temp.append(1)
                else:
                    temp.append(float("inf"))
            fieldpointers.append(temp)
        fieldpointers[self.startx][self.starty] = "S"
        fieldpointers[self.endx][self.endy] = "G"
        return fieldpointers
 
    def findWayBack(self, goalposind, fieldpointers):
        posind = goalposind
        p = self.index2sub(posind)
        sum = [p]
        x = p[0]
        y = p[1]
        while fieldpointers[x][y] != "S":
            temp = fieldpointers[x][y]
            if temp == "L":
                y -= 1
            elif temp == "R":
                y += 1
            elif temp == "U":
                x -= 1
            elif temp == "D":
                x += 1
            else:
                print("Error Find way back")
                exit()
            sum.append([x, y])
        return sum
 
    def index2sub(self, posind):
        row = (posind - 1) % self.matrix_width
        col = (posind - 1) // self.matrix_width
        return [int(row), int(col)]
 
 
if __name__ == "__main__":
    A_star().a_star()