Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
466f775
add test for synapse post neuron transformer
Mar 5, 2026
d487778
add test for synapse post neuron transformer
Mar 5, 2026
8bf5c83
Merge remote-tracking branch 'clinssen/test_synapse_post_neuron_trans…
Mar 5, 2026
8b95377
add test for synapse post neuron transformer
Mar 19, 2026
8b082dd
test synapse post neuron transformer
Mar 19, 2026
1983993
replace custom NEST is_type with std::holds_alternative
Mar 27, 2026
cf1c824
Support for downstream NEST with std::variant (#24)
pnbabu Mar 30, 2026
9deadbd
Merge remote-tracking branch 'upstream/main' into holds_alternative
Mar 31, 2026
01ad099
Merge remote-tracking branch 'clinssen/holds_alternative' into holds_…
Mar 31, 2026
4039733
Merge remote-tracking branch 'upstream/main' into holds_alternative
Apr 2, 2026
f45d9aa
Merge remote-tracking branch 'upstream/main' into test_synapse_post_n…
Apr 3, 2026
702c0e6
Merge remote-tracking branch 'clinssen/holds_alternative' into test_s…
Apr 3, 2026
a1fada2
synapse post transformer wip ??
Apr 3, 2026
a413948
wip
Apr 4, 2026
4d95204
wip
Apr 5, 2026
d0ff7d4
fix and add test for synapse post transformer
Apr 5, 2026
26b81af
fix and add test for synapse post transformer
Apr 5, 2026
eb7ce54
fix and add test for synapse post transformer
Apr 5, 2026
a975556
Merge remote-tracking branch 'clinssen/test_synapse_post_neuron_trans…
Apr 6, 2026
efd2548
Merge remote-tracking branch 'clinssen/holds_alternative' into test_s…
Apr 6, 2026
1c0afd5
fix and add test for synapse post transformer
Apr 6, 2026
e402540
fix and add test for synapse post transformer
Apr 9, 2026
8f6923f
fix and add test for synapse post transformer
Apr 9, 2026
8378b22
Merge remote-tracking branch 'upstream/main' into test_synapse_post_n…
Apr 12, 2026
a816d44
fix and add test for synapse post transformer
Apr 13, 2026
e25a051
Merge remote-tracking branch 'upstream/main' into test_synapse_post_n…
Apr 14, 2026
c4dd67c
updates for new NEST API
Apr 14, 2026
b1f93f7
fix and add test for synapse post transformer
Apr 15, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions doc/running/running_nest.rst
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,12 @@ During code generation, the third-factor variable of the synapse and its corresp

This specifies that the neuron ``iaf_psc_exp_dend`` has to be generated paired with the synapse ``third_factor_stdp``, and that the input ports ``post_spikes`` and ``I_post_dend`` in the synapse are to be connected to the postsynaptic partner. For the ``I_post_dend`` input port, the corresponding variable in the (postsynaptic) neuron is called ``I_dend``. Note that inline expressions can also be used; in this example in case ``I_dend`` had been an inline expression in the postsynaptic neuron.

To prevent the NESTML code generator from moving specific variables from synapse into postsynaptic neuron, the code generation option ``strictly_synaptic_vars`` may be used (see https://nestml.readthedocs.io/en/latest/pynestml.transformers.html#pynestml.transformers.synapse_post_neuron_transformer.SynapsePostNeuronTransformer). Note that the paired-generation can cause subtle changes in the order in which variables are updated. In particular, the numerical value obtained for any moved variables at a time of a spike is always the value "just before" the update due to the spike. Please see the unit test `test_synapse_post_neuron_transformer_update_order.py <https://github.com/nest/nestml/blob/master/tests/nest_tests/test_synapse_post_neuron_transformer_update_order.py>`_.

.. warning::

To ensure correct and reproducible results, it is recommended to generate code initially with ``strictly_synaptic_vars`` marking all synaptic variables. Verify numerical results by hand or using a small, hand-written reference implementation of the model. Only afterwards, remove ``strictly_synaptic_vars``, which results in much more efficient code and better runtime performance, and again validate the results.

When a continuous-time input port is defined in the synapse model which is connected to a postsynaptic neuron, a corresponding buffer is allocated in each neuron which retains the recent history of the needed state variables. Two options are available for how the buffer is implemented: a "continuous-time" based buffer, or a spike-based buffer (see the NEST code generator option ``continuous_state_buffering_method`` on :class:`pynestml.codegeneration.html#pynestml.codegeneration.nest_code_generator.NESTCodeGenerator`).

By default, the "continuous-time" based buffer is selected. This covers the most general case of different synaptic delay values and a discontinuous third-factor signal. The implementation corresponds to the event-based update scheme in Fig. 4b of [Stapmanns2021]_. There, the authors observe that the storage and management of such a buffer can be expensive in terms of memory and runtime. In each time step, the value of the current dendritic current (or membrane potential, or other third factor) is appended to the buffer. The maximum length of the buffer depends on the maximum inter-spike interval of any of the presynaptic neurons.
Expand Down
34 changes: 32 additions & 2 deletions pynestml/transformers/synapse_post_neuron_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,19 @@
# Fallback for Python 3.8 - 3.11
from typing_extensions import override

from pynestml.cocos.co_cos_manager import CoCosManager
from pynestml.codegeneration.code_generator_utils import CodeGeneratorUtils
from pynestml.frontend.frontend_configuration import FrontendConfiguration
from pynestml.meta_model.ast_assignment import ASTAssignment
from pynestml.meta_model.ast_compound_stmt import ASTCompoundStmt
from pynestml.meta_model.ast_equations_block import ASTEquationsBlock
from pynestml.meta_model.ast_for_stmt import ASTForStmt
from pynestml.meta_model.ast_if_stmt import ASTIfStmt
from pynestml.meta_model.ast_inline_expression import ASTInlineExpression
from pynestml.meta_model.ast_model import ASTModel
from pynestml.meta_model.ast_node import ASTNode
from pynestml.meta_model.ast_simple_expression import ASTSimpleExpression
from pynestml.meta_model.ast_variable import ASTVariable
from pynestml.symbols.predefined_variables import PredefinedVariables
from pynestml.meta_model.ast_while_stmt import ASTWhileStmt
from pynestml.symbols.symbol import SymbolKind
from pynestml.symbols.variable_symbol import BlockType
from pynestml.transformers.transformer import Transformer
Expand Down Expand Up @@ -186,6 +188,34 @@ def transform_neuron_synapse_pair_(self, neuron: ASTModel, synapse: ASTModel, me
if self.option_exists("weight_variable") and removesuffix(synapse.get_name(), FrontendConfiguration.suffix) in self.get_option("weight_variable").keys() and self.get_option("weight_variable")[removesuffix(synapse.get_name(), FrontendConfiguration.suffix)]:
strictly_synaptic_vars.add(self.get_option("weight_variable")[removesuffix(synapse.get_name(), FrontendConfiguration.suffix)])

# exclude variables that are written to inside compound blocks
for input_block in new_synapse.get_input_blocks():
for port in input_block.get_input_ports():
if CodeGeneratorUtils.is_post_port(port.name, neuron.name, synapse.name, neuron_synapse_pairs=self._options["neuron_synapse_pairs"]):
post_receive_blocks = ASTUtils.get_on_receive_blocks_by_input_port_name(new_synapse, port.name)
for post_receive_block in post_receive_blocks:

class VariablesAssignedToInCompoundBlocksVisitor(ASTVisitor):
r"""Find variables assigned to in compound blocks"""

variables_assigned_to_in_compound_block: Set[str] = set()

def __init__(self):
super().__init__()
self.variables_assigned_to_in_compound_block: Set[str] = set()

def visit_small_stmt(self, node):
if node.is_assignment() \
and (ASTUtils.find_parent_node_by_type(node, ASTIfStmt)
or ASTUtils.find_parent_node_by_type(node, ASTWhileStmt)
or ASTUtils.find_parent_node_by_type(node, ASTForStmt)
or ASTUtils.find_parent_node_by_type(node, ASTCompoundStmt)):
self.variables_assigned_to_in_compound_block.add(node.get_assignment().get_variable().get_complete_name())

visitor = VariablesAssignedToInCompoundBlocksVisitor()
post_receive_block.accept(visitor)
strictly_synaptic_vars |= visitor.variables_assigned_to_in_compound_block

affected_vars = ASTUtils.collect_variables_affected_by_ports(synapse, post_port_names, strictly_synaptic_vars=strictly_synaptic_vars)
metadata[new_neuron.name]["syn_to_neuron_state_vars"] = [var for var in affected_vars if not (synapse.get_kernel_by_name(var) or neuron.get_kernel_by_name(var))]

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,324 @@
# -*- coding: utf-8 -*-
#
# test_synapse_post_neuron_transformer_compound_blocks.py
#
# This file is part of NEST.
#
# Copyright (C) 2004 The NEST Initiative
#
# NEST is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# NEST is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.

import numpy as np
import os
import pytest

# try to import matplotlib; set the result in the flag TEST_PLOTS
try:
import matplotlib as mpl
mpl.use("agg")
import matplotlib.pyplot as plt
TEST_PLOTS = True
except BaseException:
TEST_PLOTS = False

import nest

from pynestml.codegeneration.nest_tools import NESTTools
from pynestml.frontend.pynestml_frontend import generate_nest_target


def generate_regular_spike_train(rate, T, dt, start_t=0.):
"""
Generates regular spike train as a Boolean array.

Parameters:
- rate: Firing rate in Hz.
- T: Total simulation time in ms.
- dt: Time step in ms.

Returns:
- spike_train: Boolean array of length len(time), where True indicates a spike.
- spike_times: same data as spike time data
"""
# generate spike times
spikes = []
t = start_t
while t < T:
isi = 1 / rate * 1000
t += isi
if t < T:
spikes.append(t)

spike_times = np.array(spikes)

time = np.arange(0, T, dt)
spike_train = np.zeros(len(time), dtype=bool)

# map spike times to nearest indices in the time array; ensure that spike times correspond to time steps in "time"
indices = np.searchsorted(time, spike_times)
indices = indices[indices < len(time)]

spike_train[indices] = True

return spike_train, spike_times


@pytest.mark.skipif(NESTTools.detect_nest_version().startswith("v2"),
reason="This test does not support NEST 2")
class TestSynapsePostNeuronTransformerCompoundBlocks:
r"""This test checks that variables inside compound blocks are properly identified by the transformer as "strictly synaptic" variables."""
@pytest.fixture(autouse=True)
def generate_model_code(self, request):
r"""Generate the NEST C++ code for neuron and synapse models"""

files = [os.path.join("models", "neurons", "iaf_psc_delta_neuron.nestml"),
os.path.join("tests", "resources", "double_postsyn_trace_synapse.nestml")]
input_path = [os.path.realpath(os.path.join(os.path.dirname(__file__), os.path.join(os.pardir, os.pardir, s))) for s in files]

codegen_opts = {"neuron_synapse_pairs": [{"neuron": "iaf_psc_delta_neuron",
"synapse": "double_postsyn_trace_synapse",
"post_ports": ["post_spikes"]}],
"weight_variable": {"double_postsyn_trace_synapse": "w"}}

use_synapse_post_neuron_transformer = request.param

print("Now generating code with ``use_synapse_post_neuron_transformer`` = " + str(use_synapse_post_neuron_transformer))

if not use_synapse_post_neuron_transformer:
codegen_opts["strictly_synaptic_vars"] = {"double_postsyn_trace_synapse": ["zp_trace", "zm_trace"]}

generate_nest_target(input_path=input_path,
logging_level="DEBUG",
suffix="_nestml",
codegen_opts=codegen_opts)

@pytest.mark.parametrize("generate_model_code", (False, True), indirect=True)
@pytest.mark.parametrize("update_order_w_before_trace", [True, False])
def test_double_postsyn_trace_synapse(self, update_order_w_before_trace: bool):
neuron_model_name = "iaf_psc_delta_neuron_nestml__with_double_postsyn_trace_synapse_nestml"
synapse_model_name = "double_postsyn_trace_synapse_nestml__with_iaf_psc_delta_neuron_nestml"

nest.ResetKernel()
nest.Install("nestmlmodule")
nest.print_time = False
NESTTools.set_nest_verbosity("ERROR")

# define spike train
rate = 20
T = 500
dt = 0.1
syn_delay = dt
nest.resolution = dt

spike_train_pre_bool, spikes_pre = generate_regular_spike_train(rate, T, dt, start_t=15.)
spike_train_post_bool, spikes_post = generate_regular_spike_train(rate * 2, T, dt, start_t=2.5)

spikes_pre = np.round(spikes_pre, 1)
spikes_post = np.round(spikes_post, 1)

spikes_post -= syn_delay
spikes_pre -= syn_delay

p = {"Z": 0.4102089913,
"Z2": 0.11584585200000001,
"tau_post": 100,
"tau_post2": 63.297325105}

wr = nest.Create("weight_recorder")
syn_model = "stdp_stp_synapse_rec"
initial_w = 1.0
nest.CopyModel(synapse_model_name, syn_model, {"weight_recorder": wr,
"w": initial_w,
"delay": syn_delay,
"update_order_w_before_trace": update_order_w_before_trace,
"receptor_type": 0,
"Zp": p["Z2"],
"Zm": p["Z"],
"tau_zm": p["tau_post"],
"tau_zp": p["tau_post2"]})

spikes_pre_gr = nest.Create("spike_generator", params={"spike_times": spikes_pre})
spikes_post_gr = nest.Create("spike_generator", params={"spike_times": spikes_post})

pre_neuron = nest.Create("parrot_neuron")
post_neuron = nest.Create(neuron_model_name)

# set postsynaptic parameters
try:
post_neuron.Zp__for_double_postsyn_trace_synapse_nestml = p["Z2"]
post_neuron.Zm__for_double_postsyn_trace_synapse_nestml = p["Z"]
post_neuron.tau_zm__for_double_postsyn_trace_synapse_nestml = p["tau_post"]
post_neuron.tau_zp__for_double_postsyn_trace_synapse_nestml = p["tau_post2"]
except BaseException:
try:
post_neuron.Zp__for_minimal_SLSTDP_old_synapse_nestml = p["Z2"]
post_neuron.Zm__for_minimal_SLSTDP_old_synapse_nestml = p["Z"]
post_neuron.tau_zm__for_minimal_SLSTDP_old_synapse_nestml = p["tau_post"]
post_neuron.tau_zp__for_minimal_SLSTDP_old_synapse_nestml = p["tau_post2"]
except BaseException:
pass

nest.Connect(spikes_pre_gr, pre_neuron, "one_to_one", syn_spec={"delay": syn_delay})
nest.Connect(pre_neuron, post_neuron, "one_to_one", syn_spec={"synapse_model": syn_model})
nest.Connect(spikes_post_gr, post_neuron, "one_to_one", syn_spec={"delay": syn_delay, "weight": 9999.})

conn = nest.GetConnections(target=post_neuron, synapse_model=syn_model)
try:
conn.Zp = p["Z2"]
conn.Zm = p["Z"]
except BaseException:
pass

# spike detectors
spikedet_pre = nest.Create("spike_recorder")
spikedet_post = nest.Create("spike_recorder")
nest.Connect(pre_neuron, spikedet_pre)
nest.Connect(post_neuron, spikedet_post)

n_steps = int(np.ceil(T / syn_delay))
trace_nest = []
trace_nest_t = []
trace_nest_z2 = []

t = nest.biological_time
trace_nest_t.append(t)

try:
trace_nest.append(post_neuron.zm_trace__for_double_postsyn_trace_synapse_nestml)
trace_nest_z2.append(post_neuron.zp_trace__for_double_postsyn_trace_synapse_nestml)
except BaseException:
try:
trace_nest.append(post_neuron.zm_trace__for_minimal_SLSTDP_old_synapse_nestml)
trace_nest_z2.append(post_neuron.zp_trace__for_minimal_SLSTDP_old_synapse_nestml)
except BaseException:
trace_nest.append(conn.get("zm_trace"))
trace_nest_z2.append(conn.get("zp_trace"))

w_trace = []

w_trace.append(conn.get("weight"))
for step in range(n_steps):
nest.Simulate(syn_delay)
t = nest.biological_time

trace_nest_t.append(t)

conn = nest.GetConnections(target=post_neuron, synapse_model=syn_model)
# post
try:
trace_nest.append(post_neuron.zm_trace__for_double_postsyn_trace_synapse_nestml)
trace_nest_z2.append(post_neuron.zp_trace__for_double_postsyn_trace_synapse_nestml)
except BaseException:
try:
trace_nest.append(post_neuron.zm_trace__for_minimal_SLSTDP_old_synapse_nestml)
trace_nest_z2.append(post_neuron.zp_trace__for_minimal_SLSTDP_old_synapse_nestml)
except BaseException:
trace_nest.append(conn.get("zm_trace"))
trace_nest_z2.append(conn.get("zp_trace"))

w_trace.append(conn.get("weight"))

conn = nest.GetConnections(target=post_neuron, synapse_model=syn_model)

if TEST_PLOTS:
fig, axs = plt.subplots(3, 1)
events = wr.get("events")

axs[0].set_ylabel("w")
axs[0].grid(True)
axs[0].step(trace_nest_t, w_trace, color="black", where="post", label="NEST")

axs[1].plot(trace_nest_t, trace_nest, label="z1 nestml", color="red")
axs[1].plot(trace_nest_t, trace_nest_z2, label="z2 nestml", color="black")
axs[1].legend()
axs[1].grid(True)

axs[-1].scatter(spikedet_post.events["times"], np.zeros_like(spikedet_post.events["times"]), label="post sp")
axs[-1].scatter(spikedet_pre.events["times"], np.ones_like(spikedet_pre.events["times"]), label="pre sp")
axs[-1].set_xlim(axs[0].get_xlim())
axs[1].grid(True)

def spike_times_to_bool_array(spike_times, T, dt):
time = np.arange(0, T, dt)
spike_train = np.zeros(len(time), dtype=bool)

# map spike times to nearest indices in the time array; ensure that spike times correspond to time steps in "time"
indices = np.searchsorted(time, spike_times)
indices = indices[indices < len(time)]
spike_train[indices] = True

return spike_train

spike_train_pre = spikedet_pre.events["times"]
spike_train_post = spikedet_post.events["times"]

spike_train_pre_bool = spike_times_to_bool_array(spike_train_pre, T, dt)
spike_train_post_bool = spike_times_to_bool_array(spike_train_post, T, dt)

# Lists to record variables for plotting
w_history = []
z_history = []
z2_history = []

w = initial_w # Initial synaptic weight
z = 0.0 # Post-synaptic trace
z2 = 0.0

time = np.arange(0, T, dt)
for t in range(len(time)):
z *= np.exp(-dt / p["tau_post"])
z2 *= np.exp(-dt / p["tau_post2"])

if spike_train_pre_bool[t] and spike_train_post_bool[t]:
raise Exception()

if spike_train_post_bool[t]:
if update_order_w_before_trace:
w += z * z2 # Update synaptic weight
z += p["Z"] * (1 - z) # Increment post-synaptic trace
z2 += p["Z2"] * (1 - z2)
else:
z += p["Z"] * (1 - z) # Increment post-synaptic trace
z2 += p["Z2"] * (1 - z2)
w += z * z2 # Update synaptic weight

w_history.append(w)
z_history.append(z)
z2_history.append(z2)

if TEST_PLOTS:
axs[0].plot(time, w_history, "--", color="black", label="w python")
axs[0].legend()

axs[1].plot(time, z_history, "--", label=r"$z_1$ python", color="red")
axs[1].plot(time, z2_history, "--", label=r"$z_2$ python", color="black")
axs[1].legend()

plt.savefig("/tmp/test_synapse_post_neuron_transformer.png")

#
# testing
#

assert len(spike_train_pre) > 0
for pre_spike_time in spike_train_pre:
tidx_nest = np.argmin((pre_spike_time - trace_nest_t)**2)
w_according_to_nest = w_trace[tidx_nest + 1]

tidx_ref = np.argmin((pre_spike_time - time)**2)
w_according_to_ref = w_history[tidx_ref]

np.testing.assert_allclose(trace_nest_t[tidx_nest], time[tidx_ref])

np.testing.assert_allclose(w_according_to_nest, w_according_to_ref)
Loading
Loading