use amethyst_input::is_close_requested;
use derivative::Derivative;
use crate::{ecs::World, GameData, StateEvent};
use std::fmt::{Debug, Display, Formatter, Result as FmtResult};
#[cfg(feature = "profiler")]
use thread_profiler::profile_scope;
#[derive(Debug)]
pub enum StateError {
NoStatesPresent,
}
impl Display for StateError {
fn fmt(&self, fmt: &mut Formatter<'_>) -> FmtResult {
match *self {
StateError::NoStatesPresent => write!(
fmt,
"Tried to start state machine without any states present"
),
}
}
}
#[allow(missing_debug_implementations)]
pub struct StateData<'a, T> {
pub world: &'a mut World,
pub data: &'a mut T,
}
impl<'a, T> StateData<'a, T>
where
T: 'a,
{
pub fn new(world: &'a mut World, data: &'a mut T) -> Self {
StateData { world, data }
}
}
pub enum Trans<T, E> {
None,
Pop,
Push(Box<dyn State<T, E>>),
Switch(Box<dyn State<T, E>>),
Replace(Box<dyn State<T, E>>),
NewStack(Vec<Box<dyn State<T, E>>>),
Sequence(Vec<Trans<T, E>>),
Quit,
}
impl<T, E> Debug for Trans<T, E> {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
match self {
Trans::None => f.write_str("None"),
Trans::Pop => f.write_str("Pop"),
Trans::Push(_) => f.write_str("Push"),
Trans::Switch(_) => f.write_str("Switch"),
Trans::Replace(_) => f.write_str("Replace"),
Trans::NewStack(_) => f.write_str("NewStack"),
Trans::Sequence(sequence) => f.write_str(&format!("Sequence {:?}", sequence)),
Trans::Quit => f.write_str("Quit"),
}
}
}
pub type TransEvent<T, E> = Box<dyn Fn() -> Trans<T, E> + Send + Sync + 'static>;
pub type EmptyTrans = Trans<(), StateEvent>;
pub type SimpleTrans = Trans<GameData<'static, 'static>, StateEvent>;
pub trait State<T, E: Send + Sync + 'static> {
fn on_start(&mut self, _data: StateData<'_, T>) {}
fn on_stop(&mut self, _data: StateData<'_, T>) {}
fn on_pause(&mut self, _data: StateData<'_, T>) {}
fn on_resume(&mut self, _data: StateData<'_, T>) {}
fn handle_event(&mut self, _data: StateData<'_, T>, _event: E) -> Trans<T, E> {
Trans::None
}
fn fixed_update(&mut self, _data: StateData<'_, T>) -> Trans<T, E> {
Trans::None
}
fn update(&mut self, _data: StateData<'_, T>) -> Trans<T, E> {
Trans::None
}
fn shadow_fixed_update(&mut self, _data: StateData<'_, T>) {}
fn shadow_update(&mut self, _data: StateData<'_, T>) {}
}
pub trait EmptyState {
fn on_start(&mut self, _data: StateData<'_, ()>) {}
fn on_stop(&mut self, _data: StateData<'_, ()>) {}
fn on_pause(&mut self, _data: StateData<'_, ()>) {}
fn on_resume(&mut self, _data: StateData<'_, ()>) {}
fn handle_event(&mut self, _data: StateData<'_, ()>, event: StateEvent) -> EmptyTrans {
if let StateEvent::Window(event) = &event {
if is_close_requested(&event) {
Trans::Quit
} else {
Trans::None
}
} else {
Trans::None
}
}
fn fixed_update(&mut self, _data: StateData<'_, ()>) -> EmptyTrans {
Trans::None
}
fn update(&mut self, _data: StateData<'_, ()>) -> EmptyTrans {
Trans::None
}
fn shadow_fixed_update(&mut self, _data: StateData<'_, ()>) {}
fn shadow_update(&mut self, _data: StateData<'_, ()>) {}
}
impl<T: EmptyState> State<(), StateEvent> for T {
fn on_start(&mut self, data: StateData<'_, ()>) {
self.on_start(data)
}
fn on_stop(&mut self, data: StateData<'_, ()>) {
self.on_stop(data)
}
fn on_pause(&mut self, data: StateData<'_, ()>) {
self.on_pause(data)
}
fn on_resume(&mut self, data: StateData<'_, ()>) {
self.on_resume(data)
}
fn handle_event(&mut self, data: StateData<'_, ()>, event: StateEvent) -> EmptyTrans {
self.handle_event(data, event)
}
fn fixed_update(&mut self, data: StateData<'_, ()>) -> EmptyTrans {
self.fixed_update(data)
}
fn update(&mut self, data: StateData<'_, ()>) -> EmptyTrans {
self.update(data)
}
fn shadow_fixed_update(&mut self, data: StateData<'_, ()>) {
self.shadow_fixed_update(data);
}
fn shadow_update(&mut self, data: StateData<'_, ()>) {
self.shadow_update(data);
}
}
pub trait SimpleState {
fn on_start(&mut self, _data: StateData<'_, GameData<'_, '_>>) {}
fn on_stop(&mut self, _data: StateData<'_, GameData<'_, '_>>) {}
fn on_pause(&mut self, _data: StateData<'_, GameData<'_, '_>>) {}
fn on_resume(&mut self, _data: StateData<'_, GameData<'_, '_>>) {}
fn handle_event(
&mut self,
_data: StateData<'_, GameData<'_, '_>>,
event: StateEvent,
) -> SimpleTrans {
if let StateEvent::Window(event) = &event {
if is_close_requested(&event) {
Trans::Quit
} else {
Trans::None
}
} else {
Trans::None
}
}
fn fixed_update(&mut self, _data: StateData<'_, GameData<'_, '_>>) -> SimpleTrans {
Trans::None
}
fn update(&mut self, _data: &mut StateData<'_, GameData<'_, '_>>) -> SimpleTrans {
Trans::None
}
fn shadow_fixed_update(&mut self, _data: StateData<'_, GameData<'_, '_>>) {}
fn shadow_update(&mut self, _data: StateData<'_, GameData<'_, '_>>) {}
}
impl<T: SimpleState> State<GameData<'static, 'static>, StateEvent> for T {
fn on_start(&mut self, data: StateData<'_, GameData<'_, '_>>) {
self.on_start(data)
}
fn on_stop(&mut self, data: StateData<'_, GameData<'_, '_>>) {
self.on_stop(data)
}
fn on_pause(&mut self, data: StateData<'_, GameData<'_, '_>>) {
self.on_pause(data)
}
fn on_resume(&mut self, data: StateData<'_, GameData<'_, '_>>) {
self.on_resume(data)
}
fn handle_event(
&mut self,
data: StateData<'_, GameData<'_, '_>>,
event: StateEvent,
) -> SimpleTrans {
self.handle_event(data, event)
}
fn fixed_update(&mut self, data: StateData<'_, GameData<'_, '_>>) -> SimpleTrans {
self.fixed_update(data)
}
fn update(&mut self, mut data: StateData<'_, GameData<'_, '_>>) -> SimpleTrans {
let r = self.update(&mut data);
data.data.update(&data.world);
r
}
fn shadow_fixed_update(&mut self, data: StateData<'_, GameData<'_, '_>>) {
self.shadow_fixed_update(data);
}
fn shadow_update(&mut self, data: StateData<'_, GameData<'_, '_>>) {
self.shadow_update(data);
}
}
#[derive(Derivative)]
#[derivative(Debug)]
pub struct StateMachine<'a, T, E> {
running: bool,
#[derivative(Debug = "ignore")]
state_stack: Vec<Box<dyn State<T, E> + 'a>>,
}
impl<'a, T, E: Send + Sync + 'static> StateMachine<'a, T, E> {
pub fn new<S: State<T, E> + 'a>(initial_state: S) -> StateMachine<'a, T, E> {
StateMachine {
running: false,
state_stack: vec![Box::new(initial_state)],
}
}
pub fn is_running(&self) -> bool {
self.running
}
pub fn start(&mut self, data: StateData<'_, T>) -> Result<(), StateError> {
if !self.running {
let state = self
.state_stack
.last_mut()
.ok_or(StateError::NoStatesPresent)?;
state.on_start(data);
self.running = true;
}
Ok(())
}
pub fn handle_event(&mut self, data: StateData<'_, T>, event: E) {
let StateData { world, data } = data;
if self.running {
let trans = match self.state_stack.last_mut() {
Some(state) => state.handle_event(StateData { world, data }, event),
None => Trans::None,
};
self.transition(trans, StateData { world, data });
}
}
pub fn fixed_update(&mut self, data: StateData<'_, T>) {
let StateData { world, data } = data;
if self.running {
let trans = match self.state_stack.last_mut() {
Some(state) => {
#[cfg(feature = "profiler")]
profile_scope!("stack fixed_update");
state.fixed_update(StateData { world, data })
}
None => Trans::None,
};
for state in &mut self.state_stack {
#[cfg(feature = "profiler")]
profile_scope!("stack shadow_fixed_update");
state.shadow_fixed_update(StateData { world, data });
}
{
#[cfg(feature = "profiler")]
profile_scope!("stack fixed transition");
self.transition(trans, StateData { world, data });
}
}
}
pub fn update(&mut self, data: StateData<'_, T>) {
let StateData { world, data } = data;
if self.running {
let trans = match self.state_stack.last_mut() {
Some(state) => {
#[cfg(feature = "profiler")]
profile_scope!("stack update");
state.update(StateData { world, data })
}
None => Trans::None,
};
for state in &mut self.state_stack {
#[cfg(feature = "profiler")]
profile_scope!("stack shadow_update");
state.shadow_update(StateData { world, data });
}
{
#[cfg(feature = "profiler")]
profile_scope!("stack transition");
self.transition(trans, StateData { world, data });
}
}
}
pub fn transition(&mut self, request: Trans<T, E>, data: StateData<'_, T>) {
if self.running {
match request {
Trans::None => (),
Trans::Pop => self.pop(data),
Trans::Push(state) => self.push(state, data),
Trans::Switch(state) => self.switch(state, data),
Trans::Replace(state) => self.replace(state, data),
Trans::NewStack(states) => self.new_stack(states, data),
Trans::Sequence(sequence) => {
for trans in sequence {
let temp_data = StateData {
world: data.world,
data: data.data,
};
self.transition(trans, temp_data);
}
}
Trans::Quit => self.stop(data),
}
}
}
fn switch(&mut self, state: Box<dyn State<T, E>>, data: StateData<'_, T>) {
if self.running {
let StateData { world, data } = data;
if let Some(mut state) = self.state_stack.pop() {
state.on_stop(StateData { world, data });
}
self.state_stack.push(state);
let new_state = self.state_stack.last_mut().unwrap();
new_state.on_start(StateData { world, data });
}
}
fn push(&mut self, state: Box<dyn State<T, E>>, data: StateData<'_, T>) {
if self.running {
let StateData { world, data } = data;
if let Some(state) = self.state_stack.last_mut() {
state.on_pause(StateData { world, data });
}
self.state_stack.push(state);
let new_state = self.state_stack.last_mut().unwrap();
new_state.on_start(StateData { world, data });
}
}
fn pop(&mut self, data: StateData<'_, T>) {
if self.running {
let StateData { world, data } = data;
if let Some(mut state) = self.state_stack.pop() {
state.on_stop(StateData { world, data });
}
if let Some(state) = self.state_stack.last_mut() {
state.on_resume(StateData { world, data });
} else {
self.running = false;
}
}
}
pub(crate) fn replace(&mut self, state: Box<dyn State<T, E>>, data: StateData<'_, T>) {
if self.running {
let StateData { world, data } = data;
while let Some(mut state) = self.state_stack.pop() {
state.on_stop(StateData { world, data });
}
self.state_stack.push(state);
let new_state = self.state_stack.last_mut().unwrap();
new_state.on_start(StateData { world, data });
}
}
pub(crate) fn new_stack(&mut self, states: Vec<Box<dyn State<T, E>>>, data: StateData<'_, T>) {
if self.running {
let StateData { world, data } = data;
while let Some(mut state) = self.state_stack.pop() {
state.on_stop(StateData { world, data });
}
let state_count = states.len();
for (count, state) in states.into_iter().enumerate() {
self.state_stack.push(state);
let new_state = self.state_stack.last_mut().unwrap();
new_state.on_start(StateData { world, data });
if count != state_count - 1 {
new_state.on_pause(StateData { world, data });
}
}
}
}
pub(crate) fn stop(&mut self, data: StateData<'_, T>) {
if self.running {
let StateData { world, data } = data;
while let Some(mut state) = self.state_stack.pop() {
state.on_stop(StateData { world, data });
}
self.running = false;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
struct State0;
struct State1(u8);
struct State2;
struct StateNewStack;
struct StateSequence;
struct StateReplace(u8);
impl State<(), ()> for State0 {
fn update(&mut self, _: StateData<'_, ()>) -> Trans<(), ()> {
Trans::None
}
}
impl State<(), ()> for State1 {
fn update(&mut self, _: StateData<'_, ()>) -> Trans<(), ()> {
if self.0 > 0 {
self.0 -= 1;
Trans::None
} else {
Trans::Switch(Box::new(State2))
}
}
}
impl State<(), ()> for State2 {
fn update(&mut self, _: StateData<'_, ()>) -> Trans<(), ()> {
Trans::Pop
}
}
impl State<(), ()> for StateNewStack {
fn update(&mut self, _: StateData<'_, ()>) -> Trans<(), ()> {
Trans::NewStack(vec![
Box::new(State0),
Box::new(State0),
Box::new(State0),
Box::new(State0),
])
}
}
impl State<(), ()> for StateSequence {
fn update(&mut self, _: StateData<'_, ()>) -> Trans<(), ()> {
Trans::Sequence(vec![
Trans::Push(Box::new(State0)),
Trans::Push(Box::new(State0)),
Trans::Push(Box::new(State0)),
Trans::Pop,
])
}
}
impl State<(), ()> for StateReplace {
fn update(&mut self, _: StateData<'_, ()>) -> Trans<(), ()> {
if self.0 == 0 {
Trans::Replace(Box::new(State0))
} else {
Trans::Push(Box::new(StateReplace(self.0 - 1)))
}
}
}
#[test]
fn switch_pop() {
use crate::ecs::prelude::{World, WorldExt};
let mut world = World::new();
let mut sm = StateMachine::new(State1(7));
sm.start(StateData::new(&mut world, &mut ())).unwrap();
for _ in 0..8 {
sm.update(StateData::new(&mut world, &mut ()));
assert!(sm.is_running());
}
sm.update(StateData::new(&mut world, &mut ()));
assert!(!sm.is_running());
}
#[test]
fn new_stack() {
use crate::ecs::prelude::{World, WorldExt};
let mut world = World::new();
let mut sm = StateMachine::new(StateNewStack);
sm.start(StateData::new(&mut world, &mut ())).unwrap();
sm.update(StateData::new(&mut world, &mut ()));
assert_eq!(sm.state_stack.len(), 4);
}
#[test]
fn sequence() {
use crate::ecs::prelude::{World, WorldExt};
let mut world = World::new();
let mut sm = StateMachine::new(StateSequence);
sm.start(StateData::new(&mut world, &mut ())).unwrap();
sm.update(StateData::new(&mut world, &mut ()));
assert_eq!(sm.state_stack.len(), 3);
}
#[test]
fn replace() {
use crate::ecs::prelude::{World, WorldExt};
let mut world = World::new();
let mut sm = StateMachine::new(StateReplace(3));
sm.start(StateData::new(&mut world, &mut ())).unwrap();
for i in 0..3 {
sm.update(StateData::new(&mut world, &mut ()));
assert_eq!(sm.state_stack.len(), i + 2);
}
sm.update(StateData::new(&mut world, &mut ()));
assert_eq!(sm.state_stack.len(), 1);
}
}