TensorFlow: tf.get_colleciton()とtf.add_to_collection()
なぞのメソッドtf.get_collection()とtf.get_to_collection()を調べた.
API時には,
tf.Graph.get_collection(name, scope=None)
Returns a list of values in the collection with the given name.
Args:key: The key for the collection. For example, the GraphKeys class contains many standard names for collections.
scope: (Optional.) If supplied, the resulting list is filtered to include only items whose name begins with this string.Returns:
The list of values in the collection with the given name, or an empty list if no value has been added to that collection. The list contains the values in the order under which they were collected
こう書いてある.key = nameとみていいだろう.scopeはvarialbe_scopeでない,name_scopeなので注意.
公式サンプルでは,cifar10_multi_gpu_train.pyとcifar10.pyで使用されている.
GraphKeysに基本的なkey nameが入っているというので,見てみると,
- 'QUEUE_RUNNERS',
- 'SUMMARIES',
- 'TABLE_INITIALIZERS',
- 'TRAINABLE_VARIABLES',
- 'VARIABLES'
がクラス変数にある.APIの説明と名前を見る感じだとvariableとか作ったら,これらのmap<key, list>の変数に入ると思われる.なにも考えない(Graph object を明示的に作らないで,tf.Varialbe()するとか)と,default graphのcollectionsに入る.
サンプルコード
#!/usr/bin/env python import tensorflow as tf import numpy as np # Collection keys print "### Collection keys ###" print dir(tf.GraphKeys) print "" # Show values in VARIABLES collections print "### Show values in VARIABLES collections ###" print tf.get_collection(tf.GraphKeys.VARIABLES) print "" # Create Varialbes and show the colllections again print "### Create Varialbes and show the colllections again ###" x = tf.Variable(np.random.rand(5, 5, 5), name="x") y = tf.Variable(np.random.rand(5, 5, 5), name="y") c = tf.constant(np.random.rand(5, 5, 5), name="c") init_op = tf.initialize_all_variables() w = tf.get_variable("w", shape=(5, 5), initializer=tf.random_normal_initializer()) print "VARIABLES", tf.get_collection(tf.GraphKeys.VARIABLES) print "TRAINABLE_VARIABLES", tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES) print "TABLE_INITIALIZERS", tf.get_collection(tf.GraphKeys.TABLE_INITIALIZERS) print "SUMMARIES", tf.get_collection(tf.GraphKeys.SUMMARIES) print "QUEUE_RUNNERS", tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS) print "" # Add somthing to any collection and show that print "### Add somthing to any collection ###" sample = tf.get_collection("sample") sample.append(x) print tf.get_collection("sample") tf.add_to_collection("sample", x) tf.add_to_collection("sample", y) print tf.get_collection("sample") print "" # Add somthing to any collection and show that with scope filter print "### Add somthing to any collection and show that with scope filter ###" tf.add_to_collection("sample", x) with tf.name_scope("name_scope") as scope: print tf.get_collection("sample", scope) z = tf.Variable(np.random.rand(5, 5, 5), name="z") tf.add_to_collection("sample", z) print len(tf.get_collection("sample")) print len(tf.get_collection("sample", scope)) print x.name print y.name print z.name
一番確認したいのは,name_scopeを作って,その中でvarialbeを作る.その中であるcollectionに入れる.scopeでちゃんとフィルタできているかどうか.
こうすることで,同じcollectionでも,scopeでフィルタできる
TABLE_INITIALIZERSがなぞのまま.個人的には未解決.
参考
- https://www.tensorflow.org/versions/master/api_docs/python/framework.html#Graph.add_to_collection
- https://www.tensorflow.org/versions/master/api_docs/python/framework.html#Graph.get_collection
- https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/models/image/cifar10/cifar10_multi_gpu_train.py
- https://www.tensorflow.org/versions/master/api_docs/python/framework.html#add_to_collection