JAX 기초 및 간단한 선형회귀모델 구현실습

2020-02-28

.

Deep_Learning_TIL_(20200228)

학습 시 참고자료(출처) 자료 :

토론토 대학의 Nicolas Parpernot 교수가 작성한 JAX Tutorial 문서로 ‘ECE1513H : Introduction to Machine Learning’ 코스 수강생을 위한 실습노트임

** 코스 홈페이지 : https://www.papernot.fr/teaching/w20-ml?fbclid=IwAR01vI_2yncQzfCen6ro2_BhBPaQskpCBm9dMD_G1LY1fmB7NEfvTAGZ-Sc

노트제목 : ECE1513_Introduction_to_JAX_Student_Solutions

URL : https://colab.research.google.com/drive/1YsoctmFgZ0joXdMqEHmTsP3AAyhyNlzI?fbclid=IwAR3aNKWRdJmKqLvRS9ENegrLe9e3C6HRb33T1-mg7QpMah9LZ42knf_szlo#scrollTo=zafRRNriqeyq

참고사항 : ECE1513H : Introduction to Machine Learning 코스 수강생을 위한 실습노트

실습환경 : Google Colab gpu 엔진

1.0 Introduction to JAX

Main references:

  • [1] https://github.com/google/jax

  • [2] https://colindcarroll.com/2019/04/06/exercises-in-automatic-differentiation-using-autograd-and-jax/

  • [3] https://github.com/HIPS/autograd/blob/master/docs/tutorial.md
  • [4] https://colinraffel.com/blog/you-don-t-know-jax.html
  • [5] https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html

At its core, JAX is an extensible system for transforming numerical functions. JAX implements an updated version of autograd. Autograd’s grad function takes in a scalar function, and returns to you the gradient function.

In this tutorial we will go through some of the main features included in the JAX package: grad and jit.

One of the main features you will probably use during this course is the grad function for computing gradients of your loss functions with respect to your model’s weights.

1.1 Step-by-step example on how to use jax.grad for auto-differentiation.

1) Define your function

\[y = \sigma(wx + b),\ \sigma(x)=\frac{1}{1+e^{-x}}\]
import jax.numpy as np
from jax import grad, jit, vmap, random, value_and_grad
import itertools
import time


key = random.PRNGKey(1)

x = random.normal(key)
w = random.normal(key+1)
b = random.normal(key+2)

def sigmoid(x):
    return 1/(1 + np.exp(-x))

def make_linear_sigmoid(x): 
  def predict(W, b): # Here, you define the function with the parameters that define your model
    return sigmoid(np.dot(x, W) + b)
  return predict 

2) Define your gradient with respect to the fitting weights

The function grad takes as arguments:

  • fun: the numpy function for which the computation of the gradients is needed.
  • argnums: the arguments of the functions with respect to which the function will be differentiated
grad_not_jit(fun = make_linear_sigmoid(x),
                              argnums =  (0,1))

Returns the evaluated gradients.


Quick quiz: what would you change in the code above if you needed to differentiate the function only with respect to $b$?

Using jit to speed up functions (VERY IMPORTANT FOR THE FUTURE)

JAX provides jit (just in time compiler) which takes Python (using numpy) functions and compiles them such that they can be run efficiently on the chosen accelerators (CPU/GPU/TPU). Using jit could significantly speed up your computations and it requires little to no overhead in your coding. Let’s have a look at a simple example. It is sufficient to use the jit decorator in front of your function such that the declared operation are compiled in advance (only once) and the code will run much more efficiently without any interpreter overhead.

# I am ignoring the values returned for timing because of issues with the scoping of functions (timeit issue)
grad_not_jit = grad(make_linear_sigmoid(x), (0,1))
%timeit _, _ = grad_not_jit(w,b) # AUTOGRAD

grad_jit = jit(grad(make_linear_sigmoid(x), (0,1)))
%timeit _, _ = grad_jit(w,b) # AUTOGRAD

# Check with the analytical computation
w_gradient , b_gradient = grad_jit(w,b) # AUTOGRAD

