Commit 1dfdea6
committed
Created ReplicateKVHeadTransform to integrate KV-heads replication module within Qefficient library.
The Transform enables KV-head replication for CausalLMs and VLMs as well.
The feature is enabled by passing n_kv_head_repeat parameter during initialization of the QEff wrapper class for the corresponding model.
n_kv_head_repeat param acts as the multiplier for the number of repeats to be done to original count of KV heads.
This operation also causes the config and the hash params of the respective model to update the num_key_value_heads parameter and add a paramter orig_kv_heads to it; It allows us to export the same model with different number of kv_heads without causing a hash conflict.
Also added tests for both CausalLMs and VLMs with this functionality to compare outputs of Pytorch HF model and the AIC model.
Two new optional paramters n_kv_head_repeat and test_kv_replicate are added for testing purpose.
Setting test_kv_replicate to True performs a KV-head replication of every model such that the number of KV-heads and attention heads becomes equal. This was done to ensure tests don't fail due to misalignment issues when we simply repeat num_key_value_heads twice and thus cause a divisibility error on hum_heads.
Signed-off-by: Dhiraj Kumar Sah <[email protected]>1 parent 04f1ad7 commit 1dfdea6
File tree
4 files changed
+327
-1
lines changed- QEfficient/transformers/models
- tests/transformers/models
4 files changed
+327
-1
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
43 | 43 | | |
44 | 44 | | |
45 | 45 | | |
| 46 | + | |
46 | 47 | | |
47 | 48 | | |
48 | 49 | | |
| |||
883 | 884 | | |
884 | 885 | | |
885 | 886 | | |
| 887 | + | |
| 888 | + | |
| 889 | + | |
| 890 | + | |
| 891 | + | |
886 | 892 | | |
887 | 893 | | |
888 | 894 | | |
| |||
1511 | 1517 | | |
1512 | 1518 | | |
1513 | 1519 | | |
| 1520 | + | |
| 1521 | + | |
| 1522 | + | |
1514 | 1523 | | |
1515 | 1524 | | |
1516 | 1525 | | |
| |||
2063 | 2072 | | |
2064 | 2073 | | |
2065 | 2074 | | |
| 2075 | + | |
2066 | 2076 | | |
| 2077 | + | |
2067 | 2078 | | |
2068 | 2079 | | |
2069 | 2080 | | |
| |||
2164 | 2175 | | |
2165 | 2176 | | |
2166 | 2177 | | |
| 2178 | + | |
| 2179 | + | |
| 2180 | + | |
2167 | 2181 | | |
2168 | 2182 | | |
2169 | 2183 | | |
| |||
2265 | 2279 | | |
2266 | 2280 | | |
2267 | 2281 | | |
| 2282 | + | |
| 2283 | + | |
2268 | 2284 | | |
| 2285 | + | |
2269 | 2286 | | |
2270 | 2287 | | |
2271 | 2288 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
9 | 9 | | |
10 | 10 | | |
11 | 11 | | |
| 12 | + | |
12 | 13 | | |
13 | 14 | | |
14 | 15 | | |
| |||
424 | 425 | | |
425 | 426 | | |
426 | 427 | | |
| 428 | + | |
| 429 | + | |
| 430 | + | |
427 | 431 | | |
428 | 432 | | |
| 433 | + | |
429 | 434 | | |
430 | 435 | | |
431 | 436 | | |
| |||
630 | 635 | | |
631 | 636 | | |
632 | 637 | | |
| 638 | + | |
| 639 | + | |
| 640 | + | |
| 641 | + | |
| 642 | + | |
| 643 | + | |
| 644 | + | |
| 645 | + | |
| 646 | + | |
| 647 | + | |
| 648 | + | |
| 649 | + | |
| 650 | + | |
| 651 | + | |
| 652 | + | |
| 653 | + | |
| 654 | + | |
| 655 | + | |
| 656 | + | |
| 657 | + | |
| 658 | + | |
| 659 | + | |
| 660 | + | |
| 661 | + | |
| 662 | + | |
| 663 | + | |
| 664 | + | |
| 665 | + | |
| 666 | + | |
| 667 | + | |
| 668 | + | |
| 669 | + | |
| 670 | + | |
| 671 | + | |
| 672 | + | |
| 673 | + | |
| 674 | + | |
| 675 | + | |
| 676 | + | |
| 677 | + | |
| 678 | + | |
| 679 | + | |
| 680 | + | |
| 681 | + | |
| 682 | + | |
| 683 | + | |
| 684 | + | |
| 685 | + | |
| 686 | + | |
| 687 | + | |
| 688 | + | |
| 689 | + | |
| 690 | + | |
| 691 | + | |
| 692 | + | |
| 693 | + | |
| 694 | + | |
| 695 | + | |
| 696 | + | |
| 697 | + | |
| 698 | + | |
| 699 | + | |
| 700 | + | |
| 701 | + | |
| 702 | + | |
| 703 | + | |
| 704 | + | |
| 705 | + | |
| 706 | + | |
| 707 | + | |
| 708 | + | |
| 709 | + | |
| 710 | + | |
| 711 | + | |
| 712 | + | |
| 713 | + | |
| 714 | + | |
| 715 | + | |
| 716 | + | |
| 717 | + | |
| 718 | + | |
| 719 | + | |
| 720 | + | |
| 721 | + | |
| 722 | + | |
| 723 | + | |
| 724 | + | |
| 725 | + | |
| 726 | + | |
| 727 | + | |
| 728 | + | |
| 729 | + | |
| 730 | + | |
| 731 | + | |
| 732 | + | |
| 733 | + | |
| 734 | + | |
| 735 | + | |
| 736 | + | |
| 737 | + | |
| 738 | + | |
| 739 | + | |
| 740 | + | |
| 741 | + | |
| 742 | + | |
| 743 | + | |
| 744 | + | |
| 745 | + | |
| 746 | + | |
| 747 | + | |
| 748 | + | |
| 749 | + | |
| 750 | + | |
| 751 | + | |
| 752 | + | |
| 753 | + | |
| 754 | + | |
| 755 | + | |
| 756 | + | |
| 757 | + | |
| 758 | + | |
| 759 | + | |
| 760 | + | |
| 761 | + | |
| 762 | + | |
| 763 | + | |
| 764 | + | |
| 765 | + | |
| 766 | + | |
| 767 | + | |
| 768 | + | |
| 769 | + | |
| 770 | + | |
| 771 | + | |
| 772 | + | |
| 773 | + | |
| 774 | + | |
| 775 | + | |
| 776 | + | |
| 777 | + | |
| 778 | + | |
| 779 | + | |
| 780 | + | |
| 781 | + | |
633 | 782 | | |
634 | 783 | | |
635 | 784 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
282 | 282 | | |
283 | 283 | | |
284 | 284 | | |
| 285 | + | |
| 286 | + | |
| 287 | + | |
| 288 | + | |
| 289 | + | |
| 290 | + | |
| 291 | + | |
| 292 | + | |
| 293 | + | |
| 294 | + | |
| 295 | + | |
| 296 | + | |
| 297 | + | |
| 298 | + | |
| 299 | + | |
| 300 | + | |
| 301 | + | |
| 302 | + | |
| 303 | + | |
| 304 | + | |
| 305 | + | |
| 306 | + | |
| 307 | + | |
| 308 | + | |
| 309 | + | |
| 310 | + | |
| 311 | + | |
| 312 | + | |
| 313 | + | |
| 314 | + | |
| 315 | + | |
| 316 | + | |
| 317 | + | |
| 318 | + | |
| 319 | + | |
| 320 | + | |
| 321 | + | |
| 322 | + | |
| 323 | + | |
| 324 | + | |
| 325 | + | |
| 326 | + | |
| 327 | + | |
| 328 | + | |
| 329 | + | |
| 330 | + | |
| 331 | + | |
| 332 | + | |
| 333 | + | |
| 334 | + | |
| 335 | + | |
| 336 | + | |
| 337 | + | |
| 338 | + | |
| 339 | + | |
| 340 | + | |
| 341 | + | |
| 342 | + | |
| 343 | + | |
| 344 | + | |
| 345 | + | |
| 346 | + | |
| 347 | + | |
| 348 | + | |
| 349 | + | |
| 350 | + | |
285 | 351 | | |
286 | 352 | | |
287 | 353 | | |
| |||
360 | 426 | | |
361 | 427 | | |
362 | 428 | | |
| 429 | + | |
| 430 | + | |
| 431 | + | |
| 432 | + | |
| 433 | + | |
| 434 | + | |
| 435 | + | |
| 436 | + | |
| 437 | + | |
| 438 | + | |
| 439 | + | |
| 440 | + | |
| 441 | + | |
| 442 | + | |
| 443 | + | |
| 444 | + | |
| 445 | + | |
| 446 | + | |
| 447 | + | |
| 448 | + | |
| 449 | + | |
| 450 | + | |
363 | 451 | | |
364 | 452 | | |
365 | 453 | | |
| |||
0 commit comments