Source code for pinnicle.physics.timeinvariant

from . import EquationBase, Constants
from ..parameter import EquationParameter
from ..utils import slice_column, jacobian

# time invariant constraint {{{
[docs] class TimeInvariantConstraintParameter(EquationParameter, Constants): """ parameters for the time-invariant constraint """ _EQUATION_TYPE = 'Time_Invariant' def __init__(self, param_dict={}): # load necessary constants Constants.__init__(self) super().__init__(param_dict)
[docs] def set_default(self): self.input = ['x', 'y', 't'] self.output = ['H', 's', 'C'] self.output_lb = [self.variable_lb[k] for k in self.output] self.output_ub = [self.variable_ub[k] for k in self.output] self.data_weights = [1.0e-6, 1.0e-6, 1.0e-8] self.residuals = ["db/dt", "dC/dt"] self.pde_weights = [1.0e6, 1.0e6] # scalar variables: name:value self.scalar_variables = {}
[docs] class TimeInvariantConstraint(EquationBase): #{{{ """ A temporary solution to add time invariant constraint db/dt=0, dC/dt=0 to the PINN TODO: 1. define these for every time independent equation, similar as _pde_jax 2. use tf.cond, and similar autograph for pytorch and jax (when they have these implemented), put everything all in _pde """ _EQUATION_TYPE = 'Time_Invariant' def __init__(self, parameters=TimeInvariantConstraintParameter()): super().__init__(parameters) def _pde(self, nn_input_var, nn_output_var): #{{{ """ time invariant constraint Args: nn_input_var: global input to the nn nn_output_var: global output from the nn """ # get the ids tid = self.local_input_var["t"] sid = self.local_output_var["s"] Hid = self.local_output_var["H"] Cid = self.local_output_var["C"] # unpacking normalized output s = slice_column(nn_output_var, sid) H = slice_column(nn_output_var, Hid) C = slice_column(nn_output_var, Cid) # time derivative H_t = jacobian(nn_output_var, nn_input_var, i=Hid, j=tid) s_t = jacobian(nn_output_var, nn_input_var, i=sid, j=tid) C_t = jacobian(nn_output_var, nn_input_var, i=Cid, j=tid) # residual fdbdt = s_t - H_t fdCdt = C_t return [fdbdt, fdCdt] #}}} def _pde_jax(self, nn_input_var, nn_output_var): #{{{ """ time invariant constraint, jax version Args: nn_input_var: global input to the nn nn_output_var: global output from the nn """ return self._pde(nn_input_var, nn_output_var) #}}}
#}}} #}}}