From d9239d31a677c4e8a0ace55f71e2c47181d0bded Mon Sep 17 00:00:00 2001 From: r3w0p Date: Mon, 6 Jan 2025 00:23:10 +0000 Subject: [PATCH] need big simplifications... - state: hand needs to be reduced to binary per numeral/face class - actions needs to be intent for movement and not exact movements - more "coaching" --- include/caravan/core/training.h | 2 +- src/caravan/core/training.cpp | 57 +++++++++++++------------------- src/caravan/train.cpp | 13 ++++---- test/caravan/model/test_game.cpp | 4 +-- 4 files changed, 33 insertions(+), 43 deletions(-) diff --git a/include/caravan/core/training.h b/include/caravan/core/training.h index 9679e19..03fe643 100644 --- a/include/caravan/core/training.h +++ b/include/caravan/core/training.h @@ -16,7 +16,7 @@ const uint16_t SIZE_ACTION = 5; const uint16_t SIZE_ACTION_SPACE = 920; -const uint16_t SIZE_GAME_STATE = 200; +const uint16_t SIZE_GAME_STATE = 38; const uint8_t NUM_PLAYER_ABC = 1; const uint8_t NUM_PLAYER_DEF = 2; diff --git a/src/caravan/core/training.cpp b/src/caravan/core/training.cpp index d803bf1..cae2cc8 100644 --- a/src/caravan/core/training.cpp +++ b/src/caravan/core/training.cpp @@ -38,6 +38,11 @@ uint8_t suit_to_uint8_t(Suit s) { return static_cast(s); } +uint8_t direction_to_uint8_t(Direction d) { + // ANY is 0 + return static_cast(d); +} + void add_hand_to_game_state(GameState *gs, uint16_t *i_gs, Player *player) { Hand hand = player->get_hand(); uint8_t hand_size = player->get_size_hand(); @@ -62,36 +67,28 @@ void add_hand_to_game_state(GameState *gs, uint16_t *i_gs, Player *player) { } } -void add_caravan_to_game_state(GameState *gs, uint16_t *i_gs, Caravan *caravan) { - // Current highest numeral position along caravan track - uint8_t max_track = caravan->get_size(); +void add_caravan_to_game_state(GameState *gs, uint16_t *i_gs, Game *game, Caravan *caravan) { + uint8_t caravan_size = caravan->get_size(); + + // Add whether caravan is winning + (*gs)[(*i_gs)++] = game->is_caravan_winning(caravan->get_name()); + + // Add whether numeral track is full + (*gs)[(*i_gs)++] = caravan_size == TRACK_NUMERIC_MAX; + + // Add caravan direction + (*gs)[(*i_gs)++] = direction_to_uint8_t(caravan->get_direction()); // Add caravan suit (*gs)[(*i_gs)++] = suit_to_uint8_t(caravan->get_suit()); - for (uint8_t i_track = 0; i_track < TRACK_NUMERIC_MAX; i_track++) { - // If numeral at track position, fetch slot state - if (i_track < max_track) { - Slot slot = caravan->get_slot(i_track + 1); - - // Add numeral - (*gs)[(*i_gs)++] = rank_to_uint8_t(slot.card.rank); + // Add rank of highest numeral + if (caravan_size > 0) { + Slot slot = caravan->get_slot(caravan->get_size()); + (*gs)[(*i_gs)++] = rank_to_uint8_t(slot.card.rank); - // Add face cards - for (uint8_t i_face = 0; i_face < TRACK_FACE_MAX; i_face++) { - if (i_face < slot.i_faces) { - (*gs)[(*i_gs)++] = rank_to_uint8_t(slot.faces[i_face].rank); - } else { - (*gs)[(*i_gs)++] = 0; - } - } - } else { - // No populated slot at caravan position, leave blank spaces - // for numeral and max face cards - for (uint8_t _ = 0; _ < (1 + TRACK_FACE_MAX); _++) { - (*gs)[(*i_gs)++] = 0; - } - } + } else { + (*gs)[(*i_gs)++] = 0; } } @@ -126,7 +123,7 @@ void get_game_state(GameState *gs, Game *game, PlayerName pname) { // Add state of each caravan, player's caravans first for (uint8_t i_cvn = 0; i_cvn < cvn_names_size; i_cvn++) { Caravan *caravan = table->get_caravan(cvn_names[i_cvn]); - add_caravan_to_game_state(gs, &i_gs, caravan); + add_caravan_to_game_state(gs, &i_gs, game, caravan); } } @@ -349,8 +346,6 @@ void train_on_game(Game *game, QTable &q_table, ActionSpace &action_space, Train // Otherwise, pick the optimal action from the q-table action = action_pool[action_index]; - - printf("- %llu\n", q_table[gs].size()); } // Generate input from action @@ -398,12 +393,6 @@ void train_on_game(Game *game, QTable &q_table, ActionSpace &action_space, Train } q_table[last_gs][last_action] = q_table[last_gs][last_action] + tc.learning * (tc.discount * q_table[gs][action] - q_table[last_gs][last_action]); - /* - if (game->get_winner_name() != NO_PLAYER) { - //printf("%f\n", q_table[gs][action]); - printf("%f\n", q_table[last_gs][last_action]); - } - */ } // Log last move diff --git a/src/caravan/train.cpp b/src/caravan/train.cpp index 9555b0d..a2c8679 100644 --- a/src/caravan/train.cpp +++ b/src/caravan/train.cpp @@ -34,7 +34,7 @@ int main(int argc, char *argv[]) { // Training parameters TODO user-defined arguments float discount = 0.95; float learning = 0.7; - uint32_t episode_max = 1000000; + uint32_t episode_max = 50000; // Game config uses largest deck with most samples and balance to // maximise chance of encountering every player hand combination. @@ -54,11 +54,6 @@ int main(int argc, char *argv[]) { }; for(; tc.episode <= tc.episode_max; tc.episode++) { - if (tc.episode % 100 == 0) { - printf("Episode %d\n", tc.episode); - printf("- states: %llu\n", q_table.size()); - } - // Random first player rand_first = dist_first_player(gen); gc.player_first = rand_first == NUM_PLAYER_ABC ? @@ -74,6 +69,12 @@ int main(int argc, char *argv[]) { tc.learning = learning; + if (tc.episode % 1000 == 0) { + printf("Episode %d\n", tc.episode); + printf("- explore: %.2f\n", tc.explore); + printf("- states: %llu\n", q_table.size()); + } + // Start a new game game.reset(new Game(&gc)); diff --git a/test/caravan/model/test_game.cpp b/test/caravan/model/test_game.cpp index c80d653..53fc612 100644 --- a/test/caravan/model/test_game.cpp +++ b/test/caravan/model/test_game.cpp @@ -48,7 +48,7 @@ TEST (TestGame, GetPlayerTurn) { }; Game g{&gc}; - ASSERT_EQ(g.get_player_turn(), PLAYER_ABC); + ASSERT_EQ(g.get_player_turn()->get_name(), PLAYER_ABC); } TEST (TestGame, GetWinner_NoMoves) { @@ -59,7 +59,7 @@ TEST (TestGame, GetWinner_NoMoves) { }; Game g{&gc}; - ASSERT_EQ(g.get_winner(), NO_PLAYER); + ASSERT_EQ(g.get_winner_name(), NO_PLAYER); } TEST (TestGame, PlayOption_Error_StartRound_Remove) {