翼度科技»论坛 编程开发 python 查看内容

TensorFlow中关于tf.app.flags命令行参数解析模块

13

主题

13

帖子

39

积分

新手上路

Rank: 1

积分
39
tf.app.flags命令行参数解析模块

说道命令行参数解析,就不得不提到 python 的 argparse 模块,详情可参考我之前的一篇文章:python argparse 模块命令行参数用法及说明
在阅读相关工程的源码时,很容易发现 tf.app.flags 模块的身影。其作用与 python 的 argparse 类似。
直接上代码实例,新建一个名为 test_flags.py 的文件,内容如下:
  1. #coding:utf-8
  2. import tensorflow as tf

  3. FLAGS = tf.app.flags.FLAGS
  4. # tf.app.flags.DEFINE_string("param_name", "default_val", "description")
  5. tf.app.flags.DEFINE_string("train_data_path", "/home/feige", "training data dir")
  6. tf.app.flags.DEFINE_string("log_dir", "./logs", " the log dir")
  7. tf.app.flags.DEFINE_integer("train_batch_size", 128, "batch size of train data")
  8. tf.app.flags.DEFINE_integer("test_batch_size", 64, "batch size of test data")
  9. tf.app.flags.DEFINE_float("learning_rate", 0.001, "learning rate")

  10. def main(unused_argv):
  11.     train_data_path = FLAGS.train_data_path
  12.     print("train_data_path", train_data_path)
  13.     train_batch_size = FLAGS.train_batch_size
  14.     print("train_batch_size", train_batch_size)
  15.     test_batch_size = FLAGS.test_batch_size
  16.     print("test_batch_size", test_batch_size)
  17.     size_sum = tf.add(train_batch_size, test_batch_size)
  18.     with tf.Session() as sess:
  19.         sum_result = sess.run(size_sum)
  20.         print("sum_result", sum_result)

  21. # 使用这种方式保证了,如果此文件被其他文件 import的时候,不会执行main 函数
  22. if __name__ == '__main__':
  23.     tf.app.run()   # 解析命令行参数,调用main 函数 main(sys.argv)
复制代码
上述代码已给出较为详细的注释,在此不再赘述。

该文件的调用示例以及运行结果如下所示


如果需要修改默认参数的值,则在命令行传入自定义参数值即可,若全部使用默认参数值,则可直接在命令行运行该 python 文件。
读者可能会对 tf.app.run() 有些疑问,在上述注释中也有所解释,但要真正弄清楚其运行原理

还需查阅其源代码
  1. def run(main=None, argv=None):
  2.   """Runs the program with an optional 'main' function and 'argv' list."""
  3.   f = flags.FLAGS

  4.   # Extract the args from the optional `argv` list.
  5.   args = argv[1:] if argv else None

  6.   # Parse the known flags from that list, or from the command
  7.   # line otherwise.
  8.   # pylint: disable=protected-access
  9.   flags_passthrough = f._parse_flags(args=args)
  10.   # pylint: enable=protected-access

  11.   main = main or sys.modules['__main__'].main

  12.   # Call the main function, passing through any arguments
  13.   # to the final program.
  14.   sys.exit(main(sys.argv[:1] + flags_passthrough))
复制代码
  1. flags_passthrough=f._parse_flags(args=args)
复制代码
这里的
  1. _parse_flags
复制代码
就是我们
  1. tf.app.flags
复制代码
源码中用来解析命令行参数的函数。
所以这一行就是解析参数的功能;
下面两行代码也就是 tf.app.run 的核心意思:执行程序中 main 函数,并解析命令行参数!
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。

来源:https://www.jb51.net/article/266499.htm
免责声明:由于采集信息均来自互联网,如果侵犯了您的权益,请联系我们【E-Mail:cb@itdo.tech】 我们会及时删除侵权内容,谢谢合作!

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x

举报 回复 使用道具