Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
f9b876d
Initial ox_state calc directory
Jan 28, 2026
bcb575f
Test calc script for aqueous iron chloride MD
Feb 5, 2026
60236d6
Core setup of Fe_oxidation_states benchmarks - fix to remove unecessa…
Feb 6, 2026
01b1438
Fix metrics, highlight ref on the plots, fix analysis, add rdf calcul…
Feb 13, 2026
d63e5e4
Apply suggestion from @ElliottKasoar
PKourtis Feb 16, 2026
82ef69c
Applied PR suggestion on highighted_range for the plot scatter decora…
Feb 16, 2026
6f22db7
Revert pyproject.toml to upstream version
Feb 18, 2026
71745d0
Fixup
ElliottKasoar Mar 11, 2026
7b90812
Cleaned up model declaration in the app.py
PKourtis Feb 19, 2026
740ca75
Each model has its own directory within outputs and the rdf tests get…
PKourtis Feb 19, 2026
1fd6d5f
Updated metrics level of theory to Experimental
PKourtis Feb 19, 2026
c565512
Updated analysis to match outputs/model_name data directory pattern
PKourtis Feb 19, 2026
d066c22
Added download from S3 bucket function for the input data
PKourtis Feb 19, 2026
dcfe48c
Added yes/no units
PKourtis Feb 19, 2026
288cb1d
Fixed plot_scatter highlighted range title and plot title mixup
PKourtis Feb 19, 2026
6f0423e
Download MD starting structures from S3 bucket and save outputs in th…
PKourtis Feb 19, 2026
7b97ddc
Added model name to the scatter title
PKourtis Feb 19, 2026
3e2a472
Apply pre-commit
ElliottKasoar Mar 11, 2026
b742e84
Delete calc data
ElliottKasoar Mar 11, 2026
b8e9dc4
Remove output files
ElliottKasoar Mar 12, 2026
aa657a0
Add docs link
ElliottKasoar Mar 12, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
"""Analyse aqueous Iron Chloride oxidation states."""

from __future__ import annotations

from pathlib import Path

import numpy as np
import pytest

from ml_peg.analysis.utils.decorators import build_table, plot_scatter
from ml_peg.analysis.utils.utils import load_metrics_config
from ml_peg.app import APP_ROOT
from ml_peg.calcs import CALCS_ROOT
from ml_peg.models.get_models import get_model_names
from ml_peg.models.models import current_models

MODELS = get_model_names(current_models)

CALC_PATH = CALCS_ROOT / "physicality" / "oxidation_states" / "outputs"
OUT_PATH = APP_ROOT / "data" / "physicality" / "oxidation_states"

METRICS_CONFIG_PATH = Path(__file__).with_name("metrics.yml")
DEFAULT_THRESHOLDS, DEFAULT_TOOLTIPS, _ = load_metrics_config(METRICS_CONFIG_PATH)

IRON_SALTS = ["Fe2Cl", "Fe3Cl"]
TESTS = ["Fe-O RDF Peak Split", "Peak Within Experimental Ref"]
REF_PEAK_RANGE = {
"Fe<sup>+2</sup><br>Ref": [2.0, 2.2],
"Fe<sup>+3</sup><br>Ref": [1.9, 2.0],
}


def get_rdf_results(
model: str,
) -> dict[str, tuple[list[float], list[float]]]:
"""
Get a model's Fe-O RDFs for the aqueous Fe2Cl and Fe3Cl MD.

Parameters
----------
model
Name of MLIP.

Returns
-------
results
RDF Radii and intensities for the aqueous Fe2Cl and Fe3Cl systems.
"""
results = {salt: [] for salt in IRON_SALTS}

model_calc_path = CALC_PATH / model

for salt in IRON_SALTS:
rdf_file = model_calc_path / f"O-Fe_{salt}_{model}.rdf"

fe_o_rdf = np.loadtxt(rdf_file)
r = list(fe_o_rdf[:, 0])
g_r = list(fe_o_rdf[:, 1])

results[salt].append(r)
results[salt].append(g_r)

return results


def plot_rdfs(model: str, results: dict[str, tuple[list[float], list[float]]]) -> None:
"""
Plot Fe-O RDFs.

Parameters
----------
model
Name of MLIP.
results
RDF Radii and intensities for the aqueous Fe2Cl and Fe3Cl systems.
"""