w_gradient_manual = sigmoid(np.dot(x, w) + b) * (1 - sigmoid(np.dot(x, w) + b)) * x
b_gradient_manual = sigmoid(np.dot(x, w) + b) * (1 - sigmoid(np.dot(x, w) + b)) 

print()
print('Autograd result : w')
print(w_gradient)
print('Manually derived result : w')
print(w_gradient_manual)

print('Autograd result : b')
print(b_gradient)
print('Manually derived result : b')
print(b_gradient_manual)
The slowest run took 87.19 times longer than the fastest. This could mean that an intermediate result is being cached.
1 loop, best of 3: 4.96 ms per loop
The slowest run took 513.01 times longer than the fastest. This could mean that an intermediate result is being cached.
1000 loops, best of 3: 213 µs per loop

Autograd result : w
-0.26804116
Manually derived result : w
-0.2680411
Autograd result : b
0.22633177
Manually derived result : b
0.22633173

#2.0 Review of Linear Regression As studied in your lecture, in linear regression, we are given a data set $\mathcal{D} = {(x_n,t_n)}_{n=1}^N$, where $x_n = \mathbb{R}^d$, where $d \geq 1$ represents the dimension of your data, and $t_n \in \mathbb{R}$ which are the target values that each data $x_n$ corresponds to.

Next, we seek a prediction function $y$ that takes in $x_n$ and outputs a prediction vector, \(y(x) = \textstyle\sum\limits_{i = 1}^d w_i x_i + b.\)

Let the squared loss be defined as, \begin{equation} \mathcal{L}(x,t) = \dfrac{1}{2}(y(x) - t)^2, t \in \mathbb{R} \end{equation} We can then the optimal prediction function by minimizing the mean squared error, \begin{equation} \mathcal{E}(w) = \dfrac{1}{2N}\sum\limits_{n=1}^N \mathcal{L}(x_n, t_n) = \dfrac{1}{2N}\sum\limits_{n=1}^N (w^\top x_n + b - t_n)^2 \tag*{(mse)} \end{equation}

In the following code, you will generate a data set $\mathcal{D}$ consisting of $20$ 1D data points $x_n$ sampled from a uniform distribution along with targets $t_n$ assumed to be given by the formula,

\(t_n = \sin(\pi x_n) + 0.3*\epsilon, \epsilon \sim \mathcal{N}(0,1)\) Let us visualize these data point using the Matplotlib package.

n, d = 20, 1 #n represents data points and d represents the number of dimensions

x = random.uniform(key, (n, d), dtype=np.float64,  minval = -5., maxval = 5.)    
t = np.sin(np.pi*x) + 0.3*random.normal(key, (n, d))    

import matplotlib.pyplot as plt

plt.scatter(x, t)
plt.show()

1

2.1 Directly Solving for $w^\star$ Using the Normal Equation

Since our data set is small, therefore we can directly find the optimal predictor using a bit of linear algebra.

To account for the bias term, it is common to redefining our data and weight as $x_n := \begin{bmatrix} 1 & x_n^\top \end{bmatrix}$ and $w := \begin{bmatrix} b & w^\top \end{bmatrix}$, \(y(x_n) = \textstyle\sum\limits_{i = 1}^{d+1} w_i x_{n_i} = w^\top x_n.\)

Then, we can show that our mean squared error can be equivalently written as \(\mathcal{E}(w) = \dfrac{1}{2N} \|Xw-t\|^2_2,\) where \begin{equation} X = \begin{bmatrix} x_1^\top \ \vdots \ x_N^\top \end{bmatrix} \in \mathbb{R}^{N \times (d+1)} \quad \text{and} \quad t = \begin{bmatrix} t_1\ \vdots \ t_N \end{bmatrix} \in \mathbb{R}^N \end{equation}

It can be shown that, \(w^\star = (X^\top X)^{-1}X^\top t\) minimizes the mean-squared error $\mathcal{E}$. This expression is sometimes referred to as the normal equation.

In the space below, build the matrix $X$ and find the optimal solution $w^\star$. Afterwards, plot the optimal prediction function against your dataset.

