RNN-LSTM講解-基於tensorflow實現

cnn卷積神經網絡在前面已經有所了解了,目前博主也使用它進行了一個圖像分類問題,基於kaggle裏面的food-101進行的圖像識別,識別率有點感人,基於數據集的關係,大致來說還可行。
下面我就繼續學習rnn神經網絡。

rnn神經網絡(遞歸/循環神經網絡)模式如下:

我們在處理文字等問題的時候,我們的輸入會把上一個時間輸出的數據作為下一個時間的輸入數據進行處理。
例如:我們有一段話,我們將其分詞,得到t個數據,我們分別將每一個詞傳入到x0,x1….xt裏面,當x0傳入后,會得到一個結果h0,同時我們會將處理后的數據傳入到下個時間,到下個時間的時候,我們會再傳入一個數據x1,同時還有上一個時間處理后的數據,將這兩個數據進行整合計算,然後再向下傳輸,一直到結束。
rnn本質來說還是一個bp迴路,不過他只是比bp網絡多一個環節,即它可以反饋上一時間點處理后的數據。

上圖細化如下:

rnn實際上還是存在梯度消失的問題,因此如上圖所示,當我們在第一個時間輸入的數據,可能在很久之後他就已經梯度消失了(影響很小),因此我們使用lstm(long short trem memory)

上圖有三個門:輸入門    忘記門   輸出門
1.輸入門:通過input * g 來判斷是否輸入,如果不輸入就為0,輸入就是0,以此判斷信號是否輸入
2.忘記門:這個信號是否需要衰減多少,可能為50%,衰減是根據信號來判斷。
3.輸入門:通過判斷是否輸出,或者輸出多少,例如輸出50%。
因此上述圖可化為:

可以看出,這三個門,所有得影響都是關於輸入和上一個數據得輸出來進行計算的。

可以看下圖:

我們使用lstm得話,通過三個門決定信號是否向下傳輸,傳輸多少都可以控制,是否傳入信號,輸出信息都進行控制。

下面我們還是用tensorflow實現,數據集還是手寫数字,雖然rnn主要是用在文字和語言上,但是它依舊可以用在圖片上。
下面給出代碼:

```python
import tensorflow as tf
from tensorflow.contrib import rnn
from tensorflow.examples.tutorials.mnist import  input_data
mnist=input_data.read_data_sets("MNNIST_data",one_hot=True)

#輸入圖片為 28*28
n_inputs=28#輸入一行,一行有28個像素
max_time=28#一共28行,所以為28*28
lstm_size=100#100個隱藏單元
batch_size=50
n_classes=10
n_batch=mnist.train.num_examples//batch_size#計算一共多少批次

#這裏none表示第一個維度可以是任意長度
x=tf.placeholder(tf.float32,[None,784])

y=tf.placeholder(tf.float32,[None,10])

#初始化權值
weights=tf.Variable(tf.truncated_normal([lstm_size,n_classes],stddev=0.1))
#初始化偏置值
biases=tf.Variable(tf.constant(0.1,shape=[n_classes]))

##定義Rnn 網絡
def RNN(X,weights,biases):
    inputs=tf.reshape(X,[-1,max_time,n_inputs])
    #定義lstm基本cell
    lstm_cell = rnn.BasicLSTMCell(lstm_size)
    #lstm_cell=tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(lstm_size)
    outputs,final_state=tf.nn.dynamic_rnn(lstm_cell,inputs,dtype=tf.float32)
    results=tf.nn.softmax(tf.matmul(final_state[1],weights)+biases)
    return results
prediction=RNN(x,weights,biases)
#損失函數
cross_entropy=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction,labels=y))
#優化器
train_step=tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
#保存結果
correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))

accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

init=tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    for epoch in range(6):
        for batch in range(n_batch):
            batch_xs,batch_ys=mnist.train.next_batch(batch_size)
            sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})

        acc=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
        print("iter:"+str(epoch)+"testing accuracy"+str(acc))

 

“`
運行結果如下:

本站聲明:網站內容來源於博客園,如有侵權,請聯繫我們,我們將及時處理【其他文章推薦】

※公開收購3c價格,不怕被賤賣!

※想知道網站建置網站改版該如何進行嗎?將由專業工程師為您規劃客製化網頁設計後台網頁設計

※不管是台北網頁設計公司台中網頁設計公司,全省皆有專員為您服務

※Google地圖已可更新顯示潭子電動車充電站設置地點!!

※帶您來看台北網站建置台北網頁設計,各種案例分享