Modified A* - A Working Model


 Introduction

In this article, we will be going over the implementation of Modified A* algorithm using Python to solve the 8Puzzle problem, while also giving a generalized idea on how this can potentially be used for other problems as well. In case you are not sure about A* algorithm, or what the modification is, I recommend going through the topic here.

Prerequisites: Python(Jupyter notebook), A* Algorithm, understanding copy/deep copy also helps.

Algorithms

This is going to be the algorithm for modified A* Algorithm, you will find it to be very similar to the algorithm for A* Algorithm, though this is expected as the main modification we have made is the added component of depth. Further it would help to note, depth of 1 is equivalent to the normal A* Algorithm.

mod_A(state, depth) :
    root_node = initialiseRootNode(state)
    iterated_node = root_node
    fringe = PriorityQueue()
    while not isGoalState(iterated_node):
        child_nodes = mod_exp(iterated_node, depth) //returns all child node to depth given
        for child_node in child_nodes:
            cost_val = getCost(child_node)
            fringe.put(cost_val, child_node)
        if fringe.size() > 0:
            iterated_node = fringe.get()
        else:
            break;
    return iterated_node

Algorithm for the mod_exp:

mod_exp(node)
    child_list = expand(node) //expand is used to generate the child nodes
    //we need to make two copies of queue, one to add to main queue, other to get values to expand further
    temp1 = make_fifo_queue(child_list)
    temp2 = make_fifo_queue(child_list)
    d = depth - 1
    while d > 0:
        while temp2.size() != 0:
            C = temp2.get()
            //another case of making multiple copies for same reason
            temp3_1 = expand(C)
            temp3_2 = make_fifo_queue(temp3_1)
            while temp3_2.size() !=0:
                temp1.put(temp3_2.get())
        d -= 1
        temp2 = make_fifo_queue(temp3_1)
    return temp1

The Code

Since we are done with the algorithm of the crux of our code we will get started with the nitty-gritty of how we are going to implement the algorithm for our 8 Puzzle Problem.

Setting Up Of The Problem

Unsurprisingly the part we have to tackle is the setting up of our 8 Puzzle Problem, how are we going to represent this in our code? What are the variables necessary to get the algorithm running? There are, a lot of ways to do this, but, we will be going through one of the more intuitive ways today. 

First let us try to represent the puzzle, our current case is a very easy one an 8 Puzzle problem simply has 8 tiles: 1,2,3,4,5,6,7,8, and an empty slot(e) to which we can move the slides adjacent to; in a 3x3 grid. Let us try to represent the state of the puzzle using a simple dictionary with 9 elements, ['e',1,2,3,4,5,6,7,8] that we map to a 3x3 grid. Thus the array index related to the 3x3 grid will have the following relation: 1 - [1,1], 2 - [1,2], 3 - [1,3], 4 - [2,1], 5 - [2,2], 6 - [2,3], 7 - [3,1], 8 - [3,2], 9 - [3,3]. So the generation of a random state will look like this:

base_arr = ['e',1,2,3,4,5,6,7,8]
np.random.shuffle(base_arr)
# The key values ranging from 1-9 is purposely used to also simultaneously signify the
# tile that is supposed to be there in the board's goal position
state = {
1:base_arr[0],
2:base_arr[1],
3:base_arr[2],
4:base_arr[3],
5:base_arr[4],
6:base_arr[5],
7:base_arr[6],
8:base_arr[7],
9:base_arr[8],
}

 Now that we have taken a basic look at the state generation, we need to look at the validity, not all states using the above method will give us a valid/possible shuffled state of an 8 Puzzle problem (i.e. a solvable state). So next we need to think about making sure the state produced is valid, luckily for us this is a problem that has already been solved for us.  There is something called an inversion test, explanation of inversion test is beyond the scope of this article, and readers are highly recommended to look it up, but the resulting code for inversion test will look as follows:

def is_solvable(state):
a = []
inv = 0
for i in range(1,10):
val = state[i]
if(val == 'e'):
a.append(0)
else:
a.append(val)

for i in range(0,8):
for j in range(i+1,9):
if a[i] and a[j] and (a[i] > a[j]):
inv += 1

return ((inv%2) != 1)

Great, Now we can create solvable 8 Puzzle states, next we need to be able to make some child states given an 8 Puzzle states. A simple analysis will give you the following insights:

  1. if the empty spot of the 3x3 grid is at the center, 4 movements are possible.

  2. if the empty spot of the 3x3 grid is at the edge, 3 movements are possible.

  3. if the empty spot of the 3x3 grid is at a corner, 2 movements are possible.

    The movements possible, the three distinct cases.

