Skip to main content

oxide_update_engine/
engine.rs

1// This Source Code Form is subject to the terms of the Mozilla Public
2// License, v. 2.0. If a copy of the MPL was not distributed with this
3// file, You can obtain one at https://mozilla.org/MPL/2.0/.
4
5use crate::{
6    CompletionContext, MetadataContext, StepContext, StepContextPayload,
7    StepHandle, errors::ExecutionError,
8};
9use cancel_safe_futures::coop_cancel;
10use debug_ignore::DebugIgnore;
11use derive_where::derive_where;
12use futures::{future::BoxFuture, prelude::*};
13use linear_map::LinearMap;
14use oxide_update_engine_types::{
15    events::{
16        Event, ExecutionUuid, ProgressEvent, ProgressEventKind,
17        StepComponentSummary, StepEvent, StepEventKind, StepInfo,
18        StepInfoWithMetadata, StepOutcome, StepProgress,
19    },
20    spec::{AsError, EngineSpec, GenericSpec},
21};
22use std::{
23    borrow::Cow,
24    fmt,
25    ops::ControlFlow,
26    pin::Pin,
27    sync::{Arc, Mutex},
28    task::Poll,
29};
30use tokio::{
31    sync::{mpsc, oneshot},
32    time::Instant,
33};
34/// Creates an MPSC channel suitable to be passed into the update engine.
35///
36/// This function is a convenience wrapper around
37/// [`tokio::sync::mpsc::channel`] that creates a channel of the appropriate
38/// size, and may aid in type inference.
39#[inline]
40pub fn channel<S: EngineSpec>()
41-> (mpsc::Sender<Event<S>>, mpsc::Receiver<Event<S>>) {
42    // This is a large enough channel to handle incoming messages without
43    // stalling.
44    const CHANNEL_SIZE: usize = 256;
45    mpsc::channel(CHANNEL_SIZE)
46}
47
48#[derive_where(Debug)]
49pub struct UpdateEngine<'a, S: EngineSpec> {
50    log: slog::Logger,
51    execution_id: ExecutionUuid,
52    sender: EngineSender<S>,
53
54    // This is set to None in Self::execute.
55    canceler: Option<coop_cancel::Canceler<String>>,
56    cancel_receiver: coop_cancel::Receiver<String>,
57
58    // This is a mutex to allow borrows to steps to be held by both
59    // ComponentRegistrar and NewStep at the same time. (This could also
60    // be a `RefCell` if a `Send` bound isn't required.)
61    //
62    // There is an alternative way to do this that doesn't use a mutex
63    // but involves no less than three lifetime parameters, which is
64    // excessive.
65    steps: Mutex<Steps<'a, S>>,
66}
67
68impl<'a, S: EngineSpec + 'a> UpdateEngine<'a, S> {
69    /// Creates a new `UpdateEngine`.
70    ///
71    /// It is recommended that `sender` be created using the [`channel`]
72    /// function, which sets an appropriate channel size.
73    pub fn new(log: &slog::Logger, sender: mpsc::Sender<Event<S>>) -> Self {
74        let sender = Arc::new(DefaultSender { sender });
75        Self::new_impl(log, EngineSender { sender })
76    }
77
78    // See the comment on `StepContext::with_nested_engine` for why
79    // this is necessary.
80    pub(crate) fn new_nested<S2: EngineSpec>(
81        log: &slog::Logger,
82        sender: mpsc::Sender<StepContextPayload<S2>>,
83    ) -> Self {
84        let sender = Arc::new(NestedSender { sender });
85        Self::new_impl(log, EngineSender { sender })
86    }
87
88    fn new_impl(log: &slog::Logger, sender: EngineSender<S>) -> Self {
89        let execution_id = ExecutionUuid::new_v4();
90        let (canceler, cancel_receiver) = coop_cancel::new_pair();
91        Self {
92            log: log.new(slog::o!(
93                "component" => "UpdateEngine",
94                "execution_id" => format!("{execution_id}"),
95            )),
96            execution_id,
97            sender,
98            canceler: Some(canceler),
99            cancel_receiver,
100            steps: Default::default(),
101        }
102    }
103
104    /// Returns the ID for this execution.
105    ///
106    /// All events coming from this engine will have this ID associated
107    /// with them.
108    pub fn execution_id(&self) -> ExecutionUuid {
109        self.execution_id
110    }
111
112    /// Adds a new step corresponding to the given component.
113    ///
114    /// # Notes
115    ///
116    /// The step will be considered to keep running until both the future
117    /// completes and the `StepContext` is dropped. In normal use, both
118    /// happen at the same time. However, it is technically possible to
119    /// make the `StepContext` escape the future.
120    ///
121    /// (Ideally, this would be prevented by making the function take a
122    /// `&mut StepContext`, but there are limitations in stable Rust
123    /// which make this impossible to achieve.)
124    pub fn new_step<F, Fut, T>(
125        &self,
126        component: S::Component,
127        id: S::StepId,
128        description: impl Into<Cow<'static, str>>,
129        step_fn: F,
130    ) -> NewStep<'_, 'a, S, T>
131    where
132        F: FnOnce(StepContext<S>) -> Fut + Send + 'a,
133        Fut: Future<Output = Result<StepResult<T, S>, S::Error>> + Send + 'a,
134        T: Send + 'a,
135    {
136        self.for_component(component).new_step(id, description, step_fn)
137    }
138
139    /// Creates a [`ComponentRegistrar`] that defines steps within the
140    /// context of a component.
141    ///
142    /// It is often useful to define similar steps across multiple
143    /// components. A `ComponentRegistrar` provides an easy way to do
144    /// so.
145    pub fn for_component(
146        &self,
147        component: S::Component,
148    ) -> ComponentRegistrar<'_, 'a, S> {
149        ComponentRegistrar { steps: &self.steps, component }
150    }
151
152    /// Creates and returns an abort handle for this engine.
153    ///
154    /// An abort handle can be used to forcibly cancel update engine
155    /// executions.
156    pub fn abort_handle(&self) -> AbortHandle {
157        AbortHandle {
158            canceler: self
159                .canceler
160                .as_ref()
161                .expect("abort_sender should always be present")
162                .clone(),
163        }
164    }
165
166    /// Executes the engine.
167    ///
168    /// This returns an `ExecutionHandle`, which needs to be awaited on
169    /// to drive the engine forward.
170    pub fn execute(mut self) -> ExecutionHandle<'a, S> {
171        let canceler = self
172            .canceler
173            .take()
174            .expect("execute is the only function which does this");
175        let abort_handle = AbortHandle { canceler };
176
177        let engine_fut = self.execute_impl().boxed();
178
179        ExecutionHandle { engine_fut: DebugIgnore(engine_fut), abort_handle }
180    }
181
182    async fn execute_impl(
183        mut self,
184    ) -> Result<CompletionContext<S>, ExecutionError<S>> {
185        let mut event_index = 0;
186        let next_event_index = || {
187            event_index += 1;
188            event_index - 1
189        };
190        let mut exec_cx = ExecutionContext::new(
191            self.execution_id,
192            next_event_index,
193            self.sender.clone(),
194        );
195
196        let steps = {
197            let mut steps_lock = self.steps.lock().unwrap();
198            // Grab the steps and component counts from within
199            // steps_lock, then let steps_lock go. (Without this,
200            // clippy warns about steps_lock being held across await
201            // points.)
202            //
203            // There are no concurrency concerns here because `execute`
204            // consumes `self`, and is the only piece of code that has
205            // access to the mutex (`self.steps` is a `Mutex<T>`, not
206            // an `Arc<Mutex<T>>`!)
207            std::mem::take(&mut *steps_lock)
208        };
209
210        let step_infos: Vec<_> = steps
211            .steps
212            .iter()
213            .enumerate()
214            .map(|(index, step)| {
215                let total_component_steps = steps
216                    .component_counts
217                    .get(&step.metadata_gen.component)
218                    .expect("this component was added");
219                step.metadata_gen.to_step_info(index, *total_component_steps)
220            })
221            .collect();
222
223        let components = steps
224            .component_counts
225            .iter()
226            .map(|(component, &total_component_steps)| StepComponentSummary {
227                component: component.clone(),
228                total_component_steps,
229            })
230            .collect();
231
232        let mut steps_iter = steps.steps.into_iter().enumerate();
233
234        // We need to handle the following separately:
235        // * The first step
236        // * Intermediate steps
237        // * The last step
238
239        let Some((index, first_step)) = steps_iter.next() else {
240            // There are no steps defined.
241            self.sender
242                .send(Event::Step(StepEvent {
243                    spec: S::spec_name(),
244                    execution_id: self.execution_id,
245                    event_index: (exec_cx.next_event_index)(),
246                    total_elapsed: exec_cx.total_start.elapsed(),
247                    kind: StepEventKind::NoStepsDefined,
248                }))
249                .await?;
250            return Ok(CompletionContext::new());
251        };
252
253        let first_step_info = {
254            let total_component_steps = steps
255                .component_counts
256                .get(&first_step.metadata_gen.component)
257                .expect("this component was added");
258            first_step
259                .metadata_gen
260                .into_step_info_with_metadata(index, *total_component_steps)
261                .await
262        };
263
264        let event = Event::Step(StepEvent {
265            spec: S::spec_name(),
266            execution_id: self.execution_id,
267            event_index: (exec_cx.next_event_index)(),
268            total_elapsed: exec_cx.total_start.elapsed(),
269            kind: StepEventKind::ExecutionStarted {
270                steps: step_infos,
271                components,
272                first_step: first_step_info.clone(),
273            },
274        });
275
276        self.sender.send(event).await?;
277
278        let step_exec_cx = exec_cx.create(first_step_info);
279
280        let (mut step_res, mut reporter) = first_step
281            .exec
282            .execute(&self.log, step_exec_cx, &mut self.cancel_receiver)
283            .await?;
284
285        // Now run all remaining steps.
286        for (index, step) in steps_iter {
287            let total_component_steps = steps
288                .component_counts
289                .get(&step.metadata_gen.component)
290                .expect("this component was added");
291
292            let step_info = step
293                .metadata_gen
294                .into_step_info_with_metadata(index, *total_component_steps)
295                .await;
296            let next_step = reporter.next_step(step_res, &step_info);
297            next_step.await?;
298
299            let step_exec_cx = exec_cx.create(step_info);
300
301            (step_res, reporter) = step
302                .exec
303                .execute(&self.log, step_exec_cx, &mut self.cancel_receiver)
304                .await?;
305        }
306
307        // Finally, report the last step.
308        reporter.last_step(step_res).await?;
309
310        Ok(CompletionContext::new())
311    }
312}
313
314/// Abstraction used to send events to whatever receiver is interested
315/// in them.
316///
317/// # Why is this type so weird?
318///
319/// `EngineSender` is a wrapper around a cloneable trait object. Why do
320/// we need that?
321///
322/// `SenderImpl` has two implementations:
323///
324/// 1. `DefaultSender`, which is a wrapper around an
325///    `mpsc::Sender<Event<S>>`. This is used when the receiver is user
326///    code.
327/// 2. `NestedSender`, which is a more complex wrapper around an
328///    `mpsc::Sender<StepContextPayload<S>>`.
329///
330/// You might imagine that we could have `EngineSender` be an enum
331/// with these two variants. But we actually want `NestedSender<S>` to
332/// implement `SenderImpl<S>` for *any* EngineSpec, not just `S`, to
333/// allow nested engines to be a different EngineSpec than the outer
334/// engine.
335///
336/// In other words, `NestedSender` doesn't represent a single
337/// `mpsc::Sender<StepContextPayload<S>>`, it represents the universe
338/// of all possible EngineSpecs S. This is an infinite number of
339/// variants, and requires a trait object to represent.
340#[derive_where(Clone, Debug)]
341struct EngineSender<S: EngineSpec> {
342    sender: Arc<dyn SenderImpl<S>>,
343}
344
345impl<S: EngineSpec> EngineSender<S> {
346    async fn send(&self, event: Event<S>) -> Result<(), ExecutionError<S>> {
347        self.sender.send(event).await
348    }
349}
350
351trait SenderImpl<S: EngineSpec>: Send + Sync + fmt::Debug {
352    fn send(
353        &self,
354        event: Event<S>,
355    ) -> BoxFuture<'_, Result<(), ExecutionError<S>>>;
356}
357
358#[derive_where(Debug)]
359struct DefaultSender<S: EngineSpec> {
360    sender: mpsc::Sender<Event<S>>,
361}
362
363impl<S: EngineSpec> SenderImpl<S> for DefaultSender<S> {
364    fn send(
365        &self,
366        event: Event<S>,
367    ) -> BoxFuture<'_, Result<(), ExecutionError<S>>> {
368        self.sender.send(event).map_err(|error| error.into()).boxed()
369    }
370}
371
372#[derive_where(Debug)]
373struct NestedSender<S: EngineSpec> {
374    sender: mpsc::Sender<StepContextPayload<S>>,
375}
376
377// Note that NestedSender<S> implements SenderImpl<S2> for any S2:
378// EngineSpec. That is to allow nested engines to implement arbitrary
379// EngineSpecs.
380impl<S: EngineSpec, S2: EngineSpec> SenderImpl<S2> for NestedSender<S> {
381    fn send(
382        &self,
383        event: Event<S2>,
384    ) -> BoxFuture<'_, Result<(), ExecutionError<S2>>> {
385        let now = Instant::now();
386        async move {
387            let (done, done_rx) = oneshot::channel();
388            self.sender
389                .send(StepContextPayload::NestedSingle {
390                    now,
391                    event: event.into_generic(),
392                    done,
393                })
394                .await
395                .expect("our code always keeps payload_receiver open");
396            _ = done_rx.await;
397            Ok(())
398        }
399        .boxed()
400    }
401}
402
403/// A join handle for an UpdateEngine.
404///
405/// This handle should be awaited to drive and obtain the result of an
406/// execution.
407#[derive(Debug)]
408#[must_use = "ExecutionHandle does nothing unless polled"]
409pub struct ExecutionHandle<'a, S: EngineSpec> {
410    engine_fut: DebugIgnore<
411        BoxFuture<'a, Result<CompletionContext<S>, ExecutionError<S>>>,
412    >,
413    abort_handle: AbortHandle,
414}
415
416impl<S: EngineSpec> ExecutionHandle<'_, S> {
417    /// Aborts this engine execution with a message.
418    ///
419    /// This sends the message immediately, and returns a future that
420    /// can be optionally waited against to block until the abort is
421    /// processed.
422    ///
423    /// If this engine is still running, it is aborted at the next
424    /// await point. The engine sends an `ExecutionAborted` message
425    /// over the wire, and an `ExecutionError::Aborted` is returned.
426    ///
427    /// Returns `Err(message)` if the engine has already completed
428    /// execution.
429    pub fn abort(
430        &self,
431        message: impl Into<String>,
432    ) -> Result<AbortWaiter, String> {
433        self.abort_handle.abort(message.into())
434    }
435
436    /// Creates and returns an abort handle for this engine.
437    ///
438    /// An abort handle can be used to forcibly cancel update engine
439    /// executions.
440    pub fn abort_handle(&self) -> AbortHandle {
441        self.abort_handle.clone()
442    }
443}
444
445impl<S: EngineSpec> Future for ExecutionHandle<'_, S> {
446    type Output = Result<CompletionContext<S>, ExecutionError<S>>;
447
448    fn poll(
449        mut self: Pin<&mut Self>,
450        cx: &mut std::task::Context<'_>,
451    ) -> std::task::Poll<Self::Output> {
452        self.engine_fut.0.as_mut().poll(cx)
453    }
454}
455
456/// An abort handle, used to forcibly cancel update engine executions.
457#[derive(Clone, Debug)]
458pub struct AbortHandle {
459    canceler: coop_cancel::Canceler<String>,
460}
461
462impl AbortHandle {
463    /// Aborts this engine execution with a message.
464    ///
465    /// This sends the message immediately, and returns a future that
466    /// can be optionally waited against to block until the abort is
467    /// processed.
468    ///
469    /// If this engine is still running, it is aborted at the next
470    /// await point. The engine sends an `ExecutionAborted` message
471    /// over the wire, and an `ExecutionError::Aborted` is returned.
472    ///
473    /// Returns `Err(message)` if the engine has already completed
474    /// execution.
475    pub fn abort(
476        &self,
477        message: impl Into<String>,
478    ) -> Result<AbortWaiter, String> {
479        let waiter = self.canceler.cancel(message.into())?;
480        Ok(AbortWaiter { waiter })
481    }
482}
483
484/// A future which can be used to optionally block until an abort
485/// message is processed.
486///
487/// Dropping this future does not cancel the abort.
488#[derive(Debug)]
489pub struct AbortWaiter {
490    waiter: coop_cancel::Waiter<String>,
491}
492
493impl Future for AbortWaiter {
494    type Output = ();
495
496    fn poll(
497        mut self: Pin<&mut Self>,
498        cx: &mut std::task::Context<'_>,
499    ) -> Poll<Self::Output> {
500        self.waiter.poll_unpin(cx)
501    }
502}
503
504#[derive_where(Default, Debug)]
505struct Steps<'a, S: EngineSpec> {
506    steps: Vec<Step<'a, S>>,
507
508    // This is a `LinearMap` and not a `HashMap`/`BTreeMap` because we
509    // don't want to impose a `Hash` or `Ord` restriction on
510    // `S::Component`. In particular, we want to support
511    // `S::Component` being a generic `serde_json::Value`, which
512    // doesn't implement `Hash` or `Ord` but does implement `Eq`.
513    component_counts: LinearMap<S::Component, usize>,
514}
515
516// Note: have to be careful with lifetimes here because 'a is an
517// invariant lifetime. If there are compile errors related to this,
518// they're likely to be because 'a got mixed up with a covariant
519// lifetime like 'engine.
520
521/// Provides component context against which a step can be registered.
522pub struct ComponentRegistrar<'engine, 'a, S: EngineSpec> {
523    steps: &'engine Mutex<Steps<'a, S>>,
524    component: S::Component,
525}
526
527impl<'engine, 'a, S: EngineSpec> ComponentRegistrar<'engine, 'a, S> {
528    /// Returns the component associated with this registrar.
529    #[inline]
530    pub fn component(&self) -> &S::Component {
531        &self.component
532    }
533
534    /// Adds a new step corresponding to the component associated with
535    /// the registrar.
536    ///
537    /// # Notes
538    ///
539    /// The step will be considered to keep running until both the
540    /// future completes and the `StepContext` is dropped. In normal
541    /// use, both happen at the same time. However, it is technically
542    /// possible to make the `StepContext` escape the future.
543    ///
544    /// (Ideally, this would be prevented by making the function take
545    /// a `&mut StepContext`, but there are limitations in stable Rust
546    /// which make this impossible to achieve.)
547    pub fn new_step<F, Fut, T>(
548        &self,
549        id: S::StepId,
550        description: impl Into<Cow<'static, str>>,
551        step_fn: F,
552    ) -> NewStep<'engine, 'a, S, T>
553    where
554        F: FnOnce(StepContext<S>) -> Fut + Send + 'a,
555        Fut: Future<Output = Result<StepResult<T, S>, S::Error>> + Send + 'a,
556        T: Send + 'a,
557    {
558        let (sender, receiver) = oneshot::channel();
559
560        let exec_fn = Box::new(move |cx: StepContext<S>| {
561            let result = (step_fn)(cx);
562            async move {
563                match result.await {
564                    Ok(val) => {
565                        // Ignore errors if the receiver (the
566                        // StepHandle) was dropped.
567                        _ = sender.send(val.output);
568                        Ok(val.outcome)
569                    }
570                    Err(error) => {
571                        // This terminates progress.
572                        Err(error)
573                    }
574                }
575            }
576            .boxed()
577        });
578
579        NewStep {
580            steps: self.steps,
581            component: self.component.clone(),
582            id,
583            description: description.into(),
584            exec_fn: DebugIgnore(exec_fn),
585            receiver,
586            metadata_fn: None,
587        }
588    }
589}
590
591/// A new step that hasn't been registered by an execution engine yet.
592///
593/// Created by [`UpdateEngine::new_step`] or
594/// [`ComponentRegistrar::new_step`].
595#[must_use = "call register() to register this step with the engine"]
596#[derive(Debug)]
597pub struct NewStep<'engine, 'a, S: EngineSpec, T> {
598    steps: &'engine Mutex<Steps<'a, S>>,
599    component: S::Component,
600    id: S::StepId,
601    description: Cow<'static, str>,
602    exec_fn: DebugIgnore<StepExecFn<'a, S>>,
603    receiver: oneshot::Receiver<T>,
604    metadata_fn: Option<DebugIgnore<StepMetadataFn<'a, S>>>,
605}
606
607impl<'a, S: EngineSpec, T> NewStep<'_, 'a, S, T> {
608    /// Adds a metadata-generating function to the step.
609    ///
610    /// This function is expected to produce
611    /// [`S::StepMetadata`](EngineSpec::StepMetadata). The metadata
612    /// function must be infallible, and will often be synchronous
613    /// code.
614    pub fn with_metadata_fn<F, Fut>(mut self, f: F) -> Self
615    where
616        F: FnOnce(MetadataContext<S>) -> Fut + Send + 'a,
617        Fut: Future<Output = S::StepMetadata> + Send + 'a,
618    {
619        self.metadata_fn = Some(DebugIgnore(Box::new(|cx| (f)(cx).boxed())));
620        self
621    }
622
623    /// Registers the step with the engine.
624    pub fn register(self) -> StepHandle<T, S> {
625        let mut steps_lock = self.steps.lock().unwrap();
626        let component_count = steps_lock
627            .component_counts
628            .entry(self.component.clone())
629            .or_insert(0);
630        let current_index = *component_count;
631        *component_count += 1;
632
633        let step = Step {
634            metadata_gen: StepMetadataGen {
635                id: self.id,
636                component: self.component.clone(),
637                component_index: current_index,
638                description: self.description,
639                metadata_fn: self.metadata_fn,
640            },
641            exec: StepExec { exec_fn: self.exec_fn },
642        };
643        steps_lock.steps.push(step);
644        StepHandle::new(self.receiver)
645    }
646}
647
648/// The result of a step.
649///
650/// Returned by the callback passed to `register_step`.
651#[derive_where(Debug; T: std::fmt::Debug)]
652#[must_use = "StepResult must be used"]
653pub struct StepResult<T, S: EngineSpec> {
654    /// The output of the step.
655    pub output: T,
656
657    /// The outcome associated with the step.
658    ///
659    /// This outcome is serializable.
660    pub outcome: StepOutcome<S>,
661}
662
663impl<T, S: EngineSpec> StepResult<T, S> {
664    /// Maps a `StepResult<T, S>` to `StepResult<U, S>` by applying a
665    /// function to the contained `output` value, leaving the `outcome`
666    /// untouched.
667    pub fn map<U, F>(self, op: F) -> StepResult<U, S>
668    where
669        F: FnOnce(T) -> U,
670    {
671        StepResult { output: op(self.output), outcome: self.outcome }
672    }
673}
674
675/// A success result produced by a step.
676#[derive_where(Debug; T: std::fmt::Debug)]
677#[must_use = "StepSuccess must be used"]
678pub struct StepSuccess<T, S: EngineSpec> {
679    /// The output of the step.
680    pub output: T,
681
682    /// An optional message associated with this result.
683    pub message: Option<Cow<'static, str>>,
684
685    /// Optional metadata associated with this step.
686    pub metadata: Option<S::CompletionMetadata>,
687}
688
689impl<T, S: EngineSpec> StepSuccess<T, S> {
690    /// Creates a new `StepSuccess`.
691    pub fn new(output: T) -> Self {
692        Self { output, metadata: None, message: None }
693    }
694
695    /// Adds a message to this step.
696    pub fn with_message(
697        mut self,
698        message: impl Into<Cow<'static, str>>,
699    ) -> Self {
700        self.message = Some(message.into());
701        self
702    }
703
704    /// Adds metadata to this step.
705    pub fn with_metadata(mut self, metadata: S::CompletionMetadata) -> Self {
706        self.metadata = Some(metadata);
707        self
708    }
709
710    /// Creates a `StepResult` from this `StepSuccess`.
711    pub fn build(self) -> StepResult<T, S> {
712        StepResult {
713            output: self.output,
714            outcome: StepOutcome::Success {
715                message: self.message,
716                metadata: self.metadata,
717            },
718        }
719    }
720}
721
722impl<T, S: EngineSpec> From<StepSuccess<T, S>>
723    for Result<StepResult<T, S>, S::Error>
724{
725    fn from(value: StepSuccess<T, S>) -> Self {
726        Ok(value.build())
727    }
728}
729
730#[derive_where(Debug; T: std::fmt::Debug)]
731#[must_use = "StepWarning must be used"]
732pub struct StepWarning<T, S: EngineSpec> {
733    /// The output of the step.
734    pub output: T,
735
736    /// A message associated with this result.
737    pub message: Cow<'static, str>,
738
739    /// Optional metadata associated with this step.
740    pub metadata: Option<S::CompletionMetadata>,
741}
742
743impl<T, S: EngineSpec> StepWarning<T, S> {
744    /// Creates a new `StepWarning`.
745    pub fn new(output: T, message: impl Into<Cow<'static, str>>) -> Self {
746        Self { output, message: message.into(), metadata: None }
747    }
748
749    /// Adds metadata to this step.
750    pub fn with_metadata(mut self, metadata: S::CompletionMetadata) -> Self {
751        self.metadata = Some(metadata);
752        self
753    }
754
755    /// Creates a `StepResult` from this `StepWarning`.
756    pub fn build(self) -> StepResult<T, S> {
757        StepResult {
758            output: self.output,
759            outcome: StepOutcome::Warning {
760                message: self.message,
761                metadata: self.metadata,
762            },
763        }
764    }
765}
766
767impl<T, S: EngineSpec> From<StepWarning<T, S>>
768    for Result<StepResult<T, S>, S::Error>
769{
770    fn from(value: StepWarning<T, S>) -> Self {
771        Ok(value.build())
772    }
773}
774
775#[derive_where(Debug; T: std::fmt::Debug)]
776#[must_use = "StepSkipped must be used"]
777pub struct StepSkipped<T, S: EngineSpec> {
778    /// The output of the step.
779    pub output: T,
780
781    /// A message associated with this step.
782    pub message: Cow<'static, str>,
783
784    /// Optional metadata associated with this step.
785    pub metadata: Option<S::SkippedMetadata>,
786}
787
788impl<T, S: EngineSpec> StepSkipped<T, S> {
789    /// Creates a new `StepSkipped`.
790    pub fn new(output: T, message: impl Into<Cow<'static, str>>) -> Self {
791        Self { output, message: message.into(), metadata: None }
792    }
793
794    /// Adds metadata to this step.
795    pub fn with_metadata(mut self, metadata: S::SkippedMetadata) -> Self {
796        self.metadata = Some(metadata);
797        self
798    }
799
800    /// Creates a `StepResult` from this `StepSkipped`.
801    pub fn build(self) -> StepResult<T, S> {
802        StepResult {
803            output: self.output,
804            outcome: StepOutcome::Skipped {
805                message: self.message,
806                metadata: self.metadata,
807            },
808        }
809    }
810}
811
812impl<T, S: EngineSpec> From<StepSkipped<T, S>>
813    for Result<StepResult<T, S>, S::Error>
814{
815    fn from(value: StepSkipped<T, S>) -> Self {
816        Ok(value.build())
817    }
818}
819
820/// A step consists of three components:
821///
822/// 1. Information about the step, including the component, ID, etc.
823/// 2. Metadata about the step, generated in an async function. For
824///    example, this can be a hash of an artifact, or an address it
825///    was downloaded from.
826/// 3. The actual step function.
827///
828/// 1 and 2 are in StepMetadataGen, while 3 is in exec.
829#[derive_where(Debug)]
830struct Step<'a, S: EngineSpec> {
831    metadata_gen: StepMetadataGen<'a, S>,
832    exec: StepExec<'a, S>,
833}
834
835#[derive_where(Debug)]
836struct StepMetadataGen<'a, S: EngineSpec> {
837    id: S::StepId,
838    component: S::Component,
839    component_index: usize,
840    description: Cow<'static, str>,
841    metadata_fn: Option<DebugIgnore<StepMetadataFn<'a, S>>>,
842}
843
844impl<S: EngineSpec> StepMetadataGen<'_, S> {
845    fn to_step_info(
846        &self,
847        index: usize,
848        total_component_steps: usize,
849    ) -> StepInfo<S> {
850        StepInfo {
851            id: self.id.clone(),
852            component: self.component.clone(),
853            index,
854            component_index: self.component_index,
855            total_component_steps,
856            description: self.description.clone(),
857        }
858    }
859
860    async fn into_step_info_with_metadata(
861        self,
862        index: usize,
863        total_component_steps: usize,
864    ) -> StepInfoWithMetadata<S> {
865        let info = self.to_step_info(index, total_component_steps);
866        let metadata = match self.metadata_fn {
867            None => None,
868            Some(DebugIgnore(metadata_fn)) => {
869                let cx = MetadataContext::new();
870                let metadata_fut = (metadata_fn)(cx);
871                let metadata = metadata_fut.await;
872                Some(metadata)
873            }
874        };
875
876        StepInfoWithMetadata { info, metadata }
877    }
878}
879
880#[derive_where(Debug)]
881struct StepExec<'a, S: EngineSpec> {
882    exec_fn: DebugIgnore<StepExecFn<'a, S>>,
883}
884
885impl<S: EngineSpec> StepExec<'_, S> {
886    async fn execute<F: FnMut() -> usize>(
887        self,
888        log: &slog::Logger,
889        step_exec_cx: StepExecutionContext<S, F>,
890        cancel_receiver: &mut coop_cancel::Receiver<String>,
891    ) -> Result<
892        (Result<StepOutcome<S>, S::Error>, StepProgressReporter<S, F>),
893        ExecutionError<S>,
894    > {
895        slog::debug!(
896            log,
897            "start executing step";
898            "step component" => ?step_exec_cx.step_info.info.component,
899            "step id" => ?step_exec_cx.step_info.info.id,
900        );
901        let (payload_sender, mut payload_receiver) = mpsc::channel(16);
902        let cx = StepContext::new(log, payload_sender);
903
904        let mut step_fut = (self.exec_fn.0)(cx);
905        let mut reporter = StepProgressReporter::new(step_exec_cx);
906
907        let mut step_res = None;
908        let mut payload_done = false;
909
910        loop {
911            // This is the main execution select loop. We break it
912            // up into two portions:
913            //
914            // 1. The inner select, which is the meat of the engine.
915            //    It consists of driving the step and the payload
916            //    receiver forward.
917            //
918            // 2. The outer select, which consists of selecting over
919            //    the inner select and the abort receiver.
920            //
921            // The two selects cannot be combined! That's because
922            // the else block of the inner select only applies to
923            // the step and payload receivers. We do not want to
924            // wait for the abort receiver to exit before exiting
925            // the loop.
926            let inner_select = async {
927                tokio::select! {
928                    res = &mut step_fut, if step_res.is_none() => {
929                        step_res = Some(res);
930                        Ok(ControlFlow::Continue(()))
931                    }
932
933                    // Note: payload_receiver is always kept open
934                    // while step_fut is being driven. It is only
935                    // dropped before completion if the step is
936                    // aborted, in which case step_fut is also
937                    // cancelled without being driven further. A
938                    // bunch of expects with "our code always keeps
939                    // payload_receiver open" rely on this.
940                    //
941                    // If we ever move the payload receiver to
942                    // another task so it runs in parallel, this
943                    // situation would have to be handled with care.
944                    payload = payload_receiver.recv(), if !payload_done => {
945                        match payload {
946                            Some(payload) => {
947                                reporter.handle_payload(payload).await?;
948                            }
949                            None => {
950                                // The payload receiver is complete.
951                                payload_done = true;
952                            }
953                        }
954                        Ok(ControlFlow::Continue(()))
955                    }
956
957                    else => Ok(ControlFlow::Break(())),
958                }
959            };
960
961            // This is the outer select.
962            tokio::select! {
963                ret = inner_select => {
964                    match ret {
965                        Ok(op) => {
966                            if op.is_break() {
967                                break;
968                            }
969                        }
970                        Err(error) => {
971                            return Err(error);
972                        }
973                    }
974                }
975
976                Some(message) = cancel_receiver.recv() => {
977                    return Err(reporter.handle_abort(message).await);
978                }
979            }
980        }
981
982        // Return the result -- the caller is responsible for handling
983        // events.
984        let step_res = step_res.expect("can only get here if res is Some");
985        Ok((step_res, reporter))
986    }
987}
988
989#[derive_where(Debug)]
990struct ExecutionContext<S: EngineSpec, F> {
991    execution_id: ExecutionUuid,
992    next_event_index: DebugIgnore<F>,
993    total_start: Instant,
994    sender: EngineSender<S>,
995}
996
997impl<S: EngineSpec, F> ExecutionContext<S, F> {
998    fn new(
999        execution_id: ExecutionUuid,
1000        next_event_index: F,
1001        sender: EngineSender<S>,
1002    ) -> Self {
1003        let total_start = Instant::now();
1004        Self {
1005            execution_id,
1006            next_event_index: DebugIgnore(next_event_index),
1007            total_start,
1008            sender,
1009        }
1010    }
1011
1012    fn create(
1013        &mut self,
1014        step_info: StepInfoWithMetadata<S>,
1015    ) -> StepExecutionContext<S, &mut F> {
1016        StepExecutionContext {
1017            execution_id: self.execution_id,
1018            next_event_index: DebugIgnore(&mut self.next_event_index.0),
1019            total_start: self.total_start,
1020            step_info,
1021            sender: self.sender.clone(),
1022        }
1023    }
1024}
1025
1026#[derive_where(Debug)]
1027struct StepExecutionContext<S: EngineSpec, F> {
1028    execution_id: ExecutionUuid,
1029    next_event_index: DebugIgnore<F>,
1030    total_start: Instant,
1031    step_info: StepInfoWithMetadata<S>,
1032    sender: EngineSender<S>,
1033}
1034
1035type StepMetadataFn<'a, S> = Box<
1036    dyn FnOnce(
1037            MetadataContext<S>,
1038        ) -> BoxFuture<'a, <S as EngineSpec>::StepMetadata>
1039        + Send
1040        + 'a,
1041>;
1042
1043/// NOTE: Ideally this would take `&mut StepContext<S>`, so that it
1044/// can't get squirreled away by a step's function. However, that
1045/// quickly runs into [this issue in
1046/// Rust](https://users.rust-lang.org/t/passing-self-to-callback-returning-future-vs-lifetimes/53352).
1047///
1048/// It is probably possible to use unsafe code here, though that opens
1049/// up its own can of worms.
1050type StepExecFn<'a, S> = Box<
1051    dyn FnOnce(
1052            StepContext<S>,
1053        ) -> BoxFuture<
1054            'a,
1055            Result<StepOutcome<S>, <S as EngineSpec>::Error>,
1056        > + Send
1057        + 'a,
1058>;
1059
1060struct StepProgressReporter<S: EngineSpec, F> {
1061    execution_id: ExecutionUuid,
1062    next_event_index: F,
1063    total_start: Instant,
1064    step_info: StepInfoWithMetadata<S>,
1065    step_start: Instant,
1066    attempt: usize,
1067    attempt_start: Instant,
1068    sender: EngineSender<S>,
1069}
1070
1071impl<S: EngineSpec, F: FnMut() -> usize> StepProgressReporter<S, F> {
1072    fn new(step_exec_cx: StepExecutionContext<S, F>) -> Self {
1073        let step_start = Instant::now();
1074        Self {
1075            execution_id: step_exec_cx.execution_id,
1076            next_event_index: step_exec_cx.next_event_index.0,
1077            total_start: step_exec_cx.total_start,
1078            step_info: step_exec_cx.step_info,
1079            step_start,
1080            attempt: 1,
1081            // It's slightly nicer for step_start and attempt_start
1082            // to be exactly the same.
1083            attempt_start: step_start,
1084            sender: step_exec_cx.sender,
1085        }
1086    }
1087
1088    async fn handle_payload(
1089        &mut self,
1090        payload: StepContextPayload<S>,
1091    ) -> Result<(), ExecutionError<S>> {
1092        match payload {
1093            StepContextPayload::Progress { now, progress, done } => {
1094                self.handle_progress(now, progress).await?;
1095                // Dropping the sender signals the receiver that processing is
1096                // complete.
1097                drop(done);
1098            }
1099            StepContextPayload::NestedSingle { now, event, done } => {
1100                self.handle_nested(now, event).await?;
1101                drop(done);
1102            }
1103            StepContextPayload::Nested { now, event } => {
1104                self.handle_nested(now, event).await?;
1105            }
1106            StepContextPayload::Sync { done } => {
1107                drop(done);
1108            }
1109        }
1110
1111        Ok(())
1112    }
1113
1114    async fn handle_progress(
1115        &mut self,
1116        now: Instant,
1117        progress: StepProgress<S>,
1118    ) -> Result<(), ExecutionError<S>> {
1119        match progress {
1120            StepProgress::Progress { progress, metadata } => {
1121                // Send the progress to the sender.
1122                self.sender
1123                    .send(Event::Progress(ProgressEvent {
1124                        spec: S::spec_name(),
1125                        execution_id: self.execution_id,
1126                        total_elapsed: now - self.total_start,
1127                        kind: ProgressEventKind::Progress {
1128                            step: self.step_info.clone(),
1129                            attempt: self.attempt,
1130                            progress,
1131                            metadata,
1132                            step_elapsed: now - self.step_start,
1133                            attempt_elapsed: now - self.attempt_start,
1134                        },
1135                    }))
1136                    .await
1137            }
1138            StepProgress::Reset { metadata, message } => {
1139                // Send a progress reset message, but do not reset
1140                // the attempt.
1141                self.sender
1142                    .send(Event::Step(StepEvent {
1143                        spec: S::spec_name(),
1144                        execution_id: self.execution_id,
1145                        event_index: (self.next_event_index)(),
1146                        total_elapsed: now - self.total_start,
1147                        kind: StepEventKind::ProgressReset {
1148                            step: self.step_info.clone(),
1149                            attempt: self.attempt,
1150                            metadata,
1151                            step_elapsed: now - self.step_start,
1152                            attempt_elapsed: now - self.attempt_start,
1153                            message,
1154                        },
1155                    }))
1156                    .await
1157            }
1158            StepProgress::Retry { message } => {
1159                // Retry this step.
1160                self.attempt += 1;
1161                let attempt_elapsed = now - self.attempt_start;
1162                self.attempt_start = Instant::now();
1163
1164                // Send the retry message.
1165                self.sender
1166                    .send(Event::Step(StepEvent {
1167                        spec: S::spec_name(),
1168                        execution_id: self.execution_id,
1169                        event_index: (self.next_event_index)(),
1170                        total_elapsed: now - self.total_start,
1171                        kind: StepEventKind::AttemptRetry {
1172                            step: self.step_info.clone(),
1173                            next_attempt: self.attempt,
1174                            step_elapsed: now - self.step_start,
1175                            attempt_elapsed,
1176                            message,
1177                        },
1178                    }))
1179                    .await
1180            }
1181        }
1182    }
1183
1184    async fn handle_nested(
1185        &mut self,
1186        now: Instant,
1187        event: Event<GenericSpec>,
1188    ) -> Result<(), ExecutionError<S>> {
1189        match event {
1190            Event::Step(event) => {
1191                self.sender
1192                    .send(Event::Step(StepEvent {
1193                        spec: S::spec_name(),
1194                        execution_id: self.execution_id,
1195                        event_index: (self.next_event_index)(),
1196                        total_elapsed: now - self.total_start,
1197                        kind: StepEventKind::Nested {
1198                            step: self.step_info.clone(),
1199                            attempt: self.attempt,
1200                            event: Box::new(event),
1201                            step_elapsed: now - self.step_start,
1202                            attempt_elapsed: now - self.attempt_start,
1203                        },
1204                    }))
1205                    .await
1206            }
1207            Event::Progress(event) => {
1208                self.sender
1209                    .send(Event::Progress(ProgressEvent {
1210                        spec: S::spec_name(),
1211                        execution_id: self.execution_id,
1212                        total_elapsed: now - self.total_start,
1213                        kind: ProgressEventKind::Nested {
1214                            step: self.step_info.clone(),
1215                            attempt: self.attempt,
1216                            event: Box::new(event),
1217                            step_elapsed: now - self.step_start,
1218                            attempt_elapsed: now - self.attempt_start,
1219                        },
1220                    }))
1221                    .await
1222            }
1223        }
1224    }
1225
1226    async fn handle_abort(mut self, message: String) -> ExecutionError<S> {
1227        // Send the abort message over the channel.
1228        //
1229        // The only way this can fail is if the event receiver is
1230        // closed or dropped. That failure doesn't have any
1231        // implications on whether this aborts or not.
1232        let res = self
1233            .sender
1234            .send(Event::Step(StepEvent {
1235                spec: S::spec_name(),
1236                execution_id: self.execution_id,
1237                event_index: (self.next_event_index)(),
1238                total_elapsed: self.total_start.elapsed(),
1239                kind: StepEventKind::ExecutionAborted {
1240                    aborted_step: self.step_info.clone(),
1241                    attempt: self.attempt,
1242                    step_elapsed: self.step_start.elapsed(),
1243                    attempt_elapsed: self.attempt_start.elapsed(),
1244                    message: message.clone(),
1245                },
1246            }))
1247            .await;
1248
1249        match res {
1250            Ok(()) => ExecutionError::Aborted {
1251                component: self.step_info.info.component.clone(),
1252                id: self.step_info.info.id.clone(),
1253                description: self.step_info.info.description.clone(),
1254                message,
1255            },
1256            Err(error) => error,
1257        }
1258    }
1259
1260    async fn next_step(
1261        mut self,
1262        step_res: Result<StepOutcome<S>, S::Error>,
1263        next_step_info: &StepInfoWithMetadata<S>,
1264    ) -> Result<(), ExecutionError<S>> {
1265        match step_res {
1266            Ok(outcome) => {
1267                self.sender
1268                    .send(Event::Step(StepEvent {
1269                        spec: S::spec_name(),
1270                        execution_id: self.execution_id,
1271                        event_index: (self.next_event_index)(),
1272                        total_elapsed: self.total_start.elapsed(),
1273                        kind: StepEventKind::StepCompleted {
1274                            step: self.step_info,
1275                            attempt: self.attempt,
1276                            outcome,
1277                            next_step: next_step_info.clone(),
1278                            step_elapsed: self.step_start.elapsed(),
1279                            attempt_elapsed: self.attempt_start.elapsed(),
1280                        },
1281                    }))
1282                    .await?;
1283                Ok(())
1284            }
1285            Err(error) => {
1286                let component = self.step_info.info.component.clone();
1287                let id = self.step_info.info.id.clone();
1288                let description = self.step_info.info.description.clone();
1289                self.send_error(&error).await?;
1290                Err(ExecutionError::StepFailed {
1291                    component,
1292                    id,
1293                    description,
1294                    error,
1295                })
1296            }
1297        }
1298    }
1299
1300    async fn last_step(
1301        mut self,
1302        step_res: Result<StepOutcome<S>, S::Error>,
1303    ) -> Result<(), ExecutionError<S>> {
1304        match step_res {
1305            Ok(outcome) => {
1306                self.sender
1307                    .send(Event::Step(StepEvent {
1308                        spec: S::spec_name(),
1309                        execution_id: self.execution_id,
1310                        event_index: (self.next_event_index)(),
1311                        total_elapsed: self.total_start.elapsed(),
1312                        kind: StepEventKind::ExecutionCompleted {
1313                            last_step: self.step_info,
1314                            last_attempt: self.attempt,
1315                            last_outcome: outcome,
1316                            step_elapsed: self.step_start.elapsed(),
1317                            attempt_elapsed: self.attempt_start.elapsed(),
1318                        },
1319                    }))
1320                    .await?;
1321                Ok(())
1322            }
1323            Err(error) => {
1324                let component = self.step_info.info.component.clone();
1325                let id = self.step_info.info.id.clone();
1326                let description = self.step_info.info.description.clone();
1327                self.send_error(&error).await?;
1328                Err(ExecutionError::StepFailed {
1329                    component,
1330                    id,
1331                    description,
1332                    error,
1333                })
1334            }
1335        }
1336    }
1337
1338    async fn send_error(
1339        mut self,
1340        error: &S::Error,
1341    ) -> Result<(), ExecutionError<S>> {
1342        // Stringify `error` into a message + list causes; this is
1343        // written the way it is to avoid `error` potentially living
1344        // across the `.await` below (which can cause lifetime issues
1345        // in callers).
1346        let (message, causes) = {
1347            let error = error.as_error();
1348            let message = error.to_string();
1349
1350            let mut current = error;
1351            let mut causes = vec![];
1352            while let Some(source) = current.source() {
1353                causes.push(source.to_string());
1354                current = source;
1355            }
1356            (message, causes)
1357        };
1358
1359        self.sender
1360            .send(Event::Step(StepEvent {
1361                spec: S::spec_name(),
1362                execution_id: self.execution_id,
1363                event_index: (self.next_event_index)(),
1364                total_elapsed: self.total_start.elapsed(),
1365                kind: StepEventKind::ExecutionFailed {
1366                    failed_step: self.step_info,
1367                    total_attempts: self.attempt,
1368                    step_elapsed: self.step_start.elapsed(),
1369                    attempt_elapsed: self.attempt_start.elapsed(),
1370                    message,
1371                    causes,
1372                },
1373            }))
1374            .await
1375    }
1376}
1377
1378#[cfg(test)]
1379mod tests {
1380    use super::*;
1381    use anyhow::bail;
1382    use oxide_update_engine_test_utils::TestSpec;
1383    use tokio_stream::wrappers::ReceiverStream;
1384
1385    #[tokio::test]
1386    async fn error_exits_early() {
1387        let log = slog::Logger::root(slog::Discard, slog::o!());
1388
1389        let mut step_1_run = false;
1390        let mut step_2_run = false;
1391        let mut step_3_run = false;
1392
1393        // Make a buffer big enough that the engine can never fill it up.
1394        let (sender, receiver) = mpsc::channel(512);
1395        let engine: UpdateEngine<TestSpec> = UpdateEngine::new(&log, sender);
1396
1397        engine
1398            .new_step("foo".to_owned(), 0, "Step 1", |_| async {
1399                step_1_run = true;
1400                StepSuccess::new(()).into()
1401            })
1402            .register();
1403
1404        engine
1405            .new_step::<_, _, ()>("bar".to_owned(), 0, "Step 2", |_| async {
1406                step_2_run = true;
1407                bail!("example failed")
1408            })
1409            .register();
1410
1411        engine
1412            .new_step("baz".to_owned(), 0, "Step 3", |_| async {
1413                step_3_run = true;
1414                StepSuccess::new(()).into()
1415            })
1416            .register();
1417
1418        engine
1419            .execute()
1420            .await
1421            .expect_err("step 2 failed so we should see an error here");
1422
1423        let events: Vec<_> = ReceiverStream::new(receiver).collect().await;
1424        let last_event = events.last().unwrap();
1425        match last_event {
1426            Event::Step(step_event) => {
1427                assert!(
1428                    matches!(
1429                        &step_event.kind,
1430                        StepEventKind::ExecutionFailed { failed_step, message, .. }
1431                        if failed_step.info.component == "bar"
1432                        && message == "example failed"
1433                    ),
1434                    "event didn't match: {last_event:?}"
1435                )
1436            }
1437            _ => panic!("unexpected event: {last_event:?}"),
1438        }
1439
1440        assert!(step_1_run, "Step 1 was run");
1441        assert!(step_2_run, "Step 2 was run");
1442        assert!(!step_3_run, "Step 3 was not run");
1443    }
1444}