-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathAssignment-4
63 lines (48 loc) · 1.78 KB
/
Assignment-4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import org.apache.spark.sql
import org.apache.spark.sql._
import org.apache.spark.sql.types._
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.{VectorAssembler, StringIndexer}
import org.apache.spark.sql.SparkSession
def dot (s: String) : Double = {
if (s.contains(".") || s.length == 0)
{
return -1
}
else
{
return s.toDouble
}
}
var df = spark.read
.format("csv")
.option("header", "true")
.option("mode", "DROPMALFORMED")
.load("caravan-insurance-challenge.csv")
val colNames = df.schema.names
for( i <- 1 to (colNames.length - 2)){
val x = colNames(i)
df = df.withColumn(x, df(x).cast(IntegerType))
.drop(x+"Tmp")
.withColumnRenamed(x+"Tmp", x)
}
var trainDF = df.filter( $"ORIGIN".like("train") )
trainDF = trainDF.drop(trainDF.col("ORIGIN"))
var testDF = df.filter( $"ORIGIN".like("test") )
testDF = testDF.drop(testDF.col("ORIGIN"))
val features = colNames.slice(1, colNames.length - 1)
val label = colNames(colNames.length - 1)
println("Features and labels identified")
val assembler = new VectorAssembler().setInputCols(features).setOutputCol("features")
var trainDF2 = assembler.transform(trainDF)
var testDF2 = assembler.transform(testDF)
val labelIndexer = new StringIndexer().setInputCol(label).setOutputCol("label")
trainDF2 = labelIndexer.fit(trainDF2).transform(trainDF2)
testDF2 = labelIndexer.fit(testDF2).transform(testDF2)
println("Building Model")
val model = new LogisticRegression().fit(trainDF2)
println("Model Built. Making Predictions.")
val predictions = model.transform(testDF2)
predictions.select ("features", "label", "prediction").show()