# -*- coding: utf-8 -*-
"""
Base classes for plotting classes
:author: jpeacock
"""
# =============================================================================
# Imports
# =============================================================================
from __future__ import annotations
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from loguru import logger
from scipy import interpolate, stats
from .map_interpolation_tools import interpolate_to_map
from .plot_settings import PlotSettings
from .plotters import add_raster
# =============================================================================
# Base
# =============================================================================
[docs]
class PlotBase(PlotSettings):
"""
Base class for plotting objects.
Provides core plotting functionality including figure management,
saving, updating, and redrawing plots.
Parameters
----------
**kwargs : dict
Keyword arguments passed to PlotSettings parent class
Attributes
----------
logger : loguru.Logger
Logger instance for the class
"""
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.logger = logger
self._basename = self.__class__.__name__.lower()
def __str__(self) -> str:
"""
Return string representation of the plotting object.
Returns
-------
str
String describing the plotting class
"""
return f"Plotting {self.__class__.__name__}"
def __repr__(self) -> str:
"""
Return repr representation of the plotting object.
Returns
-------
str
String describing the plotting class
"""
return self.__str__()
def _set_subplot_params(self) -> None:
"""
Set matplotlib subplot parameters from instance attributes.
Sets font size and subplot spacing parameters including bottom,
top, left, right margins, and optional wspace/hspace.
"""
# set some parameters of the figure and subplot spacing
plt.rcParams["font.size"] = self.font_size
plt.rcParams["figure.subplot.bottom"] = self.subplot_bottom
plt.rcParams["figure.subplot.top"] = self.subplot_top
plt.rcParams["figure.subplot.left"] = self.subplot_left
plt.rcParams["figure.subplot.right"] = self.subplot_right
if self.subplot_wspace is not None:
plt.rcParams["figure.subplot.wspace"] = self.subplot_wspace
if self.subplot_hspace is not None:
plt.rcParams["figure.subplot.hspace"] = self.subplot_hspace
[docs]
def plot(self) -> None:
"""
Create the plot.
This method should be overridden by subclasses to implement
specific plotting functionality.
"""
[docs]
def save_plot(
self,
save_fn: str | Path,
file_format: str = "pdf",
orientation: str = "portrait",
fig_dpi: int | None = None,
close_plot: bool = True,
) -> None:
"""
Save the figure to a file.
Parameters
----------
save_fn : str | Path
Full path to save figure to. Can be:
- Directory path: file will be saved as save_fn/basename.file_format
- Full path: file will be saved to the given path, format inferred
from extension
file_format : str, optional
File format for saved figure (pdf, eps, jpg, png, svg),
by default 'pdf'
orientation : str, optional
Page orientation ('landscape' or 'portrait'), by default 'portrait'
fig_dpi : int | None, optional
Resolution in dots-per-inch. If None, uses the figure's dpi,
by default None
close_plot : bool, optional
Whether to close the plot after saving, by default True
Examples
--------
>>> # Save plot as jpg
>>> p1.save_plot(r'/home/MT/figures', file_format='jpg')
"""
if fig_dpi is None:
fig_dpi = self.fig_dpi
save_fn = Path(save_fn)
if not save_fn.is_dir():
file_format = save_fn.suffix[1:]
else:
save_fn = save_fn.joinpath(f"{self._basename}.{file_format}")
self.fig.savefig(
save_fn, dpi=fig_dpi, format=file_format, orientation=orientation
)
if close_plot:
plt.close(self.fig)
else:
pass
self.fig_fn = save_fn
self.logger.info(f"Saved figure to: {self.fig_fn}")
[docs]
def update_plot(self) -> None:
"""
Update the plot after changing figure or axes properties.
Uses matplotlib's canvas draw method to refresh the display
after modifying figure or axes attributes.
Examples
--------
>>> [ax.grid(True, which='major') for ax in [p1.axr, p1.axp]]
>>> p1.update_plot()
"""
self.fig.canvas.draw()
[docs]
def redraw_plot(self) -> None:
"""
Recreate the plot after updating attributes.
Closes the current figure and calls plot() to create a new one
with updated attributes.
Examples
--------
>>> # Change the color and marker of the xy components
>>> p1.xy_color = (0.5, 0.5, 0.9)
>>> p1.xy_marker = '*'
>>> p1.redraw_plot()
"""
plt.close(self.fig)
self.plot()
[docs]
class PlotBaseMaps(PlotBase):
"""
Base object for plot classes that use map views.
Includes methods for interpolation of data onto map grids.
Parameters
----------
**kwargs : dict
Keyword arguments passed to PlotBase parent class and for setting
interpolation parameters
Attributes
----------
cell_size : float
Size of grid cells for interpolation, by default 0.002
n_padding_cells : int
Number of padding cells around data extent, by default 10
interpolation_method : str
Interpolation method ('delaunay', 'linear', 'nearest', 'cubic'),
by default 'delaunay'
interpolation_power : int
Power parameter for inverse distance weighting, by default 5
nearest_neighbors : int
Number of nearest neighbors to use in interpolation, by default 7
"""
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
self.cell_size = 0.002
self.n_padding_cells = 10
self.interpolation_method = "delaunay"
self.interpolation_power = 5
self.nearest_neighbors = 7
for key, value in kwargs.items():
setattr(self, key, value)
[docs]
def interpolate_to_map(self, plot_array, component: str):
"""
Interpolate data points onto a 2D map grid.
Parameters
----------
plot_array : np.ndarray
Array containing data to interpolate
component : str
Name of the component being interpolated
Returns
-------
tuple
Interpolated grid data and coordinates
"""
return interpolate_to_map(
plot_array,
component,
cell_size=self.cell_size,
n_padding_cells=self.n_padding_cells,
interpolation_method=self.interpolation_method,
interpolation_power=self.interpolation_power,
nearest_neighbors=self.nearest_neighbors,
)
[docs]
@staticmethod
def get_interp1d_functions_z(tf, interp_type: str = "slinear") -> dict | None:
"""
Create 1D interpolation functions for impedance tensor components.
Parameters
----------
tf : MT or Transfer Function object
Transfer function object containing impedance data
interp_type : str, optional
Type of interpolation ('linear', 'slinear', 'cubic'),
by default 'slinear'
Returns
-------
dict | None
Dictionary containing interpolation functions for each impedance
component (zxx, zxy, zyx, zyy) with 'real', 'imag', 'err', and
'model_err' sub-keys. Returns None if no Z data available.
"""
if tf.Z is None:
return None
# interpolate the impedance tensor
zmap = {0: "x", 1: "y"}
interp_dict = {}
for ii in range(2):
for jj in range(2):
comp = f"z{zmap[ii]}{zmap[jj]}"
interp_dict[comp] = {}
# need to look out for zeros in the impedance
# get the indicies of non-zero components
nz_index = np.nonzero(tf.Z.z[:, ii, jj])
if len(nz_index[0]) == 0:
continue
# get the non-zero components
z_real = tf.Z.z[nz_index, ii, jj].real
z_imag = tf.Z.z[nz_index, ii, jj].imag
# get the frequencies of non-zero components
f = tf.Z.frequency[nz_index]
# create a function that does 1d interpolation
interp_dict[comp]["real"] = interpolate.interp1d(
f, z_real, kind=interp_type
)
interp_dict[comp]["imag"] = interpolate.interp1d(
f, z_imag, kind=interp_type
)
if tf.Z._has_tf_error():
z_error = tf.Z.z_error[nz_index, ii, jj]
interp_dict[comp]["err"] = interpolate.interp1d(
f, z_error, kind=interp_type
)
else:
interp_dict[comp]["err"] = None
if tf.Z._has_tf_model_error():
z_model_error = tf.Z.z_model_error[nz_index, ii, jj]
interp_dict[comp]["model_err"] = interpolate.interp1d(
f, z_model_error, kind=interp_type
)
else:
interp_dict[comp]["model_err"] = None
return interp_dict
[docs]
@staticmethod
def get_interp1d_functions_t(tf, interp_type: str = "slinear") -> dict | None:
"""
Create 1D interpolation functions for tipper components.
Parameters
----------
tf : MT or Transfer Function object
Transfer function object containing tipper data
interp_type : str, optional
Type of interpolation ('linear', 'slinear', 'cubic'),
by default 'slinear'
Returns
-------
dict | None
Dictionary containing interpolation functions for tipper
components (tzx, tzy) with 'real', 'imag', 'err', and
'model_err' sub-keys. Returns None if no Tipper data available.
"""
if tf.Tipper is None:
return None
# interpolate the impedance tensor
zmap = {0: "x", 1: "y"}
interp_dict = {}
for jj in range(2):
comp = f"tz{zmap[jj]}"
interp_dict[comp] = {}
# need to look out for zeros in the impedance
# get the indicies of non-zero components
nz_index = np.nonzero(tf.Tipper.tipper[:, 0, jj])
if len(nz_index[0]) == 0:
continue
# get the non-zero components
t_real = tf.Tipper.tipper[nz_index, 0, jj].real
t_imag = tf.Tipper.tipper[nz_index, 0, jj].imag
# get the frequencies of non-zero components
f = tf.Tipper.frequency[nz_index]
# create a function that does 1d interpolation
interp_dict[comp]["real"] = interpolate.interp1d(
f, t_real, kind=interp_type
)
interp_dict[comp]["imag"] = interpolate.interp1d(
f, t_imag, kind=interp_type
)
if tf.Tipper._has_tf_error():
t_err = tf.Tipper.tipper_error[nz_index, 0, jj]
interp_dict[comp]["err"] = interpolate.interp1d(
f, t_err, kind=interp_type
)
else:
interp_dict[comp]["err"] = None
if tf.Tipper._has_tf_model_error():
t_model_err = tf.Tipper.tipper_model_error[nz_index, 0, jj]
interp_dict[comp]["model_err"] = interpolate.interp1d(
f, t_model_err, kind=interp_type
)
else:
interp_dict[comp]["model_err"] = None
return interp_dict
def _get_plot_period_index(self, tf, rtol: float = 1e-6) -> int | None:
"""Return index of the configured plot period if present in TF data."""
period = getattr(tf, "period", None)
if period is None:
return None
period_array = np.asarray(period, dtype=float)
if period_array.size == 0:
return None
idx = np.where(np.isclose(period_array, float(self.plot_period), rtol=rtol))[0]
if idx.size == 0:
return None
return int(idx[0])
def _get_interpolated_z(self, tf) -> np.ndarray:
"""
Get interpolated impedance tensor at plot period.
Parameters
----------
tf : MT or Transfer Function object
Transfer function object with interpolation functions
Returns
-------
np.ndarray
Complex impedance tensor array of shape (1, 2, 2) at the
specified plot period
"""
idx = self._get_plot_period_index(tf)
if idx is not None and tf.Z is not None:
try:
return np.nan_to_num(np.asarray(tf.Z.z[idx : idx + 1], dtype=complex))
except Exception:
pass
if not hasattr(tf, "z_interp_dict"):
tf.z_interp_dict = self.get_interp1d_functions_z(tf)
return np.nan_to_num(
np.array(
[
[
tf.z_interp_dict["zxx"]["real"](1 / self.plot_period)[0]
+ 1j
* tf.z_interp_dict["zxx"]["imag"](1.0 / self.plot_period)[0],
tf.z_interp_dict["zxy"]["real"](1.0 / self.plot_period)[0]
+ 1j
* tf.z_interp_dict["zxy"]["imag"](1.0 / self.plot_period)[0],
],
[
tf.z_interp_dict["zyx"]["real"](1.0 / self.plot_period)[0]
+ 1j
* tf.z_interp_dict["zyx"]["imag"](1.0 / self.plot_period)[0],
tf.z_interp_dict["zyy"]["real"](1.0 / self.plot_period)[0]
+ 1j
* tf.z_interp_dict["zyy"]["imag"](1.0 / self.plot_period)[0],
],
]
)
).reshape((1, 2, 2))
def _get_interpolated_z_error(self, tf) -> np.ndarray:
"""
Get interpolated impedance tensor error at plot period.
Parameters
----------
tf : MT or Transfer Function object
Transfer function object with interpolation functions
Returns
-------
np.ndarray
Impedance tensor error array of shape (1, 2, 2) at the
specified plot period. Returns zeros if no error data available.
"""
idx = self._get_plot_period_index(tf)
if idx is not None and tf.Z is not None and tf.Z._has_tf_error():
try:
return np.nan_to_num(
np.asarray(tf.Z.z_error[idx : idx + 1], dtype=float)
)
except Exception:
pass
if not hasattr(tf, "z_interp_dict"):
tf.z_interp_dict = self.get_interp1d_functions_z(tf)
if tf.z_interp_dict["zxy"]["err"] is not None:
return np.nan_to_num(
np.array(
[
[
tf.z_interp_dict["zxx"]["err"](1.0 / self.plot_period)[0],
tf.z_interp_dict["zxy"]["err"](1.0 / self.plot_period)[0],
],
[
tf.z_interp_dict["zyx"]["err"](1.0 / self.plot_period)[0],
tf.z_interp_dict["zyy"]["err"](1.0 / self.plot_period)[0],
],
]
)
).reshape((1, 2, 2))
else:
return np.zeros((1, 2, 2), dtype=float)
def _get_interpolated_z_model_error(self, tf) -> np.ndarray:
"""
Get interpolated impedance tensor model error at plot period.
Parameters
----------
tf : MT or Transfer Function object
Transfer function object with interpolation functions
Returns
-------
np.ndarray
Impedance tensor model error array of shape (1, 2, 2) at the
specified plot period. Returns zeros if no model error data available.
"""
idx = self._get_plot_period_index(tf)
if idx is not None and tf.Z is not None and tf.Z._has_tf_model_error():
try:
return np.nan_to_num(
np.asarray(tf.Z.z_model_error[idx : idx + 1], dtype=float)
)
except Exception:
pass
if not hasattr(tf, "z_interp_dict"):
tf.z_interp_dict = self.get_interp1d_functions_z(tf)
if tf.z_interp_dict["zxy"]["model_err"] is not None:
return np.nan_to_num(
np.array(
[
[
tf.z_interp_dict["zxx"]["model_err"](
1.0 / self.plot_period
)[0],
tf.z_interp_dict["zxy"]["model_err"](
1.0 / self.plot_period
)[0],
],
[
tf.z_interp_dict["zyx"]["model_err"](
1.0 / self.plot_period
)[0],
tf.z_interp_dict["zyy"]["model_err"](
1.0 / self.plot_period
)[0],
],
]
)
).reshape((1, 2, 2))
else:
return np.zeros((1, 2, 2), dtype=float)
def _get_interpolated_t(self, tf) -> np.ndarray:
"""
Get interpolated tipper at plot period.
Parameters
----------
tf : MT or Transfer Function object
Transfer function object with interpolation functions
Returns
-------
np.ndarray
Complex tipper array of shape (1, 1, 2) at the specified
plot period. Returns zeros if no tipper data available.
"""
idx = self._get_plot_period_index(tf)
if idx is not None and tf.has_tipper() and tf.Tipper is not None:
try:
return np.nan_to_num(
np.asarray(tf.Tipper.tipper[idx : idx + 1], dtype=complex)
)
except Exception:
pass
if not hasattr(tf, "t_interp_dict"):
tf.t_interp_dict = self.get_interp1d_functions_t(tf)
if not tf.has_tipper():
return np.zeros((1, 1, 2), dtype=complex)
return np.nan_to_num(
np.array(
[
[
[
tf.t_interp_dict["tzx"]["real"](1.0 / self.plot_period)[0]
+ 1j
* tf.t_interp_dict["tzx"]["imag"](1.0 / self.plot_period)[
0
],
tf.t_interp_dict["tzy"]["real"](1.0 / self.plot_period)[0]
+ 1j
* tf.t_interp_dict["tzy"]["imag"](1.0 / self.plot_period)[
0
],
]
]
]
)
).reshape((1, 1, 2))
def _get_interpolated_t_err(self, tf):
"""Get interpolated t err.
:param tf: DESCRIPTION.
:type tf: TYPE
:return: DESCRIPTION.
:rtype: TYPE
"""
idx = self._get_plot_period_index(tf)
if idx is not None and tf.has_tipper() and tf.Tipper._has_tf_error():
try:
return np.nan_to_num(
np.asarray(tf.Tipper.tipper_error[idx : idx + 1], dtype=float)
)
except Exception:
pass
if not hasattr(tf, "t_interp_dict"):
tf.t_interp_dict = self.get_interp1d_functions_t(tf)
if not tf.has_tipper():
return np.zeros((1, 1, 2), dtype=float)
if tf.Tipper._has_tf_error():
return np.nan_to_num(
np.array(
[
[
[
tf.t_interp_dict["tzx"]["err"](1.0 / self.plot_period)[
0
],
tf.t_interp_dict["tzy"]["err"](1.0 / self.plot_period)[
0
],
]
]
]
)
).reshape((1, 1, 2))
else:
return np.zeros((1, 1, 2), dtype=float)
def _get_interpolated_t_model_err(self, tf) -> np.ndarray:
"""
Get interpolated tipper model error at plot period.
Parameters
----------
tf : MT or Transfer Function object
Transfer function object with interpolation functions
Returns
-------
np.ndarray
Tipper model error array of shape (1, 1, 2) at the specified
plot period. Returns zeros if no tipper model error data available.
"""
idx = self._get_plot_period_index(tf)
if idx is not None and tf.has_tipper() and tf.Tipper._has_tf_model_error():
try:
return np.nan_to_num(
np.asarray(
tf.Tipper.tipper_model_error[idx : idx + 1],
dtype=float,
)
)
except Exception:
pass
if not hasattr(tf, "t_interp_dict"):
tf.t_interp_dict = self.get_interp1d_functions_t(tf)
if not tf.has_tipper():
return np.zeros((1, 1, 2), dtype=float)
if tf.Tipper._has_tf_error():
return np.nan_to_num(
np.array(
[
[
[
tf.t_interp_dict["tzx"]["model_err"](
1.0 / self.plot_period
)[0],
tf.t_interp_dict["tzy"]["model_err"](
1.0 / self.plot_period
)[0],
]
]
]
)
).reshape((1, 1, 2))
else:
return np.zeros((1, 1, 2), dtype=float)
[docs]
def add_raster(
self, ax, raster_fn: str | Path, add_colorbar: bool = True, **kwargs
):
"""
Add a raster image to a matplotlib axis.
Parameters
----------
ax : matplotlib.axes.Axes
Matplotlib axis to add raster to
raster_fn : str | Path
Path to raster file (readable by rasterio)
add_colorbar : bool, optional
Whether to add a colorbar, by default True
**kwargs : dict
Additional keyword arguments passed to rasterio plotting
Returns
-------
matplotlib image or collection
The raster plot object
"""
return add_raster(ax, raster_fn, add_colorbar=add_colorbar, **kwargs)
[docs]
class PlotBaseProfile(PlotBase):
"""
Base object for profile plots like pseudo sections.
Provides functionality for creating profile views of MT data along
a linear transect.
Parameters
----------
tf_list : list or MTCollection
List of transfer function objects or MTCollection
**kwargs : dict
Additional keyword arguments for PlotBase parent class and
profile settings
Attributes
----------
mt_data : list or MTCollection
MT data to plot
profile_vector : array-like | None
Profile direction vector
profile_angle : float | None
Profile angle in degrees
profile_line : tuple | None
Profile line parameters (slope, intercept)
profile_reverse : bool
Whether to reverse profile direction, by default False
x_stretch : float
Horizontal stretching factor for profile, by default 5000
y_stretch : float
Vertical stretching factor for profile, by default 1000
y_scale : str
Y-axis scale type ('period' or 'frequency'), by default 'period'
"""
def __init__(self, tf_list, **kwargs) -> None:
super().__init__(**kwargs)
self.mt_data = tf_list
self.profile_vector = None
self.profile_angle = None
self.profile_line = None
self.profile_reverse = False
self.x_stretch = 5000
self.y_stretch = 1000
self.y_scale = "period"
self._rotation_angle = 0
for key, value in kwargs.items():
setattr(self, key, value)
# ---need to rotate data on setting rotz
@property
def rotation_angle(self) -> float:
"""
Get rotation angle for data.
Returns
-------
float
Rotation angle in degrees
"""
return self._rotation_angle
@rotation_angle.setter
def rotation_angle(self, value: float) -> None:
"""
Set rotation angle for all transfer functions.
Parameters
----------
value : float
Rotation angle in degrees to apply to all data
"""
if hasattr(self.mt_data, "rotate") and hasattr(self.mt_data, "get_station"):
self.mt_data.rotate(value, inplace=True)
else:
for tf in self._iter_mt_objects():
tf.rotation_angle = value
self._rotation_angle = value
def _iter_mt_objects(self):
"""Yield MT objects from supported container types."""
if hasattr(self.mt_data, "values"):
yield from self.mt_data.values()
return
if hasattr(self.mt_data, "_iter_station_paths") and hasattr(
self.mt_data, "get_station"
):
if hasattr(self.mt_data, "compute"):
self.mt_data.compute()
for station_path in self.mt_data._iter_station_paths():
yield self.mt_data.get_station(station_path, as_mt=True)
return
raise TypeError("mt_data must provide values() or MTData-style station access")
def _get_mt_objects(self):
"""Return MT objects as a list for repeated profile operations."""
return list(self._iter_mt_objects())
def _sync_mt_data_profile_offsets(
self,
x_column: str,
y_column: str,
) -> dict[tuple[str, str], float]:
"""Compute and persist profile offsets for MTData-backed station attrs."""
if not (
hasattr(self.mt_data, "station_locations")
and hasattr(self.mt_data, "_iter_station_paths")
and hasattr(self.mt_data, "get_station")
):
return {}
station_df = self.mt_data.station_locations
if station_df is None or station_df.empty:
return {}
if (
x_column not in station_df.columns
or y_column not in station_df.columns
or "survey" not in station_df.columns
or "station" not in station_df.columns
):
return {}
x_values = station_df[x_column].to_numpy(dtype=float)
y_values = station_df[y_column].to_numpy(dtype=float)
finite = np.isfinite(x_values) & np.isfinite(y_values)
if not np.any(finite):
return {}
profile_vector = np.array([1.0, self.profile_line[0]], dtype=float)
profile_vector /= np.linalg.norm(profile_vector)
station_vectors = np.column_stack(
[x_values[finite], y_values[finite] - self.profile_line[1]]
)
offsets = np.abs(station_vectors @ profile_vector)
offsets -= offsets.min()
key_to_path: dict[tuple[str, str], str] = {}
for station_path in self.mt_data._iter_station_paths():
attrs = self.mt_data.get_station(station_path).attrs
key = (str(attrs.get("survey", "")), str(attrs.get("station", "")))
key_to_path[key] = station_path
offset_lookup: dict[tuple[str, str], float] = {}
profile_df = station_df.loc[finite, ["survey", "station"]].copy()
profile_df.loc[:, "profile_offset"] = offsets
for row in profile_df.itertuples(index=False):
key = (str(getattr(row, "survey", "")), str(getattr(row, "station", "")))
offset_value = float(getattr(row, "profile_offset", 0.0))
offset_lookup[key] = offset_value
station_path = key_to_path.get(key)
if station_path is None:
continue
self.mt_data.get_station(station_path).attrs[
"profile_offset"
] = offset_value
return offset_lookup
def _get_profile_line(
self, x: np.ndarray | None = None, y: np.ndarray | None = None
) -> None:
"""
Calculate profile line using linear regression through data points.
Determines the best-fit line through station locations and projects
all stations onto this profile line.
Parameters
----------
x : np.ndarray | None, optional
X coordinates of stations. If None, uses longitude from mt_data,
by default None
y : np.ndarray | None, optional
Y coordinates of stations. If None, uses latitude from mt_data,
by default None
Raises
------
ValueError
If only one of x or y is provided
"""
station_locations = getattr(self.mt_data, "station_locations", None)
if (
station_locations is not None
and hasattr(station_locations, "columns")
and "profile_offset" in station_locations.columns
):
offsets = np.nan_to_num(
station_locations["profile_offset"].to_numpy(dtype=float),
nan=0.0,
)
if offsets.size > 0 and np.any(offsets != 0):
return
mt_objects = None
coordinate_columns: tuple[str, str] | None = None
if x is None and y is None:
if (
station_locations is not None
and getattr(station_locations, "empty", True) is False
):
for x_col, y_col in [("longitude", "latitude"), ("east", "north")]:
if (
x_col not in station_locations.columns
or y_col not in station_locations.columns
):
continue
x_values = station_locations[x_col].to_numpy(dtype=float)
y_values = station_locations[y_col].to_numpy(dtype=float)
finite = np.isfinite(x_values) & np.isfinite(y_values)
if np.count_nonzero(finite) < 2:
continue
x = x_values[finite]
y = y_values[finite]
coordinate_columns = (x_col, y_col)
break
if x is None or y is None:
mt_objects = self._get_mt_objects()
x = np.zeros(len(mt_objects))
y = np.zeros(len(mt_objects))
for ii, tf in enumerate(mt_objects):
x[ii] = tf.longitude
y[ii] = tf.latitude
elif x is None or y is None:
raise ValueError("get_profile")
if x.size < 2 or y.size < 2:
return
# check regression for 2 profile orientations:
# horizontal (N=N(E)) or vertical(E=E(N))
# use the one with the lower standard deviation
profile1 = stats.linregress(x, y)
profile2 = stats.linregress(y, x)
# if the profile is rather E=E(N), the parameters have to converted
# into N=N(E) form:
if profile2.stderr < profile1.stderr:
self.profile_line = (
1.0 / profile2.slope,
-profile2.intercept / profile2.slope,
)
else:
self.profile_line = profile1[:2]
offset_lookup: dict[tuple[str, str], float] = {}
if coordinate_columns is not None:
offset_lookup = self._sync_mt_data_profile_offsets(*coordinate_columns)
if mt_objects is None:
mt_objects = self._get_mt_objects()
if offset_lookup:
for mt_obj in mt_objects:
key = (
str(getattr(mt_obj, "survey", "")),
str(getattr(mt_obj, "station", "")),
)
if key in offset_lookup:
mt_obj.profile_offset = offset_lookup[key]
else:
mt_obj.project_onto_profile_line(
self.profile_line[0],
self.profile_line[1],
)
return
for mt_obj in mt_objects:
mt_obj.project_onto_profile_line(self.profile_line[0], self.profile_line[1])
def _get_offset(self, tf) -> float:
"""
Get approximate offset distance for a station along the profile.
Parameters
----------
tf : MT or Transfer Function object
Transfer function object with profile_offset attribute
Returns
-------
float
Scaled offset distance along profile, negative if profile_reverse
is True
"""
direction = 1
if self.profile_reverse:
direction = -1
offset_value = None
station_locations = getattr(self.mt_data, "station_locations", None)
if (
station_locations is not None
and hasattr(station_locations, "itertuples")
and hasattr(station_locations, "columns")
and "profile_offset" in station_locations.columns
and "station" in station_locations.columns
):
tf_station = str(getattr(tf, "station", ""))
tf_survey = str(getattr(tf, "survey", ""))
for row in station_locations.itertuples(index=False):
row_station = str(getattr(row, "station", ""))
if row_station != tf_station:
continue
row_survey = str(getattr(row, "survey", ""))
if tf_survey and row_survey != tf_survey:
continue
row_offset = getattr(row, "profile_offset", None)
if row_offset is None or not np.isfinite(row_offset):
continue
offset_value = float(row_offset)
break
if offset_value is None:
offset_value = float(getattr(tf, "profile_offset", 0.0) or 0.0)
return direction * offset_value * self.x_stretch
def _get_interpolated_z(self, tf) -> np.ndarray:
"""
Get interpolated impedance tensor at plot period.
Parameters
----------
tf : MT or Transfer Function object
Transfer function object with interpolation functions
Returns
-------
np.ndarray
Complex impedance tensor array of shape (2, 2) at the
specified plot period
"""
if not hasattr(tf, "z_interp_dict"):
tf.z_interp_dict = self.get_interp1d_functions_z(tf)
return np.nan_to_num(
np.array(
[
[
tf.z_interp_dict["zxx"]["real"](1 / self.plot_period)
+ 1j * tf.z_interp_dict["zxx"]["imag"](1.0 / self.plot_period),
tf.z_interp_dict["zxy"]["real"](1.0 / self.plot_period)
+ 1j * tf.z_interp_dict["zxy"]["imag"](1.0 / self.plot_period),
],
[
tf.z_interp_dict["zyx"]["real"](1.0 / self.plot_period)
+ 1j * tf.z_interp_dict["zyx"]["imag"](1.0 / self.plot_period),
tf.z_interp_dict["zyy"]["real"](1.0 / self.plot_period)
+ 1j * tf.z_interp_dict["zyy"]["imag"](1.0 / self.plot_period),
],
]
)
)
def _get_interpolated_z_error(self, tf) -> np.ndarray:
"""
Get interpolated impedance tensor error at plot period.
Parameters
----------
tf : MT or Transfer Function object
Transfer function object with interpolation functions
Returns
-------
np.ndarray
Impedance tensor error array of shape (2, 2) at the
specified plot period
"""
if not hasattr(tf, "z_interp_dict"):
tf.z_interp_dict = self.get_interp1d_functions_z(tf)
if tf.z_interp_dict["zxy"]["err"] is not None:
return np.nan_to_num(
np.array(
[
[
tf.z_interp_dict["zxx"]["err"](1.0 / self.plot_period),
tf.z_interp_dict["zxy"]["err"](1.0 / self.plot_period),
],
[
tf.z_interp_dict["zyx"]["err"](1.0 / self.plot_period),
tf.z_interp_dict["zyy"]["err"](1.0 / self.plot_period),
],
]
)
)
def _get_interpolated_t(self, tf) -> np.ndarray:
"""
Get interpolated tipper at plot period.
Parameters
----------
tf : MT or Transfer Function object
Transfer function object with interpolation functions
Returns
-------
np.ndarray
Complex tipper array of shape (1, 1, 2) at the specified
plot period. Returns zeros if no tipper dictionary available.
"""
if tf.t_interp_dict == {}:
return np.zeros((1, 1, 2), dtype=complex)
return np.nan_to_num(
np.array(
[
[
[
tf.t_interp_dict["tzx"]["real"](1.0 / self.plot_period)
+ 1j
* tf.t_interp_dict["tzx"]["imag"](1.0 / self.plot_period),
tf.t_interp_dict["tzy"]["real"](1.0 / self.plot_period)
+ 1j
* tf.t_interp_dict["tzy"]["imag"](1.0 / self.plot_period),
]
]
]
)
)
def _get_interpolated_t_err(self, tf) -> np.ndarray:
"""
Get interpolated tipper error at plot period.
Parameters
----------
tf : MT or Transfer Function object
Transfer function object with interpolation functions
Returns
-------
np.ndarray
Tipper error array of shape (1, 1, 2) at the specified
plot period. Returns zeros if no tipper dictionary available.
"""
if tf.t_interp_dict == {}:
return np.array((1, 1, 2), dtype=float)
return np.nan_to_num(
np.array(
[
[
[
tf.t_interp_dict["tzx"]["err"](1.0 / self.plot_period),
tf.t_interp_dict["tzy"]["err"](1.0 / self.plot_period),
]
]
]
)
)