Skip to content

Commit

Permalink
Refine formula, docstring and code. (#353)
Browse files Browse the repository at this point in the history
* refine losses formula and doc of loss

* refine docstring according to pydocstyle
  • Loading branch information
HydrogenSulfate authored May 30, 2023
1 parent ca39b4b commit 10518a3
Show file tree
Hide file tree
Showing 33 changed files with 237 additions and 159 deletions.
1 change: 1 addition & 0 deletions docs/zh/api/autodiff.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@
- Jacobians
- hessian
- Hessians
- clear
show_root_heading: false
heading_level: 3
2 changes: 2 additions & 0 deletions docs/zh/api/visualize.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
handler: python
options:
members:
- Visualizer
- VisualizerScatter1D
- VisualizerScatter3D
- VisualizerVtu
Expand All @@ -12,6 +13,7 @@
- Visualizer3D
- VisualizerWeather
- save_vtu_from_dict
- save_vtu_to_mesh
- save_plot_from_1d_dict
- save_plot_from_3d_dict
show_root_heading: false
Expand Down
8 changes: 4 additions & 4 deletions ppsci/arch/afno.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@

import paddle
import paddle.fft
import paddle.nn as nn
import paddle.nn.functional as F
from paddle import nn

from ppsci.arch import activation as act_mod
from ppsci.arch import base
Expand Down Expand Up @@ -381,7 +381,7 @@ def __init__(
)

def forward(self, x):
B, C, H, W = x.shape
_, _, H, W = x.shape
if not (H == self.img_size[0] and W == self.img_size[1]):
raise ValueError(
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
Expand Down Expand Up @@ -541,7 +541,7 @@ def forward(self, x):

y = []
input = x
for i in range(self.num_timestamps):
for _ in range(self.num_timestamps):
out = self.forward_tensor(input)
y.append(out)
input = out
Expand Down Expand Up @@ -665,7 +665,7 @@ def forward(self, x):

input_wind = x
y = []
for i in range(self.num_timestamps):
for _ in range(self.num_timestamps):
with paddle.no_grad():
out_wind = self.wind_model.forward_tensor(input_wind)
out = self.forward_tensor(out_wind)
Expand Down
3 changes: 1 addition & 2 deletions ppsci/arch/physx_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,7 @@ def split_heads(self, x, k=False):
x = x.reshape(new_x_shape)
if k:
return x.transpose([0, 2, 3, 1])
else:
return x.transpose([0, 2, 1, 3])
return x.transpose([0, 2, 1, 3])

def forward(
self,
Expand Down
2 changes: 1 addition & 1 deletion ppsci/autodiff/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def __call__(
j: int = 0,
grad_y: Optional[paddle.Tensor] = None,
) -> paddle.Tensor:
"""compute hessian matrix for given ys and xs.
"""Compute hessian matrix for given ys and xs.
Args:
ys (paddle.Tensor): Output tensor.
Expand Down
2 changes: 1 addition & 1 deletion ppsci/constraint/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from typing import Any
from typing import Dict

import paddle.io as io
from paddle import io

from ppsci import data
from ppsci import loss
Expand Down
1 change: 1 addition & 0 deletions ppsci/data/dataset/array_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(
weight: Dict[str, np.ndarray],
transforms: Optional[vision.Compose] = None,
):
super().__init__()
self.input = {key: paddle.to_tensor(value) for key, value in input.items()}
self.label = {key: paddle.to_tensor(value) for key, value in label.items()}
self.weight = {key: paddle.to_tensor(value) for key, value in weight.items()}
Expand Down
4 changes: 2 additions & 2 deletions ppsci/data/dataset/era5_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ def read_data(self, path: str, var="fields"):
paths = [path] if path.endswith(".h5") else glob.glob(path + "/*.h5")
paths.sort()
files = []
for path in paths:
_file = h5py.File(path, "r")
for path_ in paths:
_file = h5py.File(path_, "r")
files.append(_file[var])
return files

Expand Down
4 changes: 2 additions & 2 deletions ppsci/geometry/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ def sample_boundary(self, n, random="pseudo", criteria=None, evenly=False):
):
area_dict = misc.convert_to_dict(area[:, 1:], ["area"])
return {**x_dict, **normal_dict, **area_dict}
else:
return {**x_dict, **normal_dict}

return {**x_dict, **normal_dict}

@abc.abstractmethod
def random_points(self, n: int, random: str = "pseudo"):
Expand Down
2 changes: 1 addition & 1 deletion ppsci/geometry/geometry_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def random_points(self, n, random="pseudo"):
x = np.empty((0, 2), dtype=paddle.get_default_dtype())
vbbox = self.bbox[1] - self.bbox[0]
while len(x) < n:
x_new = sampler.sample(n, 2, sampler="pseudo") * vbbox + self.bbox[0]
x_new = sampler.sample(n, 2, "pseudo") * vbbox + self.bbox[0]
x = np.vstack((x, x_new[self.is_inside(x_new)]))
return x[:n]

Expand Down
53 changes: 25 additions & 28 deletions ppsci/geometry/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def inflated_random_boundary_points(
all_points.append(normal)
all_points.append(area)

all_point = np.concatenate(all_point, axis=0)
all_points = np.concatenate(all_points, axis=0)
all_normal = np.concatenate(all_normal, axis=0)
all_area = np.concatenate(all_area, axis=0)
return all_points, all_normal, all_area
Expand Down Expand Up @@ -270,7 +270,7 @@ def random_boundary_points(self, n, random="pseudo", criteria=None):
npoint_per_triangle, np.arange(len(triangle_prob) + 1) - 0.5
)

all_point = []
all_points = []
all_normal = []
all_area = []
for i, npoint in enumerate(npoint_per_triangle):
Expand All @@ -286,17 +286,17 @@ def random_boundary_points(self, n, random="pseudo", criteria=None):
dtype=paddle.get_default_dtype(),
)

all_point.append(face_points)
all_points.append(face_points)
all_normal.append(face_normal)
all_area.append(valid_area)

all_point = np.concatenate(all_point, axis=0)
all_points = np.concatenate(all_points, axis=0)
all_normal = np.concatenate(all_normal, axis=0)
all_area = np.concatenate(all_area, axis=0)

# NOTE: use global mean area instead of local mean area
all_area = np.full_like(all_area, all_area.mean())
return all_point, all_normal, all_area
return all_points, all_normal, all_area

def sample_boundary(
self, n, random="pseudo", criteria=None, evenly=False, inflation_dist=None
Expand All @@ -311,11 +311,8 @@ def sample_boundary(
raise ValueError(
"Can't sample evenly on mesh now, please set evenly=False."
)
# points, normal, area = self.uniform_boundary_points(n, False)
else:
points, normals, areas = self.random_boundary_points(
n, random, criteria
)
# points, normal, area = self.uniform_boundary_points(n, False)
points, normals, areas = self.random_boundary_points(n, random, criteria)

x_dict = misc.convert_to_dict(points, self.dim_keys)
normal_dict = misc.convert_to_dict(
Expand Down Expand Up @@ -362,9 +359,7 @@ def sample_interior(self, n, random="pseudo", criteria=None, evenly=False):
raise NotImplementedError(
"uniformly sample for interior in mesh is not support yet"
)
# points, area = self.uniform_points(n)
else:
points, areas = self.random_points(n, random, criteria)
points, areas = self.random_points(n, random, criteria)

x_dict = misc.convert_to_dict(points, self.dim_keys)
area_dict = misc.convert_to_dict(areas, ["area"])
Expand All @@ -375,45 +370,47 @@ def sample_interior(self, n, random="pseudo", criteria=None, evenly=False):

return {**x_dict, **area_dict, **sdf_dict}

def union(self, rhs: "Mesh"):
def union(self, other: "Mesh"):
if not checker.dynamic_import_to_globals(["pymesh"]):
raise ModuleNotFoundError
import pymesh

csg = pymesh.CSGTree({"union": [{"mesh": self.py_mesh}, {"mesh": rhs.py_mesh}]})
csg = pymesh.CSGTree(
{"union": [{"mesh": self.py_mesh}, {"mesh": other.py_mesh}]}
)
return Mesh(csg.mesh)

def __or__(self, rhs: "Mesh"):
return self.union(rhs)
def __or__(self, other: "Mesh"):
return self.union(other)

def __add__(self, rhs: "Mesh"):
return self.union(rhs)
def __add__(self, other: "Mesh"):
return self.union(other)

def difference(self, rhs: "Mesh"):
def difference(self, other: "Mesh"):
if not checker.dynamic_import_to_globals(["pymesh"]):
raise ModuleNotFoundError
import pymesh

csg = pymesh.CSGTree(
{"difference": [{"mesh": self.py_mesh}, {"mesh": rhs.py_mesh}]}
{"difference": [{"mesh": self.py_mesh}, {"mesh": other.py_mesh}]}
)
return Mesh(csg.mesh)

def __sub__(self, rhs: "Mesh"):
return self.difference(rhs)
def __sub__(self, other: "Mesh"):
return self.difference(other)

def intersection(self, rhs: "Mesh"):
def intersection(self, other: "Mesh"):
if not checker.dynamic_import_to_globals(["pymesh"]):
raise ModuleNotFoundError
import pymesh

csg = pymesh.CSGTree(
{"intersection": [{"mesh": self.py_mesh}, {"mesh": rhs.py_mesh}]}
{"intersection": [{"mesh": self.py_mesh}, {"mesh": other.py_mesh}]}
)
return Mesh(csg.mesh)

def __and__(self, rhs: "Mesh"):
return self.intersection(rhs)
def __and__(self, other: "Mesh"):
return self.intersection(other)

def __str__(self) -> str:
"""Return the name of class"""
Expand All @@ -429,7 +426,7 @@ def __str__(self) -> str:


def area_of_triangles(v0, v1, v2):
"""ref https://math.stackexchange.com/questions/128991/how-to-calculate-the-area-of-a-3d-triangle
"""Ref https://math.stackexchange.com/questions/128991/how-to-calculate-the-area-of-a-3d-triangle
Args:
v0 (np.ndarray): Coordinates of the first vertex of the triangle surface with shape of [N, 3].
Expand Down
14 changes: 7 additions & 7 deletions ppsci/geometry/pointcloud.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@

import numpy as np

import ppsci.utils.misc as misc
from ppsci.geometry import geometry
from ppsci.utils import misc


class PointCloud(geometry.Geometry):
Expand Down Expand Up @@ -148,32 +148,32 @@ def uniform_points(self, n: int, boundary=True):
"""Compute the equispaced points in the geometry."""
return self.interior[:n]

def union(self, rhs):
def union(self, other):
raise NotImplementedError(
"Union operation for PointCloud is not supported yet."
)

def __or__(self, rhs):
def __or__(self, other):
raise NotImplementedError(
"Union operation for PointCloud is not supported yet."
)

def difference(self, rhs):
def difference(self, other):
raise NotImplementedError(
"Subtraction operation for PointCloud is not supported yet."
)

def __sub__(self, rhs):
def __sub__(self, other):
raise NotImplementedError(
"Subtraction operation for PointCloud is not supported yet."
)

def intersection(self, rhs):
def intersection(self, other):
raise NotImplementedError(
"Intersection operation for PointCloud is not supported yet."
)

def __and__(self, rhs):
def __and__(self, other):
raise NotImplementedError(
"Intersection operation for PointCloud is not supported yet."
)
Expand Down
2 changes: 1 addition & 1 deletion ppsci/loss/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from typing import Optional
from typing import Union

import paddle.nn as nn
from paddle import nn
from typing_extensions import Literal


Expand Down
8 changes: 5 additions & 3 deletions ppsci/loss/integral.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@ class IntegralLoss(base.Loss):
$$
L =
\begin{cases}
\dfrac{1}{N}\sum\limits_{i=1}^{N}{(\sum\limits_{j=1}^{M}{(x_i^j s_{j})}-y_i)^2}, & \text{if reduction='mean'} \\
\sum\limits_{i=1}^{N}{(\sum\limits_{j=1}^{M}{(x_i^j s_{j})}-y_i)^2}, & \text{if reduction='sum'}
\dfrac{1}{N} \Vert \mathbf{s} \circ \mathbf{x} - \mathbf{y} \Vert_2^2, & \text{if reduction='mean'} \\
\Vert \mathbf{s} \circ \mathbf{x} - \mathbf{y} \Vert_2^2, & \text{if reduction='sum'}
\end{cases}
$$
$M$ is the number of samples in Monte-Carlo integration.
$$
\mathbf{x}, \mathbf{y}, \mathbf{s} \in \mathcal{R}^{N}
$$
Args:
reduction (Literal["mean", "sum"], optional): Reduction method. Defaults to "mean".
Expand Down
Loading

0 comments on commit 10518a3

Please sign in to comment.