-
Notifications
You must be signed in to change notification settings - Fork 3
/
correlation_coeff.py
570 lines (536 loc) · 50.1 KB
/
correlation_coeff.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
"""
MIT License
Copyright (c) 2020 Shantanu Ghosh
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
import numpy as np
import scipy.stats
from Propensity_score_LR import Propensity_socre_LR
from Sparse_Propensity_score import Sparse_Propensity_score
from Utils import Utils
from dataloader import DataLoader
from shallow_train import shallow_train
class Correlation:
def correlation_coeff(self):
device = Utils.get_device()
csv_path = "Dataset/ihdp_sample.csv"
dL = DataLoader()
np_covariates_X, np_covariates_Y = dL.preprocess_for_graphs(csv_path)
ps_train_set = dL.convert_to_tensor(np_covariates_X, np_covariates_Y)
# ps_score_list_SAE_e2e, ps_score_list_SAE_stacked_all_layer_active, \
# ps_score_list_SAE_stacked_cur_layer_active = self.__train_propensity_net_SAE(ps_train_set, device)
#
# ps_list_25_10_ = [0.3365058898925781, 0.1499626338481903, 0.1780136674642563, 0.014368762262165546,
# 0.03886539489030838, 0.20776931941509247, 0.23725758492946625, 0.32863256335258484,
# 0.09364322572946548, 0.012019762769341469, 0.18714001774787903, 0.0018665395909920335,
# 0.0009031076915562153, 0.7351804971694946, 0.34464964270591736, 0.05372120812535286,
# 0.05463121086359024, 0.027270885184407234, 0.9168853759765625, 0.015908556059002876,
# 0.012545470148324966, 0.1441732794046402, 0.814822793006897, 0.007629018276929855,
# 0.018120190128684044, 0.17656627297401428, 0.24364547431468964, 0.1578308790922165,
# 0.1086936742067337, 0.1105617806315422, 0.2680777609348297, 0.0032791185658425093,
# 0.011415788903832436, 0.0007648891769349575, 0.09267505258321762, 0.03489454835653305,
# 0.28469085693359375, 0.14813412725925446, 0.03304534777998924, 0.6469749212265015,
# 0.04388434439897537, 0.545077383518219, 0.7376824021339417, 0.6212267875671387,
# 0.7427332997322083, 0.054039932787418365, 0.3192288875579834, 0.007769970688968897,
# 0.10768510401248932, 0.3792364001274109, 0.4514809548854828, 0.20769090950489044,
# 0.05112271010875702, 0.49506205320358276, 0.44735732674598694, 0.01088651642203331,
# 0.02849946916103363, 0.19301249086856842, 0.11508192867040634, 0.3784361779689789,
# 0.001563536236062646, 0.07678636163473129, 0.25700536370277405, 0.45361462235450745,
# 0.004687300883233547, 0.0033183200284838676, 0.019274068996310234, 0.15051142871379852,
# 0.004528303164988756, 0.0018420269479975104, 0.01595301926136017, 0.08681289106607437,
# 0.22254373133182526, 0.018458187580108643, 0.08991475403308868, 0.947604238986969,
# 0.6897139549255371, 0.0030965672340244055, 0.026209598407149315, 0.6746258735656738,
# 0.08136255294084549, 0.12220372259616852, 0.036555200815200806, 0.10239823162555695,
# 0.04042569547891617, 0.3926636576652527, 0.2064315229654312, 0.20770993828773499,
# 0.06455257534980774, 0.01789627969264984, 0.5330932140350342, 0.4693591892719269,
# 0.013274206779897213, 0.011393911205232143, 0.2535511553287506, 0.2980523407459259,
# 0.0027889872435480356, 0.0545501746237278, 0.2769225537776947, 0.2731384336948395,
# 0.004256804008036852, 0.0023430141154676676, 0.0005068830214440823, 0.000699451076798141,
# 0.0022647250443696976, 0.0003424181486479938, 0.022324901074171066, 0.0013551500160247087,
# 0.7797508239746094, 0.6964963674545288, 0.014664944261312485, 0.02154652401804924,
# 0.0856279656291008, 0.14901261031627655, 0.010696598328649998, 0.04867595434188843,
# 0.03504950925707817, 0.13861963152885437, 0.0028549698181450367, 0.04880546033382416,
# 0.011075888760387897, 0.00708325020968914, 0.0010257628746330738, 0.055154118686914444,
# 0.012662780471146107, 0.07496662437915802, 0.0013854283606633544, 0.0028660844545811415,
# 0.09046581387519836, 0.0011151449289172888, 0.028155051171779633, 0.30173027515411377,
# 0.06071389839053154, 0.0016787570202723145, 0.001979523804038763, 0.0028057654853910208,
# 0.015017005614936352, 0.010070901364088058, 0.006253891158849001, 0.21099860966205597,
# 0.022334696725010872, 0.026325948536396027, 0.1336345672607422, 0.05553610250353813,
# 0.05453195050358772, 0.00019252521451562643, 0.02985788695514202, 0.0018535612616688013,
# 0.026285812258720398, 0.04275912046432495, 0.0023135803639888763, 0.0018321247771382332,
# 0.0006235734326764941, 0.00591914402320981, 0.2980855703353882, 0.17876943945884705,
# 0.0014780719066038728, 0.00849723070859909, 0.008124170824885368, 0.06357981264591217,
# 0.0008707551751285791, 0.0006862155860289931, 0.0020929002203047276, 0.004281805362552404,
# 0.008322462439537048, 0.0003459450672380626, 0.3390841484069824, 0.7481536269187927,
# 0.5106906890869141, 0.007757176645100117, 0.014206422492861748, 0.0008175206603482366,
# 0.2657463550567627, 0.0007753681857138872, 0.0255946796387434, 0.008109605871140957,
# 7.611249020555988e-05, 0.018274854868650436, 0.0005144730675965548, 0.006225120276212692,
# 0.0014877908397465944, 0.012845293618738651, 0.00511661171913147, 0.001515206997282803,
# 0.20127516984939575, 0.3089359998703003, 0.010438602417707443, 0.012137453071773052,
# 0.19797934591770172, 0.51243656873703, 0.011713785119354725, 0.36488205194473267,
# 0.41408273577690125, 0.0014261336764320731, 0.0037824432365596294, 0.0035551516339182854,
# 0.15636710822582245, 0.03848598524928093, 0.028867216780781746, 0.007937123067677021,
# 0.0036818101070821285, 0.01173222716897726, 0.37630602717399597, 0.019481666386127472,
# 0.1464228481054306, 0.6574530601501465, 0.0071101500652730465, 0.22448740899562836,
# 0.09596722573041916, 0.31040894985198975, 0.04507742449641228, 0.07932645082473755,
# 0.03685774281620979, 0.03653689846396446, 0.09363304078578949, 0.18121066689491272,
# 0.3218793272972107, 0.17433294653892517, 0.0004975460469722748, 0.4453623294830322,
# 0.03790195286273956, 0.02381177432835102, 0.0013224466238170862, 0.18974913656711578,
# 0.1949874609708786, 0.005261131562292576, 0.12926769256591797, 0.007863630540668964,
# 0.07362367957830429, 0.10933012515306473, 0.19522283971309662, 0.14927752315998077,
# 0.45020321011543274, 0.05047404766082764, 0.03324407339096069, 0.027298346161842346,
# 0.08437549322843552, 0.19904881715774536, 0.33098074793815613, 0.09043396264314651,
# 0.318674236536026, 0.0010779169388115406, 0.24404488503932953, 0.3186350464820862,
# 0.0025669021997600794, 0.1811981052160263, 0.10796911269426346, 0.2746645212173462,
# 0.23726630210876465, 0.05970800295472145, 0.2774723172187805, 0.04485521838068962,
# 0.361568808555603, 0.566271960735321, 0.1212383583188057, 0.24455808103084564,
# 0.15271854400634766, 0.11322562396526337, 0.0335141196846962, 0.2837356626987457,
# 0.017213769257068634, 0.2849556803703308, 0.5101184844970703, 0.006656118668615818,
# 0.30922847986221313, 0.1917605698108673, 0.3227929472923279, 0.1147906556725502,
# 0.1325288861989975, 0.28648844361305237, 0.28225016593933105, 0.11361508071422577,
# 0.026927348226308823, 0.030898606404662132, 0.03599568083882332, 0.15442655980587006,
# 0.2199012190103531, 0.6688342690467834, 0.1601114124059677, 0.1656370311975479,
# 0.55259770154953, 0.07239291071891785, 0.3681557774543762, 0.11491131037473679,
# 0.24899432063102722, 0.11821242421865463, 0.15365850925445557, 0.10120883584022522,
# 0.08047141134738922, 0.04527793824672699, 0.31227877736091614, 0.017080089077353477,
# 0.014100546948611736, 0.43841680884361267, 0.0137359369546175, 0.6115028858184814,
# 0.037726692855358124, 0.06776237487792969, 0.35551199316978455, 0.27479875087738037,
# 0.0707501545548439, 0.5179652571678162, 0.17444543540477753, 0.4420629143714905,
# 0.21244992315769196, 0.03635856881737709, 0.010214539244771004, 0.011419345624744892,
# 0.30859944224357605, 0.2109130173921585, 0.001227161381393671, 0.0694965124130249,
# 0.17135953903198242, 0.13579487800598145, 0.031732890754938126, 0.11949918419122696,
# 0.7135012745857239, 0.08305172622203827, 0.23292042315006256, 0.24462968111038208,
# 0.0019088764674961567, 0.0016780936857685447, 0.07673459500074387, 0.031392328441143036,
# 0.001692366087809205, 0.006584825459867716, 0.026167025789618492, 0.012796289287507534,
# 0.2347840815782547, 0.015075706876814365, 0.030416738241910934, 0.21498487889766693,
# 0.002433608751744032, 0.039517778903245926, 0.023482363671064377, 0.0048379176296293736,
# 0.007254105526953936, 0.010515145026147366, 0.05815873667597771, 0.07883276045322418,
# 0.046819135546684265, 0.11181945353746414, 0.5231369137763977, 0.6762533783912659,
# 0.5818216800689697, 0.010390495881438255, 0.0018854711670428514, 0.006176910828799009,
# 0.10908140242099762, 0.05964774638414383, 0.0043340506963431835, 0.021595647558569908,
# 0.1715017408132553, 0.08218520134687424, 0.029034217819571495, 0.03759973123669624,
# 0.052282314747571945, 0.0028271391056478024, 0.0017544492147862911, 0.0007720172288827598,
# 0.0002776577020995319, 0.0495792031288147, 0.0015277613420039415, 0.03857247158885002,
# 0.034544143825769424, 0.001066438970156014, 0.0035532282199710608, 0.0799378827214241,
# 0.1229076161980629, 0.01684192195534706, 0.062368784099817276, 0.016936620697379112,
# 0.014279799535870552, 0.027907371520996094, 0.16180448234081268, 0.11875884234905243,
# 0.06431959569454193, 0.005786654073745012, 0.019491439685225487, 0.0338176004588604,
# 2.2590898879570886e-05, 2.1832103811902925e-05, 3.270710294600576e-05, 0.0005826150882057846,
# 0.002877677557989955, 0.005280376877635717, 0.0002973616647068411, 0.8202976584434509,
# 0.024798845872282982, 1.1437114153522998e-05, 0.007136042229831219, 0.0014124587178230286,
# 0.0002376262127654627, 5.6359633163083345e-05, 0.0016063374932855368, 0.019292164593935013,
# 0.03795655444264412, 0.0005354272434487939, 0.05028882622718811, 0.015524177812039852,
# 0.0006413861410692334, 0.0004307524941395968, 4.496751353144646e-05, 0.016009889543056488,
# 0.0011763031361624599, 4.372852708911523e-05, 0.017063191160559654, 8.582950249547139e-05,
# 0.0005633897962979972, 0.7130284309387207, 0.0007774485857225955, 0.025489015504717827,
# 0.0009836634853854775, 6.91479435772635e-05, 0.00032565215951763093, 0.0003849559579975903,
# 2.9286558856256306e-05, 0.0010570368031039834, 7.5815851232619025e-06, 1.8844430087483488e-05,
# 0.0006864393362775445, 0.0002953780931420624, 0.0004228455072734505, 0.0005543739534914494,
# 0.0101822130382061, 0.00023549457546323538, 0.0013668056344613433, 0.004204078111797571,
# 0.0021591142285615206, 0.00911776814609766, 0.00017250377277377993, 0.005282348487526178,
# 4.2991941882064566e-05, 0.004472709726542234, 0.0005618775612674654, 0.00024218297039624304,
# 0.00013537018094211817, 0.000215559994103387, 0.00025669983006082475, 0.008431092835962772,
# 0.0007646934245713055, 0.00011747117241611704, 0.005473301745951176, 0.0003126113733742386,
# 0.011448969133198261, 0.0015557686565443873, 0.008569751866161823, 0.003180823056027293,
# 0.000507085002027452, 0.537781298160553, 0.048898302018642426, 0.7545585036277771,
# 0.01993025466799736, 0.012768946588039398, 0.02008453756570816, 0.024491190910339355,
# 0.3734889328479767, 0.525874674320221, 0.13256166875362396, 0.020536212250590324,
# 0.0003381606948096305, 0.012623176909983158, 0.035666972398757935, 0.10173744708299637,
# 0.45949655771255493, 0.446538507938385, 0.6836756467819214, 0.005370558239519596,
# 0.019793491810560226, 0.0017227550270035863, 0.04722445085644722, 0.0002769494312815368,
# 0.01190416980534792, 0.03385474905371666, 0.00997639074921608, 0.06512054800987244,
# 0.00017778291658032686, 0.010086605325341225, 0.15036971867084503, 0.03493537753820419,
# 0.014694745652377605, 0.009590410627424717, 0.06938852369785309, 0.06201145052909851,
# 0.0012069909134879708, 0.09807892143726349, 0.012826574966311455, 0.0022779686842113733,
# 0.0009397809044457972, 0.08430462330579758, 0.010916265659034252, 0.018465103581547737,
# 0.09284492582082748, 0.0007923140656203032, 0.13442473113536835, 0.00029501141398213804,
# 0.02977648377418518, 0.003169003874063492, 0.06374705582857132, 0.07424328476190567,
# 0.05603868141770363, 0.0005273306160233915, 0.001537116477265954, 0.03864126652479172,
# 0.013761372305452824, 0.004713751841336489, 0.00015459433780051768, 0.06548599898815155,
# 0.002222091192379594, 0.02363699860870838, 0.002148520899936557, 0.048032067716121674,
# 0.047242671251297, 0.018189378082752228, 0.0013216364895924926, 0.004806154407560825,
# 0.0391436442732811, 0.04827449470758438, 0.032201990485191345, 0.009258149191737175,
# 0.058295246213674545, 0.5682844519615173, 0.0066831293515861034, 0.009114665910601616,
# 0.11043088138103485, 0.005345445591956377, 0.0011419858783483505, 0.06710591167211533,
# 0.001545222825370729, 0.0035176719538867474, 0.0028600646182894707, 0.26787179708480835,
# 0.012030825950205326, 0.7333755493164062, 0.0924898236989975, 0.5482035279273987,
# 0.7483367919921875, 0.15186171233654022, 0.031270235776901245, 0.054689954966306686,
# 0.22983664274215698, 0.1712011694908142, 0.46153223514556885, 0.5191338658332825,
# 0.6390472054481506, 0.6187522411346436, 0.09661407023668289, 0.0486166886985302,
# 0.05376244708895683, 0.13228805363178253, 0.18485110998153687, 0.022491177543997765,
# 0.1794504076242447, 0.5787538886070251, 0.2970709502696991, 0.35477644205093384,
# 0.07963622361421585, 0.23696789145469666, 0.07290654629468918, 0.6408109664916992,
# 0.07272060215473175, 0.24273446202278137, 0.49851101636886597, 0.2297876626253128,
# 0.012370138429105282, 0.06483983993530273, 0.032283250242471695, 0.10416410118341446,
# 0.04155949875712395, 0.5365368127822876, 0.2685471475124359, 0.7304875254631042,
# 0.16873101890087128, 0.10765017569065094, 0.5361544489860535, 0.5075755715370178,
# 0.34477660059928894, 0.03796958178281784, 0.58806973695755, 0.32740262150764465,
# 0.081614650785923, 0.1383964717388153, 0.060737576335668564, 0.5047486424446106,
# 0.5085154175758362, 0.9475945234298706, 0.36091873049736023, 0.5656833648681641,
# 0.03296098858118057, 0.338140070438385, 0.4402046799659729, 0.3290066421031952,
# 0.31168612837791443, 0.28048989176750183, 0.1813652515411377, 0.2380431592464447,
# 0.4760478138923645, 0.7804376482963562, 0.0025435779243707657, 0.28977110981941223,
# 0.5569596886634827, 0.7877252697944641, 0.6883821487426758, 0.18101866543293,
# 0.046706292778253555, 0.06395147740840912, 0.3652743399143219, 0.27191266417503357,
# 0.5388174653053284, 0.934138834476471, 0.1701051890850067, 0.11266148090362549,
# 0.12104987353086472, 0.2805715799331665, 0.09523230791091919, 0.33334827423095703,
# 0.4600788354873657, 0.27511101961135864, 0.26013627648353577, 0.23540319502353668,
# 0.6293138265609741, 0.14958517253398895, 0.03855852037668228, 0.49199220538139343,
# 0.5062726140022278, 0.11719761043787003, 0.2523224949836731, 0.3535851240158081,
# 0.566465437412262, 0.09446115046739578, 0.7193458080291748, 0.15530748665332794,
# 0.9024344086647034, 0.619306743144989, 0.15570005774497986, 0.34082600474357605,
# 0.18265828490257263, 0.09077121317386627, 0.14945639669895172, 0.022453971207141876,
# 0.31075263023376465, 0.682900607585907, 0.8920421004295349, 0.010393393225967884,
# 0.3244038224220276, 0.6724012494087219, 0.46097925305366516, 0.30986684560775757,
# 0.8810898661613464, 0.2677571475505829, 0.5119701027870178, 0.28144901990890503,
# 0.8146527409553528, 0.021037837490439415, 0.01196334045380354, 0.2313356250524521,
# 0.38446417450904846, 0.2571585476398468, 0.020353596657514572, 0.16669420897960663,
# 0.13793134689331055, 0.10052387416362762, 0.3346887528896332, 0.9261825680732727,
# 0.0702410638332367, 0.004401514772325754, 0.12539096176624298, 0.4944795072078705,
# 0.029531555250287056, 0.3734147846698761, 0.3408965468406677, 0.6572421789169312,
# 0.026275571435689926, 0.0018235023599117994, 0.6667882800102234, 0.9024632573127747,
# 0.5058287382125854, 0.010618552565574646, 0.5984737873077393, 0.7620850205421448,
# 0.3205288052558899, 0.4149850010871887, 0.015739887952804565, 0.9125362634658813,
# 0.502493679523468, 0.010216968134045601, 0.501264750957489, 0.07352140545845032,
# 0.2797524034976959, 0.49197113513946533, 0.786327064037323, 0.36495694518089294,
# 0.4937792718410492, 0.14923380315303802, 0.6707140207290649, 0.501875638961792,
# 0.06246095895767212, 0.8036077618598938, 0.5027488470077515, 0.7773066163063049,
# 0.0021783197298645973, 0.6643004417419434, 0.3796654939651489, 0.16182230412960052,
# 0.7063586711883545, 0.45415860414505005, 0.24265441298484802, 0.16849082708358765,
# 0.19779305160045624, 0.6337084770202637, 0.007776966318488121, 0.008619499392807484,
# 0.17919377982616425, 0.48299065232276917, 0.6280980110168457, 0.4688548147678375,
# 0.6413920521736145, 0.6029754877090454, 0.012571978382766247, 0.046973809599876404,
# 0.021499888971447945, 0.3052295744419098, 0.2845366895198822, 0.00994188617914915,
# 0.10093153268098831, 0.1707608848810196, 0.23015841841697693, 0.06455522775650024,
# 0.1998596489429474, 0.3655388355255127, 0.058542635291814804, 0.12859119474887848,
# 0.010832173749804497, 0.30879446864128113, 0.002360060578212142, 0.022943060845136642,
# 0.4519581198692322, 0.0038818521425127983, 0.018221097066998482, 0.03482630103826523,
# 0.46418139338493347, 0.293793261051178, 0.284146785736084, 0.567644476890564,
# 0.04316479712724686, 0.2176886945962906, 0.1264413446187973]
#
# ps_list_25_1_ = [0.3379203677177429, 0.08033178001642227, 0.11174498498439789, 0.10511206835508347,
# 0.09409968554973602, 0.3389323651790619, 0.4676828980445862, 0.2845773696899414,
# 0.34697651863098145, 0.1188291609287262, 0.18519265949726105, 0.11659646034240723,
# 0.10756745934486389, 0.5182259678840637, 0.19513317942619324, 0.12016185373067856,
# 0.10851199924945831, 0.23103117942810059, 0.5130442380905151, 0.10422299057245255,
# 0.12924565374851227, 0.19306707382202148, 0.15570373833179474, 0.08920545130968094,
# 0.11787081509828568, 0.20190349221229553, 0.1098686084151268, 0.17671145498752594,
# 0.3291429877281189, 0.19228456914424896, 0.32180550694465637, 0.09072616696357727,
# 0.10286889225244522, 0.08054888248443604, 0.1468377560377121, 0.1350991576910019,
# 0.2739027738571167, 0.09400739520788193, 0.226099893450737, 0.330873042345047,
# 0.10626007616519928, 0.21616625785827637, 0.3201182782649994, 0.4568440020084381,
# 0.28163012862205505, 0.2174510508775711, 0.27267563343048096, 0.095782570540905,
# 0.21306340396404266, 0.410489559173584, 0.2737514078617096, 0.35910260677337646,
# 0.22166503965854645, 0.4428495168685913, 0.22678689658641815, 0.11109378188848495,
# 0.11375221610069275, 0.2611459493637085, 0.2519790530204773, 0.11791415512561798,
# 0.11206243187189102, 0.177147775888443, 0.14860646426677704, 0.4627283215522766,
# 0.10039688646793365, 0.10842303186655045, 0.08730098605155945, 0.29866018891334534,
# 0.08010110259056091, 0.09551994502544403, 0.14067666232585907, 0.2497076541185379,
# 0.3944653570652008, 0.23115666210651398, 0.13938982784748077, 0.364315390586853,
# 0.3562399744987488, 0.11433114111423492, 0.1725475937128067, 0.4301903545856476,
# 0.3056260943412781, 0.15229269862174988, 0.22505047917366028, 0.1735696643590927,
# 0.0971640944480896, 0.16004854440689087, 0.1976790428161621, 0.13193854689598083,
# 0.3161192536354065, 0.14686021208763123, 0.5846854448318481, 0.31580817699432373,
# 0.18629013001918793, 0.08484101295471191, 0.09537515789270401, 0.3434858024120331,
# 0.10682468861341476, 0.19226087629795074, 0.26956018805503845, 0.15277314186096191,
# 0.12857294082641602, 0.08057508617639542, 0.07679995149374008, 0.1047675833106041,
# 0.08422505110502243, 0.06999059021472931, 0.10171198099851608, 0.07677847146987915,
# 0.15215422213077545, 0.16075272858142853, 0.08869733661413193, 0.09686969220638275,
# 0.09150903671979904, 0.0781916081905365, 0.09806479513645172, 0.09962109476327896,
# 0.07797189801931381, 0.0918135792016983, 0.09466836601495743, 0.09855043143033981,
# 0.12996505200862885, 0.08051473647356033, 0.06815771013498306, 0.09546299278736115,
# 0.08362782001495361, 0.11420125514268875, 0.08837661147117615, 0.12373937666416168,
# 0.09438571333885193, 0.07398839294910431, 0.1034855917096138, 0.0940454751253128,
# 0.18418551981449127, 0.08162592351436615, 0.09283247590065002, 0.07650069892406464,
# 0.12705810368061066, 0.13589298725128174, 0.15828269720077515, 0.15503646433353424,
# 0.09274999797344208, 0.08680147677659988, 0.15224039554595947, 0.10009095072746277,
# 0.18986031413078308, 0.08080025762319565, 0.12865050137043, 0.07820256054401398,
# 0.12452127784490585, 0.11700420081615448, 0.09310127049684525, 0.07275225967168808,
# 0.11126531660556793, 0.09180044382810593, 0.13112688064575195, 0.08694812655448914,
# 0.06809123605489731, 0.08481784909963608, 0.07682234048843384, 0.2608363628387451,
# 0.09524120390415192, 0.06815149635076523, 0.07735829800367355, 0.11197857558727264,
# 0.1268691122531891, 0.07402941584587097, 0.09421929717063904, 0.15647074580192566,
# 0.16466929018497467, 0.07895758748054504, 0.14755190908908844, 0.09386354684829712,
# 0.10496420413255692, 0.07964422553777695, 0.12397526204586029, 0.08362539857625961,
# 0.0779443010687828, 0.07657966017723083, 0.08073914796113968, 0.07922621816396713,
# 0.09476624429225922, 0.07532476633787155, 0.08583004772663116, 0.07195556908845901,
# 0.1327061951160431, 0.14415329694747925, 0.07226187735795975, 0.0883597880601883,
# 0.14913462102413177, 0.24547496438026428, 0.09024978429079056, 0.11838574707508087,
# 0.15504583716392517, 0.09119581431150436, 0.08399540185928345, 0.10815741866827011,
# 0.11589019000530243, 0.10153666883707047, 0.09390411525964737, 0.07619214802980423,
# 0.07812057435512543, 0.08548244833946228, 0.11566725373268127, 0.10686670988798141,
# 0.10194274038076401, 0.10816298425197601, 0.08748067170381546, 0.1876164823770523,
# 0.1500776708126068, 0.08230278640985489, 0.08594276756048203, 0.20574739575386047,
# 0.11169629544019699, 0.22771914303302765, 0.1846017688512802, 0.33805206418037415,
# 0.14562547206878662, 0.09274417906999588, 0.07372625172138214, 0.19955700635910034,
# 0.1818850189447403, 0.1189478263258934, 0.16006088256835938, 0.20672540366649628,
# 0.20583385229110718, 0.11225705593824387, 0.34550175070762634, 0.10130612552165985,
# 0.24093493819236755, 0.19475600123405457, 0.20586057007312775, 0.10869556665420532,
# 0.18263497948646545, 0.10113207250833511, 0.14840352535247803, 0.09122555702924728,
# 0.13351793587207794, 0.10630591213703156, 0.19350220263004303, 0.12751320004463196,
# 0.16567805409431458, 0.10013032704591751, 0.08638010919094086, 0.13779936730861664,
# 0.11805395781993866, 0.30811506509780884, 0.146883025765419, 0.25101804733276367,
# 0.09848921746015549, 0.13806134462356567, 0.25471898913383484, 0.17874951660633087,
# 0.255443811416626, 0.20805366337299347, 0.15204297006130219, 0.14999887347221375,
# 0.15268422663211823, 0.17508354783058167, 0.08976596593856812, 0.1398988962173462,
# 0.09645717591047287, 0.1159277856349945, 0.14440910518169403, 0.07565773278474808,
# 0.13520027697086334, 0.11220096051692963, 0.1028069332242012, 0.2288600653409958,
# 0.12974737584590912, 0.15512900054454803, 0.08404076099395752, 0.23312216997146606,
# 0.15702815353870392, 0.1029176190495491, 0.09227321296930313, 0.1628303974866867,
# 0.17035378515720367, 0.10521204769611359, 0.20241181552410126, 0.23169000446796417,
# 0.44350606203079224, 0.16696611046791077, 0.07475010305643082, 0.16135017573833466,
# 0.3554341197013855, 0.07840917259454727, 0.20144526660442352, 0.19635410606861115,
# 0.13333843648433685, 0.0981815829873085, 0.17648528516292572, 0.17903916537761688,
# 0.11099276691675186, 0.19530867040157318, 0.2208928018808365, 0.27517062425613403,
# 0.21012140810489655, 0.22613537311553955, 0.33992090821266174, 0.09171807020902634,
# 0.11086485534906387, 0.18492792546749115, 0.34429702162742615, 0.2387910634279251,
# 0.0884818583726883, 0.09061111509799957, 0.1027636006474495, 0.10671345144510269,
# 0.09762077033519745, 0.0940227136015892, 0.07084992527961731, 0.09765908122062683,
# 0.232946515083313, 0.07568907737731934, 0.09147863835096359, 0.09983860701322556,
# 0.16434283554553986, 0.11484641581773758, 0.24659647047519684, 0.06586812436580658,
# 0.068105049431324, 0.07518505305051804, 0.06646528840065002, 0.06706419587135315,
# 0.0666697770357132, 0.07399208098649979, 0.06842298805713654, 0.07366248965263367,
# 0.11733152717351913, 0.07348740100860596, 0.07400475442409515, 0.08745106309652328,
# 0.06658423691987991, 0.08807829767465591, 0.07048414647579193, 0.06731195002794266,
# 0.0667862594127655, 0.06755343079566956, 0.0771898552775383, 0.07650241255760193,
# 0.06892631947994232, 0.07154147326946259, 0.07231113314628601, 0.07003416121006012,
# 0.07674272358417511, 0.06903069466352463, 0.06661684066057205, 0.06689491122961044,
# 0.06601860374212265, 0.07058467715978622, 0.06769707053899765, 0.06661198288202286,
# 0.06611908227205276, 0.07348769158124924, 0.07225937396287918, 0.06652242690324783,
# 0.07274767011404037, 0.07056131958961487, 0.06580803543329239, 0.07052428275346756,
# 0.06569661945104599, 0.08991485089063644, 0.06847529858350754, 0.07429274916648865,
# 0.07409750670194626, 0.06628153473138809, 0.06616908311843872, 0.06844473630189896,
# 0.07324247807264328, 0.06954976171255112, 0.06814105808734894, 0.0712147131562233,
# 0.074394591152668, 0.06742560863494873, 0.07226438075304031, 0.07545743882656097,
# 0.07423445582389832, 0.07021569460630417, 0.06984533369541168, 0.07083151489496231,
# 0.06711813062429428, 0.06652997434139252, 0.06614476442337036, 0.06855210661888123,
# 0.07990744709968567, 0.0663495808839798, 0.07844773679971695, 0.08343928307294846,
# 0.0821342021226883, 0.06901166588068008, 0.10589518398046494, 0.07188180088996887,
# 0.07157508283853531, 0.0681111067533493, 0.07356706261634827, 0.0815301388502121,
# 0.11776763945817947, 0.06981409341096878, 0.06879104673862457, 0.08132286369800568,
# 0.0765751451253891, 0.07690572738647461, 0.07142578065395355, 0.07825205475091934,
# 0.07103368639945984, 0.07875823974609375, 0.07392983883619308, 0.07411762326955795,
# 0.0766012966632843, 0.3564867079257965, 0.0766976848244667, 0.14572568237781525,
# 0.07337143272161484, 0.07280640304088593, 0.07129362970590591, 0.06885618716478348,
# 0.06704150140285492, 0.077064648270607, 0.07058810442686081, 0.06832868605852127,
# 0.07754631340503693, 0.07145898044109344, 0.06995534151792526, 0.07496441155672073,
# 0.07013433426618576, 0.07240597903728485, 0.10411080718040466, 0.0959438607096672,
# 0.07177792489528656, 0.06929577142000198, 0.07246581465005875, 0.07868687808513641,
# 0.07281142473220825, 0.08662836998701096, 0.07436662167310715, 0.07987122237682343,
# 0.06664933264255524, 0.07654903829097748, 0.08229488879442215, 0.15083537995815277,
# 0.07470010966062546, 0.07251765578985214, 0.09497548639774323, 0.06846904754638672,
# 0.15289916098117828, 0.07200797647237778, 0.08071577548980713, 0.07831956446170807,
# 0.07458789646625519, 0.08725757896900177, 0.09455887973308563, 0.2644568681716919,
# 0.07413512468338013, 0.09620001912117004, 0.19033920764923096, 0.10465718060731888,
# 0.07459383457899094, 0.12188681215047836, 0.08848969638347626, 0.0729663148522377,
# 0.09172821044921875, 0.11408009380102158, 0.10667550563812256, 0.22077663242816925,
# 0.17935574054718018, 0.1142631322145462, 0.1501648873090744, 0.09774630516767502,
# 0.1449747234582901, 0.10361362993717194, 0.1659749150276184, 0.08628331869840622,
# 0.11507805436849594, 0.1370941549539566, 0.08692242950201035, 0.14166636765003204,
# 0.08353467285633087, 0.10970727354288101, 0.12364047765731812, 0.08308661729097366,
# 0.0838446244597435, 0.09004805982112885, 0.0922333151102066, 0.0907277762889862,
# 0.07333531230688095, 0.10785099118947983, 0.08015729486942291, 0.07752010226249695,
# 0.07313675433397293, 0.07234988361597061, 0.08453842997550964, 0.10183485597372055,
# 0.16524863243103027, 0.07150931656360626, 0.1604643017053604, 0.07207813858985901,
# 0.08752842247486115, 0.07193150371313095, 0.12204501032829285, 0.08384659886360168,
# 0.08388139307498932, 0.07872971892356873, 0.0876607745885849, 0.21355974674224854,
# 0.07816658914089203, 0.07312383502721786, 0.07654910534620285, 0.09895306825637817,
# 0.07290901988744736, 0.09818287938833237, 0.08132809400558472, 0.1075129508972168,
# 0.1466694176197052, 0.11932603269815445, 0.0771527960896492, 0.07712545245885849,
# 0.11907567083835602, 0.08663253486156464, 0.09479498863220215, 0.07886169850826263,
# 0.08942850679159164, 0.09484688937664032, 0.08148527145385742, 0.06964564323425293,
# 0.08288293331861496, 0.07364758849143982, 0.08717617392539978, 0.09090696275234222,
# 0.07271939516067505, 0.09713780879974365, 0.0763545110821724, 0.18362347781658173,
# 0.5545430183410645, 0.26221218705177307, 0.3838467299938202, 0.46218499541282654,
# 0.26045113801956177, 0.45540812611579895, 0.4221225082874298, 0.3614152669906616,
# 0.1366126388311386, 0.16206485033035278, 0.5038032531738281, 0.2902289927005768,
# 0.2119099646806717, 0.31361427903175354, 0.37998029589653015, 0.38603320717811584,
# 0.14085134863853455, 0.21522927284240723, 0.19219551980495453, 0.1872643083333969,
# 0.4470926821231842, 0.19259795546531677, 0.282594233751297, 0.22281725704669952,
# 0.3255426585674286, 0.4634615182876587, 0.26365500688552856, 0.37351512908935547,
# 0.508141815662384, 0.49897488951683044, 0.30935314297676086, 0.46181929111480713,
# 0.13552100956439972, 0.5179079174995422, 0.19075614213943481, 0.4927549958229065,
# 0.5789120197296143, 0.2973661720752716, 0.21202242374420166, 0.32899150252342224,
# 0.3226039707660675, 0.44071412086486816, 0.4104224145412445, 0.4136391282081604,
# 0.4722258746623993, 0.442671000957489, 0.08691488206386566, 0.5229451656341553,
# 0.2952505946159363, 0.2907719612121582, 0.3506596088409424, 0.34873366355895996,
# 0.3267822861671448, 0.5174949765205383, 0.5636297464370728, 0.3392939567565918,
# 0.07363885641098022, 0.14511844515800476, 0.320680171251297, 0.3352411985397339,
# 0.5368587374687195, 0.4812179505825043, 0.14985674619674683, 0.3182353973388672,
# 0.4464661180973053, 0.35860976576805115, 0.19414635002613068, 0.5518289804458618,
# 0.469176322221756, 0.49551206827163696, 0.28820276260375977, 0.5039542317390442,
# 0.23549997806549072, 0.18568815290927887, 0.41822177171707153, 0.1836974322795868,
# 0.10600315034389496, 0.20200100541114807, 0.2589026391506195, 0.3581222593784332,
# 0.5563390851020813, 0.4048709571361542, 0.5162954330444336, 0.5014818906784058,
# 0.5004151463508606, 0.35632580518722534, 0.27494022250175476, 0.3898845911026001,
# 0.2832420766353607, 0.4526689946651459, 0.2306213527917862, 0.47144100069999695,
# 0.3981761932373047, 0.4635953903198242, 0.28985679149627686, 0.22576779127120972,
# 0.40988221764564514, 0.12165339291095734, 0.46984848380088806, 0.36851778626441956,
# 0.19499817490577698, 0.4588679373264313, 0.4828434884548187, 0.4232226610183716,
# 0.5319685339927673, 0.3356710374355316, 0.21753956377506256, 0.10879117995500565,
# 0.4280291199684143, 0.36341890692710876, 0.4854242205619812, 0.17575155198574066,
# 0.3442493677139282, 0.208584263920784, 0.32545241713523865, 0.4206635057926178,
# 0.2800803780555725, 0.16305994987487793, 0.48876315355300903, 0.35154592990875244,
# 0.17298296093940735, 0.28987374901771545, 0.11071117222309113, 0.09981473535299301,
# 0.4593276381492615, 0.260109007358551, 0.24488680064678192, 0.1600152850151062,
# 0.47259771823883057, 0.2779066860675812, 0.39084455370903015, 0.5487063527107239,
# 0.36314353346824646, 0.10107510536909103, 0.08711779117584229, 0.15643468499183655,
# 0.34702417254447937, 0.470808207988739, 0.2508446276187897, 0.5278204679489136,
# 0.28929007053375244, 0.0886211097240448, 0.5804604291915894, 0.41490957140922546,
# 0.35269007086753845, 0.1058516651391983, 0.46096688508987427, 0.5447578430175781,
# 0.44033148884773254, 0.2437741905450821, 0.10845130681991577, 0.5577320456504822,
# 0.44957783818244934, 0.20271475613117218, 0.4464792311191559, 0.43857160210609436,
# 0.296863853931427, 0.3602867126464844, 0.38602396845817566, 0.27463096380233765,
# 0.2849353849887848, 0.5244476199150085, 0.354557067155838, 0.5444074869155884,
# 0.21093562245368958, 0.5376427173614502, 0.35977283120155334, 0.5170581936836243,
# 0.19430838525295258, 0.23492859303951263, 0.15488427877426147, 0.368257075548172,
# 0.4056648313999176, 0.23561181128025055, 0.2177893966436386, 0.30739855766296387,
# 0.36112192273139954, 0.4604761302471161, 0.11801379173994064, 0.14239811897277832,
# 0.2744101881980896, 0.2876740097999573, 0.1349114626646042, 0.37739816308021545,
# 0.36094599962234497, 0.4731997549533844, 0.1941041499376297, 0.20215287804603577,
# 0.3192158341407776, 0.3475882411003113, 0.08104734122753143, 0.1281113624572754,
# 0.33204546570777893, 0.265578031539917, 0.32456478476524353, 0.2734600901603699,
# 0.3355303704738617, 0.3098866045475006, 0.16760748624801636, 0.21950837969779968,
# 0.23259064555168152, 0.33919844031333923, 0.11106793582439423, 0.11545408517122269,
# 0.19700825214385986, 0.13278980553150177, 0.16070523858070374, 0.10105132311582565,
# 0.20112450420856476, 0.20004449784755707, 0.230228453874588, 0.3819774389266968,
# 0.15220323204994202, 0.2341819554567337, 0.28712669014930725]
#
# print(len(ps_list_25_10_))
# print(len(ps_list_25_1_))
#
# ps_list_25_10_ = np.array(ps_list_25_10_)
# ps_list_25_1_ = np.array(ps_list_25_1_)
# self.train_ps_25_1_SAE(ps_train_set, device)
# self.cal_correlation_coeff(ps_score_list_SAE_e2e, ps_score_list_SAE_stacked_all_layer_active,
# ps_score_list_SAE_stacked_cur_layer_active, ps_list_25_10_,
# ps_list_25_1_)
ps_list_LR = self.__train_propensity_net_LR(np_covariates_X, np_covariates_Y)
ps_list_LR_lasso = self.__train_propensity_net_LR_Lasso(np_covariates_X, np_covariates_Y)
print(len(ps_list_LR))
print(len(ps_list_LR_lasso))
self.cal_correlation_coeff_lr_lasso(ps_list_LR, ps_list_LR_lasso)
@staticmethod
def cal_correlation_coeff_lr_lasso(ps_lr, ps_lasso):
corr_lr_vs_lasso = scipy.stats.pearsonr(ps_lr, ps_lasso)[0]
print("corr_lr_vs_lasso: {0}".format(corr_lr_vs_lasso))
@staticmethod
def cal_correlation_coeff(ps_score_list_SAE_e2e, ps_score_list_SAE_stacked_all_layer_active,
ps_score_list_SAE_stacked_cur_layer_active, ps_list_25_10_,
ps_list_25_1_):
corr_e2e_25_1 = scipy.stats.pearsonr(ps_score_list_SAE_e2e, ps_list_25_1_)[0]
corr_e2e_25_10 = scipy.stats.pearsonr(ps_score_list_SAE_e2e, ps_list_25_10_)[0]
corr_e2e_stacked_all = scipy.stats.pearsonr(ps_score_list_SAE_e2e,
ps_score_list_SAE_stacked_all_layer_active)[0]
corr_e2e_stacked_cur = scipy.stats.pearsonr(ps_score_list_SAE_e2e,
ps_score_list_SAE_stacked_cur_layer_active)[0]
corr_stacked_all_25_10 = scipy.stats.pearsonr(ps_score_list_SAE_stacked_all_layer_active,
ps_list_25_10_)[0]
corr_stacked_all_25_1 = scipy.stats.pearsonr(ps_score_list_SAE_stacked_all_layer_active,
ps_list_25_1_)[0]
corr_stacked_all_stacked_cur = scipy.stats.pearsonr(ps_score_list_SAE_stacked_all_layer_active,
ps_score_list_SAE_stacked_cur_layer_active)[0]
corr_stacked_cur_25_1 = scipy.stats.pearsonr(ps_score_list_SAE_stacked_cur_layer_active,
ps_list_25_1_)[0]
corr_stacked_cur_25_10 = scipy.stats.pearsonr(ps_score_list_SAE_stacked_cur_layer_active,
ps_list_25_10_)[0]
corr_25_10_25_1 = scipy.stats.pearsonr(ps_list_25_1_,
ps_list_25_10_)[0]
print("E2E vs all")
print("corr_e2e_25_1: {0}".format(corr_e2e_25_1))
print("corr_e2e_25_10: {0}".format(corr_e2e_25_10))
print("corr_e2e_stacked_all: {0}".format(corr_e2e_stacked_all))
print("corr_e2e_stacked_cur: {0}".format(corr_e2e_stacked_cur))
print("stacked all vs all")
print("corr_stacked_all_25_10: {0}".format(corr_stacked_all_25_10))
print("corr_stacked_all_25_1: {0}".format(corr_stacked_all_25_1))
print("corr_stacked_all_stacked_cur: {0}".format(corr_stacked_all_stacked_cur))
print("stacked cur vs all")
print("corr_stacked_cur_25_1: {0}".format(corr_stacked_cur_25_1))
print("corr_stacked_cur_25_10: {0}".format(corr_stacked_cur_25_10))
print("25-10 vs all")
print("corr_25_10_25_1: {0}".format(corr_25_10_25_1))
print("Min: {0}".format(min(corr_e2e_25_1, corr_e2e_25_10, corr_e2e_stacked_all, corr_e2e_stacked_cur,
corr_stacked_all_25_10, corr_stacked_all_25_1,
corr_stacked_all_stacked_cur,
corr_stacked_cur_25_1, corr_stacked_cur_25_10,
corr_25_10_25_1)))
print("Max: {0}".format(max(corr_e2e_25_1, corr_e2e_25_10, corr_e2e_stacked_all, corr_e2e_stacked_cur,
corr_stacked_all_25_10, corr_stacked_all_25_1,
corr_stacked_all_stacked_cur,
corr_stacked_cur_25_1, corr_stacked_cur_25_10,
corr_25_10_25_1)))
@staticmethod
def __train_propensity_net_LR(np_covariates_X_train, np_covariates_Y_train):
# eval propensity network using Logistic Regression
ps_score_list_LR, _ = Propensity_socre_LR.train(np_covariates_X_train,
np_covariates_Y_train)
return ps_score_list_LR
def __train_propensity_net_LR_Lasso(self, np_covariates_X_train, np_covariates_Y_train):
# eval propensity network using Logistic Regression Lasso
ps_score_list_LR_lasso, _ = Propensity_socre_LR.train(np_covariates_X_train,
np_covariates_Y_train,
regularized=True)
return ps_score_list_LR_lasso
@staticmethod
def train_ps_25_1_SAE(ps_train_set, device):
# !!! best parameter list
train_parameters_SAE = {
'epochs': 400,
'lr': 0.001,
"batch_size": 32,
"shuffle": True,
"train_set": ps_train_set,
"sparsity_probability": 0.08,
"weight_decay": 0.0003,
"BETA": 0.4
}
ps_net_SAE = shallow_train()
print("############### Propensity Score SAE net Training ###############")
sparse_classifier = ps_net_SAE.train(train_parameters_SAE, device, phase="train")
ps_score_list_SAE = ps_net_SAE.eval(ps_train_set, device, phase="eval",
sparse_classifier=sparse_classifier)
print(ps_score_list_SAE)
def __train_propensity_net_SAE(self, ps_train_set, device):
# !!! best parameter list
train_parameters_SAE = {
"epochs": 400,
"lr": 0.001,
"batch_size": 32,
"shuffle": True,
"train_set": ps_train_set,
"sparsity_probability": 0.8,
"weight_decay": 0.0003,
"BETA": 0.1,
"input_nodes": 25,
"classifier_epoch": 50,
"model_save_path": "./Propensity_Model/SAE_PS_model_iter_id_epoch_{0}_lr_{1}.pth"
}
ps_net_SAE = Sparse_Propensity_score()
print("############### Propensity Score SAE net Training ###############")
sparse_classifier, sae_classifier_stacked_all_layer_active, sae_classifier_stacked_cur_layer_active \
= ps_net_SAE.train(train_parameters_SAE, device, phase="train")
# eval propensity network using SAE
ps_score_list_SAE_e2e = ps_net_SAE.eval(ps_train_set, device, phase="eval",
sparse_classifier=sparse_classifier)
ps_score_list_SAE_stacked_all_layer_active = ps_net_SAE.eval(ps_train_set, device, phase="eval",
sparse_classifier=sae_classifier_stacked_all_layer_active)
ps_score_list_SAE_stacked_cur_layer_active = ps_net_SAE.eval(ps_train_set, device, phase="eval",
sparse_classifier=sae_classifier_stacked_cur_layer_active)
return ps_score_list_SAE_e2e, ps_score_list_SAE_stacked_all_layer_active, \
ps_score_list_SAE_stacked_cur_layer_active
Correlation().correlation_coeff()