Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/dodal/beamlines/i09.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from dodal.devices.fast_shutter import DualFastShutter, GenericFastShutter
from dodal.devices.hutch_shutter import EXP_SHUTTER_2_INFIX, HutchShutter
from dodal.devices.motors import XYZAzimuthPolarStage
from dodal.devices.pause_plan_device import PausePlanDevice
from dodal.devices.pgm import PlaneGratingMonochromator
from dodal.devices.selectable_source import SourceSelector
from dodal.devices.synchrotron import Synchrotron
Expand Down Expand Up @@ -126,3 +127,24 @@ def lakeshore() -> Lakeshore336:
def smpm() -> XYZAzimuthPolarStage:
"""Sample Manipulator."""
return XYZAzimuthPolarStage(prefix=f"{I_PREFIX.beamline_prefix}-MO-SMPM-01:")


@devices.factory()
def checkbeam(
synchrotron: Synchrotron, dual_fast_shutter: DualFastShutter
) -> PausePlanDevice:
async def _close_shutters():
await dual_fast_shutter.set(dual_fast_shutter.close_state)

async def _open_shutters():
await dual_fast_shutter.set(dual_fast_shutter.open_state)

checkbeam = PausePlanDevice(
signals_to_condition={
synchrotron.current: lambda rc: rc > 190,
synchrotron.top_up_start_countdown: lambda topup: topup < 5,
},
callable_when_paused=_close_shutters,
callable_on_resume=_open_shutters,
)
return checkbeam
92 changes: 92 additions & 0 deletions src/dodal/devices/pause_plan_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import asyncio
import contextlib
from collections.abc import Awaitable, Callable
from typing import Any

from bluesky.protocols import Readable, Stageable
from ophyd_async.core import (
AsyncStatus,
Device,
SignalR,
observe_value,
)


class PausePlanDevice(Device, Stageable, Readable):
def __init__(
self,
signals_to_condition: dict[SignalR[Any], Callable[[Any], bool]],
callable_when_paused: Callable[[], Awaitable[None]] | None = None,
callable_on_resume: Callable[[], Awaitable[None]] | None = None,
seconds_to_wait_before_resume: float = 5,
name: str = "",
):
self._signals_to_condition = signals_to_condition
self._callable_when_paused = callable_when_paused
self._callable_on_resume = callable_on_resume
self._seconds_to_wait_before_resume = seconds_to_wait_before_resume
super().__init__(name)

async def _pause(self):
"""Pause until all signal conditions are met, calling hooks as needed."""
# Check if we actually need to pause
values = await asyncio.gather(
*(sig.get_value() for sig in self._signals_to_condition)
)
all_met = all(
pred(value)
for value, (_, pred) in zip(
values, self._signals_to_condition.items(), strict=True
)
)
if all_met:
return # no need to pause

# Call pause hook
if self._callable_when_paused:
await self._callable_when_paused()

latest = {}
event = asyncio.Event()

async def watch(signal, predicate):
async for value in observe_value(signal):
latest[signal] = predicate(value)
if len(latest) == len(self._signals_to_condition) and all(
latest.values()
):
event.set()
return

tasks = [
asyncio.create_task(watch(sig, pred))
for sig, pred in self._signals_to_condition.items()
]

await event.wait()

# Cancel watchers
for task in tasks:
task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await task

await asyncio.sleep(self._seconds_to_wait_before_resume)
# Call resume hook
if self._callable_on_resume:
await self._callable_on_resume()

@AsyncStatus.wrap
async def stage(self):
await self._pause()

@AsyncStatus.wrap
async def unstage(self):
pass

async def read(self):
await self._pause()
return {}

async def describe(self):
return {}
119 changes: 119 additions & 0 deletions tests/devices/test_pause_plan_device.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import asyncio

import pytest
from bluesky import RunEngine
from bluesky import plan_stubs as bps
from ophyd_async.core import InOut, SignalRW, init_devices, soft_signal_rw

from dodal.devices.fast_shutter import GenericFastShutter
from dodal.devices.pause_plan_device import PausePlanDevice


@pytest.fixture
def shutter() -> GenericFastShutter:
with init_devices(mock=True):
shutter = GenericFastShutter(
"TEST:", open_state=InOut.OUT, close_state=InOut.IN
)
return shutter


@pytest.fixture
def sig1() -> SignalRW[int]:
with init_devices(mock=True):
sig1 = soft_signal_rw(int, initial_value=0)
return sig1


@pytest.fixture
def sig2() -> SignalRW[int]:
with init_devices(mock=True):
sig2 = soft_signal_rw(int, initial_value=0)
return sig2


@pytest.fixture
async def pause_plan_device(
sig1: SignalRW[float],
sig2: SignalRW[float],
shutter: GenericFastShutter,
) -> PausePlanDevice:

async def _close_shutter():
await shutter.set(shutter.close_state)

async def _open_shutter():
await shutter.set(shutter.open_state)

with init_devices(mock=True):
pause_plan_device = PausePlanDevice(
{
sig1: lambda v: v == 1,
sig2: lambda v: v > 5,
},
callable_when_paused=_close_shutter,
callable_on_resume=_open_shutter,
seconds_to_wait_before_resume=0,
)
return pause_plan_device


async def test_conditions_can_arrive_in_any_order(
pause_plan_device: PausePlanDevice,
sig1: SignalRW[float],
sig2: SignalRW[float],
shutter: GenericFastShutter,
):
await shutter.set(shutter.open_state)

status = pause_plan_device.stage()

await asyncio.sleep(0.1)

assert not status.done
assert await shutter.shutter_state.get_value() == shutter.close_state

await sig1.set(1)
await sig2.set(10)

await status
assert status.success
assert await shutter.shutter_state.get_value() == shutter.open_state


@pytest.mark.asyncio
async def test_conditions_already_met(
pause_plan_device: PausePlanDevice, sig1: SignalRW[float], sig2: SignalRW[float]
):
await sig1.set(1)
await sig2.set(10)

status = pause_plan_device.stage()

await status

assert status.success


async def test_pause_device_blocks_plan_until_conditions_met(
run_engine: RunEngine,
pause_plan_device: PausePlanDevice,
sig1: SignalRW[float],
sig2: SignalRW[float],
):

start = asyncio.Event()

async def update_signals():
await start.wait()
await sig1.set(1)
await sig2.set(6)

asyncio.create_task(update_signals())

def plan():
start.set()
yield from bps.stage(pause_plan_device)
yield from bps.null()

run_engine(plan())
Loading