Skip to content

Commit e703d87

Browse files
authored
Set unique index vars in rx.foreach (#2126)
1 parent e6b0255 commit e703d87

File tree

3 files changed

+39
-37
lines changed

3 files changed

+39
-37
lines changed

reflex/components/layout/foreach.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Create a list of components from an iterable."""
22
from __future__ import annotations
33

4+
import typing
45
from typing import Any, Callable, Iterable
56

67
from reflex.components.component import Component
@@ -47,15 +48,20 @@ def create(cls, iterable: Var[Iterable], render_fn: Callable, **props) -> Foreac
4748
f"Could not foreach over var of type Any. (If you are trying to foreach over a state var, add a type annotation to the var.)"
4849
)
4950
arg = BaseVar(_var_name="_", _var_type=type_, _var_is_local=True)
51+
comp = IterTag(iterable=iterable, render_fn=render_fn).render_component(arg)
5052
return cls(
5153
iterable=iterable,
5254
render_fn=render_fn,
53-
children=[IterTag.render_component(render_fn, arg=arg)],
55+
children=[comp],
5456
**props,
5557
)
5658

5759
def _render(self) -> IterTag:
58-
return IterTag(iterable=self.iterable, render_fn=self.render_fn)
60+
return IterTag(
61+
iterable=self.iterable,
62+
render_fn=self.render_fn,
63+
index_var_name=get_unique_variable_name(),
64+
)
5965

6066
def render(self):
6167
"""Render the component.
@@ -66,9 +72,9 @@ def render(self):
6672
tag = self._render()
6773
try:
6874
type_ = (
69-
self.iterable._var_type
70-
if self.iterable._var_type.mro()[0] == dict
71-
else self.iterable._var_type.__args__[0]
75+
tag.iterable._var_type
76+
if tag.iterable._var_type.mro()[0] == dict
77+
else typing.get_args(tag.iterable._var_type)[0]
7278
)
7379
except Exception:
7480
type_ = Any
@@ -77,7 +83,7 @@ def render(self):
7783
_var_type=type_,
7884
)
7985
index_arg = tag.get_index_var_arg()
80-
component = tag.render_component(self.render_fn, arg)
86+
component = tag.render_component(arg)
8187
return dict(
8288
tag.add_props(
8389
**self.event_triggers,

reflex/components/tags/iter_tag.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,6 @@
1111
from reflex.components.component import Component
1212

1313

14-
INDEX_VAR = "i"
15-
16-
1714
class IterTag(Tag):
1815
"""An iterator tag."""
1916

@@ -23,37 +20,40 @@ class IterTag(Tag):
2320
# The component render function for each item in the iterable.
2421
render_fn: Callable
2522

26-
@staticmethod
27-
def get_index_var() -> Var:
28-
"""Get the index var for the tag.
23+
# The name of the index var.
24+
index_var_name: str = "i"
25+
26+
def get_index_var(self) -> Var:
27+
"""Get the index var for the tag (with curly braces).
28+
29+
This is used to reference the index var within the tag.
2930
3031
Returns:
3132
The index var.
3233
"""
3334
return BaseVar(
34-
_var_name=INDEX_VAR,
35+
_var_name=self.index_var_name,
3536
_var_type=int,
3637
)
3738

38-
@staticmethod
39-
def get_index_var_arg() -> Var:
40-
"""Get the index var for the tag.
39+
def get_index_var_arg(self) -> Var:
40+
"""Get the index var for the tag (without curly braces).
41+
42+
This is used to render the index var in the .map() function.
4143
4244
Returns:
4345
The index var.
4446
"""
4547
return BaseVar(
46-
_var_name=INDEX_VAR,
48+
_var_name=self.index_var_name,
4749
_var_type=int,
4850
_var_is_local=True,
4951
)
5052

51-
@staticmethod
52-
def render_component(render_fn: Callable, arg: Var) -> Component:
53+
def render_component(self, arg: Var) -> Component:
5354
"""Render the component.
5455
5556
Args:
56-
render_fn: The render function.
5757
arg: The argument to pass to the render function.
5858
5959
Returns:
@@ -65,16 +65,16 @@ def render_component(render_fn: Callable, arg: Var) -> Component:
6565
from reflex.components.layout.fragment import Fragment
6666

6767
# Get the render function arguments.
68-
args = inspect.getfullargspec(render_fn).args
69-
index = IterTag.get_index_var()
68+
args = inspect.getfullargspec(self.render_fn).args
69+
index = self.get_index_var()
7070

7171
if len(args) == 1:
7272
# If the render function doesn't take the index as an argument.
73-
component = render_fn(arg)
73+
component = self.render_fn(arg)
7474
else:
7575
# If the render function takes the index as an argument.
7676
assert len(args) == 2
77-
component = render_fn(arg, index)
77+
component = self.render_fn(arg, index)
7878

7979
# Nested foreach components or cond must be wrapped in fragments.
8080
if isinstance(component, (Foreach, Cond)):

tests/components/layout/test_foreach.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,9 @@ def display_nested_list_element(element: str, index: int):
7878
return box(text(element[index]))
7979

8080

81+
seen_index_vars = set()
82+
83+
8184
@pytest.mark.parametrize(
8285
"state_var, render_fn, render_dict",
8386
[
@@ -86,7 +89,6 @@ def display_nested_list_element(element: str, index: int):
8689
display_color,
8790
{
8891
"iterable_state": "for_each_state.colors_list",
89-
"arg_index": "i",
9092
"iterable_type": "list",
9193
},
9294
),
@@ -95,7 +97,6 @@ def display_nested_list_element(element: str, index: int):
9597
display_color_name,
9698
{
9799
"iterable_state": "for_each_state.colors_dict_list",
98-
"arg_index": "i",
99100
"iterable_type": "list",
100101
},
101102
),
@@ -104,7 +105,6 @@ def display_nested_list_element(element: str, index: int):
104105
display_shade,
105106
{
106107
"iterable_state": "for_each_state.colors_nested_dict_list",
107-
"arg_index": "i",
108108
"iterable_type": "list",
109109
},
110110
),
@@ -113,7 +113,6 @@ def display_nested_list_element(element: str, index: int):
113113
display_primary_colors,
114114
{
115115
"iterable_state": "for_each_state.primary_color",
116-
"arg_index": "i",
117116
"iterable_type": "dict",
118117
},
119118
),
@@ -122,7 +121,6 @@ def display_nested_list_element(element: str, index: int):
122121
display_color_with_shades,
123122
{
124123
"iterable_state": "for_each_state.color_with_shades",
125-
"arg_index": "i",
126124
"iterable_type": "dict",
127125
},
128126
),
@@ -131,7 +129,6 @@ def display_nested_list_element(element: str, index: int):
131129
display_nested_color_with_shades,
132130
{
133131
"iterable_state": "for_each_state.nested_colors_with_shades",
134-
"arg_index": "i",
135132
"iterable_type": "dict",
136133
},
137134
),
@@ -140,7 +137,6 @@ def display_nested_list_element(element: str, index: int):
140137
display_nested_color_with_shades_v2,
141138
{
142139
"iterable_state": "for_each_state.nested_colors_with_shades",
143-
"arg_index": "i",
144140
"iterable_type": "dict",
145141
},
146142
),
@@ -149,7 +145,6 @@ def display_nested_list_element(element: str, index: int):
149145
display_color_tuple,
150146
{
151147
"iterable_state": "for_each_state.color_tuple",
152-
"arg_index": "i",
153148
"iterable_type": "tuple",
154149
},
155150
),
@@ -158,7 +153,6 @@ def display_nested_list_element(element: str, index: int):
158153
display_colors_set,
159154
{
160155
"iterable_state": "for_each_state.colors_set",
161-
"arg_index": "i",
162156
"iterable_type": "set",
163157
},
164158
),
@@ -167,7 +161,6 @@ def display_nested_list_element(element: str, index: int):
167161
lambda el, i: display_nested_list_element(el, i),
168162
{
169163
"iterable_state": "for_each_state.nested_colors_list",
170-
"arg_index": "i",
171164
"iterable_type": "list",
172165
},
173166
),
@@ -184,8 +177,11 @@ def test_foreach_render(state_var, render_fn, render_dict):
184177
component = Foreach.create(state_var, render_fn)
185178

186179
rend = component.render()
187-
arg_index = rend["arg_index"]
188180
assert rend["iterable_state"] == render_dict["iterable_state"]
189-
assert arg_index._var_name == render_dict["arg_index"]
190-
assert arg_index._var_type == int
191181
assert rend["iterable_type"] == render_dict["iterable_type"]
182+
183+
# Make sure the index vars are unique.
184+
arg_index = rend["arg_index"]
185+
assert arg_index._var_name not in seen_index_vars
186+
assert arg_index._var_type == int
187+
seen_index_vars.add(arg_index._var_name)

0 commit comments

Comments
 (0)