Question: Please help modify this code cart.py to fit the test case .test_cart.py in python: #!/usr/bin/python import argparse import logging import sys import numpy as np
Please help modify this code cart.py to fit the test case .test_cart.py in python:
#!/usr/bin/python
import argparse import logging import sys
import numpy as np
import gym #import gym.scoreboard.scoring from gym import wrappers, logge
def discretize_state( x, xdot, theta, thetadot ): one_degree = 0.0174532 six_degrees = 0.1047192 twelve_degrees = 0.2094384 fifty_degrees = 0.87266
box = 0 if x < -2.4 or x > 2.4 or theta < -twelve_degrees or theta > twelve_degrees: return -1
if x < -0.08: box = 0 elif x < 0.08: box = 1 else: box = 2
box *= 3 if xdot < -0.5: box += 0 elif xdot < 0.5: box +=1 else: box +=2
box *= 6 if theta < -six_degrees: box += 0 if theta < -one_degree: box += 1 elif theta < 0: box += 2 elif theta < one_degree: box += 3 elif theta < six_degrees: box += 4 else: box += 5
box *= 3 if thetadot < -fifty_degrees: box += 0 elif thetadot < fifty_degrees: box += 1 else: box += 2
return box
if __name__ == '__main__': parser = argparse.ArgumentParser(description=None)
parser.add_argument('env_id', nargs='?', default='CartPole-v0', help='Select the environment to run') args = parser.parse_args()
logger = logging.getLogger() formatter = logging.Formatter('[%(asctime)s] %(message)s') handler = logging.StreamHandler(sys.stderr) handler.setFormatter(formatter) logger.addHandler(handler)
# You can set the level to logging.DEBUG or logging.WARN if you # want to change the amount of output. logger.setLevel(logging.INFO)
env = gym.make(args.env_id) outdir = '/tmp/' + 'qagent' + '-results' env = wrappers.Monitor(env, outdir, write_upon_reset=True, force=True)
env.seed(0)
Q = np.zeros([162, env.action_space.n])
alpha = 0.7 gamma = 0.97
n_episodes = 50001 for episode in range(n_episodes): tick = 0 reward = 0 done = False state = env.reset() s = discretize_state(state[0], state[1], state[2], state[3]) while done != True: tick += 1 action = 0 ri = -999 for q in range(env.action_space.n): if Q[s][q] > ri: action = q ri = Q[s][q] state, reward, done, info = env.step(action) #print( reward, done) sprime = discretize_state(state[0], state[1], state[2], state[3]) predicted_value = np.max(Q[sprime]) if sprime < 0: predicted_value = 0 reward = -5 #Q[s, action] += 0 #Q[s,action] += (1-alpha)*Q[s,action] + alpha*(ri + gamma*predicted_value) #implement equation here. Q[s,action] += alpha*(reward + gamma*predicted_value - Q[s,action]) #print(Q[s,action], ri, sprime, Q[s][action]) s = sprime
if episode % 1000 == 0: alpha *= .99 #decay rate for alpha, each 1000
Test case:
#!/usr/bin/env python3 from cart import CartPole import unittest import numpy as np
class TestTicTacToe(unittest.TestCase): # def test_init_board(self): # ttt = TicTacToe3D() # # brd,winner = ttt.play_game() # self.assertEqual(ttt.board.shape, (3,3,3))
def test_1(self): player_first = 1 expected_winner = 1 env_id = 'CartPole-v1' cartpole = CartPole(env_id, False, True, 'cart.npy') all_states = cartpole.run() max_ = np.max(all_states, axis=0) print("max = {}".format(max_)) result_1 = max_[0] <= 2.4 result_2 = max_[2] <= 0.226893 result = result_1 and result_2 print("Your max cart position = {}".format(max_[0])) print("Your max pole angle = {}".format(max_[2])) print("Cart position for success <= {}".format(2.4)) print("Pole angle for success <= {} radians".format(0.226893)) self.assertEqual(result,True)
unittest.main() if tick < 199: if episode % 1000 ==0: print "fail ", tick else: print "success"
Step by Step Solution
There are 3 Steps involved in it
Get step-by-step solutions from verified subject matter experts