#Bult the X matrix here
X = np.concatenate((np.ones((n,1)),x),axis=1)
Xt = np.transpose(X)

#Compute w using the normal equation
w = np.dot(np.dot(np.linalg.inv(np.dot(Xt, X)), Xt),t)

#Write your optimal predictor here
y = w[0] + w[1]*x

#Plot your prediction against data points
plt.scatter(x, t)
plt.plot(x,y)
plt.show()

print("bias is", w[0],"and the weight is", w[1])

2

bias is [0.01311082] and the weight is [0.20701133]

2.2 Find $w^\star$ Using Gradient Descent

Next, you will use gradient descent to find the solution to the linear regression problem. Recall that the update equation for gradient descent (as studied in class) is given by,

\[w_{k+1} = w_k - \alpha \nabla \mathcal{E}(w_k)\]

$\alpha$ is our learning rate which is usually a small number, e.g., $0.1$

Implement a jax routine for calculating the gradient of $\mathcal{E}$, and then run the gradient descent algorithm (say, for 100 iterations) until the optimal weight is found.

2.2.1 Define the mse function

Hint: your code could be similar to how we built the logistic function at the beginning

# Build the mse function here
def make_mse(x, t):  
  def mse(w,b): 
    return np.sum(np.power(x.dot(w) + b - t, 2))/(2*n)
  return mse 
# You can build this function in one line as well, e.g., def make_mse(x,t,w,b)

2.2.2 Implement Gradient Descent and Run for (say, 100) Interations

\[w_{k+1} = w_k - \alpha \nabla \mathcal{E}(w_k)\]

Hint: The function grad takes as arguments:

  • fun: the numpy function for which the computation of the gradients is needed.
  • argnums: the arguments of the functions with respect to which the function will be differentiated

Example:

w_gradient, b_gradient = grad(fun = make_mse(x, t), argnums = (0,1))(w,b) 

Then use calculated the gradient in your gradient descent update in a FOR loop .

#initialize w, alpha, and number of iterations
w, b = np.zeros(()), np.zeros(())
alpha = 0.1
iterations = 100

# Update w, b by calculating the gradient and making the update.
for _ in range(iterations):
  w_gradient, b_gradient = grad(make_mse(x, t), (0,1))(w,b)
  w, b = w - alpha * w_gradient, b - alpha * b_gradient  

print(w,b)
0.20700936 0.013098102

2.2.3 Build your optimal linear prediction function and visualize your fit to the data

# Your optimal linear prediction function here
y = b + w*x

# Visualize your data and your fit
plt.scatter(x, t)
plt.plot(x,y)
plt.show()

3

2.2.4 (Optional) You can use jit to speed calculation.

Repeat the gradient descent procedure as before but using jit to speed up the calculation of gradients.

# Build your jitted gradient descent iteration here

def make_mse(x, t):  
  def mse(w,b): 
    return np.sum(np.power(x.dot(w) + b - t, 2))/(2*n)
  return mse 

w, b = np.zeros(()), np.zeros(())
epsilon = 0.1
iterations = 100

for _ in range(iterations):
  #JIT version for computing the gradients
  w_gradient, b_gradient = jit(grad(make_mse(x, t), (0,1)))(w,b)
  w, b = w - epsilon * w_gradient, b - epsilon * b_gradient 

y = b + w*x

plt.scatter(x, t)
plt.plot(x,y)
plt.show()

4

#3.0 Overfitting and Regularization

The goal of supervised machine learning is to make accurate predictions on data that are not from the training set, also known as the test set. However it is not guaranteed that the more data we train on, the better we will perform on the test set.

It could happen that we fit our training set too well, so that a test data that does not look like the training data will cause a large error in our prediction. This is called overfitting: the scenario when a predictor that achieves a lower error on the training set results in a large error on the test set.

To overcome overfitting, we can use a technique called regularization, which involves adding a function on to our mean squared error. In the following examples, we will step through several examples of regularization techniques.

3.1 Building a Dataset

The code below creates a data set of $n = 100$ examples generated from uniformly from $[-5, 5]$, each example is $5$ dimensional. The targets are assigned as follows, \(t = x^\top \overline w + 0.1 * \epsilon, \epsilon \sim \mathcal{N}(0,1)\) where $\overline w$ is a randomly generated vector with normal distribution.

