Plot NCATS-sol

This tutorial runs the PSMA workflow on the NCATS-sol classification dataset and renders the static and interactive posterior plots.

The dataset uses low_solubility as a binary endpoint: class 1 means low solubility and class 0 means moderate-to-high solubility.

Imports

The tutorial uses the pure computation API, so no PSMA artifacts are written to disk.

from pathlib import Path
import sys

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from bokeh.embed import file_html
from bokeh.resources import CDN
from IPython.display import HTML

for candidate in [Path.cwd(), *Path.cwd().parents]:
    if (candidate / "psma").is_dir():
        sys.path.insert(0, str(candidate))
        break

from psma import compute_psma_surface  # noqa: E402
from psma.plot import (  # noqa: E402
    plot_posterior_2d,
    plot_posterior_2d_interactive,
    plot_posterior_3d,
)

Load the tutorial data

The CSV is stored with the documentation because it supports this tutorial directly.

def find_data_path() -> Path:
    """Find the NCATS-sol CSV from common notebook execution roots."""
    candidates = [
        Path("docs/_data/solubility_NCATS-sol.csv"),
        Path("_data/solubility_NCATS-sol.csv"),
        Path("../_data/solubility_NCATS-sol.csv"),
    ]
    for candidate in candidates:
        if candidate.exists():
            return candidate
    raise FileNotFoundError("Could not find solubility_NCATS-sol.csv")


df = pd.read_csv(find_data_path())
df = df.loc[:, ["canonical_smiles", "low_solubility"]].dropna()
df["canonical_smiles"] = df["canonical_smiles"].astype(str).str.strip()
df = df.loc[df["canonical_smiles"] != ""].copy()
df["low_solubility"] = pd.to_numeric(df["low_solubility"], errors="raise").astype(int)
df = df.drop_duplicates(subset="canonical_smiles", keep="first").reset_index(drop=True)
df.insert(0, "mol_id", [f"ncats_sol_{index:05d}" for index in range(len(df))])

df.head()
mol_id canonical_smiles low_solubility
0 ncats_sol_00000 O=c1cc(-c2ccc(O)c(O)c2)oc2cc(O)cc(O)c12 0
1 ncats_sol_00001 C=CCc1ccc(O)c(-c2ccc(O)c(CC=C)c2)c1 0
2 ncats_sol_00002 CC[C@H]1NC(=O)[C@@H](NC(=O)c2ncccc2O)[C@@H](C)... 0
3 ncats_sol_00003 O=c1ncn2nc(Sc3ccc(F)cc3F)ccc2c1-c1c(Cl)cccc1Cl 0
4 ncats_sol_00004 O=C(Cc1ccc(Cl)c(Cl)c1)Nc1ccc(S(=O)(=O)Nc2ccon2... 0
df["low_solubility"].value_counts().sort_index().rename("count")
low_solubility
0    1054
1    1399
Name: count, dtype: int64

Run PSMA without writing artifacts

The dataset is already binary, so label_threshold=0.5 with label_direction="ge" preserves the 0/1 class labels.

def run_ncats_sol(split_method: str):
    """Compute one NCATS-sol PSMA result without writing artifacts."""
    return compute_psma_surface(
        df,
        y_col="low_solubility",
        smiles_col="canonical_smiles",
        mol_id_col="mol_id",
        similarity_method="rdkit_morgan_tanimoto",
        split_method=split_method,
        butina_distance_cutoff=0.4,
        label_threshold=0.5,
        label_direction="ge",
        test_fraction=0.2,
        random_state=7,
    )


random_result = run_ncats_sol("random")
butina_result = run_ncats_sol("butina")
Projection diagnostics flagged instability with rank=1957 cond=3.9070720654275174e+19
Projection diagnostics flagged instability with rank=1958 cond=1.9596175125530284e+19
pd.DataFrame(
    {
        "split": ["random", "butina"],
        "auc": [random_result.metrics.auc, butina_result.metrics.auc],
        "mcc": [random_result.metrics.mcc, butina_result.metrics.mcc],
        "train_n": [len(random_result.indices.train), len(butina_result.indices.train)],
        "test_n": [len(random_result.indices.test), len(butina_result.indices.test)],
    }
)
split auc mcc train_n test_n
0 random 0.652273 0.226285 1962 491
1 butina 0.634628 0.213330 1962 491

Static 2D posterior plot

The 2D view is the most compact summary of the posterior surface and projected test compounds.

fig, ax = plt.subplots(figsize=(8, 6))
plot_posterior_2d(
    ax,
    grid_x=random_result.grid.grid_x,
    grid_y=random_result.grid.grid_y,
    posterior_z=random_result.grid.posterior_z,
    pos_density_z=random_result.grid.pos_density_z,
    c_test=random_result.coords.c_test,
    y_test_bin=random_result.labels.y_test_bin,
)
ax.set_title("NCATS-sol random split posterior")
fig.tight_layout()
../_images/0721625e981a5d89caf772d77a5fbcef675266afbcd8d61e94c89c61e5e0c389.png

Static 3D posterior plot

The 3D view makes the geometry of the posterior surface easier to inspect.

fig3d, ax3d = plot_posterior_3d(
    grid_x=random_result.grid.grid_x,
    grid_y=random_result.grid.grid_y,
    posterior_z=random_result.grid.posterior_z,
    c_test=random_result.coords.c_test,
    prob_test=random_result.prob_test,
)
ax3d.set_title("NCATS-sol random split posterior")
fig3d.tight_layout()
../_images/6b6a88f749bbe1544b364979ff342fa0d0ed07ff7787951f577c1f4934ac429f.png

Interactive posterior plot

The interactive plot adds toggleable layers, contour-level hover, compound-level hover metadata, and RDKit hover depictions from SMILES.

test_mol_ids = random_result.frames.test["mol_id"].astype(str).to_numpy()
smiles_by_mol_id = dict(
    zip(df["mol_id"].astype(str), df["canonical_smiles"].astype(str), strict=True)
)
test_smiles = np.asarray(
    [smiles_by_mol_id[mol_id] for mol_id in test_mol_ids], dtype=object
)

interactive_plot = plot_posterior_2d_interactive(
    grid_x=random_result.grid.grid_x,
    grid_y=random_result.grid.grid_y,
    posterior_z=random_result.grid.posterior_z,
    pos_density_z=random_result.grid.pos_density_z,
    c_test=random_result.coords.c_test,
    y_test_bin=random_result.labels.y_test_bin,
    prob_test=random_result.prob_test,
    mol_ids=test_mol_ids,
    smiles=test_smiles,
    width=800,
    height=600,
)
HTML(file_html(interactive_plot, CDN, "Interactive PSMA Posterior"))
Interactive PSMA Posterior