Skip to content

Commit 51e0c2d

Browse files
authored
cuda : add FILL op support (#17851)
* cuda : add FILL op support * cuda : add missing FILL op files
1 parent 37a4f63 commit 51e0c2d

File tree

3 files changed

+45
-0
lines changed

3 files changed

+45
-0
lines changed

ggml/src/ggml-cuda/fill.cu

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#include "fill.cuh"
2+
#include "convert.cuh"
3+
4+
#define CUDA_FILL_BLOCK_SIZE 256
5+
6+
template <typename T>
7+
static __global__ void fill_kernel(T * __restrict__ dst, const int64_t k, const T value) {
8+
const int64_t i = (int64_t)blockDim.x * blockIdx.x + threadIdx.x;
9+
if (i >= k) {
10+
return;
11+
}
12+
dst[i] = value;
13+
}
14+
15+
void ggml_cuda_op_fill(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
16+
void * dst_d = dst->data;
17+
cudaStream_t stream = ctx.stream();
18+
19+
GGML_ASSERT(ggml_is_contiguous(dst));
20+
21+
float value;
22+
memcpy(&value, dst->op_params, sizeof(float));
23+
24+
const int64_t k = ggml_nelements(dst);
25+
const int64_t num_blocks = (k + CUDA_FILL_BLOCK_SIZE - 1) / CUDA_FILL_BLOCK_SIZE;
26+
27+
switch (dst->type) {
28+
case GGML_TYPE_F32:
29+
fill_kernel<<<num_blocks, CUDA_FILL_BLOCK_SIZE, 0, stream>>>((float *)dst_d, k, value);
30+
break;
31+
case GGML_TYPE_F16:
32+
fill_kernel<<<num_blocks, CUDA_FILL_BLOCK_SIZE, 0, stream>>>((half *)dst_d, k, ggml_cuda_cast<half>(value));
33+
break;
34+
default:
35+
GGML_ABORT("unsupported type");
36+
}
37+
}

ggml/src/ggml-cuda/fill.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#include "common.cuh"
2+
3+
void ggml_cuda_op_fill(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
#include "ggml-cuda/solve_tri.cuh"
5757
#include "ggml-cuda/tri.cuh"
5858
#include "ggml-cuda/cumsum.cuh"
59+
#include "ggml-cuda/fill.cuh"
5960
#include "ggml.h"
6061

6162
#include <algorithm>
@@ -2730,6 +2731,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
27302731
case GGML_OP_SOLVE_TRI:
27312732
ggml_cuda_op_solve_tri(ctx, dst);
27322733
break;
2734+
case GGML_OP_FILL:
2735+
ggml_cuda_op_fill(ctx, dst);
2736+
break;
27332737
default:
27342738
return false;
27352739
}
@@ -4617,6 +4621,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
46174621
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
46184622
case GGML_OP_OPT_STEP_ADAMW:
46194623
case GGML_OP_OPT_STEP_SGD:
4624+
case GGML_OP_FILL:
46204625
case GGML_OP_CUMSUM:
46214626
case GGML_OP_TRI:
46224627
return true;

0 commit comments

Comments
 (0)