読者です 読者をやめる 読者になる 読者になる

KZKY memo

自分用メモ.

TensorFlow: tf.get_colleciton()とtf.add_to_collection()

python tensorflow deeplearning

なぞのメソッドtf.get_collection()とtf.get_to_collection()を調べた.

API時には,

tf.Graph.get_collection(name, scope=None)

Returns a list of values in the collection with the given name.
Args:

key: The key for the collection. For example, the GraphKeys class contains many standard names for collections.
scope: (Optional.) If supplied, the resulting list is filtered to include only items whose name begins with this string.

Returns:

The list of values in the collection with the given name, or an empty list if no value has been added to that collection. The list contains the values in the order under which they were collected

こう書いてある.key = nameとみていいだろう.scopeはvarialbe_scopeでない,name_scopeなので注意.

公式サンプルでは,cifar10_multi_gpu_train.pycifar10.pyで使用されている.

GraphKeysに基本的なkey nameが入っているというので,見てみると,

  • 'QUEUE_RUNNERS',
  • 'SUMMARIES',
  • 'TABLE_INITIALIZERS',
  • 'TRAINABLE_VARIABLES',
  • 'VARIABLES'

がクラス変数にある.APIの説明と名前を見る感じだとvariableとか作ったら,これらのmap<key, list>の変数に入ると思われる.なにも考えない(Graph object を明示的に作らないで,tf.Varialbe()するとか)と,default graphのcollectionsに入る.

サンプルコード
#!/usr/bin/env python

import tensorflow as tf
import numpy as np


# Collection keys
print "### Collection keys ###"
print dir(tf.GraphKeys)
print ""

# Show values in VARIABLES collections
print "### Show values in VARIABLES collections ###"
print tf.get_collection(tf.GraphKeys.VARIABLES)
print ""

# Create Varialbes and show the colllections again
print "### Create Varialbes and show the colllections again ###"
x = tf.Variable(np.random.rand(5, 5, 5), name="x")
y = tf.Variable(np.random.rand(5, 5, 5), name="y")
c = tf.constant(np.random.rand(5, 5, 5), name="c")
init_op = tf.initialize_all_variables()
w = tf.get_variable("w", shape=(5, 5), initializer=tf.random_normal_initializer())

print "VARIABLES", tf.get_collection(tf.GraphKeys.VARIABLES)
print "TRAINABLE_VARIABLES", tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
print "TABLE_INITIALIZERS", tf.get_collection(tf.GraphKeys.TABLE_INITIALIZERS)
print "SUMMARIES", tf.get_collection(tf.GraphKeys.SUMMARIES)
print "QUEUE_RUNNERS", tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS)
print ""

# Add somthing to any collection and show that
print "### Add somthing to any collection ###"
sample = tf.get_collection("sample")
sample.append(x)
print tf.get_collection("sample")

tf.add_to_collection("sample", x)
tf.add_to_collection("sample", y)
print tf.get_collection("sample")
print ""

# Add somthing to any collection and show that with scope filter
print "### Add somthing to any collection and show that with scope filter ###"
tf.add_to_collection("sample", x)
with tf.name_scope("name_scope") as scope:
    print tf.get_collection("sample", scope)
    z = tf.Variable(np.random.rand(5, 5, 5), name="z")
    tf.add_to_collection("sample", z)
    
print len(tf.get_collection("sample"))
print len(tf.get_collection("sample", scope))

print x.name
print y.name
print z.name

一番確認したいのは,name_scopeを作って,その中でvarialbeを作る.その中であるcollectionに入れる.scopeでちゃんとフィルタできているかどうか.
こうすることで,同じcollectionでも,scopeでフィルタできる

TABLE_INITIALIZERSがなぞのまま.個人的には未解決.