[ML] tf.GradientTape, Regression

2 minute read

TensorFlow tf.GradientTape, Regression

tensorflow-v2.7.0


  • tf.GradientTape: Forward Propagation의 연산값을 저장해주는 API
    연산값들에 대한 정보가 있어야 Back Propagation 과정의 Chain Rule을 적용할 수 있다.
import tensorflow as tf


t1 = tf.Variable([1, 2, 3], dtype=tf.float32)
t2 = tf.Variable([10, 20, 30], dtype=tf.float32)


with tf.GradientTape() as tape:
    t3 = t1 * t2
    
print('t1: ', t1.numpy())
print('t2: ', t2.numpy())
print('t3: ', t3.numpy())
t1:  [1. 2. 3.]
t2:  [10. 20. 30.]
t3:  [10. 40. 90.]

위의 예시코드는 Forward Propagation 연산값(t3 = t1 * t2)을 GradientTape로 저장한 것이다.
tape 객체를 통해 다음을 구할 수 있다.

  • t3에 대한 t1의 미분계수(기울기, 변화량)
  • t3에 대한 t2의 미분계수(기울기, 변화량)
gradients = tape.gradient(t3, [t1, t2])

t1_differential_coefficient = gradients[0]
t2_differential_coefficient = gradients[1]

print('t3에 대한 t1의 미분계수: ', t1_differential_coefficient.numpy())
print('t3에 대한 t2의 미분계수: ', t2_differential_coefficient.numpy())
t3에 대한 t1의 미분계수:  [10. 20. 30.]
t3에 대한 t2의 미분계수:  [1. 2. 3.]


tape.gradient는 한번만 사용 가능하다.
두번 이상 해당 API를 call할 경우, 다음과 같은 에러가 발생한다.

RuntimeError: A non-persistent GradientTape can only be used to compute one set of gradients (or jacobians)



  • tf.Variable이 아닌 tf.constant의 미분계수를 구하려고 할 경우, None이 반환된다.
    tf.constant는 업데이트가 필요한 Tensor가 아니기 때문에 Back Propagation을 진행하지 않기 때문이다.
import tensorflow as tf


t1 = tf.constant([1, 2, 3], dtype=tf.float32)
t2 = tf.Variable([10, 20, 30], dtype=tf.float32)


with tf.GradientTape() as tape:
    t3 = t1 * t2
    
print('t1: ', t1.numpy())
print('t2: ', t2.numpy())
print('t3: ', t3.numpy())
t1:  [1. 2. 3.]
t2:  [10. 20. 30.]
t3:  [10. 40. 90.]
gradients = tape.gradient(t3, [t1, t2])

t1_differential_coefficient = gradients[0]  # dt1
t2_differential_coefficient = gradients[1]  # dt2

print('t3에 대한 t1의 미분계수(dt1): ', t1_differential_coefficient)
print('t3에 대한 t2의 미분계수(dt2): ', t2_differential_coefficient.numpy())
t3에 대한 t1의 미분계수:  None
t3에 대한 t2의 미분계수:  [1. 2. 3.]



tf.GradientTape Example

import tensorflow as tf


t1 = tf.Variable([1, 2, 3], dtype=tf.float32)
t2 = tf.Variable([10, 20, 30], dtype=tf.float32)

with tf.GradientTape() as tape:
    t3 = t1 * t2
    t4 = t3 + t2
    
gradients = tape.gradient(t4, [t1, t2, t3])
dt1, dt2, dt3 = gradients[0], gradients[1], gradients[2]

print('t1: ', t1.numpy())
print('t2: ', t1.numpy())
print('t3: ', t1.numpy())
print('t4: ', t1.numpy())
print()
print('dt1: ', dt1.numpy())
print('dt2: ', dt2.numpy())
print('dt3: ', dt3.numpy())
t1:  [1. 2. 3.]
t2:  [1. 2. 3.]
t3:  [1. 2. 3.]
t4:  [1. 2. 3.]

dt1:  [10. 20. 30.]
dt2:  [2. 3. 4.]
dt3:  [1. 1. 1.]



Regression with tf.GradientTape

import tensorflow as tf
import matplotlib.pyplot as plt


# Save some gpu memories
physical_devices = tf.config.list_physical_devices('GPU')
for physical_device in physical_devices:
    tf.config.experimental.set_memory_growth(device=physical_device, enable=True)
    
    
w_target, b_target = 5, 3
x_data = tf.random.normal(shape=(1000,), dtype=tf.float32)
y_data = w_target * x_data + b_target

# Initialize weights
w = tf.Variable(0, dtype=tf.float32)
b = tf.Variable(0, dtype=tf.float32)

# Loss function
def loss_func(target, pred):
    return (target - pred) ** 2
# Training Hyper Parameters
EPOCHS = 2
LR = 1e-2

# Logging weights for Visualization
w_trace, b_trace = [w.numpy()], [b.numpy()]

# Training code
for _ in range(EPOCHS):
    for x, y in zip(x_data, y_data):
        with tf.GradientTape() as tape:
            pred = w*x + b
            loss = loss_func(target=y, pred=pred)
        
        gradients = tape.gradient(loss, [w, b])
        
        # Update weights
        w.assign_sub(LR * gradients[0])
        b.assign_sub(LR * gradients[1])
        
        # Logging values for visualization
        w_trace.append(w.numpy())
        b_trace.append(b.numpy())
        
# Visualization
fig, ax = plt.subplots(figsize=(20, 10))
ax.plot(w_trace, label='Weight')
ax.plot(b_trace, label='Bias')

# Visualization Parameters
ax.tick_params(labelsize=15)
ax.legend(fontsize=20)
ax.set_title(f'Weight: {w.numpy():.2f}, Bias: {b.numpy():.2f}')



Reference

Leave a comment