Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Improvements for Dataset Logger and its javadocs #6595

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions megamek/src/megamek/ai/dataset/ActionAndState.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@

/**
* Represents an action and the state of the board after the action is performed.
* @param round
* @param unitAction
* @param boardUnitState
* @param round game round
* @param unitAction unit action performed
* @param boardUnitState state of the board when the action is performed
* @author Luana Coppio
*/
public record ActionAndState(int round, UnitAction unitAction, List<UnitState> boardUnitState){}
25 changes: 12 additions & 13 deletions megamek/src/megamek/ai/dataset/DatasetParser.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,31 @@
*/
package megamek.ai.dataset;

import megamek.common.*;
import megamek.common.Entity;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/**
* Parses a dataset from one or more files and turns it into a training dataset.
* The dataset currently expected is the game_actions_log.tsv file generated by the megamek client hosting the game if the
* option is enabled
* <p>Parses a dataset from one or more game_action tsv files and turns it into a training dataset.</p>
* <p>The dataset currently expected is the {@code game_actions*.tsv} file generated by the host of the game if the
* option is enabled</p>
* <p>This Dataset Parser handles multiple files by ingesting them one after the other, and changing the ID of each unit by an offset,
* allowing to load the full data of all datasets as if it were a single one.</p>
* @author Luana Coppio
*/
public class DatasetParser {

private enum LineType {
MOVE_ACTION_HEADER_V1("PLAYER_ID\tENTITY_ID\tCHASSIS\tMODEL\tFACING\tFROM_X\tFROM_Y\tTO_X\tTO_Y\tHEXES_MOVED\tDISTANCE" +
"\tMP_USED\tMAX_MP\tMP_P\tHEAT_P\tARMOR_P\tINTERNAL_P\tJUMPING\tPRONE\tLEGAL\tSTEPS"),
MOVE_ACTION_HEADER_V2(UnitActionField.getHeaderLine()),
MOVE_ACTION_HEADER_V2(UnitActionField.getPartialHeaderLine(0, 22)),
MOVE_ACTION_HEADER_V3(UnitActionField.getHeaderLine()),
STATE_HEADER_V1("ROUND\tPHASE\tPLAYER_ID\tENTITY_ID\tCHASSIS\tMODEL\tTYPE\tROLE\tX\tY\tFACING\tMP\tHEAT\tPRONE\tAIRBORNE" +
"\tOFF_BOARD\tCRIPPLED\tDESTROYED\tARMOR_P\tINTERNAL_P\tDONE"),
STATE_HEADER_V2("ROUND\tPHASE\tTEAM_ID\tPLAYER_ID\tENTITY_ID\tCHASSIS\tMODEL\tTYPE\tROLE\tX\tY\tFACING\tMP\tHEAT\tPRONE" +
Expand Down Expand Up @@ -75,7 +77,8 @@ public DatasetParser parse(File file) {
try (BufferedReader reader = new BufferedReader(new FileReader(file))) {
String line;
while ((line = reader.readLine()) != null) {
if (line.startsWith(LineType.MOVE_ACTION_HEADER_V1.getText()) || line.startsWith(LineType.MOVE_ACTION_HEADER_V2.getText())) {
if (line.startsWith(LineType.MOVE_ACTION_HEADER_V1.getText()) || line.startsWith(LineType.MOVE_ACTION_HEADER_V2.getText())
|| line.startsWith(LineType.MOVE_ACTION_HEADER_V3.getText())) {
// Parse action line
String actionLine = reader.readLine();
if (actionLine == null) break;
Expand Down Expand Up @@ -136,9 +139,7 @@ public TrainingDataset getTrainingDataset() {
private UnitAction parseActionLine(String actionLine) {
try {
UnitAction unitAction = unitActionSerde.fromTsv(actionLine, idOffset);
if (highestEntityId < unitAction.id()) {
highestEntityId = unitAction.id();
}
highestEntityId = Math.max(highestEntityId, unitAction.id());
return unitAction;
} catch (Exception e) {
throw new RuntimeException("Error parsing action line: " + actionLine, e);
Expand All @@ -148,9 +149,7 @@ private UnitAction parseActionLine(String actionLine) {
private UnitState parseStateLine(String stateLine) {
try {
UnitState unitState = unitStateSerde.fromTsv(stateLine, entities, idOffset);
if (highestEntityId < unitState.id()) {
highestEntityId = unitState.id();
}
highestEntityId = Math.max(highestEntityId, unitState.id());
return unitState;
} catch (Exception e) {
throw new RuntimeException("Error parsing state line: " + stateLine, e);
Expand Down
49 changes: 42 additions & 7 deletions megamek/src/megamek/ai/dataset/TrainingDataset.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@
import java.util.stream.Collectors;

/**
* Represents a training dataset.
* This dataset is used to train a model optimize the action selection of a bot player.
* <p>Represents a training dataset.
* This dataset is used to train a model optimize the action selection of a bot player.</p>
* <p>This has the particularity of having two lists internally, they represent the same
* dataset but are offset by one round, this is a peculiarity of its internal representation so you should not worry about it
* unless you plan to use its iterator.</p>
*
* @author Luana Coppio
*/
public class TrainingDataset {
Expand Down Expand Up @@ -59,14 +63,14 @@ public ActionAndState next() {

/**
* Create a new training dataset from a list of action and state pairs.
* This will filter out any actions that do not have a corresponding state for the player.
* This will discard all actions taken by a player different from player 0
* This will discard all actions taken by non-human players
* @param actionAndStates The list of action and state pairs.
*/
public TrainingDataset(List<ActionAndState> actionAndStates) {
Set<Integer> entityIds = actionAndStates.stream().map(ActionAndState::unitAction).map(UnitAction::id).collect(Collectors.toSet());
for (int entityId : entityIds) {
List<ActionAndState> actionsForEntity = actionAndStates.stream()
.filter(actionAndState -> actionAndState.unitAction().isHuman())
.filter(actionAndState -> actionAndState.unitAction().id() == entityId)
.filter(actionAndState -> actionAndState.boardUnitState().stream().anyMatch(u -> u.id() == entityId))
.toList();
Expand All @@ -80,11 +84,22 @@ public TrainingDataset(List<ActionAndState> actionAndStates) {
}
}

/**
* <p>Create a new training dataset from two lists of action state pairs, the second list is the state of the game in the round after
* that action was made.</p>
* This expects that all action and state pairs are human actions.
* @param actionAndStates The list of action and state pairs.
* @param nextRoundActionAndState The list of action and state pairs for the next round.
*/
public TrainingDataset(List<ActionAndState> actionAndStates, List<ActionAndState> nextRoundActionAndState) {
this.actionAndStates.addAll(actionAndStates);
this.nextRoundActionAndState.addAll(nextRoundActionAndState);
}

/**
* Get the height of the board.
* @return The height of the board.
*/
public int boardHeight() {
int maxY = Integer.MIN_VALUE;
for (var actionState : actionAndStates) {
Expand All @@ -98,6 +113,10 @@ public int boardHeight() {
return maxY + 5;
}

/**
* Get the width of the board.
* @return The width of the board.
*/
public int boardWidth() {
int maxX = Integer.MIN_VALUE;
for (var actionState : actionAndStates) {
Expand All @@ -111,10 +130,20 @@ public int boardWidth() {
return maxX + 5;
}

/**
* <p>Get the size of the training dataset.</p>
* Even though internally it has two lists, the one of current action and another with the state of the game in the following
* round, this only considers the number of actions in one of the lists.
* @return The size of the training dataset.
*/
public int size() {
return Math.min(actionAndStates.size(), nextRoundActionAndState.size());
return actionAndStates.size();
}

/**
* Check if the training dataset is empty.
* @return True if the training dataset is empty.
*/
public boolean isEmpty() {
return actionAndStates.isEmpty() || nextRoundActionAndState.isEmpty();
}
Expand All @@ -123,15 +152,21 @@ public boolean isEmpty() {
* Get an iterator for the training dataset.
* This iterator always returns the current state and then the next state in the dataset,
* giving you access the state of the board before and after an action.
* <p>This iterator will return the action and state for the current round and then the next round for the same unit.</p>
* <p>It always return in this two step way, current actionAndState followed by next round actionAndState, this is necessary
* because during training, we need to know the direct result of the action taken to judge if it was a good decision or bad.</p>
* @return An iterator for the training dataset.
*/
public Iterator<ActionAndState> iterator() {
return new TrainingDatasetIterator(this);
}

/**
* Sample a training dataset with a given batch size.
* This will randomly sample the dataset, the same index will not be sampled twice. It is returned out of order.
* Sample a training dataset with a given {@code batchSize}.
* This will not modify the original dataset,
* and it will randomly sample the dataset.
* The same index will not be sampled twice.
* It is returned out of order.
* @param batchSize The batch size.
* @return A new training dataset with the sampled data.
*/
Expand Down
16 changes: 16 additions & 0 deletions megamek/src/megamek/ai/dataset/TsvSerde.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,24 @@
*/
package megamek.ai.dataset;

/**
* <p>Abstract class to serialize/deserialize objects to/from TSV format.</p>
* <p>It does not have a fromTsv function because I could not find a way to make a good API for it.</p>
* @param <T> type of object to serialize/deserialize
* @author Luana Coppio
*/
public abstract class TsvSerde<T> {

/**
* Serializes an object to TSV format.
* @param obj object to serialize
* @return the object serialized in TSV format
*/
public abstract String toTsv(T obj);

/**
* Returns the header line for the TSV format.
* @return the header line
*/
public abstract String getHeaderLine();
}
9 changes: 7 additions & 2 deletions megamek/src/megamek/ai/dataset/UnitAction.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
*/
public record UnitAction(int id, int teamId, int playerId, String chassis, String model, int facing, int fromX, int fromY, int toX, int toY, int hexesMoved, int distance, int mpUsed,
int maxMp, double mpP, double heatP, double armorP, double internalP, boolean jumping, boolean prone,
boolean legal, double chanceOfFailure, List<MovePath.MoveStepType> steps) {
boolean legal, double chanceOfFailure, List<MovePath.MoveStepType> steps, boolean bot) {

public static UnitAction fromMovePath(MovePath movePath) {
Entity entity = movePath.getEntity();
Expand Down Expand Up @@ -74,14 +74,19 @@ public static UnitAction fromMovePath(MovePath movePath) {
movePath.getFinalProne(),
movePath.isMoveLegal(),
chanceOfFailure,
steps
steps,
entity.getOwner().isBot()
);
}

public Coords currentPosition() {
return new Coords(fromX, fromY);
}

public boolean isHuman() {
return !bot;
}

public Coords finalPosition() {
return new Coords(toX, toY);
}
Expand Down
14 changes: 13 additions & 1 deletion megamek/src/megamek/ai/dataset/UnitActionField.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ public enum UnitActionField {
LEGAL("LEGAL"),
STEPS("STEPS"),
TEAM_ID("TEAM_ID"),
CHANCE_OF_FAILURE("CHANCE_OF_FAILURE");
CHANCE_OF_FAILURE("CHANCE_OF_FAILURE"),
IS_BOT("IS_BOT");

private final String headerName;

Expand All @@ -60,4 +61,15 @@ public static String getHeaderLine() {
.map(UnitActionField::getHeaderName)
.collect(Collectors.joining("\t"));
}

/**
* Builds the TSV header line (joined by tabs) by iterating over all enum constants.
*/
public static String getPartialHeaderLine(int startsAt, int endsAt) {
return Arrays.stream(values())
.skip(startsAt)
.limit(endsAt - startsAt)
.map(UnitActionField::getHeaderName)
.collect(Collectors.joining("\t"));
}
}
9 changes: 6 additions & 3 deletions megamek/src/megamek/ai/dataset/UnitActionSerde.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,8 @@
import java.util.List;
import java.util.stream.Collectors;


/**
* serializer and deserializer for game dataset UnitAction
* <p>serializer and deserializer for UnitAction to/from TSV format.</p>
* @author Luana Coppio
*/
public class UnitActionSerde extends TsvSerde<UnitAction> {
Expand Down Expand Up @@ -57,6 +56,7 @@ public String toTsv(UnitAction obj) {
row[UnitActionField.PRONE.ordinal()] = obj.prone() ? "1" : "0";
row[UnitActionField.LEGAL.ordinal()] = obj.legal() ? "1" : "0";
row[UnitActionField.CHANCE_OF_FAILURE.ordinal()] = LOG_DECIMAL.format(obj.chanceOfFailure());
row[UnitActionField.IS_BOT.ordinal()] = obj.bot() ? "1" : "0";

// For STEPS, join the list of MoveStepType values with a space.
row[UnitActionField.STEPS.ordinal()] = obj.steps().stream()
Expand Down Expand Up @@ -88,11 +88,13 @@ public UnitAction fromTsv(String line, int idOffset) throws NumberFormatExceptio
boolean jumping = "1".equals(parts[UnitActionField.JUMPING.ordinal()]);
boolean prone = "1".equals(parts[UnitActionField.PRONE.ordinal()]);
boolean legal = "1".equals(parts[UnitActionField.LEGAL.ordinal()]);
boolean bot = false;
int teamId = -1;
double chanceOfFailure = 0.0;
if (parts.length >= 23) {
teamId = Integer.parseInt(parts[UnitActionField.TEAM_ID.ordinal()]);
chanceOfFailure = Double.parseDouble(parts[UnitActionField.CHANCE_OF_FAILURE.ordinal()]);
bot = "1".equals(parts[UnitActionField.IS_BOT.ordinal()]);
}
// Convert the steps field (a space-separated list) back to a List of MoveStepType.
List<MovePath.MoveStepType> steps = Arrays.stream(
Expand Down Expand Up @@ -124,7 +126,8 @@ public UnitAction fromTsv(String line, int idOffset) throws NumberFormatExceptio
prone,
legal,
chanceOfFailure,
steps
steps,
bot
);
}

Expand Down
10 changes: 10 additions & 0 deletions megamek/src/megamek/ai/dataset/UnitState.java
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ public record UnitState(int id, GamePhase phase, int teamId, int round, int play
boolean offBoard, boolean crippled, boolean destroyed, double armorP,
double internalP, boolean done, int maxRange, int totalDamage, Entity entity) {

/**
* Creates a UnitState from an {@code entity}.
* @param entity The entity to which the state belongs
* @param game The game reference
* @return The UnitState
*/
public static UnitState fromEntity(Entity entity, Game game) {
return new UnitState(
entity.getId(),
Expand Down Expand Up @@ -79,6 +85,10 @@ public static UnitState fromEntity(Entity entity, Game game) {
entity);
}

/**
* Returns the position of the unit.
* @return The position
*/
public Coords position() {
return new Coords(x, y);
}
Expand Down
2 changes: 1 addition & 1 deletion megamek/src/megamek/ai/dataset/UnitStateSerde.java
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import java.util.Map;

/**
* serializer and deserializer for game dataset UnitState
* <p>serializer and deserializer for UnitState to/from TSV format.</p>
* @author Luana Coppio
*/
public class UnitStateSerde extends TsvSerde<UnitState> {
Expand Down
Loading