梯度下降(gradient descent)是一個求取區域最小值(local minimum)的演算法,因此也是機器學習的基本技巧之一。本篇的目標,是介紹梯度下降的基本概念與實作,以讓你對機器學習中的數學方法有簡單的認識。

梯度的意思是函數對各個變數偏微分所組成的向量,會指向讓函數值上升最快的方向,可以表示函數在不同變數上的變化率;而利用梯度下降法求極值(以最小值為例,下同),則相當於往梯度的反方向走,亦可以想像成從山坡上慢慢一步一步地走到谷底。對於每一步來說,我們需要知道目前梯度的反方向,以及決定這一步該走多長,通常以梯度的倍率表示,稱為學習率(learning rate),寫成算式則為如下:

θt+1 = θt - η∇f(θt)

其中的 θt 是目前所在位置的函數參數,η 是學習率,θt+1 是走了新的一步以後的所在位置的函數輸入。通常,在有極值存在的狀況下,由於電腦表示小數的精確度限制等因素,我們可以在走了夠多步,或者兩步之間的輸入(下山時的水平方向)或輸出(下山時的高度方向)位置變化不大的情況下,將計算出的結果當作極值。

我們來看看如何利用上述的方法,來求函數極值。因為多項式的極值求取相對簡單,因此我們以相對比較不容易手算極值的 f(x) = ex + x2 來做示範:

import matplotlib.pyplot as plt
import numpy as np


def f(x):
	return np.exp(x) + x ** 2


def df(x):
	return np.exp(x) + x * 2


x = np.linspace(-4, 4, 1000)
y = f(x)
plt.plot(x, y)

x = -3
lr = 0.01
for i in range(500):
	x_new = x - lr * df(x)
	y_new = f(x_new)
	if i % 50 == 0:
		plt.plot(x_new, y_new, '.')
		plt.text(x_new, y_new, f'{i+1}')
	x = x_new

plt.show()

在上述範例中,除非你使用比較先進的函式庫,例如 PyTorch 等,否則函數的微分仍然需要自己計算。此外,你可以試著改變起始點和學習率,看看效果是否有所不同。

多變數的狀況也非常類似,以下用 f(x, y) = ex+y + x2 - y2 來做示範:

import matplotlib.pyplot as plt
import numpy as np


def f(x, y):
	return np.exp(x+y) + x ** 2 - y ** 2


def dfx(x, y):
	return np.exp(x+y) + x * 2


def dfy(x, y):
	return np.exp(x+y) - y * 2


x = np.linspace(-3, 3, 100)
y = np.linspace(-3, 3, 100)
z = np.array([[f(c, r) for c in x] for r in y])
print(z.shape)

plt.contour(x, y, z, levels=100)
plt.colorbar()

x = 3
y = 3
lr = 0.001
for i in range(1000):
	x_new = x - lr * dfx(x, y)
	y_new = y - lr * dfy(x, y)
	if i % 50 == 0:
		plt.plot(x_new, y_new, 'x')
		plt.text(x_new, y_new, f'{i+1}')
	x = x_new
	y = y_new

plt.show()

在上述範例中,由於初始位置的梯度非常大,在一步跨太遠的情況下,很容易造成結果發散,因此我們透過降低學習率的方式,來讓每一步不要跨太遠,使得結果可以呈現收斂。

機器學習中,我們會希望模型的輸出與標準答案的誤差盡可能的小,也就是對某個可以評估誤差的函數(loss function)求取極值。該函數的輸入分成兩部分,一部分是模型的輸出,另一部分是標準答案,而我們能變更的只有模型如何輸出,因此在計算梯度時,只會對模型輸出相關的那一部分作偏微分。舉例來說,假設你看到二維平面上有一堆點,它們的位置分布長的像是下面的 function,只是你不知道實際參數的話:

ŷ = f(x) = e1x + θ2x + θ3

而若以 y 軸方向上的差異的平方當作 loss function,則模型輸出 ŷ 與標準答案 y 之間的差異,可以寫成:

ℒ(ŷ, y) = (ŷ - y)2 = (e1x + θ2x + θ3 - y)2

上列算式對 θ1、θ2、θ3 的偏微分分別是:

