Skip to content

Add PPO self-play implementation for OpenSpiel#1519

Open
Arahan-kujur wants to merge 2 commits into
google-deepmind:masterfrom
Arahan-kujur:ppo-selfplay
Open

Add PPO self-play implementation for OpenSpiel#1519
Arahan-kujur wants to merge 2 commits into
google-deepmind:masterfrom
Arahan-kujur:ppo-selfplay

Conversation

@Arahan-kujur

Copy link
Copy Markdown
Contributor

Adds a PPO (Proximal Policy Optimization) implementation in JAX/Flax (NNX) to OpenSpiel, including a self-play training loop for turn-based imperfect-information games.

Features
Actor-critic PPO agent implemented in JAX + Flax (NNX)
Supports self-play with a single agent controlling all players
Generalized Advantage Estimation (GAE) per player trajectory
Legal action masking for arbitrary OpenSpiel games
Example training script for Kuhn Poker and Leduc Poker
Unit tests covering training, evaluation mode, and self-play behavior
Results

Tested on Kuhn Poker using the example script:

Exploitability: ~0.22 after 500 iterations (entropy_coef=0.1)
Average returns close to game value (-1/18 ≈ -0.056)

This suggests the self-play setup and training loop are functioning as expected.

Notes
Designed as a reference implementation for policy gradient methods in OpenSpiel
PPO does not have convergence guarantees in imperfect-information games
Performance is sensitive to hyperparameters (e.g., entropy regularization)
Files Added
open_spiel/python/jax/ppo.py — PPO agent and training logic
open_spiel/python/examples/ppo_example_jax.py — example self-play training script
open_spiel/python/jax/ppo_jax_test.py — unit tests
Future Work
Scaling to larger games (e.g., Leduc Poker tuning)
Benchmarking against CFR-based methods
Multi-agent extensions or population-based training

Remove .vs/ IDE files, add .vs/ to .gitignore, improve PPO docstrings

Made-with: Cursor
@alexunderch

Copy link
Copy Markdown
Contributor

Hello! in general, really cool implementation! It would also be so cool if you implemented GAE calculation in a vectorised form (using jax.lax.scan) because it's a very resource-demanding operation, see: for example, here. Moveover, you mostly use numpy.random generator, but it would be nice if you took some advantage of jax.random reproducability.

Also, it would be nice if you provided some insights on how the algorithm performs for some known games.
Nice work!

@Arahan-kujur

Copy link
Copy Markdown
Contributor Author

Hello! in general, really cool implementation! It would also be so cool if you implemented GAE calculation in a vectorised form (using jax.lax.scan) because it's a very resource-demanding operation, see: for example, here. Moveover, you mostly use numpy.random generator, but it would be nice if you took some advantage of jax.random reproducability.

Also, it would be nice if you provided some insights on how the algorithm performs for some known games. Nice work!

Hi! Thanks a lot for the thoughtful feedback — I really appreciate it.

I’ve made several updates based on your suggestions:

Vectorized GAE: Replaced the Python loop with a jax.lax.scan (reverse-time) implementation, making it fully JIT-compatible and more efficient.
JAX PRNG: Switched from numpy.random to jax.random throughout the codebase, with explicit key handling for reproducibility.
Benchmarks: Added evaluation on standard OpenSpiel environments (kuhn_poker, leduc_poker, matrix_pd) along with training curves (policy/value loss, entropy, exploitability).

I also included tests for GAE correctness and PRNG reproducibility, and added documentation explaining the design choices.

Thanks again for the suggestions — they were really helpful in improving both performance and clarity.

@alexunderch

Copy link
Copy Markdown
Contributor

Thank you, I will give the results a look!

@alexunderch

alexunderch commented May 9, 2026

Copy link
Copy Markdown
Contributor

Hey! kunh_poker result looks nice!
Can you also report:

  1. exploitability for the leduc_poker (in a similar manner)?
  2. cumulative return plots/entropy for RPS(pyspiel.load_game("matrix_rps")) and breakout game
  3. delete the readme file

P.s. also, use jax.Array/chex.Array for the type annonation, not jnp.ndarray, please, because it's an official annotation way, I guess: https://docs.jax.dev/en/latest/_autosummary/jax.Array.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants