Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 130 additions & 9 deletions brain4j-math/src/main/java/org/brain4j/math/tensor/impl/GpuTensor.java
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import org.brain4j.math.tensor.Shape;
import org.brain4j.math.tensor.Tensor;
import org.brain4j.math.tensor.index.Range;

import java.util.Arrays;

import static org.lwjgl.opencl.CL10.*;
Expand Down Expand Up @@ -147,9 +146,12 @@ public static void initKernels(Device device) {
long activationsProgram = DeviceUtils.createBuildProgram(device, "/kernels/basic/activations.cl");
long gradientClipProgram = DeviceUtils.createBuildProgram(device, "/kernels/basic/gradient_clippers.cl");
long attentionProgram = DeviceUtils.createBuildProgram(device, "/kernels/attention/flash_attention.cl");


long conv2dProgram = DeviceUtils.createBuildProgram(device, "/kernels/convolution/conv2d.cl");
GpuContext.register(device, "convolve2d_direct", conv2dProgram);

String[] tensorOpsKernels = { "slice", "concat_last_dim", "concat_copy_a", "concat_copy_b", "matmul_batched",
"add", "sub", "mul", "div", "sum_along_dim", "softmax_last_dim", "layer_norm" };
"add", "sub", "mul", "div", "sum_along_dim", "softmax_last_dim", "layer_norm" , "broadcast"};