Your first task is to partition the data into 50 test examples, 20 validation examples and 30 test examples.

d = 5                             #dimension of your data
n = 100                           #total number of data points
n_train = 50                      #total number of training examples
n_validate = 20                   #total number of examples for validation 
n_test = 30                       #total number of examples for testing

x_data = random.uniform(key, (n, d), dtype=np.float64,  minval = -5., maxval = 5.)    
true_w = random.normal(key, (d,))
t_data = x_data.dot(true_w) + 0.1 * random.normal(key, (n,))

#Partition into training, validation and test set
x_train = x_data[0:n_train]                 #can also use x_data[:n_train], without the zero.
x_val = x_data[n_train:n_train+n_validate]
x_test = x_data[n_train + n_validate:n]     #can also use x_data[n_train + n_validate:], without the n

t_train = t_data[0:n_train]
t_val = t_data[n_train:n_train+n_validate]
t_test = t_data[n_train + n_validate:n]

3.2 Solve for $w^\star$ Using Gradient Descent

Next, using the same linear regresion code that you have built above, using gradient descent, compute the optimal predictor $y = x^\top w^\star + b^\star$ corresponding to the mean-squared error \(\mathcal{E}(w) = \dfrac{1}{2N}\sum\limits_{n=1}^N (w^\top x_n + b - t_n)^2\)

Afterwards, find the error between the predicted targets on the test set versus the true test targets using the formula, \(\|y_\text{test}- t_\text{test}\|_2\) where $y_\text{test}$ is your prediction on the test set.

#Define your mse function here 
def make_mse_no_reg(x, t):  
  def mse_no_reg(w,b): 
    return np.sum(np.power(x.dot(w) + b - t, 2))/(2*n_train) 
  return mse_no_reg

#initialize variables
w, b = np.zeros((d, )), np.zeros((1, ))
alpha = 0.1
iterations = 100

#Perform gradient descent

for _ in range(iterations):
  w_gradient, b_gradient = grad(make_mse_no_reg(x_train, t_train), (0,1))(w,b)
  w = w - alpha * w_gradient
  b = b - alpha * b_gradient

#Compute and print your test error
y_test = b*np.ones((n_test,)) + x_test.dot(w)

print("Error for base case (no regularization):", np.linalg.norm(y_test - t_test, 2))

Error for base case (no regularization): 0.62145925

3.3 $l_2$ regularization

The first regularization technique we will introduce is called $l_2$ regularization.

We define our mean-squared loss function as \(\mathcal{E}_{l_2}(w) = \dfrac{1}{2N}\sum\limits_{n=1}^N (w^\top x_n + b - t_n)^2 + \dfrac{\lambda_2}{2}\|w\|^2_2\) where $\lambda_2$ is a small positive constant we refer to as the regularization constant, e.g., $r = 0.01$.

This optimization problem is called ridge regression.

Make slight changes to your previous code to account for the $l_2$ norm, and using gradient descent, compute the optimal predictor $y = x^\top w^\star + b^\star$ corresponding to this regularized mean-squared error.

# regularization constant for l2 regularized mse 
lambda_2 = 0.01

def make_mse_l2(x, t):  
  def mse_l2(w,b): 
    return np.sum(np.power(x.dot(w) + b - t, 2))/(2*n_train) + lambda_2*np.square(np.linalg.norm(w, 2))
  return mse_l2

#initialize w, alpha, iterations
w, b = np.zeros((d, )), np.zeros((1, ))
alpha = 0.1
iterations = 100

#this calculates the gradient wrt to w and b at the data points (x,t)

for _ in range(iterations):
  w_gradient, b_gradient = grad(make_mse_l2(x_train, t_train), (0,1))(w,b)
  w, b = w - alpha * w_gradient, b - alpha * b_gradient 

3.3.1 Hyperparameter Tuning Using Validation Set

Next, you are going to tune the regularization parameter $\lambda_2$ using your validation set.

