TensorFlow 之 Automatic differentiation and gradient tape

Automatic differentiation and gradient tape

​ 之前我们介绍了Tensor 以及在其上的操作,下面我们介绍一下自动微分技术,---用来优化模型参数的关键。

​ tensorflow 提供了用于自动微分的API,来计算一个函数的导数。一种更接近数学的求导方法是:先写一个python函数,封装好对参数的运算。然后使用tf.contrib.eager.gradients_function 来创建一个函数计算上面封装好的函数的导函数(可指定对哪个参数求导)。同时,只要嵌套调用该函数,即可求高阶导。



Gradient tapes

什么是梯度带呢?对于每一个可导的TF函数操作,其都有对应的导函数。

对于上面提到的用户封装的一个函数,TF都会记录下这样一系列的操作,称之为tape

然后根据tape和每一个primitive操作(即上面提高的TF函数),就可以使用reverse mode differentiation

方法来计算用户封装的函数啦



tf.GradientTape(一点点臃肿,但是更详细)

有的时候,函数不是那么好封装,比如你需要对中间的某个变量进行求导。这个时候你可以使用tf.GradientTape context(python的上下文管理机制)

All computation inside the context of a tf.GradientTape is "recorded".


Higher-order gradients

GradientTape 上下文管理器中的所有相关操作都会记录下来用于automatic differentiation.

If gradients are computed in that context, then the gradient computation is recorded as well.

As a result, the exact same API works for higher-order gradients as well.



详见代码



Code1

import tensorflow as tf 
import matplotlib.pyplot as plt
from math import pi

tf.enable_eager_execution()
tfe = tf.contrib.eager

def f(x):
    return tf.square(tf.sin(x))

print( f(pi/2).numpy() == 1.0 )
grad_f = tfe.gradients_function(f) # f的梯度函数
print(tf.abs(grad_f((pi/2,pi))).numpy())

# Higher-order gradients
def grad(f):  #相当于直接返回一个一阶梯度函数
    return lambda x : tfe.gradients_function(f)(x)[0]

x = tf.lin_space(-2*pi,2*pi,100) # 100 points between -2pi ~ 2pi

plt.plot(x, f(x), label = 'f')
plt.plot(x, grad(f)(x), label = 'first derivative')
plt.plot(x, grad(grad(f))(x), label = 'second derivative')
plt.plot(x, grad(grad(grad(f)))(x), label = 'third derivative')
plt.legend()
plt.show()

Code2

'''
Gradient tapes
'''
import tensorflow as tf 

tf.enable_eager_execution()
tfe = tf.contrib.eager # shorthand for some symbols


# x^y
def f(x,y): 
    output = 1
    # Must use range(int(y)) instead of range(y) in Python 3 when
    # using TensorFlow 1.10 and earlier. Can use range(y) in 1.11+
    for i in range(int(y)): # you can use for loop (#^.^#)
        output = tf.multiply(output, x)
    return output         


# d x^y / d x
def g(x,y):
    # Return the gradient of 'f' with respect to it's first parameter  default?
    return tfe.gradients_function(f)(x,y)[0]


print( f(3,2).numpy() )
print( g(3.0,2).numpy() )
print( f(3,3).numpy() )
print( g(3.0,3).numpy() )

Code3

'''
At times it may be inconvenient to encapsulate(封装) computation of 
interest into a function. For example, if you want the gradient
of the output with respect to intermediate(中间的) values computed in
the function. In such cases, the slightly more verbose but 
explicit(明确的)  **tf.GradientTape**  context is useful. All computation 
inside the context(上下文) of a **tf.GradientTape** is "recorded".
'''
import tensorflow as tf 

tf.enable_eager_execution()
tfe = tf.contrib.eager # shorthand for some symbols

x= tf.ones((2, 2))

with tf.GradientTape(persistent = True) as t: # persistent 持久的
    t.watch(x)
    y = tf.reduce_sum(x) # 就是求和的意思(降维)
    z = tf.multiply(y,y)


#use the same tape to compute the derivative of z with 
# respect to the intermediate value y
dz_dy = t.gradient(z, y) # 对y求导
print(dz_dy.numpy())


# Derivative of z with respect to the original input tensor x
dz_dx = t.gradient(z, x) # 对x求导
print(dz_dx.numpy())


'''
higher-order gradients

Operations inside of the GradientTape context manager
 are recorded for automatic differentiation. 

 If gradients are computed in that context, 
 then the gradient computation is recorded as well. 

  As a result, the exact same API works for higher-order 
  gradients as well. For example:
'''

x = tf.Variable(1.0)  # Convert the Python 1.0 to a Tensor object
with tf.GradientTape() as t:
    with tf.GradientTape() as t2: 
        y = x*x*x
        # Compute the gradient inside the 't' context manager
        # which means the gradient computation is differentiable as well.
    dy_dx = t2.gradient(y,x)
d2y_dx2 = t.gradient(dy_dx, x)
print(dy_dx.numpy(), d2y_dx2.numpy())




评论

登录之后就可以评论 / 回复啦(#^.^#)    点此登录    点此注册

评论列表

暂无评论!快写一条吧(๑′ᴗ‵๑)