머신러닝&딥러닝/모두를 위한 딥러닝

[머신러닝 실습] Gradient Descent(경사하강법)

Chaerry._o 2023. 9. 7. 19:15
반응형

이 글은 PyTorchZeroToAll을 기반으로 작성한 글입니다.

 

 

[머신러닝 이론] Gradient descent(경사하강법)

이 글은 모두를 위한 딥러닝 시즌1을 기반으로 작성한 글입니다. [머신러닝 이론] Linear Regression(선형 회귀) 이 글은 모두의 딥러닝 시즌1을 기반으로 작성한 글입니다. [머신러닝 이론] Machine Learni

chaerrylog.tistory.com

위 글에서 경사하강법에 대해서 설명했다. 이 내용을 토대로 파이썬을 사용해 실습을 진행한다.


Gradient Descent

 

이전 실습에서 loss function에 weight 값을 다양하게 대입했을 때 아래로 볼록한 그래프가 나오는 걸 확인했다.

이러한 그래프는 경사하강법을 사용해서 loss가 최소가 되는 weight 값을 구할 수 있다.

 

x_data = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
y_data = [1, 1, 2, 4, 5, 7, 8, 9, 9, 10]

w = 1.0

def forward(x):
    return x * w

def loss(x, y):
    y_pred = forward(x)
    return (y_pred - y) * (y_pred - y)

print("Prediction (before training): ", forward(4)) # Prediction (before training):  4.0

이 부분에 대해서는 이전 실습 글에서 자세하게 다뤘으므로 간단히 넘어가겠다.

 

def gradient(x, y):  # d_loss/d_w
    return 2 * x * (x * w - y)

이론 글에서 Gradient 값이

$$ 2x(x\omega-y) $$

라는 것을 확인했다. 따라서 위에 수식을 리턴하는 gradient(x, y) 함수를 만든다.

 

for epoch in range(10):
    for x_val, y_val in zip(x_data, y_data):
        # Compute derivative w.r.t to the learned weights
        # Update the weights
        # Compute the loss and print progress
        grad = gradient(x_val, y_val)
        w = w - 0.01 * grad
        print("\tgrad: ", x_val, y_val, round(grad, 2))
        l = loss(x_val, y_val)
    print("progress:", epoch, "w=", round(w, 2), "loss=", round(l, 2))

Gradient descent 알고리즘을 10번 학습시킨다.

 

x_data와 y_data의 요소를 각각 x_val과 y_val에 넣는다.

grad에는 gradient() 함수를 사용해 기울기 값을 저장한다.

w에 w에서 learning rate(0.01)와 기울기를 곱한 값을 저장한다. 이 코드를 통해 weight를 업데이트한다.

l에는 loss() 함수를 사용해 loss 값을 저장한다.

 

print() 함수를 사용해 grad:, x값, y값, grad에 저장된 기울기 값을 출력한다.

print() 함수를 사용해 progress:, 훈련한 횟수, w= 업데이트 된 weight 값, loss= loss 값을 출력한다.

 

훈련 횟수에 따라 weight 값이 바뀌면서 기울기 값과 loss 값이 바뀌는 것을 확인할 수 있다.

 

print("Predicted score (after training): ", forward(4)) # Predicted score (after training): 7.804863933862125

훈련한 후에 weight 값이 1.95가 됐을 때 선형회귀 모델에 x=4를 대입했더니 예측값이 7.804863933862125이 나왔다.

반응형