Skip to content

Commit dc2ba1b

Browse files
authored
Merge pull request #185 from tancheng/reduce_cmd
Enable global reduce for multi-cgra vectorized kernel FIR
2 parents ca33d73 + 1d0a663 commit dc2ba1b

52 files changed

Lines changed: 3495 additions & 588 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

cgra/test/CgraRTL_fir_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,16 @@ def line_trace(s):
313313
// 17 * 19 +
314314
// 18 * 20 +
315315
// 19 * 21
316+
// = 168 +
317+
// 195 +
318+
// 224 +
319+
// 255 +
320+
// 288 +
321+
// 323 +
322+
// 360 +
323+
// 399
324+
// = 842 +
325+
// 1370
316326
// = 2212
317327
// expected sum = 2212 + 3 = 2215 (0x8a7)
318328
'''

controller/ControllerRTL.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from ..noc.PyOCN.pymtl3_net.channel.ChannelRTL import ChannelRTL
1919
from ..noc.PyOCN.pymtl3_net.xbar.XbarBypassQueueRTL import XbarBypassQueueRTL
2020

21+
from .GlobalReduceUnitRTL import GlobalReduceUnitRTL
22+
2123
class ControllerRTL(Component):
2224

2325
def construct(s,
@@ -79,6 +81,11 @@ def construct(s,
7981
s.recv_from_cpu_pkt_queue = NormalQueueRTL(IntraCgraPktType)
8082
s.send_to_cpu_pkt_queue = NormalQueueRTL(IntraCgraPktType)
8183

84+
# Global reduce unit.
85+
# TODO: We need multiple GlobalReduceUnitRTL to enable more than 1 reduction
86+
# across the fabric: https://github.com/tancheng/VectorCGRA/issues/184.
87+
s.global_reduce_unit = GlobalReduceUnitRTL(DataType, InterCgraPktType, ControllerXbarPktType)
88+
8289
# LUT for global data address mapping.
8390
addr_offset_nbits = 0
8491
s.addr2controller_lut = [Wire(CgraIdType) for _ in range(len(controller2addr_map))]
@@ -128,6 +135,7 @@ def update_received_msg():
128135
kStoreRequestInportIdx = 2
129136
kFromCpuCtrlAndDataIdx = 3
130137
kFromInterTileRingIdx = 4
138+
kFromReduceUnitIdx = 5
131139

132140
s.send_to_cpu_pkt_queue.recv.val @= 0
133141
s.send_to_cpu_pkt_queue.recv.msg @= IntraCgraPktType(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)
@@ -177,6 +185,12 @@ def update_received_msg():
177185
ControllerXbarPktType(0, # dst (always 0 to align with the single outport of the crossbar, i.e., NoC)
178186
s.recv_from_tile_load_response_pkt_queue.send.msg)
179187

188+
# For the load response (i.e., the data towards other) from local memory.
189+
s.crossbar.recv[kFromReduceUnitIdx].val @= \
190+
s.global_reduce_unit.send.val
191+
s.global_reduce_unit.send.rdy @= s.crossbar.recv[kFromReduceUnitIdx].rdy
192+
s.crossbar.recv[kFromReduceUnitIdx].msg @= s.global_reduce_unit.send.msg
193+
180194
# For the ctrl and data preloading.
181195
s.crossbar.recv[kFromCpuCtrlAndDataIdx].val @= \
182196
s.recv_from_cpu_pkt_queue.send.val
@@ -214,6 +228,10 @@ def update_received_msg():
214228
s.recv_from_inter_cgra_noc.rdy @= 0
215229
s.send_to_ctrl_ring_pkt.val @= 0
216230
s.send_to_ctrl_ring_pkt.msg @= IntraCgraPktType(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)
231+
s.global_reduce_unit.recv_count.val @= 0
232+
s.global_reduce_unit.recv_count.msg @= InterCgraPktType(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)
233+
s.global_reduce_unit.recv_data.val @= 0
234+
s.global_reduce_unit.recv_data.msg @= InterCgraPktType(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)
217235

218236
# For the load request from NoC.
219237
received_pkt = s.recv_from_inter_cgra_noc.msg
@@ -274,13 +292,25 @@ def update_received_msg():
274292
0, # vc_id
275293
s.recv_from_inter_cgra_noc.msg.payload)
276294

295+
elif s.recv_from_inter_cgra_noc.msg.payload.cmd == CMD_GLOBAL_REDUCE_ADD:
296+
s.recv_from_inter_cgra_noc.rdy @= s.global_reduce_unit.recv_data.rdy
297+
s.global_reduce_unit.recv_data.val @= 1
298+
s.global_reduce_unit.recv_data.msg @= s.recv_from_inter_cgra_noc.msg
299+
300+
elif s.recv_from_inter_cgra_noc.msg.payload.cmd == CMD_GLOBAL_REDUCE_COUNT:
301+
s.recv_from_inter_cgra_noc.rdy @= s.global_reduce_unit.recv_count.rdy
302+
s.global_reduce_unit.recv_count.val @= 1
303+
s.global_reduce_unit.recv_count.msg @= s.recv_from_inter_cgra_noc.msg
304+
277305
elif (s.recv_from_inter_cgra_noc.msg.payload.cmd == CMD_CONFIG) | \
278306
(s.recv_from_inter_cgra_noc.msg.payload.cmd == CMD_CONFIG_PROLOGUE_FU) | \
279307
(s.recv_from_inter_cgra_noc.msg.payload.cmd == CMD_CONFIG_PROLOGUE_FU_CROSSBAR) | \
280308
(s.recv_from_inter_cgra_noc.msg.payload.cmd == CMD_CONFIG_PROLOGUE_ROUTING_CROSSBAR) | \
281309
(s.recv_from_inter_cgra_noc.msg.payload.cmd == CMD_CONFIG_TOTAL_CTRL_COUNT) | \
282310
(s.recv_from_inter_cgra_noc.msg.payload.cmd == CMD_CONFIG_COUNT_PER_ITER) | \
283311
(s.recv_from_inter_cgra_noc.msg.payload.cmd == CMD_CONST) | \
312+
(s.recv_from_inter_cgra_noc.msg.payload.cmd == CMD_GLOBAL_REDUCE_ADD_RESPONSE) | \
313+
(s.recv_from_inter_cgra_noc.msg.payload.cmd == CMD_GLOBAL_REDUCE_MUL_RESPONSE) | \
284314
(s.recv_from_inter_cgra_noc.msg.payload.cmd == CMD_LAUNCH):
285315
s.recv_from_inter_cgra_noc.rdy @= s.send_to_ctrl_ring_pkt.rdy
286316
s.send_to_ctrl_ring_pkt.val @= s.recv_from_inter_cgra_noc.val

controller/GlobalReduceUnitRTL.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
'''
2+
==========================================================================
3+
GlobalReduceUnitRTL.py
4+
==========================================================================
5+
A global reduce unit to record the count that data needs to be reduced,
6+
and received the corresponding data. The unit will send the reduced data
7+
back to the controller.
8+
9+
Author : Cheng Tan
10+
Date : Sep 8, 2025
11+
'''
12+
13+
from ..lib.basic.val_rdy.ifcs import RecvIfcRTL
14+
from ..lib.basic.val_rdy.ifcs import SendIfcRTL
15+
from ..lib.basic.val_rdy.queues import NormalQueueRTL
16+
from ..lib.cmd_type import *
17+
18+
from pymtl3 import *
19+
20+
class GlobalReduceUnitRTL(Component):
21+
22+
def construct(s, DataType, InterCgraPktType, ControllerXbarPktType):
23+
24+
# Interfaces.
25+
s.recv_count = RecvIfcRTL(InterCgraPktType)
26+
s.recv_data = RecvIfcRTL(InterCgraPktType)
27+
s.send = SendIfcRTL(ControllerXbarPktType)
28+
29+
# Components
30+
s.queue = NormalQueueRTL(InterCgraPktType, 16)
31+
s.target_count = Wire(DataType)
32+
s.receiving_count = Wire(DataType)
33+
s.sending_count = Wire(DataType)
34+
s.reduce_add_value = Wire(DataType)
35+
s.reduce_mul_value = Wire(DataType)
36+
37+
# Connections.
38+
s.recv_count.rdy //= 1
39+
40+
@update
41+
def set_recv_rdy():
42+
s.recv_data.rdy @= 0
43+
s.queue.recv.val @= 0
44+
s.queue.recv.msg @= InterCgraPktType(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)
45+
if s.target_count.payload > s.receiving_count.payload:
46+
s.recv_data.rdy @= s.queue.recv.rdy
47+
s.queue.recv.msg @= s.recv_data.msg
48+
s.queue.recv.val @= s.recv_data.val
49+
50+
@update_ff
51+
def update_count():
52+
if s.reset:
53+
s.target_count <<= DataType(0, 0, 0, 0)
54+
s.receiving_count <<= DataType(0, 0, 0, 0)
55+
s.sending_count <<= DataType(0, 0, 0, 0)
56+
else:
57+
if s.recv_count.val & s.recv_count.rdy:
58+
s.target_count <<= DataType(s.recv_count.msg.payload.data.payload, 0, 0, 0)
59+
if s.recv_data.val & s.recv_data.rdy:
60+
s.receiving_count <<= DataType(s.receiving_count.payload + 1, 0, 0, 0)
61+
if s.send.rdy & s.send.val:
62+
s.sending_count <<= DataType(s.sending_count.payload + 1, 0, 0, 0)
63+
elif (s.sending_count == s.receiving_count) & \
64+
(s.sending_count == s.target_count) & \
65+
(s.target_count.payload > 0):
66+
s.sending_count <<= DataType(0, 0, 0, 0)
67+
s.receiving_count <<= DataType(0, 0, 0, 0)
68+
69+
@update
70+
def update_send():
71+
s.send.msg @= ControllerXbarPktType(0, 0)
72+
s.send.val @= 0
73+
s.queue.send.rdy @= 0
74+
if (s.target_count.payload > 0) & (s.receiving_count.payload == s.target_count.payload):
75+
# Updates the cmd type, result value, and src/dst.
76+
if s.queue.send.msg.payload.cmd == CMD_GLOBAL_REDUCE_ADD:
77+
s.send.msg.inter_cgra_pkt.payload.cmd @= CMD_GLOBAL_REDUCE_ADD_RESPONSE
78+
s.send.msg.inter_cgra_pkt.payload.data @= s.reduce_add_value
79+
elif s.queue.send.msg.payload.cmd == CMD_GLOBAL_REDUCE_MUL:
80+
s.send.msg.inter_cgra_pkt.payload.cmd @= CMD_GLOBAL_REDUCE_MUL_RESPONSE
81+
s.send.msg.inter_cgra_pkt.payload.data @= s.reduce_mul_value
82+
s.send.msg.inter_cgra_pkt.src @= s.queue.send.msg.dst
83+
s.send.msg.inter_cgra_pkt.dst @= s.queue.send.msg.src
84+
s.send.msg.inter_cgra_pkt.src_x @= s.queue.send.msg.dst_x
85+
s.send.msg.inter_cgra_pkt.src_y @= s.queue.send.msg.dst_y
86+
s.send.msg.inter_cgra_pkt.dst_x @= s.queue.send.msg.src_x
87+
s.send.msg.inter_cgra_pkt.dst_y @= s.queue.send.msg.src_y
88+
s.send.msg.inter_cgra_pkt.src_tile_id @= s.queue.send.msg.dst_tile_id
89+
s.send.msg.inter_cgra_pkt.dst_tile_id @= s.queue.send.msg.src_tile_id
90+
s.queue.send.rdy @= s.send.rdy
91+
s.send.val @= s.queue.send.val
92+
93+
@update_ff
94+
def accumulate_value():
95+
if s.reset | (s.sending_count == s.target_count):
96+
s.reduce_add_value <<= DataType(0, 0, 0, 0)
97+
s.reduce_mul_value <<= DataType(1, 0, 0, 0)
98+
else:
99+
if s.recv_data.val & \
100+
s.recv_data.rdy:
101+
if s.recv_data.msg.payload.cmd == CMD_GLOBAL_REDUCE_ADD:
102+
s.reduce_add_value <<= DataType(s.reduce_add_value.payload + s.recv_data.msg.payload.data.payload,
103+
s.recv_data.msg.payload.data.predicate,
104+
0,
105+
0)
106+
elif s.recv_data.msg.payload.cmd == CMD_GLOBAL_REDUCE_MUL:
107+
s.reduce_mul_value <<= DataType(s.reduce_mul_value.payload * s.recv_data.msg.payload.data.payload,
108+
s.recv_data.msg.payload.data.predicate,
109+
0,
110+
0)
111+
112+
def line_trace( s ):
113+
input_str = 'count:' + str(s.recv_count) + ', data:' + str(s.recv_data) + ", receiving_count:" + str(s.receiving_count)
114+
output_str = 'out:'+str(s.send)
115+
return f'{input_str}(){output_str}'
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
'''
2+
=========================================================================
3+
GlobalReduceUnitRTL_test.py
4+
=========================================================================
5+
Simple test for GlobalReduceUnitRTL.
6+
7+
Author : Cheng Tan
8+
Date : Sep 8, 2025
9+
'''
10+
11+
from pymtl3.stdlib.test_utils import config_model_with_cmdline_opts
12+
13+
from ..GlobalReduceUnitRTL import GlobalReduceUnitRTL
14+
from ...lib.basic.val_rdy.SinkRTL import SinkRTL as TestSinkRTL
15+
from ...lib.basic.val_rdy.SourceRTL import SourceRTL as TestSrcRTL
16+
from ...lib.cmd_type import *
17+
from ...lib.messages import *
18+
from ...lib.opt_type import *
19+
20+
#-------------------------------------------------------------------------
21+
# TestHarness
22+
#-------------------------------------------------------------------------
23+
24+
class TestHarness(Component):
25+
26+
def construct(s, DataType, InterCgraPktType, ControllerXbarPktType,
27+
input_count, input_data, expected_output):
28+
29+
s.src_count = TestSrcRTL(InterCgraPktType, input_count)
30+
s.src_data = TestSrcRTL(InterCgraPktType, input_data)
31+
32+
s.sink = TestSinkRTL(ControllerXbarPktType, expected_output)
33+
34+
s.dut = GlobalReduceUnitRTL(DataType,
35+
InterCgraPktType,
36+
ControllerXbarPktType)
37+
38+
# Connections
39+
s.dut.recv_count //= s.src_count.send
40+
s.dut.recv_data //= s.src_data.send
41+
s.dut.send //= s.sink.recv
42+
43+
def done(s):
44+
return s.src_count.done() and \
45+
s.src_data.done() and \
46+
s.sink.done()
47+
48+
def line_trace(s):
49+
return s.dut.line_trace()
50+
51+
#-------------------------------------------------------------------------
52+
# run_rtl_sim
53+
#-------------------------------------------------------------------------
54+
55+
def run_sim(test_harness, max_cycles = 100):
56+
57+
# Creates a simulator.
58+
test_harness.elaborate()
59+
test_harness.apply(DefaultPassGroup())
60+
test_harness.sim_reset()
61+
62+
# Runs simulation.
63+
ncycles = 0
64+
print()
65+
print("{}:{}".format(ncycles, test_harness.line_trace()))
66+
while not test_harness.done() and ncycles < max_cycles:
67+
test_harness.sim_tick()
68+
ncycles += 1
69+
print("{}:{}".format(ncycles, test_harness.line_trace()))
70+
71+
# Checks timeout.
72+
assert ncycles < max_cycles
73+
74+
test_harness.sim_tick()
75+
test_harness.sim_tick()
76+
test_harness.sim_tick()
77+
78+
#-------------------------------------------------------------------------
79+
# Test cases
80+
#-------------------------------------------------------------------------
81+
82+
def test_simple(cmdline_opts):
83+
data_nbits = 32
84+
predicate_nbits = 1
85+
86+
num_cgra_columns = 4
87+
num_cgra_rows = 1
88+
num_cgras = num_cgra_columns * num_cgra_rows
89+
num_tiles = 4
90+
num_rd_tiles = 3
91+
cgra_id_nbits = clog2(num_cgras)
92+
ControllerIdType = mk_bits(cgra_id_nbits)
93+
ctrl_mem_size = 16
94+
num_fu_inports = 2
95+
num_fu_outports = 2
96+
num_tile_inports = 4
97+
num_tile_outports = 4
98+
data_mem_size_global = 16
99+
addr_nbits = clog2(data_mem_size_global)
100+
num_registers_per_reg_bank = 16
101+
cgra_id = 0
102+
103+
idTo2d_map = {
104+
0: [0, 0],
105+
1: [1, 0],
106+
2: [2, 0],
107+
3: [3, 0]
108+
}
109+
110+
controller2addr_map = {
111+
0: [0, 3],
112+
1: [4, 7],
113+
2: [8, 11],
114+
3: [12, 15],
115+
}
116+
117+
DataType = mk_data(data_nbits, predicate_nbits)
118+
DataAddrType = mk_bits(addr_nbits)
119+
120+
CtrlType = mk_ctrl(num_fu_inports,
121+
num_fu_outports,
122+
num_tile_inports,
123+
num_tile_outports,
124+
num_registers_per_reg_bank)
125+
126+
CtrlAddrType = mk_bits(clog2(ctrl_mem_size))
127+
128+
CgraPayloadType = mk_cgra_payload(DataType,
129+
DataAddrType,
130+
CtrlType,
131+
CtrlAddrType)
132+
133+
InterCgraPktType = mk_inter_cgra_pkt(num_cgra_columns,
134+
num_cgra_rows,
135+
num_tiles,
136+
num_rd_tiles,
137+
CgraPayloadType)
138+
139+
ControllerXbarPktType = mk_controller_noc_xbar_pkt(InterCgraPktType)
140+
141+
input_count = [
142+
InterCgraPktType(payload = CgraPayloadType(CMD_GLOBAL_REDUCE_COUNT, data = DataType(3, 0, 0, 0))),
143+
]
144+
145+
input_data = [
146+
# src dst src_x src_y dst_x dst_y src_tile_id dst_tile_id
147+
InterCgraPktType(0, 0, 0, 0, 0, 0, 0, 0, payload = CgraPayloadType(CMD_GLOBAL_REDUCE_ADD, data = DataType(2, 1, 0, 0))),
148+
InterCgraPktType(0, 1, 0, 0, 1, 0, 2, 4, payload = CgraPayloadType(CMD_GLOBAL_REDUCE_ADD, data = DataType(4, 1, 0, 0))),
149+
InterCgraPktType(0, 2, 0, 0, 2, 0, 3, 4, payload = CgraPayloadType(CMD_GLOBAL_REDUCE_ADD, data = DataType(6, 1, 0, 0))),
150+
InterCgraPktType(0, 0, 0, 0, 0, 0, 0, 0, payload = CgraPayloadType(CMD_GLOBAL_REDUCE_ADD, data = DataType(3, 1, 0, 0))),
151+
InterCgraPktType(0, 1, 0, 0, 1, 0, 2, 4, payload = CgraPayloadType(CMD_GLOBAL_REDUCE_ADD, data = DataType(5, 1, 0, 0))),
152+
InterCgraPktType(0, 2, 0, 0, 2, 0, 3, 4, payload = CgraPayloadType(CMD_GLOBAL_REDUCE_ADD, data = DataType(7, 1, 0, 0))),
153+
InterCgraPktType(0, 0, 0, 0, 0, 0, 0, 0, payload = CgraPayloadType(CMD_GLOBAL_REDUCE_MUL, data = DataType(3, 1, 0, 0))),
154+
InterCgraPktType(0, 1, 0, 0, 1, 0, 2, 4, payload = CgraPayloadType(CMD_GLOBAL_REDUCE_MUL, data = DataType(5, 1, 0, 0))),
155+
InterCgraPktType(0, 2, 0, 0, 2, 0, 3, 4, payload = CgraPayloadType(CMD_GLOBAL_REDUCE_MUL, data = DataType(7, 1, 0, 0))),
156+
]
157+
158+
expected_output = [
159+
# Reversed src/dst.
160+
ControllerXbarPktType(inter_cgra_pkt = InterCgraPktType(0, 0, 0, 0, 0, 0, 0, 0, payload = CgraPayloadType(CMD_GLOBAL_REDUCE_ADD_RESPONSE, data = DataType(12, 1, 0, 0)))),
161+
ControllerXbarPktType(inter_cgra_pkt = InterCgraPktType(1, 0, 1, 0, 0, 0, 4, 2, payload = CgraPayloadType(CMD_GLOBAL_REDUCE_ADD_RESPONSE, data = DataType(12, 1, 0, 0)))),
162+
ControllerXbarPktType(inter_cgra_pkt = InterCgraPktType(2, 0, 2, 0, 0, 0, 4, 3, payload = CgraPayloadType(CMD_GLOBAL_REDUCE_ADD_RESPONSE, data = DataType(12, 1, 0, 0)))),
163+
ControllerXbarPktType(inter_cgra_pkt = InterCgraPktType(0, 0, 0, 0, 0, 0, 0, 0, payload = CgraPayloadType(CMD_GLOBAL_REDUCE_ADD_RESPONSE, data = DataType(15, 1, 0, 0)))),
164+
ControllerXbarPktType(inter_cgra_pkt = InterCgraPktType(1, 0, 1, 0, 0, 0, 4, 2, payload = CgraPayloadType(CMD_GLOBAL_REDUCE_ADD_RESPONSE, data = DataType(15, 1, 0, 0)))),
165+
ControllerXbarPktType(inter_cgra_pkt = InterCgraPktType(2, 0, 2, 0, 0, 0, 4, 3, payload = CgraPayloadType(CMD_GLOBAL_REDUCE_ADD_RESPONSE, data = DataType(15, 1, 0, 0)))),
166+
ControllerXbarPktType(inter_cgra_pkt = InterCgraPktType(0, 0, 0, 0, 0, 0, 0, 0, payload = CgraPayloadType(CMD_GLOBAL_REDUCE_MUL_RESPONSE, data = DataType(105, 1, 0, 0)))),
167+
ControllerXbarPktType(inter_cgra_pkt = InterCgraPktType(1, 0, 1, 0, 0, 0, 4, 2, payload = CgraPayloadType(CMD_GLOBAL_REDUCE_MUL_RESPONSE, data = DataType(105, 1, 0, 0)))),
168+
ControllerXbarPktType(inter_cgra_pkt = InterCgraPktType(2, 0, 2, 0, 0, 0, 4, 3, payload = CgraPayloadType(CMD_GLOBAL_REDUCE_MUL_RESPONSE, data = DataType(105, 1, 0, 0)))),
169+
]
170+
171+
th = TestHarness(DataType,
172+
InterCgraPktType,
173+
ControllerXbarPktType,
174+
input_count,
175+
input_data,
176+
expected_output)
177+
th.elaborate()
178+
th = config_model_with_cmdline_opts(th, cmdline_opts, duts = ['dut'])
179+
run_sim(th)

0 commit comments

Comments
 (0)