for (String kernel : tensorOpsKernels) {
GpuContext.register(device, kernel, tensorOpsProgram);
Expand Down Expand Up @@ -641,13 +643,15 @@ public Tensor slice(Range... ranges) {

GpuTensor result = new GpuTensor(device, newShape);

int[] starts = new int[ranges.length];
int[] steps = new int[ranges.length];
int[] starts = new int[shape.length];
int[] steps = new int[shape.length];
Arrays.fill(steps, 1);

for (int i = 0; i < ranges.length; i++) {
Range range = ranges[i];
starts[i] = range == null ? 0 : range.start();
steps[i] = range == null ? 1 : range.step();
for (int i = 0; i < shape.length; i++) {
if (i < ranges.length && ranges[i] != null) {
starts[i] = ranges[i].start();
steps[i] = ranges[i].step();
}
}

long flags = CL_MEM_READ_WRITE | CL_MEM_COPY_HOST_PTR;
Expand Down Expand Up @@ -676,6 +680,60 @@ public Tensor slice(Range... ranges) {
return result;
}


@Override
public Tensor broadcast(int[] targetShape) {


if (Arrays.equals(shape, targetShape)) {
return this;
}
int outSize= 1;
for (int i=0; i<targetShape.length; i++)
outSize= outSize * targetShape[i];


int[] paddedSrcShape = new int[targetShape.length];
int[] paddedSrcStrides = new int[targetShape.length];

Arrays.fill(paddedSrcShape, 1);
Arrays.fill(paddedSrcStrides, 0);

int offset = targetShape.length - shape.length;
for (int i = 0; i < shape.length; i++) {
paddedSrcShape[i + offset] = shape[i];
paddedSrcStrides[i + offset] = strides[i];
}

GpuTensor result = new GpuTensor(device, targetShape);


long flags = CL_MEM_READ_WRITE | CL_MEM_COPY_HOST_PTR;

TempBuffer outShapeBuffer = device.createBuffer(flags, targetShape);

TempBuffer srcShapeBuf = device.createBuffer(flags, paddedSrcShape);
TempBuffer srcStridesBuf = device.createBuffer(flags, paddedSrcStrides);

try (GpuQueue queue = GpuContext.getOrCreate(device)) {
KernelFactory.create(device, "broadcast")
.addMemParam(this.dataBuffer)
.addMemParam(result.dataBuffer)
.addMemParam(srcShapeBuf)
.addMemParam(outShapeBuffer)
.addMemParam(srcStridesBuf)
.addIntParam(targetShape.length)
.addIntParam(outSize)
.launch(queue, 1, outSize);
}

outShapeBuffer.release();
srcStridesBuf.release();
srcShapeBuf.release();
return result;
}


@Override
public Tensor layerNorm(double epsilon) {
GpuTensor result = new GpuTensor(device, shape);
Expand All @@ -700,6 +758,69 @@ public Tensor layerNorm(double epsilon) {
return result;
}

@Override
public Tensor convolve(Tensor kernel) {
if (!(kernel instanceof GpuTensor gpuKernel)) {
return convolve(kernel.to(device));
}

// Porta a 4D: [batch, channels, height, width]
GpuTensor input = this;
while (input.rank() < 4) input = (GpuTensor) input.unsqueeze();

GpuTensor kern = gpuKernel;
while (kern.rank() < 4) kern = (GpuTensor) kern.unsqueeze();

int[] inShape = input.shape(); // [batch, inChannels, inH, inW]
int[] kShape = kern.shape(); // [numFilters, inChannels, kH, kW]

int batch = inShape[0];
int inChannels = inShape[1];
int inH = inShape[2];
int inW = inShape[3];
int numFilters = kShape[0];
int kH = kShape[2];
int kW = kShape[3];
int outH = inH - kH + 1;
int outW = inW - kW + 1;

GpuTensor result = new GpuTensor(device, new int[]{batch, numFilters, outH, outW});

long[] globalWorkSize = new long[]{outH, outW};

try (GpuQueue queue = GpuContext.getOrCreate(device)) {
for (int b = 0; b < batch; b++) {
for (int f = 0; f < numFilters; f++) {
int outOffset = (b * numFilters + f) * outH * outW;

for (int c = 0; c < inChannels; c++) {
int inOffset = (b * inChannels + c) * inH * inW;
int kerOffset = (f * inChannels + c) * kH * kW;

KernelFactory.create(device, "convolve2d_direct")
.addMemParam(input.getDataBuffer())
.addMemParam(kern.getDataBuffer())
.addMemParam(result.getDataBuffer())
.addIntParam(inH)
.addIntParam(inW)
.addIntParam(kH)
.addIntParam(kW)
.addIntParam(outH)
.addIntParam(outW)
.addIntParam(0) // paddingTop
.addIntParam(0) // paddingLeft
.addIntParam(inOffset) // ← NUOVO
.addIntParam(kerOffset) // ← NUOVO
.addIntParam(outOffset) // ← NUOVO
.launch(queue, 2, globalWorkSize);
}
}
}
}

return result;
}

@Override
public float[] data() {
float[] buffer = new float[size];
Expand Down
29 changes: 29 additions & 0 deletions brain4j-math/src/main/resources/kernels/basic/tensor_ops.cl
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,9 @@ __kernel void softmax_last_dim(
}
}




__kernel void slice(
__global const float* srcData,
__global float* dstData,
Expand Down Expand Up @@ -445,6 +448,32 @@ __kernel void slice(
dstData[dstLinearIdx] = srcData[srcOffset];
}

__kernel void broadcast(
__global const float* in,
__global float* out,
__global const int* inShape,
__global const int* outShape,
__global const int* inStrides,
const int rank,
const int outSize
) {
int dstLinearIdx = get_global_id(0);

if (dstLinearIdx >= outSize) return; // ← usa direttamente outSize

int tmp = dstLinearIdx;
int srcOffset = 0;

for (int i = rank - 1; i >= 0; i--) {
int idx = tmp % outShape[i];
tmp = tmp / outShape[i];

if (inShape[i] != 1) {srcOffset += idx * inStrides[i]; }
}

out[dstLinearIdx] = in[srcOffset];
}

__kernel void concat_last_dim(
__global const float* A,
__global const float* B,
Expand Down
99 changes: 99 additions & 0 deletions brain4j-math/src/test/java/TestSliceANDBroadCast.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import org.brain4j.math.gpu.device.Device;
import org.brain4j.math.gpu.device.DeviceUtils;
import org.brain4j.math.tensor.impl.GpuTensor;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.brain4j.math.tensor.index.Range;// ← aggiunto

import java.util.Arrays;

import static org.junit.jupiter.api.Assertions.assertArrayEquals; // ← era org.junit.Assert
import static org.junit.jupiter.api.Assertions.assertNotNull;

public class TestSliceANDBroadCast {
private static Device device;

@BeforeAll
static void setup() {
device = DeviceUtils.findDevice(null);
assertNotNull(device, "No GPU found!");
GpuTensor.initKernels(device);
}

@Test
public void testBroadcastRows() {
GpuTensor input = new GpuTensor(device, new int[]{1, 3}, 10f, 20f, 30f);
int[] in = new int[]{4,3};
float[] result = input.broadcast(in).data();
float[] expected = {10,20,30, 10,20,30, 10,20,30, 10,20,30};
System.out.println("testBroadcastRows: " + Arrays.toString(result));
assertArrayEquals(expected, result, 1e-4f);
}

@Test
public void testBroadcastCols() {
GpuTensor input = new GpuTensor(device, new int[]{4, 1}, 10f, 20f, 30f, 40f);
int[] in = new int[]{4,3};
float[] result = input.broadcast(in).data();
float[] expected = {10,10,10, 20,20,20, 30,30,30, 40,40,40};
System.out.println("testBroadcastCols: " + Arrays.toString(result));
assertArrayEquals(expected, result, 1e-4f);
}

@Test
public void testBroadcastScalar() {
GpuTensor input = new GpuTensor(device, new int[]{1, 1}, 42f);
int[] in = new int[]{4,3};
float[] result = input.broadcast(in).data();
float[] expected = new float[12];
Arrays.fill(expected, 42f);
System.out.println("testBroadcastScalar: " + Arrays.toString(result));
assertArrayEquals(expected, result, 1e-4f);
}

@Test
public void testBroadcastAlreadyCorrectShape() {
GpuTensor input = new GpuTensor(device, new int[]{4, 3},
1f,2f,3f, 4f,5f,6f, 7f,8f,9f, 10f,11f,12f);
int[] in = new int[]{4,3};
float[] result = input.broadcast(in).data();
float[] expected = {1,2,3, 4,5,6, 7,8,9, 10,11,12};
System.out.println("testBroadcastAlreadyCorrectShape: " + Arrays.toString(result));
assertArrayEquals(expected, result, 1e-4f);
}

@Test
public void testSlicePartialRanges() {
GpuTensor input = new GpuTensor(device, new int[]{4, 3},
1f,2f,3f, 4f,5f,6f, 7f,8f,9f, 10f,11f,12f);

float[] result = input.slice(Range.interval(0, 2)).data();
float[] expected = {1f, 2f, 3f, 4f, 5f, 6f};

System.out.println("testSlicePartialRanges: 1 " + Arrays.toString(result));
assertArrayEquals(expected, result, 1e-4f);


result = input.slice(new Range(1,4,2)).data();
expected = new float[]{4f, 5f, 6f, 10f, 11f, 12f};

System.out.println("testSlicePartialRanges: 2" + Arrays.toString(result));
assertArrayEquals(expected, result, 1e-4f);
}

@Test
public void testBroadcast() {
GpuTensor input = new GpuTensor(device, new int[]{2, 3},
1f,2f,3f, 4f,5f,6f);


float[] result = input.broadcast(new int[]{4,2,3}).data();
float[] expected = {1f, 2f, 3f, 4f, 5f, 6f, 1f, 2f, 3f, 4f, 5f, 6f
, 1f, 2f, 3f, 4f, 5f, 6f, 1f, 2f, 3f, 4f, 5f, 6f};

System.out.println("testBroadcast: 1 " + Arrays.toString(result));
assertArrayEquals(expected, result, 1e-4f);


}
}