∂ℒ(ŷ, y) / ∂θ1 = 2(ŷ - y)e1x(-x),

∂ℒ(ŷ, y) / ∂θ2 = 2(ŷ - y)x,以及

∂ℒ(ŷ, y) / ∂θ3 = 2(ŷ - y)

以下範例會亂數產生一堆分布像是 ŷ = f(x) = e1x + θ2x + θ3 的點,並在假設不知道原本 function 參數的情況下,用梯度下降法試圖推測出參數:

import matplotlib.pyplot as plt
import numpy as np


def func(x, a, b, c):
	return np.exp((-a) * x) + b * x + c


def loss(y_hat, y):
	return np.mean((y_hat - y) ** 2)


def d_loss_a(x, y_hat, y, a, b, c):
	return np.mean(2 * (y_hat - y) * np.exp((-a)*x) * (-x))


def d_loss_b(x, y_hat, y, a, b, c):
	return np.mean(2 * (y_hat - y)* x)


def d_loss_c(x, y_hat, y, a, b, c):
	return np.mean(2 * (y_hat - y))


points = []
for _ in range(100):
	x = np.random.rand(1) * 6 - 3
	y = func(x, 1.2, 3.4, 5.6) + np.random.randn(1) * 0.01
	points.append([x, y])
points = np.array(points)

plt.plot(points[:, 0], points[:, 1], '.', label='Data')

a = 1
b = 2
c = 3
lr = 0.001
for i in range(500):
	x = points[:, 0]
	y = points[:, 1]
	y_hat = func(x, a, b, c)
	a_new = a - lr * d_loss_a(x, y_hat, y, a, b, c)
	b_new = b - lr * d_loss_b(x, y_hat, y, a, b, c)
	c_new = c - lr * d_loss_c(x, y_hat, y, a, b, c)
	a = a_new
	b = b_new
	c = c_new
print(a, b, c)

x = np.linspace(-3, 3, 1000)
y = func(x, a, b, c)
plt.plot(x, y, label='Pred')

plt.show()

在上述範例中,我們選擇了離實際位置相當近的點做為起始。你可以試著使用其他起始位置,但是有可能會讓計算結果跟實際位置相差比較大。此外,在範例中是看過所有資料以後再計算梯度,但實務上若遇到資料集非常龐大時,你也可以每次只取一小部分計算梯度,稱為 mini-batch gradient descent,實作範例如下:

import matplotlib.pyplot as plt
import numpy as np


def func(x, a, b, c):
	return np.exp((-a) * x) + b * x + c


def loss(y_hat, y):
	return np.mean((y_hat - y) ** 2)


def d_loss_a(x, y_hat, y, a, b, c):
	return np.mean(2 * (y_hat - y) * np.exp((-a)*x) * (-x))


def d_loss_b(x, y_hat, y, a, b, c):
	return np.mean(2 * (y_hat - y)* x)


def d_loss_c(x, y_hat, y, a, b, c):
	return np.mean(2 * (y_hat - y))


points = []
for _ in range(100):
	x = np.random.rand(1) * 6 - 3
	y = func(x, 1.2, 3.4, 5.6) + np.random.randn(1) * 0.01
	points.append([x, y])
points = np.array(points)

plt.plot(points[:, 0], points[:, 1], '.', label='Data')

a = 1
b = 3
c = 5
lr = 0.001
for i in range(500):
	batch_size = 10
	for j in range(0, points.shape[0], batch_size):
		x = points[j:j+batch_size, 0]
		y = points[j:j+batch_size, 1]
		y_hat = func(x, a, b, c)
		a_new = a - lr * d_loss_a(x, y_hat, y, a, b, c)
		b_new = b - lr * d_loss_b(x, y_hat, y, a, b, c)
		c_new = c - lr * d_loss_c(x, y_hat, y, a, b, c)
		a = a_new
		b = b_new
		c = c_new
print(a, b, c)

x = np.linspace(-3, 3, 1000)
y = func(x, a, b, c)
plt.plot(x, y, label='Pred')

plt.show()

Mini-batch 的技巧,在訓練類神經網路時會經常被使用,並且通常會將資料點隨機打亂順序,以提升梯度下降時的多樣性。