Source code for mtpy.imaging.plot_resphase_maps

#!/bin/env python
"""
Description:
    Plots resistivity and phase maps for a given frequency

References:

CreationDate:   4/19/18
Developer:      rakib.hassan@ga.gov.au

Revision History:
    LastUpdate:     4/19/18   RH
                    2022-09 JP


"""

# =============================================================================
# Imports
# =============================================================================

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import ticker

from mtpy.core import Z
from mtpy.imaging.mtplot_tools import PlotBaseMaps

# =============================================================================


[docs] class PlotResPhaseMaps(PlotBaseMaps): """Plots apparent resistivity and phase in map view from a list of edi files. Arguments:: **fn_list** : list of strings full paths to .edi files to plot **fig_size** : tuple or list (x, y) in inches dimensions of the figure box in inches, this is a default unit of matplotlib. You can use this so make the plot fit the figure box to minimize spaces from the plot axes to the figure box. *default* is [8, 8] **mapscale** : [ 'deg' | 'm' | 'km' ] Scale of the map coordinates. * 'deg' --> degrees in latitude and longitude * 'm' --> meters for easting and northing * 'km' --> kilometers for easting and northing **plot_yn** : [ 'y' | 'n' ] *'y' to plot on creating an instance *'n' to not plot on creating an instance **title** : string figure title **dpi** : int dots per inch of the resolution. *default* is 300 **font_size** : float size of the font that labels the plot, 2 will be added to this number for the axis labels. """ def __init__(self, mt_data, **kwargs): """Initialise the object. :param mt_data: :param **kwargs: Keyword-value pairs. """ super().__init__(**kwargs) self.mt_data = mt_data self.use_mt_data_preinterpolation = True self._interpolated_mt_data_cache = None self._interpolated_mt_data_cache_period = None # read in map scale self.map_units = "deg" self.scale = 1 self.res_cmap = "rainbow_r" self.phase_cmap = "rainbow" self.plot_period = 1 self.plot_xx = False self.plot_xy = True self.plot_yx = True self.plot_yy = False self.plot_det = False self.plot_resistivity = True self.plot_phase = True self.plot_stations = True self.marker_color = "k" self.marker_size = 10 self.cmap_limits = { "res_xx": (-1, 2), "res_xy": (0, 3), "res_yx": (0, 3), "res_yy": (-1, 2), "res_det": (0, 3), "phase_xx": (-180, 180), "phase_xy": (0, 100), "phase_yx": (0, 100), "phase_yy": (-180, 180), "phase_det": (0, 100), } self.label_dict = { "res_xx": r"$\rho_{xx} \mathrm{[\Omega m]}$", "res_xy": r"$\rho_{xy} \mathrm{[\Omega m]}$", "res_yx": r"$\rho_{yx} \mathrm{[\Omega m]}$", "res_yy": r"$\rho_{yy} \mathrm{[\Omega m]}$", "res_det": r"$\rho_{det} \mathrm{[\Omega m]}$", "phase_xx": r"$\phi_{xx}$", "phase_xy": r"$\phi_{xy}$", "phase_yx": r"$\phi_{yx}$", "phase_yy": r"$\phi_{yy}$", "phase_det": r"$\phi_{det}$", } for key, value in kwargs.items(): setattr(self, key, value) if self.show_plot: self.plot() @property def map_units(self): """Map units.""" return self._map_units @map_units.setter def map_units(self, value): """Map units.""" self._map_units = value if value in ["km"]: self.scale = 1.0 / 1000 self.cell_size = 0.2 if value in ["m"]: self.scale = 1 self.cell_size = 200 else: self.scale = 1.0 def _iter_mt_objects(self): """Yield MT objects from supported container types.""" if hasattr(self.mt_data, "values"): yield from self.mt_data.values() return if hasattr(self.mt_data, "_iter_station_paths") and hasattr( self.mt_data, "get_station" ): if hasattr(self.mt_data, "compute"): self.mt_data.compute() for station_path in self.mt_data._iter_station_paths(): yield self.mt_data.get_station(station_path, as_mt=True) return raise TypeError("mt_data must provide values() or MTData-style station access") def _get_mt_objects(self): """Return MT objects as a list for repeated map calculations.""" data_source = self._get_mt_data_for_plot_period() if hasattr(data_source, "values"): return list(data_source.values()) if hasattr(data_source, "_iter_station_paths") and hasattr( data_source, "get_station" ): if hasattr(data_source, "compute"): data_source.compute() return [ data_source.get_station(station_path, as_mt=True) for station_path in data_source._iter_station_paths() ] return list(self._iter_mt_objects()) def _get_mt_data_for_plot_period(self): """Return MTData interpolated to the active plot period when supported.""" if not self.use_mt_data_preinterpolation: return self.mt_data if not ( hasattr(self.mt_data, "interpolate") and hasattr(self.mt_data, "_iter_station_paths") and hasattr(self.mt_data, "get_station") ): return self.mt_data target_period = float(self.plot_period) if ( self._interpolated_mt_data_cache is not None and self._interpolated_mt_data_cache_period is not None and np.isclose(self._interpolated_mt_data_cache_period, target_period) ): return self._interpolated_mt_data_cache try: interpolated = self.mt_data.interpolate( np.array([target_period], dtype=float), inplace=False, bounds_error=False, ) if interpolated is not None: self._interpolated_mt_data_cache = interpolated self._interpolated_mt_data_cache_period = target_period return interpolated except Exception as error: self.logger.debug( "Falling back to per-station interpolation for plot period " f"{target_period}: {error}" ) return self.mt_data def _get_n_rows(self): """Get the number of rows in the subplot. :return: DESCRIPTION. :rtype: TYPE """ n = 0 if self.plot_resistivity: n += 1 if self.plot_phase: n += 1 return n def _get_n_columns(self): """Get the number of columns in the subplot.""" n = 0 for cc in ["xx", "xy", "yx", "yy", "det"]: if getattr(self, f"plot_{cc}"): n += 1 return n def _get_n_subplots(self): """Get the subplot indices.""" nr = self._get_n_rows() nc = self._get_n_columns() subplot_dict = { "res_xx": None, "res_xy": None, "res_yx": None, "res_yy": None, "res_det": None, "phase_xx": None, "phase_xy": None, "phase_yx": None, "phase_yy": None, "phase_det": None, } plot_num = 0 for cc in ["xx", "xy", "yx", "yy", "det"]: if self.plot_resistivity: if getattr(self, f"plot_{cc}"): plot_num += 1 subplot_dict[f"res_{cc}"] = (nr, nc, plot_num) for cc in ["xx", "xy", "yx", "yy", "det"]: if self.plot_phase: if getattr(self, f"plot_{cc}"): plot_num += 1 subplot_dict[f"phase_{cc}"] = (nr, nc, plot_num) return subplot_dict def _get_subplots(self): """Get the subplots. :return: DESCRIPTION. :rtype: TYPE """ subplot_dict = self._get_n_subplots() ax_dict = {} for cc in ["xx", "xy", "yx", "yy", "det"]: if self.plot_resistivity: comp = f"res_{cc}" if getattr(self, f"plot_{cc}"): ax_dict[comp] = self.fig.add_subplot( *subplot_dict[comp], aspect="equal" ) if self.plot_phase: comp = f"phase_{cc}" if getattr(self, f"plot_{cc}"): ax_dict[comp] = self.fig.add_subplot( *subplot_dict[comp], aspect="equal" ) share = [ax for comp, ax in ax_dict.items() if ax is not None] # share x and y across all subplots for easier zooming for ax in share[1:]: ax.sharex(share[0]) ax.sharey(share[0]) return ax_dict def _get_data_array(self): """Make a data array to plot.""" mt_objects = self._get_mt_objects() plot_array = np.zeros( len(mt_objects), dtype=[ ("station", "U20"), ("latitude", float), ("longitude", float), ("elevation", float), ("res_xx", float), ("res_xy", float), ("res_yx", float), ("res_yy", float), ("res_det", float), ("phase_xx", float), ("phase_xy", float), ("phase_yx", float), ("phase_yy", float), ("phase_det", float), ], ) for ii, tf in enumerate(mt_objects): try: z = self._get_interpolated_z(tf) except ValueError: self.logger.warning( f"Could not interpolate period {self.plot_period} for station {tf.station}" ) continue z_object = Z(z, frequency=[1.0 / self.plot_period]) plot_array["station"][ii] = tf.station plot_array["latitude"][ii] = tf.latitude plot_array["longitude"][ii] = tf.longitude if tf.elevation is not None: plot_array["elevation"][ii] = tf.elevation * self.scale plot_array["res_xx"][ii] = z_object.res_xx[0] plot_array["res_xy"][ii] = z_object.res_xy[0] plot_array["res_yx"][ii] = z_object.res_yx[0] plot_array["res_yy"][ii] = z_object.res_yy[0] plot_array["res_det"][ii] = z_object.res_det[0] plot_array["phase_xx"][ii] = z_object.phase_xx[0] plot_array["phase_xy"][ii] = z_object.phase_xy[0] if z_object.phase_yx[0] != 0: plot_array["phase_yx"][ii] = z_object.phase_yx[0] + 180 plot_array["phase_yy"][ii] = z_object.phase_yy[0] plot_array["phase_det"][ii] = z_object.phase_det[0] return plot_array def _get_cmap(self, component): """Get color map with proper limits.""" if "res" in component: cmap = self.res_cmap elif "phase" in component: cmap = self.phase_cmap return cmap def _get_colorbar(self, ax, im_mappable, component): """Get colorbar. :param im_mappable: :param ax: :param component: DESCRIPTION. :type component: TYPE :return: DESCRIPTION. :rtype: TYPE """ if "res" in component: cb = plt.colorbar( im_mappable, ax=ax, ticks=ticker.FixedLocator( np.arange( int(np.round(self.cmap_limits[component][0])), int(np.round(self.cmap_limits[component][1])) + 1, ) ), shrink=0.6, extend="both", ) labels = [ self.period_label_dict[dd] for dd in np.arange( int(np.round(self.cmap_limits[component][0])), int(np.round(self.cmap_limits[component][1])) + 1, ) ] cb.ax.yaxis.set_major_formatter(ticker.FixedFormatter(labels)) elif "phase" in component: cb = plt.colorbar(im_mappable, ax=ax, shrink=0.6, extend="both") cb.ax.tick_params(axis="both", which="major", labelsize=self.font_size - 1) cb.ax.tick_params(axis="both", which="minor", labelsize=self.font_size - 1) return cb # ----------------------------------------------- # The main plot method for this module # -------------------------------------------------
[docs] def plot(self): """Plot function.""" # set position properties for the plot self._set_subplot_params() # make figure instance self.fig = plt.figure(self.fig_num, figsize=self.fig_size, dpi=self.fig_dpi) # clear the figure if there is already one up plt.clf() subplot_dict = self._get_subplots() data_array = self._get_data_array() # plot results subplot_numbers = self._get_n_subplots() for comp, ax in subplot_dict.items(): cmap = self._get_cmap(comp) plot_array = data_array[np.nonzero(data_array[comp])] if self.interpolation_method in ["nearest", "linear", "cubic"]: x, y, image = self.interpolate_to_map(plot_array, comp) im = ax.pcolormesh( x, y, image, cmap=cmap, vmin=self.cmap_limits[comp][0], vmax=self.cmap_limits[comp][1], ) elif self.interpolation_method in [ "fancy", "delaunay", "triangulate", ]: triangulation, image, indices = self.interpolate_to_map( plot_array, comp, ) im = ax.tricontourf( triangulation, image, # mask=indices, levels=np.linspace( self.cmap_limits[comp][0], self.cmap_limits[comp][1], 50, ), extend="both", cmap=cmap, ) self._get_colorbar(ax, im, comp) # show stations if self.plot_stations: if self.plot_stations: ax.scatter( plot_array["longitude"], plot_array["latitude"], marker=self.marker, s=self.marker_size, c=self.marker_color, ) # Label plots ax.text( 0.01, 0.9, self.label_dict[comp], fontdict={"size": self.font_size + 2}, transform=ax.transAxes, bbox=dict(facecolor="white", alpha=0.5), ) if ( subplot_numbers[comp][2] == 1 or subplot_numbers[comp][2] == subplot_numbers[comp][1] + 1 ): ax.set_ylabel("Latitude (deg)", fontdict=self.font_dict) if subplot_numbers[comp][0] == 1: ax.set_xlabel("Longitude (deg)", fontdict=self.font_dict) elif subplot_numbers[comp][0] == 2: if subplot_numbers[comp][2] > (subplot_numbers[comp][1]): ax.set_xlabel("Longitude (deg)", fontdict=self.font_dict) # Plot title self.fig.suptitle(f"Plot Period: {self.plot_period:.5g} s", y=0.985)