Now, let us make a function that returns the possible movements when a state is passed to it.

#outputs all possible moves from the given state
#['l' - moves tile to the left, 'r' - moves tile to the right, 'u' - moves tile up, 'd' - moves tile down]
def possible_moves(state):
    emt_pos = None
for i in range(1,10):
if(state.number_pos[i] == 'e'):
# empty position in the puzzle
emt_pos = i
break

if emt_pos == 1 :
return ['l','u']
if emt_pos == 2 :
return ['r','l','u']
if emt_pos == 3 :
return ['r','u']
if emt_pos == 4 :
return ['l','u','d']
if emt_pos == 5 :
return ['r','l','u','d']
if emt_pos == 6:
return ['r','u','d']
if emt_pos == 7:
return ['l','d']
if emt_pos == 8:
return ['r','l','d']
if emt_pos == 9:
return ['r','d']
 

So far so good, now we are going to do the following, first we will be encapsulating the "states" in the code with a class named 8Puzzle for better understanding and more intuitive usage, further, we will also write some necessary functions in the class to move the tiles in the board, along with other helper functions, necessary for our use case.

#this is essentially the setting up of the board:{remember, the rules for where the numbers can be moved is not 
# note given, it is at the discretion of the user to enter the numbers right and do the right actions}
# e -stands for empty
#goal state:
# 1 2 3
# 4 5 6
# 7 8 e
class Puzzle8:

#some global variables which corresponds to the whole problem tree and not just the node

#Number of nodes in the graph
node_num = 1

#maximum number of nodes in priority queue at any point (mostly for potential analytical purposes)
max_fringe_len = 1

#set of explored nodes, to avoid nodes
explored = set()

def __init__(self, parent):
self.parent = parent

#checks if a given state is solvable, uses inversion check, specific to 8 puzzle problem.
def is_solvable(self):
a = []
inv = 0
for i in range(1,10):
val = self.number_pos[i]
if(val == 'e'):
a.append(0)
else:
a.append(val)

for i in range(0,8):
for j in range(i+1,9):
if a[i] and a[j] and (a[i] > a[j]):
inv += 1

return ((inv%2) != 1)


#defining equality for objects of the 8Puzzle class
def __eq__(self,other):
a1 = (self.number_pos[1] == other.number_pos[1])
a2 = (self.number_pos[2] == other.number_pos[2])
a3 = (self.number_pos[3] == other.number_pos[3])
a4 = (self.number_pos[4] == other.number_pos[4])
a5 = (self.number_pos[5] == other.number_pos[5])
a6 = (self.number_pos[6] == other.number_pos[6])
a7 = (self.number_pos[7] == other.number_pos[7])
a8 = (self.number_pos[8] == other.number_pos[8])
a9 = (self.number_pos[9] == other.number_pos[9])
return (a1 and a2 and a3 and a4 and a5 and a6 and a7 and a8 and a9)

#defining lesser than for the objects of the class, this function is necessary for adding it to heaps and queues.
def __lt__(self, other):
return False


#outputs all possible moves from the given state
#['l' - moves tile to the left, 'r' - moves tile to the right, 'u' - moves tile up, 'd' - moves tile down]
def possible_moves(self):
for i in range(1,10):
if(self.number_pos[i] == 'e'):
# empty position in the puzzle
emt_pos = i
break

if(emt_pos == 1):
return ['l','u']
if(emt_pos == 2):
return ['r','l','u']
if(emt_pos == 3):
return ['r','u']
if(emt_pos == 4):
return ['l','u','d']
if(emt_pos == 5):
return ['r','l','u','d']
if(emt_pos == 6):
return ['r','u','d']
if(emt_pos == 7):
return ['l','d']
if(emt_pos == 8):
return ['r','l','d']
if(emt_pos == 9):
return ['r','d']

"""Note: the methods are named with empty slot movement in mind aka a tile
moving left is akin to empty slot moving right"""

#moves a tile from right of empty slot to empty slot
def mov_left(self):
for i in range(1,10):
if(self.number_pos[i] == 'e'):
emt_pos = i
self.number_pos[emt_pos] = self.number_pos[emt_pos+1]
self.number_pos[emt_pos+1] = 'e'

#moves a tile from left of empty slot to empty slot
def mov_right(self):
for i in range(1,10):
if(self.number_pos[i] == 'e'):
emt_pos = i
self.number_pos[emt_pos] = self.number_pos[emt_pos-1]
self.number_pos[emt_pos-1] = 'e'

