diff --git a/ggml/src/ggml-cuda/fill.cu b/ggml/src/ggml-cuda/fill.cu new file mode 100644 index 00000000000..eb8ccb7802b --- /dev/null +++ b/ggml/src/ggml-cuda/fill.cu @@ -0,0 +1,37 @@ +#include "fill.cuh" +#include "convert.cuh" + +#define CUDA_FILL_BLOCK_SIZE 256 + +template +static __global__ void fill_kernel(T * __restrict__ dst, const int64_t k, const T value) { + const int64_t i = (int64_t)blockDim.x * blockIdx.x + threadIdx.x; + if (i >= k) { + return; + } + dst[i] = value; +} + +void ggml_cuda_op_fill(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + void * dst_d = dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(dst)); + + float value; + memcpy(&value, dst->op_params, sizeof(float)); + + const int64_t k = ggml_nelements(dst); + const int64_t num_blocks = (k + CUDA_FILL_BLOCK_SIZE - 1) / CUDA_FILL_BLOCK_SIZE; + + switch (dst->type) { + case GGML_TYPE_F32: + fill_kernel<<>>((float *)dst_d, k, value); + break; + case GGML_TYPE_F16: + fill_kernel<<>>((half *)dst_d, k, ggml_cuda_cast(value)); + break; + default: + GGML_ABORT("unsupported type"); + } +} diff --git a/ggml/src/ggml-cuda/fill.cuh b/ggml/src/ggml-cuda/fill.cuh new file mode 100644 index 00000000000..8443c83620d --- /dev/null +++ b/ggml/src/ggml-cuda/fill.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_fill(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 235d94d5007..d0463388c54 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -56,6 +56,7 @@ #include "ggml-cuda/solve_tri.cuh" #include "ggml-cuda/tri.cuh" #include "ggml-cuda/cumsum.cuh" +#include "ggml-cuda/fill.cuh" #include "ggml.h" #include @@ -2730,6 +2731,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_SOLVE_TRI: ggml_cuda_op_solve_tri(ctx, dst); break; + case GGML_OP_FILL: + ggml_cuda_op_fill(ctx, dst); + break; default: return false; } @@ -4617,6 +4621,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CROSS_ENTROPY_LOSS_BACK: case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_SGD: + case GGML_OP_FILL: case GGML_OP_CUMSUM: case GGML_OP_TRI: return true;