Skip to content

Commit

Permalink
Working with the Distance baked into dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
mlusiak committed May 23, 2021
1 parent 5614143 commit b90a700
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 57 deletions.
31 changes: 2 additions & 29 deletions src/Tyres/DataModels/TyreStint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,40 +24,13 @@ public class TyreStint
public float TrackTemperature;
[LoadColumn(10)]
public string Reason;
[LoadColumn(11)]
public float Laps;
}

public class DistanceTyreStint : TyreStint
{
public float Distance { get; set; }
[LoadColumn(12)]
public float Distance;
}

public class TyreStintPrediction
{
[ColumnName("Score")]
public float Distance;
}

public class LapsToDistanceInput
{
public float TrackLength { get; set; }
public float Laps { get; set; }
}

public class LapsToDistanceOutput
{
public float Distance { get; set; }
}

[CustomMappingFactoryAttribute(nameof(CustomMappings.DistanceMapping))]
public class CustomMappings : CustomMappingFactory<LapsToDistanceInput, LapsToDistanceOutput>
{
public static void DistanceMapping(LapsToDistanceInput input, LapsToDistanceOutput output) => output.Distance = input.Laps * input.TrackLength;

public override Action<LapsToDistanceInput, LapsToDistanceOutput> GetMapping()
{
return DistanceMapping;
}
}
}
2 changes: 1 addition & 1 deletion src/Tyres/DomainModels/Track.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ namespace Tyres.DomainModels
public class Track
{
public string Name { get; set; }
public float Distance { get; set; }
public float TrackLength { get; set; }
}

public class Race
Expand Down
28 changes: 10 additions & 18 deletions src/Tyres/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ public class Program
static void Main(string[] args)
{
var mlContext = new MLContext(seed: 0);
mlContext.ComponentCatalog.RegisterAssembly(typeof(CustomMappings).Assembly);

// Load data
var data = mlContext.Data.LoadFromTextFile<TyreStint>(DatasetsLocation, ';', true);
Expand All @@ -27,36 +26,29 @@ static void Main(string[] args)
var filtered = mlContext.Data.FilterByCustomPredicate(data, (TyreStint row) => !(row.Reason.Equals("Pit Stop") || row.Reason.Equals("Race Finish")) );
var debug = mlContext.Data.CreateEnumerable<TyreStint>(filtered, reuseRowObject: false).Count();

// Transforming to add distancefeature
var customTransformer = mlContext.Transforms.CustomMapping<LapsToDistanceInput, LapsToDistanceOutput>(CustomMappings.DistanceMapping, "LapsToDistance");
var transformed = customTransformer.Fit(filtered).Transform(filtered);
var debug2 = mlContext.Data.CreateEnumerable<DistanceTyreStint>(transformed, reuseRowObject: false).ToList();

// Divide dataset into training and testing data
var split = mlContext.Data.TrainTestSplit(filtered, testFraction: 0.1);
var trainingData = split.TrainSet;
var testingData = split.TestSet;


// Run AutoML experiment
var experimentTime = 60u;
var experimentTime = 900u;
Console.WriteLine("=============== Training the model ===============");
Console.WriteLine($"Running AutoML regression experiment for {experimentTime} seconds...");
var experimentResult = mlContext.Auto()
.CreateRegressionExperiment(experimentTime)
.Execute(trainingData, testingData,
columnInformation: new ColumnInformation()
{
CategoricalColumnNames = { nameof(DistanceTyreStint.Team), nameof(DistanceTyreStint.Car), nameof(DistanceTyreStint.Driver), nameof(DistanceTyreStint.Compound), nameof(DistanceTyreStint.Reason) },
NumericColumnNames = { nameof(DistanceTyreStint.AirTemperature), nameof(DistanceTyreStint.TrackTemperature) },
LabelColumnName = nameof(DistanceTyreStint.Distance)
},
preFeaturizer: customTransformer
);
CategoricalColumnNames = { nameof(TyreStint.Team), nameof(TyreStint.Car), nameof(TyreStint.Driver), nameof(TyreStint.Compound), nameof(TyreStint.Reason) },
NumericColumnNames = { nameof(TyreStint.AirTemperature), nameof(TyreStint.TrackTemperature) },
LabelColumnName = nameof(TyreStint.Distance)
}
);


// Print top models found by AutoML
Console.WriteLine();
TrainingHelper.PrintTopModels(experimentResult);

Console.WriteLine("===== Evaluating model's accuracy with test data =====");
Expand All @@ -65,18 +57,18 @@ static void Main(string[] args)
var trainedModel = best.Model;
var predictions = trainedModel.Transform(testingData);

var metrics = mlContext.Regression.Evaluate(predictions, labelColumnName: nameof(DistanceTyreStint.Distance), scoreColumnName: "Score");
var metrics = mlContext.Regression.Evaluate(predictions, labelColumnName: nameof(TyreStint.Distance), scoreColumnName: "Score");
TrainingHelper.PrintRegressionMetrics(best.TrainerName, metrics);