#moves a tile from down(bottom) of empty slot to empty slot
def mov_up(self):
for i in range(1,10):
if(self.number_pos[i] == 'e'):
emt_pos = i
self.number_pos[emt_pos] = self.number_pos[emt_pos+3]
self.number_pos[emt_pos+3] = 'e'

#moves a tile from up(top) of empty slot to empty slot
def mov_down(self):
for i in range(1,10):
if(self.number_pos[i] == 'e'):
emt_pos = i
self.number_pos[emt_pos] = self.number_pos[emt_pos-3]
self.number_pos[emt_pos-3] = 'e'

#prints the board represented by the object
def show_board(self):

print(self.number_pos[1],self.number_pos[2],self.number_pos[3])
print(self.number_pos[4],self.number_pos[5],self.number_pos[6])
print(self.number_pos[7],self.number_pos[8],self.number_pos[9])

#returns the board as a dictionary
def get_board(self):
res = self.number_pos.copy()
return res

#checks if the state already exists in the explored set
def exist(self):
return self.dict_to_str() in Puzzle8.explored
#converts dictionary(board representation) into a string to push into a set(so that it can be hashed)
#returns the string
def dict_to_str(self):
result = ''
for i in range(1,10):
result = result+str(self.number_pos[i])
return result

#checks if a state is in its goal state and return True if it is, False otherwise
def goal_test(self):
for i in range(1,9):
if(self.number_pos[i] != i):
return False
return True

Ok, that was a lot of extra code without explanation, don't worry, we will be going over it now(going through the code and their comments in parallel, is highly recommended):

  1. move_left(), move_right(), move_up(), move_down(): these are the methods used for moving the tiles (The move direction is in relation to resulting movement of the empty slot)

  2. __eq__(), __lt__(): these are redefinition of equality check(==) and lesser than check(<) to suit the object. Equality in our case simply compares all tile positions, whereas lesser than method is simply necessary to allow us to add the object into heaps and such.

  3. exist(), dict_to_str(), explored: Why are these seemingly unrelated methods listed together? This is because these methods all help with a single functionality: keeping track of explored nodes. "explored" is a set that is shared by all the 8 Puzzle objects, it keeps track of all the nodes explored by our algorithm, but this is not as simple as it seems, because the number of nodes that needs to be stored can reach very high numbers, leading to delays if you want to search and check if a node is already visited, which is why we use a set(), sets are naturally hashed and thus have very quick search time, so we use the dict_to_str() method, to convert our dictionary to a string(for storage efficiency) and store it in the explored set. The exist() method is later used to check if any of the state we create is present in the explored set.

  4. show_board(): Used to print the state in a 3x3 grid, to make it visually appealing.

  5. get_board(): returns a copy of the dictionary storing the 8 Puzzle state.

  6. goal_test(): checks if the goal has been reached.

Heuristics

Our next task is to decide on heuristics, here we will be testing with two different heuristics:

  1. Mismatch count: Number of tiles not in their position.

  2. Manhattan distance: Manhattan distance of a single tile is the sum of the absolute difference of its current x coordinate and goal x coordinate, and its current y coordinate and goal y coordinate. We take the sum of all the tile values for our heuristic.

Goal State(left board), and shuffled state(right board)
Mismatch count
- tiles 4,5,7,8 are not in their original position, so mismatch count is 4.
Manhattan distance - 1,2,3,6 are on their goal state, so its Manhattan distance is 0, whereas 5 should be at (2,2) but is at (3,1) so |2-3|+|2-1| = 2, similarly we can find Manhattan distance of tiles 4,7,8 to be 1,1,1 respective sum of all the Manhattan distance is 2+1+1+1 (=5).

The Code for heuristic as part of the Puzzle8 class is given below:

#number of misplaced tiles(mismatch count)
def heur1(self):
h1 = -1
for i in range(1,10):
if(self.number_pos[i] != i):
h1 += 1
return h1

#number of squares a tile is away from its goal state(manhattan distance)
def heur2(self):
h2 = 0
for i in range(1,10):
#print(i)
if(self.number_pos[i] != 'e'):
# find the horizontal and vertical values of the tile if it was at its intended position.
# we are also going to make use of the fact that the dictionary keys also refer to the correct tile number
# for its position, which is why we use 'i' ( which ranges from 1 to 9 including both ends, same as the keys of
# our state dictionary - number_pos)
k = i%3
if(k == 0):
num_pos_h = 3
num_pos_v = int(i/3)
else:
num_pos_h = k
num_pos_v = int(i/3) + 1

