Skip to content

Commit 20d0f6f

Browse files
committed
merge cd and kt traintpl
1 parent e471ebd commit 20d0f6f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+103
-268
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ run_edustudio(
4141
dataset='FrcSub',
4242
cfg_file_name=None,
4343
traintpl_cfg_dict={
44-
'cls': 'CDInterTrainTPL',
44+
'cls': 'EduTrainTPL',
4545
},
4646
datatpl_cfg_dict={
4747
'cls': 'CDInterExtendsQDataTPL'

docs/source/conf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
project = 'EduStudio'
1010
copyright = '2023, HFUT-LEC'
1111
author = 'HFUT-LEC'
12-
release = 'v1.0.0-alpha4'
12+
release = 'v1.0.0-alpha5'
1313

1414
import sphinx_rtd_theme
1515
import os

docs/source/get_started/quick_start.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ run_edustudio(
1313
dataset='FrcSub',
1414
cfg_file_name=None,
1515
traintpl_cfg_dict={
16-
'cls': 'CDInterTrainTPL',
16+
'cls': 'EduTrainTPL',
1717
},
1818
datatpl_cfg_dict={
1919
'cls': 'CDInterExtendsQDataTPL'

docs/source/index.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
.. EduStudio documentation master file.
2-
.. title:: EduStudio v1.0.0-alpha4
2+
.. title:: EduStudio v1.0.0-alpha5
33
.. image:: assets/logo.png
44

55
=========================================================

docs/source/user_guide/reference_table.md

+42-42
Original file line numberDiff line numberDiff line change
@@ -4,51 +4,51 @@
44

55
| Model | DataTPL | TrainTPL | EvalTPL |
66
| :------ | ---------------------: | :-------------: | ------------------------------------------------------ |
7-
| IRT | CDInterDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL |
8-
| MIRT | CDInterDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL |
9-
| NCDM | CDInterExtendsQDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL、CognitiveDiagnosisEvalTPL |
10-
| CNCD_Q | CNCDQDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL |
11-
| CNCD_F | CNCDFDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL |
12-
| DINA | CDInterExtendsQDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL、CognitiveDiagnosisEvalTPL |
13-
| HierCDF | HierCDFDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL、CognitiveDiagnosisEvalTPL |
14-
| CDGK | CDGKDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL、CognitiveDiagnosisEvalTPL |
15-
| CDMFKC | CDInterExtendsQDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL |
16-
| ECD | ECDDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL |
17-
| IRR | IRRDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL |
18-
| KaNCD | CDInterExtendsQDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL、CognitiveDiagnosisEvalTPL |
19-
| KSCD | CDInterExtendsQDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL |
20-
| MGCD | MGCDDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL |
21-
| RCD | RCDDataTPL | CDInterTrainTPL | BinaryClassificationEvalTPL |
7+
| IRT | CDInterDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
8+
| MIRT | CDInterDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
9+
| NCDM | CDInterExtendsQDataTPL | EduTrainTPL | BinaryClassificationEvalTPL、CognitiveDiagnosisEvalTPL |
10+
| CNCD_Q | CNCDQDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
11+
| CNCD_F | CNCDFDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
12+
| DINA | CDInterExtendsQDataTPL | EduTrainTPL | BinaryClassificationEvalTPL、CognitiveDiagnosisEvalTPL |
13+
| HierCDF | HierCDFDataTPL | EduTrainTPL | BinaryClassificationEvalTPL、CognitiveDiagnosisEvalTPL |
14+
| CDGK | CDGKDataTPL | EduTrainTPL | BinaryClassificationEvalTPL、CognitiveDiagnosisEvalTPL |
15+
| CDMFKC | CDInterExtendsQDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
16+
| ECD | ECDDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
17+
| IRR | IRRDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
18+
| KaNCD | CDInterExtendsQDataTPL | EduTrainTPL | BinaryClassificationEvalTPL、CognitiveDiagnosisEvalTPL |
19+
| KSCD | CDInterExtendsQDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
20+
| MGCD | MGCDDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
21+
| RCD | RCDDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
2222

2323
## KT models
2424

2525
| Model | DataTPL | TrainTPL | EvalTPL |
2626
| :----------- | ----------------------: | :-------------: | --------------------------- |
27-
| AKT | KTInterDataTPLCptUnfold | KTInterTrainTPL | BinaryClassificationEvalTPL |
27+
| AKT | KTInterDataTPLCptUnfold | EduTrainTPL | BinaryClassificationEvalTPL |
2828
| ATKT | KTInterDataTPLCptUnfold | AtktTrainTPL | BinaryClassificationEvalTPL |
29-
| CKT | KTInterExtendsQDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
30-
| CL4KT | CL4KTDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
31-
| CT_NCM | KTInterCptUnfoldDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
32-
| DeepIRT | KTInterExtendsQDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
33-
| DIMKT | DIMKTDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
34-
| DKT | KTInterDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
35-
| DKTDSC | DKTDSCDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
36-
| DKTForget | DKTForgetDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
37-
| DKT_plus | KTInterDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
38-
| DKVMN | KTInterExtendsQDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
39-
| DTransformer | KTInterCptUnfoldDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
40-
| EERNN | EERNNDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
41-
| EKT | EERNNDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
42-
| HawkesKT | KTInterCptUnfoldDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
43-
| IEKT | KTInterExtendsQDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
44-
| KQN | KTInterDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
45-
| LPKT | LPKTDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
46-
| LPKT_S | LPKTDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
47-
| QDKT | QDKTDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
48-
| QIKT | KTInterExtendsQDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
49-
| RKT | RKTDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
50-
| SAINT | KTInterCptUnfoldDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
51-
| SAINT_plus | KTInterCptUnfoldDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
52-
| SAKT | KTInterDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
53-
| SimpleKT | KTInterCptUnfoldDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
54-
| SKVMN | KTInterDataTPL | KTInterTrainTPL | BinaryClassificationEvalTPL |
29+
| CKT | KTInterExtendsQDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
30+
| CL4KT | CL4KTDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
31+
| CT_NCM | KTInterCptUnfoldDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
32+
| DeepIRT | KTInterExtendsQDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
33+
| DIMKT | DIMKTDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
34+
| DKT | KTInterDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
35+
| DKTDSC | DKTDSCDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
36+
| DKTForget | DKTForgetDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
37+
| DKT_plus | KTInterDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
38+
| DKVMN | KTInterExtendsQDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
39+
| DTransformer | KTInterCptUnfoldDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
40+
| EERNN | EERNNDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
41+
| EKT | EERNNDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
42+
| HawkesKT | KTInterCptUnfoldDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
43+
| IEKT | KTInterExtendsQDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
44+
| KQN | KTInterDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
45+
| LPKT | LPKTDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
46+
| LPKT_S | LPKTDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
47+
| QDKT | QDKTDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
48+
| QIKT | KTInterExtendsQDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
49+
| RKT | RKTDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
50+
| SAINT | KTInterCptUnfoldDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
51+
| SAINT_plus | KTInterCptUnfoldDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
52+
| SAKT | KTInterDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
53+
| SimpleKT | KTInterCptUnfoldDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |
54+
| SKVMN | KTInterDataTPL | EduTrainTPL | BinaryClassificationEvalTPL |

docs/source/user_guide/usage/aht.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def objective_function(args):
5454

5555

5656
search_space= {
57-
'traintpl_cfg.cls': tune.grid_search(['CDInterTrainTPL']),
57+
'traintpl_cfg.cls': tune.grid_search(['EduTrainTPL']),
5858
'datatpl_cfg.cls': tune.grid_search(['CDInterExtendsQDataTPL']),
5959
'modeltpl_cfg.cls': tune.grid_search(['KaNCD']),
6060
'evaltpl_cfg.clses': tune.grid_search([['BinaryClassificationEvalTPL', 'CognitiveDiagnosisEvalTPL']]),
@@ -115,7 +115,7 @@ def objective_function(args):
115115

116116

117117
space = {
118-
'traintpl_cfg.cls': hp.choice('traintpl_cfg.cls', ['CDInterTrainTPL']),
118+
'traintpl_cfg.cls': hp.choice('traintpl_cfg.cls', ['EduTrainTPL']),
119119
'datatpl_cfg.cls': hp.choice('datapl_cfg.cls', ['CDInterExtendsQDataTPL']),
120120
'modeltpl_cfg.cls': hp.choice('modeltpl_cfg.cls', ['KaNCD']),
121121
'evaltpl_cfg.clses': hp.choice('evaltpl_cfg.clses', [['BinaryClassificationEvalTPL', 'CognitiveDiagnosisEvalTPL']]),

docs/source/user_guide/usage/run_edustudio.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ run_edustudio(
1111
dataset='FrcSub',
1212
cfg_file_name=None,
1313
traintpl_cfg_dict={
14-
'cls': 'CDInterTrainTPL',
14+
'cls': 'EduTrainTPL',
1515
},
1616
datatpl_cfg_dict={
1717
'cls': 'CDInterExtendsQDataTPL'
@@ -48,7 +48,7 @@ datatpl_cfg:
4848
cls: CDInterDataTPL
4949

5050
traintpl_cfg:
51-
cls: CDTrainTPL
51+
cls: EduTrainTPL
5252
batch_size: 512
5353

5454
modeltpl_cfg:

edustudio/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
from __future__ import print_function
33
from __future__ import division
44

5-
__version__ = '1.0.0-alpha4'
5+
__version__ = 'v1.0.0-alpha5'

edustudio/traintpl/__init__.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,5 @@
44

55
from .base_traintpl import BaseTrainTPL
66
from .gd_traintpl import GDTrainTPL
7-
from .cd_inter_traintpl import CDInterTrainTPL
8-
from .kt_inter_traintpl import KTInterTrainTPL
7+
from .edu_traintpl import EduTrainTPL
98
from .atkt_traintpl import AtktTrainTPL

edustudio/traintpl/base_traintpl.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def get_default_cfg(cls):
6060
return cfg
6161

6262
def start(self):
63-
self.logger.info(f"TrainTPL {self.__class__.__base__} Started!")
63+
self.logger.info(f"TrainTPL {self.__class__} Started!")
6464
set_same_seeds(self.traintpl_cfg['seed'])
6565

6666
def _check_params(self):

edustudio/traintpl/cd_inter_traintpl.py

-161
This file was deleted.

edustudio/traintpl/kt_inter_traintpl.py edustudio/traintpl/edu_traintpl.py

+3-6
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import shutil
1111

1212

13-
class KTInterTrainTPL(GDTrainTPL):
13+
class EduTrainTPL(GDTrainTPL):
1414
default_cfg = {
1515
'num_stop_rounds': 10,
1616
'early_stop_metrics': [('auc','max')],
@@ -20,9 +20,6 @@ class KTInterTrainTPL(GDTrainTPL):
2020
'batch_size': 32,
2121
}
2222

23-
def __init__(self, cfg: UnifyConfig):
24-
super().__init__(cfg)
25-
2623
def _check_params(self):
2724
super()._check_params()
2825
assert self.traintpl_cfg['best_epoch_metric'] in set(i[0] for i in self.traintpl_cfg['early_stop_metrics'])
@@ -111,7 +108,7 @@ def evaluate(self, loader):
111108
batch_dict = self.batch_dict2device(batch_dict)
112109
eval_dict = self.model.predict(**batch_dict)
113110
pd_list[idx] = eval_dict['y_pd']
114-
gt_list[idx] = eval_dict['y_gt']
111+
gt_list[idx] = eval_dict['y_gt'] if 'y_gt' in eval_dict else batch_dict['label']
115112
y_pd = torch.hstack(pd_list)
116113
y_gt = torch.hstack(gt_list)
117114

@@ -142,7 +139,7 @@ def inference(self, loader):
142139
batch_dict = self.batch_dict2device(batch_dict)
143140
eval_dict = self.model.predict(**batch_dict)
144141
pd_list[idx] = eval_dict['y_pd']
145-
gt_list[idx] = eval_dict['y_gt']
142+
gt_list[idx] = eval_dict['y_gt'] if 'y_gt' in eval_dict else batch_dict['label']
146143
y_pd = torch.hstack(pd_list)
147144
y_gt = torch.hstack(gt_list)
148145

0 commit comments

Comments
 (0)