Source code for mtpy.imaging.plot_pseudosection

# -*- coding: utf-8 -*-
"""
Created on Thu May 30 18:39:58 2013

@author: jpeacock-pr
"""

import matplotlib.pyplot as plt

# =============================================================================
# Imports
# =============================================================================
import numpy as np
import pandas as pd
from matplotlib import ticker
from scipy import signal

from mtpy.imaging.mtplot_tools import (
    griddata_interpolate,
    PlotBaseProfile,
    triangulate_interpolation,
)

# ==============================================================================


[docs] class PlotResPhasePseudoSection(PlotBaseProfile): """Plot a resistivity and phase pseudo section for different components Need to input one of the following lists:. """ def __init__(self, mt_data, **kwargs): """Initialize parameters.""" super().__init__(mt_data, **kwargs) # --> set figure parameters self.aspect = kwargs.pop("aspect", "auto") self.xtickspace = kwargs.pop("xtickspace", 1) self.station_id = kwargs.pop("station_id", kwargs.pop("stationid", [0, 4])) self.stationid = self.station_id self.linedir = kwargs.pop("linedir", "ew") # --> set plots to plot and how to plot them self.plot_xx = False self.plot_xy = True self.plot_yx = True self.plot_yy = False self.plot_det = False self.plot_resistivity = True self.plot_phase = True self.data_df = None self.n_periods = 60 self.interpolation_method = "nearest" self.nearest_neighbors = 7 self.interpolation_power = 4 self.median_filter_kernel = None self.x_stretch = 1 self.y_stretch = 1 self.station_step = 1 # --> set plot limits self.cmap_limits = { "res_xx": (-1, 2), "res_xy": (0, 3), "res_yx": (0, 3), "res_yy": (-1, 2), "res_det": (0, 3), "phase_xx": (-180, 180), "phase_xy": (0, 100), "phase_yx": (0, 100), "phase_yy": (-180, 180), "phase_det": (0, 100), } self.label_dict = { "res_xx": r"$\rho_{xx} \mathrm{[\Omega m]}$", "res_xy": r"$\rho_{xy} \mathrm{[\Omega m]}$", "res_yx": r"$\rho_{yx} \mathrm{[\Omega m]}$", "res_yy": r"$\rho_{yy} \mathrm{[\Omega m]}$", "res_det": r"$\rho_{det} \mathrm{[\Omega m]}$", "phase_xx": r"$\phi_{xx}$", "phase_xy": r"$\phi_{xy}$", "phase_yx": r"$\phi_{yx}$", "phase_yy": r"$\phi_{yy}$", "phase_det": r"$\phi_{det}$", } # --> set colormaps Note only mtcolors is supported self.res_cmap = plt.get_cmap("mt_rd2gr2bl") self.phase_cmap = plt.get_cmap("mt_bl2gr2rd") for key, value in kwargs.items(): setattr(self, key, value) if self.show_plot: self.plot() def _get_period_array(self, df): """Get the period array to interpolate on to.""" p_min = df.period.min() * self.y_stretch p_max = df.period.max() * self.y_stretch return np.linspace(p_min, p_max, self.n_periods) def _get_n_rows(self): """Get the number of rows in the subplot. :return: DESCRIPTION. :rtype: TYPE """ n = 0 if self.plot_resistivity: n += 1 if self.plot_phase: n += 1 return n def _get_n_columns(self): """Get the number of columns in the subplot.""" n = 0 for cc in ["xx", "xy", "yx", "yy", "det"]: if getattr(self, f"plot_{cc}"): n += 1 return n def _get_n_subplots(self): """Get the subplot indices.""" nr = self._get_n_rows() nc = self._get_n_columns() subplot_dict = { "res_xx": None, "res_xy": None, "res_yx": None, "res_yy": None, "res_det": None, "phase_xx": None, "phase_xy": None, "phase_yx": None, "phase_yy": None, "phase_det": None, } plot_num = 0 for cc in ["xx", "xy", "yx", "yy", "det"]: if self.plot_resistivity: if getattr(self, f"plot_{cc}"): plot_num += 1 subplot_dict[f"res_{cc}"] = (nr, nc, plot_num) for cc in ["xx", "xy", "yx", "yy", "det"]: if self.plot_phase: if getattr(self, f"plot_{cc}"): plot_num += 1 subplot_dict[f"phase_{cc}"] = (nr, nc, plot_num) return subplot_dict def _get_subplots(self): """Get the subplots. :return: DESCRIPTION. :rtype: TYPE """ subplot_dict = self._get_n_subplots() ax_dict = {} for cc in ["xx", "xy", "yx", "yy", "det"]: if self.plot_resistivity: comp = f"res_{cc}" if getattr(self, f"plot_{cc}"): ax_dict[comp] = self.fig.add_subplot( *subplot_dict[comp], aspect=self.aspect ) if self.plot_phase: comp = f"phase_{cc}" if getattr(self, f"plot_{cc}"): ax_dict[comp] = self.fig.add_subplot( *subplot_dict[comp], aspect=self.aspect ) share = [ax for comp, ax in ax_dict.items() if ax is not None] # share x and y across all subplots for easier zooming for ax in share[1:]: ax.sharex(share[0]) ax.sharey(share[0]) return ax_dict def _get_data_df(self): """Get resistivity and phase values in the correct order according to offsets and periods. """ self._get_profile_line() entries = [] mt_objects = self._get_mt_objects() for tf in mt_objects: offset = self._get_offset(tf) rp = tf.Z def _safe_log10(value): value = float(value) if value <= 0: return 0.0 return np.log10(value) for ii, period in enumerate(tf.period): if rp.phase_yx[ii] != 0: rp.phase_yx[ii] += 180 entry = { "station": tf.station, "offset": offset, "period": np.log10(period), "res_xx": _safe_log10(rp.res_xx[ii]), "res_xy": _safe_log10(rp.res_xy[ii]), "res_yx": _safe_log10(rp.res_yx[ii]), "res_yy": _safe_log10(rp.res_yy[ii]), "res_det": _safe_log10(rp.res_det[ii]), "phase_xx": rp.phase_xx[ii], "phase_xy": rp.phase_xy[ii], "phase_yx": rp.phase_yx[ii] + 180, "phase_yy": rp.phase_yy[ii], "phase_det": rp.phase_det[ii], } entries.append(entry) return pd.DataFrame(entries) def _get_offset_station(self, df): """Get the plotting offset and station name for labels. :return: DESCRIPTION. :rtype: TYPE """ plot_dict = {"station": [], "offset": []} for station in df.station.unique(): plot_dict["station"].append( station[self.station_id[0] : self.station_id[1]] ) plot_dict["offset"].append( df.loc[df.station == station, "offset"].unique()[0] * self.x_stretch ) plot_dict["station"] = np.array(plot_dict["station"]) plot_dict["offset"] = np.array(plot_dict["offset"]) if self.station_step > 1: plot_dict["station"][ np.arange(1, len(plot_dict["station"]), self.station_step) ] = "" return plot_dict def _get_cmap(self, component): """Get color map with proper limits.""" if "res" in component: cmap = self.res_cmap elif "phase" in component: cmap = self.phase_cmap return cmap def _get_colorbar(self, ax, im_mappable, component): """Get colorbar. :param im_mappable: :param ax: :param component: DESCRIPTION. :type component: TYPE :return: DESCRIPTION. :rtype: TYPE """ if "res" in component: cb = plt.colorbar( im_mappable, ax=ax, ticks=ticker.FixedLocator( np.arange( int(np.round(self.cmap_limits[component][0])), int(np.round(self.cmap_limits[component][1])) + 1, ) ), shrink=0.6, extend="both", label=self.label_dict[component], ) labels = [ self.period_label_dict[dd] for dd in np.arange( int(np.round(self.cmap_limits[component][0])), int(np.round(self.cmap_limits[component][1])) + 1, ) ] cb.ax.yaxis.set_major_formatter(ticker.FixedFormatter(labels)) elif "phase" in component: cb = plt.colorbar( im_mappable, ax=ax, shrink=0.6, extend="both", label=self.label_dict[component], ) cb.ax.tick_params(axis="both", which="major", labelsize=self.font_size - 1) cb.ax.tick_params(axis="both", which="minor", labelsize=self.font_size - 1) return cb
[docs] def plot(self): """Plot function. :return: DESCRIPTION. :rtype: TYPE """ if self.data_df is None: self.data_df = self._get_data_df() plot_periods = self._get_period_array(self.data_df) plot_dict = self._get_offset_station(self.data_df) # set position properties for the plot self._set_subplot_params() # make figure instance self.fig = plt.figure(self.fig_num, figsize=self.fig_size, dpi=self.fig_dpi) # clear the figure if there is already one up plt.clf() # Get dictionary of subplots subplot_dict = self._get_subplots() # plot results subplot_numbers = self._get_n_subplots() for comp, ax in subplot_dict.items(): cmap = self._get_cmap(comp) # get nonzero elements of the component comp_df = self.data_df.iloc[self.data_df.res_xx.to_numpy().nonzero()] if self.interpolation_method in ["nearest", "linear", "cubic"]: x, y, image = griddata_interpolate( comp_df.offset * self.x_stretch, comp_df.period * self.y_stretch, comp_df[comp].to_numpy(), self.data_df.offset * self.x_stretch, plot_periods, self.interpolation_method, ) if self.median_filter_kernel is not None: image = signal.medfilt2d(image, self.median_filter_kernel) im = ax.pcolormesh( x, y, image, cmap=cmap, vmin=self.cmap_limits[comp][0], vmax=self.cmap_limits[comp][1], ) elif self.interpolation_method in [ "fancy", "delaunay", "triangulate", ]: triangulation, image, indices = triangulate_interpolation( comp_df.offset * self.x_stretch, comp_df.period * self.y_stretch, comp_df[comp].to_numpy(), comp_df.offset * self.x_stretch, comp_df.period * self.y_stretch, self.data_df.offset * self.x_stretch, plot_periods, nearest_neighbors=self.nearest_neighbors, interp_pow=self.interpolation_power, ) im = ax.tricontourf( triangulation, image, mask=indices, levels=np.linspace( self.cmap_limits[comp][0], self.cmap_limits[comp][1], 50, ), extend="both", cmap=cmap, ) self._get_colorbar(ax, im, comp) ## Y axis y_tick_labels = [] ax.set_ylim((plot_periods.max(), plot_periods.min())) # set y-axis tick labels ax.yaxis.set_ticks( np.arange( plot_periods.max(), plot_periods.min(), -1 * self.y_stretch, ) ) for tk in ax.get_yticks(): try: y_tick_labels.append( self.period_label_dict[int(tk / self.y_stretch)] ) except KeyError: y_tick_labels.append("") ax.set_yticklabels(y_tick_labels) ## X axis # set x-axis ticks ax.set_xticks(plot_dict["offset"]) # set x-axis tick labels as station names ax.set_xticklabels(plot_dict["station"]) # Label plots # ax.set_title( # self.label_dict[comp], # fontdict={"size": self.font_size + 2}, # ) if ( subplot_numbers[comp][2] == 1 or subplot_numbers[comp][2] == subplot_numbers[comp][1] + 1 ): ax.set_ylabel("Period (s)", fontdict=self.font_dict) if subplot_numbers[comp][0] == 1: ax.set_xlabel("Station", fontdict=self.font_dict) elif subplot_numbers[comp][0] == 2: if subplot_numbers[comp][2] > (subplot_numbers[comp][1]): ax.set_xlabel("Station", fontdict=self.font_dict) self.fig.tight_layout() plt.show()