# find the horizontal and vertical value of the tile for the current positions of the tile
p = self.number_pos[i]%3
if(p == 0):
a_num_pos_h = 3
a_num_pos_v = int(self.number_pos[i]/3)
else:
a_num_pos_h = p
a_num_pos_v = int(self.number_pos[i]/3) + 1
man_dist = abs(num_pos_h-a_num_pos_h) + abs(num_pos_v-a_num_pos_v)
h2 += man_dist
return h2
 

The Algorithm Implementation

Finally, we will be looking at the code implementation of modified A* algorithm, along with it's helper method followed by a brief explanation of what the method does.

#returns all the unexplored children of the node as a list of Puzzle8 objects
def expand(S):
pos_moves = S.possible_moves()
children_nodes = []

#The possible moves are iterated through, and the object from the resulting moves are created.
for i in range(pos_moves.__len__()):
if pos_moves[i] == 'r':
child = Puzzle8(S)
child.number_pos = S.get_board()
child.depth = S.depth + 1
child.mov_right()
if not(child.exist()):
Puzzle8.node_num += 1
Puzzle8.explored.add(child.dict_to_str())
children_nodes.append(child)
elif pos_moves[i] == 'l':
child = Puzzle8(S)
child.number_pos = S.get_board()
child.depth = S.depth + 1
child.mov_left()
if not(child.exist()):
Puzzle8.node_num += 1
Puzzle8.explored.add(child.dict_to_str())
children_nodes.append(child)
elif pos_moves[i] == 'u':
child = Puzzle8(S)
child.number_pos = S.get_board()
child.depth = S.depth + 1
child.mov_up()
if not(child.exist()):
Puzzle8.node_num += 1
Puzzle8.explored.add(child.dict_to_str())
children_nodes.append(child)
else:
child = Puzzle8(S)
child.number_pos = S.get_board()
child.depth = S.depth + 1
child.mov_down()
if not(child.exist()):
Puzzle8.node_num += 1
Puzzle8.explored.add(child.dict_to_str())
children_nodes.append(child)
return children_nodes

#converts a list into a queue and returns the queue
def make_fifo_queue(lis):
result = q.Queue()
[result.put(ele) for ele in lis]
return result


#modified BFS search which checks up to a given depth and returns a queue with all the elements.
def mod_exp(S,depth):
temp3_1 = []
child_list = expand(S)
temp1 = make_fifo_queue(child_list)
temp2 = make_fifo_queue(child_list)
d = depth - 1
while d > 0:
while temp2.qsize() != 0:
C = temp2.get()
temp3_1 = expand(C)
temp3_2 = make_fifo_queue(temp3_1)
while temp3_2.qsize() !=0:
temp1.put(temp3_2.get())
d -= 1
temp2 = make_fifo_queue(temp3_1)
return temp1


#modified A* search, the "AI" part
def mod_A(initial_state,heur,depth=1):
root = Puzzle8(None)
root.number_pos = initial_state
root.depth = 0
root.parent = None

Puzzle8.explored.add(root.dict_to_str())
fringe = PriorityQueue()
S = root
print("initial state:")
root.show_board()
while not(S.goal_test()):
# get all the nodes upto a certain depth
temp = mod_exp(S,depth)

# Iterate through the queue:
while temp.qsize() != 0:

# Pop first element from queue
l = temp.get()
"""note: the val here is cost of moving to the specific node, the heur1 and heur2 are two different
heuristic costs used for testing the code.
l.depth is equivalent to the path cost. Each depth adding 1"""

# get the cost value for the node
# two different heuristics we can set with heur flag
if heur == 1:
# mismatch count heuristic used
val = (l.heur1()+l.depth)
else:
# manhattan distance heuristic used
val = (l.heur2()+l.depth)

# add them to the priority queue
fringe.put((val,l))

# store the max fringe size, only for analytical purpose
Puzzle8.max_fringe_len = max(Puzzle8.max_fringe_len,fringe.qsize())

# as long as the fringe size is not 0
if(fringe.qsize() != 0):
# pop the top element from the priority queue and continue iteration from that point
# this essentially moves the graph search to the node with least cost, since priority makes sure of that.
S = fringe.get()[1]
else:
# if there are no values in the fringe, break out of the loop, only happens if there are no solutions,
# after iterating, all possible moveset
break
#returns solution object(or last object checked if there is no solution)
return S

