-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathvalue.py
32 lines (22 loc) · 1.39 KB
/
value.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from phi.tf.flow import StateDependency, Physics, ConstantField, FieldEffect, FieldPhysics, ADD
from .pde_base import PDE
class ScalarEffectControl(Physics):
def __init__(self):
Physics.__init__(self, [StateDependency('scalar', 'scalar', single_state=True),
StateDependency('pred', 'next_state_prediction', single_state=True, blocking=True)])
def step(self, effect, dt=1.0, scalar=None, pred=None):
force = pred.prediction.scalar.data - scalar.data
return effect.copied_with(field=effect.field.with_data(force), age=effect.age + dt)
class IncrementPDE(PDE):
def create_pde(self, world, control_trainable, constant_prediction_offset):
world.reset(world.batch_size, add_default_objects=False)
world.add(ConstantField(0.0, name='scalar', flags=()), physics=FieldPhysics('scalar'))
world.add(FieldEffect(ConstantField(0, flags=()), ['scalar'], ADD), physics=ScalarEffectControl())
def target_matching_loss(self, target_state, actual_state):
return None
def total_force_loss(self, states):
return None
def predict(self, n, initial, target, trainable):
center_age = (initial.scalar.age + target.scalar.age) / 2
new_field = initial.scalar.copied_with(data=(initial.scalar.data + target.scalar.data) * 0.5, flags=(), age=center_age)
return initial.state_replaced(new_field)