Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 41 additions & 12 deletions pynestml/codegeneration/nest_code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple

import datetime
import inspect
import re

import odetoolbox
Expand Down Expand Up @@ -98,6 +99,8 @@ class NESTCodeGenerator(CodeGenerator):
- **neuron_parent_class_include**: The C++ header filename to include that contains **neuron_parent_class**. Default: ``"archiving_node.h"``.
- **neuron_synapse_pairs**: List of pairs of (neuron, synapse) model names.
- **synapse_models**: List of synapse model names. Instructs the code generator that models with these names are synapse models.
- **disable_singularity_detection**: Set to True to disable detection of conditions under which numerical singularities (division by zero) could occur in the generated analytic solver. This can be useful for analytic solvers containing a large amount of conditions, which could take a long time to compute. If True, at most one analytic solver will be returned, in which numerical singularities could occur. (This parameter is directly passed to ODE-toolbox.)
- **use_alternative_expM**: If :python:`False`, use the sympy function ``sympy.exp`` to compute the matrix exponential. If :python:`True`, use an alternative function (see :py:func:`odetoolbox.sympy_helpers.expMt` for details). This can be useful as calls to ``sympy.exp`` can sometimes take a very large amount of time. (This parameter is directly passed to ODE-toolbox.)
- **preserve_expressions**: Set to True, or a list of strings corresponding to individual variable names, to disable internal rewriting of expressions, and return same output as input expression where possible. Only applies to variables specified as first-order differential equations. (This parameter is passed to ODE-toolbox.)
- **simplify_expression**: For all expressions ``expr`` that are rewritten by ODE-toolbox: the contents of this parameter string are ``eval()``ed in Python to obtain the final output expression. Override for custom expression simplification steps. Example: ``sympy.simplify(expr)``. Default: ``"sympy.logcombine(sympy.powsimp(sympy.expand(expr)))"``. (This parameter is passed to ODE-toolbox.)
- **gap_junctions**:
Expand Down Expand Up @@ -126,6 +129,8 @@ class NESTCodeGenerator(CodeGenerator):
"neuron_parent_class_include": "archiving_node.h",
"neuron_synapse_pairs": [],
"synapse_models": [],
"disable_singularity_detection": False,
"use_alternative_expM": False,
"preserve_expressions": True,
"simplify_expression": "sympy.logcombine(sympy.powsimp(sympy.expand(expr)))",
"gap_junctions": {
Expand Down Expand Up @@ -414,8 +419,12 @@ def analyse_neuron(self, neuron: ASTModel) -> Tuple[Dict[str, ASTAssignment], Di
ASTUtils.replace_convolution_aliasing_inlines(neuron)

if self.analytic_solver[neuron.get_name()] is not None:
neuron = ASTUtils.add_declarations_to_internals(
neuron, self.analytic_solver[neuron.get_name()]["propagators"])
if "conditions" in self.analytic_solver[neuron.get_name()].keys():
propagators = self.analytic_solver[neuron.get_name()]["conditions"]["default"]["propagators"]
else:
propagators = self.analytic_solver[neuron.get_name()]["propagators"]

neuron = ASTUtils.add_declarations_to_internals(neuron, propagators)

self.update_symbol_table(neuron)

Expand Down Expand Up @@ -461,8 +470,12 @@ def analyse_synapse(self, synapse: ASTModel) -> Dict[str, ASTAssignment]:
spike_updates, _ = self.get_spike_update_expressions(synapse, kernel_buffers, [analytic_solver, numeric_solver], delta_factors)

if not self.analytic_solver[synapse.get_name()] is None:
synapse = ASTUtils.add_declarations_to_internals(
synapse, self.analytic_solver[synapse.get_name()]["propagators"])
if "conditions" in self.analytic_solver[synapse.get_name()].keys():
propagators = self.analytic_solver[synapse.get_name()]["conditions"]["default"]["propagators"]
else:
propagators = self.analytic_solver[synapse.get_name()]["propagators"]

synapse = ASTUtils.add_declarations_to_internals(synapse, propagators)

self.update_symbol_table(synapse)

Expand Down Expand Up @@ -801,8 +814,11 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict:

namespace["update_expressions"] = {}
for sym in namespace["analytic_state_variables"] + namespace["analytic_state_variables_moved"]:
expr_str = self.analytic_solver[neuron.get_name()]["update_expressions"][sym]
expr_str = ODEToolboxUtils._rewrite_piecewise_into_ternary(expr_str)
if "conditions" in self.analytic_solver[neuron.get_name()].keys():
update_expressions = self.analytic_solver[neuron.get_name()]["conditions"]["default"]["update_expressions"][sym]
else:
update_expressions = self.analytic_solver[neuron.get_name()]["update_expressions"][sym]
expr_str = ODEToolboxUtils._rewrite_piecewise_into_ternary(update_expressions)
expr_ast = ModelParser.parse_expression(expr_str)
# pretend that update expressions are in "equations" block, which should always be present, as differential equations must have been defined to get here
expr_ast.update_scope(neuron.get_equations_blocks()[0].get_scope())
Expand All @@ -820,7 +836,10 @@ def _get_neuron_model_namespace(self, neuron: ASTModel) -> Dict:
sets_vector_param_in_update_expr_visitor = ASTSetVectorParameterInUpdateExpressionVisitor(var)
expr_ast.accept(sets_vector_param_in_update_expr_visitor)

namespace["propagators"] = self.analytic_solver[neuron.get_name()]["propagators"]
if "conditions" in self.analytic_solver[neuron.get_name()].keys():
namespace["propagators"] = self.analytic_solver[neuron.get_name()]["conditions"]["default"]["propagators"]
else:
namespace["propagators"] = self.analytic_solver[neuron.get_name()]["propagators"]

namespace["propagators_are_state_dependent"] = False
for prop_name, prop_expr in namespace["propagators"].items():
Expand Down Expand Up @@ -959,11 +978,21 @@ def ode_toolbox_analysis(self, neuron: ASTModel, kernel_buffers: Mapping[ASTKern

odetoolbox_indict["options"]["simplify_expression"] = self.get_option("simplify_expression")
disable_analytic_solver = self.get_option("solver") != "analytic"
solver_result = odetoolbox.analysis(odetoolbox_indict,
disable_stiffness_check=True,
disable_analytic_solver=disable_analytic_solver,
preserve_expressions=self.get_option("preserve_expressions"),
log_level=FrontendConfiguration.logging_level)
if not "use_alternative_expM" in inspect.signature(odetoolbox.analysis).parameters.keys():
Logger.log_message(None, None, "Old version of ODE-toolbox used; consider upgrading. ``disable_singularity_detection`` and ``use_alternative_expM`` flags will be ignored.", None, LoggingLevel.WARNING)
solver_result = odetoolbox.analysis(odetoolbox_indict,
disable_stiffness_check=True,
disable_analytic_solver=disable_analytic_solver,
preserve_expressions=self.get_option("preserve_expressions"),
log_level=FrontendConfiguration.logging_level)
else:
solver_result = odetoolbox.analysis(odetoolbox_indict,
disable_stiffness_check=True,
disable_analytic_solver=disable_analytic_solver,
disable_singularity_detection=self.get_option("disable_singularity_detection"),
use_alternative_expM=self.get_option("use_alternative_expM"),
preserve_expressions=self.get_option("preserve_expressions"),
log_level=FrontendConfiguration.logging_level)
analytic_solver = None
analytic_solvers = [x for x in solver_result if x["solver"] == "analytical"]
assert len(analytic_solvers) <= 1, "More than one analytic solver not presently supported"
Expand Down
6 changes: 4 additions & 2 deletions pynestml/utils/ast_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2221,9 +2221,11 @@ def get_delta_factors_(cls, neuron: ASTModel, equations_block: ASTEquationsBlock
if cls.is_delta_kernel(neuron.get_kernel_by_name(kernel.get_variable().get_name())):
inport = conv_call.args[1].get_variable()
expr_str = str(expr)
sympy_expr = sympy.parsing.sympy_parser.parse_expr(expr_str, global_dict=odetoolbox.Shape._sympy_globals)
global_dict = odetoolbox.Shape._sympy_globals.copy()
sympy_expr = sympy.parsing.sympy_parser.parse_expr(expr_str, global_dict=global_dict)
sympy_expr = sympy.expand(sympy_expr)
sympy_conv_expr = sympy.parsing.sympy_parser.parse_expr(str(conv_call), global_dict=odetoolbox.Shape._sympy_globals)
global_dict = odetoolbox.Shape._sympy_globals.copy()
sympy_conv_expr = sympy.parsing.sympy_parser.parse_expr(str(conv_call), global_dict=global_dict)
factor_str = []
for term in sympy.Add.make_args(sympy_expr):
if term.find(sympy_conv_expr):
Expand Down
2 changes: 1 addition & 1 deletion pynestml/utils/ode_toolbox_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _rewrite_piecewise_into_ternary(cls, s: str) -> str:
"Float": sympy.Float,
"Function": sympy.Function}

sympy_expr = sympy.parsing.sympy_parser.parse_expr(s, global_dict=_sympy_globals_no_functions)
sympy_expr = sympy.parsing.sympy_parser.parse_expr(s, global_dict=_sympy_globals_no_functions.copy())

class MySympyPrinter(StrPrinter):
"""Resulting expressions will be parsed by NESTML parser. R
Expand Down
Loading