@@ -733,20 +733,31 @@ impl<'tcx> CodegenCx<'tcx> {
733733 . decorate ( var_id. unwrap ( ) , Decoration :: Invariant , std:: iter:: empty ( ) ) ;
734734 }
735735 if let Some ( per_primitive_ext) = attrs. per_primitive_ext {
736- if storage_class != Ok ( StorageClass :: Output ) {
737- self . tcx . dcx ( ) . span_fatal (
738- per_primitive_ext. span ,
739- "`#[spirv(per_primitive_ext)]` is only valid on Output variables" ,
740- ) ;
741- }
742- if !( execution_model == ExecutionModel :: MeshEXT
743- || execution_model == ExecutionModel :: MeshNV )
744- {
745- self . tcx . dcx ( ) . span_fatal (
746- per_primitive_ext. span ,
747- "`#[spirv(per_primitive_ext)]` is only valid in mesh shaders" ,
748- ) ;
736+ match execution_model {
737+ ExecutionModel :: Fragment => {
738+ if storage_class != Ok ( StorageClass :: Input ) {
739+ self . tcx . dcx ( ) . span_fatal (
740+ per_primitive_ext. span ,
741+ "`#[spirv(per_primitive_ext)]` in fragment shaders is only valid on Input variables" ,
742+ ) ;
743+ }
744+ }
745+ ExecutionModel :: MeshNV | ExecutionModel :: MeshEXT => {
746+ if storage_class != Ok ( StorageClass :: Output ) {
747+ self . tcx . dcx ( ) . span_fatal (
748+ per_primitive_ext. span ,
749+ "`#[spirv(per_primitive_ext)]` in mesh shaders is only valid on Output variables" ,
750+ ) ;
751+ }
752+ }
753+ _ => {
754+ self . tcx . dcx ( ) . span_fatal (
755+ per_primitive_ext. span ,
756+ "`#[spirv(per_primitive_ext)]` is only valid in fragment or mesh shaders" ,
757+ ) ;
758+ }
749759 }
760+
750761 self . emit_global ( ) . decorate (
751762 var_id. unwrap ( ) ,
752763 Decoration :: PerPrimitiveEXT ,
0 commit comments