feat: update workspace paths and enhance gitignore

- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution
- Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory
- Added Cargo.lock to gitignore with appropriate comment
- Reorganized IDE files section in gitignore for better clarity
- Added newline at end of file for proper formatting
This commit is contained in:
2026-03-05 19:39:14 +01:00
parent 4bb7ca9074
commit 3a67c0979c
1605 changed files with 537032 additions and 2 deletions

View File

@@ -0,0 +1,31 @@
[package]
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
categories = ["science"]
description = "RL crate for the Burn framework"
edition.workspace = true
keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"]
license.workspace = true
name = "burn-rl"
readme.workspace = true
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-rl"
documentation = "https://docs.rs/burn-rl"
version.workspace = true
[dependencies]
burn-core = { path = "../burn-core", version = "=0.21.0-pre.2", features = [
"dataset",
"std",
], default-features = false }
burn-optim = { path = "../burn-optim", version = "=0.21.0-pre.2", features = [
"std",
], default-features = false }
derive-new.workspace = true
log = { workspace = true }
rand.workspace = true
[dev-dependencies]
burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2" }
[lints]
workspace = true

View File

@@ -0,0 +1 @@
../../LICENSE-APACHE

View File

@@ -0,0 +1 @@
../../LICENSE-MIT

View File

