@@ -67,30 +67,48 @@ layout (binding = 5) writeonly buffer O {D_TYPE data_o[];};
67
67
#if defined(A_TYPE_PACKED16)
68
68
#define BINDING_IDX_K 0
69
69
#define BINDING_IDX_V 1
70
- layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2];
70
+ layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed;
71
+ layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed;
71
72
#endif
72
73
73
74
#if defined(DATA_A_Q4_0)
74
75
#define BLOCK_BYTE_SIZE 18
75
76
76
77
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
77
- uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
78
- uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
79
- uint shift = (iqs & 0x10) >> 2;
80
- vui_lo >>= shift;
81
- vui_hi >>= shift;
82
-
83
- return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
78
+ if (binding_idx == BINDING_IDX_K) {
79
+ uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
80
+ uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
81
+ uint shift = (iqs & 0x10) >> 2;
82
+ vui_lo >>= shift;
83
+ vui_hi >>= shift;
84
+
85
+ return float(k_packed.k_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
86
+ } else {
87
+ uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]);
88
+ uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]);
89
+ uint shift = (iqs & 0x10) >> 2;
90
+ vui_lo >>= shift;
91
+ vui_hi >>= shift;
92
+
93
+ return float(v_packed.v_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f);
94
+ }
84
95
}
85
96
#endif
86
97
87
98
#if defined(DATA_A_Q8_0)
88
99
#define BLOCK_BYTE_SIZE 34
89
100
vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) {
90
- const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
91
- const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
101
+ if (binding_idx == BINDING_IDX_K) {
102
+ const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
103
+ const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
92
104
93
- return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
105
+ return float(k_packed.k_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
106
+ } else {
107
+ const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147
108
+ const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy;
109
+
110
+ return float(v_packed.v_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y);
111
+ }
94
112
}
95
113
#endif
96
114
0 commit comments