# -*- coding: utf-8 -*-
"""
plots multiple MT responses simultaneously
Created on Thu May 30 17:02:39 2013
@author: jpeacock-pr
YG: the code there is massey, todo may need to rewrite it sometime
"""
# ============================================================================
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.ticker import MultipleLocator
from mtpy.imaging.mtplot_tools import (
get_log_tick_labels,
plot_phase,
plot_pt_lateral,
plot_resistivity,
plot_tipper_lateral,
PlotBase,
)
# ============================================================================
[docs]
class PlotMultipleResponses(PlotBase):
"""Plots multiple MT responses simultaneously either in single plots or in
one plot of sub-figures or in a single plot with subfigures for each
component.
Arguments::
**fn_list** : list of filenames to plot
ie. [fn_1, fn_2, ...], *default* is None
**plot_num** : [ 1 | 2 | 3 ]
* 1 for just Ex/By and Ey/Bx *default*
* 2 for all 4 components
* 3 for off diagonal plus the determinant
**plot_style** : [ '1' | 'all' | 'compare' ]
determines the plotting style:
* '1' for plotting each station in a different
figure. *default*
* 'all' for plotting each station in a subplot
all in the same figure
* 'compare' for comparing the responses all in
one plot. Here the responses are
colored from dark to light. This
plot can get messy if too many stations
are plotted.
"""
def __init__(self, mt_data, **kwargs):
"""Initialize parameters."""
self.plot_num = 1
self.plot_style = "1"
self.mt_data = mt_data
self.include_survey = True
super().__init__(**kwargs)
self.plot_dict = dict(
[
(kk, vv)
for kk, vv in zip(
["tip", "pt", "strike", "skew"],
[
self.plot_tipper,
self.plot_pt,
self.plot_strike,
self.plot_skew,
],
)
]
)
# set arrow properties
self.arrow_head_length = 0.03
self.arrow_head_width = 0.03
self.arrow_lw = 0.5
self.plot_model_error = None
# ellipse_properties
self.ellipse_size = 0.25
# plot on initializing
if self.show_plot:
self.plot()
# ---need to rotate data on setting rotz
@property
def rotation_angle(self):
"""Rotation angle."""
return self._rotation_angle
@rotation_angle.setter
def rotation_angle(self, value):
"""Only a single value is allowed."""
# Prefer container-level rotation for MTData to avoid repeated
# dataset->MT reconstruction.
if hasattr(self.mt_data, "rotate") and hasattr(self.mt_data, "get_station"):
self.mt_data.rotate(value, inplace=True)
else:
for tf in self._iter_mt_objects():
tf.rotation_angle = value
self._rotation_angle = value
def _iter_mt_objects(self):
"""Yield MT objects from supported containers.
Supports legacy MTData-like containers and MTData instances.
"""
if hasattr(self.mt_data, "values"):
yield from self.mt_data.values()
return
if hasattr(self.mt_data, "_iter_station_paths") and hasattr(
self.mt_data, "get_station"
):
if hasattr(self.mt_data, "compute"):
self.mt_data.compute()
for station_path in self.mt_data._iter_station_paths():
yield self.mt_data.get_station(station_path, as_mt=True)
return
raise TypeError("mt_data must provide values() or MTData-style station access")
def _get_mt_objects(self):
"""Return MT objects as a list for repeated plotting passes."""
return list(self._iter_mt_objects())
@property
def plot_model_error(self):
"""Plot model error."""
return self._plot_model_error
@plot_model_error.setter
def plot_model_error(self, value):
"""Plot model error."""
if value:
self._error_str = "model_error"
else:
self._error_str = "error"
self._plot_model_error = value
def _plot_resistivity(self, axr, period, z_obj, mode="od", index=0, axr2=None):
"""Plot resistivity."""
if mode == "od":
comps = ["xy", "yx"]
props = [
self.xy_error_bar_properties,
self.yx_error_bar_properties,
]
if axr2 is not None:
ax_list = [axr, axr2]
else:
ax_list = [axr, axr]
elif mode == "d":
comps = ["xx", "yy"]
props = [
self.xy_error_bar_properties,
self.yx_error_bar_properties,
]
if axr2 is not None:
ax_list = [axr, axr2]
else:
ax_list = [axr, axr]
elif mode == "det":
comps = ["xy", "yx", "det"]
props = [
self.xy_error_bar_properties,
self.yx_error_bar_properties,
self.det_error_bar_properties,
]
if axr2 is not None:
ax_list = [axr, axr2, axr]
else:
ax_list = [axr, axr, axr]
elif mode == "det_only":
comps = ["det"]
props = [self.det_error_bar_properties]
ax_list = [axr]
res_limits = self.set_resistivity_limits(z_obj.resistivity, mode=mode)
x_limits = self.set_period_limits(period)
eb_list = []
label_list = []
for comp, prop, ax in zip(comps, props, ax_list):
ebax = plot_resistivity(
ax,
period,
getattr(z_obj, f"res_{comp}"),
getattr(z_obj, f"res_{self._error_str}_{comp}"),
**prop,
)
eb_list.append(ebax[0])
label_list.append("$Z_{" + comp + "}$")
# --> set axes properties
plt.setp(ax.get_xticklabels(), visible=False)
ax.set_yscale("log", nonpositive="clip")
ax.set_xscale("log", nonpositive="clip")
ax.set_xlim(x_limits)
ax.set_ylim(res_limits)
ax.grid(
True,
alpha=0.25,
which="both",
color=(0.25, 0.25, 0.25),
lw=0.25,
)
if index == 0:
axr.set_ylabel(
r"App. Res. ($\mathbf{\Omega \cdot m}$)",
fontdict=self.font_dict,
)
else:
plt.setp(axr.get_yticklabels(), visible=False)
axr.legend(
eb_list,
label_list,
loc=3,
markerscale=1,
borderaxespad=0.01,
labelspacing=0.07,
handletextpad=0.2,
borderpad=0.02,
)
return eb_list, label_list
def _plot_phase(self, axp, period, z_obj, mode="od", index=0, axp2=None):
"""Plot phase."""
if mode == "od":
comps = ["xy", "yx"]
if axp2 is not None:
ax_list = [axp, axp2]
else:
ax_list = [axp, axp]
props = [
self.xy_error_bar_properties,
self.yx_error_bar_properties,
]
elif mode == "d":
comps = ["xx", "yy"]
props = [
self.xy_error_bar_properties,
self.yx_error_bar_properties,
]
elif mode == "det":
comps = ["xy", "yx", "det"]
props = [
self.xy_error_bar_properties,
self.yx_error_bar_properties,
self.det_error_bar_properties,
]
if axp2 is not None:
ax_list = [axp, axp2, axp]
else:
ax_list = [axp, axp, axp]
elif mode == "det_only":
comps = ["det"]
props = [self.det_error_bar_properties]
ax_list = [axp]
phase_limits = self.set_phase_limits(z_obj.phase, mode=mode)
for comp, prop, ax in zip(comps, props, ax_list):
if comp == "yx":
plot_phase(
ax,
period,
getattr(z_obj, f"phase_{comp}"),
getattr(z_obj, f"phase_{self._error_str}_{comp}"),
yx=True,
**prop,
)
else:
plot_phase(
ax,
period,
getattr(z_obj, f"phase_{comp}"),
getattr(z_obj, f"phase_{self._error_str}_{comp}"),
yx=False,
**prop,
)
ax.set_ylim(phase_limits)
if phase_limits[0] < -10 or phase_limits[1] > 100:
ax.yaxis.set_major_locator(MultipleLocator(30))
ax.yaxis.set_minor_locator(MultipleLocator(10))
else:
ax.yaxis.set_major_locator(MultipleLocator(15))
ax.yaxis.set_minor_locator(MultipleLocator(5))
ax.grid(
True,
alpha=0.25,
which="both",
color=(0.25, 0.25, 0.25),
lw=0.25,
)
ax.set_xscale("log", nonpositive="clip")
if "y" not in self.plot_tipper and not self.plot_pt:
ax.set_xlabel("Period (s)", self.font_dict)
# --> set axes properties
if index == 0:
axp.set_ylabel("Phase (deg)", self.font_dict)
def _plot_tipper(
self, axt, period, t_obj, index=0, legend=False, zero_reference=False
):
"""Plot tipper."""
if t_obj is None:
return None, None
axt, tip_list, tip_label = plot_tipper_lateral(
axt,
t_obj,
self.plot_tipper,
self.arrow_real_properties,
self.arrow_imag_properties,
self.font_size,
legend=legend,
zero_reference=zero_reference,
arrow_direction=self.arrow_direction,
)
if axt is None:
return None, None
axt.set_xlabel("Period (s)", fontdict=self.font_dict)
axt.yaxis.set_major_locator(MultipleLocator(0.2))
axt.yaxis.set_minor_locator(MultipleLocator(0.1))
axt.set_xlabel("Period (s)", fontdict=self.font_dict)
if index == 0:
axt.set_ylabel("Tipper", fontdict=self.font_dict)
# set th xaxis tick labels to invisible
if self.plot_pt:
plt.setp(axt.xaxis.get_ticklabels(), visible=False)
axt.set_xlabel("")
return tip_list, tip_label
def _plot_pt(self, axpt, period, pt_obj, index=0, y_shift=0, edge_color=None):
"""Plot pt."""
# ----plot phase tensor ellipse---------------------------------------
if self.plot_pt:
color_array = self.get_pt_color_array(pt_obj)
x_limits = self.set_period_limits(period)
# -------------plot ellipses-----------------------------------
(
self.cbax,
self.cbpt,
) = plot_pt_lateral(
axpt,
pt_obj,
color_array,
self.ellipse_properties,
y_shift,
self.fig,
edge_color,
index,
)
# ----set axes properties-----------------------------------------------
# --> set tick labels and limits
axpt.set_xlim(np.log10(x_limits[0]), np.log10(x_limits[1]))
tklabels, xticks = get_log_tick_labels(axpt)
axpt.set_xticks(xticks)
axpt.set_xticklabels(tklabels, fontdict={"size": self.font_size})
axpt.set_xlabel("Period (s)", fontdict=self.font_dict)
# need to reset the x_limits caouse they get reset when calling
# set_ticks for some reason
axpt.set_xlim(np.log10(x_limits[0]), np.log10(x_limits[1]))
axpt.grid(
True,
alpha=0.25,
which="major",
color=(0.25, 0.25, 0.25),
lw=0.25,
)
plt.setp(axpt.get_yticklabels(), visible=False)
if index == 0:
self.cbpt.set_label(
self.cb_label_dict[self.ellipse_colorby],
fontdict={"size": self.font_size},
)
def _get_nrows(self):
"""Get nrows."""
pdict = {"res": 0, "phase": 1}
index = 0
nrows = 1
if self.plot_tipper.find("y") >= 0:
pdict["tip"] = index
index += 1
nrows = 2
if self.plot_pt:
pdict["pt"] = index
nrows = 2
index += 1
if nrows == 1:
hr = [1]
elif nrows == 2:
hr = [2, 1]
return nrows, index, hr, pdict
def _setup_subplots(
self,
gs_master,
n_stations=1,
n_index=0,
plot_num=1,
hspace=0.05,
wspace=0.15,
):
"""Setup subplots."""
# create a dictionary for the number of subplots needed
pdict = {"res": 0, "phase": 1}
# start the index at 2 because resistivity and phase is permanent for
# now
axr = None
axp = None
axr2 = None
axp2 = None
axt = None
axpt = None
nrows, index, hr, pdict = self._get_nrows()
gs_rp = gridspec.GridSpecFromSubplotSpec(
2,
2,
subplot_spec=gs_master[0, n_index],
height_ratios=[2, 1.5],
hspace=hspace,
wspace=wspace,
)
if nrows == 2:
gs_aux = gridspec.GridSpecFromSubplotSpec(
index, 1, subplot_spec=gs_master[1, n_index], hspace=hspace
)
# --> make figure for xy,yx components
if plot_num == 1 or plot_num == 3:
# set label coordinates
if self.plot_style == "compare":
label_coords = (-0.075, 0.5)
elif self.plot_style == "compare":
label_coords = (-0.095, 0.5)
# --> create the axes instances
# apparent resistivity axis
axr = self.fig.add_subplot(gs_rp[0, :])
# phase axis that shares period axis with resistivity
axp = self.fig.add_subplot(gs_rp[1, :], sharex=axr)
# --> make figure for all 4 components
elif plot_num == 2:
# set label coordinates
label_coords = (-0.095, 0.5)
# --> create the axes instances
# apparent resistivity axis
axr = self.fig.add_subplot(gs_rp[0, 0])
axr2 = self.fig.add_subplot(gs_rp[0, 1], sharex=axr)
axr2.yaxis.set_label_coords(-0.1, 0.5)
# phase axis that shares period axis with resistivity
axp = self.fig.add_subplot(gs_rp[1, 0], sharex=axr)
axp2 = self.fig.add_subplot(gs_rp[1, 1], sharex=axr)
axp2.yaxis.set_label_coords(-0.1, 0.5)
# set albel coordinates
axr.yaxis.set_label_coords(label_coords[0], label_coords[1])
axp.yaxis.set_label_coords(label_coords[0], label_coords[1])
# --> plot tipper
if self.plot_tipper.find("y") >= 0:
axt = self.fig.add_subplot(
gs_aux[pdict["tip"], :],
)
axt.yaxis.set_label_coords(label_coords[0], label_coords[1])
# --> plot phase tensors
if self.plot_pt:
# can't share axis because not on the same scale
# Removed aspect = "equal" for now, it flows better, if you want
# a detailed analysis look at plot pt
axpt = self.fig.add_subplot(gs_aux[pdict["pt"], :])
axpt.yaxis.set_label_coords(label_coords[0], label_coords[1])
return axr, axp, axr2, axp2, axt, axpt, label_coords
def _plot_all(self):
"""Plot all."""
mt_objects = self._get_mt_objects()
ns = len(mt_objects)
# set figure size according to what the plot will be.
if self.fig_size is None:
if self.plot_num == 1 or self.plot_num == 3:
self.fig_size = [ns * 4, 6]
elif self.plot_num == 2:
self.fig_size = [ns * 8, 6]
nrows, n_index, hr, pdict = self._get_nrows()
gs_master = gridspec.GridSpec(nrows, ns, hspace=0.05, height_ratios=hr)
# make a figure instance
self.fig = plt.figure(self.fig_num, self.fig_size, dpi=self.fig_dpi)
for ii, mt in enumerate(mt_objects):
(
axr,
axp,
axr2,
axp2,
axt,
axpt,
label_coords,
) = self._setup_subplots(
gs_master,
n_stations=ns,
n_index=ii,
plot_num=self.plot_num,
hspace=0.075,
wspace=0.09,
)
# plot apparent resistivity od
if self.plot_num == 1:
self._plot_resistivity(axr, mt.period, mt.Z, mode="od", index=ii)
if self.res_limits is not None:
axr.set_ylim(self.res_limits)
# plot phase od
self._plot_phase(axp, mt.period, mt.Z, mode="od", index=ii)
if self.phase_limits is not None:
axp.set_ylim(self.phase_limits)
# Plot Determinant
elif self.plot_num == 3:
# plot apparent resistivity od
self._plot_resistivity(axr, mt.period, mt.Z, mode="det", index=ii)
if self.res_limits is not None:
axr.set_ylim(self.res_limits)
# plot phase od
self._plot_phase(axp, mt.period, mt.Z, mode="det", index=ii)
if self.phase_limits is not None:
axp.set_ylim(self.phase_limits)
# plot diagonal components
if self.plot_num == 2:
# plot apparent resistivity od
self._plot_resistivity(axr2, mt.period, mt.Z, mode="d", index=ii)
# plot phase od
self._plot_phase(axp2, mt.period, mt.Z, mode="d", index=ii)
# plot tipper
self._plot_tipper(axt, mt.period, mt.Tipper, index=ii)
if self.tipper_limits is not None:
axt.set_ylim(self.tipper_limits)
# plot phase tensor
self._plot_pt(axpt, mt.period, mt.pt, index=ii)
axr.set_title(mt.station, fontsize=self.font_size, fontweight="bold")
def _plot_compare(self):
"""Plot compare."""
# plot diagonal components
if self.plot_num == 2:
raise ValueError(
"Compare mode does not support plotting diagonal components yet"
)
mt_objects = self._get_mt_objects()
ns = len(mt_objects)
# make color lists for the plots going light to dark
cxy = [(0, 0 + float(cc) / ns, 1 - float(cc) / ns) for cc in range(ns)]
cyx = [(1, float(cc) / ns, 0) for cc in range(ns)]
cdet = [(0, 1 - float(cc) / ns, 0) for cc in range(ns)]
ctipr = [(0.75 * cc / ns, 0.75 * cc / ns, 0.75 * cc / ns) for cc in range(ns)]
ctipi = [(float(cc) / ns, 1 - float(cc) / ns, 0.25) for cc in range(ns)]
# make marker lists for the different components
mxy = ["s", "D", "x", "+", "*", "1", "3", "4"] * ns
myx = ["o", "h", "8", "p", "H", 7, 4, 6] * ns
legend_list_xy = []
legend_list_yx = []
legend_list_tip = []
station_list = []
station_list_t = []
# make a figure instance
self.fig = plt.figure(self.fig_num, self.fig_size, dpi=self.fig_dpi)
nrows, n_index, hr, pdict = self._get_nrows()
gs_master = gridspec.GridSpec(nrows, 1, hspace=0.15, height_ratios=hr)
if self.plot_num == 1:
(
axr,
axp,
axr2,
axp2,
axt,
axpt,
label_coords,
) = self._setup_subplots(gs_master, plot_num=2)
elif self.plot_num == 3:
(
axr,
axp,
axr2,
axp2,
axt,
axpt,
label_coords,
) = self._setup_subplots(gs_master, plot_num=1)
period = []
for ii, mt in enumerate(mt_objects):
period.append(mt.period.min())
period.append(mt.period.max())
self.xy_color = cxy[ii]
self.xy_marker = mxy[ii]
self.yx_color = cyx[ii]
self.yx_marker = myx[ii]
self.det_color = cdet[ii]
self.det_marker = mxy[ii]
self.arrow_color_real = ctipr[ii]
self.arrow_color_imag = ctipi[ii]
# plot apparent resistivity od
if self.plot_num == 1:
eb_list, label_list = self._plot_resistivity(
axr,
mt.period,
mt.Z,
mode="od",
axr2=axr2,
)
# plot phase od
self._plot_phase(axp, mt.period, mt.Z, mode="od", index=ii, axp2=axp2)
# Plot Determinant
elif self.plot_num == 3:
# plot apparent resistivity od
eb_list, label_list = self._plot_resistivity(
axr, mt.period, mt.Z, mode="det_only"
)
# plot phase od
self._plot_phase(axp, mt.period, mt.Z, mode="det_only")
# plot tipper
tip_list, tip_label = self._plot_tipper(axt, mt.period, mt.Tipper)
if tip_list is not None:
if self.plot_tipper.find("r") > 0:
legend_list_tip.append(tip_list[0])
station_list_t.append(f"{mt.station}_{tip_label[0]}")
if self.plot_tipper.find("i") > 0:
legend_list_tip.append(tip_list[1])
station_list_t.append(f"{mt.station}_{tip_label[1]}")
elif self.plot_tipper.find("i") > 0:
legend_list_tip.append(tip_list[0])
station_list_t.append(f"{mt.station}_{tip_label[0]}")
# plot phase tensor
self._plot_pt(
axpt,
mt.period,
mt.pt,
y_shift=ii * self.ellipse_size,
edge_color=cxy[ii],
)
legend_list_xy += [eb_list[0]]
if self.plot_num in [1, 2]:
legend_list_yx += [eb_list[1]]
if self.include_survey:
station_list.append(f"{mt.station}_{mt.survey_metadata.id}")
else:
station_list.append(f"{mt.station}")
# set limits
if self.res_limits is not None:
axr.set_ylim(self.res_limits)
if axr2 is not None:
axr2.set_ylim(self.res_limits)
if self.phase_limits is not None:
axp.set_ylim(self.phase_limits)
if axp2 is not None:
axp2.set_ylim(self.phase_limits)
if self.tipper_limits is not None:
axt.set_ylim(self.tipper_limits)
period_limits = [
10 ** np.floor(np.log10(min(period))),
10 ** np.ceil(np.log10(max(period))),
]
for ax in [axr, axp]:
if ax is not None:
ax.set_xlim(period_limits)
if ax in [axt, axpt]:
if ax is not None:
ax.set_xlim([np.log10(period_limits[0]), np.log10(period_limits[1])])
# make legend
if self.plot_num == 1:
axr.legend(
legend_list_xy,
station_list,
loc=3,
ncol=2,
markerscale=0.75,
borderaxespad=0.01,
labelspacing=0.07,
handletextpad=0.2,
borderpad=0.25,
prop={"size": self.font_size - 2},
)
axr2.legend(
legend_list_yx,
station_list,
loc=3,
ncol=2,
markerscale=0.75,
borderaxespad=0.01,
labelspacing=0.07,
handletextpad=0.2,
borderpad=0.25,
prop={"size": self.font_size - 2},
)
elif self.plot_num == 3:
axr.legend(
legend_list_xy,
station_list,
loc=3,
ncol=2,
markerscale=0.75,
borderaxespad=0.01,
labelspacing=0.07,
handletextpad=0.2,
borderpad=0.25,
)
if self.plot_tipper.find("y") >= 0:
axt.legend(
legend_list_tip,
station_list_t,
loc=3,
ncol=2,
markerscale=0.75,
borderaxespad=0.01,
labelspacing=0.07,
handletextpad=0.2,
borderpad=0.25,
)
self.axr = axr
self.axp = axp
self.axr2 = axr2
self.axp2 = axp2
self.axt = axt
self.axpt = axpt
def _plot_single(self):
"""Plot single."""
p_dict = {}
for ii, tf in enumerate(self._iter_mt_objects(), 1):
p = tf.plot_mt_response(
**{
"fig_num": ii,
"plot_tipper": self.plot_tipper,
"plot_pt": self.plot_pt,
"plot_num": self.plot_num,
}
)
p_dict[tf.station] = p
return p_dict
# ---plot the resistivity and phase
[docs]
def plot(self):
"""Plot the apparent resistivity and phase."""
plt.clf()
self.subplot_right = 0.98
self._set_subplot_params()
# Plot all in one figure as subplots
if self.plot_style == "all":
self.subplot_left = 0.04
self.subplot_top = 0.96
self._set_subplot_params()
self._plot_all()
# Plot all responses into one plot to compare changes
elif self.plot_style == "compare":
self._plot_compare()
elif self.plot_style in [1, "1", "single"]:
return self._plot_single()