Skip to content

Commit 82209ef

Browse files
authored
vulkan: Support F16 OP_FILL (#22177)
1 parent 9998d88 commit 82209ef

2 files changed

Lines changed: 8 additions & 1 deletion

File tree

ggml/src/ggml-vulkan/ggml-vulkan.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -792,6 +792,7 @@ struct vk_device_struct {
792792
vk_pipeline pipeline_arange_f32;
793793

794794
vk_pipeline pipeline_fill_f32;
795+
vk_pipeline pipeline_fill_f16;
795796

796797
vk_pipeline pipeline_geglu[2];
797798
vk_pipeline pipeline_reglu[2];
@@ -4577,6 +4578,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
45774578
ggml_vk_create_pipeline(device, device->pipeline_arange_f32, "arange_f32", arange_f32_len, arange_f32_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
45784579

45794580
ggml_vk_create_pipeline(device, device->pipeline_fill_f32, "fill_f32", fill_f32_len, fill_f32_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
4581+
ggml_vk_create_pipeline(device, device->pipeline_fill_f16, "fill_f16", fill_f16_len, fill_f16_data, "main", 1, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
45804582

45814583
#define CREATE_GLU(name) \
45824584
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
@@ -9844,6 +9846,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
98449846
if (dst->type == GGML_TYPE_F32) {
98459847
return ctx->device->pipeline_fill_f32;
98469848
}
9849+
if (dst->type == GGML_TYPE_F16) {
9850+
return ctx->device->pipeline_fill_f16;
9851+
}
98479852
return nullptr;
98489853
default:
98499854
return nullptr;
@@ -15713,8 +15718,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1571315718
|| (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32)
1571415719
|| (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16);
1571515720
case GGML_OP_ARANGE:
15716-
case GGML_OP_FILL:
1571715721
return op->type == GGML_TYPE_F32;
15722+
case GGML_OP_FILL:
15723+
return op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16;
1571815724
case GGML_OP_SCALE:
1571915725
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
1572015726
case GGML_OP_PAD:

ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,7 @@ void process_shaders() {
889889
string_to_spv("add1_f32_f32", "add1.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
890890
string_to_spv("arange_f32", "arange.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
891891
string_to_spv("fill_f32", "fill.comp", {{"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
892+
string_to_spv("fill_f16", "fill.comp", {{"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
892893
string_to_spv("step_f16", "step.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
893894
string_to_spv("step_f32", "step.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
894895
string_to_spv("round_f16", "round.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});

0 commit comments

Comments
 (0)