Skip to content

Commit 07873e2

Browse files
authored
Add argument for use_groups to subMatrix (#225)
1 parent 6db779d commit 07873e2

2 files changed

Lines changed: 19 additions & 3 deletions

File tree

src/storage/matrix.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,9 @@ void define_sparse_matrix(py::module& m, std::string const& vtSuffix) {
9393
}
9494
return stream.str();
9595
}, py::arg("row"), "Print rows from start to end")
96-
.def("submatrix", [](SparseMatrix<ValueType> const& matrix, storm::storage::BitVector const& rowConstraint, storm::storage::BitVector const& columnConstraint, bool insertDiagonalEntries = false) {
97-
return matrix.getSubmatrix(true, rowConstraint, columnConstraint, insertDiagonalEntries);
98-
}, py::arg("row_constraint"), py::arg("column_constraint"), py::arg("insert_diagonal_entries") = false, "Get submatrix")
96+
.def("submatrix", [](SparseMatrix<ValueType> const& matrix, storm::storage::BitVector const& rowConstraint, storm::storage::BitVector const& columnConstraint, bool insertDiagonalEntries = false, bool useGroups = true) {
97+
return matrix.getSubmatrix(useGroups, rowConstraint, columnConstraint, insertDiagonalEntries);
98+
}, py::arg("row_constraint"), py::arg("column_constraint"), py::arg("insert_diagonal_entries") = false, py::arg("use_groups") = true, "Get submatrix")
9999
// Entry_index lead to problems
100100
.def("row_iter", [](SparseMatrix<ValueType>& matrix, row_index start, row_index end) {
101101
return py::make_iterator(matrix.begin(start), matrix.end(end));

tests/storage/test_matrix.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,3 +152,19 @@ def test_submatrix(self):
152152
assert submatrix.nr_entries == 10
153153
for e in submatrix:
154154
assert e.value() == 0.5 or e.value() == 0 or (e.value() == 1 and e.column > 3)
155+
156+
def test_submatrix_no_groups(self):
157+
model = stormpy.build_sparse_model_from_explicit(get_example_path("mdp", "two_dice.tra"), get_example_path("mdp", "two_dice.lab"))
158+
matrix = model.transition_matrix
159+
assert type(matrix) is stormpy.storage.SparseMatrix
160+
assert matrix.nr_rows == 254
161+
assert matrix.nr_columns == model.nr_states
162+
assert matrix.nr_entries == 436
163+
assert matrix.nr_entries == model.nr_transitions
164+
165+
assert matrix.nr_rows != matrix.nr_columns
166+
167+
row_constraint = stormpy.BitVector(254, [0, 1, 3, 4, 7, 8, 9])
168+
submatrix = matrix.submatrix(row_constraint, row_constraint, use_groups=False)
169+
assert submatrix.nr_rows == 7
170+
assert submatrix.nr_columns == 7

0 commit comments

Comments
 (0)