@@ -83,6 +83,14 @@ struct PackedVec3::State {
8383 // / A map from type to the name of a helper function used to unpack that type.
8484 Hashmap<const core::type::Type*, Symbol, 4 > unpack_helpers;
8585
86+ // / @returns true if @p addrspace requires vec3 types to be packed
87+ bool AddressSpaceNeedsPacking (core::AddressSpace addrspace) {
88+ // Host-shareable address spaces need to be packed to match the memory layout on the host.
89+ // The workgroup address space needs to be packed so that the size of generated threadgroup
90+ // variables matches the size of the original WGSL declarations.
91+ return core::IsHostShareable (addrspace) || addrspace == core::AddressSpace::kWorkgroup ;
92+ }
93+
8694 // / @param ty the type to test
8795 // / @returns true if `ty` is a vec3, false otherwise
8896 bool IsVec3 (const core::type::Type* ty) {
@@ -374,7 +382,7 @@ struct PackedVec3::State {
374382 // if the transform is necessary.
375383 for (auto * decl : src.AST ().GlobalVariables ()) {
376384 auto * var = sem.Get <sem::GlobalVariable>(decl);
377- if (var && core::IsHostShareable (var->AddressSpace ()) &&
385+ if (var && AddressSpaceNeedsPacking (var->AddressSpace ()) &&
378386 ContainsVec3 (var->Type ()->UnwrapRef ())) {
379387 return true ;
380388 }
@@ -411,7 +419,7 @@ struct PackedVec3::State {
411419 [&](const sem::TypeExpression* type) {
412420 // Rewrite pointers to types that contain vec3s.
413421 auto * ptr = type->Type ()->As <core::type::Pointer>();
414- if (ptr && core::IsHostShareable (ptr->AddressSpace ())) {
422+ if (ptr && AddressSpaceNeedsPacking (ptr->AddressSpace ())) {
415423 auto new_store_type = RewriteType (ptr->StoreType ());
416424 if (new_store_type) {
417425 auto access = ptr->AddressSpace () == core::AddressSpace::kStorage
@@ -424,7 +432,7 @@ struct PackedVec3::State {
424432 }
425433 },
426434 [&](const sem::Variable* var) {
427- if (!core::IsHostShareable (var->AddressSpace ())) {
435+ if (!AddressSpaceNeedsPacking (var->AddressSpace ())) {
428436 return ;
429437 }
430438
@@ -440,7 +448,7 @@ struct PackedVec3::State {
440448 auto * lhs = sem.GetVal (assign->lhs );
441449 auto * rhs = sem.GetVal (assign->rhs );
442450 if (!ContainsVec3 (rhs->Type ()) ||
443- !core::IsHostShareable (
451+ !AddressSpaceNeedsPacking (
444452 lhs->Type ()->As <core::type::Reference>()->AddressSpace ())) {
445453 // Skip assignments to address spaces that are not host-shareable, or
446454 // that do not contain vec3 types.
@@ -468,7 +476,7 @@ struct PackedVec3::State {
468476 [&](const sem::Load* load) {
469477 // Unpack loads of types that contain vec3s in host-shareable address spaces.
470478 if (ContainsVec3 (load->Type ()) &&
471- core::IsHostShareable (load->ReferenceType ()->AddressSpace ())) {
479+ AddressSpaceNeedsPacking (load->ReferenceType ()->AddressSpace ())) {
472480 to_unpack.Add (load);
473481 }
474482 },
@@ -478,7 +486,7 @@ struct PackedVec3::State {
478486 // struct.
479487 if (auto * ref = accessor->Type ()->As <core::type::Reference>()) {
480488 if (IsVec3 (ref->StoreType ()) &&
481- core::IsHostShareable (ref->AddressSpace ())) {
489+ AddressSpaceNeedsPacking (ref->AddressSpace ())) {
482490 ctx.Replace (node, b.MemberAccessor (ctx.Clone (accessor->Declaration ()),
483491 kStructMemberName ));
484492 }
0 commit comments