@plot_scatter(
filename=OUT_PATH / f"Fe-O_{model}_RDF_scatter.json",
title=f"<b>{model} MD</b>",
x_label="r [Å]",
y_label="Fe-O G(r)",
show_line=True,
show_markers=False,
highlight_range=REF_PEAK_RANGE,
)
def plot_result() -> dict[str, tuple[list[float], list[float]]]:
"""
Plot the RDFs.

Returns
-------
model_results
Dictionary of model Fe-O RDFs for the aqueous Fe2Cl and Fe3Cl systems.
"""
return results

plot_result()


@pytest.fixture
def get_oxidation_states_passfail() -> dict[str, dict]:
"""
Test whether model RDF peaks are split and they fall within the reference range.

Returns
-------
oxidation_states_passfail
Dictionary of pass fail per model.
"""
oxidation_state_passfail = {test: {} for test in TESTS}

fe_2_ref = [2.0, 2.2]
fe_3_ref = [1.9, 2.0]

for model in MODELS:
peak_position = {}
results = get_rdf_results(model)
plot_rdfs(model, results)

for salt in IRON_SALTS:
r = results[salt][0]
g_r = results[salt][1]
peak_position[salt] = r[g_r.index(max(g_r))]

peak_difference = abs(peak_position["Fe2Cl"] - peak_position["Fe3Cl"])

oxidation_state_passfail["Fe-O RDF Peak Split"][model] = 0.0
oxidation_state_passfail["Peak Within Experimental Ref"][model] = 0.0

if peak_difference > 0.07:
oxidation_state_passfail["Fe-O RDF Peak Split"][model] = 1.0

if fe_2_ref[0] <= peak_position["Fe2Cl"] <= fe_2_ref[1]:
oxidation_state_passfail["Peak Within Experimental Ref"][model] += 0.5

if fe_3_ref[0] <= peak_position["Fe3Cl"] <= fe_3_ref[1]:
oxidation_state_passfail["Peak Within Experimental Ref"][model] += 0.5

return oxidation_state_passfail


@pytest.fixture
@build_table(
filename=OUT_PATH / "oxidation_states_table.json",
metric_tooltips=DEFAULT_TOOLTIPS,
thresholds=DEFAULT_THRESHOLDS,
)
def oxidation_states_passfail_metrics(
get_oxidation_states_passfail: dict[str, dict],
) -> dict[str, dict]:
"""
Get all oxidation states pass fail metrics.

Parameters
----------
get_oxidation_states_passfail
Dictionary of pass fail per model.

Returns
-------
dict[str, dict]
Dictionary of pass fail per model.
"""
return get_oxidation_states_passfail


def test_oxidation_states_passfail_metrics(
oxidation_states_passfail_metrics: dict[str, dict],
) -> None:
"""
Run oxidation states test.

Parameters
----------
oxidation_states_passfail_metrics
All oxidation states pass fail.
"""
return
13 changes: 13 additions & 0 deletions ml_peg/analysis/physicality/oxidation_states/metrics.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
metrics:
Fe-O RDF Peak Split:
Comment thread
PKourtis marked this conversation as resolved.
good: 1.0
bad: 0.0
unit: Yes(1)/No(0)
tooltip: Whether there is a split between Fe-O RDF peaks for different iron oxidation states
level_of_theory: Experimental
Peak Within Experimental Ref:
good: 1.0
bad: 0.0
unit: Yes(1)/No(0)
tooltip: Whether the RDF peak positions match experimental peaks
level_of_theory: Experimental
29 changes: 28 additions & 1 deletion ml_peg/analysis/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from dash import dash_table
import numpy as np
import pandas as pd
import plotly.colors as pc
import plotly.graph_objects as go

