a0-jax

所属分类:collect
开发工具:Python
文件大小:0KB
下载次数:0
上传日期:2023-01-02 08:39:13
上 传 者sh-1993
说明:  JAX中的AlphaZero,
(AlphaZero in JAX,)

文件列表:
.pylintrc (13977, 2023-01-02)
LICENSE (1063, 2023-01-02)
convert_to_tfjs.py (2501, 2023-01-02)
games/ (0, 2023-01-02)
games/__init__.py (0, 2023-01-02)
games/caro_game.py (6408, 2023-01-02)
games/connect_four_game.py (5161, 2023-01-02)
games/connect_two_game.py (3680, 2023-01-02)
games/connect_two_game_test.py (2123, 2023-01-02)
games/dsu.py (2794, 2023-01-02)
games/env.py (1652, 2023-01-02)
games/go_game.py (11798, 2023-01-02)
games/tic_tac_toe_game.py (5245, 2023-01-02)
go_game_test.py (46, 2023-01-02)
go_web_app.py (3654, 2023-01-02)
index.html (9254, 2023-01-02)
play.py (6268, 2023-01-02)
plot_search_tree.py (3117, 2023-01-02)
policies/ (0, 2023-01-02)
policies/__init__.py (0, 2023-01-02)
policies/mlp_policy.py (1521, 2023-01-02)
policies/resnet_policy.py (3525, 2023-01-02)
requirements.txt (85, 2023-01-02)
search_tree.png (263854, 2023-01-02)
test.sh (560, 2023-01-02)
train_agent.py (10560, 2023-01-02)
tree_search.py (1958, 2023-01-02)
utils.py (1506, 2023-01-02)

# a0-jax AlphaZero in JAX using deepmind [mctx](https://github.com/deepmind/mctx) library. ```sh pip install -r requirements.txt ``` ## Train agent ### Connect-Two game ```sh python train_agent.py --weight-decay=1e-2 --num-iterations=3 ``` ### Connect-Four game ```sh TF_CPP_MIN_LOG_LEVEL=2 \ python train_agent.py \ --game_class="games.connect_four_game.Connect4Game" \ --agent_class="policies.resnet_policy.ResnetPolicyValueNet" \ --batch-size=4096 \ --num_simulations_per_move=32 \ --num_self_plays_per_iteration=102400 \ --learning-rate=1e-2 \ --num_iterations=500 \ --lr-decay-steps=200000 ``` A live Connect-4 agent is running at https://huggingface.co/spaces/ntt123/Connect-4-Game. We use tensorflow.js to run the policy on the browser. ### Caro (Gomoku) game ```sh TF_CPP_MIN_LOG_LEVEL=2 \ python3 train_agent.py \ --game-class="games.caro_game.CaroGame" \ --agent-class="policies.resnet_policy.ResnetPolicyValueNet128" \ --selfplay-batch-size=1024 \ --training-batch-size=1024 \ --num-simulations-per-move=32 \ --num-self-plays-per-iteration=102400 \ --learning-rate=1e-2 \ --random-seed=42 \ --ckpt-filename="./caro_agent_9x9_128.ckpt" \ --num-iterations=100 \ --lr-decay-steps=500000 ``` A live Caro agent is running at https://caro.ntt123.repl.co. ### Go game ```sh TF_CPP_MIN_LOG_LEVEL=2 \ python3 train_agent.py \ --game-class="games.go_game.GoBoard9x9" \ --agent-class="policies.resnet_policy.ResnetPolicyValueNet128" \ --selfplay-batch-size=1024 \ --training-batch-size=1024 \ --num-simulations-per-move=32 \ --num-self-plays-per-iteration=102400 \ --learning-rate=1e-2 \ --random-seed=42 \ --ckpt-filename="./go_agent_9x9_128.ckpt" \ --num-iterations=200 \ --lr-decay-steps=1000000 ``` A live Go agent is running at https://go.ntt123.repl.co. You can run the agent on your local machine with the `go_web_app.py` script. We also have an [interative colab notebook](https://colab.research.google.com/drive/1IlN1gThYrLazxTGrhryNzspx-Ts_6llj?usp=sharing) that runs the agent on GPU to reduce inference time. ## Plot the search tree ```sh python plot_search_tree.py # ./search_tree.png ``` ## Play ```sh python play.py ``` ## TPU sponsor Agents in the above demos are trained on Google TPUs sponsored by Google under the [TPU Research Cloud program](https://sites.research.google/trc/about/).

近期下载者

相关文件


收藏者