@@ -44,26 +44,26 @@ GgmlOvDecoder::GgmlOvDecoder(ggml_tensor * node,
4444 int num_heads_kv,
4545 int head_size,
4646 const std::vector<int > & swa_layers) :
47+ m_is_static(is_static),
4748 m_cgraph(cgraph),
4849 m_node(node),
4950 m_op_name(std::string(node->name)),
50- m_context_size(context_size),
51- m_context_size_swa(context_size_swa),
52- m_swa_layers(swa_layers),
53- m_num_heads(num_heads),
54- m_num_heads_kv(num_heads_kv),
51+ m_ctx(context_size),
52+ m_ctx_swa(context_size_swa),
53+ m_n_heads(num_heads),
54+ m_n_heads_kv(num_heads_kv),
5555 m_head_size(head_size),
56- m_is_static(is_static ) {
56+ m_swa_layers(swa_layers ) {
5757 set_input_output (node);
5858}
5959
6060GgmlOvDecoder::GgmlOvDecoder (ggml_cgraph * cgraph,
6161 std::map<std::string, std::shared_ptr<ov::Node>> & model_weights,
6262 bool is_static) :
63+ m_is_static(is_static),
6364 m_cgraph(cgraph),
6465 m_op_name(m_node ? std::string(m_node->name) : ""),
65- m_model_weights(model_weights),
66- m_is_static(is_static) {
66+ m_model_weights(model_weights) {
6767 if (auto * env = getenv (" GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS" ); env && std::string (env) != " 0" ) {
6868 unsetenv (" GGML_OPENVINO_PRINT_CGRAPH_TENSOR_ADDRESS" );
6969 print_tensor_address_map (cgraph);
@@ -78,7 +78,7 @@ GgmlOvDecoder::GgmlOvDecoder(ggml_cgraph * cgraph,
7878 set_input_output (cur_node);
7979 }
8080
81- // add_extra_inputs();
81+ add_extra_inputs ();
8282}
8383
8484GgmlOvDecoder::GgmlOvDecoder (ggml_cgraph * cgraph, std::map<std::string, std::shared_ptr<ov::Node>> & model_weights) {
@@ -125,7 +125,8 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) {
125125 // Add model inputs and weights constants, if called for the whole graph
126126 if (naive) {
127127 if (m_model_weights.find (src_name) == m_model_weights.end ()) {
128- auto param_node = std::make_shared<ov::op::v0::Parameter>(get_ov_type (src), get_graph_input_shape (src));
128+ auto param_node =
129+ std::make_shared<ov::op::v0::Parameter>(get_ov_type (src), get_graph_input_shape (node, src));
129130 param_node->set_friendly_name (src_name);
130131 param_node->output (0 ).get_tensor ().set_names ({src_name});
131132 m_model_inputs[src_name] = param_node;
@@ -142,7 +143,8 @@ void GgmlOvDecoder::set_input_output(ggml_tensor * node, bool naive) {
142143 if (m_model_inputs.find (src_name) != m_model_inputs.end ()) {
143144 continue ;
144145 }
145- auto param_node = std::make_shared<ov::op::v0::Parameter>(get_ov_type (src), get_graph_input_shape (src));
146+ auto param_node =
147+ std::make_shared<ov::op::v0::Parameter>(get_ov_type (src), get_graph_input_shape (node, src));
146148 param_node->set_friendly_name (src_name);
147149 param_node->output (0 ).get_tensor ().set_names ({src_name});
148150 m_model_inputs[src_name] = param_node;
@@ -191,63 +193,93 @@ void GgmlOvDecoder::set_llm_params() {
191193 auto * node = m_cgraph->nodes [i];
192194 std::string name = std::string (node->name );
193195 if (node->op == GGML_OP_FLASH_ATTN_EXT) {
194- auto * cache_k = node->src [1 ];
195- cache_k = cache_k->view_src ? cache_k->view_src : cache_k;
196+ auto * cache_k_perm = node->src [1 ];
197+ assert (cache_k_perm->op == GGML_OP_PERMUTE);
198+ auto * cache_k_view = cache_k_perm->src [0 ];
199+ assert (cache_k_view->op == GGML_OP_VIEW);
200+
201+ auto * cache_k = cache_k_view->src [0 ];
196202 int layer = extract_layer_from_name (cache_k->name );
203+ auto * mask = node->src [3 ];
204+ std::string mask_name (mask->name );
205+ assert (mask_name.find (" KQ_mask" ) == 0 );
197206
198207 if (std::string (node->src [3 ]->name ).find (" swa" ) != std::string::npos) {
199208 m_swa_layers.push_back (layer);
200- m_context_size_swa = cache_k->ne [1 ];
209+ m_ctx_per_seq_swa = cache_k->ne [1 ];
210+ } else {
211+ m_ctx_per_seq = cache_k->ne [1 ];
212+ m_n_seq = cache_k->ne [2 ];
213+ }
214+
215+ m_n_seq_active = mask->ne [3 ];
216+ auto seq_size = cache_k->ne [0 ] * cache_k->ne [1 ] * ggml_type_size (cache_k->type );
217+ size_t offset;
218+ memcpy (&offset, cache_k_view->op_params , sizeof (size_t ));
219+ m_seq_active_start = offset / seq_size;
220+ m_token_len_per_seq = node->ne [2 ];
221+
222+ if (mask_name.find (" swa" ) != std::string::npos) {
223+ m_attention_size_swa = mask->ne [0 ];
201224 } else {
202- m_context_size = cache_k ->ne [1 ];
225+ m_attention_size = mask ->ne [0 ];
203226 }
227+ if (m_is_static) {
228+ m_attention_size = m_ctx_per_seq;
229+ m_attention_size_swa = m_ctx_per_seq_swa;
230+ m_token_len_per_seq = 1 ;
231+ }
232+
204233 } else if (node->op == GGML_OP_ROPE) {
205234 if (name.find (" Qcur-0" ) == 0 || std::string (node->src [0 ]->name ).find (" Qcur-0" ) == 0 ) {
206235 m_head_size = node->ne [0 ];
207236 m_rope_params = node->op_params ;
208237 auto * inp_pos = node->src [1 ];
209238 m_input_len = inp_pos->ne [0 ];
210- m_past_kv_len = *(int32_t *) inp_pos->data ;
211239 } else if (name.find (" Kcur-0" ) == 0 || std::string (node->src [0 ]->name ).find (" Kcur-0" ) == 0 ) {
212- m_num_heads_kv = node->ne [1 ];
240+ m_n_heads_kv = node->ne [1 ];
213241 }
214242 }
215243 }
244+ m_ctx = m_ctx_per_seq * m_n_seq;
245+ m_ctx_swa = m_ctx_per_seq_swa * m_n_seq;
216246}
217247
218- void GgmlOvDecoder::validate_cgraph () const {}
248+ void GgmlOvDecoder::validate_cgraph () const {
249+ if (m_n_seq > 1 && m_is_static == true ) {
250+ throw std::runtime_error (" n_seq > 1 is not supported on NPU. Try setting -np 1." );
251+ }
252+ }
219253
220- ov::PartialShape GgmlOvDecoder::get_graph_input_shape (const ggml_tensor * src ) const {
221- auto name = std::string (src ->name );
254+ ov::PartialShape GgmlOvDecoder::get_graph_input_shape (const ggml_tensor * op, const ggml_tensor * input ) const {
255+ auto name = std::string (input ->name );
222256 ov::PartialShape input_shape;
223257
224258 if (name == " inp_tokens" || name == " inp_pos" || name == " inp_out_ids" ) {
225- input_shape = ov::PartialShape{1 , 1 , m_is_static ? 1 : -1 };
259+ input_shape = ov::PartialShape{1 , 1 , 1 , m_is_static ? 1 : -1 };
226260
227261 } else if (name.find (" KQ_mask" ) == 0 ) {
228262 if (m_is_static) {
229- input_shape = ov::PartialShape{1 , 1 , m_context_size };
263+ input_shape = ov::PartialShape{1 , 1 , 1 , m_ctx };
230264 } else {
231- input_shape = ov::PartialShape{1 , -1 , -1 };
265+ input_shape = ov::PartialShape{- 1 , 1 , -1 , -1 };
232266 }
233267
234268 } else if (name.find (" cache_" ) == 0 ) {
235- auto past_token_len = -1 ;
236- if (m_is_static) {
237- int layer = extract_layer_from_name (name);
238- bool is_swa = is_swa_layer (layer);
239- past_token_len = is_swa ? m_context_size_swa : m_context_size;
269+ input_shape = ov::PartialShape{get_shape (input)};
270+ if (!m_is_static) {
271+ // do not fix ctx size to make llama-bench work
272+ input_shape[2 ] = -1 ;
240273 }
241- input_shape = ov::PartialShape{past_token_len, m_num_heads_kv, m_head_size};
242274
243- } else if (const auto * op = get_tensor_used_op (src); op && op->op == GGML_OP_SET_ROWS ) {
244- input_shape = ov::PartialShape{1 , 1 , m_is_static ? 1 : -1 };
275+ } else if (op && op-> op == GGML_OP_SET_ROWS && op->src [ 1 ] == input ) {
276+ input_shape = ov::PartialShape{1 , 1 , 1 , m_is_static ? 1 : -1 };
245277
246- } else if (src ->op == GGML_OP_VIEW) {
278+ } else if (input ->op == GGML_OP_VIEW) {
247279 // This case is added to make test-backend-ops work
248- input_shape = ov::PartialShape{get_shape (src ->view_src )};
280+ input_shape = ov::PartialShape{get_shape (input ->view_src )};
249281 } else {
250- input_shape = ov::PartialShape{get_shape (src )};
282+ input_shape = ov::PartialShape{get_shape (input )};
251283 }
252284 return input_shape;
253285}
@@ -256,40 +288,35 @@ void GgmlOvDecoder::add_extra_inputs() {
256288 // Extra inputs:
257289 // 1. `attention_size`, used in FLASH_ATTN where the shape of the matmul's are 256 aligned,
258290 // see llama_kv_cache_unified::get_n_kv and llama_kv_cache_unified::get_padding.
259- // Not used for NPU.
260- // Update: not used anymore after the optimization of making kvcache dynamic (but breaks iSWA models)
261- int64_t attention_size = -1 ;
262- int64_t attention_size_swa = -1 ;
263- for (const auto & node : m_nodes) {
264- if (node->op == GGML_OP_FLASH_ATTN_EXT) {
265- auto * mask = node->src [3 ];
266- std::string mask_name (mask->name );
267- if (mask_name.find (" KQ_mask" ) != 0 ) {
268- throw std::runtime_error (" Unexpected flash attention node: " + std::string (mask->name ));
269- }
270- if (mask_name.find (" swa" ) != std::string::npos) {
271- attention_size_swa = mask->ne [0 ];
272- } else {
273- attention_size = mask->ne [0 ];
274- }
275- }
276- }
277-
278- auto create_attention_size_input = [this ](const std::string & name, int64_t size) {
279- auto param_node = std::make_shared<ov::op::v0::Parameter>(ov::element::i64 , ov::Shape{1 });
280- param_node->set_friendly_name (name);
281- param_node->output (0 ).get_tensor ().set_names ({name});
282- m_model_extra_inputs[name] = param_node;
291+ // 2. `n_seq_active` and `seq_active_start`, used in FLASH_ATTN_EXT to indicate the active sequences in the batch
283292
284- auto tensor = std::make_shared<ov::Tensor>(ov::element::i64 , ov::Shape{1 });
285- *tensor->data <int64_t >() = size;
286- m_model_extra_input_values[name] = tensor;
293+ auto create_1d_input = [this ](const std::string & name, int64_t value) {
294+ if (m_is_static) {
295+ auto constant =
296+ std::make_shared<ov::op::v0::Constant>(ov::element::i64 , ov::Shape{1 }, std::vector<int64_t >{value});
297+ constant->set_friendly_name (name);
298+ m_model_extra_inputs[name] = constant;
299+ } else {
300+ auto param_node = std::make_shared<ov::op::v0::Parameter>(ov::element::i64 , ov::Shape{1 });
301+ param_node->set_friendly_name (name);
302+ param_node->output (0 ).get_tensor ().set_names ({name});
303+ m_model_extra_inputs[name] = param_node;
304+
305+ auto tensor = std::make_shared<ov::Tensor>(ov::element::i64 , ov::Shape{1 });
306+ *tensor->data <int64_t >() = value;
307+ m_model_extra_input_values[name] = tensor;
308+ }
287309 };
288310
289- create_attention_size_input (" attention_size" , attention_size );
290- if (attention_size_swa != -1 ) {
291- create_attention_size_input (" attention_size_swa" , attention_size_swa );
311+ create_1d_input (" attention_size" , m_attention_size );
312+ if (m_attention_size_swa != -1 ) {
313+ create_1d_input (" attention_size_swa" , m_attention_size_swa );
292314 }
315+ create_1d_input (" n_seq_active" , m_n_seq_active);
316+ create_1d_input (" seq_active_start" , m_seq_active_start);
317+ create_1d_input (" seq_active_end" , m_seq_active_start + m_n_seq_active);
318+ create_1d_input (" token_len_per_seq" , m_token_len_per_seq);
319+ // create_1d_input("token_len", m_token_len_per_seq * m_n_seq_active);
293320}
294321
295322const ggml_tensor * GgmlOvDecoder::get_tensor_used_op (const ggml_tensor * tensor) const {
@@ -390,6 +417,8 @@ std::shared_ptr<ov::Node> GgmlOvDecoder::create_weight_node(ggml_tensor * tensor
390417 auto node_shape = get_shape (tensor);
391418 auto ne_total = ggml_nelements (tensor);
392419
420+ OPENVINO_ASSERT (node_shape[0 ] == 1 , " Got 4D weights, expect all weights to be 2D: " , tensor->name );
421+ node_shape.erase (node_shape.begin ());
393422 OPENVINO_ASSERT (node_shape[0 ] == 1 , " Got 3D weights, expect all weights to be 2D: " , tensor->name );
394423 node_shape.erase (node_shape.begin ());
395424
@@ -559,15 +588,15 @@ void print_tensor_address_map(const ggml_cgraph * cgraph) {
559588
560589std::vector<size_t > GgmlOvDecoder::get_shape (const ggml_tensor * tensor) {
561590 std::vector<size_t > shape;
562- for (int i = GGML_MAX_DIMS - 2 ; i >= 0 ; --i) {
591+ for (int i = GGML_MAX_DIMS - 1 ; i >= 0 ; --i) {
563592 shape.push_back (static_cast <size_t >(tensor->ne [i]));
564593 }
565594 return shape;
566595}
567596
568597std::vector<size_t > GgmlOvDecoder::get_stride (const ggml_tensor * tensor) {
569598 std::vector<size_t > stride;
570- for (int i = GGML_MAX_DIMS - 2 ; i >= 0 ; --i) {
599+ for (int i = GGML_MAX_DIMS - 1 ; i >= 0 ; --i) {
571600 stride.push_back (static_cast <size_t >(tensor->nb [i]));
572601 }
573602 return stride;
@@ -626,7 +655,11 @@ std::vector<size_t> GgmlOvDecoder::get_output_stride(const std::string & name) c
626655}
627656
628657ov::PartialShape GgmlOvDecoder::get_output_shape (const std::string & name) const {
629- return ov::PartialShape (get_shape (m_outputs.at (name)));
658+ auto * ggml_tensor = m_outputs.at (name);
659+ if (ggml_tensor->op == GGML_OP_SET_ROWS) {
660+ ggml_tensor = ggml_tensor->view_src ;
661+ }
662+ return ov::PartialShape (get_shape (ggml_tensor));
630663}
631664
632665ov::element::Type GgmlOvDecoder::get_output_type (const std::string & name) const {
0 commit comments