// Run sample predictions
var predictionEngine = mlContext.Model.CreatePredictionEngine<TyreStint, TyreStintPrediction>(trainedModel);

var lh = new DistanceTyreStint() { Track = "Bahrain International Circuit", TrackLength = 5412f, Team = "Mercedes", Car = "W12", Driver = "Lewis Hamilton", Compound = "C3", AirTemperature = 20.5f, TrackTemperature = 28.3f, Reason = "Pit Stop" };
var lh = new TyreStint() { Track = "Bahrain International Circuit", TrackLength = 5412f, Team = "Mercedes", Car = "W12", Driver = "Lewis Hamilton", Compound = "C3", AirTemperature = 20.5f, TrackTemperature = 28.3f, Reason = "Pit Stop" };
var lhPred = predictionEngine.Predict(lh);
var lhLaps = lhPred.Distance / 5412f;

var mv = new DistanceTyreStint() { Track = "Bahrain International Circuit", TrackLength = 5412f, Team = "Red Bull", Car = "RB16B", Driver = "Max Verstappen", Compound = "C3", AirTemperature = 20.5f, TrackTemperature = 28.3f, Reason = "Pit Stop" };
var mv = new TyreStint() { Track = "Bahrain International Circuit", TrackLength = 5412f, Team = "Red Bull", Car = "RB16B", Driver = "Max Verstappen", Compound = "C3", AirTemperature = 20.5f, TrackTemperature = 28.3f, Reason = "Pit Stop" };
var mvPred = predictionEngine.Predict(mv);
var mvLaps = mvPred.Distance / 5412f;

Expand All @@ -97,7 +89,7 @@ static void Main(string[] args)
new Top10Driver() {Team = "Renault / Alpine", Car = "A521", Name = "Esteban Ocon", StartingCompound = "C5"},
};

DataHelper.PrintPredictionTable(predictionEngine, "Monaco", 25.0f, 43.7f, top10Monaco);
DataHelper.PrintPredictionTable(predictionEngine, "Monaco", 20.0f, 35.0f, top10Monaco);
}
}
}
10 changes: 5 additions & 5 deletions src/Tyres/StaticData/Season2021.cs
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ public static class Season2021

public static Dictionary<string, Track> Tracks = new Dictionary<string, Track>()
{
{ "Bahrain", new Track() {Name = "Bahrain International Circuit", Distance = 5412f} },
{ "Imola", new Track() {Name = "Imola", Distance = 4909f } },
{ "Portimão", new Track() {Name = "Portimão", Distance = 4653f } },
{ "Catalunya", new Track() {Name = "Catalunya", Distance = 4675f } },
{ "Monaco", new Track() {Name = "Monaco", Distance = 3337f } }
{ "Bahrain", new Track() {Name = "Bahrain International Circuit", TrackLength = 5412f} },
{ "Imola", new Track() {Name = "Imola", TrackLength = 4909f } },
{ "Portimão", new Track() {Name = "Portimão", TrackLength = 4653f } },
{ "Catalunya", new Track() {Name = "Catalunya", TrackLength = 4675f } },
{ "Monaco", new Track() {Name = "Monaco", TrackLength = 3337f } }
};
}
}
10 changes: 6 additions & 4 deletions src/Tyres/ViewHelpers/DataHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,16 @@ public class DataHelper
{
public static void PrintPredictionTable(PredictionEngine<TyreStint, TyreStintPrediction> predictionEngine, string trackName, float airTemperature, float trackTemperature, List<Top10Driver> drivers)
{
Console.WriteLine(Season2021.Tracks[trackName].Name);
var monaco = Season2021.Tracks[trackName];

Console.WriteLine(monaco.Name);
Console.WriteLine("==========");
foreach (var d in drivers)
{
var prediction = predictionEngine.Predict(new TyreStint()
{
Track = Season2021.Tracks[trackName].Name,
TrackLength = Season2021.Tracks[trackName].Distance,
Track = monaco.Name,
TrackLength = monaco.TrackLength,
Team = d.Team,
Car = d.Car,
Driver = d.Name,
Expand All @@ -27,7 +29,7 @@ public static void PrintPredictionTable(PredictionEngine<TyreStint, TyreStintPre
TrackTemperature = trackTemperature,
Reason = "Pit Stop"
});
Console.WriteLine($"| {d.Name} | {d.StartingCompound} | {prediction.Distance } | | |");
Console.WriteLine($"| {d.Name} | {d.StartingCompound} | {prediction.Distance / monaco.TrackLength } | | |");
}
Console.WriteLine("");
}
Expand Down
1 change: 1 addition & 0 deletions src/Tyres/ViewHelpers/TrainingHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ public class TrainingHelper
{
public static void PrintTopModels(ExperimentResult<RegressionMetrics> experimentResult)
{
Console.WriteLine();
// Get top few runs ranked by R-Squared.
// R-Squared is a metric to maximize, so OrderByDescending() is correct.
// For RMSE and other regression metrics, OrderByAscending() is correct.
Expand Down

0 comments on commit b90a700

Please sign in to comment.