Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reproduce the results in the paper #17

Open
Huu-An opened this issue Dec 29, 2023 · 4 comments
Open

Reproduce the results in the paper #17

Huu-An opened this issue Dec 29, 2023 · 4 comments
Assignees
Labels
help wanted Extra attention is needed

Comments

@Huu-An
Copy link

Huu-An commented Dec 29, 2023

Dear Mr.@Mouradost,

I tried to reproduce the result reported in the paper, but I could not do it. I used the hyperparameters described in the paper to train the model many times. However, My model does not converge and the model's performance is bad.

I follow your instructions in issue #10 , I just train the STGM (w/o Estimator). I will give my configuration and the log of the latest training below. Thank you for taking the time to read and reply to me.

Hope to receive a response from you soon.

HuuAn,


My Configuration

Hardware
GPU: RTX 3080 Ti
RAM: 32GB
CPU: AMD Ryzen 9 5900X 12-Core Processor
OS: Linux

Adjustment: Change AdamW to Adam

config/config.yaml

defaults:
  - dataset: metrla
  - estimator: default
  - model: stgm
  - trainer: default
  - paths: default
  - log: default
  - hydra: default
  - _self_

device: auto

config/trainers/default

_target_: trainers.default.Trainer
epochs: 200
batch_size: 64
lr: 1e-3
clip: 3.0
usegamp: true
verbose: true
model_pred_single: false # If the model predict only one step at a time
log_level: ${log.level}
logger_name: ${log.logger}
device: ${device}

config/models/stgm.yaml

_target_: models.stgm.Model
name: STGM
in_channels: 1
hidden_channels: 32
out_channels: 1
nb_blocks: 1
timestep_max: ${dataset.window_size}
channels_last: True
device: ${device}
log_level: ${log.level}

Execution command

python run.py trainer.epochs=300 model.hidden_channels=32

Traning log

Epoch 300/300 train: 100%|______________________________________________| 300/300 [1:05:04<00:00, 13.02s/it, loss=19.6, loss_val=14.4]
[2023-12-29 12:10:19,741][STGM][INFO] - Model succesfully saved
[2023-12-29 12:10:19,841][Dataset][INFO] - Dataset METRLA (test) length: 6831
Generating predictions: 100%|___________________| 54/54 [00:02<00:00, 23.20it/s]
[2023-12-29 12:10:22,494][root][INFO] - 

================================================================================
                                Mean Error Real                                 
--------------------------------------------------------------------------------
        MSE                 RMSE                MAE                 MAPE        
--------------------------------------------------------------------------------
       66.02                8.08                4.37               13.72%       
================================================================================
================================================================================
                               By Step Error Real                               
--------------------------------------------------------------------------------
     Steps            MSE             RMSE            MAE             MAPE      
--------------------------------------------------------------------------------
       1             40.37            6.35            3.63           10.82%     
       2             47.42            6.89            3.89           11.51%     
       3             52.39            7.24            3.99           12.12%     
       4             60.08            7.75            4.19           12.99%     
       5             61.04            7.81            4.25           13.21%     
       6             67.34            8.21            4.45           14.02%     
       7             69.31            8.33            4.47           14.06%     
       8             73.12            8.55            4.56           14.61%     
       9             76.86            8.77            4.67           14.97%     
       10            78.77            8.88            4.73           15.17%     
       11            81.88            9.05            4.78           15.58%     
       12            83.61            9.14            4.81           15.53%     
================================================================================
@Mouradost
Copy link
Owner

Dear @HuuAnnnn,

First of all, thank you for your interest in our work and for your report. Concerning the hyperparameters described in the paper we did mention that we used 4 layers which translates to 4 blocks in config/models/stgm.yaml ~ nb_blocks: 4 and it seems that you have only used 1 block. Please let me know if you achieve comparable results with the paper, otherwise, I will have a look at the code in this repo as soon as I can.

@Mouradost Mouradost self-assigned this Jan 4, 2024
@Mouradost Mouradost added the help wanted Extra attention is needed label Jan 4, 2024
@Huu-An
Copy link
Author

Huu-An commented Jan 4, 2024

Thank you for your reply. I will try the way you suggested. I have another question: I can't find the GLU function or the CTMixer layer applied in the code as in the paper. Is the code missing the fusion model?

@Huu-An Huu-An closed this as completed Feb 1, 2024
@Huu-An Huu-An reopened this Feb 6, 2024
@Mouradost
Copy link
Owner

I have checked the code and effectively I pushed an old experimental version that doesn't contain some part of the model I will push the full model soon. I apologize for any inconvenience.

@danaesav
Copy link

danaesav commented Apr 29, 2024

I have checked the code and effectively I pushed an old experimental version that doesn't contain some part of the model I will push the full model soon. I apologize for any inconvenience.

@Mouradost Is there an update on this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

3 participants