Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unable to reproduce the results in the article #18

Open
Jimmy-7664 opened this issue Apr 8, 2024 · 0 comments
Open

Unable to reproduce the results in the article #18

Jimmy-7664 opened this issue Apr 8, 2024 · 0 comments

Comments

@Jimmy-7664
Copy link

Jimmy-7664 commented Apr 8, 2024

To Reproduce
Steps to reproduce the behavior:

  1. Successfully prepare for the dtw.py with command python dtw.py /home/user/STGM/datasets/PEMSBAY/pems-bay.h5
  2. Modify the /src/datasets/pemsbay.py to
# ML
import pandas as pd

# Own
from datasets.base import BaseDataset
import numpy as np

class Dataset(BaseDataset):
    def load_data(self):
        self._name = "PEMSBAY"
        self.adj = pd.read_csv(
            self.data_folder_path / "PEMSBAY" / "W_pemsbay.csv"
        ).values
        data = pd.read_hdf(self.data_folder_path / "PEMSBAY" / "pems-bay.h5")
        self.timestamps = data.index
        self.timestamps.freq = self.timestamps.inferred_freq
        self.time_index = self.time_to_idx(self.timestamps, freq="5min")
        self.data = data.values
        self.data_mean, self.data_std = self.data.mean(), self.data.std()
        self.data_min, self.data_max = self.data.min(), self.data.max()
        self.sim=np.load('/home/user/STGM/src/dtw.npy')
  1. Modify all related the config file to train with STGM_full model
dataset:
  _target_: datasets.pemsbay.Dataset
  name: PEMS-BAY
  data_folder_path: ${paths.dataset}
  mode: null
  train_ratio: 0.8
  val_ratio: 0.1
  window_size: 12
  batch_size: 64
  nb_worker: 16
  r: 1
estimator:
  _target_: models.estimator.Model
  name: STGM_ESTIMATOR
  in_channels: 1
  hidden_channels: 64
  bias: false
  nb_blocks: 2
  channels_last: true
  device: ${device}
  log_level: ${log.level}
model:
  _target_: models.stgm_full.Model
  name: STGM
  in_channels: 1
  hidden_channels: 32
  out_channels: 1
  nb_blocks: 4
  timestep_max: ${dataset.window_size}
  channels_last: true
  device: ${device}
  log_level: ${log.level}
trainer:
  _target_: trainers.full.Trainer
  epochs: 200
  batch_size: 64
  lr: 0.001
  clip: 3.0
  use_amp: false
  verbose: true
  model_pred_single: false
  log_level: ${log.level}
  logger_name: ${log.logger}
  device: ${device}
paths:
  log: ${hydra:runtime.cwd}
  dataset: /home/user/STGM/datasets/
  prior_weights: null
log:
  logger: wandb
  level: CRITICAL
  global_level: INFO
  project: TrafficForcastingSTGM
  entity: YOUR_USERNAME
  nb_hours: 24
  style: default
device: cuda:3

  1. Train STGM with command ./run.py trainer.epochs=200 trainer.batch_size=64 dataset=pemsbay device=cuda:3 model.nb_blocks=4 log.logger=wandb model.hidden_channels=32

The result is much worse than the paper. Similar situation also happens to METRLA dataset.

Linux Server:

  • OS: 20.04-Ubuntu
  • Python 3.11.0, PyTorch 2.0.0 and Hydra 1.3.1
  • GPU: NVIDIA RTX A6000

Could you help me to figure out the problem?
Also, I found the memory usage was more than 20G. How can I train STGM on Nvidia RTX 2070 GPU with 8G memory? Looking forward to your kindly help.

@Jimmy-7664 Jimmy-7664 reopened this Aug 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant