Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/RLMatrix.Common/IProvidesDiscreteActionMask.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
using System;

namespace RLMatrix.Common
{
// Optional interface implemented by generated discrete environments that provide action masks
public interface IProvidesDiscreteActionMask
{
// Returns per-head masks; for H discrete heads returns H arrays of length actionSize[head]
int[][] GetDiscreteActionMasks();
}
}

10 changes: 10 additions & 0 deletions src/RLMatrix.Toolkit/AdditionalFiles.cs
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,16 @@ public class RLMatrixRewardAttribute : Attribute { }
public class RLMatrixDoneAttribute : Attribute { }
[AttributeUsage(AttributeTargets.Method)]
public class RLMatrixResetAttribute : Attribute { }
[AttributeUsage(AttributeTargets.Method)]
public class RlMatrixActionMaskProviderAttribute : Attribute
{
public string MethodName { get; }
public RlMatrixActionMaskProviderAttribute(string methodName)
{
MethodName = methodName;
}
}
}";
}
}
9 changes: 9 additions & 0 deletions src/RLMatrix.Toolkit/Attributes.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ public class RLMatrixDoneAttribute : Attribute { }
[AttributeUsage(AttributeTargets.Method)]
public class RLMatrixResetAttribute : Attribute { }

[AttributeUsage(AttributeTargets.Method)]
public class RlMatrixActionMaskProviderAttribute : Attribute
{
public string MethodName { get; }
public RlMatrixActionMaskProviderAttribute(string methodName)
{
MethodName = methodName;
}
}


//----------------------------------------------Stubs for future semantics integration----------------------------------------------
Expand Down
60 changes: 57 additions & 3 deletions src/RLMatrix.Toolkit/RLMatrixSourceGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
var discreteActionInfos = discreteActionMethods
.Select(m => new ActionInfo(
m.Name,
GetAttributeArgument<int>(m, "RLMatrixActionDiscreteAttribute", 0)))
GetAttributeArgument<int>(m, "RLMatrixActionDiscreteAttribute", 0),
GetAttributeArgument<string>(m, "RlMatrixActionMaskProviderAttribute", 0, null)))
.ToArray();

