Skip to content

Commit

Permalink
Reorganize the code, prepare for Monaco GP
Browse files Browse the repository at this point in the history
  • Loading branch information
mlusiak committed May 23, 2021
1 parent b9baca0 commit cfce520
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 207 deletions.
5 changes: 5 additions & 0 deletions src/Tyres/DomainModels/Driver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,9 @@ public class Driver
public string Car { get; set; }

}

public class Top10Driver : Driver
{
public string StartingCompound { get; set; }
}
}
229 changes: 22 additions & 207 deletions src/Tyres/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
using Tyres.DataModels;
using Tyres.DomainModels;
using Tyres.StaticData;
using Tyres.ViewHelpers;

namespace Tyres
{
public class Program
{
private static string DatasetsLocation = @"../../../../../data/Tyres.csv";
private static string DatasetsLocation = @"../../../../../data/TyreStints.csv";

static void Main(string[] args)
{
Expand Down Expand Up @@ -62,26 +63,25 @@ static void Main(string[] args)
*/

uint experimentTime = 120;
var experimentTime = 600u;
Console.WriteLine("=============== Training the model ===============");
Console.WriteLine($"Running AutoML regression experiment for {experimentTime} seconds...");
ExperimentResult<RegressionMetrics> experimentResult = mlContext.Auto()
var experimentResult = mlContext.Auto()
.CreateRegressionExperiment(experimentTime)
.Execute(trainingData, progressHandler: null, labelColumnName: "Laps");

// Print top models found by AutoML
Console.WriteLine();
PrintTopModels(experimentResult);
TrainingHelper.PrintTopModels(experimentResult);

Console.WriteLine("===== Evaluating model's accuracy with test data =====");
RunDetail<RegressionMetrics> best = experimentResult.BestRun;
var best = experimentResult.BestRun;

ITransformer trainedModel = best.Model;
IDataView predictions = trainedModel.Transform(testingData);
var trainedModel = best.Model;
var predictions = trainedModel.Transform(testingData);

var metrics = mlContext.Regression.Evaluate(predictions, labelColumnName: "Laps", scoreColumnName: "Score");
PrintRegressionMetrics(best.TrainerName, metrics);

TrainingHelper.PrintRegressionMetrics(best.TrainerName, metrics);


// Run sample predictions
Expand All @@ -96,208 +96,23 @@ static void Main(string[] args)
var mvLaps = mvPred.Distance / 5412f;


// Run predictions for Bahrain
var bahrain2021 = new Race()
{
Track = Season2021.Tracks["Bahrain"],
Drivers = Season2021.Drivers,
TyreCompounds = new List<string>() { "C2", "C3", "C4" },
AirTemperature = 20.5f,
TrackTemperature = 28.3f
};
//PrintAllPredictions(predictionEngine, bahrain2021);

PrintTop10Catalunya(predictionEngine);
}

private static void PrintAllPredictions(PredictionEngine<TyreStint, TyreStintPrediction> predictionEngine, Race race)
{
foreach (var d in race.Drivers)
{
foreach (var c in race.TyreCompounds)
{
var prediction = predictionEngine.Predict(new TyreStint()
{
Track = race.Track.Name,
TrackLength = race.Track.Distance,
Team = d.Team,
Car = d.Car,
Driver = d.Name,
Compound = c,
AirTemperature = race.AirTemperature,
TrackTemperature = race.TrackTemperature,
Reason = "Pit Stop"
});
Console.WriteLine($"| {d.Name} | {c} | {prediction.Distance / race.Track.Distance} | |");
}
}

}


private static void PrintTop10Catalunya(PredictionEngine<TyreStint, TyreStintPrediction> predictionEngine)
{
var Top10Catalunya = new List<Top10Driver>()
var top10Monaco = new List<Top10Driver>()
{
new Top10Driver() {Team = "Red Bull", Car = "RB16B", Name = "Max Verstappen", StartingCompound = "C3"},
new Top10Driver() {Team = "Mercedes", Car = "W12", Name = "Valtteri Bottas", StartingCompound = "C3"},
new Top10Driver() {Team = "Mercedes", Car = "W12", Name = "Lewis Hamilton", StartingCompound = "C3"},
new Top10Driver() {Team = "Ferrari", Car = "SF21", Name = "Carlos Sainz", StartingCompound = "C3"},
new Top10Driver() {Team = "Red Bull", Car = "RB16B", Name = "Sergio Pérez", StartingCompound = "C3"},
new Top10Driver() {Team = "McLaren", Car = "MCL35M", Name = "Lando Norris", StartingCompound = "C3"},
new Top10Driver() {Team = "Ferrari", Car = "SF21", Name = "Charles Leclerc", StartingCompound = "C3"},
new Top10Driver() {Team = "McLaren", Car = "MCL35M", Name = "Daniel Ricciardo", StartingCompound = "C3"},
new Top10Driver() {Team = "Renault / Alpine", Car = "A521", Name = "Esteban Ocon", StartingCompound = "C3"},
new Top10Driver() {Team = "Renault / Alpine", Car = "A521", Name = "Fernando Alonso", StartingCompound = "C3"},
new Top10Driver() {Team = "Ferrari", Car = "SF21", Name = "Charles Leclerc", StartingCompound = "C5"},
new Top10Driver() {Team = "Red Bull", Car = "RB16B", Name = "Max Verstappen", StartingCompound = "C5"},
new Top10Driver() {Team = "Mercedes", Car = "W12", Name = "Valtteri Bottas", StartingCompound = "C5"},
new Top10Driver() {Team = "Ferrari", Car = "SF21", Name = "Carlos Sainz", StartingCompound = "C5"},
new Top10Driver() {Team = "McLaren", Car = "MCL35M", Name = "Lando Norris", StartingCompound = "C5"},
new Top10Driver() {Team = "Toro Rosso / AlphaTauri", Car = "AT02", Name = "Pierre Gasly", StartingCompound = "C5"},
new Top10Driver() {Team = "Mercedes", Car = "W12", Name = "Lewis Hamilton", StartingCompound = "C5"},
new Top10Driver() {Team = "Force India / Racing Point / Aston Martin", Car = "AMR21", Name = "Sebastian Vettel", StartingCompound = "C5"},
new Top10Driver() {Team = "Red Bull", Car = "RB16B", Name = "Sergio Pérez", StartingCompound = "C5"},
new Top10Driver() {Team = "Sauber / Alfa Romeo", Car = "C41", Name = "Antonio Giovinazzi", StartingCompound = "C5"},
new Top10Driver() {Team = "Renault / Alpine", Car = "A521", Name = "Esteban Ocon", StartingCompound = "C5"},
};


Console.WriteLine("Catalunya");
Console.WriteLine("==========");
foreach (var d in Top10Catalunya)
{
var prediction = predictionEngine.Predict(new TyreStint()
{
Track = Season2021.Tracks["Catalunya"].Name,
TrackLength = Season2021.Tracks["Catalunya"].Distance,
Team = d.Team,
Car = d.Car,
Driver = d.Name,
Compound = d.StartingCompound,
AirTemperature = 25.0f,
TrackTemperature = 43.7f,
Reason = "Pit Stop"
});
Console.WriteLine($"| {d.Name} | {d.StartingCompound} | {prediction.Distance } | | |");
}
Console.WriteLine("");
DataHelper.PrintPredictionTable(predictionEngine, "Monaco", 25.0f, 43.7f, top10Monaco);
}



private static void PrintTop10ImolaPortimao(PredictionEngine<TyreStint, TyreStintPrediction> predictionEngine)
{
var ImolaTop10 = new List<Top10Driver>()
{
new Top10Driver() {Team = "Mercedes", Car = "W12", Name = "Lewis Hamilton", StartingCompound = "C3"},
new Top10Driver() {Team = "Red Bull", Car = "RB16B", Name = "Sergio Pérez", StartingCompound = "C4"},
new Top10Driver() {Team = "Red Bull", Car = "RB16B", Name = "Max Verstappen", StartingCompound = "C3"},
new Top10Driver() {Team = "Ferrari", Car = "SF21", Name = "Charles Leclerc", StartingCompound = "C4"},
new Top10Driver() {Team = "Toro Rosso / AlphaTauri", Car = "AT02", Name = "Pierre Gasly", StartingCompound = "C4"},
new Top10Driver() {Team = "McLaren", Car = "MCL35M", Name = "Daniel Ricciardo", StartingCompound = "C4"},
new Top10Driver() {Team = "McLaren", Car = "MCL35M", Name = "Lando Norris", StartingCompound = "C4"},
new Top10Driver() {Team = "Mercedes", Car = "W12", Name = "Valtteri Bottas", StartingCompound = "C3"},
new Top10Driver() {Team = "Renault / Alpine", Car = "A521", Name = "Esteban Ocon", StartingCompound = "C4"},
new Top10Driver() {Team = "Force India / Racing Point / Aston Martin", Car = "AMR21", Name = "Lance Stroll", StartingCompound = "C4"},
};

var PortimaoTop10 = new List<Top10Driver>()
{
new Top10Driver() {Team = "Mercedes", Car = "W12", Name = "Valtteri Bottas", StartingCompound = "C2"},
new Top10Driver() {Team = "Mercedes", Car = "W12", Name = "Lewis Hamilton", StartingCompound = "C2"},
new Top10Driver() {Team = "Red Bull", Car = "RB16B", Name = "Max Verstappen", StartingCompound = "C2"},
new Top10Driver() {Team = "Red Bull", Car = "RB16B", Name = "Sergio Pérez", StartingCompound = "C2"},
new Top10Driver() {Team = "Ferrari", Car = "SF21", Name = "Carlos Sainz", StartingCompound = "C3"},
new Top10Driver() {Team = "Renault / Alpine", Car = "A521", Name = "Esteban Ocon", StartingCompound = "C3"},
new Top10Driver() {Team = "McLaren", Car = "MCL35M", Name = "Lando Norris", StartingCompound = "C3"},
new Top10Driver() {Team = "Ferrari", Car = "SF21", Name = "Charles Leclerc", StartingCompound = "C2"},
new Top10Driver() {Team = "Toro Rosso / AlphaTauri", Car = "AT02", Name = "Pierre Gasly", StartingCompound = "C3"},
new Top10Driver() {Team = "Force India / Racing Point / Aston Martin", Car = "AMR21", Name = "Sebastian Vettel", StartingCompound = "C3"},
};


Console.WriteLine("Imola");
Console.WriteLine("==========");
foreach (var d in ImolaTop10)
{
var prediction = predictionEngine.Predict(new TyreStint()
{
Track = Season2021.Tracks["Imola"].Name,
TrackLength = Season2021.Tracks["Imola"].Distance,
Team = d.Team,
Car = d.Car,
Driver = d.Name,
Compound = d.StartingCompound,
AirTemperature = 9.3f,
TrackTemperature = 17.5f,
Reason = "Pit Stop"
});
Console.WriteLine($"| {d.Name} | {d.StartingCompound} | {prediction.Distance / Season2021.Tracks["Imola"].Distance} | | |");
}
Console.WriteLine("");

Console.WriteLine("Portimão");
Console.WriteLine("==========");
foreach (var d in PortimaoTop10)
{
var prediction = predictionEngine.Predict(new TyreStint()
{
Track = Season2021.Tracks["Portimão"].Name,
TrackLength = Season2021.Tracks["Portimão"].Distance,
Team = d.Team,
Car = d.Car,
Driver = d.Name,
Compound = d.StartingCompound,
AirTemperature = 19.8f,
TrackTemperature = 40.3f,
Reason = "Pit Stop"
});
Console.WriteLine($"| {d.Name} | {d.StartingCompound} | {prediction.Distance / Season2021.Tracks["Portimão"].Distance} | | |");
}
}

private static void PrintTopModels(ExperimentResult<RegressionMetrics> experimentResult)
{
// Get top few runs ranked by R-Squared.
// R-Squared is a metric to maximize, so OrderByDescending() is correct.
// For RMSE and other regression metrics, OrderByAscending() is correct.
var topRuns = experimentResult.RunDetails
.Where(r => r.ValidationMetrics != null && !double.IsNaN(r.ValidationMetrics.RSquared))
.OrderByDescending(r => r.ValidationMetrics.RSquared).Take(3);

Console.WriteLine("Top models ranked by R-Squared --");
PrintRegressionMetricsHeader();
for (var i = 0; i < topRuns.Count(); i++)
{
var run = topRuns.ElementAt(i);
PrintIterationMetrics(i + 1, run.TrainerName, run.ValidationMetrics, run.RuntimeInSeconds);
}
}

public static void PrintRegressionMetrics(string name, RegressionMetrics metrics)
{
Console.WriteLine($"*************************************************");
Console.WriteLine($"* Metrics for {name} regression model ");
Console.WriteLine($"*------------------------------------------------");
Console.WriteLine($"* LossFn: {metrics.LossFunction:0.##}");
Console.WriteLine($"* R2 Score: {metrics.RSquared:0.##}");
Console.WriteLine($"* Absolute loss: {metrics.MeanAbsoluteError:#.##}");
Console.WriteLine($"* Squared loss: {metrics.MeanSquaredError:#.##}");
Console.WriteLine($"* RMS loss: {metrics.RootMeanSquaredError:#.##}");
Console.WriteLine($"*************************************************");
}

internal static void PrintRegressionMetricsHeader()
{
CreateRow($"{"",-4} {"Trainer",-35} {"RSquared",8} {"Absolute-loss",13} {"Squared-loss",12} {"RMS-loss",8} {"Duration",9}", 114);
}

internal static void PrintIterationMetrics(int iteration, string trainerName, RegressionMetrics metrics, double? runtimeInSeconds)
{
CreateRow($"{iteration,-4} {trainerName,-35} {metrics?.RSquared ?? double.NaN,8:F4} {metrics?.MeanAbsoluteError ?? double.NaN,13:F2} {metrics?.MeanSquaredError ?? double.NaN,12:F2} {metrics?.RootMeanSquaredError ?? double.NaN,8:F2} {runtimeInSeconds.Value,9:F1}", 114);
}

private static void CreateRow(string message, int width)
{
Console.WriteLine("|" + message.PadRight(width - 2) + "|");
}



}

public class Top10Driver : Driver
{
public string StartingCompound { get; set; }
}
}
1 change: 1 addition & 0 deletions src/Tyres/StaticData/Season2021.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public static class Season2021
{ "Imola", new Track() {Name = "Imola", Distance = 4909f } },
{ "Portimão", new Track() {Name = "Portimão", Distance = 4653f } },
{ "Catalunya", new Track() {Name = "Catalunya", Distance = 4675f } },
{ "Monaco", new Track() {Name = "Monaco", Distance = 3337f } }
};
}
}
35 changes: 35 additions & 0 deletions src/Tyres/ViewHelpers/DataHelper.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
using System;
using System.Collections.Generic;
using Microsoft.ML;
using Tyres.DataModels;
using Tyres.DomainModels;
using Tyres.StaticData;

