{ "cells": [ { "cell_type": "markdown", "id": "3f2949ca", "metadata": {}, "source": [ "# Plot NCATS-sol\n", "\n", "This tutorial runs the PSMA workflow on the NCATS-sol classification dataset and renders the static and interactive posterior plots.\n", "\n", "The dataset uses `low_solubility` as a binary endpoint: class `1` means low solubility and class `0` means moderate-to-high solubility." ] }, { "cell_type": "markdown", "id": "f433bc45", "metadata": {}, "source": [ "## Imports\n", "\n", "The tutorial uses the pure computation API, so no PSMA artifacts are written to disk." ] }, { "cell_type": "code", "execution_count": null, "id": "492c4c72", "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "import sys\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "from bokeh.embed import file_html\n", "from bokeh.resources import CDN\n", "from IPython.display import HTML\n", "\n", "for candidate in [Path.cwd(), *Path.cwd().parents]:\n", " if (candidate / \"psma\").is_dir():\n", " sys.path.insert(0, str(candidate))\n", " break\n", "\n", "from psma import compute_psma_surface # noqa: E402\n", "from psma.plot import ( # noqa: E402\n", " plot_posterior_2d,\n", " plot_posterior_2d_interactive,\n", " plot_posterior_3d,\n", ")" ] }, { "cell_type": "markdown", "id": "ff7efe69", "metadata": {}, "source": [ "## Load the tutorial data\n", "\n", "The CSV is stored with the documentation because it supports this tutorial directly." ] }, { "cell_type": "code", "execution_count": null, "id": "543d42b1", "metadata": {}, "outputs": [], "source": [ "def find_data_path() -> Path:\n", " \"\"\"Find the NCATS-sol CSV from common notebook execution roots.\"\"\"\n", " candidates = [\n", " Path(\"docs/_data/solubility_NCATS-sol.csv\"),\n", " Path(\"_data/solubility_NCATS-sol.csv\"),\n", " Path(\"../_data/solubility_NCATS-sol.csv\"),\n", " ]\n", " for candidate in candidates:\n", " if candidate.exists():\n", " return candidate\n", " raise FileNotFoundError(\"Could not find solubility_NCATS-sol.csv\")\n", "\n", "\n", "df = pd.read_csv(find_data_path())\n", "df = df.loc[:, [\"canonical_smiles\", \"low_solubility\"]].dropna()\n", "df[\"canonical_smiles\"] = df[\"canonical_smiles\"].astype(str).str.strip()\n", "df = df.loc[df[\"canonical_smiles\"] != \"\"].copy()\n", "df[\"low_solubility\"] = pd.to_numeric(df[\"low_solubility\"], errors=\"raise\").astype(int)\n", "df = df.drop_duplicates(subset=\"canonical_smiles\", keep=\"first\").reset_index(drop=True)\n", "df.insert(0, \"mol_id\", [f\"ncats_sol_{index:05d}\" for index in range(len(df))])\n", "\n", "df.head()" ] }, { "cell_type": "code", "execution_count": null, "id": "589453f9", "metadata": {}, "outputs": [], "source": [ "df[\"low_solubility\"].value_counts().sort_index().rename(\"count\")" ] }, { "cell_type": "markdown", "id": "38a642f1", "metadata": {}, "source": [ "## Run PSMA without writing artifacts\n", "\n", "The dataset is already binary, so `label_threshold=0.5` with `label_direction=\"ge\"` preserves the `0/1` class labels." ] }, { "cell_type": "code", "execution_count": null, "id": "a528c782", "metadata": {}, "outputs": [], "source": [ "def run_ncats_sol(split_method: str):\n", " \"\"\"Compute one NCATS-sol PSMA result without writing artifacts.\"\"\"\n", " return compute_psma_surface(\n", " df,\n", " y_col=\"low_solubility\",\n", " smiles_col=\"canonical_smiles\",\n", " mol_id_col=\"mol_id\",\n", " similarity_method=\"rdkit_morgan_tanimoto\",\n", " split_method=split_method,\n", " butina_distance_cutoff=0.4,\n", " label_threshold=0.5,\n", " label_direction=\"ge\",\n", " test_fraction=0.2,\n", " random_state=7,\n", " )\n", "\n", "\n", "random_result = run_ncats_sol(\"random\")\n", "butina_result = run_ncats_sol(\"butina\")" ] }, { "cell_type": "code", "execution_count": null, "id": "a4de6c3c", "metadata": {}, "outputs": [], "source": [ "pd.DataFrame(\n", " {\n", " \"split\": [\"random\", \"butina\"],\n", " \"auc\": [random_result.metrics.auc, butina_result.metrics.auc],\n", " \"mcc\": [random_result.metrics.mcc, butina_result.metrics.mcc],\n", " \"train_n\": [len(random_result.indices.train), len(butina_result.indices.train)],\n", " \"test_n\": [len(random_result.indices.test), len(butina_result.indices.test)],\n", " }\n", ")" ] }, { "cell_type": "markdown", "id": "f1c7744c", "metadata": {}, "source": [ "## Static 2D posterior plot\n", "\n", "The 2D view is the most compact summary of the posterior surface and projected test compounds." ] }, { "cell_type": "code", "execution_count": null, "id": "35ad7cc7", "metadata": {}, "outputs": [], "source": [ "fig, ax = plt.subplots(figsize=(8, 6))\n", "plot_posterior_2d(\n", " ax,\n", " grid_x=random_result.grid.grid_x,\n", " grid_y=random_result.grid.grid_y,\n", " posterior_z=random_result.grid.posterior_z,\n", " pos_density_z=random_result.grid.pos_density_z,\n", " c_test=random_result.coords.c_test,\n", " y_test_bin=random_result.labels.y_test_bin,\n", ")\n", "ax.set_title(\"NCATS-sol random split posterior\")\n", "fig.tight_layout()" ] }, { "cell_type": "markdown", "id": "abc05cbb", "metadata": {}, "source": [ "## Static 3D posterior plot\n", "\n", "The 3D view makes the geometry of the posterior surface easier to inspect." ] }, { "cell_type": "code", "execution_count": null, "id": "fcb76d1b", "metadata": {}, "outputs": [], "source": [ "fig3d, ax3d = plot_posterior_3d(\n", " grid_x=random_result.grid.grid_x,\n", " grid_y=random_result.grid.grid_y,\n", " posterior_z=random_result.grid.posterior_z,\n", " c_test=random_result.coords.c_test,\n", " prob_test=random_result.prob_test,\n", ")\n", "ax3d.set_title(\"NCATS-sol random split posterior\")\n", "fig3d.tight_layout()" ] }, { "cell_type": "markdown", "id": "cacdfb8d", "metadata": {}, "source": [ "## Interactive posterior plot\n", "\n", "The interactive plot adds toggleable layers, contour-level hover, compound-level hover metadata, and RDKit hover depictions from SMILES." ] }, { "cell_type": "code", "execution_count": null, "id": "b11fcb16", "metadata": {}, "outputs": [], "source": [ "test_mol_ids = random_result.frames.test[\"mol_id\"].astype(str).to_numpy()\n", "smiles_by_mol_id = dict(\n", " zip(df[\"mol_id\"].astype(str), df[\"canonical_smiles\"].astype(str), strict=True)\n", ")\n", "test_smiles = np.asarray(\n", " [smiles_by_mol_id[mol_id] for mol_id in test_mol_ids], dtype=object\n", ")\n", "\n", "interactive_plot = plot_posterior_2d_interactive(\n", " grid_x=random_result.grid.grid_x,\n", " grid_y=random_result.grid.grid_y,\n", " posterior_z=random_result.grid.posterior_z,\n", " pos_density_z=random_result.grid.pos_density_z,\n", " c_test=random_result.coords.c_test,\n", " y_test_bin=random_result.labels.y_test_bin,\n", " prob_test=random_result.prob_test,\n", " mol_ids=test_mol_ids,\n", " smiles=test_smiles,\n", " width=800,\n", " height=600,\n", ")\n", "HTML(file_html(interactive_plot, CDN, \"Interactive PSMA Posterior\"))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "pygments_lexer": "ipython3" } }, "nbformat": 4, "nbformat_minor": 5 }