Skip to content

Commit

Permalink
Merge pull request #6595 from Scoppio/fix/gameactions-log-overwriteap…
Browse files Browse the repository at this point in the history
…pend-only

Fix: Improvements for Dataset Logger and its javadocs
  • Loading branch information
HammerGS authored Feb 24, 2025
2 parents 10c1647 + d049701 commit a9edb26
Show file tree
Hide file tree
Showing 12 changed files with 158 additions and 49 deletions.
6 changes: 3 additions & 3 deletions megamek/src/megamek/ai/dataset/ActionAndState.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@

/**
* Represents an action and the state of the board after the action is performed.
* @param round
* @param unitAction
* @param boardUnitState
* @param round game round
* @param unitAction unit action performed
* @param boardUnitState state of the board when the action is performed
* @author Luana Coppio
*/
public record ActionAndState(int round, UnitAction unitAction, List<UnitState> boardUnitState){}
25 changes: 12 additions & 13 deletions megamek/src/megamek/ai/dataset/DatasetParser.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,31 @@
*/
package megamek.ai.dataset;

import megamek.common.*;
import megamek.common.Entity;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
* Parses a dataset from one or more files and turns it into a training dataset.
* The dataset currently expected is the game_actions_log.tsv file generated by the megamek client hosting the game if the
* option is enabled
* <p>Parses a dataset from one or more game_action tsv files and turns it into a training dataset.</p>
* <p>The dataset currently expected is the {@code game_actions*.tsv} file generated by the host of the game if the
* option is enabled</p>
* <p>This Dataset Parser handles multiple files by ingesting them one after the other, and changing the ID of each unit by an offset,
* allowing to load the full data of all datasets as if it were a single one.</p>
* @author Luana Coppio
*/
public class DatasetParser {

private enum LineType {
MOVE_ACTION_HEADER_V1("PLAYER_ID\tENTITY_ID\tCHASSIS\tMODEL\tFACING\tFROM_X\tFROM_Y\tTO_X\tTO_Y\tHEXES_MOVED\tDISTANCE" +
"\tMP_USED\tMAX_MP\tMP_P\tHEAT_P\tARMOR_P\tINTERNAL_P\tJUMPING\tPRONE\tLEGAL\tSTEPS"),
MOVE_ACTION_HEADER_V2(UnitActionField.getHeaderLine()),
MOVE_ACTION_HEADER_V2(UnitActionField.getPartialHeaderLine(0, 22)),
MOVE_ACTION_HEADER_V3(UnitActionField.getHeaderLine()),
STATE_HEADER_V1("ROUND\tPHASE\tPLAYER_ID\tENTITY_ID\tCHASSIS\tMODEL\tTYPE\tROLE\tX\tY\tFACING\tMP\tHEAT\tPRONE\tAIRBORNE" +
"\tOFF_BOARD\tCRIPPLED\tDESTROYED\tARMOR_P\tINTERNAL_P\tDONE"),
STATE_HEADER_V2("ROUND\tPHASE\tTEAM_ID\tPLAYER_ID\tENTITY_ID\tCHASSIS\tMODEL\tTYPE\tROLE\tX\tY\tFACING\tMP\tHEAT\tPRONE" +
Expand Down Expand Up @@ -75,7 +77,8 @@ public DatasetParser parse(File file) {
try (BufferedReader reader = new BufferedReader(new FileReader(file))) {
String line;
while ((line = reader.readLine()) != null) {
if (line.startsWith(LineType.MOVE_ACTION_HEADER_V1.getText()) || line.startsWith(LineType.MOVE_ACTION_HEADER_V2.getText())) {
if (line.startsWith(LineType.MOVE_ACTION_HEADER_V1.getText()) || line.startsWith(LineType.MOVE_ACTION_HEADER_V2.getText())
|| line.startsWith(LineType.MOVE_ACTION_HEADER_V3.getText())) {
// Parse action line
String actionLine = reader.readLine();
if (actionLine == null) break;
Expand Down Expand Up @@ -136,9 +139,7 @@ public TrainingDataset getTrainingDataset() {
private UnitAction parseActionLine(String actionLine) {
try {
UnitAction unitAction = unitActionSerde.fromTsv(actionLine, idOffset);
if (highestEntityId < unitAction.id()) {
highestEntityId = unitAction.id();
}
highestEntityId = Math.max(highestEntityId, unitAction.id());
return unitAction;
} catch (Exception e) {
throw new RuntimeException("Error parsing action line: " + actionLine, e);
Expand All @@ -148,9 +149,7 @@ private UnitAction parseActionLine(String actionLine) {
private UnitState parseStateLine(String stateLine) {
try {
UnitState unitState = unitStateSerde.fromTsv(stateLine, entities, idOffset);
if (highestEntityId < unitState.id()) {
highestEntityId = unitState.id();
}
highestEntityId = Math.max(highestEntityId, unitState.id());
return unitState;
} catch (Exception e) {
throw new RuntimeException("Error parsing state line: " + stateLine, e);
Expand Down
49 changes: 42 additions & 7 deletions megamek/src/megamek/ai/dataset/TrainingDataset.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@
import java.util.stream.Collectors;

/**
* Represents a training dataset.
* This dataset is used to train a model optimize the action selection of a bot player.
* <p>Represents a training dataset.
* This dataset is used to train a model optimize the action selection of a bot player.</p>
* <p>This has the particularity of having two lists internally, they represent the same
* dataset but are offset by one round, this is a peculiarity of its internal representation so you should not worry about it
* unless you plan to use its iterator.</p>
*
* @author Luana Coppio
*/
public class TrainingDataset {
Expand Down Expand Up @@ -59,14 +63,14 @@ public ActionAndState next() {

/**
* Create a new training dataset from a list of action and state pairs.
* This will filter out any actions that do not have a corresponding state for the player.
* This will discard all actions taken by a player different from player 0
* This will discard all actions taken by non-human players
* @param actionAndStates The list of action and state pairs.
*/
public TrainingDataset(List<ActionAndState> actionAndStates) {
Set<Integer> entityIds = actionAndStates.stream().map(ActionAndState::unitAction).map(UnitAction::id).collect(Collectors.toSet());
for (int entityId : entityIds) {
List<ActionAndState> actionsForEntity = actionAndStates.stream()
.filter(actionAndState -> actionAndState.unitAction().isHuman())
.filter(actionAndState -> actionAndState.unitAction().id() == entityId)
.filter(actionAndState -> actionAndState.boardUnitState().stream().anyMatch(u -> u.id() == entityId))
.toList();
Expand All @@ -80,11 +84,22 @@ public TrainingDataset(List<ActionAndState> actionAndStates) {
}
}

/**
* <p>Create a new training dataset from two lists of action state pairs, the second list is the state of the game in the round after
* that action was made.</p>
* This expects that all action and state pairs are human actions.
* @param actionAndStates The list of action and state pairs.
* @param nextRoundActionAndState The list of action and state pairs for the next round.
*/
public TrainingDataset(List<ActionAndState> actionAndStates, List<ActionAndState> nextRoundActionAndState) {
this.actionAndStates.addAll(actionAndStates);
this.nextRoundActionAndState.addAll(nextRoundActionAndState);
}

/**
* Get the height of the board.
* @return The height of the board.
*/
public int boardHeight() {
int maxY = Integer.MIN_VALUE;
for (var actionState : actionAndStates) {
Expand All @@ -98,6 +113,10 @@ public int boardHeight() {
return maxY + 5;
}

/**
* Get the width of the board.
* @return The width of the board.
*/
public int boardWidth() {
int maxX = Integer.MIN_VALUE;
for (var actionState : actionAndStates) {
Expand All @@ -111,10 +130,20 @@ public int boardWidth() {
return maxX + 5;
}

/**
* <p>Get the size of the training dataset.</p>
* Even though internally it has two lists, the one of current action and another with the state of the game in the following
* round, this only considers the number of actions in one of the lists.
* @return The size of the training dataset.
*/
public int size() {
return Math.min(actionAndStates.size(), nextRoundActionAndState.size());
return actionAndStates.size();
}

/**
* Check if the training dataset is empty.
* @return True if the training dataset is empty.
*/
public boolean isEmpty() {
return actionAndStates.isEmpty() || nextRoundActionAndState.isEmpty();
}
Expand All @@ -123,15 +152,21 @@ public boolean isEmpty() {
* Get an iterator for the training dataset.
* This iterator always returns the current state and then the next state in the dataset,
* giving you access the state of the board before and after an action.
* <p>This iterator will return the action and state for the current round and then the next round for the same unit.</p>
* <p>It always return in this two step way, current actionAndState followed by next round actionAndState, this is necessary
* because during training, we need to know the direct result of the action taken to judge if it was a good decision or bad.</p>
* @return An iterator for the training dataset.
*/
public Iterator<ActionAndState> iterator() {
return new TrainingDatasetIterator(this);
}

/**
* Sample a training dataset with a given batch size.
* This will randomly sample the dataset, the same index will not be sampled twice. It is returned out of order.
* Sample a training dataset with a given {@code batchSize}.
* This will not modify the original dataset,
* and it will randomly sample the dataset.
* The same index will not be sampled twice.
* It is returned out of order.
* @param batchSize The batch size.
* @return A new training dataset with the sampled data.
*/
Expand Down
16 changes: 16 additions & 0 deletions megamek/src/megamek/ai/dataset/TsvSerde.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,24 @@
*/
package megamek.ai.dataset;

/**
* <p>Abstract class to serialize/deserialize objects to/from TSV format.</p>
* <p>It does not have a fromTsv function because I could not find a way to make a good API for it.</p>
* @param <T> type of object to serialize/deserialize
* @author Luana Coppio
*/
public abstract class TsvSerde<T> {

/**
* Serializes an object to TSV format.
* @param obj object to serialize
* @return the object serialized in TSV format
*/
public abstract String toTsv(T obj);

/**
* Returns the header line for the TSV format.
* @return the header line
*/
public abstract String getHeaderLine();
}
9 changes: 7 additions & 2 deletions megamek/src/megamek/ai/dataset/UnitAction.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
*/
public record UnitAction(int id, int teamId, int playerId, String chassis, String model, int facing, int fromX, int fromY, int toX, int toY, int hexesMoved, int distance, int mpUsed,
int maxMp, double mpP, double heatP, double armorP, double internalP, boolean jumping, boolean prone,
boolean legal, double chanceOfFailure, List<MovePath.MoveStepType> steps) {
boolean legal, double chanceOfFailure, List<MovePath.MoveStepType> steps, boolean bot) {

public static UnitAction fromMovePath(MovePath movePath) {
Entity entity = movePath.getEntity();
Expand Down Expand Up @@ -74,14 +74,19 @@ public static UnitAction fromMovePath(MovePath movePath) {
movePath.getFinalProne(),
movePath.isMoveLegal(),
chanceOfFailure,
steps
steps,
entity.getOwner().isBot()
);
}

public Coords currentPosition() {
return new Coords(fromX, fromY);
}

public boolean isHuman() {
return !bot;
}

public Coords finalPosition() {
return new Coords(toX, toY);
}
Expand Down
14 changes: 13 additions & 1 deletion megamek/src/megamek/ai/dataset/UnitActionField.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ public enum UnitActionField {
LEGAL("LEGAL"),
STEPS("STEPS"),
TEAM_ID("TEAM_ID"),
CHANCE_OF_FAILURE("CHANCE_OF_FAILURE");
CHANCE_OF_FAILURE("CHANCE_OF_FAILURE"),
IS_BOT("IS_BOT");

private final String headerName;

Expand All @@ -60,4 +61,15 @@ public static String getHeaderLine() {
.map(UnitActionField::getHeaderName)
.collect(Collectors.joining("\t"));
}

/**
* Builds the TSV header line (joined by tabs) by iterating over all enum constants.
*/
public static String getPartialHeaderLine(int startsAt, int endsAt) {
return Arrays.stream(values())
.skip(startsAt)
.limit(endsAt - startsAt)
.map(UnitActionField::getHeaderName)
.collect(Collectors.joining("\t"));
}
}
9 changes: 6 additions & 3 deletions megamek/src/megamek/ai/dataset/UnitActionSerde.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@
import java.util.List;
import java.util.stream.Collectors;


/**
* serializer and deserializer for game dataset UnitAction
* <p>serializer and deserializer for UnitAction to/from TSV format.</p>
* @author Luana Coppio
*/
public class UnitActionSerde extends TsvSerde<UnitAction> {
Expand Down Expand Up @@ -57,6 +56,7 @@ public String toTsv(UnitAction obj) {
row[UnitActionField.PRONE.ordinal()] = obj.prone() ? "1" : "0";
row[UnitActionField.LEGAL.ordinal()] = obj.legal() ? "1" : "0";
row[UnitActionField.CHANCE_OF_FAILURE.ordinal()] = LOG_DECIMAL.format(obj.chanceOfFailure());
row[UnitActionField.IS_BOT.ordinal()] = obj.bot() ? "1" : "0";

// For STEPS, join the list of MoveStepType values with a space.
row[UnitActionField.STEPS.ordinal()] = obj.steps().stream()
Expand Down Expand Up @@ -88,11 +88,13 @@ public UnitAction fromTsv(String line, int idOffset) throws NumberFormatExceptio
boolean jumping = "1".equals(parts[UnitActionField.JUMPING.ordinal()]);
boolean prone = "1".equals(parts[UnitActionField.PRONE.ordinal()]);
boolean legal = "1".equals(parts[UnitActionField.LEGAL.ordinal()]);
boolean bot = false;
int teamId = -1;
double chanceOfFailure = 0.0;
if (parts.length >= 23) {
teamId = Integer.parseInt(parts[UnitActionField.TEAM_ID.ordinal()]);
chanceOfFailure = Double.parseDouble(parts[UnitActionField.CHANCE_OF_FAILURE.ordinal()]);
bot = "1".equals(parts[UnitActionField.IS_BOT.ordinal()]);
}
// Convert the steps field (a space-separated list) back to a List of MoveStepType.
List<MovePath.MoveStepType> steps = Arrays.stream(
Expand Down Expand Up @@ -124,7 +126,8 @@ public UnitAction fromTsv(String line, int idOffset) throws NumberFormatExceptio
prone,
legal,
chanceOfFailure,
steps
steps,
bot
);
}

Expand Down
10 changes: 10 additions & 0 deletions megamek/src/megamek/ai/dataset/UnitState.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ public record UnitState(int id, GamePhase phase, int teamId, int round, int play
boolean offBoard, boolean crippled, boolean destroyed, double armorP,
double internalP, boolean done, int maxRange, int totalDamage, Entity entity) {

/**
* Creates a UnitState from an {@code entity}.
* @param entity The entity to which the state belongs
* @param game The game reference
* @return The UnitState
*/
public static UnitState fromEntity(Entity entity, Game game) {
return new UnitState(
entity.getId(),
Expand Down Expand Up @@ -79,6 +85,10 @@ public static UnitState fromEntity(Entity entity, Game game) {
entity);
}

/**
* Returns the position of the unit.
* @return The position
*/
public Coords position() {
return new Coords(x, y);
}
Expand Down
2 changes: 1 addition & 1 deletion megamek/src/megamek/ai/dataset/UnitStateSerde.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import java.util.Map;

/**
* serializer and deserializer for game dataset UnitState
* <p>serializer and deserializer for UnitState to/from TSV format.</p>
* @author Luana Coppio
*/
public class UnitStateSerde extends TsvSerde<UnitState> {
Expand Down
Loading

0 comments on commit a9edb26

Please sign in to comment.