use derive_new::new; use burn_core::{prelude::*, record::Record, tensor::backend::AutodiffBackend}; use crate::TransitionBatch; /// An action along with additional context about the decision. #[derive(Clone, new)] pub struct ActionContext { /// The context. pub context: C, /// The action. pub action: A, } /// The state of a policy. pub trait PolicyState { /// The type of the record. type Record: Record; /// Convert the state to a record. fn into_record(self) -> Self::Record; /// Load the state from a record. fn load_record(&self, record: Self::Record) -> Self; } /// Trait for a RL policy. pub trait Policy: Clone { /// The observation given as input to the policy. type Observation; /// The action distribution parameters defining how the action will be sampled. type ActionDistribution; /// The action. type Action; /// Additional context on the policy's decision. type ActionContext; /// The current parameterization of the policy. type PolicyState: PolicyState; /// Produces the action distribution from a batch of observations. fn forward(&mut self, obs: Self::Observation) -> Self::ActionDistribution; /// Gives the action from a batch of observations. fn action( &mut self, obs: Self::Observation, deterministic: bool, ) -> (Self::Action, Vec); /// Update the policy's parameters. fn update(&mut self, update: Self::PolicyState); /// Returns the current parameterization. fn state(&self) -> Self::PolicyState; /// Loads the policy parameters from a record. fn load_record(self, record: >::Record) -> Self; } /// Trait for a type that can be batched and unbatched (split). pub trait Batchable: Sized { /// Create a batch from a list of items. fn batch(value: Vec) -> Self; /// Create a list from batched items. fn unbatch(self) -> Vec; } /// A training output. pub struct RLTrainOutput { /// The policy. pub policy: P, /// The item. pub item: TO, } /// Batched transitions for a PolicyLearner. pub type LearnerTransitionBatch = TransitionBatch>::Observation,

>::Action>; /// Learner for a policy. pub trait PolicyLearner where B: AutodiffBackend, >::Observation: Clone + Batchable, >::ActionDistribution: Clone + Batchable, >::Action: Clone + Batchable, { /// Additional context of a training step. type TrainContext; /// The policy to train. type InnerPolicy: Policy; /// The record of the learner. type Record: Record; /// Execute a training step on the policy. fn train( &mut self, input: LearnerTransitionBatch, ) -> RLTrainOutput>::PolicyState>; /// Returns the learner's current policy. fn policy(&self) -> Self::InnerPolicy; /// Update the learner's policy. fn update_policy(&mut self, update: Self::InnerPolicy); /// Convert the learner's state into a record. fn record(&self) -> Self::Record; /// Load the learner's state from a record. fn load_record(self, record: Self::Record) -> Self; }