TensorFlow: Graphをもう少し理解
Graphをもう少し理解する.
- Default graph
- Create another graph in this thread (main thread)
- Graph in multi thread
- Write graph as protbuf to disk
- Read graph from disk and to Graph
の5通りのサンプル.チェックポイントからの復帰はやってない.
チェックポイントからの復帰は,これを参考にすればできるはず.
tf.import_graph_def(graph_def)でimportしたときには,opsは,default_graphに追加される.別のgraphに追加したいなら,g = tf.Graph()して,g.as_default()して,そのコンテキストの中で,tf.import_graph_def(graph_def)すること.
#!/usr/bin/env python import tensorflow as tf import threading import numpy as np class GraphWorker(threading.Thread): """ """ def __init__(self, index): """ """ super(GraphWorker, self).__init__() self.index = index pass def run(self): g = tf.get_default_graph() #print "Graph in main thread", g g_local = tf.Graph() #print "Graph in this thread", g_local def main(): # Default graph """ Confirm that a default Graph is always registered, and accessible by calling tf.get_default_graph(). To add an operation to the default graph, simply call one of the functions that defines a new Operation: """ c = tf.constant(4.0) assert c.graph is tf.get_default_graph() # Create another graph in this thread (main thread) """ Confirm that tf.Graph.as_default() method should be used if you want to create multiple graphs in the same process. """ with tf.Graph().as_default() as g: c = tf.constant(30.0) assert c.graph is g d = tf.constant(40.0) assert d.graph is g e = tf.Variable(np.random.rand(10, 10), name="e") assert tf.get_default_graph() != g # Graph in multi thread """ The default graph is a property of the current thread. If you create a new thread, and wish to use the default graph in that thread, you must explicitly add a with g.as_default(): in that thread's function. """ n = 4 graph_workers = [] for i in xrange(n): worker = GraphWorker(i) worker.start() graph_workers.append(worker) for i in xrange(n): graph_workers[i].join() # Write graph as protbuf to disk print g.get_operations() print len(g.get_operations()) tf.train.write_graph(g.as_graph_def(), "./graph_dir_text", "./graph.pbtxt") tf.train.write_graph(g.as_graph_def(), "./graph_dir", "./graph.pb", as_text=False) # Read graph from disk and to Graph print tf.get_default_graph() with open("./graph_dir/graph.pb", "rb") as fp: graph_def = tf.GraphDef() graph_def.ParseFromString(fp.read()) tf.import_graph_def(graph_def) print tf.get_default_graph() print "----- ops in g -----" for op in g.get_operations(): print op.name print "----- ops in default graph after import_graph_def -----" for op in tf.get_default_graph().get_operations(): print op.name pass if __name__ == '__main__': main()
参考
- https://www.tensorflow.org/versions/master/api_docs/python/framework.html#Graph
- https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/python/ops/
- https://tensorflow.googlesource.com/tensorflow/+/master/tensorflow/python/framework/ops.py
- http://stackoverflow.com/questions/33759623/tensorflow-how-to-restore-a-previously-saved-model-python
- https://groups.google.com/a/tensorflow.org/forum/#!topic/discuss/HfQ4etzW7D8