from ml_peg.analysis.utils.utils import (
Expand Down Expand Up @@ -436,8 +437,10 @@ def plot_scatter(
x_label: str | None = None,
y_label: str | None = None,
show_line: bool = False,
show_markers: bool = True,
hoverdata: dict | None = None,
filename: str = "scatter.json",
highlight_range: dict = None,
) -> Callable:
"""
Plot scatter plot of MLIP results.
Expand All @@ -452,10 +455,14 @@ def plot_scatter(
Label for y-axis. Default is `None`.
show_line
Whether to show line between points. Default is False.
show_markers
Whether to show markers on the plot. Default is True.
hoverdata
Hover data dictionary. Default is `{}`.
filename
Filename to save plot as JSON. Default is "scatter.json".
highlight_range
Dictionary of rectangle title and x-axis endpoints.

Returns
-------
Expand Down Expand Up @@ -504,7 +511,13 @@ def plot_scatter_wrapper(*args, **kwargs) -> dict[str, Any]:
hovertemplate += f"<b>{key}: </b>%{{customdata[{i}]}}<br>"
customdata = list(zip(*hoverdata.values(), strict=True))

mode = "lines+markers" if show_line else "markers"
modes = []
if show_line:
modes.append("lines")
if show_markers:
modes.append("markers")

mode = "+".join(modes)

fig = go.Figure()
for mlip, value in results.items():
Expand All @@ -520,6 +533,20 @@ def plot_scatter_wrapper(*args, **kwargs) -> dict[str, Any]:
)
)

colors = pc.qualitative.Plotly

if highlight_range:
for i, (h_text, range) in enumerate(highlight_range.items()):
fig.add_vrect(
x0=range[0],
x1=range[1],
annotation_text=h_text,
annotation_position="top",
fillcolor=colors[i],
opacity=0.25,
line_width=0,
)

fig.update_layout(
title={"text": title},
xaxis={"title": {"text": x_label}},
Expand Down
90 changes: 90 additions & 0 deletions ml_peg/app/physicality/oxidation_states/app_oxidation_states.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""Run oxidation states app."""

from __future__ import annotations

from dash import Dash
from dash.html import Div

from ml_peg.app import APP_ROOT
from ml_peg.app.base_app import BaseApp
from ml_peg.app.utils.build_callbacks import (
plot_from_table_cell,
)
from ml_peg.app.utils.load import read_plot
from ml_peg.calcs import CALCS_ROOT
from ml_peg.models.get_models import get_model_names
from ml_peg.models.models import current_models

MODELS = get_model_names(current_models)

BENCHMARK_NAME = "Iron Oxidation States"
DOCS_URL = "https://ddmms.github.io/ml-peg/user_guide/benchmarks/physicality.html#oxidation-states"
DATA_PATH = APP_ROOT / "data" / "physicality" / "oxidation_states"
REF_PATH = CALCS_ROOT / "physicality" / "oxidation_states" / "data"


class FeOxidationStatesApp(BaseApp):
"""Fe Oxidation States benchmark app layout and callbacks."""

def register_callbacks(self) -> None:
"""Register callbacks to app."""
scatter_plots = {
model: {
"Fe-O RDF Peak Split": read_plot(
DATA_PATH / f"Fe-O_{model}_RDF_scatter.json",
id=f"{BENCHMARK_NAME}-{model}-figure-Fe-O-RDF",
),
"Peak Within Experimental Ref": read_plot(
DATA_PATH / f"Fe-O_{model}_RDF_scatter.json",
id=f"{BENCHMARK_NAME}-{model}-figure-Fe-O-RDF",
),
}
for model in MODELS
}

plot_from_table_cell(
table_id=self.table_id,
plot_id=f"{BENCHMARK_NAME}-figure-placeholder",
cell_to_plot=scatter_plots,
)


def get_app() -> FeOxidationStatesApp:
"""
Get Fe Oxidation States benchmark app layout and callback registration.

Returns
-------
FeOxidationStatesApp
Benchmark layout and callback registration.
"""
return FeOxidationStatesApp(
name=BENCHMARK_NAME,
description=(
"Evaluate model ability to capture different oxidation states of Fe"
"from aqueous Fe 2Cl and Fe 3Cl MD RDFs"
),
docs_url=DOCS_URL,
table_path=DATA_PATH / "oxidation_states_table.json",
extra_components=[
Div(id=f"{BENCHMARK_NAME}-figure-placeholder"),
Div(id=f"{BENCHMARK_NAME}-struct-placeholder"),
],
)


if __name__ == "__main__":
# Create Dash app
full_app = Dash(
__name__,
assets_folder=DATA_PATH.parent.parent,
suppress_callback_exceptions=True,
)

# Construct layout and register callbacks
FeOxidationStatesApp = get_app()
full_app.layout = FeOxidationStatesApp.layout
FeOxidationStatesApp.register_callbacks()

# Run app
full_app.run(port=8054, debug=True)
Loading
Loading