Let us look at the crux of all the code we wrote:

  1.  mod_A(), this is the code equivalent of the algorithm for modified A*, Readers are highly recommended to compare the method and the algorithm, you will find that they are strikingly similar with some exception of extra variables initialized/used.

  2. mod_exp(), returns the child nodes that has not been visited yet up to a certain depth (more details about this is given in the link given in introduction of the post), the algorithm for this type of search, is once again already given.

  3. expand(), helper function for mod_exp(), and it gives the immediate child nodes given a node, specific to Puzzle8 class (readers will have to make a difference here depending on the problem they work with).

  4. make_fifo_queue(), converts list into a queue.

Driver and driver related functions

Finally we have the driver that runs everything. This part is the most flexible, it is up to the user to run the above code however they want, but, in this particular case, we will be doing the following cases:

  1. 10 randomly generated cases, with depth = 1, heuristic = mismatch count.

  2. 10 randomly generated cases, with depth = 1, heuristic = Manhattan distance.

  3. 10 randomly generated cases, with depth = 5, heuristic = mismatch count.

  4. 10 randomly generated cases, with depth = 5, heuristic = Manhattan distance.

 The code for the above tests, looks something like this, once we include couple of lines of code for a visually aesthetic output:

#creates 10 random states in Puzzle8 acceptable format and return is as a list of dictionaries
#Also makes sure they are solvable
def create_states():
i=0
result = []
base_arr = ['e',1,2,3,4,5,6,7,8]
while i<10:
np.random.shuffle(base_arr)
state = {
1:base_arr[0],
2:base_arr[1],
3:base_arr[2],
4:base_arr[3],
5:base_arr[4],
6:base_arr[5],
7:base_arr[6],
8:base_arr[7],
9:base_arr[8],
}
p = Puzzle8(None)
p.number_pos = state
if(p.is_solvable()):
result.append(state)
i += 1
del p

return result

def reproducePath(S):
temp = cp.deepcopy(S)
while temp.__class__.__name__ == "Puzzle8":
temp.show_board()
print("_________")
temp = temp.parent


#driver function going to be called
def driver():
state = create_states()

print("solving 10 states with depth = 1, 'mismatch count' heuristic:")
print(" ")
for i in range(10):
Puzzle8.node_num = 1
Puzzle8.max_fringe_len = 1
Puzzle8.explored = set()
res = mod_A(state[i],1,1)
print("state:",i+1,", path length:",res.depth,", number of nodes generated:",Puzzle8.node_num,", maximum number of fringes:",Puzzle8.max_fringe_len)
print(" ")
print(" ")
print("solving 10 states with depth = 1, 'manhattan distance' heuristic:")
for i in range(10):
Puzzle8.node_num = 1
Puzzle8.max_fringe_len = 1
Puzzle8.explored = set()
res = mod_A(state[i],2,1)
print("state:",i+1,", path length:",res.depth,", number of nodes generated:",Puzzle8.node_num,", maximum number of fringes:",Puzzle8.max_fringe_len)
print(" ")
print(" ")
print("solving 10 states with depth = 5, 'mismatch count' heuristic:")
for i in range(10):
Puzzle8.node_num = 1
Puzzle8.max_fringe_len = 1
Puzzle8.explored = set()
res = mod_A(state[i],1,5)
print("state:",i+1,", path length:",res.depth,", number of nodes generated:",Puzzle8.node_num,", maximum number of fringes:",Puzzle8.max_fringe_len)
print(" ")
print(" ")
print("solving 10 states with depth = 5, 'manhattan distance' heuristic:")
for i in range(10):
Puzzle8.node_num = 1
Puzzle8.max_fringe_len = 1
Puzzle8.explored = set()
res = mod_A(state[i],2,5)
print("state:",i+1,", path length:",res.depth,", number of nodes generated:",Puzzle8.node_num,", maximum number of fringes:",Puzzle8.max_fringe_len)
print(" ")

#Just here to show the path taken by the last state and to show the function.
reproducePath(res)

Output

Finally, we will tie up all the parts together, and try running it (find the final run here).

Conclusion

In this article, we have managed to build a simple graph search based AI that can solve an 8-Puzzle problem. In the process we should have got an insight on the following; How to represent a state that can be iterated through easily?, How are we going to judge a goal state? What are the potential heuristics we can use and what are the relative performance? Another idea to note is, we purposely chose 8-Puzzle problem for (modified) A* algorithm, since not all problems can be solved using this algorithm.  

In the end the intention of this article was to help the reader to get a good understanding on, the practice of using existing algorithms to try and solve possible problems, trying to improve it by testing it with multiple heuristics, and potentially trying to improve the core algorithm itself. Finally, I would like to thank everyone for having the patience to read through the entire article, and wish you all the best on your carrier.

Popular posts from this blog

A Look Into Modified A* Algorithm

Red-Black Tree - Understanding and Implementation.