-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.rs
113 lines (92 loc) · 3.26 KB
/
main.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
use rustograd::engine::*;
use rustograd::nn::*;
use rand::seq::SliceRandom;
use std::fs::File;
use std::io::{BufRead, BufReader, Write};
fn read_dataset(filename: &str) -> (Vec<Vec<f64>>, Vec<f64>) {
let file = File::open(filename).expect("Unable to open file");
let reader = BufReader::new(file);
let mut x = Vec::new();
let mut y = Vec::new();
for (i, line) in reader.lines().enumerate() {
if i == 0 {
continue;
} // Skip header
let line = line.expect("Unable to read line");
let values: Vec<f64> = line
.split(',')
.map(|s| s.parse().expect("Parse error"))
.collect();
x.push(vec![values[0], values[1]]);
y.push(values[2]);
}
(x, y)
}
fn main() {
// Read the dataset
let (x, y) = read_dataset("moon_dataset.csv");
let n_samples = x.len();
// Create the MLP
let mut model = MLP::new(2, &[16, 16, 1]);
println!("Model: {:?}", model);
println!("Number of parameters: {}", model.parameters().len());
// Training loop
let n_epochs = 50;
let mut rng = rand::thread_rng();
for epoch in 0..n_epochs {
let mut total_loss = ValueWrapper::new(0.0);
let mut correct = 0;
// Shuffle indices
let mut indices: Vec<usize> = (0..n_samples).collect();
indices.shuffle(&mut rng);
for &i in &indices {
// Convert input to ValueWrapper
let input: Vec<ValueWrapper> = x[i].iter().map(|&x| ValueWrapper::new(x)).collect();
let target = y[i];
// Forward pass
let output = model.call(input)[0].clone();
// Compute loss (hinge loss)
let mut loss = (ValueWrapper::new(1.0) + (-target * output.clone())).relu();
total_loss += loss.clone();
// Compute accuracy
if (output.0.borrow().data.signum() as i32) == (target as i32) {
correct += 1;
}
// Backward pass
model.zero_grad();
loss.backward();
// Update weights (SGD)
let learning_rate = 0.01;
for p in model.parameters() {
let mut p_mut = p.0.borrow_mut();
p_mut.data -= learning_rate * p_mut.grad;
}
}
// Print progress
if epoch % 5 == 0 {
println!(
"Epoch {}: Loss = {:.4}, Accuracy = {}/{}",
epoch,
total_loss.0.borrow().data,
correct,
n_samples
);
}
}
// Generate points for visualization
let output_file_name = "decision_boundary_data2.csv";
let mut file = File::create(output_file_name).unwrap();
writeln!(file, "x,y,z").unwrap();
let step = 0.1;
for j in -30..=30 {
for i in -30..=30 {
let x = i as f64 * step;
let y = j as f64 * step;
let input = vec![ValueWrapper::new(x), ValueWrapper::new(y)];
let output = model.call(input)[0].clone();
writeln!(file, "{},{},{}", x, y, output.0.borrow().data).unwrap();
}
}
println!("Decision boundary data saved to '{}'", output_file_name);
println!("Use the Python visualization script to see the results.");
}