Source code for pinnicle.utils.data_misfit

import deepxde as dde
import deepxde.backend as bkd
from deepxde.backend import backend_name, tf, jax, torch

# ---- tensorflow {{{
[docs] def surface_log_vel_misfit_tf(v_true, v_pred): """Compute SurfaceLogVelMisfit: This function is for tensorflow backend """ epsvel=2.220446049250313e-16 return bkd.reduce_mean(bkd.square((tf.math.log((tf.abs(v_pred)+epsvel)/(tf.abs(v_true)+epsvel)))))
[docs] def mean_squared_log_error_tf(y_true, y_pred): """ use tensorflow function to compute mean squared log error """ return tf.keras.losses.MeanSquaredLogarithmicError()(y_true, y_pred)
#}}} # ---- jax {{{
[docs] def surface_log_vel_misfit_jax(v_true, v_pred): """Compute SurfaceLogVelMisfit: This function is for jax """ epsvel=2.220446049250313e-16 return bkd.reduce_mean(bkd.square((jax.numpy.log((jax.numpy.abs(v_pred)+epsvel)/(jax.numpy.abs(v_true)+epsvel)))))
[docs] def mean_squared_log_error_jax(y_true, y_pred): """ use jax/numpy function to compute mean squared log error """ return bkd.reduce_mean(bkd.square(jax.numpy.log(y_true+1.0) - jax.numpy.log(y_pred+1.0)))
#}}} # ---- pytorch {{{
[docs] def surface_log_vel_misfit_pytorch(v_true, v_pred): """Compute SurfaceLogVelMisfit: This function is for pytorch backend """ epsvel=2.220446049250313e-16 return bkd.reduce_mean(bkd.square((torch.log((bkd.abs(v_pred)+epsvel)/(bkd.abs(v_true)+epsvel)))))
[docs] def mean_squared_log_error_pytorch(y_true, y_pred): """ use jax/numpy function to compute mean squared log error """ return bkd.reduce_mean(bkd.square(torch.log(y_true+1.0) - torch.log(y_pred+1.0)))
#}}} # ---------------
[docs] def loss_dict_tf(): return { "VEL_LOG": surface_log_vel_misfit_tf, "MEAN_SQUARE_LOG": mean_squared_log_error_tf }
[docs] def loss_dict_jax(): return { "VEL_LOG": surface_log_vel_misfit_jax, "MEAN_SQUARE_LOG": mean_squared_log_error_jax }
[docs] def loss_dict_pytorch(): return { "VEL_LOG": surface_log_vel_misfit_pytorch, "MEAN_SQUARE_LOG": mean_squared_log_error_pytorch }
if backend_name == "tensorflow": LOSS_DICT = loss_dict_tf() elif backend_name == "jax": LOSS_DICT = loss_dict_jax() elif backend_name == "pytorch": LOSS_DICT = loss_dict_pytorch()
[docs] def get(identifier): """Retrieves a loss function. Args: identifier: A loss identifier. String name of a loss function, or a loss function. Returns: A loss function. """ if isinstance(identifier, (list, tuple)): return list(map(get, identifier)) if isinstance(identifier, str): if identifier in LOSS_DICT: return LOSS_DICT[identifier] elif identifier in dde.losses.LOSS_DICT: return identifier if callable(identifier): return identifier raise ValueError("Could not interpret loss function identifier:", identifier)