Train several different models for $\lambda_2 \in {0.1, 0.01, 0.001}$. Then compute the error between your prediction on the validation set and the corresponding targets using the formula, \(\|y_\text{val}- t_\text{val}\|_2\)

Pick the best performing model.

#Your code Here
def make_mse_l2(x, t, lambda_2):  
  def mse_l2(w,b): 
    return np.sum(np.power(x_train.dot(w) + b - t, 2))/(2*n_train) + lambda_2*np.square(np.linalg.norm(w, 2))
  return mse_l2

w, b = np.zeros((d, )), np.zeros((1, ))
alpha = 0.1
iterations = 100

#create two lists to store the weights and biases
w_list = []
b_list = []

lambda_set = np.array([0.1, 0.01, 0.001])

#compute the weights correesponding to several regularization constants and save in a list
for i in lambda_set:
  for _ in range(iterations):
    w_gradient, b_gradient = grad(make_mse_l2(x_train, t_train, i), (0,1))(w,b)
    w, b = w - alpha * w_gradient, b - alpha * b_gradient 
  w_list.append(w)
  b_list.append(b)
  y_val = b*np.ones((n_validate,)) + x_val.dot(w)
  print(np.linalg.norm(y_val - t_val, 2))
0.6131907
0.42614818
0.4301212

3.3.2 Compute Test Error

Finally, using the best performing model, find the error between the predicted targets on the test set versus the true test targets using the formula, \(\|y_\text{test}- t_\text{test}\|_2\) where $y_\text{test}$ is your prediction on the test set.

# Your code Here

# from validation, we see that the weights corresponding to lambda = 0.01 
# performs the best on the validation set

y_test = b_list[1]*np.ones((n_test,)) + x_test.dot(w_list[1])

print("The error or l2 regularization:", np.linalg.norm(y_test - t_test, 2))
The error or l2 regularization: 0.62681055

3.4 $l_1$ Regularization

The next regularization technique we will introduce is called $l_1$ regularization.

We define our mean-squared loss function as \(\mathcal{E}_{l_1}(w) = \dfrac{1}{2N}\sum\limits_{n=1}^N (w^\top x_n + b - t_n)^2 + \dfrac{\lambda_1}{2}\|w\|_1\) where $\lambda_1$ is again our regularization constant.

This optimization problem is called the LASSO.

Repeat the previous experiment by training several different models and choosing the best one based on its performance on the validation set.

Finally, pick the best performing model and use it to find the error between the predicted targets on the test set versus the true test targets.

# Your code here
def make_mse_l1(x, t, lambda_1):  
  def mse_l1(w,b): 
    return np.sum(np.power(x.dot(w) + b - t, 2))/(2*n) + lambda_1*np.linalg.norm(w, 1)
  return mse_l1

w, b = np.zeros((d, )), np.zeros((1, ))
alpha = 0.1
iterations = 100

w_list = []
b_list = []

#this calculates the gradient wrt to w and b at the data points (x,t)

lambda_set = np.array([0.1, 0.01, 0.001])
for i in lambda_set:
  for _ in range(iterations):
    w_gradient, b_gradient = grad(make_mse_l1(x_train, t_train, i), (0,1))(w,b)
    w, b = w - alpha * w_gradient, b - alpha * b_gradient 
  w_list.append(w)
  b_list.append(b)
  y = b*np.ones((n_validate,)) + x_val.dot(w)
  print(np.linalg.norm(y - t_val, 2))

y_test = b_list[2]*np.ones((n_test,)) + x_test.dot(w_list[2])

print("Error for l2 regularization:", np.linalg.norm(y_test - t_test, 2))
1.1694859
0.43717736
0.43021026
Error for l2 regularization: 0.62039083

3.5 Approximate $l_1$ Regularization

A technicality is that the $l_1$ norm is not differentiable, so JAX was actually avoiding the non-differentiable points.

To address the non-differentiability of $l_1$ norm, Schmidt et al. [1] proposed a differentiable approximation to the $l_1$ norm.