var continuousActionInfos = continuousActionMethods
Expand Down Expand Up @@ -113,6 +114,7 @@ private static void GenerateEnvironmentClass(SourceProductionContext context, En
private static string GenerateSource(EnvironmentModel model)
{
var interfaceName = model.IsContinuous ? "IContinuousEnvironmentAsync<float[]>" : "IEnvironmentAsync<float[]>";
var maskInterface = model.IsContinuous ? string.Empty : ", IProvidesDiscreteActionMask";
var actionProperties = model.IsContinuous ? GenerateContinuousProperties(model) : GenerateDiscreteProperties(model);
var stepMethod = model.IsContinuous ? GenerateContinuousStepMethod(model) : GenerateDiscreteStepMethod(model);
var actionSizeCalc = model.IsContinuous ? "DiscreteActionSize.Length + ContinuousActionBounds.Length" : "actionSize.Length";
Expand All @@ -121,6 +123,7 @@ private static string GenerateSource(EnvironmentModel model)
var actionMethodsInit = GenerateActionMethodsInitialization(model);
var observationCollection = GenerateObservationCollection(model);
var ghostStepActions = GenerateGhostStepActions(model);
var masksMethod = model.IsContinuous ? string.Empty : GenerateGetDiscreteActionMasksMethod();

return $$"""
using System;
Expand All @@ -130,10 +133,11 @@ private static string GenerateSource(EnvironmentModel model)
using OneOf;
using RLMatrix;
using RLMatrix.Toolkit;
using RLMatrix.Common;

namespace {{model.NamespaceName}}
{
public partial class {{model.ClassName}} : {{interfaceName}}
public partial class {{model.ClassName}} : {{interfaceName}}{{maskInterface}}
{
private int _poolingRate;
private RLMatrixPoolingHelper _poolingHelper;
Expand All @@ -144,6 +148,8 @@ public partial class {{model.ClassName}} : {{interfaceName}}
private int _maxStepsSoft;
private bool _rlMatrixEpisodeTerminated;
private bool _rlMatrixEpisodeTruncated;
private Func<int[]>[] _maskProviders;

private (Action<int> method, int maxValue)[] _actionMethodsWithCaps;
private Action<float>[] _continuousActionMethods;

Expand Down Expand Up @@ -281,6 +287,8 @@ private float[] _GetBaseObservations()
{{observationCollection}}
return observations.ToArray();
}

{{masksMethod}}
}
}
""";
Expand Down Expand Up @@ -330,6 +338,25 @@ private static string GenerateActionMethodsInitialization(EnvironmentModel model
sb.AppendLine(" _actionMethodsWithCaps = new (Action<int>, int)[0];");
}

if (model.DiscreteActions.Any())
{
sb.AppendLine();
sb.AppendLine(" _maskProviders = new Func<int[]>[]");
sb.AppendLine(" {");
foreach (var action in model.DiscreteActions)
{
if (string.IsNullOrEmpty(action.MaskProvider))
sb.AppendLine(" null,");
else
sb.AppendLine($" () => {action.MaskProvider}(),");
}
sb.AppendLine(" };");
}
else
{
sb.AppendLine(" _maskProviders = new Func<int[]>[0];");
}

if (model.IsContinuous && model.ContinuousActions.Any())
{
sb.AppendLine();
Expand Down Expand Up @@ -454,6 +481,31 @@ private static string GenerateObservationCollection(EnvironmentModel model)
return sb.ToString();
}

private static string GenerateGetDiscreteActionMasksMethod()
{
return $$"""
public int[][] GetDiscreteActionMasks()
{
int heads = _actionMethodsWithCaps.Length;
var masks = new int[heads][];
for (int i = 0; i < heads; i++)
{
if (_maskProviders != null && i < _maskProviders.Length && _maskProviders[i] != null)
{
masks[i] = _maskProviders[i]();
}
else
{
int size = _actionMethodsWithCaps[i].maxValue;
masks[i] = Enumerable.Repeat(1, size).ToArray();
}
}
return masks;
}
""";
}


public class EnvironmentModel
{
public string ClassName { get; }
Expand Down Expand Up @@ -498,11 +550,13 @@ public class ActionInfo
{
public string MethodName { get; }
public int Size { get; }
public string MaskProvider { get; }

public ActionInfo(string methodName, int size)
public ActionInfo(string methodName, int size, string maskProvider = null)
{
MethodName = methodName;
Size = size;
MaskProvider = maskProvider;
}
}

Expand Down
17 changes: 17 additions & 0 deletions src/RLMatrix/Agents/Common/IDiscreteProxyWithMask.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using System;
using System.Collections.Generic;
using System.Threading.Tasks;

namespace RLMatrix.Agents.Common
{
// Optional extension for proxies that support action masks
public interface IDiscreteProxyWithMask<T>
{
#if NET8_0_OR_GREATER
ValueTask<Dictionary<Guid, int[]>> SelectActionsBatchWithMaskAsync(List<(Guid environmentId, T state, int[][] actionMasks)> stateInfos, bool isTraining);
#else
Task<Dictionary<Guid, int[]>> SelectActionsBatchWithMaskAsync(List<(Guid environmentId, T state, int[][] actionMasks)> stateInfos, bool isTraining);
#endif
}
}

52 changes: 51 additions & 1 deletion src/RLMatrix/Agents/Common/LocalDiscreteRolloutAgent.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using RLMatrix.Agents.DQN.Domain;
using RLMatrix.Agents.PPO.Implementations;
using RLMatrix.Common;
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
Expand Down Expand Up @@ -58,7 +59,23 @@ public async Task Step(bool isTraining = true)
var stateResults = await Task.WhenAll(stateTaskList);

List<(Guid environmentId, TState state)> payload = stateResults.ToList();
var actions = await _agent.SelectActionsBatchAsync(payload, isTraining);

Dictionary<Guid, int[]> actions;
if (_agent is IDiscreteProxyWithMask<TState> withMask)
{
var payloadWithMasks = stateResults.Select(r =>
{
int[][] masks = null;
if (_environments[r.environmentId] is IProvidesDiscreteActionMask provider)
masks = provider.GetDiscreteActionMasks();
return (r.environmentId, r.state, masks);
}).ToList();
actions = await withMask.SelectActionsBatchWithMaskAsync(payloadWithMasks, isTraining);
}
else
{
actions = await _agent.SelectActionsBatchAsync(payload, isTraining);
}

List<Task<(Guid environmentId, (float, bool) reward)>> rewardTaskList = new List<Task<(Guid environmentId, (float, bool) reward)>>();
foreach (var action in actions)
Expand Down Expand Up @@ -212,6 +229,17 @@ public void StepSync(bool isTraining = true)

private Dictionary<Guid, int[]> GetActionsBatchSync(List<(Guid environmentId, TState state)> stateInfos, bool isTraining)
{
if (_agent is IDiscreteProxyWithMask<TState> withMask)
{
var payloadWithMasks = stateInfos.Select(r =>
{
int[][] masks = null;
if (_environments[r.environmentId] is IProvidesDiscreteActionMask provider)
masks = provider.GetDiscreteActionMasks();
return (r.environmentId, r.state, masks);
}).ToList();
return withMask.SelectActionsBatchWithMaskAsync(payloadWithMasks, isTraining).GetAwaiter().GetResult();
}
return _agent.SelectActionsBatchAsync(stateInfos, isTraining).GetAwaiter().GetResult();
}

Expand All @@ -225,11 +253,33 @@ private Dictionary<Guid, int[]> GetActionsBatchSync(List<(Guid environmentId, TS
#if NET8_0_OR_GREATER
public ValueTask<Dictionary<Guid, int[]>> GetActionsBatchAsync(List<(Guid environmentId, TState state)> stateInfos, bool isTraining)
{
if (_agent is IDiscreteProxyWithMask<TState> withMask)
{
var payloadWithMasks = stateInfos.Select(r =>
{
int[][] masks = null;
if (_environments[r.environmentId] is IProvidesDiscreteActionMask provider)
masks = provider.GetDiscreteActionMasks();
return (r.environmentId, r.state, masks);
}).ToList();
return withMask.SelectActionsBatchWithMaskAsync(payloadWithMasks, isTraining);
}
return _agent.SelectActionsBatchAsync(stateInfos, isTraining);
}
#else
public Task<Dictionary<Guid, int[]>> GetActionsBatchAsync(List<(Guid environmentId, TState state)> stateInfos, bool isTraining)
{
if (_agent is IDiscreteProxyWithMask<TState> withMask)
{
var payloadWithMasks = stateInfos.Select(r =>
{
int[][] masks = null;
if (_environments[r.environmentId] is IProvidesDiscreteActionMask provider)
masks = provider.GetDiscreteActionMasks();
return (r.environmentId, r.state, masks);
}).ToList();
return withMask.SelectActionsBatchWithMaskAsync(payloadWithMasks, isTraining);
}
return _agent.SelectActionsBatchAsync(stateInfos, isTraining);
}
#endif
Expand Down
11 changes: 8 additions & 3 deletions src/RLMatrix/Agents/DQN/Domain/ComposableDiscreteAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public class ComposableQDiscreteAgent<T> : IDiscreteAgentCore<T>, IHasMemory<T>,
public required Action ResetNoisyLayers { get; init; }
public required DQNAgentOptions Options { get; init; }
public required Device Device { get; init; }
public required Func<T[], ComposableQDiscreteAgent<T>, bool, int[][]> SelectActionsFunc { private get; init; }
public required Func<T[], ComposableQDiscreteAgent<T>, bool, int[][][]?, int[][]> SelectActionsFunc { private get; init; }
#else
public Module<Tensor, Tensor> policyNet { get; set; }
public Module<Tensor, Tensor> targetNet { get; set; }
Expand All @@ -38,7 +38,7 @@ public class ComposableQDiscreteAgent<T> : IDiscreteAgentCore<T>, IHasMemory<T>,
public Action ResetNoisyLayers { get; set; }
public DQNAgentOptions Options { get; set; }
public Device Device { get; set; }
public Func<T[], ComposableQDiscreteAgent<T>, bool, int[][]> SelectActionsFunc { private get; set; }
public Func<T[], ComposableQDiscreteAgent<T>, bool, int[][][]?, int[][]> SelectActionsFunc { private get; set; }
#endif

public Random Random = new Random();
Expand All @@ -61,7 +61,12 @@ public void OptimizeModel()

public int[][] SelectActions(T[] states, bool isTraining)
{
return SelectActionsFunc(states, this, isTraining);
return SelectActionsFunc(states, this, isTraining, null);
}

public int[][] SelectActions(T[] states, bool isTraining, int[][][] masks)
{
return SelectActionsFunc(states, this, isTraining, masks);
}

public void Save(string path)
Expand Down
Loading