Skip to content
108 changes: 103 additions & 5 deletions python/src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#include <pybind11/numpy.h>
#include <pybind11/functional.h>

#include <string.h>

#include <cpu.h>
#include <gpu.h>
#include <net.h>
Expand Down Expand Up @@ -248,9 +250,9 @@ PYBIND11_MODULE(ncnn, m)

.def(py::init<const Mat&>(), py::arg("m"))

.def(py::init([](py::buffer const b) {
.def(py::init([](py::buffer const b, int batch_index) {
py::buffer_info info = b.request();
if (info.ndim > 4)
if (batch_index == 233 && info.ndim > 4)
{
std::stringstream ss;
ss << "convert numpy.ndarray to ncnn.Mat only dims <=4 support now, but given " << info.ndim;
Expand All @@ -259,6 +261,85 @@ PYBIND11_MODULE(ncnn, m)

size_t elemsize = info.itemsize;

if (batch_index != 233)
{
if (info.ndim > 5)
{
std::stringstream ss;
ss << "convert numpy.ndarray to ncnn.Mat with batch only dims <=5 support now, but given " << info.ndim;
pybind11::pybind11_fail(ss.str());
}

if (info.ndim < 2)
{
std::stringstream ss;
ss << "convert numpy.ndarray to ncnn.Mat with batch only dims >=2 support now, but given " << info.ndim;
pybind11::pybind11_fail(ss.str());
}

if (batch_index < 0)
batch_index += info.ndim;

if (batch_index < 0 || batch_index >= info.ndim)
{
std::stringstream ss;
ss << "batch_index out of range";
pybind11::pybind11_fail(ss.str());
}

std::vector<int> shape;
for (int i = 0; i < info.ndim; i++)
{
if (i == batch_index)
continue;
shape.push_back((int)info.shape[i]);
}

Mat* v = new Mat;
if (shape.size() == 1)
{
v->create_batch(shape[0], (int)info.shape[batch_index], elemsize, 1);
}
else if (shape.size() == 2)
{
v->create_batch(shape[1], shape[0], (int)info.shape[batch_index], elemsize, 1);
}
else if (shape.size() == 3)
{
v->create_batch(shape[2], shape[1], shape[0], (int)info.shape[batch_index], elemsize, 1);
}
else if (shape.size() == 4)
{
v->create_batch(shape[3], shape[2], shape[1], shape[0], (int)info.shape[batch_index], elemsize, 1);
}

py::object src = py::reinterpret_borrow<py::object>(b);
for (int i = 0; i < v->n; i++)
{
py::array slice = src.attr("take")(i, py::arg("axis") = batch_index).attr("copy")();
py::buffer_info slice_info = slice.request();

Mat mb = v->batch(i);
const unsigned char* sptr = (const unsigned char*)slice_info.ptr;

if (mb.dims <= 2)
{
memcpy(mb.data, sptr, (size_t)mb.w * mb.h * elemsize);
}
else
{
size_t channel_size = (size_t)mb.w * mb.h * mb.d * elemsize;
for (int q = 0; q < mb.c; q++)
{
Mat mbq = mb.channel(q);
memcpy(mbq.data, sptr + channel_size * q, channel_size);
}
}
}

return std::unique_ptr<Mat>(v);
}

Mat* v = nullptr;
if (info.ndim == 1)
{
Expand Down Expand Up @@ -288,16 +369,27 @@ PYBIND11_MODULE(ncnn, m)
}
return std::unique_ptr<Mat>(v);
}),
py::arg("array"))
py::arg("array"), py::arg("batch_index") = 233)
.def_buffer([](Mat& m) -> py::buffer_info {
return to_buffer_info(m);
})
.def(
"numpy", [](py::object obj, const std::string& format = "") -> py::array {
"numpy", [](py::object obj, const std::string& format = "", int batch_index = 233) -> py::array {
auto* m = obj.cast<Mat*>();
if (batch_index != 233)
{
py::object numpy = py::module_::import("numpy");
py::list batch_slices;
for (int i = 0; i < m->n; i++)
{
Mat mb = m->batch(i);
batch_slices.append(py::array(to_buffer_info(mb, format)));
}
return numpy.attr("stack")(batch_slices, py::arg("axis") = batch_index).cast<py::array>();
}
return py::array(to_buffer_info(*m, format), obj);
},
py::arg("format") = "", "i for int32, f for float32, d for double")
py::arg("format") = "", py::arg("batch_index") = 233, "i for int32, f for float32, d for double")
//.def("fill", (void (Mat::*)(int))(&Mat::fill), py::arg("v"))
.def("fill", (void (Mat::*)(float))(&Mat::fill), py::arg("v"))
.def("clone", &Mat::clone, py::arg("allocator") = nullptr)
Expand Down Expand Up @@ -351,13 +443,18 @@ PYBIND11_MODULE(ncnn, m)
.def("create", (void (Mat::*)(int, int, size_t, int, Allocator*)) & Mat::create, py::arg("w"), py::arg("h"), py::kw_only(), py::arg("elemsize") = 4, py::arg("elempack") = 1, py::arg("allocator") = nullptr)
.def("create", (void (Mat::*)(int, int, int, size_t, int, Allocator*)) & Mat::create, py::arg("w"), py::arg("h"), py::arg("c"), py::kw_only(), py::arg("elemsize") = 4, py::arg("elempack") = 1, py::arg("allocator") = nullptr)
.def("create", (void (Mat::*)(int, int, int, int, size_t, int, Allocator*)) & Mat::create, py::arg("w"), py::arg("h"), py::arg("d"), py::arg("c"), py::kw_only(), py::arg("elemsize") = 4, py::arg("elempack") = 1, py::arg("allocator") = nullptr)
.def("create_batch", (void (Mat::*)(int, int, size_t, int, Allocator*)) & Mat::create_batch, py::arg("w"), py::arg("batch"), py::kw_only(), py::arg("elemsize") = 4, py::arg("elempack") = 1, py::arg("allocator") = nullptr)
.def("create_batch", (void (Mat::*)(int, int, int, size_t, int, Allocator*)) & Mat::create_batch, py::arg("w"), py::arg("h"), py::arg("batch"), py::kw_only(), py::arg("elemsize") = 4, py::arg("elempack") = 1, py::arg("allocator") = nullptr)
.def("create_batch", (void (Mat::*)(int, int, int, int, size_t, int, Allocator*)) & Mat::create_batch, py::arg("w"), py::arg("h"), py::arg("c"), py::arg("batch"), py::kw_only(), py::arg("elemsize") = 4, py::arg("elempack") = 1, py::arg("allocator") = nullptr)
.def("create_batch", (void (Mat::*)(int, int, int, int, int, size_t, int, Allocator*)) & Mat::create_batch, py::arg("w"), py::arg("h"), py::arg("d"), py::arg("c"), py::arg("batch"), py::kw_only(), py::arg("elemsize") = 4, py::arg("elempack") = 1, py::arg("allocator") = nullptr)
.def("create_like", (void (Mat::*)(const Mat&, Allocator*)) & Mat::create_like, py::arg("m"), py::arg("allocator") = nullptr)
.def("addref", &Mat::addref)
.def("release", &Mat::release)
.def("empty", &Mat::empty)
.def("total", &Mat::total)
.def("elembits", &Mat::elembits)
.def("shape", &Mat::shape)
.def("batch", (Mat(Mat::*)(int)) & Mat::batch, py::arg("b"))
.def("channel", (Mat(Mat::*)(int)) & Mat::channel, py::arg("c"))
//.def("channel", (const Mat (Mat::*)(int) const) & Mat::channel, py::arg("c"))
.def("depth", (Mat(Mat::*)(int)) & Mat::depth, py::arg("z"))
Expand Down Expand Up @@ -471,6 +568,7 @@ PYBIND11_MODULE(ncnn, m)
.def_readwrite("d", &Mat::d)
.def_readwrite("c", &Mat::c)
.def_readwrite("cstep", &Mat::cstep)
.def_readwrite("n", &Mat::n)
.def("__repr__", [](const Mat& m) {
std::stringstream ss;
ss << "<ncnn.Mat w=" << m.w << " h=" << m.h << " d=" << m.d << " c=" << m.c << " dims=" << m.dims
Expand Down
15 changes: 15 additions & 0 deletions python/tests/test_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,21 @@ def test_numpy():
array2[0] = 100
assert array[0] == 100

def test_numpy_batch_index():
array = np.arange(3 * 5 * 2 * 7, dtype=np.float32).reshape(3, 5, 2, 7)
mat = ncnn.Mat(array, batch_index=2)
array2 = mat.numpy(batch_index=2)
assert (array == array2).all()

mat2 = mat.clone()
array3 = mat2.numpy(batch_index=2)
assert (array == array3).all()

array = np.arange(2 * 3 * 4 * 5 * 6, dtype=np.float32).reshape(2, 3, 4, 5, 6)
mat = ncnn.Mat(array, batch_index=-2)
array2 = mat.numpy(batch_index=-2)
assert (array == array2).all()

def test_fill():
mat = ncnn.Mat(1)
mat.fill(1.0)
Expand Down
Loading
Loading