# -*- coding: utf-8 -*-
"""
Scaffold for a tree-backed MT data container.
This class is an outline for migrating from OrderedDict-based MTData to an
Xarray tree representation for better scalability.
"""
from __future__ import annotations
import importlib
from copy import deepcopy
from pathlib import Path
from typing import Any, Callable, TYPE_CHECKING
import numpy as np
import pandas as pd
import xarray as xr
from loguru import logger
from mtpy.core.transfer_function import IMPEDANCE_UNITS
from mtpy.imaging import (
PlotMultipleResponses,
PlotPenetrationDepthMap,
PlotPhaseTensorMaps,
PlotPhaseTensorPseudoSection,
PlotResidualPTMaps,
PlotResPhaseMaps,
PlotResPhasePseudoSection,
PlotStations,
PlotStrike,
)
from mtpy.modeling.errors import ModelErrors
from . import mt_data_accessor as _mt_data_accessor # noqa: F401
from .mt_data_tree_index import MTDataTreeIndexStore
from .mt_dataframe import MTDataFrame
COORDINATE_REFERENCE_FRAME_OPTIONS = {
"+": "ned",
"-": "enu",
"z+": "ned",
"z-": "enu",
"nez+": "ned",
"enz-": "enu",
"ned": "ned",
"enu": "enu",
"exp(+ i\\omega t)": "ned",
"exp(+i\\omega t)": "ned",
"exp(- i\\omega t)": "enu",
"exp(-i\\omega t)": "enu",
None: "ned",
}
if TYPE_CHECKING:
from .mt import MT
from .mt_stations import MTStations
[docs]
class MTData:
"""
Tree-backed container for MT collection data.
Notes
-----
Composition is intentionally used instead of inheriting from xarray's tree
type. This keeps the public MT API independent from xarray internals and
allows controlled migration from MTData.
"""
ROOT_NAME = "root"
SURVEYS_NODE = "surveys"
STATIONS_NODE = "stations"
METADATA_STORAGE_MODES = {"dict", "summary", "cache"}
DATASET_COPY_MODES = {"deep", "shallow", "none"}
COORDINATE_REFERENCE_FRAME = COORDINATE_REFERENCE_FRAME_OPTIONS
IMPEDANCE_UNITS = IMPEDANCE_UNITS
def __init__(
self,
tree: Any | None = None,
metadata_storage: str = "cache",
dataset_copy_mode: str = "shallow",
use_index: bool = False,
index_db_path: str = ":memory:",
**attrs: Any,
) -> None:
"""Initialize an MTData container.
Parameters
----------
tree : Any, optional
Existing tree-like object, typically an ``xarray.DataTree``.
When ``None``, an empty tree with a root dataset is created.
metadata_storage : {'dict', 'summary', 'cache'}, optional
Strategy used to store station and survey metadata in dataset
attributes.
dataset_copy_mode : {'deep', 'shallow', 'none'}, optional
Default dataset copy behavior used when adding stations.
use_index : bool, optional
If ``True``, enable an SQLite-backed station/period index for fast
geographic and period queries.
index_db_path : str, optional
SQLite database path used by the index.
**attrs
Additional root-level attributes stored on ``self.tree.attrs``.
Raises
------
ValueError
If *metadata_storage* or *dataset_copy_mode* is not a supported
option.
"""
storage_mode = str(metadata_storage).strip().lower()
if storage_mode not in self.METADATA_STORAGE_MODES:
raise ValueError(
"metadata_storage must be one of "
f"{sorted(self.METADATA_STORAGE_MODES)}"
)
self.metadata_storage = storage_mode
copy_mode = str(dataset_copy_mode).strip().lower()
if copy_mode not in self.DATASET_COPY_MODES:
raise ValueError(
"dataset_copy_mode must be one of " f"{sorted(self.DATASET_COPY_MODES)}"
)
self.dataset_copy_mode = copy_mode
# Optional in-memory metadata cache keyed by station tree path.
self._metadata_cache: dict[str, dict[str, Any]] = {
"survey": {},
"station": {},
}
# Optional SQLite-backed index for fast geographic / period queries.
self._index: MTDataTreeIndexStore | None = (
MTDataTreeIndexStore(index_db_path) if use_index else None
)
self._index_db_path = index_db_path
self._lazy_use_index = use_index
# Deferred station-level transforms keyed by station path.
self._lazy_station_transforms: dict[str, Callable[[], xr.Dataset]] = {}
self.tree = (
tree
if tree is not None
else xr.DataTree(name=self.ROOT_NAME, dataset=xr.Dataset())
)
# Keep root metadata lightweight and schema-focused at initialization.
self.tree.attrs.setdefault("schema_name", "mtpy.mt_data_tree")
self.tree.attrs.setdefault("schema_version", "0.1.0")
self.tree.attrs.update(attrs)
self.attrs = self.tree.attrs
self._coordinate_reference_frame_options = dict(self.COORDINATE_REFERENCE_FRAME)
self._coordinate_reference_frame = "+"
self._impedance_unit_factors = dict(self.IMPEDANCE_UNITS)
self._impedance_units = "mt"
self.data_rotation_angle = 0
self.model_parameters: dict[str, Any] = {}
self._center_lat = None
self._center_lon = None
self._center_elev = 0.0
self.z_model_error = ModelErrors(
error_value=5,
error_type="geometric_mean",
floor=True,
mode="impedance",
)
self.t_model_error = ModelErrors(
error_value=0.02,
error_type="absolute",
floor=True,
mode="tipper",
)
# Initialize a predictable top-level path for survey grouping.
if self.SURVEYS_NODE not in self.tree.children:
self.tree[self.SURVEYS_NODE] = xr.DataTree(
name=self.SURVEYS_NODE, dataset=xr.Dataset()
)
self.coordinate_reference_frame = self.attrs.get(
"coordinate_reference_frame", "ned"
)
self.impedance_units = self.attrs.get("impedance_units", "mt")
def __deepcopy__(self, memo: dict) -> "MTData":
"""Create a deep copy of MTData object."""
copied_tree = self.tree.copy(deep=True)
copied = self.__class__(
tree=copied_tree,
metadata_storage=self.metadata_storage,
dataset_copy_mode=self.dataset_copy_mode,
use_index=False,
index_db_path=self._index_db_path,
**dict(self.attrs),
)
memo[id(self)] = copied
copied._metadata_cache = deepcopy(self._metadata_cache, memo)
copied._lazy_station_transforms = dict(self._lazy_station_transforms)
copied._lazy_use_index = self._lazy_use_index
if self._index is not None and not copied.is_lazy:
copied.rebuild_index(index_db_path=self._index_db_path)
return copied
@property
def station_paths(self) -> list[str]:
"""Return sorted station paths present in the tree."""
return sorted(self._iter_station_paths())
@property
def short_station_paths(self) -> list[str]:
"""Return sorted station paths in ``survey/station`` form."""
return sorted(
[
f"{parts[1]}/{parts[3]}"
for path in self.station_paths
if isinstance(path, str)
for parts in [path.split("/")]
if len(parts) >= 4
]
)
@property
def datum_epsg(self) -> int | None:
"""Return the root datum EPSG code when available."""
value = self.attrs.get("datum_crs", self.attrs.get("datum_epsg"))
epsg = self._coerce_epsg_value(value)
if epsg is None:
return None
if str(epsg).isdigit():
return int(epsg)
return None
@datum_epsg.setter
def datum_epsg(self, value: Any) -> None:
"""Set root datum CRS/EPSG."""
self.datum_crs = value
@property
def datum_crs(self) -> Any | None:
"""Return the root datum CRS/EPSG value."""
return self.attrs.get("datum_crs", self.attrs.get("datum_epsg"))
@datum_crs.setter
def datum_crs(self, value: Any) -> None:
"""Set root datum CRS/EPSG."""
if value in [None, "", "None", "none", "null"]:
self.attrs.pop("datum_crs", None)
self.attrs.pop("datum_epsg", None)
return
self.attrs["datum_crs"] = value
epsg = self._coerce_epsg_value(value)
if epsg is not None:
self.attrs["datum_epsg"] = epsg
@property
def utm_epsg(self) -> int | None:
"""Return the root UTM EPSG code when available."""
value = self.attrs.get("utm_crs", self.attrs.get("utm_epsg"))
epsg = self._coerce_epsg_value(value)
if epsg is None:
return None
if str(epsg).isdigit():
return int(epsg)
return None
@utm_epsg.setter
def utm_epsg(self, value: Any) -> None:
"""Set root UTM CRS/EPSG and propagate to all station attrs."""
self.utm_crs = value
@property
def utm_crs(self) -> Any | None:
"""Return the root UTM CRS/EPSG value used for station projections."""
return self.attrs.get("utm_crs", self.attrs.get("utm_epsg"))
@utm_crs.setter
def utm_crs(self, value: Any) -> None:
"""Set root UTM CRS/EPSG and refresh station location attrs."""
if value in [None, "", "None", "none", "null"]:
self.attrs.pop("utm_crs", None)
self.attrs.pop("utm_epsg", None)
return
self.attrs["utm_crs"] = value
epsg = self._coerce_epsg_value(value)
if epsg is not None:
self.attrs["utm_epsg"] = epsg
self._apply_utm_crs_to_station_attrs(value)
def _apply_utm_crs_to_station_attrs(self, utm_crs: Any) -> None:
"""Apply a root UTM CRS/EPSG to all station attrs and recompute EN."""
from .mt_location import MTLocation
for station_path in self._iter_station_paths():
station = self.get_station(station_path)
attrs = station.attrs
attrs["utm_crs"] = utm_crs
latitude = attrs.get("latitude")
longitude = attrs.get("longitude")
if latitude in [None, "", "None", "none", "null"]:
continue
if longitude in [None, "", "None", "none", "null"]:
continue
try:
point = MTLocation(
latitude=float(latitude),
longitude=float(longitude),
utm_crs=utm_crs,
)
attrs["easting"] = float(point.east)
attrs["northing"] = float(point.north)
except Exception:
continue
@property
def survey_names(self) -> list[str]:
"""Return sorted survey names inferred from station paths."""
return sorted(
{
path.split("/")[1]
for path in self.station_paths
if isinstance(path, str) and path.count("/") >= 3
}
)
def __repr__(self) -> str:
"""Return a concise constructor-like summary for debugging."""
station_paths = self.station_paths
survey_names = self.survey_names
index_enabled = self._index is not None or self._lazy_use_index
return (
"MTData("
f"stations={len(station_paths)}, "
f"surveys={len(survey_names)}, "
f"lazy_stations={self.lazy_station_count}, "
f"metadata_storage='{self.metadata_storage}', "
f"dataset_copy_mode='{self.dataset_copy_mode}', "
f"index_enabled={index_enabled}"
")"
)
def __str__(self) -> str:
"""Return a human-readable summary of tree content and paths."""
station_paths = self.station_paths
survey_names = self.survey_names
preview_limit = 8
preview_paths = station_paths[:preview_limit]
index_enabled = self._index is not None or self._lazy_use_index
lines = [
"MTData Summary",
f" stations: {len(station_paths)}",
f" surveys: {len(survey_names)}",
f" lazy stations: {self.lazy_station_count}",
f" index enabled: {index_enabled}",
f" metadata storage: {self.metadata_storage}",
f" dataset copy mode: {self.dataset_copy_mode}",
f" impedance units: {self.impedance_units}",
(" coordinate reference frame: " f"{self.coordinate_reference_frame}"),
" survey names:",
]
if survey_names:
lines.extend([f" - {name}" for name in survey_names])
else:
lines.append(" - <none>")
lines.append(" station paths:")
if preview_paths:
lines.extend([f" - {path}" for path in preview_paths])
if len(station_paths) > preview_limit:
lines.append(f" - ... ({len(station_paths) - preview_limit} more)")
else:
lines.append(" - <none>")
return "\n".join(lines)
def __add__(self, other: Any) -> "MTData":
"""Return a new tree containing stations from ``self`` and ``other``.
Notes
-----
Existing station paths from ``self`` are overwritten by ``other`` when
duplicates are found. A warning is emitted for each overwritten path.
"""
if not isinstance(other, MTData):
return NotImplemented
merged = self.copy()
other.compute()
for station_path in other._iter_station_paths():
if merged._path_exists(station_path):
logger.warning(
"Overwriting existing station path during MTData merge: {}",
station_path,
)
station_ds = other.get_station(station_path).copy(deep=True)
merged._set_station_dataset(station_path, station_ds)
# Replace cached metadata entries for overwritten/new paths.
merged._clear_cached_metadata(station_path)
for metadata_kind in ["survey", "station"]:
cached_md = other._metadata_cache[metadata_kind].get(station_path)
if cached_md is not None:
merged._metadata_cache[metadata_kind][station_path] = deepcopy(
cached_md
)
if merged._index is not None and not merged.is_lazy:
merged.rebuild_index(index_db_path=merged._index_db_path)
return merged
[docs]
def copy(self) -> "MTData":
"""Create a deep copy of MTData object."""
return deepcopy(self)
[docs]
def clone_empty(self) -> "MTData":
"""Create a copy of MTData excluding all station datasets."""
return self.__class__(
metadata_storage=self.metadata_storage,
dataset_copy_mode=self.dataset_copy_mode,
use_index=self._index is not None or self._lazy_use_index,
index_db_path=self._index_db_path,
**dict(self.attrs),
)
@staticmethod
def _metadata_to_dict(metadata: Any) -> dict[str, Any]:
"""Safely convert mt_metadata objects to dictionaries."""
if metadata is None:
return {}
if hasattr(metadata, "to_dict"):
for kwargs in ({"single": True}, {}):
try:
out = metadata.to_dict(**kwargs)
if isinstance(out, dict):
return out
except TypeError:
continue
return {}
@staticmethod
def _metadata_to_summary(metadata: Any) -> dict[str, Any]:
"""Build a lightweight metadata summary for fast per-station attrs."""
if metadata is None:
return {}
md_id = getattr(metadata, "id", None)
if md_id in [None, "", "None", "none", "null"]:
return {}
return {"id": str(md_id)}
def _serialize_metadata(self, metadata: Any) -> dict[str, Any]:
"""Serialize metadata according to configured storage mode."""
if self.metadata_storage == "dict":
return self._metadata_to_dict(metadata)
return self._metadata_to_summary(metadata)
def _metadata_ref(self, station_path: str, metadata: Any) -> str | None:
"""Return metadata reference key for cache mode without mutating cache."""
if self.metadata_storage != "cache" or metadata is None:
return None
return station_path
def _commit_cached_metadata(
self,
station_path: str,
survey_metadata: Any,
station_metadata: Any,
) -> None:
"""Persist metadata objects in the in-memory cache after successful insert."""
if self.metadata_storage != "cache":
return
if survey_metadata is not None:
self._metadata_cache["survey"][station_path] = survey_metadata
if station_metadata is not None:
self._metadata_cache["station"][station_path] = station_metadata
def _clear_cached_metadata(self, node_path: str) -> None:
"""Remove cached metadata for one node path or an entire subtree prefix."""
for metadata_kind in ["survey", "station"]:
keys_to_remove = [
key
for key in self._metadata_cache[metadata_kind]
if key == node_path or key.startswith(f"{node_path}/")
]
for key in keys_to_remove:
self._metadata_cache[metadata_kind].pop(key, None)
def _resolve_dataset_copy_mode(self, dataset_copy_mode: str | None) -> str:
"""Resolve copy mode from call-level override or instance default."""
mode = (
self.dataset_copy_mode if dataset_copy_mode is None else dataset_copy_mode
)
mode = str(mode).strip().lower()
if mode not in self.DATASET_COPY_MODES:
raise ValueError(
"dataset_copy_mode must be one of " f"{sorted(self.DATASET_COPY_MODES)}"
)
return mode
@staticmethod
def _copy_station_dataset(station_ds: xr.Dataset, mode: str) -> xr.Dataset:
"""Copy station dataset according to selected copy mode."""
if mode == "none":
return station_ds
if mode == "deep":
return station_ds.copy(deep=True)
return station_ds.copy(deep=False)
def _extract_station_dataset(
self, mt_obj: "MT", dataset_copy_mode: str | None = None
) -> xr.Dataset:
"""Extract an xarray.Dataset from MT object transfer function."""
tf_obj = getattr(mt_obj, "_transfer_function", None)
if tf_obj is None:
raise TypeError("MT object is missing _transfer_function")
if isinstance(tf_obj, xr.Dataset):
source_ds = tf_obj
elif hasattr(tf_obj, "to_xarray"):
source_ds = tf_obj.to_xarray()
elif hasattr(tf_obj, "_dataset") and isinstance(tf_obj._dataset, xr.Dataset):
source_ds = tf_obj._dataset
else:
raise TypeError("Could not extract xarray.Dataset from MT object")
copy_mode = self._resolve_dataset_copy_mode(dataset_copy_mode)
return self._copy_station_dataset(source_ds, copy_mode)
def _build_station_attrs(
self,
mt_obj: "MT",
survey: str,
station: str,
survey_metadata: Any,
station_metadata: Any,
survey_metadata_ref: str | None,
station_metadata_ref: str | None,
) -> dict[str, Any]:
"""Build default station attrs payload for one MT object."""
return {
"survey": survey,
"station": station,
"tf_id": getattr(mt_obj, "tf_id", station),
"latitude": getattr(mt_obj, "latitude", None),
"longitude": getattr(mt_obj, "longitude", None),
"elevation": getattr(mt_obj, "elevation", None),
"datum_crs": getattr(mt_obj, "datum_crs", None),
"utm_crs": self._get_utm_crs(mt_obj),
"easting": getattr(mt_obj, "east", None),
"northing": getattr(mt_obj, "north", None),
"model_east": getattr(mt_obj, "model_east", 0.0),
"model_north": getattr(mt_obj, "model_north", 0.0),
"model_elevation": getattr(mt_obj, "model_elevation", 0.0),
"profile_offset": getattr(mt_obj, "profile_offset", 0.0),
"coordinate_reference_frame": getattr(
mt_obj, "coordinate_reference_frame", None
),
"impedance_units": getattr(mt_obj, "impedance_units", None),
"survey_metadata": self._serialize_metadata(survey_metadata),
"station_metadata": self._serialize_metadata(station_metadata),
"survey_metadata_ref": survey_metadata_ref,
"station_metadata_ref": station_metadata_ref,
}
def _build_station_attrs_from_precomputed(
self,
mt_obj: "MT",
survey: str,
station: str,
survey_metadata: Any,
station_metadata: Any,
survey_metadata_ref: str | None,
station_metadata_ref: str | None,
precomputed_attrs: dict[str, Any],
) -> dict[str, Any]:
"""Build attrs from precomputed payload plus required canonical keys."""
station_attrs = dict(precomputed_attrs)
station_attrs["survey"] = survey
station_attrs["station"] = station
station_attrs.setdefault("tf_id", getattr(mt_obj, "tf_id", station))
station_attrs.setdefault(
"survey_metadata", self._serialize_metadata(survey_metadata)
)
station_attrs.setdefault(
"station_metadata", self._serialize_metadata(station_metadata)
)
station_attrs["survey_metadata_ref"] = survey_metadata_ref
station_attrs["station_metadata_ref"] = station_metadata_ref
return station_attrs
def _coerce_and_prepare_station(
self,
mt_obj: "MT | str | Path",
dataset_copy_mode: str | None = None,
precomputed_attrs: dict[str, Any] | None = None,
) -> tuple[str, str, xr.Dataset, dict[str, Any]]:
"""Coerce station input and build station path/dataset payload."""
mt_obj = self._coerce_mt_object(mt_obj)
survey = self._clean_name(
getattr(mt_obj, "survey", None)
or getattr(getattr(mt_obj, "survey_metadata", None), "id", None),
"default",
)
station = self._clean_name(
getattr(mt_obj, "station", None)
or getattr(getattr(mt_obj, "station_metadata", None), "id", None),
"unknown_station",
)
station_path = self._station_path(survey, station)
station_ds = self._extract_station_dataset(
mt_obj, dataset_copy_mode=dataset_copy_mode
)
survey_metadata_obj = getattr(mt_obj, "survey_metadata", None)
station_metadata_obj = getattr(mt_obj, "station_metadata", None)
survey_metadata_ref = self._metadata_ref(station_path, survey_metadata_obj)
station_metadata_ref = self._metadata_ref(station_path, station_metadata_obj)
if precomputed_attrs is None:
station_attrs = self._build_station_attrs(
mt_obj,
survey,
station,
survey_metadata_obj,
station_metadata_obj,
survey_metadata_ref,
station_metadata_ref,
)
else:
station_attrs = self._build_station_attrs_from_precomputed(
mt_obj,
survey,
station,
survey_metadata_obj,
station_metadata_obj,
survey_metadata_ref,
station_metadata_ref,
precomputed_attrs,
)
station_ds.attrs.update(station_attrs)
return (
station_path,
station,
station_ds,
{
"survey": survey_metadata_obj,
"station": station_metadata_obj,
},
)
def _cache_metadata(
self, station_path: str, metadata_kind: str, metadata: Any
) -> str | None:
"""Cache full metadata object in-memory and return reference key."""
if self.metadata_storage != "cache" or metadata is None:
return None
self._metadata_cache[metadata_kind][station_path] = metadata
return station_path
@property
def metadata_cache(self) -> dict[str, dict[str, Any]]:
"""In-memory metadata map keyed by station path for cache mode."""
return self._metadata_cache
@property
def is_lazy(self) -> bool:
"""True when one or more deferred station transforms are pending."""
return bool(self._lazy_station_transforms)
@property
def lazy_station_count(self) -> int:
"""Number of stations with pending deferred transforms."""
return len(self._lazy_station_transforms)
@property
def coordinate_reference_frame(self) -> str:
"""
Coordinate reference frame.
Returns
-------
str
Reference frame identifier ('NED' or 'ENU')
"""
return self._coordinate_reference_frame_options[
self._coordinate_reference_frame
].upper()
@coordinate_reference_frame.setter
def coordinate_reference_frame(self, value: str) -> None:
"""
Set coordinate reference frame.
Parameters
----------
value : str
Reference frame identifier. Options:
- 'NED': x=North, y=East, z=+down
- 'ENU': x=East, y=North, z=+up
Raises
------
ValueError
If value is not a recognized reference frame
Notes
-----
Updates coordinate reference frame for all MT objects in collection
"""
if not isinstance(value, str):
raise TypeError("Coordinate reference frame input must be a string.")
normalized = value.strip().lower()
if normalized not in self._coordinate_reference_frame_options:
raise ValueError(
f"{value} is not understood as a reference frame. "
f"Options are {self._coordinate_reference_frame_options}"
)
if normalized in ["ned", "+"]:
normalized = "+"
elif normalized in ["enu", "-"]:
normalized = "-"
self._coordinate_reference_frame = normalized
self.attrs["coordinate_reference_frame"] = self.coordinate_reference_frame
for station_path in self._iter_station_paths():
station_ds = self.get_station(station_path)
station_ds.attrs["coordinate_reference_frame"] = (
self.coordinate_reference_frame
)
@property
def impedance_units(self) -> str:
"""
Impedance units.
Returns
-------
str
Impedance units ('mt' or 'ohm')
"""
return self._impedance_units
@impedance_units.setter
def impedance_units(self, value: str) -> None:
"""
Set impedance units.
Parameters
----------
value : str
Impedance units. Options: 'mt' [mV/km/nT] or 'ohm' [Ohms]
Raises
------
TypeError
If value is not a string
ValueError
If value is not 'mt' or 'ohm'
Notes
-----
Updates impedance units for all MT objects in collection
"""
if not isinstance(value, str):
raise TypeError("Units input must be a string.")
if value.lower() not in self._impedance_unit_factors.keys():
raise ValueError(f"{value} is not an acceptable unit for impedance.")
self._impedance_units = value.lower()
self.attrs["impedance_units"] = self._impedance_units
for station_path in self._iter_station_paths():
station_ds = self.get_station(station_path)
station_ds.attrs["impedance_units"] = self._impedance_units
def _realize_station(self, station_path: str) -> str | None:
"""Materialize one deferred station transform if present."""
transform = self._lazy_station_transforms.pop(station_path, None)
if transform is None:
return None
station_ds = transform()
self._set_station_dataset(station_path, station_ds)
if self._index is not None:
station_row, period_row = MTDataTreeIndexStore._extract_rows(
station_path,
station_ds,
)
if period_row is None:
self._index.delete_station_by_tree_path(station_path)
self._index.upsert_station(station_row)
if period_row is not None:
self._index.replace_station_period_rows(period_row)
return station_row.survey_name
return None
@staticmethod
def _is_dask_delayed(obj: Any) -> bool:
"""Return True when *obj* is a dask delayed object."""
return obj.__class__.__name__ == "Delayed" and hasattr(obj, "dask")
[docs]
def compute(
self,
station_paths: list[str] | None = None,
scheduler: str | None = None,
) -> "MTData":
"""Materialize deferred station transforms and refresh index state.
Parameters
----------
station_paths : list[str], optional
Subset of station tree paths to realize. When ``None``, all pending
lazy station transforms are computed.
scheduler : str, optional
Dask scheduler name passed through when delayed transforms are
present.
Returns
-------
MTData
The current tree instance.
"""
if not self._lazy_station_transforms:
return self
if self._index is None and self._lazy_use_index:
self._index = MTDataTreeIndexStore(self._index_db_path)
pending_paths = list(self._lazy_station_transforms.keys())
station_paths = self._normalize_station_paths(station_paths)
if station_paths is None:
paths_to_realize = pending_paths
else:
requested = set(station_paths)
paths_to_realize = [path for path in pending_paths if path in requested]
realized_datasets: dict[str, xr.Dataset] = {}
delayed_paths: list[str] = []
delayed_objs: list[Any] = []
for station_path in paths_to_realize:
transform = self._lazy_station_transforms.pop(station_path, None)
if transform is None:
continue
out = transform()
if self._is_dask_delayed(out):
delayed_paths.append(station_path)
delayed_objs.append(out)
else:
realized_datasets[station_path] = out
if delayed_objs:
try:
dask = importlib.import_module("dask")
except ImportError as exc:
raise RuntimeError(
"Dask delayed transforms are pending but dask is not installed."
) from exc
computed = dask.compute(*delayed_objs, scheduler=scheduler)
for station_path, station_ds in zip(delayed_paths, computed):
realized_datasets[station_path] = station_ds
updated_surveys: set[str] = set()
for station_path in paths_to_realize:
station_ds = realized_datasets.get(station_path)
if station_ds is None:
continue
if not isinstance(station_ds, xr.Dataset):
raise TypeError(
"Deferred transform must return xr.Dataset, "
f"got {type(station_ds)!r}"
)
self._set_station_dataset(station_path, station_ds)
if self._index is not None:
station_row, period_row = MTDataTreeIndexStore._extract_rows(
station_path,
station_ds,
)
if period_row is None:
self._index.delete_station_by_tree_path(station_path)
self._index.upsert_station(station_row)
if period_row is not None:
self._index.replace_station_period_rows(period_row)
updated_surveys.add(station_row.survey_name)
if self._index is not None:
for survey_name in updated_surveys:
self._index.refresh_survey_aggregates(survey_name)
return self
[docs]
def persist(
self,
station_paths: list[str] | None = None,
scheduler: str | None = None,
) -> "MTData":
"""Alias for :meth:`compute`.
Parameters
----------
station_paths : list[str], optional
Station paths to realize.
scheduler : str, optional
Dask scheduler name.
Returns
-------
MTData
The current tree instance.
"""
return self.compute(station_paths=station_paths, scheduler=scheduler)
[docs]
def as_dask(
self,
chunks: dict[str, int] | str | None,
station_paths: list[str] | None = None,
variables: list[str] | None = None,
inplace: bool = False,
) -> "MTData":
"""Chunk station datasets to dask-backed arrays.
Parameters
----------
chunks : dict[str, int] or str or None
Chunk specification passed to xarray ``chunk``.
station_paths : list[str], optional
Subset of station paths to chunk.
variables : list[str], optional
Data-variable names to chunk. When ``None``, all variables are
chunked.
inplace : bool, optional
If ``True``, modify this tree in place. Otherwise return a chunked
subset copy.
Returns
-------
MTData
Chunked tree (same instance when *inplace* is ``True``).
Raises
------
RuntimeError
If dask is not installed.
KeyError
If a requested variable is missing for a selected station.
Examples
--------
>>> tree = tree.as_dask(chunks={"period": 32})
>>> tree = tree.as_dask(chunks="auto", variables=["transfer_function"])
"""
try:
importlib.import_module("dask.array")
except ImportError as exc:
raise RuntimeError("Dask is required for as_dask()") from exc
station_paths = self._normalize_station_paths(station_paths)
tree_obj = self if inplace else self.get_subset(self._iter_station_paths())
target_paths = tree_obj._iter_station_paths()
if station_paths is not None:
requested = set(station_paths)
target_paths = [path for path in target_paths if path in requested]
for station_path in target_paths:
station_ds = tree_obj.get_station(station_path)
if variables is not None:
missing = [
name for name in variables if name not in station_ds.data_vars
]
if missing:
raise KeyError(f"Variables not found for chunking: {missing}")
chunked_ds = station_ds.copy(deep=False)
for var_name in variables:
chunked_ds[var_name] = station_ds[var_name].chunk(chunks)
else:
chunked_ds = station_ds.chunk(chunks)
tree_obj._set_station_dataset(station_path, chunked_ds)
return tree_obj
[docs]
def rechunk(
self,
chunks: dict[str, int] | str | None,
station_paths: list[str] | None = None,
variables: list[str] | None = None,
inplace: bool = True,
) -> "MTData":
"""Rechunk station datasets.
Parameters
----------
chunks : dict[str, int] or str or None
Chunk specification passed to :meth:`as_dask`.
station_paths : list[str], optional
Subset of station paths to rechunk.
variables : list[str], optional
Data variables to rechunk.
inplace : bool, optional
If ``True`` (default), modify the current tree.
Returns
-------
MTData
Rechunked tree.
"""
return self.as_dask(
chunks=chunks,
station_paths=station_paths,
variables=variables,
inplace=inplace,
)
[docs]
def is_dask_backed(self, station_paths: list[str] | None = None) -> bool:
"""Check whether selected stations are dask-backed.
Parameters
----------
station_paths : list[str], optional
Station subset to inspect. When ``None``, all stations are checked.
Returns
-------
bool
``True`` only when each selected station has dask-backed arrays for
all data variables.
"""
station_paths = self._normalize_station_paths(station_paths)
self.compute(station_paths=station_paths)
target_paths = self._iter_station_paths()
if station_paths is not None:
requested = set(station_paths)
target_paths = [path for path in target_paths if path in requested]
if not target_paths:
return False
for station_path in target_paths:
station_ds = self.get_station(station_path)
for da in station_ds.data_vars.values():
if getattr(da.data, "chunks", None) is None:
return False
return True
[docs]
def chunk_plan(
self,
station_paths: list[str] | None = None,
) -> dict[str, dict[str, tuple[tuple[int, ...], ...] | None]]:
"""Return per-station chunk layout for each data variable.
Parameters
----------
station_paths : list[str], optional
Station subset to summarize.
Returns
-------
dict[str, dict[str, tuple[tuple[int, ...], ...] or None]]
Mapping from station path to variable chunk tuples (or ``None`` for
NumPy-backed variables).
"""
station_paths = self._normalize_station_paths(station_paths)
self.compute(station_paths=station_paths)
target_paths = self._iter_station_paths()
if station_paths is not None:
requested = set(station_paths)
target_paths = [path for path in target_paths if path in requested]
plan: dict[str, dict[str, tuple[tuple[int, ...], ...] | None]] = {}
for station_path in target_paths:
station_ds = self.get_station(station_path)
plan[station_path] = {
var_name: da.chunks for var_name, da in station_ds.data_vars.items()
}
return plan
[docs]
def map_stations(
self,
transform: Callable[[xr.Dataset], xr.Dataset],
station_paths: list[str] | None = None,
lazy: bool = True,
inplace: bool = False,
) -> "MTData":
"""Apply a dataset transform to selected stations.
Parameters
----------
transform : callable
Function receiving one station ``xr.Dataset`` and returning an
``xr.Dataset``.
station_paths : list[str], optional
Station subset to transform.
lazy : bool, optional
If ``True`` (default), register deferred transforms. If ``False``,
apply immediately.
inplace : bool, optional
If ``True``, mutate this tree. Otherwise return a transformed copy.
Returns
-------
MTData
Tree with registered or applied transforms.
Raises
------
TypeError
If *lazy* is ``False`` and *transform* does not return an
``xr.Dataset``.
Examples
--------
>>> def keep_short_periods(ds):
... return ds.sel(period=ds.period <= 10.0)
>>> out = tree.map_stations(keep_short_periods, lazy=False, inplace=False)
"""
station_paths = self._normalize_station_paths(station_paths)
tree_obj = self if inplace else self.get_subset(self._iter_station_paths())
tree_obj.compute()
target_paths = tree_obj._iter_station_paths()
if station_paths is not None:
requested = set(station_paths)
target_paths = [path for path in target_paths if path in requested]
for station_path in target_paths:
station_ds = tree_obj.get_station(station_path).copy(deep=False)
if lazy:
tree_obj._lazy_station_transforms[station_path] = (
lambda ds=station_ds, op=transform: op(ds)
)
if not lazy:
def _validated_transform(ds: xr.Dataset) -> xr.Dataset:
out_ds = transform(ds)
if not isinstance(out_ds, xr.Dataset):
raise TypeError(
"map_stations transform must return xr.Dataset, "
f"got {type(out_ds)!r}"
)
return out_ds
tree_obj.tree.mt.map_stations(
_validated_transform,
station_paths=target_paths,
inplace=True,
)
return tree_obj
[docs]
def interpolate_dask(
self,
new_periods: np.ndarray,
f_type: str = "period",
bounds_error: bool = True,
chunks: dict[str, int] | str | None = None,
scheduler: str | None = None,
compute: bool = True,
inplace: bool = False,
**kwargs: Any,
) -> "MTData":
"""Interpolate stations with dask-delayed execution.
Parameters
----------
new_periods : ndarray
Target periods or frequencies depending on *f_type*.
f_type : str, optional
``'period'`` (default) or ``'frequency'``/``'freq'``.
bounds_error : bool, optional
Restrict interpolation to each station's native period range.
chunks : dict[str, int] or str or None, optional
Optional chunking applied before creating delayed transforms.
scheduler : str, optional
Dask scheduler used during computation when *compute* is ``True``.
compute : bool, optional
If ``True`` (default), execute delayed transforms immediately.
inplace : bool, optional
If ``True``, modify this tree.
**kwargs
Forwarded to the station interpolation routine.
Returns
-------
MTData
Tree with interpolated results or pending delayed transforms.
"""
try:
dask = importlib.import_module("dask")
delayed = getattr(dask, "delayed")
except ImportError as exc:
raise RuntimeError("Dask is required for interpolate_dask()") from exc
base_tree = self if inplace else self.get_subset(self._iter_station_paths())
if chunks is not None:
base_tree.as_dask(chunks=chunks, inplace=True)
lazy_tree = base_tree.interpolate_lazy(
new_periods,
f_type=f_type,
inplace=True,
bounds_error=bounds_error,
**kwargs,
)
for station_path, transform in list(lazy_tree._lazy_station_transforms.items()):
lazy_tree._lazy_station_transforms[station_path] = (
lambda fn=transform: delayed(fn)()
)
if compute:
lazy_tree.compute(scheduler=scheduler)
elif scheduler is not None:
dask.config.set(scheduler=scheduler)
return lazy_tree
[docs]
def rotate_dask(
self,
rotation_angle: float | np.ndarray,
chunks: dict[str, int] | str | None = None,
scheduler: str | None = None,
compute: bool = True,
inplace: bool = False,
) -> "MTData":
"""Rotate stations using dask-delayed execution.
**Note**: Rotation is defined as a rotation of the coordinate reference frame
of the transfer function, not a rotation of the physical measurement. So,
for example, a 90 degree rotation of a station with NED coordinates would
swap the North and East components of the transfer function and change the
coordinate reference frame to ENU.
Parameters
----------
rotation_angle : float or ndarray
Rotation angle in degrees, scalar or per-period array.
chunks : dict[str, int] or str or None, optional
Optional chunking applied before creating delayed transforms.
scheduler : str, optional
Dask scheduler used when *compute* is ``True``.
compute : bool, optional
If ``True`` (default), execute delayed transforms immediately.
inplace : bool, optional
If ``True``, modify this tree.
Returns
-------
MTData
Tree with rotated results or pending delayed transforms.
"""
try:
dask = importlib.import_module("dask")
delayed = getattr(dask, "delayed")
except ImportError as exc:
raise RuntimeError("Dask is required for rotate_dask()") from exc
base_tree = self if inplace else self.get_subset(self._iter_station_paths())
if chunks is not None:
base_tree.as_dask(chunks=chunks, inplace=True)
def _rotate_transform(ds: xr.Dataset) -> xr.Dataset:
crf = ds.attrs.get(
"coordinate_reference_frame",
self.attrs.get("coordinate_reference_frame", "ned"),
)
return MTData._rotate_station_dataset(
ds,
rotation_angle,
coordinate_reference_frame=crf,
)
lazy_tree = base_tree.map_stations(
_rotate_transform,
lazy=True,
inplace=True,
)
for station_path, transform in list(lazy_tree._lazy_station_transforms.items()):
lazy_tree._lazy_station_transforms[station_path] = (
lambda fn=transform: delayed(fn)()
)
if compute:
lazy_tree.compute(scheduler=scheduler)
elif scheduler is not None:
dask.config.set(scheduler=scheduler)
return lazy_tree
[docs]
def finalize_index(self) -> None:
"""Recompute deferred stations and rebuild the index."""
self.compute()
self.rebuild_index(index_db_path=self._index_db_path)
def _hydrate_metadata_from_cache(
self, mt_obj: "MT", station_ds: xr.Dataset
) -> None:
"""Populate MT metadata objects from in-memory cache when references exist."""
attrs = station_ds.attrs
for metadata_kind in ["survey", "station"]:
ref_key = attrs.get(f"{metadata_kind}_metadata_ref")
if not isinstance(ref_key, str):
continue
cached_md = self._metadata_cache[metadata_kind].get(ref_key)
if cached_md is None:
continue
target_md = getattr(mt_obj, f"{metadata_kind}_metadata", None)
if target_md is None or not hasattr(target_md, "from_dict"):
continue
cached_dict = self._metadata_to_dict(cached_md)
if cached_dict:
target_md.from_dict(cached_dict)
@staticmethod
def _clean_name(value: Any, fallback: str) -> str:
"""Normalize path segment names for tree paths."""
name = str(value).strip() if value is not None else ""
if not name:
return fallback
return name.replace("/", "_")
def _station_path(self, survey: str, station: str) -> str:
"""Build canonical station path under /surveys."""
return f"{self.SURVEYS_NODE}/{survey}/{self.STATIONS_NODE}/{station}"
def _resolve_station_path(self, station_key: str) -> str:
"""Resolve a public station key to the canonical stored tree path."""
if not isinstance(station_key, str) or not station_key.strip():
raise KeyError("station_key must be a non-empty string")
key = station_key.strip().strip("/")
candidates: list[str] = []
def _append(candidate: str) -> None:
if candidate not in candidates:
candidates.append(candidate)
if key.startswith(f"{self.SURVEYS_NODE}/"):
_append(key)
if "." in key:
survey, station = key.split(".", 1)
_append(
self._station_path(
self._clean_name(survey, "default"),
self._clean_name(station, "unknown_station"),
)
)
if key.count("/") == 1:
survey, station = key.split("/", 1)
_append(
self._station_path(
self._clean_name(survey, "default"),
self._clean_name(station, "unknown_station"),
)
)
_append(key)
for candidate in candidates:
if (
self._path_exists(candidate)
or candidate in self._lazy_station_transforms
):
return candidate
raise KeyError(f"Station key not found: {station_key}")
def _normalize_station_paths(
self, station_paths: list[str] | None
) -> list[str] | None:
"""Normalize public station-path inputs while preserving no-match behavior."""
if station_paths is None:
return None
normalized: list[str] = []
for station_key in station_paths:
try:
normalized.append(self._resolve_station_path(station_key))
except KeyError:
if isinstance(station_key, str):
normalized.append(station_key.strip().strip("/"))
else:
normalized.append(station_key)
return normalized
@staticmethod
def _coerce_mt_object(mt_obj: "MT | str | Path") -> "MT":
"""Convert supported inputs to an MT instance."""
from .mt import MT
if isinstance(mt_obj, MT):
return mt_obj
if isinstance(mt_obj, (str, Path)):
m = MT(mt_obj)
m.read()
return m
raise TypeError(
"mt_obj must be an MT instance, filename string, or pathlib.Path"
)
def _path_exists(self, node_path: str) -> bool:
"""Check if a tree node path exists."""
try:
_ = self.tree[node_path]
return True
except KeyError:
return False
def _iter_station_paths(self) -> list[str]:
"""Return all station node paths under /surveys."""
if self._index is not None:
return self._index.all_station_paths()
station_paths: list[str] = []
def _walk(node: Any, node_path: str = "") -> None:
ds = getattr(node, "ds", None)
if isinstance(ds, xr.Dataset) and node_path.count("/") >= 3:
station_paths.append(node_path)
for child_name, child in getattr(node, "children", {}).items():
child_path = f"{node_path}/{child_name}" if node_path else child_name
_walk(child, child_path)
_walk(self.tree)
return station_paths
@staticmethod
def _crs_to_epsg(value: Any) -> Any:
"""Convert a CRS-like value to an EPSG code when possible."""
if value in [None, "", "None", "none", "null"]:
return None
if hasattr(value, "to_epsg"):
return value.to_epsg()
return value
@staticmethod
def _station_locations_columns() -> list[str]:
"""Column order matching MTStations.station_locations."""
return [
"survey",
"station",
"latitude",
"longitude",
"elevation",
"datum_epsg",
"east",
"north",
"utm_epsg",
"model_east",
"model_north",
"model_elevation",
"profile_offset",
]
def _station_location_record(self, station_path: str) -> dict[str, Any]:
"""Build one station-location record directly from dataset attrs."""
attrs = self.get_station(station_path).attrs
return {
"survey": attrs.get("survey"),
"station": attrs.get("station"),
"latitude": attrs.get("latitude"),
"longitude": attrs.get("longitude"),
"elevation": attrs.get("elevation"),
"datum_epsg": self._crs_to_epsg(attrs.get("datum_crs")),
"east": attrs.get("easting"),
"north": attrs.get("northing"),
"utm_epsg": self._crs_to_epsg(attrs.get("utm_crs")),
"model_east": attrs.get("model_east", 0.0),
"model_north": attrs.get("model_north", 0.0),
"model_elevation": attrs.get("model_elevation", 0.0),
"profile_offset": attrs.get("profile_offset", 0.0),
}
def _station_path_to_location_mt(self, station_path: str) -> "MT":
"""Build a lightweight MT object containing only location metadata."""
from .mt import MT
attrs = self.get_station(station_path).attrs
mt_obj = MT()
if attrs.get("survey") is not None:
mt_obj.survey = attrs["survey"]
if attrs.get("station") is not None:
mt_obj.station = attrs["station"]
if attrs.get("datum_crs") is not None:
mt_obj.datum_crs = attrs["datum_crs"]
if attrs.get("utm_crs") is not None:
mt_obj.utm_crs = attrs["utm_crs"]
for attr_name, attr_value in [
("latitude", attrs.get("latitude")),
("longitude", attrs.get("longitude")),
("elevation", attrs.get("elevation")),
("east", attrs.get("easting")),
("north", attrs.get("northing")),
("model_east", attrs.get("model_east", 0.0)),
("model_north", attrs.get("model_north", 0.0)),
("model_elevation", attrs.get("model_elevation", 0.0)),
("profile_offset", attrs.get("profile_offset", 0.0)),
]:
if attr_value is not None:
setattr(mt_obj, attr_name, attr_value)
return mt_obj
@staticmethod
def _get_utm_crs(mt_obj: "MT") -> Any:
"""Get UTM CRS information from MT object when available."""
crs = getattr(mt_obj, "utm_crs", None)
if crs is not None:
return crs
return getattr(mt_obj, "utm_epsg", None)
@staticmethod
def _dataset_to_mt(station_ds: xr.Dataset) -> "MT":
"""Build an MT object from a stored station dataset and attrs."""
from .mt import MT
mt_obj = MT()
mt_obj._transfer_function = station_ds.copy()
attrs = station_ds.attrs
if attrs.get("survey") is not None:
mt_obj.survey = attrs["survey"]
if attrs.get("station") is not None:
mt_obj.station = attrs["station"]
if attrs.get("coordinate_reference_frame") is not None:
crf = attrs["coordinate_reference_frame"]
if isinstance(crf, str):
crf_key = crf.upper()
if crf_key == "NED":
crf = "+"
elif crf_key == "ENU":
crf = "-"
mt_obj.coordinate_reference_frame = crf
if attrs.get("impedance_units") is not None:
mt_obj.impedance_units = attrs["impedance_units"]
if attrs.get("datum_crs") is not None:
mt_obj.datum_crs = attrs["datum_crs"]
if attrs.get("utm_crs") is not None:
mt_obj.utm_crs = attrs["utm_crs"]
if attrs.get("latitude") is not None:
mt_obj.latitude = attrs["latitude"]
if attrs.get("longitude") is not None:
mt_obj.longitude = attrs["longitude"]
if attrs.get("elevation") is not None:
mt_obj.elevation = attrs["elevation"]
if attrs.get("easting") is not None:
mt_obj.east = attrs["easting"]
if attrs.get("northing") is not None:
mt_obj.north = attrs["northing"]
if attrs.get("profile_offset") is not None:
mt_obj.profile_offset = attrs["profile_offset"]
survey_md = attrs.get("survey_metadata", {})
if isinstance(survey_md, dict) and survey_md:
if hasattr(mt_obj.survey_metadata, "from_dict"):
mt_obj.survey_metadata.from_dict(survey_md)
station_md = attrs.get("station_metadata", {})
if isinstance(station_md, dict) and station_md:
if hasattr(mt_obj.station_metadata, "from_dict"):
mt_obj.station_metadata.from_dict(station_md)
return mt_obj
@staticmethod
def _pick_channel_labels(
available: list[Any], candidates: list[str], required: int
) -> list[Any] | None:
"""Pick channel labels from available coordinates using preferred names."""
channel_map = {str(label).lower(): label for label in available}
selected: list[Any] = []
for candidate in candidates:
key = candidate.lower()
if key in channel_map and channel_map[key] not in selected:
selected.append(channel_map[key])
if len(selected) == required:
return selected
return None
@staticmethod
def _coerce_epsg_value(value: Any) -> str | None:
"""Normalize CRS/EPSG values to a dataframe-compatible EPSG string."""
if value is None:
return None
if isinstance(value, (int, np.integer)):
return str(int(value))
try:
from pyproj import CRS
epsg = CRS.from_user_input(value).to_epsg()
if epsg is not None:
return str(int(epsg))
except Exception:
pass
value_str = str(value).strip()
if not value_str:
return None
if value_str.isdigit():
return str(int(value_str))
return value_str
def _station_dataset_to_dataframe(
self,
station_ds: xr.Dataset,
utm_crs: Any | None = None,
cols: list[str] | None = None,
impedance_units: str = "mt",
) -> pd.DataFrame:
"""Convert one station dataset directly into dataframe rows."""
from .transfer_function import Tipper, Z
period = np.asarray(station_ds.coords["period"].values, dtype=float)
n_entries = period.size
station_df = MTDataFrame(n_entries=n_entries)
attrs = station_ds.attrs
station_df.survey = attrs.get("survey", "")
station_df.station = attrs.get("station", "")
station_df.latitude = attrs.get("latitude", 0.0)
station_df.longitude = attrs.get("longitude", 0.0)
station_df.elevation = attrs.get("elevation", 0.0)
station_df.datum_epsg = self._coerce_epsg_value(attrs.get("datum_crs"))
station_df.east = attrs.get("easting", 0.0)
station_df.north = attrs.get("northing", 0.0)
station_df.utm_epsg = self._coerce_epsg_value(
utm_crs if utm_crs is not None else attrs.get("utm_crs")
)
station_df.model_east = attrs.get("model_east", 0.0)
station_df.model_north = attrs.get("model_north", 0.0)
station_df.model_elevation = attrs.get("model_elevation", 0.0)
station_df.profile_offset = attrs.get("profile_offset", 0.0)
station_df.dataframe.loc[:, "period"] = period
if "output" in station_ds.coords and "input" in station_ds.coords:
output_labels = list(station_ds.coords["output"].values)
input_labels = list(station_ds.coords["input"].values)
z_outputs = self._pick_channel_labels(
output_labels, ["ex", "ey", "x", "y"], 2
)
z_inputs = self._pick_channel_labels(
input_labels, ["hx", "hy", "x", "y"], 2
)
if z_outputs is not None and z_inputs is not None:
tf = (
station_ds["transfer_function"]
.sel(output=z_outputs, input=z_inputs)
.values
)
tf_error = (
station_ds["transfer_function_error"]
.sel(output=z_outputs, input=z_inputs)
.values
)
tf_model_error = (
station_ds["transfer_function_model_error"]
.sel(output=z_outputs, input=z_inputs)
.values
)
z_object = Z(
z=tf,
z_error=tf_error,
frequency=1.0 / period,
z_model_error=tf_model_error,
input_units="mt",
output_units=impedance_units,
)
station_df.from_z_object(z_object, units=impedance_units)
t_output = self._pick_channel_labels(output_labels, ["hz", "z"], 1)
t_inputs = self._pick_channel_labels(
input_labels, ["hx", "hy", "x", "y"], 2
)
if t_output is not None and t_inputs is not None:
tipper = (
station_ds["transfer_function"]
.sel(output=t_output, input=t_inputs)
.values
)
tipper_error = (
station_ds["transfer_function_error"]
.sel(output=t_output, input=t_inputs)
.values
)
tipper_model_error = (
station_ds["transfer_function_model_error"]
.sel(output=t_output, input=t_inputs)
.values
)
tipper_object = Tipper(
tipper=tipper,
tipper_error=tipper_error,
frequency=1.0 / period,
tipper_model_error=tipper_model_error,
)
station_df.from_t_object(tipper_object)
if cols is None:
return station_df.dataframe
return station_df.dataframe.loc[:, cols]
[docs]
def to_mt_stations(self) -> "MTStations":
"""Build an :class:`MTStations` view from current station locations.
Returns
-------
MTStations
Station-location container backed by ``self.station_locations``.
"""
from .mt_stations import MTStations
return MTStations(
self.utm_epsg,
datum_epsg=self.datum_epsg,
station_locations=self.station_locations,
)
@property
def center_point(self) -> Any:
"""Return the geographic center point of the station collection.
If explicit center coordinates have been stored (e.g. after reading a
ModEM data file), those values are returned directly. Otherwise the
center is derived on-the-fly from all station locations via
:meth:`to_mt_stations`.
Returns
-------
MTLocation
Center location with ``latitude``, ``longitude``, ``elevation``,
``east``, ``north``, and ``utm_epsg`` attributes populated.
Examples
--------
>>> from mtpy.core import MTData
>>> tree = MTData()
>>> # (add stations first)
>>> cp = tree.center_point
>>> print(cp.latitude, cp.longitude)
"""
from .mt_location import MTLocation
if self._center_lat is not None and self._center_lon is not None:
center_location = MTLocation()
center_location.latitude = self._center_lat
center_location.longitude = self._center_lon
center_location.elevation = self._center_elev
utm_epsg = self.attrs.get("utm_epsg")
if utm_epsg not in [None, "", "None", "none", "null"]:
center_location.utm_epsg = utm_epsg
datum_crs = self.attrs.get("datum_crs")
if datum_crs not in [None, "", "None", "none", "null"]:
center_location.datum_crs = datum_crs
center_location.model_east = center_location.east
center_location.model_north = center_location.north
center_location.model_elevation = self._center_elev
return center_location
return self.to_mt_stations().center_point
def _dataframe_with_relative_locations(
self,
utm_crs: Any | None = None,
impedance_units: str = "mt",
) -> pd.DataFrame:
"""Return a station dataframe with model-relative coordinates populated.
Calls :meth:`to_dataframe` and, if ``model_east``/``model_north`` are
all zero but absolute ``east``/``north`` values are available, computes
the model coordinates relative to :attr:`center_point`.
Parameters
----------
utm_crs : pyproj CRS or int, optional
Override UTM CRS passed to :meth:`to_dataframe`, by default
``None`` (use the tree's stored CRS).
impedance_units : str, optional
Units for the impedance tensor, e.g. ``'mt'`` or ``'ohm'``,
by default ``'mt'``.
Returns
-------
pandas.DataFrame
Station dataframe with ``model_east``, ``model_north``, and
``model_elevation`` columns filled.
Raises
------
ValueError
If ``model_east``/``model_north`` are zero, absolute UTM
coordinates are available, but no UTM EPSG is set on
:attr:`center_point`.
"""
df = self.to_dataframe(utm_crs=utm_crs, impedance_units=impedance_units).copy()
if df.empty:
return df
model_east = pd.to_numeric(df.get("model_east"), errors="coerce").fillna(0.0)
model_north = pd.to_numeric(df.get("model_north"), errors="coerce").fillna(0.0)
if not (np.allclose(model_east, 0.0) and np.allclose(model_north, 0.0)):
return df
east = pd.to_numeric(df.get("east"), errors="coerce").fillna(0.0)
north = pd.to_numeric(df.get("north"), errors="coerce").fillna(0.0)
if np.allclose(east, 0.0) or np.allclose(north, 0.0):
return df
center = self.center_point
if center.utm_epsg is None:
raise ValueError(
"Need to input data UTM EPSG or CRS to compute relative station locations"
)
df.loc[:, "model_east"] = east - center.east
df.loc[:, "model_north"] = north - center.north
df.loc[:, "model_elevation"] = (
pd.to_numeric(df.get("elevation"), errors="coerce").fillna(0.0)
- center.elevation
)
return df
[docs]
def to_geo_df(
self,
model_locations: bool = False,
data_type: str = "station_locations",
) -> Any:
"""Create a GeoDataFrame for GIS workflows.
Parameters
----------
model_locations : bool, optional
If ``True``, use ``model_east``/``model_north`` as geometry.
Otherwise use longitude/latitude.
data_type : str, optional
One of ``'station_locations'`` (or ``'stations'``), ``'pt'``,
``'tipper'``, or ``'both'``.
Returns
-------
geopandas.GeoDataFrame
GeoDataFrame with point geometries.
Raises
------
ImportError
If geopandas is not installed.
ValueError
If *data_type* is unsupported.
"""
try:
import geopandas as gpd
except ImportError as exc:
raise ImportError(
"geopandas is required for to_geo_df but is not installed"
) from exc
if data_type in ["station_locations", "stations"]:
df = self.station_locations
elif data_type in ["phase_tensor", "pt"]:
df = self.to_mt_dataframe().phase_tensor
elif data_type in ["tipper", "t"]:
df = self.to_mt_dataframe().tipper
elif data_type in ["both", "shapefiles"]:
df = self.to_mt_dataframe().for_shapefiles
else:
raise ValueError(f"Option for 'data_type' {data_type} is unsupported.")
if model_locations:
return gpd.GeoDataFrame(
df,
geometry=gpd.points_from_xy(df.model_east, df.model_north),
crs=None,
)
crs_value = None
if "datum_epsg" in df.columns:
for value in df["datum_epsg"].tolist():
epsg_value = self._coerce_epsg_value(value)
if epsg_value is None:
continue
if str(epsg_value).isdigit():
crs_value = f"EPSG:{epsg_value}"
else:
crs_value = epsg_value
break
return gpd.GeoDataFrame(
df,
geometry=gpd.points_from_xy(df.longitude, df.latitude),
crs=crs_value,
)
[docs]
def to_shp_pt_tipper(
self,
save_dir: str | Path,
output_crs: Any | None = None,
utm: bool = False,
pt: bool = True,
tipper: bool = True,
periods: np.ndarray | None = None,
period_tol: float | None = None,
ellipse_size: float | None = None,
arrow_size: float | None = None,
) -> dict[str, list[str]]:
"""Write phase-tensor and tipper shapefiles.
Parameters
----------
save_dir : str or pathlib.Path
Output directory for shapefiles.
output_crs : Any, optional
Output coordinate reference system.
utm : bool, optional
If ``True``, export in UTM coordinates.
pt : bool, optional
If ``True``, write phase-tensor shapefiles.
tipper : bool, optional
If ``True``, write tipper shapefiles.
periods : numpy.ndarray, optional
Periods to export. When ``None``, use all available periods.
period_tol : float, optional
Period matching tolerance.
ellipse_size : float, optional
Phase-tensor ellipse size. When ``None`` and *pt* is ``True``, the
size is estimated automatically.
arrow_size : float, optional
Tipper arrow size. When ``None`` and *tipper* is ``True``, the size
is estimated automatically.
Returns
-------
dict[str, list[str]]
Mapping of output type to written shapefile paths.
Notes
-----
For mixed station period sampling, interpolate first so all stations
share a common period set.
"""
from mtpy.gis.shapefile_creator import ShapefileCreator
sc = ShapefileCreator(self.to_mt_dataframe(), output_crs, save_dir=save_dir)
sc.utm = utm
if ellipse_size is None and pt:
sc.ellipse_size = sc.estimate_ellipse_size()
else:
sc.ellipse_size = ellipse_size
if arrow_size is None and tipper:
sc.arrow_size = sc.estimate_arrow_size()
else:
sc.arrow_size = arrow_size
return sc.make_shp_files(
pt=pt,
tipper=tipper,
periods=periods,
period_tol=period_tol,
)
@property
def station_locations(self) -> pd.DataFrame:
"""Station-location table built directly from tree dataset attrs."""
columns = self._station_locations_columns()
if self._index is not None:
records = self._index.all_station_records()
if not records:
return pd.DataFrame(columns=columns)
return pd.DataFrame(
[
{
"survey": r.survey_name,
"station": r.name,
"latitude": r.latitude,
"longitude": r.longitude,
"elevation": r.elevation,
"datum_epsg": r.datum_epsg,
"east": r.east,
"north": r.north,
"utm_epsg": r.utm_epsg,
"model_east": r.model_east,
"model_north": r.model_north,
"model_elevation": r.model_elevation,
"profile_offset": r.profile_offset,
}
for r in records
],
columns=columns,
)
station_paths = self._iter_station_paths()
if not station_paths:
return pd.DataFrame(columns=columns)
return pd.DataFrame(
[self._station_location_record(path) for path in station_paths],
columns=columns,
)
@property
def mt_stations(self) -> "MTStations":
"""Convenience accessor for station locations represented by the tree."""
return self.to_mt_stations()
def _sync_station_locations_from_mt_stations(
self,
mt_stations: "MTStations",
) -> None:
"""Write MTStations location updates back into tree station attrs."""
self.compute()
station_df = mt_stations.station_locations
if station_df is None or station_df.empty:
return
key_to_path: dict[tuple[str, str], str] = {}
for station_path in self._iter_station_paths():
attrs = self.get_station(station_path).attrs
key = (
self._clean_name(attrs.get("survey"), "default"),
self._clean_name(attrs.get("station"), "unknown_station"),
)
key_to_path[key] = station_path
for row in station_df.itertuples(index=False):
survey = self._clean_name(getattr(row, "survey", None), "default")
station = self._clean_name(
getattr(row, "station", None),
"unknown_station",
)
station_path = key_to_path.get((survey, station))
if station_path is None:
continue
attrs = self.get_station(station_path).attrs
attrs["latitude"] = getattr(row, "latitude", attrs.get("latitude"))
attrs["longitude"] = getattr(row, "longitude", attrs.get("longitude"))
attrs["elevation"] = getattr(row, "elevation", attrs.get("elevation"))
attrs["easting"] = getattr(row, "east", attrs.get("easting"))
attrs["northing"] = getattr(row, "north", attrs.get("northing"))
attrs["model_east"] = getattr(row, "model_east", attrs.get("model_east"))
attrs["model_north"] = getattr(
row,
"model_north",
attrs.get("model_north"),
)
attrs["model_elevation"] = getattr(
row,
"model_elevation",
attrs.get("model_elevation"),
)
attrs["profile_offset"] = getattr(
row,
"profile_offset",
attrs.get("profile_offset"),
)
datum_epsg = self._coerce_epsg_value(getattr(row, "datum_epsg", None))
if datum_epsg is not None:
attrs["datum_crs"] = datum_epsg
utm_epsg = self._coerce_epsg_value(getattr(row, "utm_epsg", None))
if utm_epsg is not None:
attrs["utm_crs"] = utm_epsg
attrs["utm_epsg"] = utm_epsg
if mt_stations.utm_epsg is not None:
self.attrs["utm_epsg"] = mt_stations.utm_epsg
self.attrs["utm_crs"] = mt_stations.utm_epsg
if mt_stations.datum_epsg is not None:
self.attrs["datum_crs"] = mt_stations.datum_epsg
self.data_rotation_angle = getattr(
mt_stations,
"rotation_angle",
self.data_rotation_angle,
)
self._center_lat = getattr(mt_stations, "_center_lat", self._center_lat)
self._center_lon = getattr(mt_stations, "_center_lon", self._center_lon)
self._center_elev = getattr(mt_stations, "_center_elev", self._center_elev)
if self._index is not None:
self.rebuild_index(index_db_path=self._index_db_path)
[docs]
def compute_relative_locations(self) -> None:
"""Compute model-relative station coordinates and sync to the tree."""
stations = self.to_mt_stations()
stations.compute_relative_locations()
self._sync_station_locations_from_mt_stations(stations)
[docs]
def rotate_stations(self, rotation_angle: float) -> None:
"""Rotate station model coordinates and sync to the tree."""
stations = self.to_mt_stations()
stations.rotate_stations(rotation_angle)
self._sync_station_locations_from_mt_stations(stations)
[docs]
def center_stations(self, model_obj: Any) -> None:
"""Center station locations to model cell centers and sync to tree."""
stations = self.to_mt_stations()
stations.center_stations(model_obj)
self._sync_station_locations_from_mt_stations(stations)
[docs]
def project_stations_on_topography(
self,
model_object: Any,
air_resistivity: float = 1e12,
sea_resistivity: float = 0.3,
ocean_bottom: bool = False,
) -> None:
"""Project station elevations to model topography and sync to tree."""
stations = self.to_mt_stations()
stations.project_stations_on_topography(
model_object,
air_resistivity=air_resistivity,
sea_resistivity=sea_resistivity,
ocean_bottom=ocean_bottom,
)
self._sync_station_locations_from_mt_stations(stations)
[docs]
def to_geopd(self) -> Any:
"""Return station locations as a GeoDataFrame via MTStations."""
return self.to_mt_stations().to_geopd()
[docs]
def to_shp(self, shp_fn: str | Path) -> str | Path:
"""Write a station-location shapefile via MTStations."""
return self.to_mt_stations().to_shp(shp_fn)
[docs]
def to_csv(self, csv_fn: str | Path, geometry: bool = False) -> None:
"""Write station locations to CSV via MTStations."""
self.to_mt_stations().to_csv(csv_fn, geometry=geometry)
[docs]
def to_vtk(
self,
vtk_fn: str | Path | None = None,
vtk_save_path: str | Path | None = None,
vtk_fn_basename: str = "ModEM_stations",
geographic: bool = False,
shift_east: float = 0,
shift_north: float = 0,
shift_elev: float = 0,
units: str = "km",
coordinate_system: str = "nez+",
) -> Path:
"""Write a station-location VTK file via MTStations."""
return self.to_mt_stations().to_vtk(
vtk_fn=vtk_fn,
vtk_save_path=vtk_save_path,
vtk_fn_basename=vtk_fn_basename,
geographic=geographic,
shift_east=shift_east,
shift_north=shift_north,
shift_elev=shift_elev,
units=units,
coordinate_system=coordinate_system,
)
[docs]
def generate_profile(
self,
units: str = "deg",
) -> tuple[float, float, float, float, dict[str, float]]:
"""Generate a best-fit profile line via MTStations."""
return self.to_mt_stations().generate_profile(units=units)
[docs]
def generate_profile_from_strike(
self,
strike: float,
units: str = "deg",
) -> tuple[float, float, float, float, dict[str, float]]:
"""Generate a profile line from strike via MTStations."""
return self.to_mt_stations().generate_profile_from_strike(
strike,
units=units,
)
[docs]
def get_nearby_stations(
self,
station_key: str,
radius: float,
radius_units: str = "m",
) -> list[str]:
"""Find neighboring stations around a reference station.
Parameters
----------
station_key : str
Reference station key as canonical tree path or ``survey.station``.
radius : float
Search radius in the units specified by *radius_units*.
radius_units : {'m', 'meters', 'metres', 'deg', 'degrees'}, optional
Distance units for *radius*.
Returns
-------
list[str]
Matching stations as ``survey.station`` keys (excluding the
reference station).
Raises
------
ValueError
If metric units are requested without UTM coordinate information,
or if *radius_units* is unsupported.
Examples
--------
>>> nearby = tree.get_nearby_stations("surveyA.station01", radius=5000)
>>> nearby_deg = tree.get_nearby_stations("surveyA.station01", 0.1, "deg")
"""
self.compute()
station_path = self._resolve_station_path(station_key)
local_attrs = self.get_station(station_path).attrs
sdf = self.station_locations.copy()
if sdf.empty:
return []
if radius_units in ["m", "meters", "metres"]:
if "utm_epsg" not in sdf.columns or (
sdf["utm_epsg"].replace("", np.nan).dropna().empty
):
raise ValueError(
"Cannot estimate distances in meters without a UTM CRS. Set 'utm_crs' first."
)
sdf["radius"] = np.sqrt(
(
float(local_attrs.get("easting", 0.0))
- pd.to_numeric(sdf.east, errors="coerce").fillna(0.0)
)
** 2
+ (
float(local_attrs.get("northing", 0.0))
- pd.to_numeric(sdf.north, errors="coerce").fillna(0.0)
)
** 2
)
elif radius_units in ["deg", "degrees"]:
sdf["radius"] = np.sqrt(
(
float(local_attrs.get("longitude", 0.0))
- pd.to_numeric(sdf.longitude, errors="coerce").fillna(0.0)
)
** 2
+ (
float(local_attrs.get("latitude", 0.0))
- pd.to_numeric(sdf.latitude, errors="coerce").fillna(0.0)
)
** 2
)
else:
raise ValueError(
"radius_units must be one of: m, meters, metres, deg, degrees"
)
return [
f"{row.survey}.{row.station}"
for row in sdf.loc[(sdf.radius <= radius) & (sdf.radius > 0)].itertuples()
]
[docs]
def get_profile(
self,
x1: float,
y1: float,
x2: float,
y2: float,
radius: float,
) -> "MTData":
"""Extract stations within a corridor around a profile line.
Parameters
----------
x1, y1, x2, y2 : float
Profile start and end coordinates in the same coordinate system as
station locations.
radius : float
Corridor half-width around the profile line.
Returns
-------
MTData
New tree containing only stations that fall within the profile
corridor.
"""
self.compute()
profile_stations = self.to_mt_stations()._extract_profile(
x1,
y1,
x2,
y2,
radius,
)
if profile_stations.empty:
return self.clone_empty()
key_to_path: dict[tuple[str, str], str] = {}
for station_path in self._iter_station_paths():
attrs = self.get_station(station_path).attrs
key = (str(attrs.get("survey", "")), str(attrs.get("station", "")))
key_to_path[key] = station_path
selected_paths: list[str] = []
for row in profile_stations.itertuples(index=False):
key = (str(getattr(row, "survey")), str(getattr(row, "station")))
station_path = key_to_path.get(key)
if station_path is not None:
selected_paths.append(station_path)
profile_tree = self.get_subset(selected_paths)
for row in profile_stations.itertuples(index=False):
survey = self._clean_name(getattr(row, "survey", None), "default")
station = self._clean_name(
getattr(row, "station", None),
"unknown_station",
)
station_path = profile_tree._station_path(survey, station)
if not profile_tree._path_exists(station_path):
continue
if hasattr(row, "profile_offset"):
profile_tree.get_station(station_path).attrs["profile_offset"] = float(
getattr(row, "profile_offset")
)
return profile_tree
[docs]
def compute_model_errors(
self,
z_error_value: float | None = None,
z_error_type: str | None = None,
z_floor: bool | None = None,
t_error_value: float | None = None,
t_error_type: str | None = None,
t_floor: bool | None = None,
) -> None:
"""Recompute impedance and tipper model errors for all stations.
Parameters
----------
z_error_value, z_error_type, z_floor : optional
Overrides for impedance model-error settings.
t_error_value, t_error_type, t_floor : optional
Overrides for tipper model-error settings.
"""
self.compute()
if z_error_value is not None:
self.z_model_error.error_value = z_error_value
if z_error_type is not None:
self.z_model_error.error_type = z_error_type
if z_floor is not None:
self.z_model_error.floor = z_floor
if t_error_value is not None:
self.t_model_error.error_value = t_error_value
if t_error_type is not None:
self.t_model_error.error_type = t_error_type
if t_floor is not None:
self.t_model_error.floor = t_floor
for station_path in self._iter_station_paths():
station_ds = self.get_station(station_path)
attrs = dict(station_ds.attrs)
mt_obj = self._dataset_to_mt(station_ds)
self._hydrate_metadata_from_cache(mt_obj, station_ds)
mt_obj.compute_model_z_errors(**self.z_model_error.error_parameters)
mt_obj.compute_model_t_errors(**self.t_model_error.error_parameters)
out_ds = mt_obj._transfer_function
out_ds.attrs = attrs
self._set_station_dataset(station_path, out_ds)
[docs]
def estimate_starting_rho(self) -> None:
"""Estimate starting resistivity from all station data and plot summary curves."""
import matplotlib.pyplot as plt
self.compute()
entries: list[dict[str, float]] = []
for station_path in self._iter_station_paths():
mt_obj = self.get_station(station_path, as_mt=True)
for period, res_det in zip(mt_obj.period, mt_obj.Z.res_det):
entries.append({"period": period, "res_det": res_det})
res_df = pd.DataFrame(entries)
mean_rho = res_df.groupby("period").mean()
median_rho = res_df.groupby("period").median()
fig = plt.figure()
ax = fig.add_subplot(1, 1, 1)
(l1,) = ax.loglog(mean_rho.index, mean_rho.res_det, lw=2, color=(0.75, 0.25, 0))
(l2,) = ax.loglog(
median_rho.index, median_rho.res_det, lw=2, color=(0, 0.25, 0.75)
)
ax.loglog(
mean_rho.index,
np.repeat(mean_rho.res_det.mean(), mean_rho.shape[0]),
ls="--",
lw=2,
color=(0.75, 0.25, 0),
)
ax.loglog(
median_rho.index,
np.repeat(median_rho.res_det.median(), median_rho.shape[0]),
ls="--",
lw=2,
color=(0, 0.25, 0.75),
)
ax.set_xlabel("Period (s)", fontdict={"size": 12, "weight": "bold"})
ax.set_ylabel("Resistivity (Ohm-m)", fontdict={"size": 12, "weight": "bold"})
ax.legend(
[l1, l2],
[
f"Mean = {mean_rho.res_det.mean():.1f}",
f"Median = {median_rho.res_det.median():.1f}",
],
loc="upper left",
)
ax.grid(which="both", ls="--", color=(0.75, 0.75, 0.75))
ax.set_xlim((res_df.period.min(), res_df.period.max()))
plt.show()
[docs]
def to_modem(self, data_filename: str | Path | None = None, **kwargs: Any) -> Any:
"""Create a ModEM Data object from the station collection.
Parameters
----------
data_filename : str or pathlib.Path, optional
Path to write the ModEM data file. When ``None`` (default) the
file is not written.
**kwargs
Additional keyword arguments forwarded to
:class:`mtpy.modeling.modem.Data` (e.g. ``rotation_angle``,
``inv_mode``, ``formatting``).
Returns
-------
mtpy.modeling.modem.Data
Populated ModEM Data object with ``z_model_error`` and
``t_model_error`` set from the tree.
Examples
--------
Create a data file and retrieve the Data object:
>>> from mtpy.core import MTData
>>> tree = MTData()
>>> # tree.add_station(...) # populate with MT objects first
>>> tree.model_parameters = {"inv_mode": "1", "formatting": "1"}
>>> modem_data = tree.to_modem(data_filename="ModEM_data.dat")
>>> print(modem_data.center_point.latitude)
"""
from mtpy.modeling.modem import Data
modem_kwargs = dict(self.model_parameters)
modem_kwargs.update(kwargs)
modem_df = self._dataframe_with_relative_locations(
impedance_units=self.impedance_units
)
if modem_df.empty:
modem_df = self.to_dataframe(impedance_units=self.impedance_units)
modem_data = Data(
dataframe=modem_df,
center_point=self.center_point,
**modem_kwargs,
)
modem_data.z_model_error = self.z_model_error
modem_data.t_model_error = self.t_model_error
if data_filename is not None:
modem_data.write_data_file(file_name=data_filename)
return modem_data
[docs]
def from_modem(
self, data_filename: str | Path, survey: str = "data", **kwargs: Any
) -> None:
"""Populate the tree by reading an existing ModEM data file.
Station datasets, model-error parameters, the center point, and any
top-level model parameters (those without a dot in the key) are all
restored from the file.
Parameters
----------
data_filename : str or pathlib.Path
Path to the ModEM ``.dat`` / ``.data`` file.
survey : str, optional
Survey label to assign to all imported stations,
by default ``'data'``.
**kwargs
Additional keyword arguments forwarded to
:class:`mtpy.modeling.modem.Data`.
Examples
--------
>>> from mtpy.core import MTData
>>> tree = MTData()
>>> tree.from_modem("ModEM_data.dat", survey="line1")
>>> print(tree.survey_ids)
"""
from mtpy.modeling.modem import Data
modem_data = Data(**kwargs)
mdf = modem_data.read_data_file(data_filename)
mdf.dataframe.loc[:, "survey"] = survey
self.from_mt_dataframe(mdf)
self.z_model_error = ModelErrors(
mode="impedance", **modem_data.z_model_error.error_parameters
)
self.t_model_error = ModelErrors(
mode="tipper", **modem_data.t_model_error.error_parameters
)
self.data_rotation_angle = modem_data.rotation_angle
self._center_lat = modem_data.center_point.latitude
self._center_lon = modem_data.center_point.longitude
self._center_elev = modem_data.center_point.elevation
self.attrs["utm_epsg"] = modem_data.center_point.utm_epsg
self.attrs["datum_crs"] = modem_data.center_point.datum_crs
self.model_parameters = {
key: value
for key, value in modem_data.model_parameters.items()
if "." not in key
}
[docs]
def to_occam2d(self, data_filename: str | Path | None = None, **kwargs: Any) -> Any:
"""Create an Occam2D data object from the station collection.
Parameters
----------
data_filename : str or pathlib.Path, optional
Path to write the Occam2D data file. When ``None`` (default) the
file is not written.
**kwargs
Additional keyword arguments forwarded to
:class:`mtpy.modeling.occam2d.Occam2DData` (e.g.
``model_mode``, ``profile_angle``, ``res_te_err``).
Returns
-------
mtpy.modeling.occam2d.Occam2DData
Populated Occam2D data object with ``profile_origin`` set from
:attr:`center_point` when not supplied via *kwargs*.
Notes
-----
All information is derived from the station dataframe. The user
should create the profile, interpolate, and estimate model errors
from the tree before calling this method.
Examples
--------
>>> from mtpy.core import MTData
>>> tree = MTData()
>>> # tree.add_station(...) # populate first
>>> occam_data = tree.to_occam2d(
... data_filename="OccamDataFile.dat", model_mode="5"
... )
"""
from mtpy.modeling.occam2d import Occam2DData
occam2d_data = Occam2DData(**kwargs)
occam2d_data.dataframe = self.to_dataframe()
if occam2d_data.profile_origin is None:
cp = self.center_point
occam2d_data.profile_origin = (cp.east, cp.north)
if data_filename is not None:
occam2d_data.write_data_file(data_filename)
return occam2d_data
[docs]
def from_occam2d(
self,
data_filename: str | Path,
file_type: str = "data",
**kwargs: Any,
) -> None:
"""Populate the tree by reading an existing Occam2D data file.
After reading, ``profile_origin``, ``profile_angle``, and
``model_mode`` are stored in :attr:`model_parameters`.
Parameters
----------
data_filename : str or pathlib.Path
Path to the Occam2D data file.
file_type : str, optional
``'data'`` (default) or ``'response'``/``'model'``. Controls the
survey label (``'data'`` or ``'model'``) assigned to each row.
**kwargs
Additional keyword arguments forwarded to
:class:`mtpy.modeling.occam2d.Occam2DData`.
Examples
--------
Read a data file:
>>> from mtpy.core import MTData
>>> tree = MTData()
>>> tree.from_occam2d("OccamDataFile.dat")
>>> print(tree.station_ids)
Read a response / model file:
>>> tree.from_occam2d("OccamResponse.dat", file_type="response")
"""
from mtpy.modeling.occam2d import Occam2DData
occam2d_data = Occam2DData(**kwargs)
occam2d_data.read_data_file(data_filename)
if file_type in ["data"]:
occam2d_data.dataframe["survey"] = "data"
elif file_type in ["response", "model"]:
occam2d_data.dataframe["survey"] = "model"
self.from_dataframe(occam2d_data.dataframe)
self.model_parameters["profile_origin"] = occam2d_data.profile_origin
self.model_parameters["profile_angle"] = occam2d_data.profile_angle
self.model_parameters["model_mode"] = occam2d_data.model_mode
[docs]
def to_simpeg_2d(self, **kwargs: Any) -> Any:
"""Create a SimPEG 2-D MT data object from the station collection.
Parameters
----------
**kwargs
Additional keyword arguments forwarded to
:class:`mtpy.modeling.simpeg.data_2d.Simpeg2DData`. Common
options include:
include_elevation : bool
Include station elevation in the receiver locations,
by default ``True``.
invert_te : bool
Include TE-mode apparent resistivity and phase,
by default ``True``.
invert_tm : bool
Include TM-mode apparent resistivity and phase,
by default ``True``.
Returns
-------
mtpy.modeling.simpeg.data_2d.Simpeg2DData
Populated SimPEG 2-D data object.
Notes
-----
The impedance units are converted to ``'ohm'`` automatically.
The user should create the profile, interpolate, and estimate model
errors from the tree before calling this method.
Examples
--------
>>> from mtpy.core import MTData
>>> tree = MTData()
>>> # tree.add_station(...) # populate first
>>> simpeg_2d = tree.to_simpeg_2d(invert_te=True, invert_tm=False)
"""
from mtpy.modeling.simpeg.data_2d import Simpeg2DData
return Simpeg2DData(self.to_dataframe(impedance_units="ohm"), **kwargs)
[docs]
def to_simpeg_3d(self, **kwargs: Any) -> Any:
"""Create a SimPEG 3-D MT data object from the station collection.
Parameters
----------
**kwargs
Additional keyword arguments forwarded to
:class:`mtpy.modeling.simpeg.data_3d.Simpeg3DData`. Common
options include:
include_elevation : bool
Include station elevation in the receiver locations,
by default ``False``.
geographic_coordinates : bool
Use geographic (UTM) coordinates instead of model-relative
coordinates, by default ``True``.
invert_z_xx, invert_z_xy, invert_z_yx, invert_z_yy : bool
Select which impedance tensor components to include,
all default to ``True``.
invert_t_zx, invert_t_zy : bool
Select which tipper components to include,
both default to ``True``.
Returns
-------
mtpy.modeling.simpeg.data_3d.Simpeg3DData
Populated SimPEG 3-D data object.
Notes
-----
The impedance units are converted to ``'ohm'`` automatically.
The user should interpolate and estimate model errors from the tree
before calling this method.
Examples
--------
>>> from mtpy.core import MTData
>>> tree = MTData()
>>> # tree.add_station(...) # populate first
>>> simpeg_3d = tree.to_simpeg_3d(invert_z_yy=False, include_elevation=True)
"""
from mtpy.modeling.simpeg.data_3d import Simpeg3DData
return Simpeg3DData(self.to_dataframe(impedance_units="ohm"), **kwargs)
[docs]
def add_white_noise(self, value: float, inplace: bool = True) -> "MTData | None":
"""Add white noise to the impedance and tipper of every station.
Multiplies the real and imaginary parts of the transfer function by
independent random factors drawn from
``1 ± U(0, value) `` and increments the transfer-function error by
*value*. Useful for generating synthetic test datasets.
Parameters
----------
value : float
Noise level expressed as a decimal fraction (0–1) or as a
percentage (>1). Values greater than 1 are divided by 100
automatically (e.g. ``10`` becomes ``0.10``, i.e. 10 %).
inplace : bool, optional
When ``True`` (default) each station dataset is modified in place
and ``None`` is returned. When ``False`` a new
:class:`MTData` containing noisy copies is returned and the
original tree is left unchanged.
Returns
-------
MTData or None
A new tree containing noisy station data when *inplace* is
``False``; ``None`` otherwise.
Examples
--------
In-place noise addition:
>>> from mtpy.core import MTData
>>> tree = MTData()
>>> # tree.add_station(...) # populate first
>>> tree.add_white_noise(5) # adds 5 % noise in place
Non-destructive copy with noise:
>>> noisy_tree = tree.add_white_noise(0.05, inplace=False)
"""
if value > 1:
value = value / 100.0
paths = self._iter_station_paths()
if inplace:
for path in paths:
mt_obj = self.get_station(path, as_mt=True)
mt_obj.add_white_noise(value)
self.add_station(mt_obj)
return None
else:
mt_list = []
for path in paths:
mt_obj = self.get_station(path, as_mt=True)
mt_list.append(mt_obj.add_white_noise(value, inplace=False))
new_tree = self.clone_empty()
new_tree.add_stations(mt_list)
return new_tree
[docs]
def estimate_spatial_static_shift(
self,
station_key: str,
radius: float,
period_min: float,
period_max: float,
radius_units: str = "m",
shift_tolerance: float = 0.15,
) -> tuple[float, float]:
"""Estimate static-shift scale factors from nearby stations.
Parameters
----------
station_key : str
Target station key.
radius : float
Neighbor search radius.
period_min, period_max : float
Period bounds used for comparison.
radius_units : str, optional
Radius units passed to :meth:`get_nearby_stations`.
shift_tolerance : float, optional
Values within ``1 +/- shift_tolerance`` are snapped to ``1.0``.
Returns
-------
tuple[float, float]
Estimated ``(sx, sy)`` static-shift factors.
"""
nearby_keys = self.get_nearby_stations(station_key, radius, radius_units)
if len(nearby_keys) == 0:
return 1.0, 1.0
nearby_paths = [self._resolve_station_path(key) for key in nearby_keys]
md = self.get_subset(nearby_paths)
local_site = self.get_station(
self._resolve_station_path(station_key),
as_mt=True,
)
interp_periods = local_site.period[
np.where(
(local_site.period >= period_min) & (local_site.period <= period_max)
)
]
local_site = local_site.interpolate(interp_periods)
md.interpolate(interp_periods)
df = md.to_dataframe()
sx = np.nanmedian(df.res_xy) / np.nanmedian(local_site.Z.res_xy)
sy = np.nanmedian(df.res_yx) / np.nanmedian(local_site.Z.res_yx)
if 1 - shift_tolerance < sx < 1 + shift_tolerance:
sx = 1.0
if 1 - shift_tolerance < sy < 1 + shift_tolerance:
sy = 1.0
return sx, sy
@property
def n_stations(self) -> int:
"""Total number of stations in the collection."""
self.compute()
if self._index is not None:
return self._index.n_stations()
return len(self._iter_station_paths())
@property
def survey_ids(self) -> list[str]:
"""Unique survey IDs in the collection."""
self.compute()
if self._index is not None:
return [row.name for row in self._index.all_surveys()]
return list(
{
path.split("/", 3)[1]
for path in self._iter_station_paths()
if path.count("/") >= 3
}
)
[docs]
def get_survey(self, survey_id: str) -> "MTData":
"""Return a subset tree for one survey.
Parameters
----------
survey_id : str
Survey identifier.
Returns
-------
MTData
Tree containing all stations under the selected survey.
"""
self.compute()
station_list = [
station_path
for station_path in self._iter_station_paths()
if station_path.startswith(
f"{self.SURVEYS_NODE}/{survey_id}/{self.STATIONS_NODE}/"
)
]
return self.get_subset(station_list)
[docs]
def add_station(
self,
mt_obj: "MT | str | Path | list[MT | str | Path]",
overwrite: bool = True,
dataset_copy_mode: str | None = None,
) -> str | list[str]:
"""
Add an MT object as a station node in the tree.
Node path pattern:
/surveys/{survey_id}/stations/{station_id}
Parameters
----------
mt_obj : mtpy.core.MT, str, Path, or list
MT object, filename, pathlib.Path, or a list of mixed supported
input types.
overwrite : bool, optional
If False, raise if station path already exists.
dataset_copy_mode : {'deep', 'shallow', 'none'}, optional
Dataset copy behavior for station transfer-function storage.
Returns
-------
str or list[str]
Station node path for scalar inputs or list of paths for list
inputs.
"""
self.compute()
if mt_obj is None:
raise TypeError("mt_obj cannot be None")
if isinstance(mt_obj, list):
return self.add_stations(
mt_obj,
overwrite=overwrite,
dataset_copy_mode=dataset_copy_mode,
)
(
station_path,
station,
station_ds,
metadata_objects,
) = self._coerce_and_prepare_station(
mt_obj,
dataset_copy_mode=dataset_copy_mode,
)
if self._path_exists(station_path) and not overwrite:
raise KeyError(f"Station path already exists: {station_path}")
self.tree[station_path] = xr.DataTree(name=station, dataset=station_ds)
self._commit_cached_metadata(
station_path,
metadata_objects["survey"],
metadata_objects["station"],
)
if self._index is not None:
station_row, period_row = MTDataTreeIndexStore._extract_rows(
station_path, station_ds
)
self._index.upsert_station(station_row)
if period_row is not None:
self._index.replace_station_period_rows(period_row)
self._index.refresh_survey_aggregates(station_row.survey_name)
return station_path
[docs]
def add_tf(
self,
tf: "MT | str | Path | list[MT | str | Path]",
**kwargs: Any,
) -> str | list[str]:
"""Alias for add_station to mirror MTData API."""
return self.add_station(tf, **kwargs)
[docs]
def add_stations(
self,
mt_objects: list["MT | str | Path"],
overwrite: bool = True,
dataset_copy_mode: str | None = None,
precomputed_attrs: list[dict[str, Any] | None] | None = None,
) -> list[str]:
"""
Bulk-add MT stations with optional precomputed attrs for fast ingest.
Parameters
----------
mt_objects : list
List of MT objects, filename strings, or Paths.
overwrite : bool, optional
If False, raise if a station path already exists.
dataset_copy_mode : {'deep', 'shallow', 'none'}, optional
Dataset copy behavior for station transfer-function storage.
precomputed_attrs : list[dict | None], optional
Optional attrs payload aligned by index with mt_objects. When
provided, these attrs are used directly and only canonical keys are
enforced (survey/station and metadata refs).
Returns
-------
list[str]
Inserted station paths.
"""
self.compute()
if mt_objects is None:
raise TypeError("mt_objects cannot be None")
if not isinstance(mt_objects, list):
raise TypeError("mt_objects must be a list")
if not mt_objects:
return []
if precomputed_attrs is not None:
if not isinstance(precomputed_attrs, list):
raise TypeError("precomputed_attrs must be a list when provided")
if len(precomputed_attrs) != len(mt_objects):
raise ValueError("precomputed_attrs must match mt_objects length")
prepared: list[tuple[str, str, xr.Dataset, dict[str, Any]]] = []
seen_paths: set[str] = set()
for index, mt_obj in enumerate(mt_objects):
attrs = None
if precomputed_attrs is not None:
attrs = precomputed_attrs[index]
if attrs is not None and not isinstance(attrs, dict):
raise TypeError("Each precomputed_attrs entry must be dict or None")
(
station_path,
station,
station_ds,
metadata_objects,
) = self._coerce_and_prepare_station(
mt_obj,
dataset_copy_mode=dataset_copy_mode,
precomputed_attrs=attrs,
)
if station_path in seen_paths and not overwrite:
raise KeyError(f"Station path already exists: {station_path}")
seen_paths.add(station_path)
if self._path_exists(station_path) and not overwrite:
raise KeyError(f"Station path already exists: {station_path}")
prepared.append((station_path, station, station_ds, metadata_objects))
parent_cache: dict[str, Any] = {}
inserted_paths: list[str] = []
for station_path, station, station_ds, metadata_objects in prepared:
parent_path, child_name = station_path.rsplit("/", 1)
parent_node = parent_cache.get(parent_path)
if parent_node is None:
try:
parent_node = self.tree[parent_path]
except KeyError:
_, survey_name, _ = parent_path.split("/", 2)
survey_path = f"{self.SURVEYS_NODE}/{survey_name}"
if not self._path_exists(survey_path):
self.tree[survey_path] = xr.DataTree(
name=survey_name,
dataset=xr.Dataset(),
)
if not self._path_exists(parent_path):
self.tree[parent_path] = xr.DataTree(
name=self.STATIONS_NODE,
dataset=xr.Dataset(),
)
parent_node = self.tree[parent_path]
parent_cache[parent_path] = parent_node
parent_node[child_name] = xr.DataTree(
name=station,
dataset=station_ds,
)
self._commit_cached_metadata(
station_path,
metadata_objects["survey"],
metadata_objects["station"],
)
inserted_paths.append(station_path)
if self._index is not None:
updated_surveys: set[str] = set()
for station_path, _station, station_ds, _meta in prepared:
station_row, period_row = MTDataTreeIndexStore._extract_rows(
station_path, station_ds
)
self._index.upsert_station(station_row)
if period_row is not None:
self._index.replace_station_period_rows(period_row)
updated_surveys.add(station_row.survey_name)
for sv in updated_surveys:
self._index.refresh_survey_aggregates(sv)
return inserted_paths
[docs]
def get_station(self, station_key: str, as_mt: bool = False) -> xr.Dataset | "MT":
"""Return one station as a dataset or reconstructed MT object.
Parameters
----------
station_key : str
Station identifier in canonical tree-path, ``survey/station``, or
``survey.station`` form.
as_mt : bool, optional
If ``True``, convert the stored dataset to an ``MT`` object.
Returns
-------
xarray.Dataset or MT
Station dataset (default) or reconstructed MT object.
Examples
--------
>>> ds = tree.get_station("surveys/surveyA/stations/st01")
>>> ds = tree.get_station("surveyA/st01")
>>> mt_obj = tree.get_station("surveys/surveyA/stations/st01", as_mt=True)
"""
station_path = self._resolve_station_path(station_key)
self.compute(station_paths=[station_path])
station_ds = self.tree[station_path].ds
if as_mt:
mt_obj = self._dataset_to_mt(station_ds)
self._hydrate_metadata_from_cache(mt_obj, station_ds)
return mt_obj
return station_ds
[docs]
def remove_station(self, station_key: str) -> None:
"""Remove one station node and its cached/indexed metadata.
Parameters
----------
station_key : str
Station identifier in canonical tree-path, ``survey/station``, or
``survey.station`` form.
"""
station_key = self._resolve_station_path(station_key)
self.compute()
self._lazy_station_transforms.pop(station_key, None)
self._clear_cached_metadata(station_key)
if self._index is not None:
self._index.delete_station_by_tree_path(station_key)
if "/" not in station_key:
del self.tree[station_key]
return
parent_path, child_name = station_key.rsplit("/", 1)
del self.tree[parent_path][child_name]
[docs]
def get_subset(self, station_list: list[str]) -> "MTData":
"""Create a tree containing only selected station paths.
Parameters
----------
station_list : list[str]
Station tree paths to copy into the subset.
Returns
-------
MTData
New tree with copied station datasets and relevant metadata cache
entries.
"""
station_list = [
self._resolve_station_path(station_key) for station_key in station_list
]
subset = self.__class__(
metadata_storage=self.metadata_storage,
**dict(self.attrs),
)
for station_key in station_list:
station_ds = self.get_station(station_key).copy()
attrs = station_ds.attrs
target_path = self._station_path(
self._clean_name(attrs.get("survey"), "default"),
self._clean_name(attrs.get("station"), "unknown_station"),
)
if self.metadata_storage == "cache":
for metadata_kind in ["survey", "station"]:
cached_md = self._metadata_cache[metadata_kind].get(station_key)
if cached_md is None:
continue
subset._metadata_cache[metadata_kind][target_path] = cached_md
station_ds.attrs[f"{metadata_kind}_metadata_ref"] = target_path
subset.tree[target_path] = xr.DataTree(
name=target_path.rsplit("/", 1)[-1], dataset=station_ds
)
return subset
def _set_station_dataset(self, station_path: str, station_ds: xr.Dataset) -> None:
"""Insert or replace a station dataset at its tree path."""
parent_path, child_name = station_path.rsplit("/", 1)
try:
parent_node = self.tree[parent_path]
except KeyError:
_, survey_name, _ = parent_path.split("/", 2)
survey_path = f"{self.SURVEYS_NODE}/{survey_name}"
if not self._path_exists(survey_path):
self.tree[survey_path] = xr.DataTree(
name=survey_name,
dataset=xr.Dataset(),
)
if not self._path_exists(parent_path):
self.tree[parent_path] = xr.DataTree(
name=self.STATIONS_NODE,
dataset=xr.Dataset(),
)
parent_node = self.tree[parent_path]
parent_node[child_name] = xr.DataTree(name=child_name, dataset=station_ds)
@staticmethod
def _interpolate_station_dataset(
station_ds: xr.Dataset,
new_periods: np.ndarray,
**kwargs: Any,
) -> xr.Dataset:
"""Interpolate a stored station dataset via the dataset tf accessor."""
target_periods = np.asarray(new_periods, dtype=float)
attrs = dict(station_ds.attrs)
if "period" not in station_ds.coords or not station_ds.data_vars:
coords = {
coord_name: coord.values
for coord_name, coord in station_ds.coords.items()
if coord_name != "period"
}
coords["period"] = target_periods
interpolated_ds = xr.Dataset(coords=coords)
interpolated_ds.attrs.update(attrs)
return interpolated_ds
if target_periods.size == 0:
interpolated_ds = station_ds.isel(period=slice(0, 0)).copy(deep=True)
interpolated_ds.attrs.update(attrs)
return interpolated_ds
return station_ds.tf.interpolate(target_periods, inplace=False, **kwargs)
@staticmethod
def _rotate_station_dataset(
station_ds: xr.Dataset,
rotation_angle: float | np.ndarray,
coordinate_reference_frame: str = "ned",
) -> xr.Dataset:
"""Rotate impedance/tipper channel blocks via the dataset tf accessor."""
if (
"transfer_function" not in station_ds
or "period" not in station_ds.coords
or station_ds.sizes.get("period", 0) == 0
):
return station_ds
# Materialise dask arrays before rotation; the accessor uses .loc[]
# assignments which xarray cannot apply to dask-backed arrays.
try:
return station_ds.load().tf.rotate(
rotation_angle,
coordinate_reference_frame=coordinate_reference_frame,
inplace=False,
)
except ValueError:
return station_ds
[docs]
def rotate(
self,
rotation_angle: float | np.ndarray,
inplace: bool = True,
) -> "MTData" | None:
"""Rotate all station transfer functions.
**Note**: This method is not intended for general coordinate
rotation of station locations. It is designed to rotate the
impedance and tipper channels of each station dataset
according to the specified angle and coordinate reference frame.
The station location coordinates are not modified by this method.
Rotation is off the coordinate system therefore a positive
clockwise rotation of 10 degrees will rotate the coordinate
system 10 degrees and the estimated strike angle will be 10
degrees less than the strike angle before rotation.
Parameters
----------
rotation_angle : float or ndarray
Scalar rotation angle in degrees or per-period angle array.
inplace : bool, optional
If ``True`` (default), mutate this tree and return ``None``.
Returns
-------
MTData or None
Rotated copy when *inplace* is ``False``; otherwise ``None``.
"""
tree_obj = self
if not inplace:
tree_obj = self.__class__(
metadata_storage=self.metadata_storage,
dataset_copy_mode=self.dataset_copy_mode,
use_index=self._index is not None,
index_db_path=self._index_db_path,
**dict(self.attrs),
)
updated_surveys: set[str] = set()
for station_path in self._iter_station_paths():
station_ds = self.get_station(station_path)
crf = station_ds.attrs.get(
"coordinate_reference_frame",
self.attrs.get("coordinate_reference_frame", "ned"),
)
rotated_ds = self._rotate_station_dataset(
station_ds,
rotation_angle,
coordinate_reference_frame=crf,
)
tree_obj._set_station_dataset(station_path, rotated_ds)
if tree_obj.metadata_storage == "cache":
for metadata_kind in ["survey", "station"]:
cached_md = self._metadata_cache[metadata_kind].get(station_path)
if cached_md is not None:
tree_obj._metadata_cache[metadata_kind][
station_path
] = cached_md
if tree_obj._index is not None:
station_row, period_row = MTDataTreeIndexStore._extract_rows(
station_path,
rotated_ds,
)
tree_obj._index.upsert_station(station_row)
if period_row is not None:
tree_obj._index.replace_station_period_rows(period_row)
updated_surveys.add(station_row.survey_name)
if tree_obj._index is not None:
for survey_name in updated_surveys:
tree_obj._index.refresh_survey_aggregates(survey_name)
if not inplace:
return tree_obj
return None
[docs]
def interpolate(
self,
new_periods: np.ndarray,
f_type: str = "period",
inplace: bool = True,
bounds_error: bool = True,
**kwargs: Any,
) -> "MTData" | None:
"""Interpolate all stations to a shared period grid.
Parameters
----------
new_periods : ndarray
Target period array, or frequency array when *f_type* is
``'frequency'``/``'freq'``.
f_type : {'frequency', 'freq', 'period', 'per'}, optional
Specifies the meaning of *new_periods*.
inplace : bool, optional
If ``True`` (default), update this tree in place.
bounds_error : bool, optional
If ``True``, clip target periods to each station's native period
range.
**kwargs
Forwarded to station interpolation.
Returns
-------
MTData or None
Interpolated copy when *inplace* is ``False``; otherwise ``None``.
Raises
------
ValueError
If *f_type* is unsupported.
"""
if f_type not in ["frequency", "freq", "period", "per"]:
raise ValueError(
f"f_type must be either 'frequency' or 'period' not {f_type}"
)
target_periods = np.asarray(new_periods, dtype=float)
if target_periods.ndim != 1:
target_periods = target_periods.reshape(-1)
if f_type in ["frequency", "freq"]:
target_periods = 1.0 / target_periods
tree_obj = self
if not inplace:
tree_obj = self.__class__(
metadata_storage=self.metadata_storage,
dataset_copy_mode=self.dataset_copy_mode,
use_index=self._index is not None,
index_db_path=self._index_db_path,
**dict(self.attrs),
)
updated_surveys: set[str] = set()
for station_path in self._iter_station_paths():
station_ds = self.get_station(station_path)
interp_periods = target_periods
if bounds_error and "period" in station_ds.coords:
station_periods = np.asarray(
station_ds.coords["period"].values, dtype=float
)
if station_periods.size > 0:
interp_periods = target_periods[
(target_periods <= station_periods.max())
& (target_periods >= station_periods.min())
]
interpolated_ds = self._interpolate_station_dataset(
station_ds,
interp_periods,
**kwargs,
)
tree_obj._set_station_dataset(station_path, interpolated_ds)
if tree_obj.metadata_storage == "cache":
for metadata_kind in ["survey", "station"]:
cached_md = self._metadata_cache[metadata_kind].get(station_path)
if cached_md is not None:
tree_obj._metadata_cache[metadata_kind][
station_path
] = cached_md
if tree_obj._index is not None:
station_row, period_row = MTDataTreeIndexStore._extract_rows(
station_path,
interpolated_ds,
)
if period_row is None:
tree_obj._index.delete_station_by_tree_path(station_path)
tree_obj._index.upsert_station(station_row)
if period_row is not None:
tree_obj._index.replace_station_period_rows(period_row)
updated_surveys.add(station_row.survey_name)
if tree_obj._index is not None:
for survey_name in updated_surveys:
tree_obj._index.refresh_survey_aggregates(survey_name)
if not inplace:
return tree_obj
return None
[docs]
def interpolate_lazy(
self,
new_periods: np.ndarray,
f_type: str = "period",
inplace: bool = False,
bounds_error: bool = True,
**kwargs: Any,
) -> "MTData":
"""Register deferred interpolation transforms for all stations.
Parameters
----------
new_periods : ndarray
Target period array, or frequency array when *f_type* indicates
frequency input.
f_type : {'frequency', 'freq', 'period', 'per'}, optional
Specifies the meaning of *new_periods*.
inplace : bool, optional
If ``True``, clear and replace lazy transforms on this instance.
Otherwise return a new tree with lazy transforms attached.
bounds_error : bool, optional
If ``True``, clip target periods to each station's native period
range.
**kwargs
Forwarded to station interpolation at compute time.
Returns
-------
MTData
Tree with pending interpolation transforms.
"""
if f_type not in ["frequency", "freq", "period", "per"]:
raise ValueError(
f"f_type must be either 'frequency' or 'period' not {f_type}"
)
# Build lazy plans from realized source station datasets.
self.compute()
target_periods = np.asarray(new_periods, dtype=float)
if target_periods.ndim != 1:
target_periods = target_periods.reshape(-1)
if f_type in ["frequency", "freq"]:
target_periods = 1.0 / target_periods
tree_obj = self
if not inplace:
tree_obj = self.__class__(
metadata_storage=self.metadata_storage,
dataset_copy_mode=self.dataset_copy_mode,
use_index=False,
index_db_path=self._index_db_path,
**dict(self.attrs),
)
tree_obj._lazy_use_index = self._index is not None
else:
tree_obj._lazy_station_transforms.clear()
for station_path in self._iter_station_paths():
source_station_ds = self.get_station(station_path)
interp_periods = target_periods
if bounds_error and "period" in source_station_ds.coords:
station_periods = np.asarray(
source_station_ds.coords["period"].values,
dtype=float,
)
if station_periods.size > 0:
interp_periods = target_periods[
(target_periods <= station_periods.max())
& (target_periods >= station_periods.min())
]
source_snapshot = source_station_ds.copy(deep=False)
target_snapshot = np.asarray(interp_periods, dtype=float)
interp_kwargs = dict(kwargs)
def _transform(
ds: xr.Dataset = source_snapshot,
periods: np.ndarray = target_snapshot,
op_kwargs: dict[str, Any] = interp_kwargs,
) -> xr.Dataset:
return MTData._interpolate_station_dataset(ds, periods, **op_kwargs)
tree_obj._lazy_station_transforms[station_path] = _transform
if not inplace:
tree_obj._set_station_dataset(
station_path, source_snapshot.copy(deep=False)
)
if tree_obj.metadata_storage == "cache":
for metadata_kind in ["survey", "station"]:
cached_md = self._metadata_cache[metadata_kind].get(station_path)
if cached_md is not None:
tree_obj._metadata_cache[metadata_kind][
station_path
] = cached_md
return tree_obj
[docs]
def apply_bounding_box(
self, lon_min: float, lon_max: float, lat_min: float, lat_max: float
) -> "MTData":
"""Return stations that fall inside a lon/lat bounding box.
Parameters
----------
lon_min, lon_max : float
Longitude bounds.
lat_min, lat_max : float
Latitude bounds.
Returns
-------
MTData
Subset tree containing stations inside the bounding box.
"""
if self._index is not None:
station_keys = self._index.query_station_paths(
lon_min=lon_min, lon_max=lon_max, lat_min=lat_min, lat_max=lat_max
)
return self.get_subset(station_keys)
station_df = self.station_locations
if station_df is None or station_df.empty:
return self.__class__(**dict(self.attrs))
bb_df = station_df.loc[
(station_df.longitude >= lon_min)
& (station_df.longitude <= lon_max)
& (station_df.latitude >= lat_min)
& (station_df.latitude <= lat_max)
]
station_keys = [
self._station_path(
self._clean_name(survey, "default"),
self._clean_name(station, "unknown_station"),
)
for survey, station in zip(bb_df.survey, bb_df.station)
]
return self.get_subset(station_keys)
[docs]
def rebuild_index(self, index_db_path: str = ":memory:") -> None:
"""
Build or replace the station index from the current tree contents.
Enables the index if it was not already active.
Parameters
----------
index_db_path : str
SQLite database path. Defaults to ``":memory:"`` (in-process).
"""
# Temporarily clear self._index so _iter_station_paths() uses the tree
# walk, not the (new, empty) index that would otherwise be returned.
saved = self._index
self._index = None
try:
new_index = MTDataTreeIndexStore(index_db_path)
new_index.rebuild_from_tree(self)
except Exception:
self._index = saved
raise
self._index = new_index
[docs]
def query_station_paths(
self,
survey: str | None = None,
lat_min: float | None = None,
lat_max: float | None = None,
lon_min: float | None = None,
lon_max: float | None = None,
period_min: float | None = None,
period_max: float | None = None,
) -> list[str]:
"""
Return station tree paths matching filter criteria via the index.
Requires the index to be enabled (``use_index=True`` or after calling
:meth:`rebuild_index`).
Parameters
----------
survey, lat_min, lat_max, lon_min, lon_max, period_min, period_max
See :meth:`MTDataTreeIndexStore.query_station_paths`.
Returns
-------
list[str]
"""
self.compute()
if self._index is None:
raise RuntimeError(
"Index not enabled. Pass use_index=True to the constructor "
"or call rebuild_index() first."
)
return self._index.query_station_paths(
survey=survey,
lat_min=lat_min,
lat_max=lat_max,
lon_min=lon_min,
lon_max=lon_max,
period_min=period_min,
period_max=period_max,
)
[docs]
def to_dataframe(
self,
utm_crs: Any | None = None,
cols: list[str] | None = None,
impedance_units: str = "mt",
) -> pd.DataFrame:
"""Convert all stations to a concatenated pandas DataFrame.
Parameters
----------
utm_crs : Any, optional
CRS override used when exporting station locations.
cols : list[str], optional
Column subset to include.
impedance_units : str, optional
Impedance unit convention for exported transfer-function values.
Returns
-------
pandas.DataFrame
Concatenated station dataframe.
"""
self.compute()
station_paths = self._iter_station_paths()
df_list = []
for path in station_paths:
station_ds = self.get_station(path)
try:
df_list.append(
self._station_dataset_to_dataframe(
station_ds,
utm_crs=utm_crs,
cols=cols,
impedance_units=impedance_units,
)
)
except Exception:
# Fallback keeps behavior for unexpected/legacy dataset layouts.
df_list.append(
self._dataset_to_mt(station_ds)
.to_dataframe(
utm_crs=utm_crs,
cols=cols,
impedance_units=impedance_units,
)
.dataframe
)
if not df_list:
return pd.DataFrame()
return pd.concat(df_list, ignore_index=True)
[docs]
def to_mt_dataframe(
self, utm_crs: Any | None = None, impedance_units: str = "mt"
) -> MTDataFrame:
"""Create an :class:`MTDataFrame` from all stations.
Parameters
----------
utm_crs : Any, optional
CRS override used during dataframe conversion.
impedance_units : str, optional
Impedance unit convention for exported values.
Returns
-------
MTDataFrame
MTDataFrame wrapping the concatenated station dataframe.
"""
return MTDataFrame(
self.to_dataframe(utm_crs=utm_crs, impedance_units=impedance_units)
)
[docs]
def from_dataframe(self, df: pd.DataFrame, impedance_units: str = "mt") -> None:
"""Populate the tree from a station dataframe.
Parameters
----------
df : pandas.DataFrame
Dataframe containing MT rows, grouped by station (and survey when
available).
impedance_units : str, optional
Unit convention used by impedance values in *df*.
"""
from .mt import MT
if df.empty:
return
group_cols = ["station"]
if "survey" in df.columns:
group_cols = ["survey", "station"]
mt_objects = []
for _, sdf in df.groupby(group_cols, sort=False):
mt_object = MT(period=sdf.period.unique())
mt_object.from_dataframe(sdf, impedance_units=impedance_units)
mt_objects.append(mt_object)
if mt_objects:
self.add_stations(mt_objects)
[docs]
def from_mt_dataframe(
self, mt_df: MTDataFrame, impedance_units: str = "mt"
) -> None:
"""Populate the tree from an :class:`MTDataFrame`.
Parameters
----------
mt_df : MTDataFrame
Input MTDataFrame.
impedance_units : str, optional
Unit convention used by impedance values.
"""
self.from_dataframe(mt_df.dataframe, impedance_units=impedance_units)
[docs]
def get_periods(self) -> np.ndarray:
"""Return sorted unique periods across all stations.
Returns
-------
numpy.ndarray
One-dimensional array of unique periods in ascending order.
"""
self.compute()
periods: list[np.ndarray] = []
def _walk(node: Any) -> None:
ds = getattr(node, "ds", None)
if isinstance(ds, xr.Dataset) and "period" in ds.coords:
periods.append(np.asarray(ds.coords["period"].values, dtype=float))
for child in getattr(node, "children", {}).values():
_walk(child)
_walk(self.tree)
if not periods:
return np.array([], dtype=float)
unique_periods = np.unique(np.concatenate(periods))
unique_periods.sort()
return unique_periods
[docs]
def keys(self) -> list[str]:
"""Return immediate top-level child node keys.
Returns
-------
list[str]
Names of direct children under the tree root.
"""
return list(self.tree.children.keys())
def _resolve_plot_station_key(
self,
station_key: str | None = None,
station_id: str | None = None,
survey_id: str | None = None,
) -> str:
"""Resolve plotting selectors to one canonical station tree path.
Parameters
----------
station_key : str, optional
Canonical station path or alternate station key accepted by
:meth:`_resolve_station_path`.
station_id : str, optional
Station identifier.
survey_id : str, optional
Survey identifier used to disambiguate duplicate station IDs.
Returns
-------
str
Canonical station tree path.
Raises
------
ValueError
If both *station_key* and *station_id* are missing, or if
*station_id* is ambiguous without *survey_id*.
KeyError
If no matching station can be resolved.
"""
if station_key is not None:
return self._resolve_station_path(station_key)
if station_id is None:
raise ValueError("Provide station_key or station_id")
station_name = self._clean_name(station_id, "unknown_station")
if survey_id is not None:
survey_name = self._clean_name(survey_id, "default")
return self._resolve_station_path(
self._station_path(survey_name, station_name)
)
matches = [
path
for path in self.station_paths
if path.rsplit("/", 1)[-1] == station_name
]
if len(matches) == 1:
return matches[0]
if len(matches) == 0:
raise KeyError(
"Station key not found for station_id without survey_id: "
f"{station_id}"
)
raise ValueError(
"Multiple stations matched station_id. "
"Provide survey_id to disambiguate."
)
[docs]
def plot_mt_response(
self,
station_key: str | list[str] | None = None,
station_id: str | list[str] | None = None,
survey_id: str | list[str] | None = None,
**kwargs: Any,
) -> PlotMultipleResponses | Any:
"""
Plot MT response for one or more stations.
Parameters
----------
station_key : str, list of str, optional
Station key(s) in canonical or accepted alternate form.
station_id : str, list of str, optional
Station ID(s). When provided without *survey_id*, each station ID
must be unique across surveys.
survey_id : str, list of str, optional
Survey ID(s) used with *station_id*. If list-valued, must align
one-to-one with *station_id*.
**kwargs : dict
Additional plotting keyword arguments.
Returns
-------
PlotMultipleResponses or plot object
Multi-station response plot for list-valued selectors, otherwise
a single-station MT response plot object.
Raises
------
ValueError
If list-valued *survey_id* does not match list-valued
*station_id*, or if station selection is ambiguous.
KeyError
If a requested station cannot be resolved.
Examples
--------
>>> tree.plot_mt_response(station_key="survey_a/st01")
>>> tree.plot_mt_response(station_id="st01", survey_id="survey_a")
>>> tree.plot_mt_response(
... station_id=["st01", "st02"],
... survey_id=["survey_a", "survey_b"],
... )
"""
if isinstance(station_key, (list, tuple)):
station_keys = [self._resolve_station_path(sk) for sk in station_key]
return PlotMultipleResponses(self.get_subset(station_keys), **kwargs)
elif isinstance(station_id, (list, tuple)):
station_ids = list(station_id)
if isinstance(survey_id, (list, tuple)):
survey_ids = list(survey_id)
if len(survey_ids) != len(station_ids):
raise ValueError("Number of survey must match number of stations")
else:
survey_ids = [survey_id] * len(station_ids)
station_keys = [
self._resolve_plot_station_key(
station_id=station,
survey_id=survey,
)
for survey, station in zip(survey_ids, station_ids)
]
return PlotMultipleResponses(self.get_subset(station_keys), **kwargs)
else:
station_path = self._resolve_plot_station_key(
station_id=station_id,
survey_id=survey_id,
station_key=station_key,
)
mt_object = self.get_station(station_path, as_mt=True)
return mt_object.plot_mt_response(**kwargs)
[docs]
def plot_stations(
self,
map_epsg: int = 4326,
bounding_box: tuple[float, float, float, float] | None = None,
model_locations: bool = False,
**kwargs: Any,
) -> PlotStations:
"""
Plot station locations on a map.
Parameters
----------
map_epsg : int, optional
EPSG code forwarded to :class:`PlotStations` as ``map_epsg``.
bounding_box : tuple of float, optional
Optional ``(lon_min, lon_max, lat_min, lat_max)`` used to subset
stations before plotting.
model_locations : bool, optional
Use model coordinates instead of geographic coordinates.
**kwargs : dict
Additional plotting keyword arguments.
Returns
-------
PlotStations
Station plot object
Raises
------
ValueError
If *bounding_box* is provided and does not contain four values.
Examples
--------
>>> tree.plot_stations()
>>> tree.plot_stations(map_epsg=3857)
>>> tree.plot_stations(bounding_box=(-121.5, -120.0, 36.5, 38.0))
"""
mt_data = self
if bounding_box is not None:
if len(bounding_box) != 4:
raise ValueError(
"bounding_box must be (lon_min, lon_max, lat_min, lat_max)"
)
mt_data = self.apply_bounding_box(*bounding_box)
gdf = mt_data.to_geo_df(model_locations=model_locations)
if model_locations:
kwargs["plot_cx"] = False
kwargs.setdefault("map_epsg", map_epsg)
return PlotStations(gdf, **kwargs)
[docs]
def plot_strike(self, **kwargs: Any) -> PlotStrike:
"""
Plot strike angle.
Parameters
----------
**kwargs : dict
Additional plotting keyword arguments.
Returns
-------
PlotStrike
Strike plot object
Examples
--------
>>> tree.plot_strike()
>>> tree.plot_strike(show_plot=False)
"""
return PlotStrike(self, **kwargs)
[docs]
def plot_phase_tensor(
self,
station_key: str | None = None,
station_id: str | None = None,
survey_id: str | None = None,
**kwargs: Any,
) -> Any:
"""
Plot phase tensor elements for a station.
Parameters
----------
station_key : str, optional
Station key in canonical or accepted alternate form.
station_id : str, optional
Station ID.
survey_id : str, optional
Survey ID used to disambiguate duplicate station IDs.
**kwargs : dict
Additional plotting keyword arguments.
Returns
-------
plot object
Phase tensor plot object
Raises
------
ValueError
If station selection is ambiguous.
KeyError
If the station cannot be resolved.
Examples
--------
>>> tree.plot_phase_tensor(station_key="survey_a/st01")
>>> tree.plot_phase_tensor(station_id="st01", survey_id="survey_a")
"""
station_path = self._resolve_plot_station_key(
station_id=station_id,
survey_id=survey_id,
station_key=station_key,
)
mt_object = self.get_station(station_path, as_mt=True)
return mt_object.plot_phase_tensor(**kwargs)
[docs]
def plot_phase_tensor_map(self, **kwargs: Any) -> PlotPhaseTensorMaps:
"""
Plot phase tensor maps.
Parameters
----------
**kwargs : dict
Additional plotting keyword arguments.
Returns
-------
PlotPhaseTensorMaps
Phase tensor map plot object
Examples
--------
>>> tree.plot_phase_tensor_map(plot_period=10)
>>> tree.plot_phase_tensor_map(plot_station=True)
"""
return PlotPhaseTensorMaps(mt_data=self, **kwargs)
[docs]
def plot_tipper_map(self, **kwargs: Any) -> PlotPhaseTensorMaps:
"""
Plot tipper (induction vector) maps.
Parameters
----------
**kwargs : dict
Additional plotting keyword arguments. Defaults are
``plot_pt=False`` and ``plot_tipper='yri'`` when not explicitly
provided.
Returns
-------
PlotPhaseTensorMaps
Tipper map plot object
Examples
--------
>>> tree.plot_tipper_map()
>>> tree.plot_tipper_map(plot_tipper="yri", plot_pt=False)
"""
kwargs.setdefault("plot_pt", False)
kwargs.setdefault("plot_tipper", "yri")
return PlotPhaseTensorMaps(mt_data=self, **kwargs)
[docs]
def plot_phase_tensor_pseudosection(
self, mt_data: "MTData" | None = None, **kwargs: Any
) -> PlotPhaseTensorPseudoSection:
"""
Plot phase tensor pseudosection.
Parameters
----------
mt_data : MTData, optional
MTData object to plot. Defaults to ``self``.
**kwargs : dict
Additional plotting keyword arguments.
Returns
-------
PlotPhaseTensorPseudoSection
Pseudosection plot object
Examples
--------
>>> tree.plot_phase_tensor_pseudosection()
>>> subset = tree.get_survey("survey_a")
>>> tree.plot_phase_tensor_pseudosection(mt_data=subset)
"""
if mt_data is None:
mt_data = self
return PlotPhaseTensorPseudoSection(mt_data=mt_data, **kwargs)
[docs]
def plot_penetration_depth_1d(
self,
station_key: str | None = None,
station_id: str | None = None,
survey_id: str | None = None,
**kwargs: Any,
) -> Any:
"""
Plot 1D penetration depth.
Parameters
----------
station_key : str, optional
Station key in canonical or accepted alternate form.
station_id : str, optional
Station ID.
survey_id : str, optional
Survey ID used to disambiguate duplicate station IDs.
**kwargs : dict
Additional plotting keyword arguments.
Returns
-------
plot object
Penetration depth plot object
Raises
------
ValueError
If station selection is ambiguous.
KeyError
If the station cannot be resolved.
Notes
-----
Based on Niblett-Bostick transformation
Examples
--------
>>> tree.plot_penetration_depth_1d(station_key="survey_a/st01")
>>> tree.plot_penetration_depth_1d(
... station_id="st01", survey_id="survey_a", depth_units="km"
... )
"""
station_path = self._resolve_plot_station_key(
station_id=station_id,
survey_id=survey_id,
station_key=station_key,
)
mt_object = self.get_station(station_path, as_mt=True)
return mt_object.plot_depth_of_penetration(**kwargs)
[docs]
def plot_penetration_depth_map(self, **kwargs: Any) -> PlotPenetrationDepthMap:
"""
Plot penetration depth in map view.
Parameters
----------
**kwargs : dict
Additional plotting keyword arguments.
Returns
-------
PlotPenetrationDepthMap
Penetration depth map plot object
Examples
--------
>>> tree.plot_penetration_depth_map(plot_period=10)
>>> tree.plot_penetration_depth_map(depth_units="km")
"""
return PlotPenetrationDepthMap(mt_data=self, **kwargs)
[docs]
def plot_resistivity_phase_maps(self, **kwargs: Any) -> PlotResPhaseMaps:
"""
Plot apparent resistivity and/or phase maps.
Parameters
----------
**kwargs : dict
Additional plotting keyword arguments.
Returns
-------
PlotResPhaseMaps
Resistivity/phase map plot object
Examples
--------
>>> tree.plot_resistivity_phase_maps(plot_period=10)
>>> tree.plot_resistivity_phase_maps(plot_xy=True, plot_yx=False)
"""
return PlotResPhaseMaps(mt_data=self, **kwargs)
[docs]
def plot_resistivity_phase_pseudosections(
self, **kwargs: Any
) -> PlotResPhasePseudoSection:
"""
Plot resistivity and phase pseudosections.
Parameters
----------
**kwargs : dict
Additional plotting keyword arguments.
Returns
-------
PlotResPhasePseudoSection
Pseudosection plot object
Examples
--------
>>> tree.plot_resistivity_phase_pseudosections()
>>> tree.plot_resistivity_phase_pseudosections(interpolation_method="nearest")
"""
return PlotResPhasePseudoSection(mt_data=self, **kwargs)
[docs]
def plot_residual_phase_tensor_maps(
self, survey_01: str, survey_02: str, **kwargs: Any
) -> PlotResidualPTMaps:
"""
Plot residual phase tensor maps.
Parameters
----------
survey_01 : str
First survey ID.
survey_02 : str
Second survey ID.
**kwargs : dict
Additional plotting keyword arguments.
Returns
-------
PlotResidualPTMaps
Residual phase tensor map plot object
Raises
------
KeyError
If either survey ID is not present in the current MTData.
Examples
--------
>>> tree.plot_residual_phase_tensor_maps("survey_a", "survey_b")
>>> tree.plot_residual_phase_tensor_maps(
... "survey_a", "survey_b", plot_freq=1.0
... )
"""
survey_data_01 = self.get_survey(survey_01)
survey_data_02 = self.get_survey(survey_02)
if survey_data_01.n_stations == 0:
raise KeyError(f"Survey not found: {survey_01}")
if survey_data_02.n_stations == 0:
raise KeyError(f"Survey not found: {survey_02}")
return PlotResidualPTMaps(survey_data_01, survey_data_02, **kwargs)