ATTENTION

You can interact with this notebook online at Binder badge

Understanding Automatic Differentiation (Incomplete)

This tutorial walks through how automatic differentiation works and how sympyle implements it

In [1]:
import plotly.plotly as py
import numpy as np
from plotly.offline import  init_notebook_mode, plot, iplot
init_notebook_mode(connected=True)
import plotly.graph_objs as go
from plotly.tools import FigureFactory as FF
from IPython.display import HTML
# The polling here is to ensure that plotly.js has already been loaded before
# setting display alignment in order to avoid a race condition.
display(HTML(
    '<script>'
        'var waitForPlotly = setInterval( function() {'
            'if( typeof(window.Plotly) !== "undefined" ){'
                'MathJax.Hub.Config({ SVG: { font: "STIX-Web" }, displayAlign: "center" });'
                'MathJax.Hub.Queue(["setRenderer", MathJax.Hub, "SVG"]);'
                'clearInterval(waitForPlotly);'
            '}}, 250 );'
    '</script>'
))

Derivatives and why they are useful

This tutorial assumes that the reader has a basic understanding of derivatives. If this is not the case, he may refer to the excellent The matrix calculus you need for Deep Learning.

Derivatives describe how much some variable changes when some other variable is changed by a tiny amount. This is extremely handy, because if we can describe how much some error changes when we change some parameter, we can optimize that parameter to produce a lower error.

A simple example

To drive home the above concept , lets do an example by hand.

Lets define our ‘error’ equation as:

\begin{equation} error = 2 * x \end{equation}

where x is a learnable parameter. To begin with , lets give it an initial value of \(10\).

Therefore, \(\frac{d\ error}{d x}\) will be \(2\), which basically means if \(x\) is changed by a infinitesimally small amount (so small that we can consider it almost not changing at all ), \(error\) would change by \(2\).

Why is this useful

If we know how much changing an input to a function changes the value of the function, we can simply change the input such that the value of the function is less ( we want to lower our error ).

Visualizing the derivative of a function

Let us visualize

\begin{equation} y = x^2 \end{equation}

and its derivative

\begin{equation} \frac{d\ y}{d\ x} = 2*x \end{equation}
In [ ]:
x = np.linspace(-5,5, 1000)
y = x**2

dydx = 2*x

trace1 = go.Scatter(
    x=x,
    y=y,
    mode='lines',
    name='x^2'
)

trace2 = go.Scatter(
    x=x,
    y=dydx,
    mode='lines',
    name='derivative of x^2'
)


trace_data = [trace1, trace2]
iplot(trace_data,show_link=False,config=dict(displaylogo=False))

Notice in this function, as the value of y increases, the value of \(\frac{dy}{dx}\) also increases .

Chain rule

Often times our error is not a simple function, but a chain of functions. For example :

\begin{equation} error = add(10,x) \end{equation}

Where \(add\) is simply the + operator. But it helps to explicitly think of it as a function.

Notice that 10 is a constant here, but \(x\) is our input parameter like the previous example. We now want to try to change \(x\) such that \(error\) is lower.

The chain rule describes how the derivative of a chain of functions may be calculated.

Its given by :

\begin{equation} \frac{d\ error}{d\ x} = \frac{d\ error}{d\ add} *\frac{d\ add}{d\ x} \end{equation}

Visualizing the derivative of a chained function

Let us use the function

\begin{equation} y = sin(x)^2 \end{equation}

let us rewrite it as :

\begin{equation} y = pow(sin(x),2) \end{equation}

where \(pow\) is simply the power function

The chain rule tell us that

\begin{equation} \frac{dy}{dx} = \color{orange}{\frac{d\ y}{d\ pow}} * \color{blue}{\frac{d\ pow}{d\ x}} \end{equation}

In this case it will be

\begin{equation} \frac{dy}{dx} = \color{orange}{2sin(x)} * \color{blue}{cos(x)} \end{equation}
In [ ]:
x = np.linspace(-5,5, 1000)
y = np.sin(x)**2

dydx = 2*np.sin(x)*np.cos(x)

trace1 = go.Scatter(
    x=x,
    y=y,
    mode='lines',
    name='sin(x)**2'
)

trace2 = go.Scatter(
    x=x,
    y=dydx,
    mode='lines',
    name='derivative of sin(x)**2'
)

trace_data = [trace1, trace2]
iplot(trace_data,show_link=False,config=dict(displaylogo=False))

Large, modern networks with several million parameters are trained using this system. The expression \(\frac{d\ error}{d\ x}\) reads precisely this:

How much will \(error\) increase if I increase \(x\) by a tiny amount

Since we are interested in decreasing \(error\) , we change the value of \(x\) in the direction opposite to the derivative.

Updating parameters based on gradients

The derivative tells us the behavior of a function at a point.

We need to update the parameter based on this behavior. We can do param = param - gradient , but this has the potential to overshoot the optimum (in the case of \(x^2\) our optimum is the bottom of the curve).

Therefore we take much smaller steps by just multiplying the gradient with a value between 0 and 1. The number that we multiply by is known as the Learning Rate.

An example

Let us revisit our original function

\begin{equation} y = x^2 \end{equation}

and its derivative

\begin{equation} \frac{d\ y}{d\ x} = 2*x \end{equation}

We will try to update x such that y is as close to 0 as possible

In [76]:
def create_animated_plot(func,derivative,x=0,LR=.1,iters=10):
    """
    Create an animated plot of a parameter being optimized

    :param func: a callable that returns the value of the function at a point
    :param derivative: a callable that returns the derivative of the function at a point
    :param x: the initial value of x
    :param LR: the learning rate
    :param iters: the number of iterations to run for
    """

    x_history = []
    y_history = []

    for i in range(iters):
        y = func(x)
        dydx = derivative(x)
        x_history.append(x)
        y_history.append(y)
        x = x - LR*dydx

    x_history = np.array(x_history)

    x = np.linspace(-max(x_history)-10,max(x_history)+10,1000)
    y= func(x)

    def prepare_data():
        data_arr = []
        for i in range(len(x_history)):
            data_arr.append({"data":[{'x':x_history[i:i+1],
                                      'y':y_history[i:i+1],
                                      "line":{"color":"red",
                                             "dash":"dot"},
                                     "name":"Learned X value"},
                                    ],
                            },
                           )
        data_arr.append({"data":[{"x":x_history,"y":y_history,"line":{"color":"red"}}],})
        return data_arr

    data_arr = prepare_data()
    figure = {'data': [{'x': x, 'y': y,"line":{"color":"green",
                             "dash":"dashdot",
                            "width":.5,

                           },
                       "name":"y = f(x)"},
                   {'x': x, 'y': y,
                    "name":"y = f(x)",
                    "line":{"color":"green",
                             "dash":"dashdot",
                            "width":.5
                           },},

                  ],
          'layout': {'xaxis': {'range': [-10, 10], 'autorange': True},
                     'yaxis': {'range': [-10, 10], 'autorange': True},
                    'updatemenus': [{'type': 'buttons',
                                      'buttons': [{'label': 'Play',
                                                   'method': 'animate',
                                                   'args': [None]}]}]},
          'frames': data_arr}

    return figure
In [78]:
fig = animate(lambda x:x**2,
              lambda x:2*x,
              LR=.1,
              x=20,
              iters=10)
iplot(fig,show_link=False,config=dict(displaylogo=False))