-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Reorganize the code, prepare for Monaco GP
- Loading branch information
Showing
5 changed files
with
119 additions
and
207 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
using System; | ||
using System.Collections.Generic; | ||
using Microsoft.ML; | ||
using Tyres.DataModels; | ||
using Tyres.DomainModels; | ||
using Tyres.StaticData; | ||
|
||
namespace Tyres.ViewHelpers | ||
{ | ||
public class DataHelper | ||
{ | ||
public static void PrintPredictionTable(PredictionEngine<TyreStint, TyreStintPrediction> predictionEngine, string trackName, float airTemperature, float trackTemperature, List<Top10Driver> drivers) | ||
{ | ||
Console.WriteLine(Season2021.Tracks[trackName].Name); | ||
Console.WriteLine("=========="); | ||
foreach (var d in drivers) | ||
{ | ||
var prediction = predictionEngine.Predict(new TyreStint() | ||
{ | ||
Track = Season2021.Tracks[trackName].Name, | ||
TrackLength = Season2021.Tracks[trackName].Distance, | ||
Team = d.Team, | ||
Car = d.Car, | ||
Driver = d.Name, | ||
Compound = d.StartingCompound, | ||
AirTemperature = airTemperature, | ||
TrackTemperature = trackTemperature, | ||
Reason = "Pit Stop" | ||
}); | ||
Console.WriteLine($"| {d.Name} | {d.StartingCompound} | {prediction.Distance } | | |"); | ||
} | ||
Console.WriteLine(""); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
using System; | ||
using System.Linq; | ||
using Microsoft.ML.AutoML; | ||
using Microsoft.ML.Data; | ||
|
||
namespace Tyres.ViewHelpers | ||
{ | ||
public class TrainingHelper | ||
{ | ||
public static void PrintTopModels(ExperimentResult<RegressionMetrics> experimentResult) | ||
{ | ||
// Get top few runs ranked by R-Squared. | ||
// R-Squared is a metric to maximize, so OrderByDescending() is correct. | ||
// For RMSE and other regression metrics, OrderByAscending() is correct. | ||
var topRuns = experimentResult.RunDetails | ||
.Where(r => r.ValidationMetrics != null && !double.IsNaN(r.ValidationMetrics.RSquared)) | ||
.OrderByDescending(r => r.ValidationMetrics.RSquared).Take(3); | ||
|
||
Console.WriteLine("Top models ranked by R-Squared --"); | ||
PrintRegressionMetricsHeader(); | ||
for (var i = 0; i < topRuns.Count(); i++) | ||
{ | ||
var run = topRuns.ElementAt(i); | ||
PrintIterationMetrics(i + 1, run.TrainerName, run.ValidationMetrics, run.RuntimeInSeconds); | ||
} | ||
} | ||
|
||
public static void PrintRegressionMetrics(string name, RegressionMetrics metrics) | ||
{ | ||
Console.WriteLine($"*************************************************"); | ||
Console.WriteLine($"* Metrics for {name} regression model "); | ||
Console.WriteLine($"*------------------------------------------------"); | ||
Console.WriteLine($"* LossFn: {metrics.LossFunction:0.##}"); | ||
Console.WriteLine($"* R2 Score: {metrics.RSquared:0.##}"); | ||
Console.WriteLine($"* Absolute loss: {metrics.MeanAbsoluteError:#.##}"); | ||
Console.WriteLine($"* Squared loss: {metrics.MeanSquaredError:#.##}"); | ||
Console.WriteLine($"* RMS loss: {metrics.RootMeanSquaredError:#.##}"); | ||
Console.WriteLine($"*************************************************"); | ||
} | ||
|
||
internal static void PrintRegressionMetricsHeader() | ||
{ | ||
CreateRow($"{"",-4} {"Trainer",-35} {"RSquared",8} {"Absolute-loss",13} {"Squared-loss",12} {"RMS-loss",8} {"Duration",9}", 114); | ||
} | ||
|
||
internal static void PrintIterationMetrics(int iteration, string trainerName, RegressionMetrics metrics, double? runtimeInSeconds) | ||
{ | ||
CreateRow($"{iteration,-4} {trainerName,-35} {metrics?.RSquared ?? double.NaN,8:F4} {metrics?.MeanAbsoluteError ?? double.NaN,13:F2} {metrics?.MeanSquaredError ?? double.NaN,12:F2} {metrics?.RootMeanSquaredError ?? double.NaN,8:F2} {runtimeInSeconds.Value,9:F1}", 114); | ||
} | ||
|
||
private static void CreateRow(string message, int width) | ||
{ | ||
Console.WriteLine("|" + message.PadRight(width - 2) + "|"); | ||
} | ||
} | ||
} |