Source code for mtpy.processing.aurora.process_aurora

# -*- coding: utf-8 -*-
"""
Created on Tue Jul 30 17:11:42 2024

@author: jpeacock
"""

# =============================================================================
# Imports
# =============================================================================
from __future__ import annotations

import warnings

import numpy as np
import pandas as pd
from aurora.config.config_creator import ConfigCreator
from aurora.config.metadata import Processing
from aurora.pipelines.process_mth5 import process_mth5
from loguru import logger
from mt_metadata.features import StridingWindowCoherence
from mt_metadata.features.weights import (
    ChannelWeightSpec,
    FeatureWeightSpec,
    TaperMonotonicWeightKernel,
)
from mth5.helpers import close_open_files
from mth5.mth5 import MTH5
from mth5.processing.kernel_dataset import KernelDataset

from mtpy import MT
from mtpy.processing.base import BaseProcessing

warnings.filterwarnings("ignore")
# =============================================================================


[docs] class AuroraProcessing(BaseProcessing): """Convenience class to process with Aurora .. code-block:: python from mtpy.processing.aurora.process_aurora import AuroraProcessing ap = AuroraProcessing() # set local station and path to MTH5 ap.local_station_id = "mt01" ap.local_mth5_path = "/path/to/local_mth5.h5" # set remote station and path to MTH5 ap.remote_station_id = "rr01" ap.remote_mth5_path = "/path/to/remote_mth5.h5" # process single sample rate tf_obj = ap.process_single_sample_rate(sample_rate=1) # process multiple sample rates, merge them all together and # save transfer functions to the local MTH5 tf_processed_dict = ap.process( sample_rates=[4096, 1], merge=True, save_to_mth5=True ). """ def __init__(self, **kwargs) -> None: """ Initialize AuroraProcessing with default merge and window parameters. Parameters ---------- **kwargs : dict Additional keyword arguments passed to BaseProcessing. """ self.merge_dictionary = { 1: {"period_min": 4, "period_max": 30000}, 4: {"period_min": 1, "period_max": 30000}, 50: {"period_min": 15, "period_max": 10000}, 150: {"period_min": 30, "period_max": 3000}, 256: {"period_min": 64, "period_max": 100}, 1024: {"period_min": 1.0 / 256, "period_max": 1.0 / 2.6}, 4096: {"period_min": 1.0 / 1024, "period_max": 1.0 / 26}, 24000: {"period_min": 1.0 / 6000, "period_max": 1.0 / 187.5}, } self.default_window_parameters = { "high": { "stft.window.overlap": 256, "stft.window.num_samples": 1024, "stft.window.type": "dpss", "stft.window.additional_args": {"alpha": 2.5}, }, "low": { "stft.window.overlap": 64, "stft.window.num_samples": 128, "stft.window.type": "dpss", "stft.window.additional_args": {"alpha": 2.5}, }, } self._processing_dict_keys = ["config", "kernel_dataset"] super().__init__(**kwargs) def _get_merge_df(self) -> pd.DataFrame: """ Get a DataFrame containing the periods to use for each sample rate. Returns ------- pd.DataFrame DataFrame with columns: sample_rate, period_min, period_max. """ return pd.DataFrame( { "sample_rate": list(self.merge_dictionary.keys()), "period_min": [ mgd["period_min"] for mgd in self.merge_dictionary.values() ], "period_max": [ mgd["period_max"] for mgd in self.merge_dictionary.values() ], } )
[docs] def add_simple_coherence_weights( self, channel_list=[("ex", ["ex", "hy"]), ("ey", ["ey", "hx"]), ("hz", ["hx", "hx"])], **kwargs, ) -> list[ChannelWeightSpec]: """ Add coherence weights using the channel weight spec. Parameters ---------- **kwargs : dict Additional keyword arguments (currently unused). Returns ------- list[ChannelWeightSpec] List of channel weight specifications with coherence features. """ channel_weight_specs = [] for channel in channel_list: station_1 = self.local_station_id station_2 = self.local_station_id if channel[0] in ["hz", "h3"]: station_2 = self.remote_station_id cws = ChannelWeightSpec( combination_style="multiplication", output_channels=[channel[0]], feature_weight_specs=[ FeatureWeightSpec( feature_name="coherence", # time domain coherence estimation feature=StridingWindowCoherence( channel_1=channel[1][0], channel_2=channel[1][1], station_1=station_1, station_2=station_2, # the window is set to the stft window internally. # window=Window( # type="hann", # num_samples=256, # overlap=128 # ) ), # how to weight the coherence, could be a list of different tapers weight_kernels=[ TaperMonotonicWeightKernel( style="taper", half_window_style="hann", threshold="low cut", transition_lower_bound=kwargs.get( "transition_lower_bound", 0.6 ), transition_upper_bound=kwargs.get( "transition_upper_bound", 0.9 ), ) ], ) ], ) channel_weight_specs.append(cws) return channel_weight_specs
[docs] def create_config( self, kernel_dataset: KernelDataset | None = None, decimation_kwargs: dict = {}, add_coherence_weights: bool = False, **kwargs, ) -> Processing: """ Create Aurora processing configuration. Parameters ---------- kernel_dataset : KernelDataset or None, optional Kernel dataset defining processing runs, by default None. decimation_kwargs : dict, optional Decimation parameters including window settings, by default {}. add_coherence_weights : bool, optional Whether to add coherence-based weights, by default True. **kwargs : dict Additional configuration parameters. Returns ------- Processing Aurora configuration object. Raises ------ ValueError If kernel_dataset is None and no kernel dataset exists. """ if kernel_dataset is None: if self.has_kernel_dataset(): if self.sample_rate > 1000: decimation_kwargs.update(self.default_window_parameters["high"]) else: decimation_kwargs.update(self.default_window_parameters["low"]) else: raise ValueError( "Cannot make config because KernelDataset has not been set yet." ) else: if kernel_dataset.sample_rate > 1000: decimation_kwargs.update(self.default_window_parameters["high"]) else: decimation_kwargs.update(self.default_window_parameters["low"]) # need to pass the number of samples in the window to correctly set the bands kwargs["num_samples_window"] = decimation_kwargs["stft.window.num_samples"] cc = ConfigCreator() config = cc.create_from_kernel_dataset(kernel_dataset, **kwargs) self._set_decimation_level_parameters( config, add_coherence_weights=add_coherence_weights, **decimation_kwargs ) return config
def _set_decimation_level_parameters( self, config: Processing, add_coherence_weights: bool = False, **kwargs ) -> None: """ Set decimation level parameters for all decimation bands. Parameters ---------- config : Processing Aurora configuration object to modify. add_coherence_weights : bool, optional Whether to add coherence-based channel weights, by default True. **kwargs : dict Key-value pairs to update in each decimation level. """ for decimation in config.decimations: for key, value in kwargs.items(): decimation.update_attribute(key, value) if add_coherence_weights: channel_weight_specs = self.add_simple_coherence_weights() decimation.channel_weight_specs = channel_weight_specs def _initialize_kernel_dataset(self, sample_rate: float | None = None) -> None: """ Initialize a kernel dataset. Parameters ---------- sample_rate : float or None, optional Sample rate to use, by default None (uses first available). """ if not self.has_run_summary(): self.run_summary = self.get_run_summary() if sample_rate is not None: run_summary = self.run_summary.set_sample_rate(sample_rate) else: # have to use a single sample rate otherwise an error is thrown. run_summary = self.run_summary.set_sample_rate( self.run_summary.df.sample_rate.unique()[0] ) self.from_run_summary(run_summary)
[docs] def create_kernel_dataset( self, run_summary: RunSummary | None = None, local_station_id: str | None = None, remote_station_id: str | None = None, sample_rate: float | None = None, ) -> KernelDataset: """ Build and return a KernelDataset. Parameters ---------- run_summary : RunSummary or None, optional Run summary to use, by default None (creates from MTH5). local_station_id : str or None, optional Local station identifier, by default None. remote_station_id : str or None, optional Remote reference station identifier, by default None. sample_rate : float or None, optional Sample rate to filter runs, by default None. Returns ------- KernelDataset Kernel dataset defining processing configuration. """ if run_summary is None: if not self.has_run_summary(): run_summary = self.get_run_summary() else: run_summary = self.run_summary if sample_rate is not None: run_summary = run_summary.set_sample_rate(sample_rate) self.from_run_summary( run_summary, local_station_id=local_station_id, remote_station_id=remote_station_id, sample_rate=sample_rate, ) return self.clone()
[docs] def process_single_sample_rate( self, sample_rate: float, config: Processing | None = None, kernel_dataset: KernelDataset | None = None, plot: bool = False, ) -> MT | None: """ Process a single sample rate to generate transfer functions. Parameters ---------- sample_rate : float Sample rate of time series data to process. config : Processing or None, optional Aurora configuration object, by default None (creates from kernel_dataset). kernel_dataset : KernelDataset or None, optional Kernel dataset defining processing runs, by default None (creates from run summary). Returns ------- MT or None Transfer function object, or None if processing fails. """ if kernel_dataset is None: kernel_dataset = self.create_kernel_dataset( local_station_id=self.local_station_id, remote_station_id=self.remote_station_id, sample_rate=sample_rate, ) if config is None: config = self.create_config(kernel_dataset=kernel_dataset) try: tf_obj = process_mth5(config, kernel_dataset, show_plot=plot) except Exception as error: close_open_files() logger.exception(error) logger.error(f"Skipping sample_rate {sample_rate}") return tf_obj.tf_id = self.processing_id # copy to an MT object using deep copy to avoid metadata references mt_obj = MT() mt_obj.survey_metadata.update(tf_obj.survey_metadata) mt_obj.station_metadata.update(tf_obj.station_metadata) mt_obj.channel_nomenclature = tf_obj.channel_nomenclature mt_obj._transfer_function = tf_obj._transfer_function return mt_obj
[docs] def process( self, sample_rates: float | list[float] | None = None, processing_dict: ( dict[float, dict[str, Processing | KernelDataset]] | None ) = None, merge: bool = True, save_to_mth5: bool = True, plot: bool = False, ) -> dict[float | str, dict[str, bool | MT]]: """ Process magnetotelluric data at multiple sample rates. Parameters ---------- sample_rates : float, list of float, or None, optional Sample rate(s) to process, by default None. processing_dict : dict or None, optional Dictionary mapping sample rates to config and kernel_dataset. Format: {sample_rate: {'config': Processing, 'kernel_dataset': KernelDataset}} By default None. merge : bool, optional Whether to merge all sample rates into a single transfer function according to merge_dict, by default True. save_to_mth5 : bool, optional Whether to save transfer functions to local MTH5 file, by default True. Returns ------- dict[float or str, dict[str, bool or MT]] Dictionary with sample rates and 'combined' as keys, each containing {'processed': bool, 'tf': MT or None}. Raises ------ ValueError If neither sample_rates nor processing_dict is provided. TypeError If sample_rates or processing_dict is not the correct format. Notes ----- If merge is True and multiple sample rates are processed, a 'combined' key is added with the merged transfer function. Examples -------- >>> ap = AuroraProcessing() >>> ap.local_station_id = "mt01" >>> ap.local_mth5_path = "data.h5" >>> results = ap.process(sample_rates=[1, 4], merge=True) """ if sample_rates is None and processing_dict is None: raise ValueError("Must set either sample rates or processing_dict.") if processing_dict is None: if isinstance(sample_rates, (int, float)): sample_rates = [sample_rates] elif isinstance(sample_rates, (list, tuple, np.ndarray)): sample_rates = list(sample_rates) else: raise TypeError( "Sample rates are incorrect type. Expected an int or " f"list not {type(sample_rates)}" ) tf_processed = dict( [(sr, {"processed": False, "tf": None}) for sr in sample_rates] ) for sr in sample_rates: try: mt_obj = self.process_single_sample_rate(sr) except Exception as e: logger.error(e) logger.error(f"Skipping sample rate {sr}") logger.exception(e) continue if mt_obj is not None: tf_processed[sr]["processed"] = True tf_processed[sr]["tf"] = mt_obj else: self._validate_processing_dict(processing_dict) tf_processed = dict( [ (sr, {"processed": False, "tf": None}) for sr in processing_dict.keys() ] ) for key, pdict in processing_dict.items(): logger.info(f"Processing sample rate {key}.") try: mt_obj = self.process_single_sample_rate( key, config=pdict["config"], kernel_dataset=pdict["kernel_dataset"], plot=plot, ) except Exception as e: logger.error(e) logger.error(f"Skipping sample rate {key}") logger.exception(e) continue if mt_obj is not None: tf_processed[key]["processed"] = mt_obj.has_transfer_function() tf_processed[key]["tf"] = mt_obj logger.info(f"Finished processing sample rate {key}.") processed = self._validate_tf_processed_dict(tf_processed) if len(processed.keys()) > 1: if merge: ### merge transfer functions according to merge dict combined_tf = self.merge_transfer_functions(processed) combined_tf_id = self.local_station_id if self.remote_station_id: combined_tf_id += f"_rr_{self.remote_station_id}" combined_tf_id += "_combined" combined_tf.tf_id = combined_tf_id processed["combined"] = {"processed": True, "tf": combined_tf} else: processed["combined"] = processed[list(processed.keys())[0]] if save_to_mth5: ### add tf to local MTH5 self._add_tf_to_local_mth5(processed) return processed
def _validate_config(self, config: Processing) -> None: """ Validate configuration object type. Parameters ---------- config : Processing Configuration object to validate. Raises ------ TypeError If config is not a Processing object. """ if not isinstance(config, Processing): raise TypeError( "Config must be a aurora.config.metadata.Processing object. " f"Got type {type(config)}" ) def _validate_kernel_dataset(self, kernel_dataset: KernelDataset) -> None: """ Validate kernel dataset object type. Parameters ---------- kernel_dataset : KernelDataset Kernel dataset object to validate. Raises ------ TypeError If kernel_dataset is not a KernelDataset object. """ if not isinstance(kernel_dataset, KernelDataset): raise TypeError( "Config must be a mtpy.processing.KernelDataset object. " f"Got type {type(kernel_dataset)}" ) def _validate_processing_dict( self, processing_dict: dict[float, dict[str, Processing | KernelDataset]] ) -> None: """ Validate the processing dictionary format. Parameters ---------- processing_dict : dict Dictionary mapping sample rates to config and kernel_dataset. Raises ------ TypeError If processing_dict or its values are not dictionaries. KeyError If required keys are missing from processing dictionary. """ error_msg = "Format is {sample_rate: {'config': config object, " "'kernel_dataset': KernelDataset object}" if not isinstance(processing_dict, dict): raise TypeError( "Input processing_dict must be a dictionary. " f"Got type {type(processing_dict)}." ) for key, pdict in processing_dict.items(): if not isinstance(pdict, dict): raise TypeError( "Input processing_dict must be a dictionary. " f"Got type {type(pdict)}. " + error_msg ) if sorted(self._processing_dict_keys) != sorted(pdict.keys()): raise KeyError( "Processing dict can only have keys " f"{self._processing_dict_keys}. " + error_msg ) self._validate_config(pdict["config"]) self._validate_kernel_dataset(pdict["kernel_dataset"]) def _validate_tf_processed_dict( self, tf_dict: dict[float, dict[str, bool | MT]] ) -> dict[float, dict[str, bool | MT]]: """ Filter processed transfer functions from processing dictionary. Parameters ---------- tf_dict : dict Dictionary of processed transfer functions with format: {sample_rate: {'processed': bool, 'tf': MT or None}}. Returns ------- dict Dictionary containing only successfully processed transfer functions. Raises ------ ValueError If no transfer functions were processed successfully. """ new_dict = {} for key, p_dict in tf_dict.items(): if p_dict["processed"]: new_dict[key] = p_dict else: logger.warning( f"Sample rate {key} was not processed correctly. Check log." ) if new_dict == {}: raise ValueError("No Transfer Functions were processed.") return new_dict def _add_tf_to_local_mth5( self, tf_dict: dict[float | str, dict[str, bool | MT]] ) -> None: """ Add transfer functions to the local MTH5 file. Parameters ---------- tf_dict : dict Dictionary of transfer functions to add. """ with MTH5() as m: m.open_mth5(self.local_mth5_path) for p_dict in tf_dict.values(): m.add_transfer_function(p_dict["tf"]) def _get_merge_tf_list( self, tf_dict: dict[float, dict[str, bool | MT]] ) -> list[dict[str, MT | float]]: """ Prepare transfer functions list for merging with period constraints. Parameters ---------- tf_dict : dict Dictionary of processed transfer functions. Returns ------- list[dict] List of dictionaries containing transfer functions and their period min/max values for merging. """ merge_df = self._get_merge_df() merge_list = [] for key, pdict in self._validate_tf_processed_dict(tf_dict).items(): if key in merge_df.sample_rate.tolist(): row = merge_df.loc[merge_df.sample_rate == key] period_min = row.period_min.iloc[0] period_max = row.period_max.iloc[0] else: period_min = pdict["tf"].period.min() period_max = pdict["tf"].period.max() merge_list.append( { "tf": pdict["tf"], "period_min": period_min, "period_max": period_max, } ) return merge_list
[docs] def merge_transfer_functions( self, tf_dict: dict[float, dict[str, bool | MT]] ) -> MT: """ Merge multiple transfer functions according to merge_dict. Parameters ---------- tf_dict : dict Dictionary of transfer functions to merge. Returns ------- MT Merged transfer function combining all sample rates. """ merge_list = self._get_merge_tf_list(tf_dict) if len(merge_list) > 1: return merge_list[0]["tf"].merge( merge_list[1:], period_max=merge_list[0]["period_max"] ) else: return merge_list[0]["tf"]