Skip to content

Commit 84e1123

Browse files
committed
Refactored foreign Vulkan code
1 parent daccf36 commit 84e1123

File tree

10 files changed

+165
-204
lines changed

10 files changed

+165
-204
lines changed

cyfra-core/src/main/scala/io/computenode/cyfra/core/GProgram.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ trait GProgram[Params, L <: Layout: {LayoutBinding, LayoutStruct}] extends GExec
2020
val layout: InitProgramLayout => Params => L
2121
val dispatch: (L, Params) => ProgramDispatch
2222
val workgroupSize: WorkDimensions
23-
def layoutStruct = summon[LayoutStruct[L]]
23+
def layoutStruct: LayoutStruct[L] = summon[LayoutStruct[L]]
2424

2525
object GProgram:
2626
type WorkDimensions = (Int, Int, Int)

cyfra-e2e-test/src/test/scala/io/computenode/cyfra/e2e/fs2interop/Fs2Tests.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@ import algebra.VectorAlgebra
66
import io.computenode.cyfra.fs2interop.*
77
import io.computenode.cyfra.core.CyfraRuntime
88
import io.computenode.cyfra.runtime.VkCyfraRuntime
9-
import fs2.{io as fs2io, *}
10-
import _root_.io.computenode.cyfra.spirvtools.{SpirvCross, SpirvDisassembler, SpirvToolsRunner}
11-
import _root_.io.computenode.cyfra.spirvtools.SpirvTool.ToFile
9+
import io.computenode.cyfra.spirvtools.{SpirvCross, SpirvDisassembler, SpirvToolsRunner}
10+
import io.computenode.cyfra.spirvtools.SpirvTool.ToFile
11+
12+
import fs2.*
1213

1314
import java.nio.file.Paths
1415

