-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata_models.py
More file actions
143 lines (118 loc) · 4.45 KB
/
Copy pathdata_models.py
File metadata and controls
143 lines (118 loc) · 4.45 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""Data models and constants for the ARC-AGI-3 game API.
Holds the wire types shared by the client and the environment: the colour and
game-state enums, the colour palette, and the :class:`FrameData` snapshot
(a Pydantic model).
"""
from __future__ import annotations
from enum import Enum, IntEnum
from typing import Any, Optional
from pydantic import BaseModel, Field
GRID = 64 # frames are GRID x GRID
NUM_COLORS = 16 # colour ids 0..15
class Color(IntEnum):
"""The 16 fixed ARC-AGI colour ids. Every frame cell holds one of these.
Subclasses ``int`` so a member is usable directly as a colour id / array
index (``frame[y][x] == Color.RED``) while naming the otherwise-magic 0..15.
RGB values are in :data:`PALETTE`, indexed by the colour id.
"""
BLACK = 0
BLUE = 1
RED = 2
GREEN = 3
YELLOW = 4
GREY = 5
FUCHSIA = 6
ORANGE = 7
AZURE = 8
MAROON = 9
WHITE = 10
DARK_GREY = 11
PINK = 12
TEAL = 13
PURPLE = 14
NEAR_BLACK = 15
@property
def rgb(self) -> tuple[int, int, int]:
"""The (R, G, B) triple this colour renders to."""
return PALETTE[int(self)]
# RGB palette indexed by colour id (Color). Used by the env renderers.
PALETTE: tuple[tuple[int, int, int], ...] = (
(0, 0, 0), # 0 BLACK
(0, 116, 217), # 1 BLUE
(255, 65, 54), # 2 RED
(46, 204, 64), # 3 GREEN
(255, 220, 0), # 4 YELLOW
(170, 170, 170), # 5 GREY
(240, 18, 190), # 6 FUCHSIA
(255, 133, 27), # 7 ORANGE
(127, 219, 255), # 8 AZURE
(135, 12, 37), # 9 MAROON
(255, 255, 255), # 10 WHITE
(99, 99, 99), # 11 DARK_GREY
(255, 167, 167), # 12 PINK
(0, 128, 128), # 13 TEAL
(128, 0, 128), # 14 PURPLE
(60, 60, 60), # 15 NEAR_BLACK
)
class GameState(str, Enum):
"""Lifecycle state of an ARC-AGI-3 game.
Subclasses ``str`` so members compare equal to their wire value
(``GameState.WIN == "WIN"``) and serialise transparently to JSON/Lance.
"""
NOT_PLAYED = "NOT_PLAYED"
NOT_FINISHED = "NOT_FINISHED"
WIN = "WIN"
GAME_OVER = "GAME_OVER"
@classmethod
def _missing_(cls, value: Any) -> "Optional[GameState]":
# Tolerate casing and the ``NOT_STARTED`` alias seen in some API docs.
if isinstance(value, str):
v = value.upper()
if v == "NOT_STARTED":
return cls.NOT_PLAYED
for member in cls:
if member.value == v:
return member
return None
@property
def is_terminal(self) -> bool:
"""True once the episode is over (win or loss)."""
return self in (GameState.WIN, GameState.GAME_OVER)
class FrameData(BaseModel):
"""Normalised snapshot returned by RESET / ACTION* commands.
A Pydantic model: validates/coerces incoming JSON (including the
:class:`GameState` enum) and round-trips via ``model_dump_json``.
Wire shape::
{
"game_id": "...",
"guid": "...", # session id, echoed on every call
"frame": [[[0..15, ...]]], # list of 64x64 grids of colour ids
"state": "NOT_PLAYED" | "NOT_FINISHED" | "WIN" | "GAME_OVER",
"levels_completed": int, # (formerly "score")
"win_levels": int, # (formerly "win_score")
"action_input": {"id": int, ...},
"available_actions": [1, 2, ...]
}
"""
game_id: str = ""
guid: str = ""
frame: list[list[list[int]]] = Field(default_factory=list) # list of 2D grids
state: GameState = GameState.NOT_PLAYED
levels_completed: int = 0
win_levels: int = 0
action_input: dict[str, Any] = Field(default_factory=dict)
available_actions: list[int] = Field(default_factory=list)
@classmethod
def from_json(cls, d: dict[str, Any]) -> "FrameData":
return cls(
game_id=d.get("game_id", ""),
guid=d.get("guid", ""),
frame=d.get("frame", []) or [],
# GameState(...) resolves casing / the NOT_STARTED alias up front.
state=GameState(d.get("state", "NOT_PLAYED")),
# accept both new and legacy field names
levels_completed=d.get("levels_completed", d.get("score", 0)) or 0,
win_levels=d.get("win_levels", d.get("win_score", 0)) or 0,
action_input=d.get("action_input", {}) or {},
available_actions=d.get("available_actions", []) or [],
)