{ "cells": [ { "cell_type": "markdown", "id": "03d685b8", "metadata": {}, "source": [ "# Example 5, gradients\n", "\n", "For this example we will show how this code can allow gradients involving the contours of a function can be to be taken with respect to the parameters of a function." ] }, { "cell_type": "code", "execution_count": 1, "id": "28b92ba7", "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2" ] }, { "cell_type": "code", "execution_count": 2, "id": "1dac839a", "metadata": {}, "outputs": [], "source": [ "import jax\n", "import jax.numpy as jnp\n", "\n", "jax.config.update(\"jax_enable_x64\", True)\n", "\n", "import matplotlib.pyplot as plt\n", "import matplotlib.colors as colors\n", "\n", "from jax_zero_contour import ZeroSolver" ] }, { "cell_type": "markdown", "id": "da577e9d", "metadata": {}, "source": [ "To demonstrate how gradients work let's write a function that takes in a position along with a couple of free parameters. The shape of the zero contours will change depending on the values of these parameters." ] }, { "cell_type": "code", "execution_count": 3, "id": "760fa33d", "metadata": {}, "outputs": [], "source": [ "def f(p, A, B):\n", " return A * ((p + B)**4 - 30 * p**2 - 20 * p).sum(axis=0) + 2.0" ] }, { "cell_type": "markdown", "id": "2b8b7c69", "metadata": {}, "source": [ "Next we need to write a function that takes in these parameters and calculates some quantity of interest using the contours. For this example we will calculate the centroid of each contour. Inside this function we just need to `jax.tree_util.Partial` our input parameters into the function before proceeding as normal.\n", "\n", "```{note}\n", "For plotting reasons we will also return the calculated contours, when taking the jacobian later we will use teh `has_aux=True` to indicate the function returns this extra data.\n", "```" ] }, { "cell_type": "code", "execution_count": 4, "id": "0c99629c", "metadata": {}, "outputs": [], "source": [ "zs = ZeroSolver()" ] }, { "cell_type": "code", "execution_count": 5, "id": "866cd0f5", "metadata": {}, "outputs": [], "source": [ "def centroids(A, B):\n", " _f = jax.tree_util.Partial(f, A=A, B=B)\n", " init_guess = jnp.array([[-2.0, 0.0], [4.0, 4.0]])\n", " paths, _ = zs.zero_contour_finder(\n", " _f,\n", " init_guess\n", " )\n", " return jnp.nanmean(paths['path'], axis=1), paths" ] }, { "cell_type": "markdown", "id": "da49b7c9", "metadata": {}, "source": [ "As a test we can see what the output looks like for a given set of values for `A` and `B`." ] }, { "cell_type": "code", "execution_count": 10, "id": "c3057540", "metadata": {}, "outputs": [], "source": [ "center, paths = centroids(0.01, 0.5)" ] }, { "cell_type": "code", "execution_count": 24, "id": "25e30610", "metadata": {}, "outputs": [ { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.plot(*paths['path'].T)\n", "plt.plot(*center[0].T, 'x', color='C0')\n", "plt.plot(*center[1].T, 'x', color='C1')\n", "plt.gca().set_aspect(1);" ] }, { "cell_type": "markdown", "id": "1e7c2bfa", "metadata": {}, "source": [ "As with any jax function we can evaluate the jacobian with respect to the inputs:" ] }, { "cell_type": "code", "execution_count": 22, "id": "80b15cdf", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(Array([[ 17.77701254, 17.91892068],\n", " [-39.5424976 , -39.5436664 ]], dtype=float64, weak_type=True),\n", " Array([[-0.75541393, -0.75484016],\n", " [-0.2485473 , -0.24775855]], dtype=float64, weak_type=True))" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "grad, paths = jax.jacobian(centroids, argnums=(0, 1), has_aux=True)(0.01, 0.5)\n", "grad" ] }, { "cell_type": "code", "execution_count": null, "id": "de4f8bbf", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "lensing", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.9" } }, "nbformat": 4, "nbformat_minor": 5 }