cyfra-examples/src/main/scala/io/computenode/cyfra/samples/TestingStuff.scala

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,10 @@ import io.computenode.cyfra.spirvtools.{SpirvCross, SpirvToolsRunner, SpirvValid
1313
import org.lwjgl.BufferUtils
1414
import org.lwjgl.system.MemoryUtil
1515

16-
import java.nio.ByteBuffer
1716
import java.nio.file.Paths
1817
import java.util.concurrent.atomic.AtomicInteger
1918
import scala.collection.parallel.CollectionConverters.given
2019

21-
def printBuffer(bb: ByteBuffer): Unit =
22-
val l = bb.asIntBuffer()
23-
val a = new Array[Int](l.remaining())
24-
l.get(a)
25-
println(a.mkString(" "))
26-
2720
object TestingStuff:
2821

2922
// === Emit program ===
@@ -60,30 +53,30 @@ object TestingStuff:
6053

6154
case class FilterProgramUniform(filterValue: Int32) extends GStruct[FilterProgramUniform]
6255

63-
case class FilterProgramLayout(in: GBuffer[Int32], out: GBuffer[GBoolean], params: GUniform[FilterProgramUniform] = GUniform.fromParams)
64-
extends Layout
56+
case class FilterProgramLayout(in: GBuffer[Int32], out: GBuffer[Int32], params: GUniform[FilterProgramUniform] = GUniform.fromParams) extends Layout
6557

6658
val filterProgram = GProgram[FilterProgramParams, FilterProgramLayout](
6759
layout = params =>
6860
FilterProgramLayout(
6961
in = GBuffer[Int32](params.inSize),
70-
out = GBuffer[GBoolean](params.inSize),
62+
out = GBuffer[Int32](params.inSize),
7163
params = GUniform(FilterProgramUniform(params.filterValue)),
7264
),
7365
dispatch = (_, args) => GProgram.StaticDispatch((args.inSize / 128, 1, 1)),
7466
): layout =>
7567
val invocId = GIO.invocationId
7668
val element = GIO.read(layout.in, invocId)
7769
val isMatch = element === layout.params.read.filterValue
78-
GIO.write(layout.out, invocId, isMatch)
70+
val a: Int32 = when[Int32](isMatch)(1).otherwise(0)
71+
GIO.write(layout.out, invocId, a)
7972

8073
// === GExecution ===
8174

8275
case class EmitFilterParams(inSize: Int, emitN: Int, filterValue: Int)
8376

84-
case class EmitFilterLayout(inBuffer: GBuffer[Int32], emitBuffer: GBuffer[Int32], filterBuffer: GBuffer[GBoolean]) extends Layout
77+
case class EmitFilterLayout(inBuffer: GBuffer[Int32], emitBuffer: GBuffer[Int32], filterBuffer: GBuffer[Int32]) extends Layout
8578

86-
case class EmitFilterResult(out: GBuffer[GBoolean]) extends Layout
79+
case class EmitFilterResult(out: GBuffer[Int32]) extends Layout
8780

8881
val emitFilterExecution = GExecution[EmitFilterParams, EmitFilterLayout]()
8982
.addProgram(emitProgram)(
@@ -153,17 +146,17 @@ object TestingStuff:
153146
init = EmitFilterLayout(
154147
inBuffer = GBuffer[Int32](buffer),
155148
emitBuffer = GBuffer[Int32](data.length * 2),
156-
filterBuffer = GBuffer[GBoolean](data.length * 2),
149+
filterBuffer = GBuffer[Int32](data.length * 2),
157150
),
158151
onDone = layout => layout.filterBuffer.read(rbb),
159152
)
160153
runtime.close()
161154

162-
printBuffer(rbb)
163155
val actual = (0 until 2 * 1024).map(i => result.get(i) != 0)
164156
val expected = (0 until 1024).flatMap(x => Seq.fill(emitFilterParams.emitN)(x)).map(_ == emitFilterParams.filterValue)
165157
expected
166158
.zip(actual)
167159
.zipWithIndex
168160
.foreach:
169161
case ((e, a), i) => assert(e == a, s"Mismatch at index $i: expected $e, got $a")
162+
println("DONE")

cyfra-fs2/src/main/scala/io/computenode/cyfra/fs2interop/GPipe.scala

Lines changed: 11 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -81,18 +81,17 @@ object GPipe:
8181
val invocId = GIO.invocationId
8282
val element = GIO.read[C](layout.in, invocId)
8383
val result = when(pred(element))(1: Int32).otherwise(0)
84-
for
85-
_ <- GIO.printf("Pred: Element %d -> %d", invocId, result)
86-
_ <- GIO.write[Int32](layout.out, invocId, result)
84+
for _ <- GIO.write[Int32](layout.out, invocId, result)
8785
yield Empty()
8886

8987
// Prefix sum (inclusive), upsweep/downsweep
9088
case class ScanParams(inSize: Int, intervalSize: Int)
9189
case class ScanArgs(intervalSize: Int32) extends GStruct[ScanArgs]
92-
case class ScanLayout(ints: GBuffer[Int32], intervalSize: GUniform[ScanArgs] = GUniform.fromParams) extends Layout
90+
case class ScanLayout(ints: GBuffer[Int32]) extends Layout
91+
case class ScanProgramLayout(ints: GBuffer[Int32], intervalSize: GUniform[ScanArgs] = GUniform.fromParams) extends Layout
9392

94-
val upsweep = GProgram[ScanParams, ScanLayout](
95-
layout = params => ScanLayout(ints = GBuffer[Int32](params.inSize), intervalSize = GUniform(ScanArgs(params.intervalSize))),
93+
val upsweep = GProgram[ScanParams, ScanProgramLayout](
94+
layout = params => ScanProgramLayout(ints = GBuffer[Int32](params.inSize), intervalSize = GUniform(ScanArgs(params.intervalSize))),
9695
dispatch = (layout, params) => GProgram.StaticDispatch((Math.ceil(params.inSize.toFloat / params.intervalSize / 256).toInt, 1, 1)),
9796
): layout =>
9897
val ScanArgs(size) = layout.intervalSize.read
@@ -104,23 +103,11 @@ object GPipe:
104103
val oldValue = GIO.read[Int32](layout.ints, end)
105104
val addValue = GIO.read[Int32](layout.ints, mid)
106105
val newValue = oldValue + addValue
107-
for
108-
_ <- GIO.printf(
109-
"Upsweep: invocId %d, root %d, size %d, mid %d, end %d, oldValue %d, addValue %d, newValue %d",
110-
invocId,
111-
root,
112-
size,
113-
mid,
114-
end,
115-
oldValue,
116-
addValue,
117-
newValue,
118-
)
119-
_ <- GIO.write[Int32](layout.ints, end, newValue)
106+
for _ <- GIO.write[Int32](layout.ints, end, newValue)
120107
yield Empty()
121108

122-
val downsweep = GProgram[ScanParams, ScanLayout](
123-
layout = params => ScanLayout(ints = GBuffer[Int32](params.inSize), intervalSize = GUniform(ScanArgs(params.intervalSize))),
109+
val downsweep = GProgram[ScanParams, ScanProgramLayout](
110+
layout = params => ScanProgramLayout(ints = GBuffer[Int32](params.inSize), intervalSize = GUniform(ScanArgs(params.intervalSize))),
124111
dispatch = (layout, params) => GProgram.StaticDispatch((Math.ceil(params.inSize.toFloat / params.intervalSize / 256).toInt, 1, 1)),
125112
): layout =>
126113
val ScanArgs(size) = layout.intervalSize.read
@@ -131,17 +118,7 @@ object GPipe:
131118
val oldValue = GIO.read[Int32](layout.ints, mid)
132119
val addValue = when(end > 0)(GIO.read[Int32](layout.ints, end)).otherwise(0)
133120
val newValue = oldValue + addValue
134-
for
135-
_ <- GIO.printf(
136-
"Downsweep: invocId %d, end %d, mid %d, oldValue %d, addValue %d, newValue %d",
137-
invocId,
138-
end,
139-
mid,
140-
oldValue,
141-
addValue,
142-
newValue,
143-
)
144-
_ <- GIO.write[Int32](layout.ints, mid, newValue)
121+
for _ <- GIO.write[Int32](layout.ints, mid, newValue)
145122
yield Empty()
146123

147124
// Stitch together many upsweep / downsweep program phases recursively
@@ -153,7 +130,7 @@ object GPipe:
153130
): GExecution[ScanParams, ScanLayout, ScanLayout] =
154131
if intervalSize > inSize then exec
155132
else
156-
val newExec = exec.addProgram(upsweep)(params => ScanParams(inSize, intervalSize), layout => layout)
133+
val newExec = exec.addProgram(upsweep)(params => ScanParams(inSize, intervalSize), layout => ScanProgramLayout(layout.ints))
157134
upsweepPhases(newExec, inSize, intervalSize * 2)
158135

159136
@annotation.tailrec
@@ -164,7 +141,7 @@ object GPipe:
164141
): GExecution[ScanParams, ScanLayout, ScanLayout] =
165142
if intervalSize < 2 then exec
166143
else
167-
val newExec = exec.addProgram(downsweep)(params => ScanParams(inSize, intervalSize), layout => layout)
144+
val newExec = exec.addProgram(downsweep)(params => ScanParams(inSize, intervalSize), layout => ScanProgramLayout(layout.ints))
168145
downsweepPhases(newExec, inSize, intervalSize / 2)
169146

170147
val initExec = GExecution[ScanParams, ScanLayout]() // no program
@@ -183,7 +160,6 @@ object GPipe:
183160
val element = GIO.read[C](layout.in, invocId)
184161
val prefixSum = GIO.read[Int32](layout.scan, invocId)
185162
for
186-
_ <- GIO.printf("Compact: Element %d, prefix sum %d", invocId, prefixSum)
187163
_ <- GIO.when(invocId > 0):
188164
val prevScan = GIO.read[Int32](layout.scan, invocId - 1)
189165
GIO.when(prevScan < prefixSum):

cyfra-runtime/src/main/scala/io/computenode/cyfra/runtime/ExecutionHandler.scala

Lines changed: 19 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ class ExecutionHandler(runtime: VkCyfraRuntime, threadContext: VulkanThreadConte
4646
val (result, shaderCalls) = interpret(execution, params, layout)
4747

4848
val descriptorSets = shaderCalls.map:
49-
case ShaderCall(pipeline, layout, _, _) =>
49+
case ShaderCall(pipeline, layout, _) =>
5050
pipeline.pipelineLayout.sets
5151
.map(dsManager.allocate)
5252
.zip(layout)
@@ -58,7 +58,7 @@ class ExecutionHandler(runtime: VkCyfraRuntime, threadContext: VulkanThreadConte
5858
val dispatches: Seq[Dispatch] = shaderCalls
5959
.zip(descriptorSets)
6060
.map:
61-
case (ShaderCall(pipeline, layout, dispatch, _), sets) =>
61+
case (ShaderCall(pipeline, layout, dispatch), sets) =>
6262
Dispatch(pipeline, layout, sets, dispatch)
6363

6464
val (executeSteps, _) = dispatches.foldLeft((Seq.empty[ExecutionStep], Set.empty[GBinding[?]])):
@@ -94,7 +94,7 @@ class ExecutionHandler(runtime: VkCyfraRuntime, threadContext: VulkanThreadConte
9494
case x: ExecutionBinding[?] => x
9595
case x: GBinding[?] =>
9696
val e = ExecutionBinding(x)(using x.fromExpr, x.tag)
97-
bindingsAcc.put(e, mutable.Buffer(x)) // store only base contribution here
97+
bindingsAcc.put(e, mutable.Buffer(x))
9898
e
9999
mapper.fromBindings(res)
100100

@@ -128,53 +128,33 @@ class ExecutionHandler(runtime: VkCyfraRuntime, threadContext: VulkanThreadConte
128128
val layoutInit =
129129
val initProgram: InitProgramLayout = summon[VkAllocation].getInitProgramLayout
130130
program.layout(initProgram)(params)
131-
132-
val callInits: Map[GBinding[?], Seq[GBinding[?]]] =
133-
lb
134-
.toBindings(layout)
135-
.zip(lb.toBindings(layoutInit))
136-
.groupMap(_._1)(_._2)
137-
131+
lb.toBindings(layout)
132+
.zip(lb.toBindings(layoutInit))
133+
.foreach:
134+
case (binding, initBinding) =>
135+
bindingsAcc(binding).append(initBinding)
138136
val dispatch = program.dispatch(layout, params) match
139137
case GProgram.DynamicDispatch(buffer, offset) => DispatchType.Indirect(buffer, offset)
140138
case GProgram.StaticDispatch(size) => DispatchType.Direct(size._1, size._2, size._3)
141139
// noinspection ScalaRedundantCast
142-
(layout.asInstanceOf[RL], Seq(ShaderCall(shader.underlying, shader.shaderBindings(layout), dispatch, callInits)))
140+
(layout.asInstanceOf[RL], Seq(ShaderCall(shader.underlying, shader.shaderBindings(layout), dispatch)))
143141
case _ => ???
144142

145143
val (rl, steps) = interpretImpl(execution, params, mockBindings(layout))
146-
147-
val finalBindingForRl: mutable.Map[GBinding[?], GBinding[?]] = mutable.Map.empty
144+
val bingingToVk = bindingsAcc.map(x => (x._1, interpretBinding(x._1, x._2.toSeq)))
148145

149146
val nextSteps = steps.map:
150-
case ShaderCall(pipeline, layout, dispatch, callInits) =>
147+
case ShaderCall(pipeline, layout, dispatch) =>
151148
val nextLayout = layout.map:
152149
_.map:
153-
case Binding(binding, operation) =>
154-
val base = bindingsAcc.getOrElse(binding, mutable.Buffer.empty).toSeq
155-
val extras = callInits.getOrElse(binding, Seq.empty)
156-
val resolved = interpretBinding(binding, base ++ extras)
157-
finalBindingForRl.update(binding, resolved)
158-
Binding(resolved, operation)
159-
150+
case Binding(binding, operation) => Binding(bingingToVk(binding), operation)
160151
val nextDispatch = dispatch match
161-
case x: DispatchType.Direct => x
162-
case DispatchType.Indirect(buffer, offset) =>
163-
val base = bindingsAcc.getOrElse(buffer, mutable.Buffer.empty).toSeq
164-
val extras = callInits.getOrElse(buffer, Seq.empty)
165-
val resolved = interpretBinding(buffer, base ++ extras)
166-
finalBindingForRl.update(buffer, resolved)
167-
DispatchType.Indirect(resolved, offset)
168-
169-
ShaderCall(pipeline, nextLayout, nextDispatch, Map.empty)
152+
case x: Direct => x
153+
case Indirect(buffer, offset) => Indirect(bingingToVk(buffer), offset)
154+
ShaderCall(pipeline, nextLayout, nextDispatch)
170155

171156
val mapper = summon[LayoutBinding[RL]]
172-
val rlBindings = mapper
173-
.toBindings(rl)
174-
.map: b =>
175-
finalBindingForRl.getOrElse(b, interpretBinding(b, bindingsAcc.getOrElse(b, mutable.Buffer.empty).toSeq))
176-
val res = mapper.fromBindings(rlBindings)
177-
157+
val res = mapper.fromBindings(mapper.toBindings(rl).map(bingingToVk.apply))
178158
(res, nextSteps)
179159

180160
private def interpretBinding(binding: GBinding[?], bindings: Seq[GBinding[?]])(using VkAllocation): GBinding[?] =
@@ -207,7 +187,7 @@ class ExecutionHandler(runtime: VkCyfraRuntime, threadContext: VulkanThreadConte
207187
case _: GUniform.ParamUniform[?] => false
208188
case x => throw BindingLogicError(x, "Unsupported binding type")
209189
if allocations.size > 1 then throw BindingLogicError(allocations, "Multiple allocations for uniform")
210-
allocations.headOption.getOrElse(throw new IllegalStateException("Uniform never allocated"))
190+
allocations.headOption.getOrElse(throw new BindingLogicError(Seq(), "Uniform never allocated"))
211191
case x => throw new IllegalArgumentException(s"Binding of type ${x.getClass.getName} should not be here")
212192

213193
private def recordCommandBuffer(steps: Seq[ExecutionStep]): VkCommandBuffer = pushStack: stack =>
@@ -256,19 +236,13 @@ class ExecutionHandler(runtime: VkCyfraRuntime, threadContext: VulkanThreadConte
256236
.distinct
257237

258238
object ExecutionHandler:
259-
case class ShaderCall(
260-
pipeline: ComputePipeline,
261-
layout: ShaderLayout,
262-
dispatch: DispatchType,
263-
callInits: Map[GBinding[?], Seq[GBinding[?]]], // per-program contributions
264-
)
239+
case class ShaderCall(pipeline: ComputePipeline, layout: ShaderLayout, dispatch: DispatchType)
265240

266241
sealed trait ExecutionStep
267-
268242
case class Dispatch(pipeline: ComputePipeline, layout: ShaderLayout, descriptorSets: Seq[DescriptorSet], dispatch: DispatchType)
269243
extends ExecutionStep
270-
271244
case object PipelineBarrier extends ExecutionStep
245+
272246
sealed trait DispatchType
273247
object DispatchType:
274248
case class Direct(x: Int, y: Int, z: Int) extends DispatchType

cyfra-vulkan/src/main/scala/io/computenode/cyfra/vulkan/VulkanContext.scala

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
package io.computenode.cyfra.vulkan
22

33
import io.computenode.cyfra.utility.Logger.logger
4-
import io.computenode.cyfra.vulkan.VulkanContext.ValidationLayers
4+
import io.computenode.cyfra.vulkan.VulkanContext.{validation, vulkanPrintf}
55
import io.computenode.cyfra.vulkan.command.CommandPool
6-
import io.computenode.cyfra.vulkan.core.{DebugCallback, Device, Instance, PhysicalDevice, Queue}
6+
import io.computenode.cyfra.vulkan.core.{DebugMessengerCallback, DebugReportCallback, Device, Instance, PhysicalDevice, Queue}
77
import io.computenode.cyfra.vulkan.memory.{Allocator, DescriptorPool, DescriptorPoolManager, DescriptorSetManager}
88
import org.lwjgl.system.Configuration
99

@@ -15,12 +15,13 @@ import scala.jdk.CollectionConverters.*
1515
* MarconZet Created 13.04.2020
1616
*/
1717
private[cyfra] object VulkanContext:
18-
val ValidationLayer: String = "VK_LAYER_KHRONOS_validation"
19-
private val ValidationLayers: Boolean = System.getProperty("io.computenode.cyfra.vulkan.validation", "false").toBoolean
18+
private val validation: Boolean = System.getProperty("io.computenode.cyfra.vulkan.validation", "false").toBoolean
19+
private val vulkanPrintf: Boolean = System.getProperty("io.computenode.cyfra.vulkan.printf", "false").toBoolean
2020

2121
private[cyfra] class VulkanContext:
22-
private val instance: Instance = new Instance(ValidationLayers)
23-
private val debugCallback: Option[DebugCallback] = if ValidationLayers then Some(new DebugCallback(instance)) else None
22+
private val instance: Instance = new Instance(validation, vulkanPrintf)
23+
private val debugReport: Option[DebugReportCallback] = if validation then Some(new DebugReportCallback(instance)) else None
24+
private val debugMessenger: Option[DebugMessengerCallback] = if validation & vulkanPrintf then Some(new DebugMessengerCallback(instance)) else None
2425
private val physicalDevice = new PhysicalDevice(instance)
2526
physicalDevice.assertRequirements()
2627

@@ -54,5 +55,6 @@ private[cyfra] class VulkanContext:
5455
descriptorPoolManager.destroy()
5556
allocator.destroy()
5657
device.destroy()
57-
debugCallback.foreach(_.destroy())
58+
debugReport.foreach(_.destroy())
59+
debugMessenger.foreach(_.destroy())
5860
instance.destroy()

0 commit comments

Comments
 (0)