-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathN_queens_env.py
70 lines (60 loc) · 2.21 KB
/
N_queens_env.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import gymnasium as gym
from gymnasium import spaces
import numpy as np
class NQueensEnv(gym.Env):
def __init__(self, n=8):
super(NQueensEnv, self).__init__()
self.n = n # Size of the board
self.action_space = spaces.Discrete(n * n) # Action space
self.observation_space = spaces.MultiDiscrete([n] * n) # Observation space
self.current_step = 0
self.max_steps = 500
self.reset()
def reset(self, **kwargs):
self.state = np.random.choice(self.n, size=self.n, replace=False) # Random initial state
self.current_step = 0
return self.state, {}
def step(self, action):
self.current_step= self.current_step +1
row = action % self.n
col = action // self.n
self.state[row] = col
reward = self.calculate_reward()
if reward == 0:
done = True
truncated = False
elif self.current_step > self.max_steps:
truncated = True
done = False
else:
done = False
truncated = False
return self.state, reward, done, truncated, {}
def calculate_reward(self):
# Initialize violations count
violations = 0
# Create sets to keep track of conflicts
row_set = set()
diag1_set = set()
diag2_set = set()
# Iterate through each queen
for i in range(self.n):
# Calculate positions on diagonals
diag1 = self.state[i] + i
diag2 = self.state[i] - i
# Check for conflicts
if self.state[i] in row_set or diag1 in diag1_set or diag2 in diag2_set:
violations += 1
# Update sets
row_set.add(self.state[i])
diag1_set.add(diag1)
diag2_set.add(diag2)
# Reward is negative number of violations
return -violations
def is_terminal_state(self):
return self.calculate_reward() == 0 # Terminal state reached when there are no violations
def render(self, mode='human'):
board = [['_' for _ in range(self.n)] for _ in range(self.n)]
for i, row in enumerate(board):
row[self.state[i]] = 'Q'
print(' '.join(row))