Skip to content

Commit

Permalink
fix-code: fix import err
Browse files Browse the repository at this point in the history
  • Loading branch information
BHM-Bob committed May 22, 2023
1 parent 33b8e8b commit 75b8c02
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 33 deletions.
8 changes: 5 additions & 3 deletions mbapy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
'''
Author: BHM-Bob 2262029386@qq.com
Date: 2022-11-01 22:16:49
LastEditors: BHM-Bob
LastEditTime: 2023-04-06 21:34:17
LastEditors: BHM-Bob 2262029386@qq.com
LastEditTime: 2023-05-22 17:16:33
Description:
'''
from . import base, file, plot, web, dl_torch, stats
from . import base, file, plot, web, stats
from .__version__ import (
__author__,
__author_email__,
Expand All @@ -17,6 +17,8 @@
__url__,
)

from mbapy import dl_torch as dl_torch

# def main():
# pass
"""
Expand Down
14 changes: 7 additions & 7 deletions mbapy/dl_torch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
'''
Author: BHM-Bob 2262029386@qq.com
Date: 2023-03-21 00:06:00
LastEditors: BHM-Bob
LastEditTime: 2023-05-05 21:53:16
LastEditors: BHM-Bob 2262029386@qq.com
LastEditTime: 2023-05-22 17:18:31
Description:
'''

_Params = {
'USE_VIZDOM':False,
}

# from . import bb, data, loss, m, utils, paper
from mbapy.dl_torch import bb as bb
from mbapy.dl_torch import data as data
from mbapy.dl_torch import m as m
from mbapy.dl_torch import loss as loss
from mbapy.dl_torch import utils as utils

9 changes: 6 additions & 3 deletions mbapy/dl_torch/bb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Author: BHM-Bob 2262029386@qq.com
Date: 2023-03-23 21:50:21
LastEditors: BHM-Bob 2262029386@qq.com
LastEditTime: 2023-05-22 16:48:33
LastEditTime: 2023-05-22 17:19:55
Description: Basic Blocks
'''

Expand All @@ -13,7 +13,7 @@
import torch.nn as nn
import torch.nn.functional as F

from mbapy.dl_torch import paper
from . import paper

class CnnCfg:
@torch.jit.ignore
Expand Down Expand Up @@ -482,8 +482,11 @@ class SABlock1D(SABlock):
def __init__(self, cfg:CnnCfg):
super().__init__(cfg)
self.cnn1 = GenCnn1d(cfg.inc, cfg.outc, cfg.kernel_size)
self.cnn2 = GenCnn1d(cfg.outc, cfg.outc, cfg.kernel_size)
# self.cnn2 = GenCnn1d(cfg.outc, cfg.outc, cfg.kernel_size)
self.extra = nn.Conv1d(cfg.inc, cfg.outc, 1, stride = 1, padding="same")
def forward(self, x): # [b,inc,h,w] => [b,outc,h,w]
out = torch.cat([ cnn(x) for cnn in self.cnn1 ], dim=1)
return out + self.extra(x)

class SABlock1DR(SABlockR):
"""[b, c, l] => [b, c', l']"""
Expand Down
10 changes: 5 additions & 5 deletions mbapy/dl_torch/m.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Author: BHM-Bob 2262029386@qq.com
Date: 2023-03-23 21:50:21
LastEditors: BHM-Bob 2262029386@qq.com
LastEditTime: 2023-05-12 23:06:54
LastEditTime: 2023-05-22 17:17:24
Description: Model, most of models outputs [b, c', w', h'] or [b, l', c'] or [b, D]\n
you can add tail_trans as normal transformer or out_transformer in LayerCfg of model.__init__()
'''
Expand All @@ -16,10 +16,10 @@
import torch.nn as nn
import torch.nn.functional as F

from mbapy.base import autoparse
from mbapy.dl_torch.utils import GlobalSettings
from dl_torch import bb
from dl_torch.bb import CnnCfg
from ..base import autoparse
from .utils import GlobalSettings
from . import bb
from .bb import CnnCfg

# str2net合法性前置声明
str2net = {}
Expand Down
10 changes: 6 additions & 4 deletions mbapy/dl_torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
import torch
import torch.nn as nn

from mbapy.base import MyArgs
from mbapy.file import save_json, read_json
from ..base import MyArgs
from ..file import save_json, read_json

from mbapy.dl_torch import _Params
_Params = {
'USE_VIZDOM':False,
}

if _Params['USE_VIZDOM']:
import visdom
Expand Down Expand Up @@ -214,7 +216,7 @@ def save_checkpoint(epoch, args:GlobalSettings, model, optimizer, loss, other:di
"args":args.toDict(),
}
state.update(other)
filename = os.path.join(args.modelRoot,
filename = os.path.join(args.model_oot,
f"checkpoint_{tailName:s}_{time.asctime(time.localtime()).replace(':', '-'):s}.pth.tar")
torch.save(state, filename)

Expand Down
7 changes: 2 additions & 5 deletions mbapy/test/dl_t/BasicBlock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,14 @@
Author: BHM-Bob 2262029386@qq.com
Date: 2022-11-04 12:33:19
LastEditors: BHM-Bob 2262029386@qq.com
LastEditTime: 2023-05-21 12:11:42
LastEditTime: 2023-05-22 17:09:20
Description: Test for Basic Blocks
'''
import torch
import torch.nn as nn
import torch.nn.functional as F

try:
import mbapy.dl_torch.bb as bb
except:
import dl_torch.bb as bb
import mbapy.dl_torch.bb as bb

x = torch.arange(16, dtype = torch.float32, device = 'cuda').reshape([1, 1, 4, 4])
t = F.unfold(x, 3, 1, 1, 1)
Expand Down
10 changes: 4 additions & 6 deletions mbapy/test/dl_t/data.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
'''
Author: BHM-Bob 2262029386@qq.com
Date: 2023-05-02 20:40:37
LastEditors: BHM-Bob
LastEditTime: 2023-05-06 16:56:27
LastEditors: BHM-Bob 2262029386@qq.com
LastEditTime: 2023-05-22 17:12:17
Description:
'''
import torch
import torch.nn as nn
import torch.nn.functional as F

import dl_torch as dt
dt._Params['USE_VIZDOM'] =False
from dl_torch.utils import Mprint, GlobalSettings
from dl_torch.data import DataSetRAM
from mbapy.dl_torch.utils import Mprint, GlobalSettings
from mbapy.dl_torch.data import DataSetRAM

# global settings
mp = Mprint()
Expand Down

0 comments on commit 75b8c02

Please sign in to comment.