In this case, we define our mean-squared loss function as \(\mathcal{E}_{\widetilde{l_1}}(w) = \dfrac{1}{2N}\sum\limits_{n=1}^N (w^\top x_n + b - t_n)^2 + \sum\limits_{i = 1}^d \dfrac{\lambda}{a} (\log(1+\exp(a w_i)) + \log(1-\exp(-a w_i)))\) where both $a$ and $\lambda$ are regularization constants.

Repeat the previous experiments by training several different models and choosing the best one based on its performance on the validation set.

(For convenience, you can fix $a$, say $a = 0.1$. Alternatively you can tune both $a$ and $\lambda$)

Finally, pick the best performing model and use it to find the error between the predicted targets on the test set versus the true test targets.

[1] Mark Schmidt, Glenn Fung, and Rmer Rosales. Fast optimization methods for l1 regularization: A comparative study and two new approaches. In European Conference on Machine Learning, pages 286–297. Springer, 2007.

# Your code here
a = 0.1

def make_mse_approx_l1(x, t, a, lambda_approx_l1):  
  def mse_approx_l1(w,b): 
    return np.sum(np.power(x.dot(w) + b - t, 2))/(2*n_train) + lambda_approx_l1/a*(np.log(1+np.exp(a*w)) + np.log(1+np.exp(-a*w))).sum()
  return mse_approx_l1

w, b = np.zeros((d, )), np.zeros((1, ))
alpha = 0.1
iterations = 100

w_list = []
b_list = []

#this calculates the gradient wrt to w and b at the data points (x,t)

lambda_set = np.array([0.1, 0.01, 0.001])
for i in lambda_set:
  for _ in range(iterations):
    w_gradient, b_gradient = grad(make_mse_approx_l1(x_train, t_train, a, i), (0,1))(w,b)
    w, b = w - alpha * w_gradient, b - alpha * b_gradient 
  w_list.append(w)
  b_list.append(b)
  y = b*np.ones((n_validate,)) + x_val.dot(w)
  print(np.linalg.norm(y - t_val, 2))

y_test = b_list[0]*np.ones((n_test,)) + x_test.dot(w_list[0])

print("Error for l2 regularization:", np.linalg.norm(y_test - t_test, 2))
0.42910042
0.4306859
0.43086198
Error for l2 regularization: 0.6223559

3.6 Elastic Net

Finally, we can define our mean-squared loss function as a combination of $l_1$ and $l_2$ norm, \(\mathcal{E}_{l_1+l_2}(w) = \dfrac{1}{2N}\sum\limits_{n=1}^N (w^\top x_n + b - t_n)^2 + \frac{r_1}{2}\|w\|_1 + \dfrac{r_2}{2}\|w\|_2^2\)

This regularizer is called the elastic net.

Repeat the previous experiments by training several different models and choosing the best one based on its performance on the validation set.

Finally, pick the best performing model and use it to find the error between the predicted targets on the test set versus the true test targets.

# Your code here: 

def make_mse_elastic(x, t, r_l1, r_l2):  
  def mse_elastic(w,b): 
    return np.sum(np.power(x.dot(w) + b - t, 2))/(2*n) + r_l1*np.linalg.norm(w, 1) + r_l2*np.square(np.linalg.norm(w, 2)) 
  return mse_elastic

w, b = np.zeros((d, )), np.zeros((1, ))
alpha = 0.1
iterations = 100

w_list = []
b_list = []

#this calculates the gradient wrt to w and b at the data points (x,t)

lambda_set = np.array([0.1, 0.01, 0.001])
for i in lambda_set:
  for _ in range(iterations):
    w_gradient, b_gradient = grad(make_mse_elastic(x_train, t_train, i, i), (0,1))(w,b)
    w, b = w - alpha * w_gradient, b - alpha * b_gradient 
  w_list.append(w)
  b_list.append(b)
  y = b*np.ones((n_validate,)) + x_val.dot(w)
  print(np.linalg.norm(y - t_val, 2))

y_test = b_list[2]*np.ones((n_test,)) + x_test.dot(w_list[2])

print("Error for l2 regularization:", np.linalg.norm(y_test - t_test, 2))

1.8915151
0.45457628
0.42896822
Error for l2 regularization: 0.62130165