From 25989d02286d694f3ad7b061a6aab658a26acbf6 Mon Sep 17 00:00:00 2001 From: typoverflow Date: Wed, 30 Aug 2023 19:29:49 +0800 Subject: [PATCH] refactor: remove redudant features (#41) --- UtilsRL/__init__.py | 9 +- UtilsRL/exp/__init__.py | 6 +- UtilsRL/exp/argparse.py | 14 -- UtilsRL/exp/precision.py | 53 ----- UtilsRL/exp/snapshot.py | 33 --- UtilsRL/logger/base_logger.py | 15 +- UtilsRL/monitor.py | 282 -------------------------- UtilsRL/operator/wkv_op.cpp | 21 -- UtilsRL/operator/wkv_op.cu | 133 ------------ UtilsRL/operator/wkv_op_extend.cpp | 93 --------- UtilsRL/operator/wkv_op_extend.cu | 204 ------------------- UtilsRL/rl/buffer/trjectory_replay.py | 43 ---- 12 files changed, 8 insertions(+), 898 deletions(-) delete mode 100644 UtilsRL/exp/precision.py delete mode 100644 UtilsRL/exp/snapshot.py delete mode 100644 UtilsRL/monitor.py delete mode 100644 UtilsRL/operator/wkv_op.cpp delete mode 100644 UtilsRL/operator/wkv_op.cu delete mode 100644 UtilsRL/operator/wkv_op_extend.cpp delete mode 100644 UtilsRL/operator/wkv_op_extend.cu delete mode 100644 UtilsRL/rl/buffer/trjectory_replay.py diff --git a/UtilsRL/__init__.py b/UtilsRL/__init__.py index 5919236..0db2662 100644 --- a/UtilsRL/__init__.py +++ b/UtilsRL/__init__.py @@ -1,15 +1,8 @@ __version__ = "0.5.9" -from UtilsRL import data_structure, env, exp, logger, math, misc, plot, rl, monitor +from UtilsRL import exp, logger __all__ = [ - "data_structure", - "env", "exp", "logger", - "math", - "misc", - "plot", - "rl", - "monitor" ] \ No newline at end of file diff --git a/UtilsRL/exp/__init__.py b/UtilsRL/exp/__init__.py index e2b3cd0..057e86a 100644 --- a/UtilsRL/exp/__init__.py +++ b/UtilsRL/exp/__init__.py @@ -1,8 +1,6 @@ from UtilsRL.exp.argparse import parse_args, argparse_callbacks, register_argparse_callback from UtilsRL.exp._seed import * from UtilsRL.exp._device import * -from UtilsRL.exp.snapshot import make_snapshot -from UtilsRL.exp.precision import set_precision from UtilsRL.logger import BaseLogger from UtilsRL.logger import logger as url_internal_logger @@ -63,6 +61,4 @@ def setup(args, return args - -register_argparse_callback("UtilsRL.snapshot", make_snapshot) -register_argparse_callback("UtilsRL.precision", set_precision) \ No newline at end of file + \ No newline at end of file diff --git a/UtilsRL/exp/argparse.py b/UtilsRL/exp/argparse.py index b7487f4..be23579 100644 --- a/UtilsRL/exp/argparse.py +++ b/UtilsRL/exp/argparse.py @@ -110,20 +110,6 @@ def traverse_add(old, new, current_key=""): traverse_add(file_args, cmd_args) - # check if there is a callback - for key in argparse_callbacks: - _args = file_args - _keys = key.split(".") - for k in _keys: - _args = _args.get(k, None) - if _args is None: - ret = argparse_callbacks[key](None) # call callback with default None and then break - file_args = update_args(file_args, ret, eval=False) - break - else: - ret = argparse_callbacks[key](_args) - file_args = update_args(file_args, ret, eval=False) - return file_args def update_args(args, new_args: Optional[Union[dict, list]] = None, eval=True): diff --git a/UtilsRL/exp/precision.py b/UtilsRL/exp/precision.py deleted file mode 100644 index 04f8f45..0000000 --- a/UtilsRL/exp/precision.py +++ /dev/null @@ -1,53 +0,0 @@ -import os -import torch -import numpy as np -from UtilsRL.misc.namespace import NameSpace - -from typing import Optional, Dict, Union - -def set_precision(args: Optional[Union[torch.dtype, str, int]]): - if args is None: # default precision is float32 - args = 32 - if args in {16, 32, 64}: - if args == 16: prec = "float16" - elif args == 32: prec = "float32" - elif args == 64: prec = "float64" - elif args.lower() in {"float32", "float16", "float32", "float64", "double"}: - args = args.lower() - if args == "double": - prec = "float64" - elif args == "float": - prec = "float32" - else: - prec = args - elif args in {torch.float16, torch.float32, torch.float64}: - if args == torch.float16: prec = "float16" - elif args == torch.float32: prec = "float32" - elif args == torch.float64: prec = "float64" - else: - e = f""" - When setting precision, input format should be either - - int, 16 or 32 or 64; - - str, float16 or float or float32 or float64 or double; - - torch.dtype object. - But got {type(args)} {args}. - """ - raise TypeError(e) - - np_ftype, torch_ftype = { - "float16": [np.float16, torch.float16], - "float32": [np.float32, torch.float32], - "float64": [np.float64, torch.float64] - }.get(prec) - torch.set_default_dtype(torch_ftype) - if prec == "float64": - torch.set_default_tensor_type(torch.DoubleTensor) - - return { - "UtilsRL.numpy_fp": np_ftype, - "UtilsRL.torch_fp": torch_ftype, - "UtilsRL.precision": prec - } - - - \ No newline at end of file diff --git a/UtilsRL/exp/snapshot.py b/UtilsRL/exp/snapshot.py deleted file mode 100644 index 38b63f0..0000000 --- a/UtilsRL/exp/snapshot.py +++ /dev/null @@ -1,33 +0,0 @@ -import os - -from UtilsRL.misc.namespace import NameSpaceMeta -from UtilsRL.logger import logger - -from typing import Optional, Dict, Union - - -def make_snapshot(args: Optional[Union[Dict, NameSpaceMeta, str]]): - if args is None: - return { - "UtilsRL.snapshot_branch": None - } - - prefix = branch = "/".join(["snapshot", args]) - suffix = 0 - while os.system(f"git --no-pager branch | grep {branch} > /dev/null 2>&1") == 0: - suffix += 1 - branch = prefix + f"-{suffix}" - cmd = f"git add -A >/dev/null 2>&1 && \ - git stash >/dev/null 2>&1 && \ - git switch -c {branch} >/dev/null 2>&1 && \ - git stash apply >/dev/null 2>&1 && \ - git add -A >/dev/null 2>&1 && \ - git commit -m \"snapshot: {branch}\" >/dev/null 2>&1 && \ - git switch - >/dev/null 2>&1 && \ - git stash pop >/dev/null 2>&1 " - logger.info(f"UtilsRL.snapshot: saving code snapshot to {branch}.") - os.system(cmd) - - return { - "UtilsRL.snapshot_branch": branch - } \ No newline at end of file diff --git a/UtilsRL/logger/base_logger.py b/UtilsRL/logger/base_logger.py index 638fec2..009177a 100644 --- a/UtilsRL/logger/base_logger.py +++ b/UtilsRL/logger/base_logger.py @@ -1,6 +1,7 @@ from typing import Any, Dict, Optional, Sequence, Union import os +import torch import pickle from datetime import datetime @@ -26,18 +27,14 @@ def numpy_save(obj, file): with open(file, "w") as fp: np.save(fp, obj) - if protocol == "torch": - import torch - return torch.save - else: - return { - "pickle": pickle_save, - "numpy": numpy_save - }.get(protocol) + return { + "torch": torch.save, + "pickle": pickle_save, + "numpy": numpy_save + }.get(protocol) def load_fn(protocol: str="torch"): def torch_load(file): - import torch return torch.load(file, map_location="cpu") def pickle_load(file): diff --git a/UtilsRL/monitor.py b/UtilsRL/monitor.py deleted file mode 100644 index 93ff5b6..0000000 --- a/UtilsRL/monitor.py +++ /dev/null @@ -1,282 +0,0 @@ -import os -import sys -import copy -import smtplib -import contextlib -import atexit -import inspect -import pickle -from smtplib import SMTP_SSL, SMTP -from email.mime.text import MIMEText - -from UtilsRL.third_party.tqdm import tqdm_tty, tqdm_notebook, tqdm_file - -from typing import Optional, Sequence, Union, Callable, Any - -tqdm_cls: Any = None -try: - ipy_str = str(type(get_ipython())) - if 'zmqshell' in ipy_str: - tqdm_cls = tqdm_notebook - if 'terminal' in ipy_str: - tqdm_cls = tqdm_tty -except: - if sys.stderr.isatty(): - tqdm_cls = tqdm_tty - else: - tqdm_cls = tqdm_file - - -@contextlib.contextmanager -def update_scope(func: Callable, globals: dict): - old_globals = func.__globals__.copy() - func.__globals__.update(globals) - yield func - func.__globals__.update(old_globals) - - -class MonitorError(Exception): - pass - - -class Monitor(object): - """Monitor is designed to monitor the main training loop and inform the user about the progress. \ - It mainly serves for three purposes: 1. visualize the training progress; 2. register callback functions to trigger actions \ - during certain stage of training; and 3. register context variables to save & load training context. - - :param desc: description of the Monitor, will be displayed at the left side of the progress meter. - :param out_dir: output directory of the products (model checkpoints). This must be specified if you are to use :func:`~UtilsRL.monitor.Monitor.register_context`. - """ - - @staticmethod - def eval_outer(expr): - L = inspect.currentframe().f_back.f_back.f_locals - G = inspect.currentframe().f_back.f_back.f_globals - return eval(expr, G, L) - - @staticmethod - def email(msg, to, user, password, smtp_server=None, port=None): - """Send an email from to a given receiver. - - :param str msg: the message to be sent in the email. - :param str to: email address of the receiver. - :param str user: email address of the sender. - :param str password: the password string for sender email. - :param str smtp_server: the smtp server address for email sending. If this is set to `None`, a default \ - server address will be used according to the email's host. - :param int port: the smtp server port for email sending. If this is set to `None`, a default port will \ - be used according to the email's host. - """ - def get_host(user): - return user.split("@")[1].split(".")[0] - host_info = { - "qq": ("smtp.qq.com", 587), - # "gmail": ("smtp.gmail.com", 587), - "outlook": ("smtp.office365.com", 587) - } - if smtp_server is None or port is None: - host = get_host(user) - if host not in host_info: - raise KeyError("Host {} is not supported, current supported types are: {}".format(host, list(host_info.keys()))) - smtp_server, port = host_info[host] - - _msg = MIMEText(msg, "plain", _charset="utf-8") - _msg["Subject"] = "An message from your program" - _msg["from"] = "UtilsRL.Monitor" - - with SMTP(host=smtp_server, port=port) as smtp: - smtp.starttls() - smtp.login(user = user, password = password) - smtp.sendmail(from_addr=user, to_addrs=to, msg=_msg.as_string()) - - def __init__(self, - desc: Optional[str] = None, - out_dir: Optional[str] = None, - *args, **kwargs): - - self.desc = desc if tqdm_cls == tqdm_file else "\033[1;37m[{}]\033[0m".format(desc) - self.tqdm_cls = tqdm_cls - self.out_dir = out_dir - self.args = args - self.kwargs = kwargs - - self.has_callbacks = False - self.callbacks = list() - self.exit_callbacks = list() - self.end_callbacks = list() - self.has_context = False - - def listen(self, - iterable = None, - initial: Optional[int] = 0, - total: Optional[int] = None, - miniters: Optional[int] = None): - """Set the monitor to listen at the training loop. Note that a monitor can be assigned \ - for listening only once. - - :param iterable: an iterable to listen at, for example, `range(5)`. - :param initial: startpoint of the iteration. - :param total: total number of iterations. If this is set to `None`, the total number of iterations will be `len(iterable)`. - :param miniters: minimum number of iterations between adjacent updates to the meter. - """ - - if hasattr(self, "iterable") and self.iterable is not None: - raise MonitorError("A monitor can only listen at one iterable.") - - self.iterable = iterable - self.tqdm = self.tqdm_cls(iterable, self.desc, total=total, initial=initial, miniters=miniters) - self.total = self.tqdm.total - self.initial = initial - self.miniters = self.tqdm.miniters - - self.global_step = self.initial - return self - - def register_callback(self, - name: str, - on: Optional[Union[str, int]] = None, - callback: Callable = None, - *args, **kwargs): - """Register callback functions which will be called when `on` is satisfied. - - :param name: the name of the callback function. - :param on: specifies the condition when `callback` is triggered. Legal values are: - - - None, then the callback will never be called. - - `exit`, then the callback will be called on exit. - - int, then the callback will be called at the beginning of `on`th iteration. - - str which represents a percentage, then the callback will be called at this \ - stage of training. - :param callback: the callback funtion. It will take args and kwargs as input, and self.global_step \ - will also be added as keyward argument. **So when defining a callback function, it's \ - recommended to add redundant kwargs with `**kwargs` to its signature**. - """ - self.has_callbacks = True - if on is None or on == False: - return - elif on == 'exit': - if name in [ec["name"] for ec in self.exit_callbacks]: - return - self.exit_callbacks.append({ - "name": name, - "on": "exit", - "callback": callback, - "args": (args, kwargs) - }) - atexit.register(callback, *args, **kwargs) - elif on == "100%" or on == "end": - if name in [ec["name"] for ec in self.end_callbacks]: - return - self.end_callbacks.append({ - "name": name, - "on": "end", - "callback": callback, - "args": (args, kwargs) - }) - else: - if name in [c["name"] for c in self.callbacks]: - return - if isinstance(on, str): - try: - per = float(on[:-1]) / 100 - assert 0 <= per < 1 - except Exception: - raise MonitorError("Invalid percentage {}".format(on)) - else: - if not isinstance(on, int): - raise MonitorError("Unrecognized condition: {}".format(on)) - self.callbacks.append({ - "name": name, - "on": on, - "callback": callback, - "args": (args, kwargs) - }) - - def register_context(self, expressions, save_every=None, save_mode="replace", load_path=None): - """Register variables as context. Monitor will save the context variables \ - periodically and restore them if the training is resumed from a checkpoint. \ - Note: only one register_context call with valid save_every is permitted. - - :param expressions: the expressions of the variables which you wish to designate as context. - :param save_every: save the context every `save_every` iterations. If set to None, the context - will not be saved. - :param save_mode: specifies the mode of saving. values are: - - - "replace": replace previously saved context. - - "append": save context without replacing. - :param load_path: specifies the path of the checkpoint of the context to load. If set to None, - the context will not be loaded. - """ - - if isinstance(expressions, str): - expressions = [expressions] - if save_every is None or save_every == False: - pass - elif isinstance(save_every, int) and save_every >= 1: - if self.has_context: - raise MonitorError("Only one register_context call with valid save_every is permitted.") - if self.out_dir is None: - raise MonitorError("Before using monitor to save context, you must specify the output directory.") - if save_mode not in ["replace", "append"]: - raise MonitorError("save_mode must be either 'replace' or 'append'.") - # save the infos for context saving - self.has_context = True - self.context = expressions - self.context_load_path = load_path - self.save_every = save_every - self.save_mode = save_mode - else: - raise MonitorError(f"Illegal value for save_every: {save_every}") - - ret_dict = dict() - if load_path: - # load obj from given path - for expr in expressions: - if os.path.exists(os.path.join(load_path, expr)): - with open(os.path.join(load_path, expr), "rb") as fp: - ret_dict[expr] = pickle.load(fp) - else: - ret_dict[expr] = Monitor.eval_outer(expr) - else: - # get reference from outer scope - for expr in expressions: - ret_dict[expr] = Monitor.eval_outer(expr) - - return ret_dict if len(ret_dict) > 1 else ret_dict[expressions[0]] - - def _check_callbacks(self): - if not hasattr(self, "global_step") or not hasattr(self, "total"): - raise MonitorError("Monitor must listen on an iterable before registered callbacks can be called.") - for c in self.callbacks: - if isinstance(c["on"], int): - if self.global_step == c["on"]: - c["callback"]( *c["args"][0], **c["args"][1], global_step=self.global_step) - elif isinstance(c["on"], str): - threshold = int(c["on"][:-1]) / 100 * self.total - if self.global_step >= threshold and self.global_step - 1 < threshold: - c["callback"](*c["args"][0], **c["args"][1], global_step=self.global_step) - - def __iter__(self): - tqdm_iter = iter(self.tqdm) - while True: - try: - if self.has_context \ - and self.global_step > self.initial \ - and self.global_step % self.save_every == 0: - - save_path = os.path.join(self.out_dir, str(self.global_step)) if self.save_mode == "append" \ - else os.path.join(self.out_dir, "context") - if not os.path.exists(save_path): - os.makedirs(save_path) - for expr in self.context: - obj = Monitor.eval_outer(expr) - with open(os.path.join(save_path, expr), "wb") as fp: - pickle.dump(obj, fp) - if self.has_callbacks: - self._check_callbacks() - yield next(tqdm_iter) - self.global_step += 1 - except StopIteration: - for c in self.end_callbacks: - c["callback"](*c["args"][0], **c["args"][1], global_step=self.global_step) - break diff --git a/UtilsRL/operator/wkv_op.cpp b/UtilsRL/operator/wkv_op.cpp deleted file mode 100644 index 4ac9c17..0000000 --- a/UtilsRL/operator/wkv_op.cpp +++ /dev/null @@ -1,21 +0,0 @@ -#include - -void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y); -void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv); - -void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { - cuda_forward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr()); -} -void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { - cuda_backward(B, T, C, w.data_ptr(), u.data_ptr(), k.data_ptr(), v.data_ptr(), y.data_ptr(), gy.data_ptr(), gw.data_ptr(), gu.data_ptr(), gk.data_ptr(), gv.data_ptr()); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &forward, "wkv forward"); - m.def("backward", &backward, "wkv backward"); -} - -TORCH_LIBRARY(wkv, m) { - m.def("forward", forward); - m.def("backward", backward); -} \ No newline at end of file diff --git a/UtilsRL/operator/wkv_op.cu b/UtilsRL/operator/wkv_op.cu deleted file mode 100644 index a7f4137..0000000 --- a/UtilsRL/operator/wkv_op.cu +++ /dev/null @@ -1,133 +0,0 @@ -#include -#include - -#define MIN_VALUE (-1e38) - -template -__global__ void kernel_forward(const int B, const int T, const int C, - const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, - F *__restrict__ const _y) { - const int idx = blockIdx.x * blockDim.x + threadIdx.x; - const int _b = idx / C; - const int _c = idx % C; - const int _offset = _b * T * C + _c; - - F u = _u[_c]; - F w = _w[_c]; - const F *__restrict__ const k = _k + _offset; - const F *__restrict__ const v = _v + _offset; - F *__restrict__ const y = _y + _offset; - - // aa and bb are running sums divided by exp(pp) (to avoid overflow) - F aa = 0, bb = 0, pp = MIN_VALUE; - for (int i = 0; i < T; i++) { - const int ii = i * C; - const F kk = k[ii]; - const F vv = v[ii]; - - F ww = u + kk; - F p = max(pp, ww); - F e1 = exp(pp - p); - F e2 = exp(ww - p); - y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2); - - ww = w + pp; - p = max(ww, kk); - e1 = exp(ww - p); - e2 = exp(kk - p); - aa = e1 * aa + e2 * vv; - bb = e1 * bb + e2; - pp = p; - } -} - -template -__global__ void kernel_backward(const int B, const int T, const int C, - const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, - const F *__restrict__ const _y, const F *__restrict__ const _gy, - F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) { - const int idx = blockIdx.x * blockDim.x + threadIdx.x; - const int _b = idx / C; - const int _c = idx % C; - const int _offset = _b * T * C + _c; - - F u = _u[_c]; - F w = _w[_c]; - const F *__restrict__ const k = _k + _offset; - const F *__restrict__ const v = _v + _offset; - const F *__restrict__ const y = _y + _offset; - const F *__restrict__ const gy = _gy + _offset; - F *__restrict__ const gk = _gk + _offset; - F *__restrict__ const gv = _gv + _offset; - - F q[Tmax], r[Tmax]; - - F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE; - for (int i = 0; i < T; i++) { - const int ii = i * C; - const F kk = k[ii]; - const F vv = v[ii]; - const F yy = y[ii]; - - F ww = u + kk; - F p = max(pp, ww); - F e1 = exp(pp - p); - F e2 = exp(ww - p); - const F qq = gy[ii] / (e1 * bb + e2); - gw += (ga - gb * yy) * e1 * qq; - gu += (vv - yy) * e2 * qq; - q[i] = qq; - r[i] = ww - p; - - ww = w + pp; - p = max(ww, kk); - e1 = exp(ww - p); - e2 = exp(kk - p); - ga = e1 * (aa + ga); - gb = e1 * (bb + gb); - aa = e1 * aa + e2 * vv; - bb = e1 * bb + e2; - pp = p; - } - const int _offsetBC = _b * C + _c; - _gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward() - _gu[_offsetBC] = gu; - - aa = 0, bb = 0, pp = MIN_VALUE; - for (int i = T - 1; i >= 0; i--) { - const int ii = i * C; - const F kk = k[ii]; - const F vv = v[ii]; - const F yy = y[ii]; - const F qq = q[i]; - const F rr = r[i]; - - F e1 = qq * exp(rr); - F e2 = exp(kk + pp); - gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb); - gv[ii] = e1 + e2 * aa; - - const F ww = w + pp; - const F www = rr - u - kk; - const F p = max(ww, www); - e1 = exp(ww - p); - e2 = qq * exp(www - p); - aa = e1 * aa + e2; - bb = e1 * bb - e2 * yy; - pp = p; - } -} - -void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) { - dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance - assert(B * C % threadsPerBlock.x == 0); - dim3 numBlocks(B * C / threadsPerBlock.x); - kernel_forward<<>>(B, T, C, w, u, k, v, y); -} - -void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) { - dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance - assert(B * C % threadsPerBlock.x == 0); - dim3 numBlocks(B * C / threadsPerBlock.x); - kernel_backward<<>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv); -} \ No newline at end of file diff --git a/UtilsRL/operator/wkv_op_extend.cpp b/UtilsRL/operator/wkv_op_extend.cpp deleted file mode 100644 index 147515f..0000000 --- a/UtilsRL/operator/wkv_op_extend.cpp +++ /dev/null @@ -1,93 +0,0 @@ -#include - -void cuda_forward( - int B, int T, int C, - float *w, float *u, float *k, float *v, - float *h1, float *h2, float *y, - float *oh1, float *oh2 -); - -void cuda_backward( - int B, int T, int C, - float *w, float *u, float *k, float *v, - float *h1, float *h2, float *y, - float *gy, float *goh1, float *goh2, - float *gw, float *gu, float *gk, float *gv, - float *gh1, float *gh2 -); - -void forward( - int64_t B, int64_t T, int64_t C, - torch::Tensor &w, - torch::Tensor &u, - torch::Tensor &k, - torch::Tensor &v, - torch::Tensor &h1, - torch::Tensor &h2, - torch::Tensor &y, - torch::Tensor &oh1, - torch::Tensor &oh2 -) { - cuda_forward( - B, T, C, - w.data_ptr(), - u.data_ptr(), - k.data_ptr(), - v.data_ptr(), - h1.data_ptr(), - h2.data_ptr(), - y.data_ptr(), - oh1.data_ptr(), - oh2.data_ptr() - ); -} - -void backward( - int64_t B, int64_t T, int64_t C, - torch::Tensor &w, - torch::Tensor &u, - torch::Tensor &k, - torch::Tensor &v, - torch::Tensor &h1, - torch::Tensor &h2, - torch::Tensor &y, - torch::Tensor &gy, - torch::Tensor &goh1, - torch::Tensor &goh2, - torch::Tensor &gw, - torch::Tensor &gu, - torch::Tensor &gk, - torch::Tensor &gv, - torch::Tensor &gh1, - torch::Tensor &gh2 -) { - cuda_backward( - B, T, C, - w.data_ptr(), - u.data_ptr(), - k.data_ptr(), - v.data_ptr(), - h1.data_ptr(), - h2.data_ptr(), - y.data_ptr(), - gy.data_ptr(), - goh1.data_ptr(), - goh2.data_ptr(), - gw.data_ptr(), - gu.data_ptr(), - gk.data_ptr(), - gv.data_ptr(), - gh1.data_ptr(), - gh2.data_ptr() - ); -} - -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("forward", &forward, "wkv forward"); - m.def("backward", &backward, "wkv backward"); -} - -TORCH_LIBRARY(wkv_extend, m) { - m.def("forward", forward); - m.def("backward", backward); -} \ No newline at end of file diff --git a/UtilsRL/operator/wkv_op_extend.cu b/UtilsRL/operator/wkv_op_extend.cu deleted file mode 100644 index b95a7d3..0000000 --- a/UtilsRL/operator/wkv_op_extend.cu +++ /dev/null @@ -1,204 +0,0 @@ -#include -#include - -#define MIN_VALUE (-1e38) - -template -__global__ void kernel_forward( - const int B, - const int T, - const int C, - const F *__restrict__ const _w, - const F *__restrict__ const _u, - const F *__restrict__ const _k, - const F *__restrict__ const _v, - const F *__restrict__ const _h1, - const F *__restrict__ const _h2, - F *__restrict__ const _y, - F *__restrict__ const _oh1, - F *__restrict__ const _oh2 -) { - const int idx = blockIdx.x * blockDim.x + threadIdx.x; - const int _b = idx / C; - const int _c = idx % C; - const int _offset = _b * T * C + _c; - const int _hist_offset = _b * C + _c; - - F u = _u[_c]; - F w = _w[_c]; - F aa, bb, pp; - aa = (_h1 == NULL)? 0:_h1[_hist_offset]; - bb = (_h2 == NULL)? 0:_h2[_hist_offset*2]; - pp = (_h2 == NULL)? MIN_VALUE:_h2[_hist_offset*2+1]; - - const F *__restrict__ const k = _k + _offset; - const F *__restrict__ const v = _v + _offset; - F *__restrict__ const y = _y + _offset; - - // aa and bb are running sums divided by exp(pp) (to avoid overflow) - for (int i = 0; i < T; i++) { - const int ii = i * C; - const F kk = k[ii]; - const F vv = v[ii]; - - F ww = u + kk; - F p = max(pp, ww); - F e1 = exp(pp - p); - F e2 = exp(ww - p); - y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2); - - ww = w + pp; - p = max(ww, kk); - e1 = exp(ww - p); - e2 = exp(kk - p); - aa = e1 * aa + e2 * vv; - bb = e1 * bb + e2; - pp = p; - } - - _oh1[_hist_offset] = aa; - _oh2[_hist_offset*2] = bb; - _oh2[_hist_offset*2+1] = pp; -} - -void cuda_forward( - int B, int T, int C, - float *w, float *u, float *k, float *v, - float *h1, float *h2, float *y, - float *oh1, float *oh2 -) { - dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance - assert(B * C % threadsPerBlock.x == 0); - dim3 numBlocks(B * C / threadsPerBlock.x); - kernel_forward<<>>(B, T, C, w, u, k, v, h1, h2, y, oh1, oh2); -} - - -template -__global__ void kernel_backward( - const int B, - const int T, - const int C, - const F *__restrict__ const _w, - const F *__restrict__ const _u, - const F *__restrict__ const _k, - const F *__restrict__ const _v, - const F *__restrict__ const _h1, - const F *__restrict__ const _h2, - const F *__restrict__ const _y, - const F *__restrict__ const _gy, - const F *__restrict__ const _goh1, - const F *__restrict__ const _goh2, - F *__restrict__ const _gw, - F *__restrict__ const _gu, - F *__restrict__ const _gk, - F *__restrict__ const _gv, - F *__restrict__ const _gh1, - F *__restrict__ const _gh2 -) { - const int idx = blockIdx.x * blockDim.x + threadIdx.x; - const int _b = idx / C; - const int _c = idx % C; - const int _offset = _b * T * C + _c; - const int _hist_offset = _b * C + _c; - - F u = _u[_c]; - F w = _w[_c]; - const F *__restrict__ const k = _k + _offset; - const F *__restrict__ const v = _v + _offset; - const F *__restrict__ const y = _y + _offset; - const F *__restrict__ const gy = _gy + _offset; - F *__restrict__ const gk = _gk + _offset; - F *__restrict__ const gv = _gv + _offset; - - F q[Tmax], r[Tmax]; - F gw = 0, gu = 0, ga = 0, gb = 0; - F aa, bb, pp; - aa = (_h1 == NULL)? 0:_h1[_hist_offset]; - bb = (_h2 == NULL)? 0:_h2[_hist_offset*2]; - pp = (_h2 == NULL)? MIN_VALUE:_h2[_hist_offset*2+1]; - for (int i = 0; i < T; i++) { - const int ii = i * C; - const F kk = k[ii]; - const F vv = v[ii]; - const F yy = y[ii]; - - F ww = u + kk; - F p = max(pp, ww); - F e1 = exp(pp - p); - F e2 = exp(ww - p); - const F qq = gy[ii] / (e1 * bb + e2); - gw += (ga - gb * yy) * e1 * qq; - gu += (vv - yy) * e2 * qq; - q[i] = qq; - r[i] = ww - p; - - ww = w + pp; - p = max(ww, kk); - e1 = exp(ww - p); - e2 = exp(kk - p); - ga = e1 * (aa + ga); - gb = e1 * (bb + gb); - aa = e1 * aa + e2 * vv; - bb = e1 * bb + e2; - pp = p; - } - - F gaa = 0, gbb = 0, gpp = MIN_VALUE; - // if (_goh1 != NULL && _goh2 != NULL) { - gaa = _goh1[_hist_offset]; - gbb = _goh2[_hist_offset*2]; - gpp = _goh2[_hist_offset*2+1]; - if (gaa == 0 && gbb == 0) gpp = MIN_VALUE; - // actually torch will always feed _goh1 and goh2 with 0 at first back-propagation, so we set gpp = MIN_VALUE for consistency - // } - - // below is back-propagating gradients through time - gw += (gaa * ga + gbb * gb); - for (int i = T - 1; i >= 0; i--) { - const int ii = i * C; - const F kk = k[ii]; - const F vv = v[ii]; - const F yy = y[ii]; - const F qq = q[i]; - const F rr = r[i]; - - F e1 = qq * exp(rr); - F e2 = exp(kk + gpp); - gk[ii] = e1 * (vv - yy) + e2 * (gaa * vv + gbb); - gv[ii] = e1 + e2 * gaa; - - const F ww = w + gpp; - const F www = rr - u - kk; - const F p = max(ww, www); - e1 = exp(ww - p); - e2 = qq * exp(www - p); - gaa = e1 * gaa + e2; - gbb = e1 * gbb - e2 * yy; - gpp = p; - } - - if (_gh1 != NULL and _gh2 != NULL) { - _gh1[_hist_offset] = gaa; - _gh2[_hist_offset*2] = gbb; - _gh2[_hist_offset*2+1] = gpp; - } - - const int _offsetBC = _b * C + _c; - _gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward() - _gu[_offsetBC] = gu; -} - -void cuda_backward( - int B, int T, int C, - float *w, float *u, float *k, float *v, - float *h1, float *h2, float *y, - float *gy, float *goh1, float *goh2, - float *gw, float *gu, float *gk, float *gv, - float *gh1, float* gh2 -) { - dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance - assert(B * C % threadsPerBlock.x == 0); - dim3 numBlocks(B * C / threadsPerBlock.x); - kernel_backward<<>>(B, T, C, w, u, k, v, h1, h2, y, gy, goh1, goh2, gw, gu, gk, gv, gh1, gh2); -} \ No newline at end of file diff --git a/UtilsRL/rl/buffer/trjectory_replay.py b/UtilsRL/rl/buffer/trjectory_replay.py deleted file mode 100644 index c882a61..0000000 --- a/UtilsRL/rl/buffer/trjectory_replay.py +++ /dev/null @@ -1,43 +0,0 @@ -# from typing import Optional, Union, Any, Sequence -# from typing import Dict as DictLike - -# import numpy as np - -# from UtilsRL.logger import logger -# from .base import Replay, SimpleReplay, FlexReplay - -# class TrajectorySimpleReplay(SimpleReplay): -# def __init__(self, max_size: int, max_traj_len: int, field_specs: Optional[DictLike]=None, *args, **kwargs): -# field_specs["valid"] = { -# "shape": [1, ], -# "dtype": np.float32, -# } -# super().__init__(max_size, field_specs, *args, **kwargs) -# self._max_traj_len = max_traj_len -# self._size = 0 - -# def reset(self): -# self._pointer = self._size = 0 -# self.fields = self.fields or {} -# for _key, _specs in self.field_specs.items(): -# initializer = _specs.get("initializer", np.zeros) -# self.fields[_key] = initializer(shape=[self._max_size, self._max_traj_len, ]+list(_specs["shape"]), dtype=_specs["stype"]) - -# def add_fields(self, new_field_specs: Optional[DictLike]=None): -# new_field_specs = new_field_specs or {} -# self.fields = self.fields or {} -# for _key, _specs in new_field_specs.items(): -# _old_specs = self.field_specs.get(_key, None) -# if _old_specs is None or _old_specs != _specs: -# self.field_specs[_key] = _specs -# initializer = _specs.get("initializer", np.zeros) -# self.fields[_key] = initializer(shape=[self._max_size, self._max_traj_len, ]+list(_specs["shape"]), dtype=_specs["stype"]) - -# def add_sample(self, key_or_dict: Union[str, DictLike], data: Optional[Any]=None): -# # we force data to be [Batch, Length, ...] -# def pad_or_trunc(name, values, max_len): -# if values.shape - -# if isinstance(key_or_dict, str): -# key_or_dict = {key_or_dict: data} - \ No newline at end of file