Skip to content
6 changes: 6 additions & 0 deletions src/layer/arm/reshape_arm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ Reshape_arm::Reshape_arm()

int Reshape_arm::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
{
if (batch_mode != 0)
return Reshape::forward(bottom_blobs, top_blobs, opt);

const Mat& bottom_blob = bottom_blobs[0];
Mat& top_blob = top_blobs[0];

Expand Down Expand Up @@ -317,6 +320,9 @@ int Reshape_arm::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>&

int Reshape_arm::forward_bf16s_fp16s(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
{
if (batch_mode != 0)
return Reshape::forward(bottom_blobs, top_blobs, opt);

const Mat& bottom_blob = bottom_blobs[0];
Mat& top_blob = top_blobs[0];

Expand Down
6 changes: 6 additions & 0 deletions src/layer/loongarch/reshape_loongarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ Reshape_loongarch::Reshape_loongarch()

int Reshape_loongarch::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
{
if (batch_mode != 0)
return Reshape::forward(bottom_blobs, top_blobs, opt);

const Mat& bottom_blob = bottom_blobs[0];
Mat& top_blob = top_blobs[0];

Expand Down Expand Up @@ -480,6 +483,9 @@ int Reshape_loongarch::forward(const std::vector<Mat>& bottom_blobs, std::vector

int Reshape_loongarch::forward_bf16s(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
{
if (batch_mode != 0)
return Reshape::forward(bottom_blobs, top_blobs, opt);

const Mat& bottom_blob = bottom_blobs[0];
Mat& top_blob = top_blobs[0];

Expand Down
6 changes: 6 additions & 0 deletions src/layer/mips/reshape_mips.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ Reshape_mips::Reshape_mips()

int Reshape_mips::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
{
if (batch_mode != 0)
return Reshape::forward(bottom_blobs, top_blobs, opt);

const Mat& bottom_blob = bottom_blobs[0];
Mat& top_blob = top_blobs[0];

Expand Down Expand Up @@ -327,6 +330,9 @@ int Reshape_mips::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>

int Reshape_mips::forward_bf16s(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
{
if (batch_mode != 0)
return Reshape::forward(bottom_blobs, top_blobs, opt);

const Mat& bottom_blob = bottom_blobs[0];
Mat& top_blob = top_blobs[0];

Expand Down
122 changes: 118 additions & 4 deletions src/layer/reshape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@

#include "expression.h"

#include <string.h>

namespace ncnn {

Reshape::Reshape()
{
one_blob_only = true;
support_inplace = false;
batch_mode = 0;
}

int Reshape::load_param(const ParamDict& pd)
Expand All @@ -30,6 +33,14 @@ int Reshape::load_param(const ParamDict& pd)
if (w == -233)
ndim = 0;

batch_mode = pd.get(12, 0);
if (batch_mode != 0)
{
support_batch = true;
support_packing = false;
support_vulkan_packing = false;
}

shape_expr = pd.get(6, "");

// count reference blobs
Expand Down Expand Up @@ -79,9 +90,14 @@ int Reshape::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top
}

int total = bottom_blob.w * bottom_blob.h * bottom_blob.d * bottom_blob.c;
if (batch_mode == 1)
total *= bottom_blob.n;

int dims = bottom_blob.dims;

if (batch_mode != 0 && ndim == 0)
return -1;

if (ndim == 1)
{
if (outw == 0)
Expand All @@ -90,7 +106,7 @@ int Reshape::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top
if (outw == -1)
outw = total;

if (dims == 1 && bottom_blob.w == outw)
if (batch_mode == 0 && dims == 1 && bottom_blob.w == outw)
{
top_blob = bottom_blob;
return 0;
Expand All @@ -108,7 +124,7 @@ int Reshape::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top
if (outh == -1)
outh = total / outw;

if (dims == 2 && bottom_blob.h == outh)
if (batch_mode == 0 && dims == 2 && bottom_blob.h == outh)
{
top_blob = bottom_blob;
return 0;
Expand All @@ -130,7 +146,7 @@ int Reshape::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top
if (outc == -1)
outc = total / outh / outw;

if (dims == 3 && bottom_blob.c == outc)
if (batch_mode == 0 && dims == 3 && bottom_blob.c == outc)
{
top_blob = bottom_blob;
top_blob.w = outw;
Expand Down Expand Up @@ -158,7 +174,7 @@ int Reshape::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top
if (outc == -1)
outc = total / outd / outh / outw;

if (dims == 4 && bottom_blob.c == outc)
if (batch_mode == 0 && dims == 4 && bottom_blob.c == outc)
{
top_blob = bottom_blob;
top_blob.w = outw;
Expand All @@ -168,6 +184,104 @@ int Reshape::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top
}
}

if (batch_mode == 1)
{
if (bottom_blob.elempack != 1)
return -1;

Mat bottom_blob_flattened(total, bottom_blob.elemsize, opt.blob_allocator);
if (bottom_blob_flattened.empty())
return -100;

unsigned char* outptr = bottom_blob_flattened;
const size_t size = (size_t)bottom_blob.w * bottom_blob.h * bottom_blob.d * bottom_blob.elemsize;
for (int b = 0; b < bottom_blob.n; b++)
{
const Mat bottom_blob_b = bottom_blob.batch(b);
for (int q = 0; q < bottom_blob.c; q++)
{
const unsigned char* ptr = (const unsigned char*)bottom_blob_b + bottom_blob.cstep * q * bottom_blob.elemsize;
memcpy(outptr, ptr, size);
outptr += size;
}
}

if (ndim == 1)
top_blob = bottom_blob_flattened.reshape(outw, opt.blob_allocator);
if (ndim == 2)
top_blob = bottom_blob_flattened.reshape(outw, outh, opt.blob_allocator);
if (ndim == 3)
top_blob = bottom_blob_flattened.reshape(outw, outh, outc, opt.blob_allocator);
if (ndim == 4)
top_blob = bottom_blob_flattened.reshape(outw, outh, outd, outc, opt.blob_allocator);

if (top_blob.empty())
return -100;

return 0;
}

if (batch_mode == 2)
{
if (bottom_blob.n != 1 || bottom_blob.elempack != 1)
return -1;

size_t out_total = outw;
if (ndim == 2)
out_total *= outh;
if (ndim == 3)
out_total *= (size_t)outh * outc;
if (ndim == 4)
out_total *= (size_t)outh * outd * outc;

if (out_total == 0)
return -1;

const size_t bottom_total = (size_t)bottom_blob.w * bottom_blob.h * bottom_blob.d * bottom_blob.c;
const int batch = bottom_total / out_total;
if ((size_t)batch * out_total != bottom_total)
return -1;

if (ndim == 1)
top_blob.create_batch(outw, batch, bottom_blob.elemsize, 1, opt.blob_allocator);
if (ndim == 2)
top_blob.create_batch(outw, outh, batch, bottom_blob.elemsize, 1, opt.blob_allocator);
if (ndim == 3)
top_blob.create_batch(outw, outh, outc, batch, bottom_blob.elemsize, 1, opt.blob_allocator);
if (ndim == 4)
top_blob.create_batch(outw, outh, outd, outc, batch, bottom_blob.elemsize, 1, opt.blob_allocator);

if (top_blob.empty())
return -100;

Mat bottom_blob_flattened(bottom_total, bottom_blob.elemsize, opt.workspace_allocator);
if (bottom_blob_flattened.empty())
return -100;

unsigned char* outptr = bottom_blob_flattened;
const size_t size = (size_t)bottom_blob.w * bottom_blob.h * bottom_blob.d * bottom_blob.elemsize;
for (int q = 0; q < bottom_blob.c; q++)
{
const unsigned char* ptr = (const unsigned char*)bottom_blob + bottom_blob.cstep * q * bottom_blob.elemsize;
memcpy(outptr, ptr, size);
outptr += size;
}

const unsigned char* ptr = bottom_blob_flattened;
const size_t out_channel_size = (size_t)top_blob.w * top_blob.h * top_blob.d * bottom_blob.elemsize;
for (int b = 0; b < batch; b++)
{
Mat top_blob_b = top_blob.batch(b);
for (int q = 0; q < top_blob.c; q++)
{
memcpy(top_blob_b.channel(q), ptr, out_channel_size);
ptr += out_channel_size;
}
}

return 0;
}

if (ndim == 1)
{
top_blob = bottom_blob.reshape(outw, opt.blob_allocator);
Expand Down
1 change: 1 addition & 0 deletions src/layer/reshape.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class Reshape : public Layer
int c;

int ndim;
int batch_mode;

// see docs/developer-guide/expression.md
std::string shape_expr;
Expand Down
Loading
Loading