namespace Tyres.ViewHelpers
{
public class DataHelper
{
public static void PrintPredictionTable(PredictionEngine<TyreStint, TyreStintPrediction> predictionEngine, string trackName, float airTemperature, float trackTemperature, List<Top10Driver> drivers)
{
Console.WriteLine(Season2021.Tracks[trackName].Name);
Console.WriteLine("==========");
foreach (var d in drivers)
{
var prediction = predictionEngine.Predict(new TyreStint()
{
Track = Season2021.Tracks[trackName].Name,
TrackLength = Season2021.Tracks[trackName].Distance,
Team = d.Team,
Car = d.Car,
Driver = d.Name,
Compound = d.StartingCompound,
AirTemperature = airTemperature,
TrackTemperature = trackTemperature,
Reason = "Pit Stop"
});
Console.WriteLine($"| {d.Name} | {d.StartingCompound} | {prediction.Distance } | | |");
}
Console.WriteLine("");
}
}
}
56 changes: 56 additions & 0 deletions src/Tyres/ViewHelpers/TrainingHelper.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
using System;
using System.Linq;
using Microsoft.ML.AutoML;
using Microsoft.ML.Data;

namespace Tyres.ViewHelpers
{
public class TrainingHelper
{
public static void PrintTopModels(ExperimentResult<RegressionMetrics> experimentResult)
{
// Get top few runs ranked by R-Squared.
// R-Squared is a metric to maximize, so OrderByDescending() is correct.
// For RMSE and other regression metrics, OrderByAscending() is correct.
var topRuns = experimentResult.RunDetails
.Where(r => r.ValidationMetrics != null && !double.IsNaN(r.ValidationMetrics.RSquared))
.OrderByDescending(r => r.ValidationMetrics.RSquared).Take(3);

Console.WriteLine("Top models ranked by R-Squared --");
PrintRegressionMetricsHeader();
for (var i = 0; i < topRuns.Count(); i++)
{
var run = topRuns.ElementAt(i);
PrintIterationMetrics(i + 1, run.TrainerName, run.ValidationMetrics, run.RuntimeInSeconds);
}
}

public static void PrintRegressionMetrics(string name, RegressionMetrics metrics)
{
Console.WriteLine($"*************************************************");
Console.WriteLine($"* Metrics for {name} regression model ");
Console.WriteLine($"*------------------------------------------------");
Console.WriteLine($"* LossFn: {metrics.LossFunction:0.##}");
Console.WriteLine($"* R2 Score: {metrics.RSquared:0.##}");
Console.WriteLine($"* Absolute loss: {metrics.MeanAbsoluteError:#.##}");
Console.WriteLine($"* Squared loss: {metrics.MeanSquaredError:#.##}");
Console.WriteLine($"* RMS loss: {metrics.RootMeanSquaredError:#.##}");
Console.WriteLine($"*************************************************");
}

internal static void PrintRegressionMetricsHeader()
{
CreateRow($"{"",-4} {"Trainer",-35} {"RSquared",8} {"Absolute-loss",13} {"Squared-loss",12} {"RMS-loss",8} {"Duration",9}", 114);
}

internal static void PrintIterationMetrics(int iteration, string trainerName, RegressionMetrics metrics, double? runtimeInSeconds)
{
CreateRow($"{iteration,-4} {trainerName,-35} {metrics?.RSquared ?? double.NaN,8:F4} {metrics?.MeanAbsoluteError ?? double.NaN,13:F2} {metrics?.MeanSquaredError ?? double.NaN,12:F2} {metrics?.RootMeanSquaredError ?? double.NaN,8:F2} {runtimeInSeconds.Value,9:F1}", 114);
}

private static void CreateRow(string message, int width)
{
Console.WriteLine("|" + message.PadRight(width - 2) + "|");
}
}
}

0 comments on commit cfce520

Please sign in to comment.