Rustの機械学習ライブラリcandleを使用した、DQNの実装です。
candleは"0.8.2"を使用。
Rustの練習のためにGridWorldを作成。
RustのcandleでのDQN実装コードはまだ少ないので、誰かの参考になればと思います。
field図
・5×5のグリッドを設定
・start地点は(4,0)、goalは(1,4)、中間報酬は(2,3)
→中間報酬を設定することで、フラグによって目的を切り替えるかを実験
・エージェントにとって学習しづらいように意図的に壁を(2,1),(3,1),(1,3),(2,4)に配置
→ゴールへ行くには壁沿いを進む必要があるので、マイナス報酬になりやすいのと、中間報酬を取ると次に壁に当たりやすい。
→また、中間報酬とゴール地点を近くに置くことによって、QNetがどっちの報酬かわかりづらくしている
・報酬設定 = { ゴール:1.0点, 中間報酬:0.5点, 壁:-0.1点, 1ステップ:-0.01点 }
以下のヒートマップは中間報酬を取得する前後のQ値の可視化です。
矢印はそのstateでのmax_Qのaction方向です。
※左:中間報酬前(Pre)、右:中間報酬後(Post)
ゴール付近では黄色になっている。また、グラーデーションになっており、しっかり隣のマスの期待値が伝播しているのがわかる。
中間報酬取る前は、近くを通った時は中間報酬取った方が良く、中間報酬を取った後は即座にゴールに向かっていることがわかる
探索フェーズを長めに取らないと、ゴールを知る前に探索を終えてしまうため、epsilon_decayは長めに取った方が安定する。
学習率lrは小さすぎると、探索時にゴールしたことを上手く学習できない。大きくし過ぎると、マイナス報酬の頻度が多いので、重みが0になって学習崩壊する。lr = 1e-4あたりがちょうどいいと思います。微調整してください。
とりあえずRustでDQNを実装し、学習させることができたが課題は残っている。
気になるポイントは大きく分けて2つある。
1.ゴールの1マス前までしか黄色ではない。
2.スタート地点から中間報酬を取らずに直接ゴールに行くように収束している
1つ目の問題点は、TD法なのでおそらく報酬が伝わりにくいため、1つ前までしか高Q値になっていない。 これを解決するには、n-stepにする必要があると思われる。
2つ目の問題点は、中間報酬を取りにいかない点だ。 スタート地点から中間報酬を取らずに直接ゴールへ向かった方が事故が少なく、安定すると判断していそうだ。探索率が高いフェーズでは、中間報酬を取りに行くと事故率が高く、探索率が低い時にはすでにゴールへ直行するルートが確立されているため、わざわざ中間報酬を取りに行こうとしない。これは、DoubleDQN、DuelingDQNを導入するだけである程度改善されそうではある。また、活性化関数をReLUにしているのでバッサリ行き過ぎている可能性もあり、スタート地点での微差を判断できていない可能性がある。
(以下、書いている途中)
そもそも中間報酬が0.5点でゴールが1.0点だとエージェントから見てゴールのほうが高い山に見えてるのでは?
そしてゴールすると、中間報酬を得る機会がないので、最初からゴールに直行しているのでは?と思い、とある検証をした。
割引率をγとして、スタートからのマス目をゴールが9マス、中間報酬が5マスなので、
γ9 * (1.0) = γ5 * (0.5) → γ ≈ 0.84
なので、γが0.84以下であれば中間報酬の方を先に取るという仮説がたつ。
収束状況やQNetの表現力の事情もあるので、少し抑えめでγ=0.80で検証した。
γ=0.80での中間報酬前のq_map
スタート位置で以前はUpを選択していたが、中間報酬の方が見えるようになったため、Rightを選択するようになった
γを小さくすると、どこに報酬があるのかわからなくなりやすく、学習崩壊しやすい。
実際に探索フェーズが終わった後に局所解にハマったので、あまりいい方法ではないと思われる。
二分木なので、PrioritizedReplayBufferのcapacityは2の累乗になるように実装してある。なので、capacityに大きすぎる値を入れるのNG。本実装では、Agentに10000を渡しているため、 2**14がcapacityになっている。
betaとεのバランスは考える必要がある。探索フェーズが終わってもbetaが低いままだと、有力でないものをたくさん学習してしまうから。 betaは0.0に近いほど優先度を重視したIS_Weightになるため、微調整してください。
alphaはどの程度、優先度をつけるかのパラメータであり、1.0に近いほど優先度を考慮します。 本実装では0.35にしていますが、0.4~0.6ぐらいが一般的です。
※左:中間報酬前(Pre)、右:中間報酬後(Post)
中間報酬を取ってからゴールを目指すようになったのがわかる。中間報酬をとるまでは引き返しが最適方策になっている。 (γ=0.99に戻してある。)
中間報酬がある(2,3)の隣に注目すると、次のstepで中間報酬が取れるのにもかかわらず、ヒートマップの色が異なる。これは、おそらくだが問題設定に対して、表現力ギリギリのQNetを設定しているからだと思われる。(表現力をあげたければ、NNを大きくすればいいがそれでは面白くないためギリギリを攻めている)。
また、報酬と関係ないところで壁にぶつかったり無限ループしてたりするが、これはPERを実装したことで、TD誤差があまり大きくならないところは曖昧になっている。このあたりも表現力あげれば解決すると思われる。
学習は4000episodeしているので、データが足りていないというよりはTD誤差の小さいところをQNetが捨てているのだと考えられます。
target_netのdetach()を忘れていたため、target_netにも勾配が流れてしまっていた。その修正後の学習結果
あまり結果は変わらないが、表現力を抑えているため特定のルート以外は表現できていない。betaが収束しきる前に学習を終えている影響もあるとは思われる
n_step_bufferがn_step貯まる前にdoneが来てしまうと、割引すぎ問題が発覚した。まだ修正に手が回ってないため、後々修正する。n=1では何も問題はない。
RustでのDQN実装例が少なく、さらに日本語で書いているものは非常に少ないので、ひとつの実装例として参考になる部分があれば幸いです。
勉強するにしても日本語ベースの情報が少なすぎて、困ってる日本人がどこかにいるかもしれないですよね。 とりあえず実装して動かせるものを見るだけでもありがたいんじゃないかなと思い、Githubにあげました。 生成AIでコードを出すにしても、全体像がみえてる方が良いコードを出しやすいとも思います。
古典的なGridworldですが、RLをはじめる時は必ずと言っていいほど全員が学びますよね。 後のRL初学者が実装例を見て、理解が深まれば嬉しく思います。







