# -*- coding: utf-8 -*-
"""
Scaffold for a tree-backed MT data container.
This class is an outline for migrating from OrderedDict-based MTData to an
Xarray tree representation for better scalability.
"""
from __future__ import annotations
import importlib
from copy import deepcopy
from pathlib import Path
from typing import Any, Callable, TYPE_CHECKING
import numpy as np
import pandas as pd
import xarray as xr
from loguru import logger
from mtpy.core.transfer_function import IMPEDANCE_UNITS
from mtpy.imaging import (
PlotMultipleResponses,
PlotPenetrationDepthMap,
PlotPhaseTensorMaps,
PlotPhaseTensorPseudoSection,
PlotResidualPTMaps,
PlotResPhaseMaps,
PlotResPhasePseudoSection,
PlotStations,
PlotStrike,
)
from mtpy.modeling.errors import ModelErrors
from . import mt_data_accessor as _mt_data_accessor # noqa: F401
from .mt_data_tree_index import MTDataTreeIndexStore
from .mt_dataframe import MTDataFrame
COORDINATE_REFERENCE_FRAME_OPTIONS = {
"+": "ned",
"-": "enu",
"z+": "ned",
"z-": "enu",
"nez+": "ned",
"enz-": "enu",
"ned": "ned",
"enu": "enu",
"exp(+ i\\omega t)": "ned",
"exp(+i\\omega t)": "ned",
"exp(- i\\omega t)": "enu",
"exp(-i\\omega t)": "enu",
None: "ned",
}
if TYPE_CHECKING:
from .mt import MT
from .mt_stations import MTStations
[docs]
class MTData:
"""
Tree-backed container for MT collection data.
Notes
-----
Composition is intentionally used instead of inheriting from xarray's tree
type. This keeps the public MT API independent from xarray internals and
allows controlled migration from MTData.
"""
ROOT_NAME = "root"
SURVEYS_NODE = "surveys"
STATIONS_NODE = "stations"
METADATA_STORAGE_MODES = {"dict", "summary", "cache"}
DATASET_COPY_MODES = {"deep", "shallow", "none"}
COORDINATE_REFERENCE_FRAME = COORDINATE_REFERENCE_FRAME_OPTIONS
IMPEDANCE_UNITS = IMPEDANCE_UNITS
def __init__(
self,
tree: Any | None = None,
metadata_storage: str = "cache",
dataset_copy_mode: str = "shallow",
use_index: bool = False,
index_db_path: str = ":memory:",
**attrs: Any,
) -> None:
"""Initialize an MTData container.
Parameters
----------
tree : Any, optional
Existing tree-like object, typically an ``xarray.DataTree``.
When ``None``, an empty tree with a root dataset is created.
metadata_storage : {'dict', 'summary', 'cache'}, optional
Strategy used to store station and survey metadata in dataset
attributes.
dataset_copy_mode : {'deep', 'shallow', 'none'}, optional
Default dataset copy behavior used when adding stations.
use_index : bool, optional
If ``True``, enable an SQLite-backed station/period index for fast
geographic and period queries.
index_db_path : str, optional
SQLite database path used by the index.
**attrs
Additional root-level attributes stored on ``self.tree.attrs``.
Raises
------
ValueError
If *metadata_storage* or *dataset_copy_mode* is not a supported
option.
"""
storage_mode = str(metadata_storage).strip().lower()
if storage_mode not in self.METADATA_STORAGE_MODES:
raise ValueError(
"metadata_storage must be one of "
f"{sorted(self.METADATA_STORAGE_MODES)}"
)
self.metadata_storage = storage_mode
copy_mode = str(dataset_copy_mode).strip().lower()
if copy_mode not in self.DATASET_COPY_MODES:
raise ValueError(
"dataset_copy_mode must be one of " f"{sorted(self.DATASET_COPY_MODES)}"
)
self.dataset_copy_mode = copy_mode
# Optional in-memory metadata cache keyed by station tree path.
self._metadata_cache: dict[str, dict[str, Any]] = {
"survey": {},
"station": {},
}
# Optional SQLite-backed index for fast geographic / period queries.
self._index: MTDataTreeIndexStore | None = (
MTDataTreeIndexStore(index_db_path) if use_index else None
)
self._index_db_path = index_db_path
self._lazy_use_index = use_index
# Deferred station-level transforms keyed by station path.
self._lazy_station_transforms: dict[str, Callable[[], xr.Dataset]] = {}
self.tree = (
tree
if tree is not None
else xr.DataTree(name=self.ROOT_NAME, dataset=xr.Dataset())
)
# Keep root metadata lightweight and schema-focused at initialization.
self.tree.attrs.setdefault("schema_name", "mtpy.mt_data_tree")
self.tree.attrs.setdefault("schema_version", "0.1.0")
self.tree.attrs.update(attrs)
self.attrs = self.tree.attrs
self._coordinate_reference_frame_options = dict(self.COORDINATE_REFERENCE_FRAME)
self._coordinate_reference_frame = "+"
self._impedance_unit_factors = dict(self.IMPEDANCE_UNITS)
self._impedance_units = "mt"
self.data_rotation_angle = 0
self.model_parameters: dict[str, Any] = {}
self._center_lat = None
self._center_lon = None
self._center_elev = 0.0
self.z_model_error = ModelErrors(
error_value=5,
error_type="geometric_mean",
floor=True,
mode="impedance",
)
self.t_model_error = ModelErrors(
error_value=0.02,
error_type="absolute",
floor=True,
mode="tipper",
)
# Initialize a predictable top-level path for survey grouping.
if self.SURVEYS_NODE not in self.tree.children:
self.tree[self.SURVEYS_NODE] = xr.DataTree(
name=self.SURVEYS_NODE, dataset=xr.Dataset()
)
self.coordinate_reference_frame = self.attrs.get(
"coordinate_reference_frame", "ned"
)
self.impedance_units = self.attrs.get("impedance_units", "mt")
def __deepcopy__(self, memo: dict) -> "MTData":
"""Create a deep copy of MTData object."""
copied_tree = self.tree.copy(deep=True)
copied = self.__class__(
tree=copied_tree,
metadata_storage=self.metadata_storage,
dataset_copy_mode=self.dataset_copy_mode,
use_index=False,
index_db_path=self._index_db_path,
**dict(self.attrs),
)
memo[id(self)] = copied
copied._metadata_cache = deepcopy(self._metadata_cache, memo)
copied._lazy_station_transforms = dict(self._lazy_station_transforms)
copied._lazy_use_index = self._lazy_use_index
if self._index is not None and not copied.is_lazy:
copied.rebuild_index(index_db_path=self._index_db_path)
return copied
@property
def station_paths(self) -> list[str]:
"""Return sorted station paths present in the tree."""
return sorted(self._iter_station_paths())
@property
def short_station_paths(self) -> list[str]:
"""Return sorted station paths in ``survey/station`` form."""
return sorted(
[
f"{parts[1]}/{parts[3]}"
for path in self.station_paths
if isinstance(path, str)
for parts in [path.split("/")]
if len(parts) >= 4
]
)
@property
def datum_epsg(self) -> int | None:
"""Return the root datum EPSG code when available."""
value = self.attrs.get("datum_crs", self.attrs.get("datum_epsg"))
epsg = self._coerce_epsg_value(value)
if epsg is None:
return None
if str(epsg).isdigit():
return int(epsg)
return None
@datum_epsg.setter
def datum_epsg(self, value: Any) -> None:
"""Set root datum CRS/EPSG."""
self.datum_crs = value
@property
def datum_crs(self) -> Any | None:
"""Return the root datum CRS/EPSG value."""
return self.attrs.get("datum_crs", self.attrs.get("datum_epsg"))
@datum_crs.setter
def datum_crs(self, value: Any) -> None:
"""Set root datum CRS/EPSG."""
if value in [None, "", "None", "none", "null"]:
self.attrs.pop("datum_crs", None)
self.attrs.pop("datum_epsg", None)
return
self.attrs["datum_crs"] = value
epsg = self._coerce_epsg_value(value)
if epsg is not None:
self.attrs["datum_epsg"] = epsg
@property
def utm_epsg(self) -> int | None:
"""Return the root UTM EPSG code when available."""
value = self.attrs.get("utm_crs", self.attrs.get("utm_epsg"))
epsg = self._coerce_epsg_value(value)
if epsg is None:
return None
if str(epsg).isdigit():
return int(epsg)
return None
@utm_epsg.setter
def utm_epsg(self, value: Any) -> None:
"""Set root UTM CRS/EPSG and propagate to all station attrs."""
self.utm_crs = value
@property
def utm_crs(self) -> Any | None:
"""Return the root UTM CRS/EPSG value used for station projections."""
return self.attrs.get("utm_crs", self.attrs.get("utm_epsg"))
@utm_crs.setter
def utm_crs(self, value: Any) -> None:
"""Set root UTM CRS/EPSG and refresh station location attrs."""
if value in [None, "", "None", "none", "null"]:
self.attrs.pop("utm_crs", None)
self.attrs.pop("utm_epsg", None)
return
self.attrs["utm_crs"] = value
epsg = self._coerce_epsg_value(value)
if epsg is not None:
self.attrs["utm_epsg"] = epsg
self._apply_utm_crs_to_station_attrs(value)
def _apply_utm_crs_to_station_attrs(self, utm_crs: Any) -> None:
"""Apply a root UTM CRS/EPSG to all station attrs and recompute EN."""
from .mt_location import MTLocation
for station_path in self._iter_station_paths():
station = self.get_station(station_path)
attrs = station.attrs
attrs["utm_crs"] = utm_crs
latitude = attrs.get("latitude")
longitude = attrs.get("longitude")
if latitude in [None, "", "None", "none", "null"]:
continue
if longitude in [None, "", "None", "none", "null"]:
continue
try:
point = MTLocation(
latitude=float(latitude),
longitude=float(longitude),
utm_crs=utm_crs,
)
attrs["easting"] = float(point.east)
attrs["northing"] = float(point.north)
except Exception:
continue
@property
def survey_names(self) -> list[str]:
"""Return sorted survey names inferred from station paths."""
return sorted(
{
path.split("/")[1]
for path in self.station_paths
if isinstance(path, str) and path.count("/") >= 3
}
)
def __repr__(self) -> str:
"""Return a concise constructor-like summary for debugging."""
station_paths = self.station_paths
survey_names = self.survey_names
index_enabled = self._index is not None or self._lazy_use_index
return (
"MTData("
f"stations={len(station_paths)}, "
f"surveys={len(survey_names)}, "
f"lazy_stations={self.lazy_station_count}, "
f"metadata_storage='{self.metadata_storage}', "
f"dataset_copy_mode='{self.dataset_copy_mode}', "
f"index_enabled={index_enabled}"
")"
)
def __str__(self) -> str:
"""Return a human-readable summary of tree content and paths."""
station_paths = self.station_paths
survey_names = self.survey_names
preview_limit = 8
preview_paths = station_paths[:preview_limit]
index_enabled = self._index is not None or self._lazy_use_index
lines = [
"MTData Summary",
f" stations: {len(station_paths)}",
f" surveys: {len(survey_names)}",
f" lazy stations: {self.lazy_station_count}",
f" index enabled: {index_enabled}",
f" metadata storage: {self.metadata_storage}",
f" dataset copy mode: {self.dataset_copy_mode}",
f" impedance units: {self.impedance_units}",
(" coordinate reference frame: " f"{self.coordinate_reference_frame}"),
" survey names:",
]
if survey_names:
lines.extend([f" - {name}" for name in survey_names])
else:
lines.append(" - <none>")
lines.append(" station paths:")
if preview_paths:
lines.extend([f" - {path}" for path in preview_paths])
if len(station_paths) > preview_limit:
lines.append(f" - ... ({len(station_paths) - preview_limit} more)")
else:
lines.append(" - <none>")
return "\n".join(lines)
def __add__(self, other: Any) -> "MTData":
"""Return a new tree containing stations from ``self`` and ``other``.
Notes
-----
Existing station paths from ``self`` are overwritten by ``other`` when
duplicates are found. A warning is emitted for each overwritten path.
"""
if not isinstance(other, MTData):
return NotImplemented
merged = self.copy()
other.compute()
for station_path in other._iter_station_paths():
if merged._path_exists(station_path):
logger.warning(
"Overwriting existing station path during MTData merge: {}",
station_path,
)
station_ds = other.get_station(station_path).copy(deep=True)
merged._set_station_dataset(station_path, station_ds)
# Replace cached metadata entries for overwritten/new paths.
merged._clear_cached_metadata(station_path)
for metadata_kind in ["survey", "station"]:
cached_md = other._metadata_cache[metadata_kind].get(station_path)
if cached_md is not None:
merged._metadata_cache[metadata_kind][station_path] = deepcopy(
cached_md
)
if merged._index is not None and not merged.is_lazy:
merged.rebuild_index(index_db_path=merged._index_db_path)
return merged
[docs]
def copy(self) -> "MTData":
"""Create a deep copy of MTData object."""
return deepcopy(self)
[docs]
def clone_empty(self) -> "MTData":
"""Create a copy of MTData excluding all station datasets."""
return self.__class__(
metadata_storage=self.metadata_storage,
dataset_copy_mode=self.dataset_copy_mode,
use_index=self._index is not None or self._lazy_use_index,
index_db_path=self._index_db_path,
**dict(self.attrs),
)
@staticmethod
def _metadata_to_dict(metadata: Any) -> dict[str, Any]:
"""Safely convert mt_metadata objects to dictionaries."""
if metadata is None:
return {}
if hasattr(metadata, "to_dict"):
for kwargs in ({"single": True}, {}):
try:
out = metadata.to_dict(**kwargs)
if isinstance(out, dict):
return out
except TypeError:
continue
return {}
@staticmethod
def _metadata_to_summary(metadata: Any) -> dict[str, Any]:
"""Build a lightweight metadata summary for fast per-station attrs."""
if metadata is None:
return {}
md_id = getattr(metadata, "id", None)
if md_id in [None, "", "None", "none", "null"]:
return {}
return {"id": str(md_id)}
def _serialize_metadata(self, metadata: Any) -> dict[str, Any]:
"""Serialize metadata according to configured storage mode."""
if self.metadata_storage == "dict":
return self._metadata_to_dict(metadata)
return self._metadata_to_summary(metadata)
def _metadata_ref(self, station_path: str, metadata: Any) -> str | None:
"""Return metadata reference key for cache mode without mutating cache."""
if self.metadata_storage != "cache" or metadata is None:
return None
return station_path
def _commit_cached_metadata(
self,
station_path: str,
survey_metadata: Any,
station_metadata: Any,
) -> None:
"""Persist metadata objects in the in-memory cache after successful insert."""
if self.metadata_storage != "cache":
return
if survey_metadata is not None:
self._metadata_cache["survey"][station_path] = survey_metadata
if station_metadata is not None:
self._metadata_cache["station"][station_path] = station_metadata
def _clear_cached_metadata(self, node_path: str) -> None:
"""Remove cached metadata for one node path or an entire subtree prefix."""
for metadata_kind in ["survey", "station"]:
keys_to_remove = [
key
for key in self._metadata_cache[metadata_kind]
if key == node_path or key.startswith(f"{node_path}/")
]
for key in keys_to_remove:
self._metadata_cache[metadata_kind].pop(key, None)
def _resolve_dataset_copy_mode(self, dataset_copy_mode: str | None) -> str:
"""Resolve copy mode from call-level override or instance default."""
mode = (
self.dataset_copy_mode if dataset_copy_mode is None else dataset_copy_mode
)
mode = str(mode).strip().lower()
if mode not in self.DATASET_COPY_MODES:
raise ValueError(
"dataset_copy_mode must be one of " f"{sorted(self.DATASET_COPY_MODES)}"
)
return mode
@staticmethod
def _copy_station_dataset(station_ds: xr.Dataset, mode: str) -> xr.Dataset:
"""Copy station dataset according to selected copy mode."""
if mode == "none":
return station_ds
if mode == "deep":
return station_ds.copy(deep=True)
return station_ds.copy(deep=False)
def _extract_station_dataset(
self, mt_obj: "MT", dataset_copy_mode: str | None = None
) -> xr.Dataset:
"""Extract an xarray.Dataset from MT object transfer function."""
tf_obj = getattr(mt_obj, "_transfer_function", None)
if tf_obj is None:
raise TypeError("MT object is missing _transfer_function")
if isinstance(tf_obj, xr.Dataset):
source_ds = tf_obj
elif hasattr(tf_obj, "to_xarray"):
source_ds = tf_obj.to_xarray()
elif hasattr(tf_obj, "_dataset") and isinstance(tf_obj._dataset, xr.Dataset):
source_ds = tf_obj._dataset
else:
raise TypeError("Could not extract xarray.Dataset from MT object")
copy_mode = self._resolve_dataset_copy_mode(dataset_copy_mode)
return self._copy_station_dataset(source_ds, copy_mode)
def _build_station_attrs(
self,
mt_obj: "MT",
survey: str,
station: str,
survey_metadata: Any,
station_metadata: Any,
survey_metadata_ref: str | None,
station_metadata_ref: str | None,
) -> dict[str, Any]:
"""Build default station attrs payload for one MT object."""
return {
"survey": survey,
"station": station,
"tf_id": getattr(mt_obj, "tf_id", station),
"latitude": getattr(mt_obj, "latitude", None),
"longitude": getattr(mt_obj, "longitude", None),
"elevation": getattr(mt_obj, "elevation", None),
"datum_crs": getattr(mt_obj, "datum_crs", None),
"utm_crs": self._get_utm_crs(mt_obj),
"easting": getattr(mt_obj, "east", None),
"northing": getattr(mt_obj, "north", None),
"model_east": getattr(mt_obj, "model_east", 0.0),
"model_north": getattr(mt_obj, "model_north", 0.0),
"model_elevation": getattr(mt_obj, "model_elevation", 0.0),
"profile_offset": getattr(mt_obj, "profile_offset", 0.0),
"coordinate_reference_frame": getattr(
mt_obj, "coordinate_reference_frame", None
),
"impedance_units": getattr(mt_obj, "impedance_units", None),
"survey_metadata": self._serialize_metadata(survey_metadata),
"station_metadata": self._serialize_metadata(station_metadata),
"survey_metadata_ref": survey_metadata_ref,
"station_metadata_ref": station_metadata_ref,
}
def _build_station_attrs_from_precomputed(
self,
mt_obj: "MT",
survey: str,
station: str,
survey_metadata: Any,
station_metadata: Any,
survey_metadata_ref: str | None,
station_metadata_ref: str | None,
precomputed_attrs: dict[str, Any],
) -> dict[str, Any]:
"""Build attrs from precomputed payload plus required canonical keys."""
station_attrs = dict(precomputed_attrs)
station_attrs["survey"] = survey
station_attrs["station"] = station
station_attrs.setdefault("tf_id", getattr(mt_obj, "tf_id", station))
station_attrs.setdefault(
"survey_metadata", self._serialize_metadata(survey_metadata)
)
station_attrs.setdefault(
"station_metadata", self._serialize_metadata(station_metadata)
)
station_attrs["survey_metadata_ref"] = survey_metadata_ref
station_attrs["station_metadata_ref"] = station_metadata_ref
return station_attrs
def _coerce_and_prepare_station(
self,
mt_obj: "MT | str | Path",
dataset_copy_mode: str | None = None,
precomputed_attrs: dict[str, Any] | None = None,
) -> tuple[str, str, xr.Dataset, dict[str, Any]]:
"""Coerce station input and build station path/dataset payload."""
mt_obj = self._coerce_mt_object(mt_obj)
survey = self._clean_name(
getattr(mt_obj, "survey", None)
or getattr(getattr(mt_obj, "survey_metadata", None), "id", None),
"default",
)
station = self._clean_name(
getattr(mt_obj, "station", None)
or getattr(getattr(mt_obj, "station_metadata", None), "id", None),
"unknown_station",
)
station_path = self._station_path(survey, station)
station_ds = self._extract_station_dataset(
mt_obj, dataset_copy_mode=dataset_copy_mode
)
survey_metadata_obj = getattr(mt_obj, "survey_metadata", None)
station_metadata_obj = getattr(mt_obj, "station_metadata", None)
survey_metadata_ref = self._metadata_ref(station_path, survey_metadata_obj)
station_metadata_ref = self._metadata_ref(station_path, station_metadata_obj)
if precomputed_attrs is None:
station_attrs = self._build_station_attrs(
mt_obj,
survey,
station,
survey_metadata_obj,
station_metadata_obj,
survey_metadata_ref,
station_metadata_ref,
)
else:
station_attrs = self._build_station_attrs_from_precomputed(
mt_obj,
survey,
station,
survey_metadata_obj,
station_metadata_obj,
survey_metadata_ref,
station_metadata_ref,
precomputed_attrs,
)
station_ds.attrs.update(station_attrs)
return (
station_path,
station,
station_ds,
{
"survey": survey_metadata_obj,
"station": station_metadata_obj,
},
)
def _cache_metadata(
self, station_path: str, metadata_kind: str, metadata: Any
) -> str | None:
"""Cache full metadata object in-memory and return reference key."""
if self.metadata_storage != "cache" or metadata is None:
return None
self._metadata_cache[metadata_kind][station_path] = metadata
return station_path
@property
def metadata_cache(self) -> dict[str, dict[str, Any]]:
"""In-memory metadata map keyed by station path for cache mode."""
return self._metadata_cache
@property
def is_lazy(self) -> bool:
"""True when one or more deferred station transforms are pending."""
return bool(self._lazy_station_transforms)
@property
def lazy_station_count(self) -> int:
"""Number of stations with pending deferred transforms."""
return len(self._lazy_station_transforms)
@property
def coordinate_reference_frame(self) -> str:
"""
Coordinate reference frame.
Returns
-------
str
Reference frame identifier ('NED' or 'ENU')
"""
return self._coordinate_reference_frame_options[
self._coordinate_reference_frame
].upper()
@coordinate_reference_frame.setter
def coordinate_reference_frame(self, value: str) -> None:
"""
Set coordinate reference frame.
Parameters
----------
value : str
Reference frame identifier. Options:
- 'NED': x=North, y=East, z=+down
- 'ENU': x=East, y=North, z=+up
Raises
------
ValueError
If value is not a recognized reference frame
Notes
-----
Updates coordinate reference frame for all MT objects in collection
"""
if not isinstance(value, str):
raise TypeError("Coordinate reference frame input must be a string.")
normalized = value.strip().lower()
if normalized not in self._coordinate_reference_frame_options:
raise ValueError(
f"{value} is not understood as a reference frame. "
f"Options are {self._coordinate_reference_frame_options}"
)
if normalized in ["ned", "+"]:
normalized = "+"
elif normalized in ["enu", "-"]:
normalized = "-"
self._coordinate_reference_frame = normalized
self.attrs["coordinate_reference_frame"] = self.coordinate_reference_frame
for station_path in self._iter_station_paths():
station_ds = self.get_station(station_path)
station_ds.attrs["coordinate_reference_frame"] = (
self.coordinate_reference_frame
)
@property
def impedance_units(self) -> str:
"""
Impedance units.
Returns
-------
str
Impedance units ('mt' or 'ohm')
"""
return self._impedance_units
@impedance_units.setter
def impedance_units(self, value: str) -> None:
"""
Set impedance units.
Parameters
----------
value : str
Impedance units. Options: 'mt' [mV/km/nT] or 'ohm' [Ohms]
Raises
------
TypeError
If value is not a string
ValueError
If value is not 'mt' or 'ohm'
Notes
-----
Updates impedance units for all MT objects in collection
"""
if not isinstance(value, str):
raise TypeError("Units input must be a string.")
if value.lower() not in self._impedance_unit_factors.keys():
raise ValueError(f"{value} is not an acceptable unit for impedance.")
self._impedance_units = value.lower()
self.attrs["impedance_units"] = self._impedance_units
for station_path in self._iter_station_paths():
station_ds = self.get_station(station_path)
station_ds.attrs["impedance_units"] = self._impedance_units
def _realize_station(self, station_path: str) -> str | None:
"""Materialize one deferred station transform if present."""
transform = self._lazy_station_transforms.pop(station_path, None)
if transform is None:
return None
station_ds = transform()
self._set_station_dataset(station_path, station_ds)
if self._index is not None:
station_row, period_row = MTDataTreeIndexStore._extract_rows(
station_path,
station_ds,
)
if period_row is None:
self._index.delete_station_by_tree_path(station_path)
self._index.upsert_station(station_row)
if period_row is not None:
self._index.replace_station_period_rows(period_row)
return station_row.survey_name
return None
@staticmethod
def _is_dask_delayed(obj: Any) -> bool:
"""Return True when *obj* is a dask delayed object."""
return obj.__class__.__name__ == "Delayed" and hasattr(obj, "dask")
[docs]
def compute(
self,
station_paths: list[str] | None = None,
scheduler: str | None = None,
) -> "MTData":
"""Materialize deferred station transforms and refresh index state.
Parameters
----------
station_paths : list[str], optional
Subset of station tree paths to realize. When ``None``, all pending
lazy station transforms are computed.
scheduler : str, optional
Dask scheduler name passed through when delayed transforms are
present.
Returns
-------
MTData
The current tree instance.
"""
if not self._lazy_station_transforms:
return self
if self._index is None and self._lazy_use_index:
self._index = MTDataTreeIndexStore(self._index_db_path)
pending_paths = list(self._lazy_station_transforms.keys())
station_paths = self._normalize_station_paths(station_paths)
if station_paths is None:
paths_to_realize = pending_paths
else:
requested = set(station_paths)
paths_to_realize = [path for path in pending_paths if path in requested]
realized_datasets: dict[str, xr.Dataset] = {}
delayed_paths: list[str] = []
delayed_objs: list[Any] = []
for station_path in paths_to_realize:
transform = self._lazy_station_transforms.pop(station_path, None)
if transform is None:
continue
out = transform()
if self._is_dask_delayed(out):
delayed_paths.append(station_path)
delayed_objs.append(out)
else:
realized_datasets[station_path] = out
if delayed_objs:
try:
dask = importlib.import_module("dask")
except ImportError as exc:
raise RuntimeError(
"Dask delayed transforms are pending but dask is not installed."
) from exc
computed = dask.compute(*delayed_objs, scheduler=scheduler)
for station_path, station_ds in zip(delayed_paths, computed):
realized_datasets[station_path] = station_ds
updated_surveys: set[str] = set()
for station_path in paths_to_realize:
station_ds = realized_datasets.get(station_path)
if station_ds is None:
continue
if not isinstance(station_ds, xr.Dataset):
raise TypeError(
"Deferred transform must return xr.Dataset, "
f"got {type(station_ds)!r}"
)
self._set_station_dataset(station_path, station_ds)
if self._index is not None:
station_row, period_row = MTDataTreeIndexStore._extract_rows(
station_path,
station_ds,
)
if period_row is None:
self._index.delete_station_by_tree_path(station_path)
self._index.upsert_station(station_row)
if period_row is not None:
self._index.replace_station_period_rows(period_row)
updated_surveys.add(station_row.survey_name)
if self._index is not None:
for survey_name in updated_surveys:
self._index.refresh_survey_aggregates(survey_name)
return self
[docs]
def persist(
self,
station_paths: list[str] | None = None,
scheduler: str | None = None,
) -> "MTData":
"""Alias for :meth:`compute`.
Parameters
----------
station_paths : list[str], optional
Station paths to realize.
scheduler : str, optional
Dask scheduler name.
Returns
-------
MTData
The current tree instance.
"""
return self.compute(station_paths=station_paths, scheduler=scheduler)
[docs]
def as_dask(
self,
chunks: dict[str, int] | str | None,
station_paths: list[str] | None = None,
variables: list[str] | None = None,
inplace: bool = False,
) -> "MTData":
"""Chunk station datasets to dask-backed arrays.
Parameters
----------
chunks : dict[str, int] or str or None
Chunk specification passed to xarray ``chunk``.
station_paths : list[str], optional
Subset of station paths to chunk.
variables : list[str], optional
Data-variable names to chunk. When ``None``, all variables are
chunked.
inplace : bool, optional
If ``True``, modify this tree in place. Otherwise return a chunked
subset copy.
Returns
-------
MTData
Chunked tree (same instance when *inplace* is ``True``).
Raises
------
RuntimeError
If dask is not installed.
KeyError
If a requested variable is missing for a selected station.
Examples
--------
>>> tree = tree.as_dask(chunks={"period": 32})
>>> tree = tree.as_dask(chunks="auto", variables=["transfer_function"])
"""
try:
importlib.import_module("dask.array")
except ImportError as exc:
raise RuntimeError("Dask is required for as_dask()") from exc
station_paths = self._normalize_station_paths(station_paths)
tree_obj = self if inplace else self.get_subset(self._iter_station_paths())
target_paths = tree_obj._iter_station_paths()
if station_paths is not None:
requested = set(station_paths)
target_paths = [path for path in target_paths if path in requested]
for station_path in target_paths:
station_ds = tree_obj.get_station(station_path)
if variables is not None:
missing = [
name for name in variables if name not in station_ds.data_vars
]
if missing:
raise KeyError(f"Variables not found for chunking: {missing}")
chunked_ds = station_ds.copy(deep=False)
for var_name in variables:
chunked_ds[var_name] = station_ds[var_name].chunk(chunks)
else:
chunked_ds = station_ds.chunk(chunks)
tree_obj._set_station_dataset(station_path, chunked_ds)
return tree_obj
[docs]
def rechunk(
self,
chunks: dict[str, int] | str | None,
station_paths: list[str] | None = None,
variables: list[str] | None = None,
inplace: bool = True,
) -> "MTData":
"""Rechunk station datasets.
Parameters
----------
chunks : dict[str, int] or str or None
Chunk specification passed to :meth:`as_dask`.
station_paths : list[str], optional
Subset of station paths to rechunk.
variables : list[str], optional
Data variables to rechunk.
inplace : bool, optional
If ``True`` (default), modify the current tree.
Returns
-------
MTData
Rechunked tree.
"""
return self.as_dask(
chunks=chunks,
station_paths=station_paths,
variables=variables,
inplace=inplace,
)
[docs]
def is_dask_backed(self, station_paths: list[str] | None = None) -> bool:
"""Check whether selected stations are dask-backed.
Parameters
----------
station_paths : list[str], optional
Station subset to inspect. When ``None``, all stations are checked.
Returns
-------
bool
``True`` only when each selected station has dask-backed arrays for
all data variables.
"""
station_paths = self._normalize_station_paths(station_paths)
self.compute(station_paths=station_paths)
target_paths = self._iter_station_paths()
if station_paths is not None:
requested = set(station_paths)
target_paths = [path for path in target_paths if path in requested]
if not target_paths:
return False
for station_path in target_paths:
station_ds = self.get_station(station_path)
for da in station_ds.data_vars.values():
if getattr(da.data, "chunks", None) is None:
return False
return True
[docs]
def chunk_plan(
self,
station_paths: list[str] | None = None,
) -> dict[str, dict[str, tuple[tuple[int, ...], ...] | None]]:
"""Return per-station chunk layout for each data variable.
Parameters
----------
station_paths : list[str], optional
Station subset to summarize.
Returns
-------
dict[str, dict[str, tuple[tuple[int, ...], ...] or None]]
Mapping from station path to variable chunk tuples (or ``None`` for
NumPy-backed variables).
"""
station_paths = self._normalize_station_paths(station_paths)
self.compute(station_paths=station_paths)
target_paths = self._iter_station_paths()
if station_paths is not None:
requested = set(station_paths)
target_paths = [path for path in target_paths if path in requested]
plan: dict[str, dict[str, tuple[tuple[int, ...], ...] | None]] = {}
for station_path in target_paths:
station_ds = self.get_station(station_path)
plan[station_path] = {
var_name: da.chunks for var_name, da in station_ds.data_vars.items()
}
return plan
[docs]
def map_stations(
self,
transform: Callable[[xr.Dataset], xr.Dataset],
station_paths: list[str] | None = None,
lazy: bool = True,
inplace: bool = False,
) -> "MTData":
"""Apply a dataset transform to selected stations.
Parameters
----------
transform : callable
Function receiving one station ``xr.Dataset`` and returning an
``xr.Dataset``.
station_paths : list[str], optional
Station subset to transform.
lazy : bool, optional
If ``True`` (default), register deferred transforms. If ``False``,
apply immediately.
inplace : bool, optional
If ``True``, mutate this tree. Otherwise return a transformed copy.
Returns
-------
MTData
Tree with registered or applied transforms.
Raises
------
TypeError
If *lazy* is ``False`` and *transform* does not return an
``xr.Dataset``.
Examples
--------
>>> def keep_short_periods(ds):
... return ds.sel(period=ds.period <= 10.0)
>>> out = tree.map_stations(keep_short_periods, lazy=False, inplace=False)
"""
station_paths = self._normalize_station_paths(station_paths)
tree_obj = self if inplace else self.get_subset(self._iter_station_paths())
tree_obj.compute()
target_paths = tree_obj._iter_station_paths()
if station_paths is not None:
requested = set(station_paths)
target_paths = [path for path in target_paths if path in requested]
for station_path in target_paths:
station_ds = tree_obj.get_station(station_path).copy(deep=False)
if lazy:
tree_obj._lazy_station_transforms[station_path] = (
lambda ds=station_ds, op=transform: op(ds)
)
if not lazy:
def _validated_transform(ds: xr.Dataset) -> xr.Dataset:
out_ds = transform(ds)
if not isinstance(out_ds, xr.Dataset):
raise TypeError(
"map_stations transform must return xr.Dataset, "
f"got {type(out_ds)!r}"
)
return out_ds
tree_obj.tree.mt.map_stations(
_validated_transform,
station_paths=target_paths,
inplace=True,
)
return tree_obj
[docs]
def interpolate_dask(
self,
new_periods: np.ndarray,
f_type: str = "period",
bounds_error: bool = True,
chunks: dict[str, int] | str | None = None,
scheduler: str | None = None,
compute: bool = True,
inplace: bool = False,
**kwargs: Any,
) -> "MTData":
"""Interpolate stations with dask-delayed execution.
Parameters
----------
new_periods : ndarray
Target periods or frequencies depending on *f_type*.
f_type : str, optional
``'period'`` (default) or ``'frequency'``/``'freq'``.
bounds_error : bool, optional
Restrict interpolation to each station's native period range.
chunks : dict[str, int] or str or None, optional
Optional chunking applied before creating delayed transforms.
scheduler : str, optional
Dask scheduler used during computation when *compute* is ``True``.
compute : bool, optional
If ``True`` (default), execute delayed transforms immediately.
inplace : bool, optional
If ``True``, modify this tree.
**kwargs
Forwarded to the station interpolation routine.
Returns
-------
MTData
Tree with interpolated results or pending delayed transforms.
"""
try:
dask = importlib.import_module("dask")
delayed = getattr(dask, "delayed")
except ImportError as exc:
raise RuntimeError("Dask is required for interpolate_dask()") from exc
base_tree = self if inplace else self.get_subset(self._iter_station_paths())
if chunks is not None:
base_tree.as_dask(chunks=chunks, inplace=True)
lazy_tree = base_tree.interpolate_lazy(
new_periods,
f_type=f_type,
inplace=True,
bounds_error=bounds_error,
**kwargs,
)
for station_path, transform in list(lazy_tree._lazy_station_transforms.items()):
lazy_tree._lazy_station_transforms[station_path] = (
lambda fn=transform: delayed(fn)()
)
if compute:
lazy_tree.compute(scheduler=scheduler)
elif scheduler is not None:
dask.config.set(scheduler=scheduler)
return lazy_tree
[docs]
def rotate_dask(
self,
rotation_angle: float | np.ndarray,
chunks: dict[str, int] | str | None = None,
scheduler: str | None = None,
compute: bool = True,
inplace: bool = False,
) -> "MTData":
"""Rotate stations using dask-delayed execution.
Parameters
----------
rotation_angle : float or ndarray
Rotation angle in degrees, scalar or per-period array.
chunks : dict[str, int] or str or None, optional
Optional chunking applied before creating delayed transforms.
scheduler : str, optional
Dask scheduler used when *compute* is ``True``.
compute : bool, optional
If ``True`` (default), execute delayed transforms immediately.
inplace : bool, optional
If ``True``, modify this tree.
Returns
-------
MTData
Tree with rotated results or pending delayed transforms.
"""
try:
dask = importlib.import_module("dask")
delayed = getattr(dask, "delayed")
except ImportError as exc:
raise RuntimeError("Dask is required for rotate_dask()") from exc
base_tree = self if inplace else self.get_subset(self._iter_station_paths())
if chunks is not None:
base_tree.as_dask(chunks=chunks, inplace=True)
def _rotate_transform(ds: xr.Dataset) -> xr.Dataset:
crf = ds.attrs.get(
"coordinate_reference_frame",
self.attrs.get("coordinate_reference_frame", "ned"),
)
return MTData._rotate_station_dataset(
ds,
rotation_angle,
coordinate_reference_frame=crf,
)
lazy_tree = base_tree.map_stations(
_rotate_transform,
lazy=True,
inplace=True,
)
for station_path, transform in list(lazy_tree._lazy_station_transforms.items()):
lazy_tree._lazy_station_transforms[station_path] = (
lambda fn=transform: delayed(fn)()
)
if compute:
lazy_tree.compute(scheduler=scheduler)
elif scheduler is not None:
dask.config.set(scheduler=scheduler)
return lazy_tree
[docs]
def finalize_index(self) -> None:
"""Recompute deferred stations and rebuild the index."""
self.compute()
self.rebuild_index(index_db_path=self._index_db_path)
def _hydrate_metadata_from_cache(
self, mt_obj: "MT", station_ds: xr.Dataset
) -> None:
"""Populate MT metadata objects from in-memory cache when references exist."""
attrs = station_ds.attrs
for metadata_kind in ["survey", "station"]:
ref_key = attrs.get(f"{metadata_kind}_metadata_ref")
if not isinstance(ref_key, str):
continue
cached_md = self._metadata_cache[metadata_kind].get(ref_key)
if cached_md is None:
continue
target_md = getattr(mt_obj, f"{metadata_kind}_metadata", None)
if target_md is None or not hasattr(target_md, "from_dict"):
continue
cached_dict = self._metadata_to_dict(cached_md)
if cached_dict:
target_md.from_dict(cached_dict)
@staticmethod
def _clean_name(value: Any, fallback: str) -> str:
"""Normalize path segment names for tree paths."""
name = str(value).strip() if value is not None else ""
if not name:
return fallback
return name.replace("/", "_")
def _station_path(self, survey: str, station: str) -> str:
"""Build canonical station path under /surveys."""
return f"{self.SURVEYS_NODE}/{survey}/{self.STATIONS_NODE}/{station}"
def _resolve_station_path(self, station_key: str) -> str:
"""Resolve a public station key to the canonical stored tree path."""
if not isinstance(station_key, str) or not station_key.strip():
raise KeyError("station_key must be a non-empty string")
key = station_key.strip().strip("/")
candidates: list[str] = []
def _append(candidate: str) -> None:
if candidate not in candidates:
candidates.append(candidate)
if key.startswith(f"{self.SURVEYS_NODE}/"):
_append(key)
if "." in key:
survey, station = key.split(".", 1)
_append(
self._station_path(
self._clean_name(survey, "default"),
self._clean_name(station, "unknown_station"),
)
)
if key.count("/") == 1:
survey, station = key.split("/", 1)
_append(
self._station_path(
self._clean_name(survey, "default"),
self._clean_name(station, "unknown_station"),
)
)
_append(key)
for candidate in candidates:
if (
self._path_exists(candidate)
or candidate in self._lazy_station_transforms
):
return candidate
raise KeyError(f"Station key not found: {station_key}")
def _normalize_station_paths(
self, station_paths: list[str] | None
) -> list[str] | None:
"""Normalize public station-path inputs while preserving no-match behavior."""
if station_paths is None:
return None
normalized: list[str] = []
for station_key in station_paths:
try:
normalized.append(self._resolve_station_path(station_key))
except KeyError:
if isinstance(station_key, str):
normalized.append(station_key.strip().strip("/"))
else:
normalized.append(station_key)
return normalized
@staticmethod
def _coerce_mt_object(mt_obj: "MT | str | Path") -> "MT":
"""Convert supported inputs to an MT instance."""
from .mt import MT
if isinstance(mt_obj, MT):
return mt_obj
if isinstance(mt_obj, (str, Path)):
m = MT(mt_obj)
m.read()
return m
raise TypeError(
"mt_obj must be an MT instance, filename string, or pathlib.Path"
)
def _path_exists(self, node_path: str) -> bool:
"""Check if a tree node path exists."""
try:
_ = self.tree[node_path]
return True
except KeyError:
return False
def _iter_station_paths(self) -> list[str]:
"""Return all station node paths under /surveys."""
if self._index is not None:
return self._index.all_station_paths()
station_paths: list[str] = []
def _walk(node: Any, node_path: str = "") -> None:
ds = getattr(node, "ds", None)
if isinstance(ds, xr.Dataset) and node_path.count("/") >= 3:
station_paths.append(node_path)
for child_name, child in getattr(node, "children", {}).items():
child_path = f"{node_path}/{child_name}" if node_path else child_name
_walk(child, child_path)
_walk(self.tree)
return station_paths
@staticmethod
def _crs_to_epsg(value: Any) -> Any:
"""Convert a CRS-like value to an EPSG code when possible."""
if value in [None, "", "None", "none", "null"]:
return None
if hasattr(value, "to_epsg"):
return value.to_epsg()
return value
@staticmethod
def _station_locations_columns() -> list[str]:
"""Column order matching MTStations.station_locations."""
return [
"survey",
"station",
"latitude",
"longitude",
"elevation",
"datum_epsg",
"east",
"north",
"utm_epsg",
"model_east",
"model_north",
"model_elevation",
"profile_offset",
]
def _station_location_record(self, station_path: str) -> dict[str, Any]:
"""Build one station-location record directly from dataset attrs."""
attrs = self.get_station(station_path).attrs
return {
"survey": attrs.get("survey"),
"station": attrs.get("station"),
"latitude": attrs.get("latitude"),
"longitude": attrs.get("longitude"),
"elevation": attrs.get("elevation"),
"datum_epsg": self._crs_to_epsg(attrs.get("datum_crs")),
"east": attrs.get("easting"),
"north": attrs.get("northing"),
"utm_epsg": self._crs_to_epsg(attrs.get("utm_crs")),
"model_east": attrs.get("model_east", 0.0),
"model_north": attrs.get("model_north", 0.0),
"model_elevation": attrs.get("model_elevation", 0.0),
"profile_offset": attrs.get("profile_offset", 0.0),
}
def _station_path_to_location_mt(self, station_path: str) -> "MT":
"""Build a lightweight MT object containing only location metadata."""
from .mt import MT
attrs = self.get_station(station_path).attrs
mt_obj = MT()
if attrs.get("survey") is not None:
mt_obj.survey = attrs["survey"]
if attrs.get("station") is not None:
mt_obj.station = attrs["station"]
if attrs.get("datum_crs") is not None:
mt_obj.datum_crs = attrs["datum_crs"]
if attrs.get("utm_crs") is not None:
mt_obj.utm_crs = attrs["utm_crs"]
for attr_name, attr_value in [
("latitude", attrs.get("latitude")),
("longitude", attrs.get("longitude")),
("elevation", attrs.get("elevation")),
("east", attrs.get("easting")),
("north", attrs.get("northing")),
("model_east", attrs.get("model_east", 0.0)),
("model_north", attrs.get("model_north", 0.0)),
("model_elevation", attrs.get("model_elevation", 0.0)),
("profile_offset", attrs.get("profile_offset", 0.0)),
]:
if attr_value is not None:
setattr(mt_obj, attr_name, attr_value)
return mt_obj
@staticmethod
def _get_utm_crs(mt_obj: "MT") -> Any:
"""Get UTM CRS information from MT object when available."""
crs = getattr(mt_obj, "utm_crs", None)
if crs is not None:
return crs
return getattr(mt_obj, "utm_epsg", None)
@staticmethod
def _dataset_to_mt(station_ds: xr.Dataset) -> "MT":
"""Build an MT object from a stored station dataset and attrs."""
from .mt import MT
mt_obj = MT()
mt_obj._transfer_function = station_ds.copy()
attrs = station_ds.attrs
if attrs.get("survey") is not None:
mt_obj.survey = attrs["survey"]
if attrs.get("station") is not None:
mt_obj.station = attrs["station"]
if attrs.get("coordinate_reference_frame") is not None:
crf = attrs["coordinate_reference_frame"]
if isinstance(crf, str):
crf_key = crf.upper()
if crf_key == "NED":
crf = "+"
elif crf_key == "ENU":
crf = "-"
mt_obj.coordinate_reference_frame = crf
if attrs.get("impedance_units") is not None:
mt_obj.impedance_units = attrs["impedance_units"]
if attrs.get("datum_crs") is not None:
mt_obj.datum_crs = attrs["datum_crs"]
if attrs.get("utm_crs") is not None:
mt_obj.utm_crs = attrs["utm_crs"]
if attrs.get("latitude") is not None:
mt_obj.latitude = attrs["latitude"]
if attrs.get("longitude") is not None:
mt_obj.longitude = attrs["longitude"]
if attrs.get("elevation") is not None:
mt_obj.elevation = attrs["elevation"]
if attrs.get("easting") is not None:
mt_obj.east = attrs["easting"]
if attrs.get("northing") is not None:
mt_obj.north = attrs["northing"]
if attrs.get("profile_offset") is not None:
mt_obj.profile_offset = attrs["profile_offset"]
survey_md = attrs.get("survey_metadata", {})
if isinstance(survey_md, dict) and survey_md:
if hasattr(mt_obj.survey_metadata, "from_dict"):
mt_obj.survey_metadata.from_dict(survey_md)
station_md = attrs.get("station_metadata", {})
if isinstance(station_md, dict) and station_md:
if hasattr(mt_obj.station_metadata, "from_dict"):
mt_obj.station_metadata.from_dict(station_md)
return mt_obj
@staticmethod
def _pick_channel_labels(
available: list[Any], candidates: list[str], required: int
) -> list[Any] | None:
"""Pick channel labels from available coordinates using preferred names."""
channel_map = {str(label).lower(): label for label in available}
selected: list[Any] = []
for candidate in candidates:
key = candidate.lower()
if key in channel_map and channel_map[key] not in selected:
selected.append(channel_map[key])
if len(selected) == required:
return selected
return None
@staticmethod
def _coerce_epsg_value(value: Any) -> str | None:
"""Normalize CRS/EPSG values to a dataframe-compatible EPSG string."""
if value is None:
return None
if isinstance(value, (int, np.integer)):
return str(int(value))
try:
from pyproj import CRS
epsg = CRS.from_user_input(value).to_epsg()
if epsg is not None:
return str(int(epsg))
except Exception:
pass
value_str = str(value).strip()
if not value_str:
return None
if value_str.isdigit():
return str(int(value_str))
return value_str
def _station_dataset_to_dataframe(
self,
station_ds: xr.Dataset,
utm_crs: Any | None = None,
cols: list[str] | None = None,
impedance_units: str = "mt",
) -> pd.DataFrame:
"""Convert one station dataset directly into dataframe rows."""
from .transfer_function import Tipper, Z
period = np.asarray(station_ds.coords["period"].values, dtype=float)
n_entries = period.size
station_df = MTDataFrame(n_entries=n_entries)
attrs = station_ds.attrs
station_df.survey = attrs.get("survey", "")
station_df.station = attrs.get("station", "")
station_df.latitude = attrs.get("latitude", 0.0)
station_df.longitude = attrs.get("longitude", 0.0)
station_df.elevation = attrs.get("elevation", 0.0)
station_df.datum_epsg = self._coerce_epsg_value(attrs.get("datum_crs"))
station_df.east = attrs.get("easting", 0.0)
station_df.north = attrs.get("northing", 0.0)
station_df.utm_epsg = self._coerce_epsg_value(
utm_crs if utm_crs is not None else attrs.get("utm_crs")
)
station_df.model_east = attrs.get("model_east", 0.0)
station_df.model_north = attrs.get("model_north", 0.0)
station_df.model_elevation = attrs.get("model_elevation", 0.0)
station_df.profile_offset = attrs.get("profile_offset", 0.0)
station_df.dataframe.loc[:, "period"] = period
if "output" in station_ds.coords and "input" in station_ds.coords:
output_labels = list(station_ds.coords["output"].values)
input_labels = list(station_ds.coords["input"].values)
z_outputs = self._pick_channel_labels(
output_labels, ["ex", "ey", "x", "y"], 2
)
z_inputs = self._pick_channel_labels(
input_labels, ["hx", "hy", "x", "y"], 2
)
if z_outputs is not None and z_inputs is not None:
tf = (
station_ds["transfer_function"]
.sel(output=z_outputs, input=z_inputs)
.values
)
tf_error = (
station_ds["transfer_function_error"]
.sel(output=z_outputs, input=z_inputs)
.values
)
tf_model_error = (
station_ds["transfer_function_model_error"]
.sel(output=z_outputs, input=z_inputs)
.values
)
z_object = Z(
z=tf,
z_error=tf_error,
frequency=1.0 / period,
z_model_error=tf_model_error,
units=impedance_units,
)
station_df.from_z_object(z_object)
t_output = self._pick_channel_labels(output_labels, ["hz", "z"], 1)
t_inputs = self._pick_channel_labels(
input_labels, ["hx", "hy", "x", "y"], 2
)
if t_output is not None and t_inputs is not None:
tipper = (
station_ds["transfer_function"]
.sel(output=t_output, input=t_inputs)
.values
)
tipper_error = (
station_ds["transfer_function_error"]
.sel(output=t_output, input=t_inputs)
.values
)
tipper_model_error = (
station_ds["transfer_function_model_error"]
.sel(output=t_output, input=t_inputs)
.values
)
tipper_object = Tipper(
tipper=tipper,
tipper_error=tipper_error,
frequency=1.0 / period,
tipper_model_error=tipper_model_error,
)
station_df.from_t_object(tipper_object)
if cols is None:
return station_df.dataframe
return station_df.dataframe.loc[:, cols]
[docs]
def to_mt_stations(self) -> "MTStations":
"""Build an :class:`MTStations` view from current station locations.
Returns
-------
MTStations
Station-location container backed by ``self.station_locations``.
"""
from .mt_stations import MTStations
return MTStations(
self.utm_epsg,
datum_epsg=self.datum_epsg,
station_locations=self.station_locations,
)
@property
def center_point(self) -> Any:
"""Return the geographic center point of the station collection.
If explicit center coordinates have been stored (e.g. after reading a
ModEM data file), those values are returned directly. Otherwise the
center is derived on-the-fly from all station locations via
:meth:`to_mt_stations`.
Returns
-------
MTLocation
Center location with ``latitude``, ``longitude``, ``elevation``,
``east``, ``north``, and ``utm_epsg`` attributes populated.
Examples
--------
>>> from mtpy.core import MTData
>>> tree = MTData()
>>> # (add stations first)
>>> cp = tree.center_point
>>> print(cp.latitude, cp.longitude)
"""
from .mt_location import MTLocation
if self._center_lat is not None and self._center_lon is not None:
center_location = MTLocation()
center_location.latitude = self._center_lat
center_location.longitude = self._center_lon
center_location.elevation = self._center_elev
utm_epsg = self.attrs.get("utm_epsg")
if utm_epsg not in [None, "", "None", "none", "null"]:
center_location.utm_epsg = utm_epsg
datum_crs = self.attrs.get("datum_crs")
if datum_crs not in [None, "", "None", "none", "null"]:
center_location.datum_crs = datum_crs
center_location.model_east = center_location.east
center_location.model_north = center_location.north
center_location.model_elevation = self._center_elev
return center_location
return self.to_mt_stations().center_point
def _dataframe_with_relative_locations(
self,
utm_crs: Any | None = None,
impedance_units: str = "mt",
) -> pd.DataFrame:
"""Return a station dataframe with model-relative coordinates populated.
Calls :meth:`to_dataframe` and, if ``model_east``/``model_north`` are
all zero but absolute ``east``/``north`` values are available, computes
the model coordinates relative to :attr:`center_point`.
Parameters
----------
utm_crs : pyproj CRS or int, optional
Override UTM CRS passed to :meth:`to_dataframe`, by default
``None`` (use the tree's stored CRS).
impedance_units : str, optional
Units for the impedance tensor, e.g. ``'mt'`` or ``'ohm'``,
by default ``'mt'``.
Returns
-------
pandas.DataFrame
Station dataframe with ``model_east``, ``model_north``, and
``model_elevation`` columns filled.
Raises
------
ValueError
If ``model_east``/``model_north`` are zero, absolute UTM
coordinates are available, but no UTM EPSG is set on
:attr:`center_point`.
"""
df = self.to_dataframe(utm_crs=utm_crs, impedance_units=impedance_units).copy()
if df.empty:
return df
model_east = pd.to_numeric(df.get("model_east"), errors="coerce").fillna(0.0)
model_north = pd.to_numeric(df.get("model_north"), errors="coerce").fillna(0.0)
if not (np.allclose(model_east, 0.0) and np.allclose(model_north, 0.0)):
return df
east = pd.to_numeric(df.get("east"), errors="coerce").fillna(0.0)
north = pd.to_numeric(df.get("north"), errors="coerce").fillna(0.0)
if np.allclose(east, 0.0) or np.allclose(north, 0.0):
return df
center = self.center_point
if center.utm_epsg is None:
raise ValueError(
"Need to input data UTM EPSG or CRS to compute relative station locations"
)
df.loc[:, "model_east"] = east - center.east
df.loc[:, "model_north"] = north - center.north
df.loc[:, "model_elevation"] = (
pd.to_numeric(df.get("elevation"), errors="coerce").fillna(0.0)
- center.elevation
)
return df
[docs]
def to_geo_df(
self,
model_locations: bool = False,
data_type: str = "station_locations",
) -> Any:
"""Create a GeoDataFrame for GIS workflows.
Parameters
----------
model_locations : bool, optional
If ``True``, use ``model_east``/``model_north`` as geometry.
Otherwise use longitude/latitude.
data_type : str, optional
One of ``'station_locations'`` (or ``'stations'``), ``'pt'``,
``'tipper'``, or ``'both'``.
Returns
-------
geopandas.GeoDataFrame
GeoDataFrame with point geometries.
Raises
------
ImportError
If geopandas is not installed.
ValueError
If *data_type* is unsupported.
"""
try:
import geopandas as gpd
except ImportError as exc:
raise ImportError(
"geopandas is required for to_geo_df but is not installed"
) from exc
if data_type in ["station_locations", "stations"]:
df = self.station_locations
elif data_type in ["phase_tensor", "pt"]:
df = self.to_mt_dataframe().phase_tensor
elif data_type in ["tipper", "t"]:
df = self.to_mt_dataframe().tipper
elif data_type in ["both", "shapefiles"]:
df = self.to_mt_dataframe().for_shapefiles
else:
raise ValueError(f"Option for 'data_type' {data_type} is unsupported.")
if model_locations:
return gpd.GeoDataFrame(
df,
geometry=gpd.points_from_xy(df.model_east, df.model_north),
crs=None,
)
crs_value = None
if "datum_epsg" in df.columns:
for value in df["datum_epsg"].tolist():
epsg_value = self._coerce_epsg_value(value)
if epsg_value is None:
continue
if str(epsg_value).isdigit():
crs_value = f"EPSG:{epsg_value}"
else:
crs_value = epsg_value
break
return gpd.GeoDataFrame(
df,
geometry=gpd.points_from_xy(df.longitude, df.latitude),
crs=crs_value,
)
[docs]
def to_shp_pt_tipper(
self,
save_dir: str | Path,
output_crs: Any | None = None,
utm: bool = False,
pt: bool = True,
tipper: bool = True,
periods: np.ndarray | None = None,
period_tol: float | None = None,
ellipse_size: float | None = None,
arrow_size: float | None = None,
) -> dict[str, list[str]]:
"""Write phase-tensor and tipper shapefiles.
Parameters
----------
save_dir : str or pathlib.Path
Output directory for shapefiles.
output_crs : Any, optional
Output coordinate reference system.
utm : bool, optional
If ``True``, export in UTM coordinates.
pt : bool, optional
If ``True``, write phase-tensor shapefiles.
tipper : bool, optional
If ``True``, write tipper shapefiles.
periods : numpy.ndarray, optional
Periods to export. When ``None``, use all available periods.
period_tol : float, optional
Period matching tolerance.
ellipse_size : float, optional
Phase-tensor ellipse size. When ``None`` and *pt* is ``True``, the
size is estimated automatically.
arrow_size : float, optional
Tipper arrow size. When ``None`` and *tipper* is ``True``, the size
is estimated automatically.
Returns
-------
dict[str, list[str]]
Mapping of output type to written shapefile paths.
Notes
-----
For mixed station period sampling, interpolate first so all stations
share a common period set.
"""
from mtpy.gis.shapefile_creator import ShapefileCreator
sc = ShapefileCreator(self.to_mt_dataframe(), output_crs, save_dir=save_dir)
sc.utm = utm
if ellipse_size is None and pt:
sc.ellipse_size = sc.estimate_ellipse_size()
else:
sc.ellipse_size = ellipse_size
if arrow_size is None and tipper:
sc.arrow_size = sc.estimate_arrow_size()
else:
sc.arrow_size = arrow_size
return sc.make_shp_files(
pt=pt,
tipper=tipper,
periods=periods,
period_tol=period_tol,
)
@property
def station_locations(self) -> pd.DataFrame:
"""Station-location table built directly from tree dataset attrs."""
columns = self._station_locations_columns()
if self._index is not None:
records = self._index.all_station_records()
if not records:
return pd.DataFrame(columns=columns)
return pd.DataFrame(
[
{
"survey": r.survey_name,
"station": r.name,
"latitude": r.latitude,
"longitude": r.longitude,
"elevation": r.elevation,
"datum_epsg": r.datum_epsg,
"east": r.east,
"north": r.north,
"utm_epsg": r.utm_epsg,
"model_east": r.model_east,
"model_north": r.model_north,
"model_elevation": r.model_elevation,
"profile_offset": r.profile_offset,
}
for r in records
],
columns=columns,
)
station_paths = self._iter_station_paths()
if not station_paths:
return pd.DataFrame(columns=columns)
return pd.DataFrame(
[self._station_location_record(path) for path in station_paths],
columns=columns,
)
@property
def mt_stations(self) -> "MTStations":
"""Convenience accessor for station locations represented by the tree."""
return self.to_mt_stations()
def _sync_station_locations_from_mt_stations(
self,
mt_stations: "MTStations",
) -> None:
"""Write MTStations location updates back into tree station attrs."""
self.compute()
station_df = mt_stations.station_locations
if station_df is None or station_df.empty:
return
key_to_path: dict[tuple[str, str], str] = {}
for station_path in self._iter_station_paths():
attrs = self.get_station(station_path).attrs
key = (
self._clean_name(attrs.get("survey"), "default"),
self._clean_name(attrs.get("station"), "unknown_station"),
)
key_to_path[key] = station_path
for row in station_df.itertuples(index=False):
survey = self._clean_name(getattr(row, "survey", None), "default")
station = self._clean_name(
getattr(row, "station", None),
"unknown_station",
)
station_path = key_to_path.get((survey, station))
if station_path is None:
continue
attrs = self.get_station(station_path).attrs
attrs["latitude"] = getattr(row, "latitude", attrs.get("latitude"))
attrs["longitude"] = getattr(row, "longitude", attrs.get("longitude"))
attrs["elevation"] = getattr(row, "elevation", attrs.get("elevation"))
attrs["easting"] = getattr(row, "east", attrs.get("easting"))
attrs["northing"] = getattr(row, "north", attrs.get("northing"))
attrs["model_east"] = getattr(row, "model_east", attrs.get("model_east"))
attrs["model_north"] = getattr(
row,
"model_north",
attrs.get("model_north"),
)
attrs["model_elevation"] = getattr(
row,
"model_elevation",
attrs.get("model_elevation"),
)
attrs["profile_offset"] = getattr(
row,
"profile_offset",
attrs.get("profile_offset"),
)
datum_epsg = self._coerce_epsg_value(getattr(row, "datum_epsg", None))
if datum_epsg is not None:
attrs["datum_crs"] = datum_epsg
utm_epsg = self._coerce_epsg_value(getattr(row, "utm_epsg", None))
if utm_epsg is not None:
attrs["utm_crs"] = utm_epsg
attrs["utm_epsg"] = utm_epsg
if mt_stations.utm_epsg is not None:
self.attrs["utm_epsg"] = mt_stations.utm_epsg
self.attrs["utm_crs"] = mt_stations.utm_epsg
if mt_stations.datum_epsg is not None:
self.attrs["datum_crs"] = mt_stations.datum_epsg
self.data_rotation_angle = getattr(
mt_stations,
"rotation_angle",
self.data_rotation_angle,
)
self._center_lat = getattr(mt_stations, "_center_lat", self._center_lat)
self._center_lon = getattr(mt_stations, "_center_lon", self._center_lon)
self._center_elev = getattr(mt_stations, "_center_elev", self._center_elev)
if self._index is not None:
self.rebuild_index(index_db_path=self._index_db_path)
[docs]
def compute_relative_locations(self) -> None:
"""Compute model-relative station coordinates and sync to the tree."""
stations = self.to_mt_stations()
stations.compute_relative_locations()
self._sync_station_locations_from_mt_stations(stations)
[docs]
def rotate_stations(self, rotation_angle: float) -> None:
"""Rotate station model coordinates and sync to the tree."""
stations = self.to_mt_stations()
stations.rotate_stations(rotation_angle)
self._sync_station_locations_from_mt_stations(stations)
[docs]
def center_stations(self, model_obj: Any) -> None:
"""Center station locations to model cell centers and sync to tree."""
stations = self.to_mt_stations()
stations.center_stations(model_obj)
self._sync_station_locations_from_mt_stations(stations)
[docs]
def project_stations_on_topography(
self,
model_object: Any,
air_resistivity: float = 1e12,
sea_resistivity: float = 0.3,
ocean_bottom: bool = False,
) -> None:
"""Project station elevations to model topography and sync to tree."""
stations = self.to_mt_stations()
stations.project_stations_on_topography(
model_object,
air_resistivity=air_resistivity,
sea_resistivity=sea_resistivity,
ocean_bottom=ocean_bottom,
)
self._sync_station_locations_from_mt_stations(stations)
[docs]
def to_geopd(self) -> Any:
"""Return station locations as a GeoDataFrame via MTStations."""
return self.to_mt_stations().to_geopd()
[docs]
def to_shp(self, shp_fn: str | Path) -> str | Path:
"""Write a station-location shapefile via MTStations."""
return self.to_mt_stations().to_shp(shp_fn)
[docs]
def to_csv(self, csv_fn: str | Path, geometry: bool = False) -> None:
"""Write station locations to CSV via MTStations."""
self.to_mt_stations().to_csv(csv_fn, geometry=geometry)
[docs]
def to_vtk(
self,
vtk_fn: str | Path | None = None,
vtk_save_path: str | Path | None = None,
vtk_fn_basename: str = "ModEM_stations",
geographic: bool = False,
shift_east: float = 0,
shift_north: float = 0,
shift_elev: float = 0,
units: str = "km",
coordinate_system: str = "nez+",
) -> Path:
"""Write a station-location VTK file via MTStations."""
return self.to_mt_stations().to_vtk(
vtk_fn=vtk_fn,
vtk_save_path=vtk_save_path,
vtk_fn_basename=vtk_fn_basename,
geographic=geographic,
shift_east=shift_east,
shift_north=shift_north,
shift_elev=shift_elev,
units=units,
coordinate_system=coordinate_system,
)
[docs]
def generate_profile(
self,
units: str = "deg",
) -> tuple[float, float, float, float, dict[str, float]]:
"""Generate a best-fit profile line via MTStations."""
return self.to_mt_stations().generate_profile(units=units)
[docs]
def generate_profile_from_strike(
self,
strike: float,
units: str = "deg",
) -> tuple[float, float, float, float, dict[str, float]]:
"""Generate a profile line from strike via MTStations."""
return self.to_mt_stations().generate_profile_from_strike(
strike,
units=units,
)
[docs]
def get_nearby_stations(
self,
station_key: str,
radius: float,
radius_units: str = "m",
) -> list[str]:
"""Find neighboring stations around a reference station.
Parameters
----------
station_key : str
Reference station key as canonical tree path or ``survey.station``.
radius : float
Search radius in the units specified by *radius_units*.
radius_units : {'m', 'meters', 'metres', 'deg', 'degrees'}, optional
Distance units for *radius*.
Returns
-------
list[str]
Matching stations as ``survey.station`` keys (excluding the
reference station).
Raises
------
ValueError
If metric units are requested without UTM coordinate information,
or if *radius_units* is unsupported.
Examples
--------
>>> nearby = tree.get_nearby_stations("surveyA.station01", radius=5000)
>>> nearby_deg = tree.get_nearby_stations("surveyA.station01", 0.1, "deg")
"""
self.compute()
station_path = self._resolve_station_path(station_key)
local_attrs = self.get_station(station_path).attrs
sdf = self.station_locations.copy()
if sdf.empty:
return []
if radius_units in ["m", "meters", "metres"]:
if "utm_epsg" not in sdf.columns or (
sdf["utm_epsg"].replace("", np.nan).dropna().empty
):
raise ValueError(
"Cannot estimate distances in meters without a UTM CRS. Set 'utm_crs' first."
)
sdf["radius"] = np.sqrt(
(
float(local_attrs.get("easting", 0.0))
- pd.to_numeric(sdf.east, errors="coerce").fillna(0.0)
)
** 2
+ (
float(local_attrs.get("northing", 0.0))
- pd.to_numeric(sdf.north, errors="coerce").fillna(0.0)
)
** 2
)
elif radius_units in ["deg", "degrees"]:
sdf["radius"] = np.sqrt(
(
float(local_attrs.get("longitude", 0.0))
- pd.to_numeric(sdf.longitude, errors="coerce").fillna(0.0)
)
** 2
+ (
float(local_attrs.get("latitude", 0.0))
- pd.to_numeric(sdf.latitude, errors="coerce").fillna(0.0)
)
** 2
)
else:
raise ValueError(
"radius_units must be one of: m, meters, metres, deg, degrees"
)
return [
f"{row.survey}.{row.station}"
for row in sdf.loc[(sdf.radius <= radius) & (sdf.radius > 0)].itertuples()
]
[docs]
def get_profile(
self,
x1: float,
y1: float,
x2: float,
y2: float,
radius: float,
) -> "MTData":
"""Extract stations within a corridor around a profile line.
Parameters
----------
x1, y1, x2, y2 : float
Profile start and end coordinates in the same coordinate system as
station locations.
radius : float
Corridor half-width around the profile line.
Returns
-------
MTData
New tree containing only stations that fall within the profile
corridor.
"""
self.compute()
profile_stations = self.to_mt_stations()._extract_profile(
x1,
y1,
x2,
y2,
radius,
)
if profile_stations.empty:
return self.clone_empty()
key_to_path: dict[tuple[str, str], str] = {}
for station_path in self._iter_station_paths():
attrs = self.get_station(station_path).attrs
key = (str(attrs.get("survey", "")), str(attrs.get("station", "")))
key_to_path[key] = station_path
selected_paths: list[str] = []
for row in profile_stations.itertuples(index=False):
key = (str(getattr(row, "survey")), str(getattr(row, "station")))
station_path = key_to_path.get(key)
if station_path is not None:
selected_paths.append(station_path)
profile_tree = self.get_subset(selected_paths)
for row in profile_stations.itertuples(index=False):
survey = self._clean_name(getattr(row, "survey", None), "default")
station = self._clean_name(
getattr(row, "station", None),
"unknown_station",
)
station_path = profile_tree._station_path(survey, station)
if not profile_tree._path_exists(station_path):
continue
if hasattr(row, "profile_offset"):
profile_tree.get_station(station_path).attrs["profile_offset"] = float(
getattr(row, "profile_offset")
)
return profile_tree
[docs]
def compute_model_errors(
self,
z_error_value: float | None = None,
z_error_type: str | None = None,
z_floor: bool | None = None,
t_error_value: float | None = None,
t_error_type: str | None = None,
t_floor: bool | None = None,
) -> None:
"""Recompute impedance and tipper model errors for all stations.
Parameters
----------
z_error_value, z_error_type, z_floor : optional
Overrides for impedance model-error settings.
t_error_value, t_error_type, t_floor : optional
Overrides for tipper model-error settings.
"""
self.compute()
if z_error_value is not None:
self.z_model_error.error_value = z_error_value
if z_error_type is not None:
self.z_model_error.error_type = z_error_type
if z_floor is not None:
self.z_model_error.floor = z_floor
if t_error_value is not None:
self.t_model_error.error_value = t_error_value
if t_error_type is not None:
self.t_model_error.error_type = t_error_type
if t_floor is not None:
self.t_model_error.floor = t_floor
for station_path in self._iter_station_paths():
station_ds = self.get_station(station_path)
attrs = dict(station_ds.attrs)
mt_obj = self._dataset_to_mt(station_ds)
self._hydrate_metadata_from_cache(mt_obj, station_ds)
mt_obj.compute_model_z_errors(**self.z_model_error.error_parameters)
mt_obj.compute_model_t_errors(**self.t_model_error.error_parameters)
out_ds = mt_obj._transfer_function
out_ds.attrs = attrs
self._set_station_dataset(station_path, out_ds)
[docs]
def estimate_starting_rho(self) -> None:
"""Estimate starting resistivity from all station data and plot summary curves."""
import matplotlib.pyplot as plt
self.compute()
entries: list[dict[str, float]] = []
for station_path in self._iter_station_paths():
mt_obj = self.get_station(station_path, as_mt=True)
for period, res_det in zip(mt_obj.period, mt_obj.Z.res_det):
entries.append({"period": period, "res_det": res_det})
res_df = pd.DataFrame(entries)
mean_rho = res_df.groupby("period").mean()
median_rho = res_df.groupby("period").median()
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
(l1,) = ax.loglog(mean_rho.index, mean_rho.res_det, lw=2, color=(0.75, 0.25, 0))
(l2,) = ax.loglog(
median_rho.index, median_rho.res_det, lw=2, color=(0, 0.25, 0.75)
)
ax.loglog(
mean_rho.index,
np.repeat(mean_rho.res_det.mean(), mean_rho.shape[0]),
ls="--",
lw=2,
color=(0.75, 0.25, 0),
)
ax.loglog(
median_rho.index,
np.repeat(median_rho.res_det.median(), median_rho.shape[0]),
ls="--",
lw=2,
color=(0, 0.25, 0.75),
)
ax.set_xlabel("Period (s)", fontdict={"size": 12, "weight": "bold"})
ax.set_ylabel("Resistivity (Ohm-m)", fontdict={"size": 12, "weight": "bold"})
ax.legend(
[l1, l2],
[
f"Mean = {mean_rho.res_det.mean():.1f}",
f"Median = {median_rho.res_det.median():.1f}",
],
loc="upper left",
)
ax.grid(which="both", ls="--", color=(0.75, 0.75, 0.75))
ax.set_xlim((res_df.period.min(), res_df.period.max()))
plt.show()
[docs]
def to_modem(self, data_filename: str | Path | None = None, **kwargs: Any) -> Any:
"""Create a ModEM Data object from the station collection.
Parameters
----------
data_filename : str or pathlib.Path, optional
Path to write the ModEM data file. When ``None`` (default) the
file is not written.
**kwargs
Additional keyword arguments forwarded to
:class:`mtpy.modeling.modem.Data` (e.g. ``rotation_angle``,
``inv_mode``, ``formatting``).
Returns
-------
mtpy.modeling.modem.Data
Populated ModEM Data object with ``z_model_error`` and
``t_model_error`` set from the tree.
Examples
--------
Create a data file and retrieve the Data object:
>>> from mtpy.core import MTData
>>> tree = MTData()
>>> # tree.add_station(...) # populate with MT objects first
>>> tree.model_parameters = {"inv_mode": "1", "formatting": "1"}
>>> modem_data = tree.to_modem(data_filename="ModEM_data.dat")
>>> print(modem_data.center_point.latitude)
"""
from mtpy.modeling.modem import Data
modem_kwargs = dict(self.model_parameters)
modem_kwargs.update(kwargs)
modem_df = self._dataframe_with_relative_locations(
impedance_units=self.impedance_units
)
if modem_df.empty:
modem_df = self.to_dataframe(impedance_units=self.impedance_units)
modem_data = Data(
dataframe=modem_df,
center_point=self.center_point,
**modem_kwargs,
)
modem_data.z_model_error = self.z_model_error
modem_data.t_model_error = self.t_model_error
if data_filename is not None:
modem_data.write_data_file(file_name=data_filename)
return modem_data
[docs]
def from_modem(
self, data_filename: str | Path, survey: str = "data", **kwargs: Any
) -> None:
"""Populate the tree by reading an existing ModEM data file.
Station datasets, model-error parameters, the center point, and any
top-level model parameters (those without a dot in the key) are all
restored from the file.
Parameters
----------
data_filename : str or pathlib.Path
Path to the ModEM ``.dat`` / ``.data`` file.
survey : str, optional
Survey label to assign to all imported stations,
by default ``'data'``.
**kwargs
Additional keyword arguments forwarded to
:class:`mtpy.modeling.modem.Data`.
Examples
--------
>>> from mtpy.core import MTData
>>> tree = MTData()
>>> tree.from_modem("ModEM_data.dat", survey="line1")
>>> print(tree.survey_ids)
"""
from mtpy.modeling.modem import Data
modem_data = Data(**kwargs)
mdf = modem_data.read_data_file(data_filename)
mdf.dataframe.loc[:, "survey"] = survey
self.from_mt_dataframe(mdf)
self.z_model_error = ModelErrors(
mode="impedance", **modem_data.z_model_error.error_parameters
)
self.t_model_error = ModelErrors(
mode="tipper", **modem_data.t_model_error.error_parameters
)
self.data_rotation_angle = modem_data.rotation_angle
self._center_lat = modem_data.center_point.latitude
self._center_lon = modem_data.center_point.longitude
self._center_elev = modem_data.center_point.elevation
self.attrs["utm_epsg"] = modem_data.center_point.utm_epsg
self.attrs["datum_crs"] = modem_data.center_point.datum_crs
self.model_parameters = {
key: value
for key, value in modem_data.model_parameters.items()
if "." not in key
}
[docs]
def to_occam2d(self, data_filename: str | Path | None = None, **kwargs: Any) -> Any:
"""Create an Occam2D data object from the station collection.
Parameters
----------
data_filename : str or pathlib.Path, optional
Path to write the Occam2D data file. When ``None`` (default) the
file is not written.
**kwargs
Additional keyword arguments forwarded to
:class:`mtpy.modeling.occam2d.Occam2DData` (e.g.
``model_mode``, ``profile_angle``, ``res_te_err``).
Returns
-------
mtpy.modeling.occam2d.Occam2DData
Populated Occam2D data object with ``profile_origin`` set from
:attr:`center_point` when not supplied via *kwargs*.
Notes
-----
All information is derived from the station dataframe. The user
should create the profile, interpolate, and estimate model errors
from the tree before calling this method.
Examples
--------
>>> from mtpy.core import MTData
>>> tree = MTData()
>>> # tree.add_station(...) # populate first
>>> occam_data = tree.to_occam2d(
... data_filename="OccamDataFile.dat", model_mode="5"
... )
"""
from mtpy.modeling.occam2d import Occam2DData
occam2d_data = Occam2DData(**kwargs)
occam2d_data.dataframe = self.to_dataframe()
if occam2d_data.profile_origin is None:
cp = self.center_point
occam2d_data.profile_origin = (cp.east, cp.north)
if data_filename is not None:
occam2d_data.write_data_file(data_filename)
return occam2d_data
[docs]
def from_occam2d(
self,
data_filename: str | Path,
file_type: str = "data",
**kwargs: Any,
) -> None:
"""Populate the tree by reading an existing Occam2D data file.
After reading, ``profile_origin``, ``profile_angle``, and
``model_mode`` are stored in :attr:`model_parameters`.
Parameters
----------
data_filename : str or pathlib.Path
Path to the Occam2D data file.
file_type : str, optional
``'data'`` (default) or ``'response'``/``'model'``. Controls the
survey label (``'data'`` or ``'model'``) assigned to each row.
**kwargs
Additional keyword arguments forwarded to
:class:`mtpy.modeling.occam2d.Occam2DData`.
Examples
--------
Read a data file:
>>> from mtpy.core import MTData
>>> tree = MTData()
>>> tree.from_occam2d("OccamDataFile.dat")
>>> print(tree.station_ids)
Read a response / model file:
>>> tree.from_occam2d("OccamResponse.dat", file_type="response")
"""
from mtpy.modeling.occam2d import Occam2DData
occam2d_data = Occam2DData(**kwargs)
occam2d_data.read_data_file(data_filename)
if file_type in ["data"]:
occam2d_data.dataframe["survey"] = "data"
elif file_type in ["response", "model"]:
occam2d_data.dataframe["survey"] = "model"
self.from_dataframe(occam2d_data.dataframe)
self.model_parameters["profile_origin"] = occam2d_data.profile_origin
self.model_parameters["profile_angle"] = occam2d_data.profile_angle
self.model_parameters["model_mode"] = occam2d_data.model_mode
[docs]
def to_simpeg_2d(self, **kwargs: Any) -> Any:
"""Create a SimPEG 2-D MT data object from the station collection.
Parameters
----------
**kwargs
Additional keyword arguments forwarded to
:class:`mtpy.modeling.simpeg.data_2d.Simpeg2DData`. Common
options include:
include_elevation : bool
Include station elevation in the receiver locations,
by default ``True``.
invert_te : bool
Include TE-mode apparent resistivity and phase,
by default ``True``.
invert_tm : bool
Include TM-mode apparent resistivity and phase,
by default ``True``.
Returns
-------
mtpy.modeling.simpeg.data_2d.Simpeg2DData
Populated SimPEG 2-D data object.
Notes
-----
The impedance units are converted to ``'ohm'`` automatically.
The user should create the profile, interpolate, and estimate model
errors from the tree before calling this method.
Examples
--------
>>> from mtpy.core import MTData
>>> tree = MTData()
>>> # tree.add_station(...) # populate first
>>> simpeg_2d = tree.to_simpeg_2d(invert_te=True, invert_tm=False)
"""
from mtpy.modeling.simpeg.data_2d import Simpeg2DData
return Simpeg2DData(self.to_dataframe(impedance_units="ohm"), **kwargs)
[docs]
def to_simpeg_3d(self, **kwargs: Any) -> Any:
"""Create a SimPEG 3-D MT data object from the station collection.
Parameters
----------
**kwargs
Additional keyword arguments forwarded to
:class:`mtpy.modeling.simpeg.data_3d.Simpeg3DData`. Common
options include:
include_elevation : bool
Include station elevation in the receiver locations,
by default ``False``.
geographic_coordinates : bool
Use geographic (UTM) coordinates instead of model-relative
coordinates, by default ``True``.
invert_z_xx, invert_z_xy, invert_z_yx, invert_z_yy : bool
Select which impedance tensor components to include,
all default to ``True``.
invert_t_zx, invert_t_zy : bool
Select which tipper components to include,
both default to ``True``.
Returns
-------
mtpy.modeling.simpeg.data_3d.Simpeg3DData
Populated SimPEG 3-D data object.
Notes
-----
The impedance units are converted to ``'ohm'`` automatically.
The user should interpolate and estimate model errors from the tree
before calling this method.
Examples
--------
>>> from mtpy.core import MTData
>>> tree = MTData()
>>> # tree.add_station(...) # populate first
>>> simpeg_3d = tree.to_simpeg_3d(invert_z_yy=False, include_elevation=True)
"""
from mtpy.modeling.simpeg.data_3d import Simpeg3DData
return Simpeg3DData(self.to_dataframe(impedance_units="ohm"), **kwargs)
[docs]
def add_white_noise(self, value: float, inplace: bool = True) -> "MTData | None":
"""Add white noise to the impedance and tipper of every station.
Multiplies the real and imaginary parts of the transfer function by
independent random factors drawn from
``1 ± U(0, value) `` and increments the transfer-function error by
*value*. Useful for generating synthetic test datasets.
Parameters
----------
value : float
Noise level expressed as a decimal fraction (0–1) or as a
percentage (>1). Values greater than 1 are divided by 100
automatically (e.g. ``10`` becomes ``0.10``, i.e. 10 %).
inplace : bool, optional
When ``True`` (default) each station dataset is modified in place
and ``None`` is returned. When ``False`` a new
:class:`MTData` containing noisy copies is returned and the
original tree is left unchanged.
Returns
-------
MTData or None
A new tree containing noisy station data when *inplace* is
``False``; ``None`` otherwise.
Examples
--------
In-place noise addition:
>>> from mtpy.core import MTData
>>> tree = MTData()
>>> # tree.add_station(...) # populate first
>>> tree.add_white_noise(5) # adds 5 % noise in place
Non-destructive copy with noise:
>>> noisy_tree = tree.add_white_noise(0.05, inplace=False)
"""
if value > 1:
value = value / 100.0
paths = self._iter_station_paths()
if inplace:
for path in paths:
mt_obj = self.get_station(path, as_mt=True)
mt_obj.add_white_noise(value)
self.add_station(mt_obj)
return None
else:
mt_list = []
for path in paths:
mt_obj = self.get_station(path, as_mt=True)
mt_list.append(mt_obj.add_white_noise(value, inplace=False))
new_tree = self.clone_empty()
new_tree.add_stations(mt_list)
return new_tree
[docs]
def estimate_spatial_static_shift(
self,
station_key: str,
radius: float,
period_min: float,
period_max: float,
radius_units: str = "m",
shift_tolerance: float = 0.15,
) -> tuple[float, float]:
"""Estimate static-shift scale factors from nearby stations.
Parameters
----------
station_key : str
Target station key.
radius : float
Neighbor search radius.
period_min, period_max : float
Period bounds used for comparison.
radius_units : str, optional
Radius units passed to :meth:`get_nearby_stations`.
shift_tolerance : float, optional
Values within ``1 +/- shift_tolerance`` are snapped to ``1.0``.
Returns
-------
tuple[float, float]
Estimated ``(sx, sy)`` static-shift factors.
"""
nearby_keys = self.get_nearby_stations(station_key, radius, radius_units)
if len(nearby_keys) == 0:
return 1.0, 1.0
nearby_paths = [self._resolve_station_path(key) for key in nearby_keys]
md = self.get_subset(nearby_paths)
local_site = self.get_station(
self._resolve_station_path(station_key),
as_mt=True,
)
interp_periods = local_site.period[
np.where(
(local_site.period >= period_min) & (local_site.period <= period_max)
)
]
local_site = local_site.interpolate(interp_periods)
md.interpolate(interp_periods)
df = md.to_dataframe()
sx = np.nanmedian(df.res_xy) / np.nanmedian(local_site.Z.res_xy)
sy = np.nanmedian(df.res_yx) / np.nanmedian(local_site.Z.res_yx)
if 1 - shift_tolerance < sx < 1 + shift_tolerance:
sx = 1.0
if 1 - shift_tolerance < sy < 1 + shift_tolerance:
sy = 1.0
return sx, sy
@property
def n_stations(self) -> int:
"""Total number of stations in the collection."""
self.compute()
if self._index is not None:
return self._index.n_stations()
return len(self._iter_station_paths())
@property
def survey_ids(self) -> list[str]:
"""Unique survey IDs in the collection."""
self.compute()
if self._index is not None:
return [row.name for row in self._index.all_surveys()]
return list(
{
path.split("/", 3)[1]
for path in self._iter_station_paths()
if path.count("/") >= 3
}
)
[docs]
def get_survey(self, survey_id: str) -> "MTData":
"""Return a subset tree for one survey.
Parameters
----------
survey_id : str
Survey identifier.
Returns
-------
MTData
Tree containing all stations under the selected survey.
"""
self.compute()
station_list = [
station_path
for station_path in self._iter_station_paths()
if station_path.startswith(
f"{self.SURVEYS_NODE}/{survey_id}/{self.STATIONS_NODE}/"
)
]
return self.get_subset(station_list)
[docs]
def add_station(
self,
mt_obj: "MT | str | Path | list[MT | str | Path]",
overwrite: bool = True,
dataset_copy_mode: str | None = None,
) -> str | list[str]:
"""
Add an MT object as a station node in the tree.
Node path pattern:
/surveys/{survey_id}/stations/{station_id}
Parameters
----------
mt_obj : mtpy.core.MT, str, Path, or list
MT object, filename, pathlib.Path, or a list of mixed supported
input types.
overwrite : bool, optional
If False, raise if station path already exists.
dataset_copy_mode : {'deep', 'shallow', 'none'}, optional
Dataset copy behavior for station transfer-function storage.
Returns
-------
str or list[str]
Station node path for scalar inputs or list of paths for list
inputs.
"""
self.compute()
if mt_obj is None:
raise TypeError("mt_obj cannot be None")
if isinstance(mt_obj, list):
return self.add_stations(
mt_obj,
overwrite=overwrite,
dataset_copy_mode=dataset_copy_mode,
)
(
station_path,
station,
station_ds,
metadata_objects,
) = self._coerce_and_prepare_station(
mt_obj,
dataset_copy_mode=dataset_copy_mode,
)
if self._path_exists(station_path) and not overwrite:
raise KeyError(f"Station path already exists: {station_path}")
self.tree[station_path] = xr.DataTree(name=station, dataset=station_ds)
self._commit_cached_metadata(
station_path,
metadata_objects["survey"],
metadata_objects["station"],
)
if self._index is not None:
station_row, period_row = MTDataTreeIndexStore._extract_rows(
station_path, station_ds
)
self._index.upsert_station(station_row)
if period_row is not None:
self._index.replace_station_period_rows(period_row)
self._index.refresh_survey_aggregates(station_row.survey_name)
return station_path
[docs]
def add_tf(
self,
tf: "MT | str | Path | list[MT | str | Path]",
**kwargs: Any,
) -> str | list[str]:
"""Alias for add_station to mirror MTData API."""
return self.add_station(tf, **kwargs)
[docs]
def add_stations(
self,
mt_objects: list["MT | str | Path"],
overwrite: bool = True,
dataset_copy_mode: str | None = None,
precomputed_attrs: list[dict[str, Any] | None] | None = None,
) -> list[str]:
"""
Bulk-add MT stations with optional precomputed attrs for fast ingest.
Parameters
----------
mt_objects : list
List of MT objects, filename strings, or Paths.
overwrite : bool, optional
If False, raise if a station path already exists.
dataset_copy_mode : {'deep', 'shallow', 'none'}, optional
Dataset copy behavior for station transfer-function storage.
precomputed_attrs : list[dict | None], optional
Optional attrs payload aligned by index with mt_objects. When
provided, these attrs are used directly and only canonical keys are
enforced (survey/station and metadata refs).
Returns
-------
list[str]
Inserted station paths.
"""
self.compute()
if mt_objects is None:
raise TypeError("mt_objects cannot be None")
if not isinstance(mt_objects, list):
raise TypeError("mt_objects must be a list")
if not mt_objects:
return []
if precomputed_attrs is not None:
if not isinstance(precomputed_attrs, list):
raise TypeError("precomputed_attrs must be a list when provided")
if len(precomputed_attrs) != len(mt_objects):
raise ValueError("precomputed_attrs must match mt_objects length")
prepared: list[tuple[str, str, xr.Dataset, dict[str, Any]]] = []
seen_paths: set[str] = set()
for index, mt_obj in enumerate(mt_objects):
attrs = None
if precomputed_attrs is not None:
attrs = precomputed_attrs[index]
if attrs is not None and not isinstance(attrs, dict):
raise TypeError("Each precomputed_attrs entry must be dict or None")
(
station_path,
station,
station_ds,
metadata_objects,
) = self._coerce_and_prepare_station(
mt_obj,
dataset_copy_mode=dataset_copy_mode,
precomputed_attrs=attrs,
)
if station_path in seen_paths and not overwrite:
raise KeyError(f"Station path already exists: {station_path}")
seen_paths.add(station_path)
if self._path_exists(station_path) and not overwrite:
raise KeyError(f"Station path already exists: {station_path}")
prepared.append((station_path, station, station_ds, metadata_objects))
parent_cache: dict[str, Any] = {}
inserted_paths: list[str] = []
for station_path, station, station_ds, metadata_objects in prepared:
parent_path, child_name = station_path.rsplit("/", 1)
parent_node = parent_cache.get(parent_path)
if parent_node is None:
try:
parent_node = self.tree[parent_path]
except KeyError:
_, survey_name, _ = parent_path.split("/", 2)
survey_path = f"{self.SURVEYS_NODE}/{survey_name}"
if not self._path_exists(survey_path):
self.tree[survey_path] = xr.DataTree(
name=survey_name,
dataset=xr.Dataset(),
)
if not self._path_exists(parent_path):
self.tree[parent_path] = xr.DataTree(
name=self.STATIONS_NODE,
dataset=xr.Dataset(),
)
parent_node = self.tree[parent_path]
parent_cache[parent_path] = parent_node
parent_node[child_name] = xr.DataTree(
name=station,
dataset=station_ds,
)
self._commit_cached_metadata(
station_path,
metadata_objects["survey"],
metadata_objects["station"],
)
inserted_paths.append(station_path)
if self._index is not None:
updated_surveys: set[str] = set()
for station_path, _station, station_ds, _meta in prepared:
station_row, period_row = MTDataTreeIndexStore._extract_rows(
station_path, station_ds
)
self._index.upsert_station(station_row)
if period_row is not None:
self._index.replace_station_period_rows(period_row)
updated_surveys.add(station_row.survey_name)
for sv in updated_surveys:
self._index.refresh_survey_aggregates(sv)
return inserted_paths
[docs]
def get_station(self, station_key: str, as_mt: bool = False) -> xr.Dataset | "MT":
"""Return one station as a dataset or reconstructed MT object.
Parameters
----------
station_key : str
Station identifier in canonical tree-path, ``survey/station``, or
``survey.station`` form.
as_mt : bool, optional
If ``True``, convert the stored dataset to an ``MT`` object.
Returns
-------
xarray.Dataset or MT
Station dataset (default) or reconstructed MT object.
Examples
--------
>>> ds = tree.get_station("surveys/surveyA/stations/st01")
>>> ds = tree.get_station("surveyA/st01")
>>> mt_obj = tree.get_station("surveys/surveyA/stations/st01", as_mt=True)
"""
station_path = self._resolve_station_path(station_key)
self.compute(station_paths=[station_path])
station_ds = self.tree[station_path].ds
if as_mt:
mt_obj = self._dataset_to_mt(station_ds)
self._hydrate_metadata_from_cache(mt_obj, station_ds)
return mt_obj
return station_ds
[docs]
def remove_station(self, station_key: str) -> None:
"""Remove one station node and its cached/indexed metadata.
Parameters
----------
station_key : str
Station identifier in canonical tree-path, ``survey/station``, or
``survey.station`` form.
"""
station_key = self._resolve_station_path(station_key)
self.compute()
self._lazy_station_transforms.pop(station_key, None)
self._clear_cached_metadata(station_key)
if self._index is not None:
self._index.delete_station_by_tree_path(station_key)
if "/" not in station_key:
del self.tree[station_key]
return
parent_path, child_name = station_key.rsplit("/", 1)
del self.tree[parent_path][child_name]
[docs]
def get_subset(self, station_list: list[str]) -> "MTData":
"""Create a tree containing only selected station paths.
Parameters
----------
station_list : list[str]
Station tree paths to copy into the subset.
Returns
-------
MTData
New tree with copied station datasets and relevant metadata cache
entries.
"""
station_list = [
self._resolve_station_path(station_key) for station_key in station_list
]
subset = self.__class__(
metadata_storage=self.metadata_storage,
**dict(self.attrs),
)
for station_key in station_list:
station_ds = self.get_station(station_key).copy()
attrs = station_ds.attrs
target_path = self._station_path(
self._clean_name(attrs.get("survey"), "default"),
self._clean_name(attrs.get("station"), "unknown_station"),
)
if self.metadata_storage == "cache":
for metadata_kind in ["survey", "station"]:
cached_md = self._metadata_cache[metadata_kind].get(station_key)
if cached_md is None:
continue
subset._metadata_cache[metadata_kind][target_path] = cached_md
station_ds.attrs[f"{metadata_kind}_metadata_ref"] = target_path
subset.tree[target_path] = xr.DataTree(
name=target_path.rsplit("/", 1)[-1], dataset=station_ds
)
return subset
def _set_station_dataset(self, station_path: str, station_ds: xr.Dataset) -> None:
"""Insert or replace a station dataset at its tree path."""
parent_path, child_name = station_path.rsplit("/", 1)
try:
parent_node = self.tree[parent_path]
except KeyError:
_, survey_name, _ = parent_path.split("/", 2)
survey_path = f"{self.SURVEYS_NODE}/{survey_name}"
if not self._path_exists(survey_path):
self.tree[survey_path] = xr.DataTree(
name=survey_name,
dataset=xr.Dataset(),
)
if not self._path_exists(parent_path):
self.tree[parent_path] = xr.DataTree(
name=self.STATIONS_NODE,
dataset=xr.Dataset(),
)
parent_node = self.tree[parent_path]
parent_node[child_name] = xr.DataTree(name=child_name, dataset=station_ds)
@staticmethod
def _interpolate_station_dataset(
station_ds: xr.Dataset,
new_periods: np.ndarray,
**kwargs: Any,
) -> xr.Dataset:
"""Interpolate a stored station dataset via the dataset tf accessor."""
target_periods = np.asarray(new_periods, dtype=float)
attrs = dict(station_ds.attrs)
if "period" not in station_ds.coords or not station_ds.data_vars:
coords = {
coord_name: coord.values
for coord_name, coord in station_ds.coords.items()
if coord_name != "period"
}
coords["period"] = target_periods
interpolated_ds = xr.Dataset(coords=coords)
interpolated_ds.attrs.update(attrs)
return interpolated_ds
if target_periods.size == 0:
interpolated_ds = station_ds.isel(period=slice(0, 0)).copy(deep=True)
interpolated_ds.attrs.update(attrs)
return interpolated_ds
return station_ds.tf.interpolate(target_periods, inplace=False, **kwargs)
@staticmethod
def _rotate_station_dataset(
station_ds: xr.Dataset,
rotation_angle: float | np.ndarray,
coordinate_reference_frame: str = "ned",
) -> xr.Dataset:
"""Rotate impedance/tipper channel blocks via the dataset tf accessor."""
if (
"transfer_function" not in station_ds
or "period" not in station_ds.coords
or station_ds.sizes.get("period", 0) == 0
):
return station_ds
# Materialise dask arrays before rotation; the accessor uses .loc[]
# assignments which xarray cannot apply to dask-backed arrays.
try:
return station_ds.load().tf.rotate(
rotation_angle,
coordinate_reference_frame=coordinate_reference_frame,
inplace=False,
)
except ValueError:
return station_ds
[docs]
def rotate(
self,
rotation_angle: float | np.ndarray,
inplace: bool = True,
) -> "MTData" | None:
"""Rotate all station transfer functions.
Parameters
----------
rotation_angle : float or ndarray
Scalar rotation angle in degrees or per-period angle array.
inplace : bool, optional
If ``True`` (default), mutate this tree and return ``None``.
Returns
-------
MTData or None
Rotated copy when *inplace* is ``False``; otherwise ``None``.
"""
tree_obj = self
if not inplace:
tree_obj = self.__class__(
metadata_storage=self.metadata_storage,
dataset_copy_mode=self.dataset_copy_mode,
use_index=self._index is not None,
index_db_path=self._index_db_path,
**dict(self.attrs),
)
updated_surveys: set[str] = set()
for station_path in self._iter_station_paths():
station_ds = self.get_station(station_path)
crf = station_ds.attrs.get(
"coordinate_reference_frame",
self.attrs.get("coordinate_reference_frame", "ned"),
)
rotated_ds = self._rotate_station_dataset(
station_ds,
rotation_angle,
coordinate_reference_frame=crf,
)
tree_obj._set_station_dataset(station_path, rotated_ds)
if tree_obj.metadata_storage == "cache":
for metadata_kind in ["survey", "station"]:
cached_md = self._metadata_cache[metadata_kind].get(station_path)
if cached_md is not None:
tree_obj._metadata_cache[metadata_kind][
station_path
] = cached_md
if tree_obj._index is not None:
station_row, period_row = MTDataTreeIndexStore._extract_rows(
station_path,
rotated_ds,
)
tree_obj._index.upsert_station(station_row)
if period_row is not None:
tree_obj._index.replace_station_period_rows(period_row)
updated_surveys.add(station_row.survey_name)
if tree_obj._index is not None:
for survey_name in updated_surveys:
tree_obj._index.refresh_survey_aggregates(survey_name)
if not inplace:
return tree_obj
return None
[docs]
def interpolate(
self,
new_periods: np.ndarray,
f_type: str = "period",
inplace: bool = True,
bounds_error: bool = True,
**kwargs: Any,
) -> "MTData" | None:
"""Interpolate all stations to a shared period grid.
Parameters
----------
new_periods : ndarray
Target period array, or frequency array when *f_type* is
``'frequency'``/``'freq'``.
f_type : {'frequency', 'freq', 'period', 'per'}, optional
Specifies the meaning of *new_periods*.
inplace : bool, optional
If ``True`` (default), update this tree in place.
bounds_error : bool, optional
If ``True``, clip target periods to each station's native period
range.
**kwargs
Forwarded to station interpolation.
Returns
-------
MTData or None
Interpolated copy when *inplace* is ``False``; otherwise ``None``.
Raises
------
ValueError
If *f_type* is unsupported.
"""
if f_type not in ["frequency", "freq", "period", "per"]:
raise ValueError(
f"f_type must be either 'frequency' or 'period' not {f_type}"
)
target_periods = np.asarray(new_periods, dtype=float)
if target_periods.ndim != 1:
target_periods = target_periods.reshape(-1)
if f_type in ["frequency", "freq"]:
target_periods = 1.0 / target_periods
tree_obj = self
if not inplace:
tree_obj = self.__class__(
metadata_storage=self.metadata_storage,
dataset_copy_mode=self.dataset_copy_mode,
use_index=self._index is not None,
index_db_path=self._index_db_path,
**dict(self.attrs),
)
updated_surveys: set[str] = set()
for station_path in self._iter_station_paths():
station_ds = self.get_station(station_path)
interp_periods = target_periods
if bounds_error and "period" in station_ds.coords:
station_periods = np.asarray(
station_ds.coords["period"].values, dtype=float
)
if station_periods.size > 0:
interp_periods = target_periods[
(target_periods <= station_periods.max())
& (target_periods >= station_periods.min())
]
interpolated_ds = self._interpolate_station_dataset(
station_ds,
interp_periods,
**kwargs,
)
tree_obj._set_station_dataset(station_path, interpolated_ds)
if tree_obj.metadata_storage == "cache":
for metadata_kind in ["survey", "station"]:
cached_md = self._metadata_cache[metadata_kind].get(station_path)
if cached_md is not None:
tree_obj._metadata_cache[metadata_kind][
station_path
] = cached_md
if tree_obj._index is not None:
station_row, period_row = MTDataTreeIndexStore._extract_rows(
station_path,
interpolated_ds,
)
if period_row is None:
tree_obj._index.delete_station_by_tree_path(station_path)
tree_obj._index.upsert_station(station_row)
if period_row is not None:
tree_obj._index.replace_station_period_rows(period_row)
updated_surveys.add(station_row.survey_name)
if tree_obj._index is not None:
for survey_name in updated_surveys:
tree_obj._index.refresh_survey_aggregates(survey_name)
if not inplace:
return tree_obj
return None
[docs]
def interpolate_lazy(
self,
new_periods: np.ndarray,
f_type: str = "period",
inplace: bool = False,
bounds_error: bool = True,
**kwargs: Any,
) -> "MTData":
"""Register deferred interpolation transforms for all stations.
Parameters
----------
new_periods : ndarray
Target period array, or frequency array when *f_type* indicates
frequency input.
f_type : {'frequency', 'freq', 'period', 'per'}, optional
Specifies the meaning of *new_periods*.
inplace : bool, optional
If ``True``, clear and replace lazy transforms on this instance.
Otherwise return a new tree with lazy transforms attached.
bounds_error : bool, optional
If ``True``, clip target periods to each station's native period
range.
**kwargs
Forwarded to station interpolation at compute time.
Returns
-------
MTData
Tree with pending interpolation transforms.
"""
if f_type not in ["frequency", "freq", "period", "per"]:
raise ValueError(
f"f_type must be either 'frequency' or 'period' not {f_type}"
)
# Build lazy plans from realized source station datasets.
self.compute()
target_periods = np.asarray(new_periods, dtype=float)
if target_periods.ndim != 1:
target_periods = target_periods.reshape(-1)
if f_type in ["frequency", "freq"]:
target_periods = 1.0 / target_periods
tree_obj = self
if not inplace:
tree_obj = self.__class__(
metadata_storage=self.metadata_storage,
dataset_copy_mode=self.dataset_copy_mode,
use_index=False,
index_db_path=self._index_db_path,
**dict(self.attrs),
)
tree_obj._lazy_use_index = self._index is not None
else:
tree_obj._lazy_station_transforms.clear()
for station_path in self._iter_station_paths():
source_station_ds = self.get_station(station_path)
interp_periods = target_periods
if bounds_error and "period" in source_station_ds.coords:
station_periods = np.asarray(
source_station_ds.coords["period"].values,
dtype=float,
)
if station_periods.size > 0:
interp_periods = target_periods[
(target_periods <= station_periods.max())
& (target_periods >= station_periods.min())
]
source_snapshot = source_station_ds.copy(deep=False)
target_snapshot = np.asarray(interp_periods, dtype=float)
interp_kwargs = dict(kwargs)
def _transform(
ds: xr.Dataset = source_snapshot,
periods: np.ndarray = target_snapshot,
op_kwargs: dict[str, Any] = interp_kwargs,
) -> xr.Dataset:
return MTData._interpolate_station_dataset(ds, periods, **op_kwargs)
tree_obj._lazy_station_transforms[station_path] = _transform
if not inplace:
tree_obj._set_station_dataset(
station_path, source_snapshot.copy(deep=False)
)
if tree_obj.metadata_storage == "cache":
for metadata_kind in ["survey", "station"]:
cached_md = self._metadata_cache[metadata_kind].get(station_path)
if cached_md is not None:
tree_obj._metadata_cache[metadata_kind][
station_path
] = cached_md
return tree_obj
[docs]
def apply_bounding_box(
self, lon_min: float, lon_max: float, lat_min: float, lat_max: float
) -> "MTData":
"""Return stations that fall inside a lon/lat bounding box.
Parameters
----------
lon_min, lon_max : float
Longitude bounds.
lat_min, lat_max : float
Latitude bounds.
Returns
-------
MTData
Subset tree containing stations inside the bounding box.
"""
if self._index is not None:
station_keys = self._index.query_station_paths(
lon_min=lon_min, lon_max=lon_max, lat_min=lat_min, lat_max=lat_max
)
return self.get_subset(station_keys)
station_df = self.station_locations
if station_df is None or station_df.empty:
return self.__class__(**dict(self.attrs))
bb_df = station_df.loc[
(station_df.longitude >= lon_min)
& (station_df.longitude <= lon_max)
& (station_df.latitude >= lat_min)
& (station_df.latitude <= lat_max)
]
station_keys = [
self._station_path(
self._clean_name(survey, "default"),
self._clean_name(station, "unknown_station"),
)
for survey, station in zip(bb_df.survey, bb_df.station)
]
return self.get_subset(station_keys)
[docs]
def rebuild_index(self, index_db_path: str = ":memory:") -> None:
"""
Build or replace the station index from the current tree contents.
Enables the index if it was not already active.
Parameters
----------
index_db_path : str
SQLite database path. Defaults to ``":memory:"`` (in-process).
"""
# Temporarily clear self._index so _iter_station_paths() uses the tree
# walk, not the (new, empty) index that would otherwise be returned.
saved = self._index
self._index = None
try:
new_index = MTDataTreeIndexStore(index_db_path)
new_index.rebuild_from_tree(self)
except Exception:
self._index = saved
raise
self._index = new_index
[docs]
def query_station_paths(
self,
survey: str | None = None,
lat_min: float | None = None,
lat_max: float | None = None,
lon_min: float | None = None,
lon_max: float | None = None,
period_min: float | None = None,
period_max: float | None = None,
) -> list[str]:
"""
Return station tree paths matching filter criteria via the index.
Requires the index to be enabled (``use_index=True`` or after calling
:meth:`rebuild_index`).
Parameters
----------
survey, lat_min, lat_max, lon_min, lon_max, period_min, period_max
See :meth:`MTDataTreeIndexStore.query_station_paths`.
Returns
-------
list[str]
"""
self.compute()
if self._index is None:
raise RuntimeError(
"Index not enabled. Pass use_index=True to the constructor "
"or call rebuild_index() first."
)
return self._index.query_station_paths(
survey=survey,
lat_min=lat_min,
lat_max=lat_max,
lon_min=lon_min,
lon_max=lon_max,
period_min=period_min,
period_max=period_max,
)
[docs]
def to_dataframe(
self,
utm_crs: Any | None = None,
cols: list[str] | None = None,
impedance_units: str = "mt",
) -> pd.DataFrame:
"""Convert all stations to a concatenated pandas DataFrame.
Parameters
----------
utm_crs : Any, optional
CRS override used when exporting station locations.
cols : list[str], optional
Column subset to include.
impedance_units : str, optional
Impedance unit convention for exported transfer-function values.
Returns
-------
pandas.DataFrame
Concatenated station dataframe.
"""
self.compute()
station_paths = self._iter_station_paths()
df_list = []
for path in station_paths:
station_ds = self.get_station(path)
try:
df_list.append(
self._station_dataset_to_dataframe(
station_ds,
utm_crs=utm_crs,
cols=cols,
impedance_units=impedance_units,
)
)
except Exception:
# Fallback keeps behavior for unexpected/legacy dataset layouts.
df_list.append(
self._dataset_to_mt(station_ds)
.to_dataframe(
utm_crs=utm_crs,
cols=cols,
impedance_units=impedance_units,
)
.dataframe
)
if not df_list:
return pd.DataFrame()
return pd.concat(df_list, ignore_index=True)
[docs]
def to_mt_dataframe(
self, utm_crs: Any | None = None, impedance_units: str = "mt"
) -> MTDataFrame:
"""Create an :class:`MTDataFrame` from all stations.
Parameters
----------
utm_crs : Any, optional
CRS override used during dataframe conversion.
impedance_units : str, optional
Impedance unit convention for exported values.
Returns
-------
MTDataFrame
MTDataFrame wrapping the concatenated station dataframe.
"""
return MTDataFrame(
self.to_dataframe(utm_crs=utm_crs, impedance_units=impedance_units)
)
[docs]
def from_dataframe(self, df: pd.DataFrame, impedance_units: str = "mt") -> None:
"""Populate the tree from a station dataframe.
Parameters
----------
df : pandas.DataFrame
Dataframe containing MT rows, grouped by station (and survey when
available).
impedance_units : str, optional
Unit convention used by impedance values in *df*.
"""
from .mt import MT
if df.empty:
return
group_cols = ["station"]
if "survey" in df.columns:
group_cols = ["survey", "station"]
mt_objects = []
for _, sdf in df.groupby(group_cols, sort=False):
mt_object = MT(period=sdf.period.unique())
mt_object.from_dataframe(sdf, impedance_units=impedance_units)
mt_objects.append(mt_object)
if mt_objects:
self.add_stations(mt_objects)
[docs]
def from_mt_dataframe(
self, mt_df: MTDataFrame, impedance_units: str = "mt"
) -> None:
"""Populate the tree from an :class:`MTDataFrame`.
Parameters
----------
mt_df : MTDataFrame
Input MTDataFrame.
impedance_units : str, optional
Unit convention used by impedance values.
"""
self.from_dataframe(mt_df.dataframe, impedance_units=impedance_units)
[docs]
def get_periods(self) -> np.ndarray:
"""Return sorted unique periods across all stations.
Returns
-------
numpy.ndarray
One-dimensional array of unique periods in ascending order.
"""
self.compute()
periods: list[np.ndarray] = []
def _walk(node: Any) -> None:
ds = getattr(node, "ds", None)
if isinstance(ds, xr.Dataset) and "period" in ds.coords:
periods.append(np.asarray(ds.coords["period"].values, dtype=float))
for child in getattr(node, "children", {}).values():
_walk(child)
_walk(self.tree)
if not periods:
return np.array([], dtype=float)
unique_periods = np.unique(np.concatenate(periods))
unique_periods.sort()
return unique_periods
[docs]
def keys(self) -> list[str]:
"""Return immediate top-level child node keys.
Returns
-------
list[str]
Names of direct children under the tree root.
"""
return list(self.tree.children.keys())
def _resolve_plot_station_key(
self,
station_key: str | None = None,
station_id: str | None = None,
survey_id: str | None = None,
) -> str:
"""Resolve plotting selectors to one canonical station tree path.
Parameters
----------
station_key : str, optional
Canonical station path or alternate station key accepted by
:meth:`_resolve_station_path`.
station_id : str, optional
Station identifier.
survey_id : str, optional
Survey identifier used to disambiguate duplicate station IDs.
Returns
-------
str
Canonical station tree path.
Raises
------
ValueError
If both *station_key* and *station_id* are missing, or if
*station_id* is ambiguous without *survey_id*.
KeyError
If no matching station can be resolved.
"""
if station_key is not None:
return self._resolve_station_path(station_key)
if station_id is None:
raise ValueError("Provide station_key or station_id")
station_name = self._clean_name(station_id, "unknown_station")
if survey_id is not None:
survey_name = self._clean_name(survey_id, "default")
return self._resolve_station_path(
self._station_path(survey_name, station_name)
)
matches = [
path
for path in self.station_paths
if path.rsplit("/", 1)[-1] == station_name
]
if len(matches) == 1:
return matches[0]
if len(matches) == 0:
raise KeyError(
"Station key not found for station_id without survey_id: "
f"{station_id}"
)
raise ValueError(
"Multiple stations matched station_id. "
"Provide survey_id to disambiguate."
)
[docs]
def plot_mt_response(
self,
station_key: str | list[str] | None = None,
station_id: str | list[str] | None = None,
survey_id: str | list[str] | None = None,
**kwargs: Any,
) -> PlotMultipleResponses | Any:
"""
Plot MT response for one or more stations.
Parameters
----------
station_key : str, list of str, optional
Station key(s) in canonical or accepted alternate form.
station_id : str, list of str, optional
Station ID(s). When provided without *survey_id*, each station ID
must be unique across surveys.
survey_id : str, list of str, optional
Survey ID(s) used with *station_id*. If list-valued, must align
one-to-one with *station_id*.
**kwargs : dict
Additional plotting keyword arguments.
Returns
-------
PlotMultipleResponses or plot object
Multi-station response plot for list-valued selectors, otherwise
a single-station MT response plot object.
Raises
------
ValueError
If list-valued *survey_id* does not match list-valued
*station_id*, or if station selection is ambiguous.
KeyError
If a requested station cannot be resolved.
Examples
--------
>>> tree.plot_mt_response(station_key="survey_a/st01")
>>> tree.plot_mt_response(station_id="st01", survey_id="survey_a")
>>> tree.plot_mt_response(
... station_id=["st01", "st02"],
... survey_id=["survey_a", "survey_b"],
... )
"""
if isinstance(station_key, (list, tuple)):
station_keys = [self._resolve_station_path(sk) for sk in station_key]
return PlotMultipleResponses(self.get_subset(station_keys), **kwargs)
elif isinstance(station_id, (list, tuple)):
station_ids = list(station_id)
if isinstance(survey_id, (list, tuple)):
survey_ids = list(survey_id)
if len(survey_ids) != len(station_ids):
raise ValueError("Number of survey must match number of stations")
else:
survey_ids = [survey_id] * len(station_ids)
station_keys = [
self._resolve_plot_station_key(
station_id=station,
survey_id=survey,
)
for survey, station in zip(survey_ids, station_ids)
]
return PlotMultipleResponses(self.get_subset(station_keys), **kwargs)
else:
station_path = self._resolve_plot_station_key(
station_id=station_id,
survey_id=survey_id,
station_key=station_key,
)
mt_object = self.get_station(station_path, as_mt=True)
return mt_object.plot_mt_response(**kwargs)
[docs]
def plot_stations(
self,
map_epsg: int = 4326,
bounding_box: tuple[float, float, float, float] | None = None,
model_locations: bool = False,
**kwargs: Any,
) -> PlotStations:
"""
Plot station locations on a map.
Parameters
----------
map_epsg : int, optional
EPSG code forwarded to :class:`PlotStations` as ``map_epsg``.
bounding_box : tuple of float, optional
Optional ``(lon_min, lon_max, lat_min, lat_max)`` used to subset
stations before plotting.
model_locations : bool, optional
Use model coordinates instead of geographic coordinates.
**kwargs : dict
Additional plotting keyword arguments.
Returns
-------
PlotStations
Station plot object
Raises
------
ValueError
If *bounding_box* is provided and does not contain four values.
Examples
--------
>>> tree.plot_stations()
>>> tree.plot_stations(map_epsg=3857)
>>> tree.plot_stations(bounding_box=(-121.5, -120.0, 36.5, 38.0))
"""
mt_data = self
if bounding_box is not None:
if len(bounding_box) != 4:
raise ValueError(
"bounding_box must be (lon_min, lon_max, lat_min, lat_max)"
)
mt_data = self.apply_bounding_box(*bounding_box)
gdf = mt_data.to_geo_df(model_locations=model_locations)
if model_locations:
kwargs["plot_cx"] = False
kwargs.setdefault("map_epsg", map_epsg)
return PlotStations(gdf, **kwargs)
[docs]
def plot_strike(self, **kwargs: Any) -> PlotStrike:
"""
Plot strike angle.
Parameters
----------
**kwargs : dict
Additional plotting keyword arguments.
Returns
-------
PlotStrike
Strike plot object
Examples
--------
>>> tree.plot_strike()
>>> tree.plot_strike(show_plot=False)
"""
return PlotStrike(self, **kwargs)
[docs]
def plot_phase_tensor(
self,
station_key: str | None = None,
station_id: str | None = None,
survey_id: str | None = None,
**kwargs: Any,
) -> Any:
"""
Plot phase tensor elements for a station.
Parameters
----------
station_key : str, optional
Station key in canonical or accepted alternate form.
station_id : str, optional
Station ID.
survey_id : str, optional
Survey ID used to disambiguate duplicate station IDs.
**kwargs : dict
Additional plotting keyword arguments.
Returns
-------
plot object
Phase tensor plot object
Raises
------
ValueError
If station selection is ambiguous.
KeyError
If the station cannot be resolved.
Examples
--------
>>> tree.plot_phase_tensor(station_key="survey_a/st01")
>>> tree.plot_phase_tensor(station_id="st01", survey_id="survey_a")
"""
station_path = self._resolve_plot_station_key(
station_id=station_id,
survey_id=survey_id,
station_key=station_key,
)
mt_object = self.get_station(station_path, as_mt=True)
return mt_object.plot_phase_tensor(**kwargs)
[docs]
def plot_phase_tensor_map(self, **kwargs: Any) -> PlotPhaseTensorMaps:
"""
Plot phase tensor maps.
Parameters
----------
**kwargs : dict
Additional plotting keyword arguments.
Returns
-------
PlotPhaseTensorMaps
Phase tensor map plot object
Examples
--------
>>> tree.plot_phase_tensor_map(plot_period=10)
>>> tree.plot_phase_tensor_map(plot_station=True)
"""
return PlotPhaseTensorMaps(mt_data=self, **kwargs)
[docs]
def plot_tipper_map(self, **kwargs: Any) -> PlotPhaseTensorMaps:
"""
Plot tipper (induction vector) maps.
Parameters
----------
**kwargs : dict
Additional plotting keyword arguments. Defaults are
``plot_pt=False`` and ``plot_tipper='yri'`` when not explicitly
provided.
Returns
-------
PlotPhaseTensorMaps
Tipper map plot object
Examples
--------
>>> tree.plot_tipper_map()
>>> tree.plot_tipper_map(plot_tipper="yri", plot_pt=False)
"""
kwargs.setdefault("plot_pt", False)
kwargs.setdefault("plot_tipper", "yri")
return PlotPhaseTensorMaps(mt_data=self, **kwargs)
[docs]
def plot_phase_tensor_pseudosection(
self, mt_data: "MTData" | None = None, **kwargs: Any
) -> PlotPhaseTensorPseudoSection:
"""
Plot phase tensor pseudosection.
Parameters
----------
mt_data : MTData, optional
MTData object to plot. Defaults to ``self``.
**kwargs : dict
Additional plotting keyword arguments.
Returns
-------
PlotPhaseTensorPseudoSection
Pseudosection plot object
Examples
--------
>>> tree.plot_phase_tensor_pseudosection()
>>> subset = tree.get_survey("survey_a")
>>> tree.plot_phase_tensor_pseudosection(mt_data=subset)
"""
if mt_data is None:
mt_data = self
return PlotPhaseTensorPseudoSection(mt_data=mt_data, **kwargs)
[docs]
def plot_penetration_depth_1d(
self,
station_key: str | None = None,
station_id: str | None = None,
survey_id: str | None = None,
**kwargs: Any,
) -> Any:
"""
Plot 1D penetration depth.
Parameters
----------
station_key : str, optional
Station key in canonical or accepted alternate form.
station_id : str, optional
Station ID.
survey_id : str, optional
Survey ID used to disambiguate duplicate station IDs.
**kwargs : dict
Additional plotting keyword arguments.
Returns
-------
plot object
Penetration depth plot object
Raises
------
ValueError
If station selection is ambiguous.
KeyError
If the station cannot be resolved.
Notes
-----
Based on Niblett-Bostick transformation
Examples
--------
>>> tree.plot_penetration_depth_1d(station_key="survey_a/st01")
>>> tree.plot_penetration_depth_1d(
... station_id="st01", survey_id="survey_a", depth_units="km"
... )
"""
station_path = self._resolve_plot_station_key(
station_id=station_id,
survey_id=survey_id,
station_key=station_key,
)
mt_object = self.get_station(station_path, as_mt=True)
return mt_object.plot_depth_of_penetration(**kwargs)
[docs]
def plot_penetration_depth_map(self, **kwargs: Any) -> PlotPenetrationDepthMap:
"""
Plot penetration depth in map view.
Parameters
----------
**kwargs : dict
Additional plotting keyword arguments.
Returns
-------
PlotPenetrationDepthMap
Penetration depth map plot object
Examples
--------
>>> tree.plot_penetration_depth_map(plot_period=10)
>>> tree.plot_penetration_depth_map(depth_units="km")
"""
return PlotPenetrationDepthMap(mt_data=self, **kwargs)
[docs]
def plot_resistivity_phase_maps(self, **kwargs: Any) -> PlotResPhaseMaps:
"""
Plot apparent resistivity and/or phase maps.
Parameters
----------
**kwargs : dict
Additional plotting keyword arguments.
Returns
-------
PlotResPhaseMaps
Resistivity/phase map plot object
Examples
--------
>>> tree.plot_resistivity_phase_maps(plot_period=10)
>>> tree.plot_resistivity_phase_maps(plot_xy=True, plot_yx=False)
"""
return PlotResPhaseMaps(mt_data=self, **kwargs)
[docs]
def plot_resistivity_phase_pseudosections(
self, **kwargs: Any
) -> PlotResPhasePseudoSection:
"""
Plot resistivity and phase pseudosections.
Parameters
----------
**kwargs : dict
Additional plotting keyword arguments.
Returns
-------
PlotResPhasePseudoSection
Pseudosection plot object
Examples
--------
>>> tree.plot_resistivity_phase_pseudosections()
>>> tree.plot_resistivity_phase_pseudosections(interpolation_method="nearest")
"""
return PlotResPhasePseudoSection(mt_data=self, **kwargs)
[docs]
def plot_residual_phase_tensor_maps(
self, survey_01: str, survey_02: str, **kwargs: Any
) -> PlotResidualPTMaps:
"""
Plot residual phase tensor maps.
Parameters
----------
survey_01 : str
First survey ID.
survey_02 : str
Second survey ID.
**kwargs : dict
Additional plotting keyword arguments.
Returns
-------
PlotResidualPTMaps
Residual phase tensor map plot object
Raises
------
KeyError
If either survey ID is not present in the current MTData.
Examples
--------
>>> tree.plot_residual_phase_tensor_maps("survey_a", "survey_b")
>>> tree.plot_residual_phase_tensor_maps(
... "survey_a", "survey_b", plot_freq=1.0
... )
"""
survey_data_01 = self.get_survey(survey_01)
survey_data_02 = self.get_survey(survey_02)
if survey_data_01.n_stations == 0:
raise KeyError(f"Survey not found: {survey_01}")
if survey_data_02.n_stations == 0:
raise KeyError(f"Survey not found: {survey_02}")
return PlotResidualPTMaps(survey_data_01, survey_data_02, **kwargs)