From 10518a3a412502258a175a63de9c96fba53fd720 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 30 May 2023 14:40:35 +0800 Subject: [PATCH] Refine formula, docstring and code. (#353) * refine losses formula and doc of loss * refine docstring according to pydocstyle --- docs/zh/api/autodiff.md | 1 + docs/zh/api/visualize.md | 2 + ppsci/arch/afno.py | 8 +-- ppsci/arch/physx_transformer.py | 3 +- ppsci/autodiff/ad.py | 2 +- ppsci/constraint/base.py | 2 +- ppsci/data/dataset/array_dataset.py | 1 + ppsci/data/dataset/era5_dataset.py | 4 +- ppsci/geometry/geometry.py | 4 +- ppsci/geometry/geometry_2d.py | 2 +- ppsci/geometry/mesh.py | 53 ++++++++--------- ppsci/geometry/pointcloud.py | 14 ++--- ppsci/loss/base.py | 2 +- ppsci/loss/integral.py | 8 ++- ppsci/loss/l1.py | 31 +++++++--- ppsci/loss/l2.py | 67 ++++++++++++++++++---- ppsci/loss/mse.py | 31 ++++++++-- ppsci/optimizer/__init__.py | 4 +- ppsci/optimizer/lr_scheduler.py | 4 +- ppsci/optimizer/optimizer.py | 6 +- ppsci/solver/train.py | 6 +- ppsci/solver/visu.py | 4 +- ppsci/utils/checker.py | 1 - ppsci/utils/config.py | 4 +- ppsci/utils/download.py | 89 ++++++++++++++--------------- ppsci/utils/expression.py | 5 ++ ppsci/utils/initializer.py | 4 +- ppsci/utils/logger.py | 2 +- ppsci/utils/misc.py | 18 +++--- ppsci/utils/profiler.py | 2 +- ppsci/utils/reader.py | 2 +- ppsci/visualize/base.py | 2 +- ppsci/visualize/plot.py | 8 +-- 33 files changed, 237 insertions(+), 159 deletions(-) diff --git a/docs/zh/api/autodiff.md b/docs/zh/api/autodiff.md index ec81debb3..8189609db 100644 --- a/docs/zh/api/autodiff.md +++ b/docs/zh/api/autodiff.md @@ -8,5 +8,6 @@ - Jacobians - hessian - Hessians + - clear show_root_heading: false heading_level: 3 diff --git a/docs/zh/api/visualize.md b/docs/zh/api/visualize.md index d69b85b29..2c120b412 100644 --- a/docs/zh/api/visualize.md +++ b/docs/zh/api/visualize.md @@ -4,6 +4,7 @@ handler: python options: members: + - Visualizer - VisualizerScatter1D - VisualizerScatter3D - VisualizerVtu @@ -12,6 +13,7 @@ - Visualizer3D - VisualizerWeather - save_vtu_from_dict + - save_vtu_to_mesh - save_plot_from_1d_dict - save_plot_from_3d_dict show_root_heading: false diff --git a/ppsci/arch/afno.py b/ppsci/arch/afno.py index d4d2f11a4..8c20609d1 100644 --- a/ppsci/arch/afno.py +++ b/ppsci/arch/afno.py @@ -22,8 +22,8 @@ import paddle import paddle.fft -import paddle.nn as nn import paddle.nn.functional as F +from paddle import nn from ppsci.arch import activation as act_mod from ppsci.arch import base @@ -381,7 +381,7 @@ def __init__( ) def forward(self, x): - B, C, H, W = x.shape + _, _, H, W = x.shape if not (H == self.img_size[0] and W == self.img_size[1]): raise ValueError( f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." @@ -541,7 +541,7 @@ def forward(self, x): y = [] input = x - for i in range(self.num_timestamps): + for _ in range(self.num_timestamps): out = self.forward_tensor(input) y.append(out) input = out @@ -665,7 +665,7 @@ def forward(self, x): input_wind = x y = [] - for i in range(self.num_timestamps): + for _ in range(self.num_timestamps): with paddle.no_grad(): out_wind = self.wind_model.forward_tensor(input_wind) out = self.forward_tensor(out_wind) diff --git a/ppsci/arch/physx_transformer.py b/ppsci/arch/physx_transformer.py index 5f597aaf4..a109a99da 100644 --- a/ppsci/arch/physx_transformer.py +++ b/ppsci/arch/physx_transformer.py @@ -114,8 +114,7 @@ def split_heads(self, x, k=False): x = x.reshape(new_x_shape) if k: return x.transpose([0, 2, 3, 1]) - else: - return x.transpose([0, 2, 1, 3]) + return x.transpose([0, 2, 1, 3]) def forward( self, diff --git a/ppsci/autodiff/ad.py b/ppsci/autodiff/ad.py index 8f0835749..c799b7d6a 100644 --- a/ppsci/autodiff/ad.py +++ b/ppsci/autodiff/ad.py @@ -166,7 +166,7 @@ def __call__( j: int = 0, grad_y: Optional[paddle.Tensor] = None, ) -> paddle.Tensor: - """compute hessian matrix for given ys and xs. + """Compute hessian matrix for given ys and xs. Args: ys (paddle.Tensor): Output tensor. diff --git a/ppsci/constraint/base.py b/ppsci/constraint/base.py index bee61f56c..da0f38efe 100644 --- a/ppsci/constraint/base.py +++ b/ppsci/constraint/base.py @@ -15,7 +15,7 @@ from typing import Any from typing import Dict -import paddle.io as io +from paddle import io from ppsci import data from ppsci import loss diff --git a/ppsci/data/dataset/array_dataset.py b/ppsci/data/dataset/array_dataset.py index f657274d9..6eab8364a 100644 --- a/ppsci/data/dataset/array_dataset.py +++ b/ppsci/data/dataset/array_dataset.py @@ -95,6 +95,7 @@ def __init__( weight: Dict[str, np.ndarray], transforms: Optional[vision.Compose] = None, ): + super().__init__() self.input = {key: paddle.to_tensor(value) for key, value in input.items()} self.label = {key: paddle.to_tensor(value) for key, value in label.items()} self.weight = {key: paddle.to_tensor(value) for key, value in weight.items()} diff --git a/ppsci/data/dataset/era5_dataset.py b/ppsci/data/dataset/era5_dataset.py index 00f5be6d1..89e6157e4 100644 --- a/ppsci/data/dataset/era5_dataset.py +++ b/ppsci/data/dataset/era5_dataset.py @@ -90,8 +90,8 @@ def read_data(self, path: str, var="fields"): paths = [path] if path.endswith(".h5") else glob.glob(path + "/*.h5") paths.sort() files = [] - for path in paths: - _file = h5py.File(path, "r") + for path_ in paths: + _file = h5py.File(path_, "r") files.append(_file[var]) return files diff --git a/ppsci/geometry/geometry.py b/ppsci/geometry/geometry.py index 61166d5b9..06c3b2eef 100644 --- a/ppsci/geometry/geometry.py +++ b/ppsci/geometry/geometry.py @@ -157,8 +157,8 @@ def sample_boundary(self, n, random="pseudo", criteria=None, evenly=False): ): area_dict = misc.convert_to_dict(area[:, 1:], ["area"]) return {**x_dict, **normal_dict, **area_dict} - else: - return {**x_dict, **normal_dict} + + return {**x_dict, **normal_dict} @abc.abstractmethod def random_points(self, n: int, random: str = "pseudo"): diff --git a/ppsci/geometry/geometry_2d.py b/ppsci/geometry/geometry_2d.py index eb7b05e7a..d32f9190e 100644 --- a/ppsci/geometry/geometry_2d.py +++ b/ppsci/geometry/geometry_2d.py @@ -455,7 +455,7 @@ def random_points(self, n, random="pseudo"): x = np.empty((0, 2), dtype=paddle.get_default_dtype()) vbbox = self.bbox[1] - self.bbox[0] while len(x) < n: - x_new = sampler.sample(n, 2, sampler="pseudo") * vbbox + self.bbox[0] + x_new = sampler.sample(n, 2, "pseudo") * vbbox + self.bbox[0] x = np.vstack((x, x_new[self.is_inside(x_new)])) return x[:n] diff --git a/ppsci/geometry/mesh.py b/ppsci/geometry/mesh.py index b02b839f4..8376bccf0 100644 --- a/ppsci/geometry/mesh.py +++ b/ppsci/geometry/mesh.py @@ -222,7 +222,7 @@ def inflated_random_boundary_points( all_points.append(normal) all_points.append(area) - all_point = np.concatenate(all_point, axis=0) + all_points = np.concatenate(all_points, axis=0) all_normal = np.concatenate(all_normal, axis=0) all_area = np.concatenate(all_area, axis=0) return all_points, all_normal, all_area @@ -270,7 +270,7 @@ def random_boundary_points(self, n, random="pseudo", criteria=None): npoint_per_triangle, np.arange(len(triangle_prob) + 1) - 0.5 ) - all_point = [] + all_points = [] all_normal = [] all_area = [] for i, npoint in enumerate(npoint_per_triangle): @@ -286,17 +286,17 @@ def random_boundary_points(self, n, random="pseudo", criteria=None): dtype=paddle.get_default_dtype(), ) - all_point.append(face_points) + all_points.append(face_points) all_normal.append(face_normal) all_area.append(valid_area) - all_point = np.concatenate(all_point, axis=0) + all_points = np.concatenate(all_points, axis=0) all_normal = np.concatenate(all_normal, axis=0) all_area = np.concatenate(all_area, axis=0) # NOTE: use global mean area instead of local mean area all_area = np.full_like(all_area, all_area.mean()) - return all_point, all_normal, all_area + return all_points, all_normal, all_area def sample_boundary( self, n, random="pseudo", criteria=None, evenly=False, inflation_dist=None @@ -311,11 +311,8 @@ def sample_boundary( raise ValueError( "Can't sample evenly on mesh now, please set evenly=False." ) - # points, normal, area = self.uniform_boundary_points(n, False) - else: - points, normals, areas = self.random_boundary_points( - n, random, criteria - ) + # points, normal, area = self.uniform_boundary_points(n, False) + points, normals, areas = self.random_boundary_points(n, random, criteria) x_dict = misc.convert_to_dict(points, self.dim_keys) normal_dict = misc.convert_to_dict( @@ -362,9 +359,7 @@ def sample_interior(self, n, random="pseudo", criteria=None, evenly=False): raise NotImplementedError( "uniformly sample for interior in mesh is not support yet" ) - # points, area = self.uniform_points(n) - else: - points, areas = self.random_points(n, random, criteria) + points, areas = self.random_points(n, random, criteria) x_dict = misc.convert_to_dict(points, self.dim_keys) area_dict = misc.convert_to_dict(areas, ["area"]) @@ -375,45 +370,47 @@ def sample_interior(self, n, random="pseudo", criteria=None, evenly=False): return {**x_dict, **area_dict, **sdf_dict} - def union(self, rhs: "Mesh"): + def union(self, other: "Mesh"): if not checker.dynamic_import_to_globals(["pymesh"]): raise ModuleNotFoundError import pymesh - csg = pymesh.CSGTree({"union": [{"mesh": self.py_mesh}, {"mesh": rhs.py_mesh}]}) + csg = pymesh.CSGTree( + {"union": [{"mesh": self.py_mesh}, {"mesh": other.py_mesh}]} + ) return Mesh(csg.mesh) - def __or__(self, rhs: "Mesh"): - return self.union(rhs) + def __or__(self, other: "Mesh"): + return self.union(other) - def __add__(self, rhs: "Mesh"): - return self.union(rhs) + def __add__(self, other: "Mesh"): + return self.union(other) - def difference(self, rhs: "Mesh"): + def difference(self, other: "Mesh"): if not checker.dynamic_import_to_globals(["pymesh"]): raise ModuleNotFoundError import pymesh csg = pymesh.CSGTree( - {"difference": [{"mesh": self.py_mesh}, {"mesh": rhs.py_mesh}]} + {"difference": [{"mesh": self.py_mesh}, {"mesh": other.py_mesh}]} ) return Mesh(csg.mesh) - def __sub__(self, rhs: "Mesh"): - return self.difference(rhs) + def __sub__(self, other: "Mesh"): + return self.difference(other) - def intersection(self, rhs: "Mesh"): + def intersection(self, other: "Mesh"): if not checker.dynamic_import_to_globals(["pymesh"]): raise ModuleNotFoundError import pymesh csg = pymesh.CSGTree( - {"intersection": [{"mesh": self.py_mesh}, {"mesh": rhs.py_mesh}]} + {"intersection": [{"mesh": self.py_mesh}, {"mesh": other.py_mesh}]} ) return Mesh(csg.mesh) - def __and__(self, rhs: "Mesh"): - return self.intersection(rhs) + def __and__(self, other: "Mesh"): + return self.intersection(other) def __str__(self) -> str: """Return the name of class""" @@ -429,7 +426,7 @@ def __str__(self) -> str: def area_of_triangles(v0, v1, v2): - """ref https://math.stackexchange.com/questions/128991/how-to-calculate-the-area-of-a-3d-triangle + """Ref https://math.stackexchange.com/questions/128991/how-to-calculate-the-area-of-a-3d-triangle Args: v0 (np.ndarray): Coordinates of the first vertex of the triangle surface with shape of [N, 3]. diff --git a/ppsci/geometry/pointcloud.py b/ppsci/geometry/pointcloud.py index 96260788c..38150119b 100644 --- a/ppsci/geometry/pointcloud.py +++ b/ppsci/geometry/pointcloud.py @@ -18,8 +18,8 @@ import numpy as np -import ppsci.utils.misc as misc from ppsci.geometry import geometry +from ppsci.utils import misc class PointCloud(geometry.Geometry): @@ -148,32 +148,32 @@ def uniform_points(self, n: int, boundary=True): """Compute the equispaced points in the geometry.""" return self.interior[:n] - def union(self, rhs): + def union(self, other): raise NotImplementedError( "Union operation for PointCloud is not supported yet." ) - def __or__(self, rhs): + def __or__(self, other): raise NotImplementedError( "Union operation for PointCloud is not supported yet." ) - def difference(self, rhs): + def difference(self, other): raise NotImplementedError( "Subtraction operation for PointCloud is not supported yet." ) - def __sub__(self, rhs): + def __sub__(self, other): raise NotImplementedError( "Subtraction operation for PointCloud is not supported yet." ) - def intersection(self, rhs): + def intersection(self, other): raise NotImplementedError( "Intersection operation for PointCloud is not supported yet." ) - def __and__(self, rhs): + def __and__(self, other): raise NotImplementedError( "Intersection operation for PointCloud is not supported yet." ) diff --git a/ppsci/loss/base.py b/ppsci/loss/base.py index 53a8224ae..e9ac8c17d 100644 --- a/ppsci/loss/base.py +++ b/ppsci/loss/base.py @@ -16,7 +16,7 @@ from typing import Optional from typing import Union -import paddle.nn as nn +from paddle import nn from typing_extensions import Literal diff --git a/ppsci/loss/integral.py b/ppsci/loss/integral.py index cbbafb9e8..b28cff561 100644 --- a/ppsci/loss/integral.py +++ b/ppsci/loss/integral.py @@ -28,12 +28,14 @@ class IntegralLoss(base.Loss): $$ L = \begin{cases} - \dfrac{1}{N}\sum\limits_{i=1}^{N}{(\sum\limits_{j=1}^{M}{(x_i^j s_{j})}-y_i)^2}, & \text{if reduction='mean'} \\ - \sum\limits_{i=1}^{N}{(\sum\limits_{j=1}^{M}{(x_i^j s_{j})}-y_i)^2}, & \text{if reduction='sum'} + \dfrac{1}{N} \Vert \mathbf{s} \circ \mathbf{x} - \mathbf{y} \Vert_2^2, & \text{if reduction='mean'} \\ + \Vert \mathbf{s} \circ \mathbf{x} - \mathbf{y} \Vert_2^2, & \text{if reduction='sum'} \end{cases} $$ - $M$ is the number of samples in Monte-Carlo integration. + $$ + \mathbf{x}, \mathbf{y}, \mathbf{s} \in \mathcal{R}^{N} + $$ Args: reduction (Literal["mean", "sum"], optional): Reduction method. Defaults to "mean". diff --git a/ppsci/loss/l1.py b/ppsci/loss/l1.py index 8bff0ed65..4984dfccd 100644 --- a/ppsci/loss/l1.py +++ b/ppsci/loss/l1.py @@ -26,11 +26,11 @@ class L1Loss(base.Loss): r"""Class for l1 loss. $$ - L = - \begin{cases} - \dfrac{1}{N}\sum\limits_{i=1}^{N}{|x_i-y_i|}, & \text{if reduction='mean'} \\ - \sum\limits_{i=1}^{N}{|x_i-y_i|}, & \text{if reduction='sum'} - \end{cases} + L = \Vert \mathbf{x} - \mathbf{y} \Vert_1 + $$ + + $$ + \mathbf{x}, \mathbf{y} \in \mathcal{R}^{N} $$ Args: @@ -39,7 +39,7 @@ class L1Loss(base.Loss): Examples: >>> import ppsci - >>> loss = ppsci.loss.L1Loss("mean") + >>> loss = ppsci.loss.L1Loss() """ def __init__( @@ -59,6 +59,7 @@ def forward(self, output_dict, label_dict, weight_dict=None): loss = F.l1_loss(output_dict[key], label_dict[key], "none") if weight_dict: loss *= weight_dict[key] + if isinstance(self.weight, (float, int)): loss *= self.weight elif isinstance(self.weight, dict) and key in self.weight: @@ -67,6 +68,8 @@ def forward(self, output_dict, label_dict, weight_dict=None): if "area" in output_dict: loss *= output_dict["area"] + loss = loss.sum(axis=1) + if self.reduction == "sum": loss = loss.sum() elif self.reduction == "mean": @@ -82,10 +85,22 @@ def forward(self, output_dict, label_dict, weight_dict=None): class PeriodicL1Loss(base.Loss): - """Class for periodic l1 loss. + r"""Class for periodic l1 loss. + + $$ + L = \Vert \mathbf{x_l}-\mathbf{x_r} \Vert_1 + $$ + + $\mathbf{x_l} \in \mathcal{R}^{N}$ is the first half of batch output, + $\mathbf{x_r} \in \mathcal{R}^{N}$ is the second half of batch output. Args: reduction (Literal["mean", "sum"], optional): Reduction method. Defaults to "mean". + weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None. + + Examples: + >>> import ppsci + >>> loss = ppsci.loss.PeriodicL1Loss("mean") """ def __init__( @@ -117,6 +132,8 @@ def forward(self, output_dict, label_dict, weight_dict=None): if "area" in output_dict: loss *= output_dict["area"] + loss = loss.sum(axis=1) + if self.reduction == "sum": loss = loss.sum() elif self.reduction == "mean": diff --git a/ppsci/loss/l2.py b/ppsci/loss/l2.py index ad809cd39..d45ce6302 100644 --- a/ppsci/loss/l2.py +++ b/ppsci/loss/l2.py @@ -27,10 +27,15 @@ class L2Loss(base.Loss): r"""Class for l2 loss. $$ - L = \sum\limits_{i=1}^{N}{(x_i-y_i)^2} + L =\Vert \mathbf{x} - \mathbf{y} \Vert_2 + $$ + + $$ + \mathbf{x}, \mathbf{y} \in \mathcal{R}^{N} $$ Args: + reduction (Literal["mean", "sum"], optional): Reduction method. Defaults to "mean". weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None. Examples: @@ -38,8 +43,16 @@ class L2Loss(base.Loss): >>> loss = ppsci.loss.L2Loss() """ - def __init__(self, weight: Optional[Union[float, Dict[str, float]]] = None): - super().__init__("sum", weight) + def __init__( + self, + reduction: Literal["mean", "sum"] = "mean", + weight: Optional[Union[float, Dict[str, float]]] = None, + ): + if reduction not in ["mean", "sum"]: + raise ValueError( + f"reduction should be 'mean' or 'sum', but got {reduction}" + ) + super().__init__(reduction, weight) def forward(self, output_dict, label_dict, weight_dict=None): losses = 0.0 @@ -51,7 +64,13 @@ def forward(self, output_dict, label_dict, weight_dict=None): if "area" in output_dict: loss *= output_dict["area"] - loss = loss.sum() + loss = loss.sum(axis=1).sqrt() + + if self.reduction == "sum": + loss = loss.sum() + elif self.reduction == "mean": + loss = loss.mean() + if isinstance(self.weight, (float, int)): loss *= self.weight elif isinstance(self.weight, dict) and key in self.weight: @@ -62,7 +81,23 @@ def forward(self, output_dict, label_dict, weight_dict=None): class PeriodicL2Loss(base.Loss): - """Class for Periodic l2 loss.""" + r"""Class for Periodic l2 loss. + + $$ + L = \Vert \mathbf{x_l}-\mathbf{x_r} \Vert_2 + $$ + + $\mathbf{x_l} \in \mathcal{R}^{N}$ is the first half of batch output, + $\mathbf{x_r} \in \mathcal{R}^{N}$ is the second half of batch output. + + Args: + reduction (Literal["mean", "sum"], optional): Reduction method. Defaults to "mean". + weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None. + + Examples: + >>> import ppsci + >>> loss = ppsci.loss.PeriodicL2Loss() + """ def __init__( self, @@ -83,17 +118,24 @@ def forward(self, output_dict, label_dict, weight_dict=None): raise ValueError( f"Length of output({n_output}) of key({key}) should be even." ) - n_output //= 2 + loss = F.mse_loss( output_dict[key][:n_output], output_dict[key][n_output:], "none" ) if weight_dict: loss *= weight_dict[key] + if "area" in output_dict: loss *= output_dict["area"] - loss = loss.sum() + loss = loss.sum(axis=1).sqrt() + + if self.reduction == "sum": + loss = loss.sum() + elif self.reduction == "mean": + loss = loss.mean() + if isinstance(self.weight, (float, int)): loss *= self.weight elif isinstance(self.weight, dict) and key in self.weight: @@ -107,11 +149,11 @@ class L2RelLoss(base.Loss): r"""Class for l2 relative loss. $$ - L = - \begin{cases} - \dfrac{1}{N}\sum\limits_{i=1}^{N}{\dfrac{\Vert \mathbf{X_i}-\mathbf{Y_i}\Vert_2}{\Vert \mathbf{Y_i}\Vert_2}}, & \text{if reduction='mean'} \\ - \sum\limits_{i=1}^{N}{\dfrac{\Vert \mathbf{X_i}-\mathbf{Y_i}\Vert_2}{\Vert \mathbf{Y_i}\Vert_2}}, & \text{if reduction='sum'} - \end{cases} + L = \dfrac{\Vert \mathbf{x} - \mathbf{y} \Vert_2}{\Vert \mathbf{y} \Vert_2} + $$ + + $$ + \mathbf{x}, \mathbf{y} \in \mathcal{R}^{N} $$ Args: @@ -148,6 +190,7 @@ def forward(self, output_dict, label_dict, weight_dict=None): loss = self.rel_loss(output_dict[key], label_dict[key]) if weight_dict is not None: loss *= weight_dict[key] + if self.reduction == "sum": loss = loss.sum() elif self.reduction == "mean": diff --git a/ppsci/loss/mse.py b/ppsci/loss/mse.py index d67d18cb0..b647459c6 100644 --- a/ppsci/loss/mse.py +++ b/ppsci/loss/mse.py @@ -29,11 +29,15 @@ class MSELoss(base.Loss): $$ L = \begin{cases} - \dfrac{1}{N}\sum\limits_{i=1}^{N}{(x_i-y_i)^2}, & \text{if reduction='mean'} \\ - \sum\limits_{i=1}^{N}{(x_i-y_i)^2}, & \text{if reduction='sum'} + \dfrac{1}{N} \Vert {\mathbf{x}-\mathbf{y}} \Vert_2^2, & \text{if reduction='mean'} \\ + \Vert {\mathbf{x}-\mathbf{y}} \Vert_2^2, & \text{if reduction='sum'} \end{cases} $$ + $$ + \mathbf{x}, \mathbf{y} \in \mathcal{R}^{N} + $$ + Args: reduction (Literal["mean", "sum"], optional): Reduction method. Defaults to "mean". weight (Optional[Union[float, Dict[str, float]]]): Weight for loss. Defaults to None. @@ -83,12 +87,16 @@ class MSELossWithL2Decay(MSELoss): $$ L = \begin{cases} - \dfrac{1}{N}\sum\limits_{i=1}^{N}{(x_i-y_i)^2} + \sum\limits_{j=1}^{M}{\Vert r_j \Vert_2^2}, & \text{if reduction='mean'} \\ - \sum\limits_{i=1}^{N}{(x_i-y_i)^2} + \sum\limits_{j=1}^{M}{\Vert r_j \Vert_2^2}, & \text{if reduction='sum'} + \dfrac{1}{N} \Vert {\mathbf{x}-\mathbf{y}} \Vert_2^2 + \displaystyle\sum_{i=1}^{M}{\Vert \mathbf{K_i} \Vert_F^2}, & \text{if reduction='mean'} \\ + \Vert {\mathbf{x}-\mathbf{y}} \Vert_2^2 + \displaystyle\sum_{i=1}^{M}{\Vert \mathbf{K_i} \Vert_F^2}, & \text{if reduction='sum'} \end{cases} $$ - $M$ is the number of variables which apply regularization on. + $$ + \mathbf{x}, \mathbf{y} \in \mathcal{R}^{N}, \mathbf{K_i} \in \mathcal{R}^{O_i \times P_i} + $$ + + $M$ is the number of which apply regularization on. Args: reduction (Literal["mean", "sum"], optional): Specifies the reduction to apply to the output: 'mean' | 'sum'. Defaults to "mean". @@ -127,7 +135,18 @@ def forward(self, output_dict, label_dict, weight_dict=None): class PeriodicMSELoss(base.Loss): - """Class for periodic mean squared error loss. + r"""Class for periodic mean squared error loss. + + $$ + L = + \begin{cases} + \dfrac{1}{N} \Vert \mathbf{x_l}-\mathbf{x_r} \Vert_2^2, & \text{if reduction='mean'} \\ + \Vert \mathbf{x_l}-\mathbf{x_r} \Vert_2^2, & \text{if reduction='sum'} + \end{cases} + $$ + + $\mathbf{x_l} \in \mathcal{R}^{N}$ is the first half of batch output, + $\mathbf{x_r} \in \mathcal{R}^{N}$ is the second half of batch output. Args: reduction (Literal["mean", "sum"], optional): Reduction method. Defaults to "mean". diff --git a/ppsci/optimizer/__init__.py b/ppsci/optimizer/__init__.py index 6515f1416..54c64f896 100644 --- a/ppsci/optimizer/__init__.py +++ b/ppsci/optimizer/__init__.py @@ -49,8 +49,8 @@ def build_lr_scheduler(cfg, epochs, iters_per_epoch): cfg = copy.deepcopy(cfg) cfg.update({"epochs": epochs, "iters_per_epoch": iters_per_epoch}) lr_scheduler_cls = cfg.pop("name") - lr_scheduler = eval(lr_scheduler_cls)(**cfg) - return lr_scheduler() + lr_scheduler_ = eval(lr_scheduler_cls)(**cfg) + return lr_scheduler_() def build_optimizer(cfg, model_list, epochs, iters_per_epoch): diff --git a/ppsci/optimizer/lr_scheduler.py b/ppsci/optimizer/lr_scheduler.py index 5789f2e58..3ac7be0a2 100644 --- a/ppsci/optimizer/lr_scheduler.py +++ b/ppsci/optimizer/lr_scheduler.py @@ -83,7 +83,7 @@ def __init__( @abc.abstractmethod def __call__(self, *kargs, **kwargs) -> lr.LRScheduler: - """generate an learning rate scheduler. + """Generate an learning rate scheduler. Returns: lr.LinearWarmup: learning rate scheduler. @@ -127,7 +127,7 @@ def __init__(self, learning_rate: float, last_epoch: int = -1): super().__init__() def get_lr(self) -> float: - """always return the same learning rate""" + """Always return the same learning rate""" return self.learning_rate diff --git a/ppsci/optimizer/optimizer.py b/ppsci/optimizer/optimizer.py index 06035c091..54679d265 100644 --- a/ppsci/optimizer/optimizer.py +++ b/ppsci/optimizer/optimizer.py @@ -19,9 +19,6 @@ from typing import Tuple from typing import Union -if TYPE_CHECKING: - import paddle - from paddle import nn from paddle import optimizer as optim from paddle import regularizer @@ -31,6 +28,9 @@ from ppsci.utils import logger from ppsci.utils import misc +if TYPE_CHECKING: + import paddle + __all__ = ["SGD", "Momentum", "Adam", "RMSProp", "AdamW", "LBFGS", "OptimizerList"] diff --git a/ppsci/solver/train.py b/ppsci/solver/train.py index dc5157d22..16deae148 100644 --- a/ppsci/solver/train.py +++ b/ppsci/solver/train.py @@ -17,13 +17,13 @@ from paddle.distributed.fleet.utils import hybrid_parallel_util as hpu -if TYPE_CHECKING: - from ppsci import solver - from ppsci.solver import printer from ppsci.utils import misc from ppsci.utils import profiler +if TYPE_CHECKING: + from ppsci import solver + def train_epoch_func(solver: "solver.Solver", epoch_id: int, log_freq: int): """Train program for one epoch diff --git a/ppsci/solver/visu.py b/ppsci/solver/visu.py index fd76766f1..373f6e999 100644 --- a/ppsci/solver/visu.py +++ b/ppsci/solver/visu.py @@ -18,11 +18,11 @@ import paddle +from ppsci.utils import misc + if TYPE_CHECKING: from ppsci import solver -from ppsci.utils import misc - def visualize_func(solver: "solver.Solver", epoch_id: int): """Visualization program diff --git a/ppsci/utils/checker.py b/ppsci/utils/checker.py index 8b64d00aa..c8777d5e7 100644 --- a/ppsci/utils/checker.py +++ b/ppsci/utils/checker.py @@ -31,7 +31,6 @@ def run_check() -> None: >>> import ppsci >>> ppsci.utils.run_check() # doctest: +SKIP """ - # test demo code below. import logging diff --git a/ppsci/utils/config.py b/ppsci/utils/config.py index 91d838cfb..b4d99d975 100644 --- a/ppsci/utils/config.py +++ b/ppsci/utils/config.py @@ -86,7 +86,7 @@ def print_dict(d, delimiter=0): def print_config(config): """ - visualize configs + Visualize configs Arguments: config: configs """ @@ -191,7 +191,7 @@ def parse_args(): def _is_num_seq(seq): # whether seq is all int number(it is a shape) - return isinstance(seq, (list, tuple)) and all([isinstance(x, int) for x in seq]) + return isinstance(seq, (list, tuple)) and all(isinstance(x, int) for x in seq) def replace_shape_with_inputspec_(node: AttrDict): diff --git a/ppsci/utils/download.py b/ppsci/utils/download.py index 511e3b0bd..4c0faf9a5 100644 --- a/ppsci/utils/download.py +++ b/ppsci/utils/download.py @@ -78,7 +78,6 @@ def get_path_from_url(url, root_dir, md5sum=None, check_exist=True, decompress=T Returns: str: a local path to save downloaded models & weights & datasets. """ - if not is_url(url): raise ValueError(f"Given url({url}) is not valid") # parse path after download to decompress under root_dir @@ -204,64 +203,60 @@ def _decompress(fname): def _uncompress_file_zip(filepath): - files = zipfile.ZipFile(filepath, "r") - file_list = files.namelist() - - file_dir = os.path.dirname(filepath) + with zipfile.ZipFile(filepath, "r") as files: + file_list = files.namelist() - if _is_a_single_file(file_list): - rootpath = file_list[0] - uncompressed_path = os.path.join(file_dir, rootpath) + file_dir = os.path.dirname(filepath) - for item in file_list: - files.extract(item, file_dir) + if _is_a_single_file(file_list): + rootpath = file_list[0] + uncompressed_path = os.path.join(file_dir, rootpath) - elif _is_a_single_dir(file_list): - rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1] - uncompressed_path = os.path.join(file_dir, rootpath) + for item in file_list: + files.extract(item, file_dir) - for item in file_list: - files.extract(item, file_dir) + elif _is_a_single_dir(file_list): + rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1] + uncompressed_path = os.path.join(file_dir, rootpath) - else: - rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1] - uncompressed_path = os.path.join(file_dir, rootpath) - if not os.path.exists(uncompressed_path): - os.makedirs(uncompressed_path) - for item in file_list: - files.extract(item, os.path.join(file_dir, rootpath)) + for item in file_list: + files.extract(item, file_dir) - files.close() + else: + rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1] + uncompressed_path = os.path.join(file_dir, rootpath) + if not os.path.exists(uncompressed_path): + os.makedirs(uncompressed_path) + for item in file_list: + files.extract(item, os.path.join(file_dir, rootpath)) return uncompressed_path def _uncompress_file_tar(filepath, mode="r:*"): - files = tarfile.open(filepath, mode) - file_list = files.getnames() - - file_dir = os.path.dirname(filepath) - - if _is_a_single_file(file_list): - rootpath = file_list[0] - uncompressed_path = os.path.join(file_dir, rootpath) - for item in file_list: - files.extract(item, file_dir) - elif _is_a_single_dir(file_list): - rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1] - uncompressed_path = os.path.join(file_dir, rootpath) - for item in file_list: - files.extract(item, file_dir) - else: - rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1] - uncompressed_path = os.path.join(file_dir, rootpath) - if not os.path.exists(uncompressed_path): - os.makedirs(uncompressed_path) - - for item in file_list: - files.extract(item, os.path.join(file_dir, rootpath)) + with tarfile.open(filepath, mode) as files: + file_list = files.getnames() + + file_dir = os.path.dirname(filepath) + + if _is_a_single_file(file_list): + rootpath = file_list[0] + uncompressed_path = os.path.join(file_dir, rootpath) + for item in file_list: + files.extract(item, file_dir) + elif _is_a_single_dir(file_list): + rootpath = os.path.splitext(file_list[0])[0].split(os.sep)[-1] + uncompressed_path = os.path.join(file_dir, rootpath) + for item in file_list: + files.extract(item, file_dir) + else: + rootpath = os.path.splitext(filepath)[0].split(os.sep)[-1] + uncompressed_path = os.path.join(file_dir, rootpath) + if not os.path.exists(uncompressed_path): + os.makedirs(uncompressed_path) - files.close() + for item in file_list: + files.extract(item, os.path.join(file_dir, rootpath)) return uncompressed_path diff --git a/ppsci/utils/expression.py b/ppsci/utils/expression.py index 62d913cc6..33116e35a 100644 --- a/ppsci/utils/expression.py +++ b/ppsci/utils/expression.py @@ -42,6 +42,11 @@ class ExpressionSolver(nn.Layer): def __init__(self): super().__init__() + def forward(self, *args, **kwargs): + raise NotImplementedError( + f"Use train_forward/eval_forward/visu_forward instead of forward." + ) + @jit.to_static def train_forward( self, diff --git a/ppsci/utils/initializer.py b/ppsci/utils/initializer.py index ebb58e0fe..1054b4e34 100644 --- a/ppsci/utils/initializer.py +++ b/ppsci/utils/initializer.py @@ -318,9 +318,7 @@ def _calculate_correct_fan(tensor, mode, reverse=False): mode = mode.lower() valid_modes = ["fan_in", "fan_out"] if mode not in valid_modes: - raise ValueError( - "Mode {} not supported, please use one of {}".format(mode, valid_modes) - ) + raise ValueError(f"Mode {mode} not supported, please use one of {valid_modes}") fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor, reverse) diff --git a/ppsci/utils/logger.py b/ppsci/utils/logger.py index 6a2545290..13b69cb5a 100644 --- a/ppsci/utils/logger.py +++ b/ppsci/utils/logger.py @@ -100,7 +100,7 @@ def set_log_level(log_level): def log_at_trainer0(log): """ - logs will print multi-times when calling Fleet API. + Logs will print multi-times when calling Fleet API. Only display single log and ignore the others. """ diff --git a/ppsci/utils/misc.py b/ppsci/utils/misc.py index e80b3def5..67c2346c9 100644 --- a/ppsci/utils/misc.py +++ b/ppsci/utils/misc.py @@ -53,14 +53,14 @@ def __init__(self, name="", fmt="f", postfix="", need_avg=True): self.reset() def reset(self): - """reset""" + """Reset""" self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): - """update""" + """Update""" self.val = val self.sum += val * n self.count += n @@ -140,23 +140,23 @@ def all_gather( return result -def convert_to_array(dict: Dict[str, np.ndarray], keys: Tuple[str, ...]) -> np.ndarray: +def convert_to_array(dict_: Dict[str, np.ndarray], keys: Tuple[str, ...]) -> np.ndarray: """Concatenate arrays in axis -1 in order of given keys. Args: - dict (Dict[str, np.ndarray]): Dict contains arrays. + dict_ (Dict[str, np.ndarray]): Dict contains arrays. keys (Tuple[str, ...]): Concatenate keys used in concatenation. Returns: np.ndarray: Concatenated array. """ - return np.concatenate([dict[key] for key in keys], axis=-1) + return np.concatenate([dict_[key] for key in keys], axis=-1) def concat_dict_list( dict_list: Tuple[Dict[str, np.ndarray], ...] ) -> Dict[str, np.ndarray]: - """concatenate arrays in tuple of dicts at axis 0. + """Concatenate arrays in tuple of dicts at axis 0. Args: dict_list (Tuple[Dict[str, np.ndarray], ...]): Tuple of dicts. @@ -187,16 +187,16 @@ def stack_dict_list( return ret -def typename(object: object) -> str: +def typename(obj: object) -> str: """Return type name of given object. Args: - object (object): Python object which is instantiated from a class. + obj (object): Python object which is instantiated from a class. Returns: str: Class name of given object. """ - return object.__class__.__name__ + return obj.__class__.__name__ def combine_array_with_time(x: np.ndarray, t: Tuple[int, ...]) -> np.ndarray: diff --git a/ppsci/utils/profiler.py b/ppsci/utils/profiler.py index 4e3d1f2f2..a67736955 100644 --- a/ppsci/utils/profiler.py +++ b/ppsci/utils/profiler.py @@ -78,7 +78,7 @@ def _parse_from_string(self, options_str): def __getitem__(self, name): if self._options.get(name, None) is None: - raise ValueError("ProfilerOptions does not have an option named %s." % name) + raise ValueError(f"ProfilerOptions does not have an option named {name}.") return self._options[name] diff --git a/ppsci/utils/reader.py b/ppsci/utils/reader.py index 5b1382dc1..48300ac96 100644 --- a/ppsci/utils/reader.py +++ b/ppsci/utils/reader.py @@ -126,7 +126,7 @@ def load_vtk_file( input_keys: Tuple[str, ...], label_keys: Optional[Tuple[str, ...]], ) -> Dict[str, np.ndarray]: - """load coordinates and attached label from the *.vtu file. + """Load coordinates and attached label from the *.vtu file. Args: filename_without_timeid (str): File name without time id. diff --git a/ppsci/visualize/base.py b/ppsci/visualize/base.py index 210d1b14a..3b58a3e27 100644 --- a/ppsci/visualize/base.py +++ b/ppsci/visualize/base.py @@ -48,7 +48,7 @@ def __init__( @abc.abstractmethod def save(self, data_dict): - """visualize result from data_dict and save as files""" + """Visualize result from data_dict and save as files""" def __str__(self): return ", ".join( diff --git a/ppsci/visualize/plot.py b/ppsci/visualize/plot.py index e7868cd8f..7aa90d91a 100644 --- a/ppsci/visualize/plot.py +++ b/ppsci/visualize/plot.py @@ -184,7 +184,7 @@ def _save_plot_from_2d_array( figsize=(num_timestamps, len(visu_keys)), ) fig.subplots_adjust(hspace=0.3) - target_flag = any(["target" in key for key in visu_keys]) + target_flag = any("target" in key for key in visu_keys) for i, data in enumerate(visu_data): if target_flag is False or "target" in visu_keys[i]: c_max = np.amax(data) @@ -430,7 +430,7 @@ def plot_weather( ax.set_xticks(xticks) ax.set_xticklabels(xticklabels) if not log_norm: - map = ax.imshow( + map_ = ax.imshow( data, interpolation="nearest", cmap=cmap, @@ -440,10 +440,10 @@ def plot_weather( ) else: norm = matplotlib.colors.LogNorm(vmin=vmin, vmax=vmax, clip=True) - map = ax.imshow( + map_ = ax.imshow( data, interpolation="nearest", cmap=cmap, aspect="auto", norm=norm ) - plt.colorbar(mappable=map, cax=None, ax=None, shrink=0.5, label=colorbar_label) + plt.colorbar(mappable=map_, cax=None, ax=None, shrink=0.5, label=colorbar_label) fig = plt.figure(facecolor="w", figsize=(7, 7)) ax = fig.add_subplot(2, 1, 1)