@@ -0,0 +1,6 @@
# Burn RL
<!-- This crate should be used with [burn](https://github.com/tracel-ai/burn). -->
<!-- [![Current Crates.io Version](https://img.shields.io/crates/v/burn-rl.svg)](https://crates.io/crates/burn-rl)
[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/tracel-ai/burn-rl/blob/master/README.md) -->

View File

@@ -0,0 +1,46 @@
/// The result of taking a step in an environment.
pub struct StepResult<S> {
/// The updated state.
pub next_state: S,
/// The reward.
pub reward: f64,
/// If the environment reached a terminal state.
pub done: bool,
/// If the environment reached its max length.
pub truncated: bool,
}
/// Trait to be implemented for a RL environment.
pub trait Environment {
/// The type of the state.
type State;
/// The type of actions.
type Action;
/// The maximum number of step for one episode.
const MAX_STEPS: usize;
/// Returns the current state.
fn state(&self) -> Self::State;
/// Take a step in the environment given an action.
fn step(&mut self, action: Self::Action) -> StepResult<Self::State>;
/// Reset the environment to an initial state.
fn reset(&mut self);
}
/// Trait to define how to initialize an environment.
/// By default, any function returning an environment implements it.
pub trait EnvironmentInit<E: Environment>: Clone {
/// Initialize the environment.
fn init(&self) -> E;
}
impl<F, E> EnvironmentInit<E> for F
where
F: Fn() -> E + Clone,
E: Environment,
{
fn init(&self) -> E {
(self)()
}
}

View File

@@ -0,0 +1,3 @@
mod base;
pub use base::*;

View File

@@ -0,0 +1,166 @@
#![warn(missing_docs)]
#![cfg_attr(docsrs, feature(doc_cfg))]
//! A library for training reinforcement learning agents.
/// Module for implementing an environment.
pub mod environment;
/// Module for implementing a policy.
pub mod policy;
/// Transition buffer.
pub mod transition_buffer;
pub use environment::*;
pub use policy::*;
pub use transition_buffer::*;
#[cfg(test)]
pub(crate) type TestBackend = burn_ndarray::NdArray<f32>;
#[cfg(test)]
pub(crate) mod tests {
use crate::{Batchable, Policy, PolicyState, TestBackend};
use burn_core::record::Record;
use burn_core::{self as burn};
/// Mock policy for testing
///
/// Calling `forward()` with a [MockObservation](MockObservation) (list of f32) returns a [MockActionDistribution](MockActionDistribution)
/// containing a list of 0s of the same length as the observation.
///
/// Calling `action()` with a [MockObservation](MockObservation) (list of f32) returns a [MockAction](MockAction) with a list of actions of the same length as the observation.
/// The actions are all 1 if the call is requested as deterministic, or else 0.
#[derive(Clone)]
pub(crate) struct MockPolicy {}
impl MockPolicy {
pub fn new() -> Self {
Self {}
}
}
impl Policy<TestBackend> for MockPolicy {
type Observation = MockObservation;
type ActionDistribution = MockActionDistribution;
type Action = MockAction;
type ActionContext = MockActionContext;
type PolicyState = MockPolicyState;
fn forward(&mut self, obs: Self::Observation) -> Self::ActionDistribution {
let mut dists = vec![];
for _ in obs.0 {
dists.push(MockActionDistribution(vec![0.]));
}
MockActionDistribution::batch(dists)
}
fn action(
&mut self,
obs: Self::Observation,
deterministic: bool,
) -> (Self::Action, Vec<Self::ActionContext>) {
let mut actions = vec![];
let mut contexts = vec![];
for _ in obs.0 {
if deterministic {
actions.push(MockAction(vec![1]));
} else {
actions.push(MockAction(vec![0]));
}
contexts.push(MockActionContext);
}
(MockAction::batch(actions), contexts)
}
fn update(&mut self, _update: Self::PolicyState) {}
fn state(&self) -> Self::PolicyState {
MockPolicyState
}
fn load_record(
self,
_record: <Self::PolicyState as PolicyState<TestBackend>>::Record,
) -> Self {
self
}
}
/// Mock observation for testing represented as a vector of f32. Can call `batch()` and `unbatch` on it.
#[derive(Clone)]
pub(crate) struct MockObservation(pub Vec<f32>);
/// Mock action for testing represented as a vector of i32. Can call `batch()` and `unbatch` on it.
#[derive(Clone)]
pub(crate) struct MockAction(pub Vec<i32>);
/// Mock action distribution for testing represented as a vector of i32. Can call `batch()` and `unbatch` on it.
#[derive(Clone)]
pub(crate) struct MockActionDistribution(Vec<f32>);
#[derive(Clone)]
pub(crate) struct MockActionContext;
/// Mock policy state for testing represented as an arbitrary `usize` that has no effect on the policy.
#[derive(Clone)]
pub(crate) struct MockPolicyState;
#[derive(Clone, Record)]
pub(crate) struct MockRecord {
item: usize,
}
impl PolicyState<TestBackend> for MockPolicyState {
type Record = MockRecord;
fn into_record(self) -> Self::Record {
MockRecord { item: 0 }
}
fn load_record(&self, _record: Self::Record) -> Self {
self.clone()
}
}
impl Batchable for MockObservation {
fn batch(items: Vec<Self>) -> Self {
MockObservation(items.iter().flat_map(|m| m.0.clone()).collect())
}
fn unbatch(self) -> Vec<Self> {
vec![MockObservation(self.0)]
}
}
impl Batchable for MockAction {
fn batch(items: Vec<Self>) -> Self {
MockAction(items.iter().flat_map(|m| m.0.clone()).collect())
}
fn unbatch(self) -> Vec<Self> {
let mut actions = vec![];
for a in self.0 {
actions.push(MockAction(vec![a]));
}
actions
}
}
impl Batchable for MockActionDistribution {
fn batch(items: Vec<Self>) -> Self {
MockActionDistribution(items.iter().flat_map(|m| m.0.clone()).collect())
}
fn unbatch(self) -> Vec<Self> {
let mut dists = vec![];
for _ in self.0 {
dists.push(MockActionDistribution(vec![0.]));
}
dists
}
}
}

View File

@@ -0,0 +1,485 @@
use std::{
sync::{
Arc,
atomic::{AtomicUsize, Ordering},
mpsc::{self, Sender},
},
thread::spawn,
};
use burn_core::prelude::Backend;
use crate::{ActionContext, Batchable, Policy, PolicyState};
#[derive(Clone)]
struct PolicyInferenceServer<B: Backend, P: Policy<B>> {
// `num_agents` used to make sure autobatching doesn't block the agents if they are less than the autobatch size.
num_agents: Arc<AtomicUsize>,
max_autobatch_size: usize,
inner_policy: P,
batch_action: Vec<ActionItem<P::Observation, P::Action, P::ActionContext>>,
batch_logits: Vec<ForwardItem<P::Observation, P::ActionDistribution>>,
}
impl<B, P> PolicyInferenceServer<B, P>
where
B: Backend,
P: Policy<B>,
P::Observation: Clone + Batchable,
P::ActionDistribution: Clone + Batchable,
P::Action: Clone + Batchable,
P::ActionContext: Clone,
{
pub fn new(max_autobatch_size: usize, inner_policy: P) -> Self {
Self {
num_agents: Arc::new(AtomicUsize::new(0)),
max_autobatch_size,
inner_policy,
batch_action: vec![],
batch_logits: vec![],
}
}
pub fn push_action(&mut self, item: ActionItem<P::Observation, P::Action, P::ActionContext>) {
self.batch_action.push(item);
if self.len_actions()
>= self
.num_agents
.load(Ordering::Relaxed)
.min(self.max_autobatch_size)
{
self.flush_actions();
}
}
pub fn push_logits(&mut self, item: ForwardItem<P::Observation, P::ActionDistribution>) {
self.batch_logits.push(item);
if self.len_logits()
>= self
.num_agents
.load(Ordering::Relaxed)
.min(self.max_autobatch_size)
{
self.flush_logits();
}
}
pub fn len_actions(&self) -> usize {
self.batch_action.len()
}
pub fn len_logits(&self) -> usize {
self.batch_logits.len()
}
pub fn flush_actions(&mut self) {
if self.len_actions() == 0 {
return;
}
let input: Vec<_> = self
.batch_action
.iter()
.map(|m| m.inference_state.clone())
.collect();
// Only deterministic if all actions are requested as deterministic.
let deterministic = self.batch_action.iter().all(|item| item.deterministic);
let (actions, context) = self
.inner_policy
.action(P::Observation::batch(input), deterministic);
let actions: Vec<_> = actions.unbatch();
for (i, item) in self.batch_action.iter().enumerate() {
item.sender
.send(ActionContext {
context: vec![context[i].clone()],
action: actions[i].clone(),
})
.expect("Autobatcher should be able to send resulting actions.");
}
self.batch_action.clear();
}
pub fn flush_logits(&mut self) {
if self.len_logits() == 0 {
return;
}
let input: Vec<_> = self
.batch_logits
.iter()
.map(|m| m.inference_state.clone())
.collect();
let output = self.inner_policy.forward(P::Observation::batch(input));
let logits: Vec<_> = output.unbatch();
for (i, item) in self.batch_logits.iter().enumerate() {
item.sender
.send(logits[i].clone())
.expect("Autobatcher should be able to send resulting probabilities.");
}
self.batch_logits.clear();
}
pub fn update_policy(&mut self, policy_update: P::PolicyState) {
if self.len_actions() > 0 {
self.flush_actions();
}
if self.len_logits() > 0 {
self.flush_logits();
}
self.inner_policy.update(policy_update);
}
pub fn state(&self) -> P::PolicyState {
self.inner_policy.state()
}
pub fn increment_agents(&mut self, num: usize) {
self.num_agents.fetch_add(num, Ordering::Relaxed);
}
pub fn decrement_agents(&mut self, num: usize) {
self.num_agents.fetch_sub(num, Ordering::Relaxed);
if self.len_actions()
>= self
.num_agents
.load(Ordering::Relaxed)
.min(self.max_autobatch_size)
{
self.flush_actions();
}
if self.len_logits()
>= self
.num_agents
.load(Ordering::Relaxed)
.min(self.max_autobatch_size)
{
self.flush_logits();
}
}
}
enum InferenceMessage<B: Backend, P: Policy<B>> {
ActionMessage(ActionItem<P::Observation, P::Action, P::ActionContext>),
ForwardMessage(ForwardItem<P::Observation, P::ActionDistribution>),
PolicyUpdate(P::PolicyState),
PolicyRequest(Sender<P::PolicyState>),
IncrementAgents(usize),
DecrementAgents(usize),
}
#[derive(Clone)]
struct ActionItem<S, A, C> {
sender: Sender<ActionContext<A, Vec<C>>>,
inference_state: S,
deterministic: bool,
}
#[derive(Clone)]
struct ForwardItem<S, O> {
sender: Sender<O>,
inference_state: S,
}
/// An asynchronous policy using an inference server with autobatching.
#[derive(Clone)]
pub struct AsyncPolicy<B: Backend, P: Policy<B>> {
inference_state_sender: Sender<InferenceMessage<B, P>>,
}
impl<B, P> AsyncPolicy<B, P>
where
B: Backend,
P: Policy<B> + Clone + Send + 'static,
P::ActionContext: Clone + Send,
P::PolicyState: Send,
P::Observation: Clone + Send + Batchable,
P::ActionDistribution: Clone + Send + Batchable,
P::Action: Clone + Send + Batchable,
{
/// Create the policy.
///
/// # Arguments
///
/// * `autobatch_size` - Number of observations to accumulate before running a pass of inference.
/// * `inner_policy` - The policy used to take actions.
pub fn new(autobatch_size: usize, inner_policy: P) -> Self {
let (sender, receiver) = std::sync::mpsc::channel();
let mut autobatcher = PolicyInferenceServer::new(autobatch_size, inner_policy.clone());
spawn(move || {
loop {
match receiver.recv() {
Ok(msg) => match msg {
InferenceMessage::ActionMessage(item) => autobatcher.push_action(item),
InferenceMessage::ForwardMessage(item) => autobatcher.push_logits(item),
InferenceMessage::PolicyUpdate(update) => autobatcher.update_policy(update),
InferenceMessage::PolicyRequest(sender) => sender
.send(autobatcher.state())
.expect("Autobatcher should be able to send current policy state."),
InferenceMessage::IncrementAgents(num) => autobatcher.increment_agents(num),
InferenceMessage::DecrementAgents(num) => autobatcher.decrement_agents(num),
},
Err(err) => {
log::error!("Error in AsyncPolicy : {}", err);
break;
}
}
}
});
Self {
inference_state_sender: sender,
}
}
/// Increment the number of agents using the inference server.
pub fn increment_agents(&self, num: usize) {
self.inference_state_sender
.send(InferenceMessage::IncrementAgents(num))
.expect("Can send message to autobatcher.")
}
/// Decrement the number of agents using the inference server.
pub fn decrement_agents(&self, num: usize) {
self.inference_state_sender
.send(InferenceMessage::DecrementAgents(num))
.expect("Can send message to autobatcher.")
}
}
impl<B, P> Policy<B> for AsyncPolicy<B, P>
where
B: Backend,
P: Policy<B> + Send + 'static,
{
type ActionContext = P::ActionContext;
type PolicyState = P::PolicyState;
type Observation = P::Observation;
type ActionDistribution = P::ActionDistribution;
type Action = P::Action;
fn forward(&mut self, states: Self::Observation) -> Self::ActionDistribution {
let (action_sender, action_receiver) = std::sync::mpsc::channel();
let item = ForwardItem {
sender: action_sender,
inference_state: states,
};
self.inference_state_sender
.send(InferenceMessage::ForwardMessage(item))
.expect("Should be able to send message to inference_server");
action_receiver
.recv()
.expect("AsyncPolicy should receive queued probabilities.")
}
fn action(
&mut self,
states: Self::Observation,
deterministic: bool,
) -> (Self::Action, Vec<Self::ActionContext>) {
let (action_sender, action_receiver) = std::sync::mpsc::channel();
let item = ActionItem {
sender: action_sender,
inference_state: states,
deterministic,
};
self.inference_state_sender
.send(InferenceMessage::ActionMessage(item))
.expect("should be able to send message to inference_server.");
let action = action_receiver
.recv()
.expect("AsyncPolicy should receive queued actions.");
(action.action, action.context)
}
fn update(&mut self, update: Self::PolicyState) {
self.inference_state_sender
.send(InferenceMessage::PolicyUpdate(update))
.expect("AsyncPolicy should be able to send policy state.")
}
fn state(&self) -> Self::PolicyState {
let (sender, receiver) = mpsc::channel();
self.inference_state_sender
.send(InferenceMessage::PolicyRequest(sender))
.expect("should be able to send message to inference_server.");
receiver
.recv()
.expect("AsyncPolicy should be able to receive policy state.")
}
fn load_record(self, _record: <Self::PolicyState as PolicyState<B>>::Record) -> Self {
// Not needed for now
todo!()
}
}
#[cfg(test)]
#[allow(clippy::needless_range_loop)]
mod tests {
use std::thread::JoinHandle;
use std::time::Duration;
use crate::TestBackend;
use crate::tests::{MockAction, MockObservation, MockPolicy};
use super::*;
#[test]
fn test_multiple_actions_before_flush() {
fn launch_thread(
policy: &AsyncPolicy<TestBackend, MockPolicy>,
handles: &mut Vec<JoinHandle<()>>,
) {
let mut thread_policy = policy.clone();
let handle = spawn(move || {
thread_policy.action(MockObservation(vec![0.]), false);
});
handles.push(handle);
}
let policy = AsyncPolicy::new(8, MockPolicy::new());
policy.increment_agents(1000);
let mut handles = vec![];
launch_thread(&policy, &mut handles);
std::thread::sleep(Duration::from_millis(10));
assert!(!handles[0].is_finished());
for _ in 0..6 {
launch_thread(&policy, &mut handles);
}
std::thread::sleep(Duration::from_millis(10));
for i in 0..7 {
assert!(!handles[i].is_finished());
}
launch_thread(&policy, &mut handles);
std::thread::sleep(Duration::from_millis(10));
for i in 0..8 {
assert!(handles[i].is_finished());
}
let mut handles = vec![];
launch_thread(&policy, &mut handles);
std::thread::sleep(Duration::from_millis(10));
assert!(!handles[0].is_finished());
}
#[test]
fn test_multiple_forward_before_flush() {
fn launch_thread(
policy: &AsyncPolicy<TestBackend, MockPolicy>,
handles: &mut Vec<JoinHandle<()>>,
) {
let mut thread_policy = policy.clone();
let handle = spawn(move || {
thread_policy.forward(MockObservation(vec![0.]));
});
handles.push(handle);
}
let policy = AsyncPolicy::new(8, MockPolicy::new());
policy.increment_agents(1000);
let mut handles = vec![];
launch_thread(&policy, &mut handles);
std::thread::sleep(Duration::from_millis(10));
assert!(!handles[0].is_finished());
for _ in 0..6 {
launch_thread(&policy, &mut handles);
}
std::thread::sleep(Duration::from_millis(10));
for i in 0..7 {
assert!(!handles[i].is_finished());
}
launch_thread(&policy, &mut handles);
std::thread::sleep(Duration::from_millis(10));
for i in 0..8 {
assert!(handles[i].is_finished());
}
let mut handles = vec![];
launch_thread(&policy, &mut handles);
std::thread::sleep(Duration::from_millis(10));
assert!(!handles[0].is_finished());
}
#[test]
fn test_async_policy_deterministic_behaviour() {
fn launch_thread(
policy: &AsyncPolicy<TestBackend, MockPolicy>,
handles: &mut Vec<JoinHandle<MockAction>>,
deterministic: bool,
) {
let mut thread_policy = policy.clone();
let handle = spawn(move || {
let (action, _) = thread_policy.action(MockObservation(vec![0.]), deterministic);
action
});
handles.push(handle);
}
let policy = AsyncPolicy::new(2, MockPolicy::new());
policy.increment_agents(1000);
let mut handles = vec![];
launch_thread(&policy, &mut handles, true);
launch_thread(&policy, &mut handles, false);
for _ in 0..2 {
let action = handles.pop().unwrap().join().unwrap();
assert_eq!(action.0, vec![0]);
}
let mut handles = vec![];
launch_thread(&policy, &mut handles, true);
launch_thread(&policy, &mut handles, true);
for _ in 0..2 {
let action = handles.pop().unwrap().join().unwrap();
assert_eq!(action.0, vec![1]);
}
}
#[test]
fn flush_when_running_agents_smaller_than_autobatch_size() {
fn launch_thread(
policy: &AsyncPolicy<TestBackend, MockPolicy>,
handles: &mut Vec<JoinHandle<()>>,
) {
let mut thread_policy = policy.clone();
let handle = spawn(move || {
thread_policy.action(MockObservation(vec![0.]), false);
});
handles.push(handle);
}
let policy = AsyncPolicy::new(8, MockPolicy::new());
policy.increment_agents(3);
let mut handles = vec![];
launch_thread(&policy, &mut handles);
launch_thread(&policy, &mut handles);
std::thread::sleep(Duration::from_millis(10));
assert!(!handles[0].is_finished());
assert!(!handles[1].is_finished());
launch_thread(&policy, &mut handles);
std::thread::sleep(Duration::from_millis(10));
for i in 0..3 {
assert!(handles[i].is_finished());
}
let mut handles = vec![];
launch_thread(&policy, &mut handles);
launch_thread(&policy, &mut handles);
std::thread::sleep(Duration::from_millis(10));
assert!(!handles[0].is_finished());
assert!(!handles[1].is_finished());
policy.decrement_agents(1);
std::thread::sleep(Duration::from_millis(10));
assert!(handles[0].is_finished());
assert!(handles[1].is_finished());
}
}

View File

@@ -0,0 +1,108 @@
use derive_new::new;
use burn_core::{prelude::*, record::Record, tensor::backend::AutodiffBackend};
use crate::TransitionBatch;
/// An action along with additional context about the decision.
#[derive(Clone, new)]
pub struct ActionContext<A, C> {
/// The context.
pub context: C,
/// The action.
pub action: A,
}
/// The state of a policy.
pub trait PolicyState<B: Backend> {
/// The type of the record.
type Record: Record<B>;
/// Convert the state to a record.
fn into_record(self) -> Self::Record;
/// Load the state from a record.
fn load_record(&self, record: Self::Record) -> Self;
}
/// Trait for a RL policy.
pub trait Policy<B: Backend>: Clone {
/// The observation given as input to the policy.
type Observation;
/// The action distribution parameters defining how the action will be sampled.
type ActionDistribution;
/// The action.
type Action;
/// Additional context on the policy's decision.
type ActionContext;
/// The current parameterization of the policy.
type PolicyState: PolicyState<B>;
/// Produces the action distribution from a batch of observations.
fn forward(&mut self, obs: Self::Observation) -> Self::ActionDistribution;
/// Gives the action from a batch of observations.
fn action(
&mut self,
obs: Self::Observation,
deterministic: bool,
) -> (Self::Action, Vec<Self::ActionContext>);
/// Update the policy's parameters.
fn update(&mut self, update: Self::PolicyState);
/// Returns the current parameterization.
fn state(&self) -> Self::PolicyState;
/// Loads the policy parameters from a record.
fn load_record(self, record: <Self::PolicyState as PolicyState<B>>::Record) -> Self;
}
/// Trait for a type that can be batched and unbatched (split).
pub trait Batchable: Sized {
/// Create a batch from a list of items.
fn batch(value: Vec<Self>) -> Self;
/// Create a list from batched items.
fn unbatch(self) -> Vec<Self>;
}
/// A training output.
pub struct RLTrainOutput<TO, P> {
/// The policy.
pub policy: P,
/// The item.
pub item: TO,
}
/// Batched transitions for a PolicyLearner.
pub type LearnerTransitionBatch<B, P> =
TransitionBatch<B, <P as Policy<B>>::Observation, <P as Policy<B>>::Action>;
/// Learner for a policy.
pub trait PolicyLearner<B>
where
B: AutodiffBackend,
<Self::InnerPolicy as Policy<B>>::Observation: Clone + Batchable,
<Self::InnerPolicy as Policy<B>>::ActionDistribution: Clone + Batchable,
<Self::InnerPolicy as Policy<B>>::Action: Clone + Batchable,
{
/// Additional context of a training step.
type TrainContext;
/// The policy to train.
type InnerPolicy: Policy<B>;
/// The record of the learner.
type Record: Record<B>;
/// Execute a training step on the policy.
fn train(
&mut self,
input: LearnerTransitionBatch<B, Self::InnerPolicy>,
) -> RLTrainOutput<Self::TrainContext, <Self::InnerPolicy as Policy<B>>::PolicyState>;
/// Returns the learner's current policy.
fn policy(&self) -> Self::InnerPolicy;
/// Update the learner's policy.
fn update_policy(&mut self, update: Self::InnerPolicy);
/// Convert the learner's state into a record.
fn record(&self) -> Self::Record;
/// Load the learner's state from a record.
fn load_record(self, record: Self::Record) -> Self;
}

View File

@@ -0,0 +1,5 @@
mod async_policy;
mod base;
pub use async_policy::*;
pub use base::*;

View File

@@ -0,0 +1,244 @@
use burn_core::{Tensor, prelude::Backend, tensor::Distribution};
use derive_new::new;
use super::SliceAccess;
/// A state transition in an environment.
#[derive(Clone, new)]
pub struct Transition<B: Backend, S, A> {
/// The initial state.
pub state: S,
/// The state after the step was taken.
pub next_state: S,
/// The action taken in the step.
pub action: A,
/// The reward.
pub reward: Tensor<B, 1>,
/// If the environment has reached a terminal state.
pub done: Tensor<B, 1>,
}
/// A batch of transitions.
pub struct TransitionBatch<B: Backend, SB, AB> {
/// Batched initial states.
pub states: SB,
/// Batched resulting states.
pub next_states: SB,
/// Batched actions.
pub actions: AB,
/// Batched rewards.
pub rewards: Tensor<B, 2>,
/// Batched flags for terminal states.
pub dones: Tensor<B, 2>,
}
/// A tensor-backed circular buffer for transitions.
///
/// Uses [`SliceAccess`] to store state and action batches in contiguous
/// tensor storage, enabling efficient random sampling via `select`.
/// The buffer lazily initializes its storage on the first `push` call.
pub struct TransitionBuffer<B: Backend, SB: SliceAccess<B>, AB: SliceAccess<B>> {
states: Option<SB>,
next_states: Option<SB>,
actions: Option<AB>,
rewards: Option<Tensor<B, 2>>,
dones: Option<Tensor<B, 2>>,
capacity: usize,
write_head: usize,
len: usize,
device: B::Device,
}
impl<B: Backend, SB: SliceAccess<B>, AB: SliceAccess<B>> TransitionBuffer<B, SB, AB> {
/// Creates a new buffer. Storage is lazily allocated on the first `push`.
pub fn new(capacity: usize, device: &B::Device) -> Self {
Self {
states: None,
next_states: None,
actions: None,
rewards: None,
dones: None,
capacity,
write_head: 0,
len: 0,
device: device.clone(),
}
}
fn ensure_init(&mut self, state: &SB, next_state: &SB, action: &AB) {
if self.states.is_none() {
self.states = Some(SB::zeros_like(state, self.capacity, &self.device));
self.next_states = Some(SB::zeros_like(next_state, self.capacity, &self.device));
self.actions = Some(AB::zeros_like(action, self.capacity, &self.device));
self.rewards = Some(Tensor::zeros([self.capacity, 1], &self.device));
self.dones = Some(Tensor::zeros([self.capacity, 1], &self.device));
}
}
/// Add a transition, overwriting the oldest if full.
pub fn push(&mut self, state: SB, next_state: SB, action: AB, reward: f32, done: bool) {
self.ensure_init(&state, &next_state, &action);
let idx = self.write_head % self.capacity;
self.states
.as_mut()
.unwrap()
.slice_assign_inplace(idx, state);
self.next_states
.as_mut()
.unwrap()
.slice_assign_inplace(idx, next_state);
self.actions
.as_mut()
.unwrap()
.slice_assign_inplace(idx, action);
let reward = Tensor::from_data([[reward]], &self.device);
self.rewards
.as_mut()
.unwrap()
.inplace(|r| r.slice_assign(idx..idx + 1, reward));
let done_val = if done { 1.0f32 } else { 0.0 };
let done = Tensor::from_data([[done_val]], &self.device);
self.dones
.as_mut()
.unwrap()
.inplace(|d| d.slice_assign(idx..idx + 1, done));
self.write_head += 1;
if self.len < self.capacity {
self.len += 1;
}
}
/// Sample a random batch of transitions.
pub fn sample(&self, batch_size: usize) -> TransitionBatch<B, SB, AB> {
assert!(batch_size <= self.len, "batch_size exceeds buffer length");
let indices = Tensor::<B, 1>::random(
[batch_size],
Distribution::Uniform(0.0, self.len as f64),
&self.device,
)
.int();
TransitionBatch {
states: self
.states
.as_ref()
.unwrap()
.clone()
.select(0, indices.clone()),
next_states: self
.next_states
.as_ref()
.unwrap()
.clone()
.select(0, indices.clone()),
actions: self
.actions
.as_ref()
.unwrap()
.clone()
.select(0, indices.clone()),
rewards: self
.rewards
.as_ref()
.unwrap()
.clone()
.select(0, indices.clone()),
dones: self.dones.as_ref().unwrap().clone().select(0, indices),
}
}
/// Current number of stored transitions.
pub fn len(&self) -> usize {
self.len
}
/// Whether the buffer is empty.
pub fn is_empty(&self) -> bool {
self.len == 0
}
/// Buffer capacity.
pub fn capacity(&self) -> usize {
self.capacity
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
type TB = Tensor<TestBackend, 2>;
fn push_transition(
buffer: &mut TransitionBuffer<TestBackend, TB, TB>,
device: &<TestBackend as Backend>::Device,
val: f32,
) {
let state = Tensor::<TestBackend, 2>::from_data([[val, val]], device);
let next_state = Tensor::<TestBackend, 2>::from_data([[val + 1.0, val + 1.0]], device);
let action = Tensor::<TestBackend, 2>::from_data([[val]], device);
buffer.push(state, next_state, action, val, false);
}
#[test]
fn push_increment_len() {
let device = Default::default();
let mut buffer = TransitionBuffer::<TestBackend, TB, TB>::new(5, &device);
assert_eq!(buffer.len(), 0);
assert!(buffer.is_empty());
push_transition(&mut buffer, &device, 1.0);
assert_eq!(buffer.len(), 1);
push_transition(&mut buffer, &device, 2.0);
assert_eq!(buffer.len(), 2);
}
#[test]
fn push_overwrites_when_full() {
let device = Default::default();
let mut buffer = TransitionBuffer::<TestBackend, TB, TB>::new(3, &device);
for i in 0..5 {
push_transition(&mut buffer, &device, i as f32);
}
assert_eq!(buffer.len(), 3);
assert_eq!(buffer.capacity(), 3);
}
#[test]
fn sample_returns_correct_shapes() {
let device = Default::default();
let mut buffer = TransitionBuffer::<TestBackend, TB, TB>::new(10, &device);
for i in 0..5 {
push_transition(&mut buffer, &device, i as f32);
}
let batch = buffer.sample(3);
assert_eq!(batch.states.dims(), [3, 2]);
assert_eq!(batch.next_states.dims(), [3, 2]);
assert_eq!(batch.actions.dims(), [3, 1]);
assert_eq!(batch.rewards.dims(), [3, 1]);
assert_eq!(batch.dones.dims(), [3, 1]);
}
#[test]
#[should_panic(expected = "batch_size exceeds buffer length")]
fn sample_panics_when_batch_too_large() {
let device = Default::default();
let mut buffer = TransitionBuffer::<TestBackend, TB, TB>::new(5, &device);
push_transition(&mut buffer, &device, 1.0);
buffer.sample(5);
}
}

View File

@@ -0,0 +1,5 @@
mod base;
mod slice_access;
pub use base::*;
pub use slice_access::*;

View File

@@ -0,0 +1,36 @@
use burn_core::prelude::*;
/// Trait for types that support tensor-like slice operations,
/// enabling storage in a [`TransitionBuffer`](super::TransitionBuffer).
///
/// Implement this trait for any type that wraps tensors and can be stored
/// in a replay buffer. The buffer uses these operations for:
/// - Pre-allocating storage (`zeros_like`)
/// - Writing transitions (`slice_assign_inplace`)
/// - Sampling batches (`select`)
pub trait SliceAccess<B: Backend>: Clone + Sized {
/// Create zeroed storage matching the shape of `sample` but with `capacity` rows
/// along the first dimension.
fn zeros_like(sample: &Self, capacity: usize, device: &B::Device) -> Self;
/// Select rows at the given indices along the specified dimension.
fn select(self, dim: usize, indices: Tensor<B, 1, Int>) -> Self;
/// Assign `value` at row `index` along the first dimension, in place.
fn slice_assign_inplace(&mut self, index: usize, value: Self);
}
impl<B: Backend> SliceAccess<B> for Tensor<B, 2> {
fn zeros_like(sample: &Self, capacity: usize, device: &B::Device) -> Self {
let feature_dim = sample.dims()[1];
Tensor::zeros([capacity, feature_dim], device)
}
fn select(self, dim: usize, indices: Tensor<B, 1, Int>) -> Self {
Tensor::select(self, dim, indices)
}
fn slice_assign_inplace(&mut self, index: usize, value: Self) {
self.inplace(|t| t.slice_assign(index..index + 1, value));
}
}