Skip to content

Commit d22704c

Browse files
committed
cuda : add missing FILL op files
1 parent d91f4f9 commit d22704c

File tree

3 files changed

+43
-0
lines changed

3 files changed

+43
-0
lines changed

ggml/src/ggml-cuda/fill.cu

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#include "fill.cuh"
2+
3+
#define CUDA_FILL_BLOCK_SIZE 256
4+
5+
template <typename T>
6+
static __global__ void fill_kernel(T * __restrict__ dst, const int64_t k, const T value) {
7+
const int64_t i = (int64_t)blockDim.x * blockIdx.x + threadIdx.x;
8+
if (i >= k) {
9+
return;
10+
}
11+
dst[i] = value;
12+
}
13+
14+
void ggml_cuda_op_fill(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
15+
void * dst_d = dst->data;
16+
cudaStream_t stream = ctx.stream();
17+
18+
GGML_ASSERT(ggml_is_contiguous(dst));
19+
20+
float value;
21+
memcpy(&value, dst->op_params, sizeof(float));
22+
23+
const int64_t k = ggml_nelements(dst);
24+
const int64_t num_blocks = (k + CUDA_FILL_BLOCK_SIZE - 1) / CUDA_FILL_BLOCK_SIZE;
25+
26+
switch (dst->type) {
27+
case GGML_TYPE_F32:
28+
fill_kernel<<<num_blocks, CUDA_FILL_BLOCK_SIZE, 0, stream>>>((float *)dst_d, k, value);
29+
break;
30+
case GGML_TYPE_F16:
31+
fill_kernel<<<num_blocks, CUDA_FILL_BLOCK_SIZE, 0, stream>>>((half *)dst_d, k, ggml_cuda_cast<half>(value));
32+
break;
33+
default:
34+
GGML_ABORT("unsupported type");
35+
}
36+
}

ggml/src/ggml-cuda/fill.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#include "common.cuh"
2+
3+
void ggml_cuda_op_fill(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2731,6 +2731,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
27312731
case GGML_OP_SOLVE_TRI:
27322732
ggml_cuda_op_solve_tri(ctx, dst);
27332733
break;
2734+
case GGML_OP_FILL:
2735+
ggml_cuda_op_fill(ctx, dst);
2736+
break;
27342737
default:
27352738
return false;
27362739
}
@@ -4618,6 +4621,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
46184621
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
46194622
case GGML_OP_OPT_STEP_ADAMW:
46204623
case GGML_OP_OPT_STEP_SGD:
4624+
case GGML_OP_FILL:
46214625
case GGML_OP_CUMSUM:
46224626
case GGML_OP_TRI:
46234627
return true;

0 commit comments

Comments
 (0)