Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,8 @@ extern "C" {
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
GGML_OP_OPT_STEP_ADAMW,

GGML_OP_FILL,

GGML_OP_COUNT,
};

Expand Down Expand Up @@ -1818,6 +1820,12 @@ extern "C" {
float stop,
float step);

// fill in-place the tensor with a constant value, return view(a)
GGML_API struct ggml_tensor * ggml_fill(
struct ggml_context * ctx,
struct ggml_tensor * a,
float value);

// top k elements per row
GGML_API struct ggml_tensor * ggml_top_k(
struct ggml_context * ctx,
Expand Down
7 changes: 6 additions & 1 deletion ggml/src/ggml-cpu/ggml-cpu.c
Original file line number Diff line number Diff line change
Expand Up @@ -1959,6 +1959,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_arange(params, tensor);
} break;
case GGML_OP_FILL:
{
ggml_compute_forward_fill(params, tensor);
} break;
case GGML_OP_TIMESTEP_EMBEDDING:
{
ggml_compute_forward_timestep_embedding(params, tensor);
Expand Down Expand Up @@ -2242,6 +2246,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_TRANSPOSE:
case GGML_OP_GET_ROWS_BACK:
case GGML_OP_DIAG:
case GGML_OP_ARANGE:
{
n_tasks = 1;
} break;
Expand Down Expand Up @@ -2279,7 +2284,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_UPSCALE:
case GGML_OP_PAD:
case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_ARANGE:
case GGML_OP_FILL:
case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_ARGSORT:
case GGML_OP_FLASH_ATTN_EXT:
Expand Down
49 changes: 49 additions & 0 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6833,6 +6833,55 @@ void ggml_compute_forward_arange(
}
}

// ggml_compute_forward_fill

static void ggml_compute_forward_fill_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
float v;
memcpy(&v, dst->op_params, sizeof(float));

const int ith = params->ith;
const int nth = params->nth;

const int n = ggml_nrows(dst);
const int nc = dst->ne[0];

const size_t nb00 = dst->nb[0];

const size_t nb0 = dst->nb[0];
const size_t nb1 = dst->nb[1];

GGML_ASSERT( nb0 == sizeof(float));
GGML_ASSERT(nb00 == sizeof(float));

for (int j = ith; j < n; j += nth) {
float * dst_ptr = (float *) ((char *) dst->data + j*nb1);

for (int i = 0; i < nc; i++) {
dst_ptr[i] = v;
}
}
}

void ggml_compute_forward_fill(
const ggml_compute_params * params,
ggml_tensor * dst) {

const ggml_tensor * src0 = dst->src[0];

switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_fill_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}

static void ggml_compute_forward_timestep_embedding_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-cpu/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ void ggml_compute_forward_upscale(const struct ggml_compute_params * params, str
void ggml_compute_forward_pad(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_pad_reflect_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_arange(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_fill(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_timestep_embedding(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_argsort(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_leaky_relu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
Expand Down
22 changes: 20 additions & 2 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -982,9 +982,11 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS",
"CROSS_ENTROPY_LOSS_BACK",
"OPT_STEP_ADAMW",

"FILL",
};

static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");

static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
Expand Down Expand Up @@ -1077,9 +1079,11 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss(x,y)",
"cross_entropy_loss_back(x,y)",
"adamw(x)",

"fill(x)",
};

static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");

static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");

Expand Down Expand Up @@ -4342,6 +4346,20 @@ struct ggml_tensor * ggml_arange(
return result;
}

struct ggml_tensor * ggml_fill(
struct ggml_context * ctx,
struct ggml_tensor * a,
float value) {
struct ggml_tensor * result = ggml_view_tensor(ctx, a);

ggml_set_op_params(result, &value, sizeof(value));

result->op = GGML_OP_FILL;
result->src[0] = a;

return result;
}

// ggml_timestep_embedding

struct ggml_tensor * ggml_timestep_embedding(
Expand Down
28 changes: 28 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2422,6 +2422,32 @@ struct test_clamp : public test_case {
}
};

// GGML_OP_FILL
struct test_fill : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
float v;

std::string vars() override {
return VARS_TO_STR3(type, ne, v);
}

test_fill(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 5, 4, 3},
float v = 0.5f)
: type(type), ne(ne), v(v) {}

ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_name(a, "a");

ggml_tensor * out = ggml_fill(ctx, a, v);
ggml_set_name(out, "out");

return out;
}
};

// GGML_OP_DIAG_MASK_INF
struct test_diag_mask_inf : public test_case {
const ggml_type type;
Expand Down Expand Up @@ -4199,6 +4225,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 32, 4));
test_cases.emplace_back(new test_gla(GGML_TYPE_F32, 32, 64, 128, 4));

test_cases.emplace_back(new test_fill(GGML_TYPE_F32));

for (ggml_type type_a : all_types) {
for (int i = 1; i < 10; ++i) {
test_cases.emplace_back(new test_mul_mat(type_a, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
Expand Down
Loading