feat: update workspace paths and enhance gitignore
- Updated stablediffusion crate path from "../stable-diffusion-burn" to "./crates/stable-diffusion-burn" for proper workspace resolution - Enhanced .gitignore to include generated model files (.mpk, .pt, .bin, .safetensors, .ckpt) and user_data directory - Added Cargo.lock to gitignore with appropriate comment - Reorganized IDE files section in gitignore for better clarity - Added newline at end of file for proper formatting
This commit is contained in:
31
crates/stable-diffusion-burn/burn-crates/burn-rl/Cargo.toml
Normal file
31
crates/stable-diffusion-burn/burn-crates/burn-rl/Cargo.toml
Normal file
@@ -0,0 +1,31 @@
|
||||
[package]
|
||||
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
|
||||
categories = ["science"]
|
||||
description = "RL crate for the Burn framework"
|
||||
edition.workspace = true
|
||||
keywords = ["deep-learning", "machine-learning", "tensor", "pytorch", "ndarray"]
|
||||
license.workspace = true
|
||||
name = "burn-rl"
|
||||
readme.workspace = true
|
||||
repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-rl"
|
||||
documentation = "https://docs.rs/burn-rl"
|
||||
version.workspace = true
|
||||
|
||||
[dependencies]
|
||||
burn-core = { path = "../burn-core", version = "=0.21.0-pre.2", features = [
|
||||
"dataset",
|
||||
"std",
|
||||
], default-features = false }
|
||||
burn-optim = { path = "../burn-optim", version = "=0.21.0-pre.2", features = [
|
||||
"std",
|
||||
], default-features = false }
|
||||
|
||||
derive-new.workspace = true
|
||||
log = { workspace = true }
|
||||
rand.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
burn-ndarray = { path = "../burn-ndarray", version = "=0.21.0-pre.2" }
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
1
crates/stable-diffusion-burn/burn-crates/burn-rl/LICENSE-APACHE
Symbolic link
1
crates/stable-diffusion-burn/burn-crates/burn-rl/LICENSE-APACHE
Symbolic link
@@ -0,0 +1 @@
|
||||
../../LICENSE-APACHE
|
||||
1
crates/stable-diffusion-burn/burn-crates/burn-rl/LICENSE-MIT
Symbolic link
1
crates/stable-diffusion-burn/burn-crates/burn-rl/LICENSE-MIT
Symbolic link
@@ -0,0 +1 @@
|
||||
../../LICENSE-MIT
|
||||
@@ -0,0 +1,6 @@
|
||||
# Burn RL
|
||||
|
||||
<!-- This crate should be used with [burn](https://github.com/tracel-ai/burn). -->
|
||||
|
||||
<!-- [](https://crates.io/crates/burn-rl)
|
||||
[](https://github.com/tracel-ai/burn-rl/blob/master/README.md) -->
|
||||
@@ -0,0 +1,46 @@
|
||||
/// The result of taking a step in an environment.
|
||||
pub struct StepResult<S> {
|
||||
/// The updated state.
|
||||
pub next_state: S,
|
||||
/// The reward.
|
||||
pub reward: f64,
|
||||
/// If the environment reached a terminal state.
|
||||
pub done: bool,
|
||||
/// If the environment reached its max length.
|
||||
pub truncated: bool,
|
||||
}
|
||||
|
||||
/// Trait to be implemented for a RL environment.
|
||||
pub trait Environment {
|
||||
/// The type of the state.
|
||||
type State;
|
||||
/// The type of actions.
|
||||
type Action;
|
||||
|
||||
/// The maximum number of step for one episode.
|
||||
const MAX_STEPS: usize;
|
||||
|
||||
/// Returns the current state.
|
||||
fn state(&self) -> Self::State;
|
||||
/// Take a step in the environment given an action.
|
||||
fn step(&mut self, action: Self::Action) -> StepResult<Self::State>;
|
||||
/// Reset the environment to an initial state.
|
||||
fn reset(&mut self);
|
||||
}
|
||||
|
||||
/// Trait to define how to initialize an environment.
|
||||
/// By default, any function returning an environment implements it.
|
||||
pub trait EnvironmentInit<E: Environment>: Clone {
|
||||
/// Initialize the environment.
|
||||
fn init(&self) -> E;
|
||||
}
|
||||
|
||||
impl<F, E> EnvironmentInit<E> for F
|
||||
where
|
||||
F: Fn() -> E + Clone,
|
||||
E: Environment,
|
||||
{
|
||||
fn init(&self) -> E {
|
||||
(self)()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
mod base;
|
||||
|
||||
pub use base::*;
|
||||
166
crates/stable-diffusion-burn/burn-crates/burn-rl/src/lib.rs
Normal file
166
crates/stable-diffusion-burn/burn-crates/burn-rl/src/lib.rs
Normal file
@@ -0,0 +1,166 @@
|
||||
#![warn(missing_docs)]
|
||||
#![cfg_attr(docsrs, feature(doc_cfg))]
|
||||
|
||||
//! A library for training reinforcement learning agents.
|
||||
|
||||
/// Module for implementing an environment.
|
||||
pub mod environment;
|
||||
/// Module for implementing a policy.
|
||||
pub mod policy;
|
||||
/// Transition buffer.
|
||||
pub mod transition_buffer;
|
||||
|
||||
pub use environment::*;
|
||||
pub use policy::*;
|
||||
pub use transition_buffer::*;
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) type TestBackend = burn_ndarray::NdArray<f32>;
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) mod tests {
|
||||
use crate::{Batchable, Policy, PolicyState, TestBackend};
|
||||
|
||||
use burn_core::record::Record;
|
||||
use burn_core::{self as burn};
|
||||
|
||||
/// Mock policy for testing
|
||||
///
|
||||
/// Calling `forward()` with a [MockObservation](MockObservation) (list of f32) returns a [MockActionDistribution](MockActionDistribution)
|
||||
/// containing a list of 0s of the same length as the observation.
|
||||
///
|
||||
/// Calling `action()` with a [MockObservation](MockObservation) (list of f32) returns a [MockAction](MockAction) with a list of actions of the same length as the observation.
|
||||
/// The actions are all 1 if the call is requested as deterministic, or else 0.
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct MockPolicy {}
|
||||
|
||||
impl MockPolicy {
|
||||
pub fn new() -> Self {
|
||||
Self {}
|
||||
}
|
||||
}
|
||||
|
||||
impl Policy<TestBackend> for MockPolicy {
|
||||
type Observation = MockObservation;
|
||||
type ActionDistribution = MockActionDistribution;
|
||||
type Action = MockAction;
|
||||
type ActionContext = MockActionContext;
|
||||
type PolicyState = MockPolicyState;
|
||||
|
||||
fn forward(&mut self, obs: Self::Observation) -> Self::ActionDistribution {
|
||||
let mut dists = vec![];
|
||||
|
||||
for _ in obs.0 {
|
||||
dists.push(MockActionDistribution(vec![0.]));
|
||||
}
|
||||
MockActionDistribution::batch(dists)
|
||||
}
|
||||
|
||||
fn action(
|
||||
&mut self,
|
||||
obs: Self::Observation,
|
||||
deterministic: bool,
|
||||
) -> (Self::Action, Vec<Self::ActionContext>) {
|
||||
let mut actions = vec![];
|
||||
let mut contexts = vec![];
|
||||
|
||||
for _ in obs.0 {
|
||||
if deterministic {
|
||||
actions.push(MockAction(vec![1]));
|
||||
} else {
|
||||
actions.push(MockAction(vec![0]));
|
||||
}
|
||||
contexts.push(MockActionContext);
|
||||
}
|
||||
|
||||
(MockAction::batch(actions), contexts)
|
||||
}
|
||||
|
||||
fn update(&mut self, _update: Self::PolicyState) {}
|
||||
|
||||
fn state(&self) -> Self::PolicyState {
|
||||
MockPolicyState
|
||||
}
|
||||
|
||||
fn load_record(
|
||||
self,
|
||||
_record: <Self::PolicyState as PolicyState<TestBackend>>::Record,
|
||||
) -> Self {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Mock observation for testing represented as a vector of f32. Can call `batch()` and `unbatch` on it.
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct MockObservation(pub Vec<f32>);
|
||||
|
||||
/// Mock action for testing represented as a vector of i32. Can call `batch()` and `unbatch` on it.
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct MockAction(pub Vec<i32>);
|
||||
|
||||
/// Mock action distribution for testing represented as a vector of i32. Can call `batch()` and `unbatch` on it.
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct MockActionDistribution(Vec<f32>);
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct MockActionContext;
|
||||
|
||||
/// Mock policy state for testing represented as an arbitrary `usize` that has no effect on the policy.
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct MockPolicyState;
|
||||
|
||||
#[derive(Clone, Record)]
|
||||
pub(crate) struct MockRecord {
|
||||
item: usize,
|
||||
}
|
||||
|
||||
impl PolicyState<TestBackend> for MockPolicyState {
|
||||
type Record = MockRecord;
|
||||
|
||||
fn into_record(self) -> Self::Record {
|
||||
MockRecord { item: 0 }
|
||||
}
|
||||
|
||||
fn load_record(&self, _record: Self::Record) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl Batchable for MockObservation {
|
||||
fn batch(items: Vec<Self>) -> Self {
|
||||
MockObservation(items.iter().flat_map(|m| m.0.clone()).collect())
|
||||
}
|
||||
|
||||
fn unbatch(self) -> Vec<Self> {
|
||||
vec![MockObservation(self.0)]
|
||||
}
|
||||
}
|
||||
|
||||
impl Batchable for MockAction {
|
||||
fn batch(items: Vec<Self>) -> Self {
|
||||
MockAction(items.iter().flat_map(|m| m.0.clone()).collect())
|
||||
}
|
||||
|
||||
fn unbatch(self) -> Vec<Self> {
|
||||
let mut actions = vec![];
|
||||
for a in self.0 {
|
||||
actions.push(MockAction(vec![a]));
|
||||
}
|
||||
actions
|
||||
}
|
||||
}
|
||||
|
||||
impl Batchable for MockActionDistribution {
|
||||
fn batch(items: Vec<Self>) -> Self {
|
||||
MockActionDistribution(items.iter().flat_map(|m| m.0.clone()).collect())
|
||||
}
|
||||
|
||||
fn unbatch(self) -> Vec<Self> {
|
||||
let mut dists = vec![];
|
||||
for _ in self.0 {
|
||||
dists.push(MockActionDistribution(vec![0.]));
|
||||
}
|
||||
dists
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,485 @@
|
||||
use std::{
|
||||
sync::{
|
||||
Arc,
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
mpsc::{self, Sender},
|
||||
},
|
||||
thread::spawn,
|
||||
};
|
||||
|
||||
use burn_core::prelude::Backend;
|
||||
|
||||
use crate::{ActionContext, Batchable, Policy, PolicyState};
|
||||
|
||||
#[derive(Clone)]
|
||||
struct PolicyInferenceServer<B: Backend, P: Policy<B>> {
|
||||
// `num_agents` used to make sure autobatching doesn't block the agents if they are less than the autobatch size.
|
||||
num_agents: Arc<AtomicUsize>,
|
||||
max_autobatch_size: usize,
|
||||
inner_policy: P,
|
||||
batch_action: Vec<ActionItem<P::Observation, P::Action, P::ActionContext>>,
|
||||
batch_logits: Vec<ForwardItem<P::Observation, P::ActionDistribution>>,
|
||||
}
|
||||
|
||||
impl<B, P> PolicyInferenceServer<B, P>
|
||||
where
|
||||
B: Backend,
|
||||
P: Policy<B>,
|
||||
P::Observation: Clone + Batchable,
|
||||
P::ActionDistribution: Clone + Batchable,
|
||||
P::Action: Clone + Batchable,
|
||||
P::ActionContext: Clone,
|
||||
{
|
||||
pub fn new(max_autobatch_size: usize, inner_policy: P) -> Self {
|
||||
Self {
|
||||
num_agents: Arc::new(AtomicUsize::new(0)),
|
||||
max_autobatch_size,
|
||||
inner_policy,
|
||||
batch_action: vec![],
|
||||
batch_logits: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push_action(&mut self, item: ActionItem<P::Observation, P::Action, P::ActionContext>) {
|
||||
self.batch_action.push(item);
|
||||
if self.len_actions()
|
||||
>= self
|
||||
.num_agents
|
||||
.load(Ordering::Relaxed)
|
||||
.min(self.max_autobatch_size)
|
||||
{
|
||||
self.flush_actions();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push_logits(&mut self, item: ForwardItem<P::Observation, P::ActionDistribution>) {
|
||||
self.batch_logits.push(item);
|
||||
if self.len_logits()
|
||||
>= self
|
||||
.num_agents
|
||||
.load(Ordering::Relaxed)
|
||||
.min(self.max_autobatch_size)
|
||||
{
|
||||
self.flush_logits();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn len_actions(&self) -> usize {
|
||||
self.batch_action.len()
|
||||
}
|
||||
|
||||
pub fn len_logits(&self) -> usize {
|
||||
self.batch_logits.len()
|
||||
}
|
||||
|
||||
pub fn flush_actions(&mut self) {
|
||||
if self.len_actions() == 0 {
|
||||
return;
|
||||
}
|
||||
let input: Vec<_> = self
|
||||
.batch_action
|
||||
.iter()
|
||||
.map(|m| m.inference_state.clone())
|
||||
.collect();
|
||||
// Only deterministic if all actions are requested as deterministic.
|
||||
let deterministic = self.batch_action.iter().all(|item| item.deterministic);
|
||||
let (actions, context) = self
|
||||
.inner_policy
|
||||
.action(P::Observation::batch(input), deterministic);
|
||||
let actions: Vec<_> = actions.unbatch();
|
||||
|
||||
for (i, item) in self.batch_action.iter().enumerate() {
|
||||
item.sender
|
||||
.send(ActionContext {
|
||||
context: vec![context[i].clone()],
|
||||
action: actions[i].clone(),
|
||||
})
|
||||
.expect("Autobatcher should be able to send resulting actions.");
|
||||
}
|
||||
self.batch_action.clear();
|
||||
}
|
||||
|
||||
pub fn flush_logits(&mut self) {
|
||||
if self.len_logits() == 0 {
|
||||
return;
|
||||
}
|
||||
let input: Vec<_> = self
|
||||
.batch_logits
|
||||
.iter()
|
||||
.map(|m| m.inference_state.clone())
|
||||
.collect();
|
||||
let output = self.inner_policy.forward(P::Observation::batch(input));
|
||||
let logits: Vec<_> = output.unbatch();
|
||||
for (i, item) in self.batch_logits.iter().enumerate() {
|
||||
item.sender
|
||||
.send(logits[i].clone())
|
||||
.expect("Autobatcher should be able to send resulting probabilities.");
|
||||
}
|
||||
self.batch_logits.clear();
|
||||
}
|
||||
|
||||
pub fn update_policy(&mut self, policy_update: P::PolicyState) {
|
||||
if self.len_actions() > 0 {
|
||||
self.flush_actions();
|
||||
}
|
||||
if self.len_logits() > 0 {
|
||||
self.flush_logits();
|
||||
}
|
||||
self.inner_policy.update(policy_update);
|
||||
}
|
||||
|
||||
pub fn state(&self) -> P::PolicyState {
|
||||
self.inner_policy.state()
|
||||
}
|
||||
|
||||
pub fn increment_agents(&mut self, num: usize) {
|
||||
self.num_agents.fetch_add(num, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub fn decrement_agents(&mut self, num: usize) {
|
||||
self.num_agents.fetch_sub(num, Ordering::Relaxed);
|
||||
if self.len_actions()
|
||||
>= self
|
||||
.num_agents
|
||||
.load(Ordering::Relaxed)
|
||||
.min(self.max_autobatch_size)
|
||||
{
|
||||
self.flush_actions();
|
||||
}
|
||||
if self.len_logits()
|
||||
>= self
|
||||
.num_agents
|
||||
.load(Ordering::Relaxed)
|
||||
.min(self.max_autobatch_size)
|
||||
{
|
||||
self.flush_logits();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum InferenceMessage<B: Backend, P: Policy<B>> {
|
||||
ActionMessage(ActionItem<P::Observation, P::Action, P::ActionContext>),
|
||||
ForwardMessage(ForwardItem<P::Observation, P::ActionDistribution>),
|
||||
PolicyUpdate(P::PolicyState),
|
||||
PolicyRequest(Sender<P::PolicyState>),
|
||||
IncrementAgents(usize),
|
||||
DecrementAgents(usize),
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ActionItem<S, A, C> {
|
||||
sender: Sender<ActionContext<A, Vec<C>>>,
|
||||
inference_state: S,
|
||||
deterministic: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ForwardItem<S, O> {
|
||||
sender: Sender<O>,
|
||||
inference_state: S,
|
||||
}
|
||||
|
||||
/// An asynchronous policy using an inference server with autobatching.
|
||||
#[derive(Clone)]
|
||||
pub struct AsyncPolicy<B: Backend, P: Policy<B>> {
|
||||
inference_state_sender: Sender<InferenceMessage<B, P>>,
|
||||
}
|
||||
|
||||
impl<B, P> AsyncPolicy<B, P>
|
||||
where
|
||||
B: Backend,
|
||||
P: Policy<B> + Clone + Send + 'static,
|
||||
P::ActionContext: Clone + Send,
|
||||
P::PolicyState: Send,
|
||||
P::Observation: Clone + Send + Batchable,
|
||||
P::ActionDistribution: Clone + Send + Batchable,
|
||||
P::Action: Clone + Send + Batchable,
|
||||
{
|
||||
/// Create the policy.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `autobatch_size` - Number of observations to accumulate before running a pass of inference.
|
||||
/// * `inner_policy` - The policy used to take actions.
|
||||
pub fn new(autobatch_size: usize, inner_policy: P) -> Self {
|
||||
let (sender, receiver) = std::sync::mpsc::channel();
|
||||
let mut autobatcher = PolicyInferenceServer::new(autobatch_size, inner_policy.clone());
|
||||
spawn(move || {
|
||||
loop {
|
||||
match receiver.recv() {
|
||||
Ok(msg) => match msg {
|
||||
InferenceMessage::ActionMessage(item) => autobatcher.push_action(item),
|
||||
InferenceMessage::ForwardMessage(item) => autobatcher.push_logits(item),
|
||||
InferenceMessage::PolicyUpdate(update) => autobatcher.update_policy(update),
|
||||
InferenceMessage::PolicyRequest(sender) => sender
|
||||
.send(autobatcher.state())
|
||||
.expect("Autobatcher should be able to send current policy state."),
|
||||
InferenceMessage::IncrementAgents(num) => autobatcher.increment_agents(num),
|
||||
InferenceMessage::DecrementAgents(num) => autobatcher.decrement_agents(num),
|
||||
},
|
||||
Err(err) => {
|
||||
log::error!("Error in AsyncPolicy : {}", err);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
inference_state_sender: sender,
|
||||
}
|
||||
}
|
||||
|
||||
/// Increment the number of agents using the inference server.
|
||||
pub fn increment_agents(&self, num: usize) {
|
||||
self.inference_state_sender
|
||||
.send(InferenceMessage::IncrementAgents(num))
|
||||
.expect("Can send message to autobatcher.")
|
||||
}
|
||||
|
||||
/// Decrement the number of agents using the inference server.
|
||||
pub fn decrement_agents(&self, num: usize) {
|
||||
self.inference_state_sender
|
||||
.send(InferenceMessage::DecrementAgents(num))
|
||||
.expect("Can send message to autobatcher.")
|
||||
}
|
||||
}
|
||||
|
||||
impl<B, P> Policy<B> for AsyncPolicy<B, P>
|
||||
where
|
||||
B: Backend,
|
||||
P: Policy<B> + Send + 'static,
|
||||
{
|
||||
type ActionContext = P::ActionContext;
|
||||
type PolicyState = P::PolicyState;
|
||||
|
||||
type Observation = P::Observation;
|
||||
type ActionDistribution = P::ActionDistribution;
|
||||
type Action = P::Action;
|
||||
|
||||
fn forward(&mut self, states: Self::Observation) -> Self::ActionDistribution {
|
||||
let (action_sender, action_receiver) = std::sync::mpsc::channel();
|
||||
let item = ForwardItem {
|
||||
sender: action_sender,
|
||||
inference_state: states,
|
||||
};
|
||||
self.inference_state_sender
|
||||
.send(InferenceMessage::ForwardMessage(item))
|
||||
.expect("Should be able to send message to inference_server");
|
||||
action_receiver
|
||||
.recv()
|
||||
.expect("AsyncPolicy should receive queued probabilities.")
|
||||
}
|
||||
|
||||
fn action(
|
||||
&mut self,
|
||||
states: Self::Observation,
|
||||
deterministic: bool,
|
||||
) -> (Self::Action, Vec<Self::ActionContext>) {
|
||||
let (action_sender, action_receiver) = std::sync::mpsc::channel();
|
||||
let item = ActionItem {
|
||||
sender: action_sender,
|
||||
inference_state: states,
|
||||
deterministic,
|
||||
};
|
||||
self.inference_state_sender
|
||||
.send(InferenceMessage::ActionMessage(item))
|
||||
.expect("should be able to send message to inference_server.");
|
||||
let action = action_receiver
|
||||
.recv()
|
||||
.expect("AsyncPolicy should receive queued actions.");
|
||||
(action.action, action.context)
|
||||
}
|
||||
|
||||
fn update(&mut self, update: Self::PolicyState) {
|
||||
self.inference_state_sender
|
||||
.send(InferenceMessage::PolicyUpdate(update))
|
||||
.expect("AsyncPolicy should be able to send policy state.")
|
||||
}
|
||||
|
||||
fn state(&self) -> Self::PolicyState {
|
||||
let (sender, receiver) = mpsc::channel();
|
||||
self.inference_state_sender
|
||||
.send(InferenceMessage::PolicyRequest(sender))
|
||||
.expect("should be able to send message to inference_server.");
|
||||
receiver
|
||||
.recv()
|
||||
.expect("AsyncPolicy should be able to receive policy state.")
|
||||
}
|
||||
|
||||
fn load_record(self, _record: <Self::PolicyState as PolicyState<B>>::Record) -> Self {
|
||||
// Not needed for now
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(clippy::needless_range_loop)]
|
||||
mod tests {
|
||||
use std::thread::JoinHandle;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::TestBackend;
|
||||
use crate::tests::{MockAction, MockObservation, MockPolicy};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_multiple_actions_before_flush() {
|
||||
fn launch_thread(
|
||||
policy: &AsyncPolicy<TestBackend, MockPolicy>,
|
||||
handles: &mut Vec<JoinHandle<()>>,
|
||||
) {
|
||||
let mut thread_policy = policy.clone();
|
||||
let handle = spawn(move || {
|
||||
thread_policy.action(MockObservation(vec![0.]), false);
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
let policy = AsyncPolicy::new(8, MockPolicy::new());
|
||||
policy.increment_agents(1000);
|
||||
|
||||
let mut handles = vec![];
|
||||
launch_thread(&policy, &mut handles);
|
||||
std::thread::sleep(Duration::from_millis(10));
|
||||
assert!(!handles[0].is_finished());
|
||||
|
||||
for _ in 0..6 {
|
||||
launch_thread(&policy, &mut handles);
|
||||
}
|
||||
std::thread::sleep(Duration::from_millis(10));
|
||||
for i in 0..7 {
|
||||
assert!(!handles[i].is_finished());
|
||||
}
|
||||
|
||||
launch_thread(&policy, &mut handles);
|
||||
std::thread::sleep(Duration::from_millis(10));
|
||||
for i in 0..8 {
|
||||
assert!(handles[i].is_finished());
|
||||
}
|
||||
|
||||
let mut handles = vec![];
|
||||
launch_thread(&policy, &mut handles);
|
||||
std::thread::sleep(Duration::from_millis(10));
|
||||
assert!(!handles[0].is_finished());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_forward_before_flush() {
|
||||
fn launch_thread(
|
||||
policy: &AsyncPolicy<TestBackend, MockPolicy>,
|
||||
handles: &mut Vec<JoinHandle<()>>,
|
||||
) {
|
||||
let mut thread_policy = policy.clone();
|
||||
let handle = spawn(move || {
|
||||
thread_policy.forward(MockObservation(vec![0.]));
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
let policy = AsyncPolicy::new(8, MockPolicy::new());
|
||||
policy.increment_agents(1000);
|
||||
|
||||
let mut handles = vec![];
|
||||
launch_thread(&policy, &mut handles);
|
||||
std::thread::sleep(Duration::from_millis(10));
|
||||
assert!(!handles[0].is_finished());
|
||||
|
||||
for _ in 0..6 {
|
||||
launch_thread(&policy, &mut handles);
|
||||
}
|
||||
std::thread::sleep(Duration::from_millis(10));
|
||||
for i in 0..7 {
|
||||
assert!(!handles[i].is_finished());
|
||||
}
|
||||
|
||||
launch_thread(&policy, &mut handles);
|
||||
std::thread::sleep(Duration::from_millis(10));
|
||||
for i in 0..8 {
|
||||
assert!(handles[i].is_finished());
|
||||
}
|
||||
|
||||
let mut handles = vec![];
|
||||
launch_thread(&policy, &mut handles);
|
||||
std::thread::sleep(Duration::from_millis(10));
|
||||
assert!(!handles[0].is_finished());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_async_policy_deterministic_behaviour() {
|
||||
fn launch_thread(
|
||||
policy: &AsyncPolicy<TestBackend, MockPolicy>,
|
||||
handles: &mut Vec<JoinHandle<MockAction>>,
|
||||
deterministic: bool,
|
||||
) {
|
||||
let mut thread_policy = policy.clone();
|
||||
let handle = spawn(move || {
|
||||
let (action, _) = thread_policy.action(MockObservation(vec![0.]), deterministic);
|
||||
action
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
let policy = AsyncPolicy::new(2, MockPolicy::new());
|
||||
policy.increment_agents(1000);
|
||||
|
||||
let mut handles = vec![];
|
||||
launch_thread(&policy, &mut handles, true);
|
||||
launch_thread(&policy, &mut handles, false);
|
||||
for _ in 0..2 {
|
||||
let action = handles.pop().unwrap().join().unwrap();
|
||||
assert_eq!(action.0, vec![0]);
|
||||
}
|
||||
|
||||
let mut handles = vec![];
|
||||
launch_thread(&policy, &mut handles, true);
|
||||
launch_thread(&policy, &mut handles, true);
|
||||
for _ in 0..2 {
|
||||
let action = handles.pop().unwrap().join().unwrap();
|
||||
assert_eq!(action.0, vec![1]);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flush_when_running_agents_smaller_than_autobatch_size() {
|
||||
fn launch_thread(
|
||||
policy: &AsyncPolicy<TestBackend, MockPolicy>,
|
||||
handles: &mut Vec<JoinHandle<()>>,
|
||||
) {
|
||||
let mut thread_policy = policy.clone();
|
||||
let handle = spawn(move || {
|
||||
thread_policy.action(MockObservation(vec![0.]), false);
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
let policy = AsyncPolicy::new(8, MockPolicy::new());
|
||||
policy.increment_agents(3);
|
||||
|
||||
let mut handles = vec![];
|
||||
launch_thread(&policy, &mut handles);
|
||||
launch_thread(&policy, &mut handles);
|
||||
std::thread::sleep(Duration::from_millis(10));
|
||||
assert!(!handles[0].is_finished());
|
||||
assert!(!handles[1].is_finished());
|
||||
|
||||
launch_thread(&policy, &mut handles);
|
||||
std::thread::sleep(Duration::from_millis(10));
|
||||
for i in 0..3 {
|
||||
assert!(handles[i].is_finished());
|
||||
}
|
||||
|
||||
let mut handles = vec![];
|
||||
launch_thread(&policy, &mut handles);
|
||||
launch_thread(&policy, &mut handles);
|
||||
std::thread::sleep(Duration::from_millis(10));
|
||||
assert!(!handles[0].is_finished());
|
||||
assert!(!handles[1].is_finished());
|
||||
|
||||
policy.decrement_agents(1);
|
||||
std::thread::sleep(Duration::from_millis(10));
|
||||
assert!(handles[0].is_finished());
|
||||
assert!(handles[1].is_finished());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,108 @@
|
||||
use derive_new::new;
|
||||
|
||||
use burn_core::{prelude::*, record::Record, tensor::backend::AutodiffBackend};
|
||||
|
||||
use crate::TransitionBatch;
|
||||
|
||||
/// An action along with additional context about the decision.
|
||||
#[derive(Clone, new)]
|
||||
pub struct ActionContext<A, C> {
|
||||
/// The context.
|
||||
pub context: C,
|
||||
/// The action.
|
||||
pub action: A,
|
||||
}
|
||||
|
||||
/// The state of a policy.
|
||||
pub trait PolicyState<B: Backend> {
|
||||
/// The type of the record.
|
||||
type Record: Record<B>;
|
||||
|
||||
/// Convert the state to a record.
|
||||
fn into_record(self) -> Self::Record;
|
||||
/// Load the state from a record.
|
||||
fn load_record(&self, record: Self::Record) -> Self;
|
||||
}
|
||||
|
||||
/// Trait for a RL policy.
|
||||
pub trait Policy<B: Backend>: Clone {
|
||||
/// The observation given as input to the policy.
|
||||
type Observation;
|
||||
/// The action distribution parameters defining how the action will be sampled.
|
||||
type ActionDistribution;
|
||||
/// The action.
|
||||
type Action;
|
||||
|
||||
/// Additional context on the policy's decision.
|
||||
type ActionContext;
|
||||
/// The current parameterization of the policy.
|
||||
type PolicyState: PolicyState<B>;
|
||||
|
||||
/// Produces the action distribution from a batch of observations.
|
||||
fn forward(&mut self, obs: Self::Observation) -> Self::ActionDistribution;
|
||||
/// Gives the action from a batch of observations.
|
||||
fn action(
|
||||
&mut self,
|
||||
obs: Self::Observation,
|
||||
deterministic: bool,
|
||||
) -> (Self::Action, Vec<Self::ActionContext>);
|
||||
|
||||
/// Update the policy's parameters.
|
||||
fn update(&mut self, update: Self::PolicyState);
|
||||
/// Returns the current parameterization.
|
||||
fn state(&self) -> Self::PolicyState;
|
||||
|
||||
/// Loads the policy parameters from a record.
|
||||
fn load_record(self, record: <Self::PolicyState as PolicyState<B>>::Record) -> Self;
|
||||
}
|
||||
|
||||
/// Trait for a type that can be batched and unbatched (split).
|
||||
pub trait Batchable: Sized {
|
||||
/// Create a batch from a list of items.
|
||||
fn batch(value: Vec<Self>) -> Self;
|
||||
/// Create a list from batched items.
|
||||
fn unbatch(self) -> Vec<Self>;
|
||||
}
|
||||
|
||||
/// A training output.
|
||||
pub struct RLTrainOutput<TO, P> {
|
||||
/// The policy.
|
||||
pub policy: P,
|
||||
/// The item.
|
||||
pub item: TO,
|
||||
}
|
||||
|
||||
/// Batched transitions for a PolicyLearner.
|
||||
pub type LearnerTransitionBatch<B, P> =
|
||||
TransitionBatch<B, <P as Policy<B>>::Observation, <P as Policy<B>>::Action>;
|
||||
|
||||
/// Learner for a policy.
|
||||
pub trait PolicyLearner<B>
|
||||
where
|
||||
B: AutodiffBackend,
|
||||
<Self::InnerPolicy as Policy<B>>::Observation: Clone + Batchable,
|
||||
<Self::InnerPolicy as Policy<B>>::ActionDistribution: Clone + Batchable,
|
||||
<Self::InnerPolicy as Policy<B>>::Action: Clone + Batchable,
|
||||
{
|
||||
/// Additional context of a training step.
|
||||
type TrainContext;
|
||||
/// The policy to train.
|
||||
type InnerPolicy: Policy<B>;
|
||||
/// The record of the learner.
|
||||
type Record: Record<B>;
|
||||
|
||||
/// Execute a training step on the policy.
|
||||
fn train(
|
||||
&mut self,
|
||||
input: LearnerTransitionBatch<B, Self::InnerPolicy>,
|
||||
) -> RLTrainOutput<Self::TrainContext, <Self::InnerPolicy as Policy<B>>::PolicyState>;
|
||||
/// Returns the learner's current policy.
|
||||
fn policy(&self) -> Self::InnerPolicy;
|
||||
/// Update the learner's policy.
|
||||
fn update_policy(&mut self, update: Self::InnerPolicy);
|
||||
|
||||
/// Convert the learner's state into a record.
|
||||
fn record(&self) -> Self::Record;
|
||||
/// Load the learner's state from a record.
|
||||
fn load_record(self, record: Self::Record) -> Self;
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
mod async_policy;
|
||||
mod base;
|
||||
|
||||
pub use async_policy::*;
|
||||
pub use base::*;
|
||||
@@ -0,0 +1,244 @@
|
||||
use burn_core::{Tensor, prelude::Backend, tensor::Distribution};
|
||||
use derive_new::new;
|
||||
|
||||
use super::SliceAccess;
|
||||
|
||||
/// A state transition in an environment.
|
||||
#[derive(Clone, new)]
|
||||
pub struct Transition<B: Backend, S, A> {
|
||||
/// The initial state.
|
||||
pub state: S,
|
||||
/// The state after the step was taken.
|
||||
pub next_state: S,
|
||||
/// The action taken in the step.
|
||||
pub action: A,
|
||||
/// The reward.
|
||||
pub reward: Tensor<B, 1>,
|
||||
/// If the environment has reached a terminal state.
|
||||
pub done: Tensor<B, 1>,
|
||||
}
|
||||
|
||||
/// A batch of transitions.
|
||||
pub struct TransitionBatch<B: Backend, SB, AB> {
|
||||
/// Batched initial states.
|
||||
pub states: SB,
|
||||
/// Batched resulting states.
|
||||
pub next_states: SB,
|
||||
/// Batched actions.
|
||||
pub actions: AB,
|
||||
/// Batched rewards.
|
||||
pub rewards: Tensor<B, 2>,
|
||||
/// Batched flags for terminal states.
|
||||
pub dones: Tensor<B, 2>,
|
||||
}
|
||||
|
||||
/// A tensor-backed circular buffer for transitions.
|
||||
///
|
||||
/// Uses [`SliceAccess`] to store state and action batches in contiguous
|
||||
/// tensor storage, enabling efficient random sampling via `select`.
|
||||
/// The buffer lazily initializes its storage on the first `push` call.
|
||||
pub struct TransitionBuffer<B: Backend, SB: SliceAccess<B>, AB: SliceAccess<B>> {
|
||||
states: Option<SB>,
|
||||
next_states: Option<SB>,
|
||||
actions: Option<AB>,
|
||||
rewards: Option<Tensor<B, 2>>,
|
||||
dones: Option<Tensor<B, 2>>,
|
||||
capacity: usize,
|
||||
write_head: usize,
|
||||
len: usize,
|
||||
device: B::Device,
|
||||
}
|
||||
|
||||
impl<B: Backend, SB: SliceAccess<B>, AB: SliceAccess<B>> TransitionBuffer<B, SB, AB> {
|
||||
/// Creates a new buffer. Storage is lazily allocated on the first `push`.
|
||||
pub fn new(capacity: usize, device: &B::Device) -> Self {
|
||||
Self {
|
||||
states: None,
|
||||
next_states: None,
|
||||
actions: None,
|
||||
rewards: None,
|
||||
dones: None,
|
||||
capacity,
|
||||
write_head: 0,
|
||||
len: 0,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn ensure_init(&mut self, state: &SB, next_state: &SB, action: &AB) {
|
||||
if self.states.is_none() {
|
||||
self.states = Some(SB::zeros_like(state, self.capacity, &self.device));
|
||||
self.next_states = Some(SB::zeros_like(next_state, self.capacity, &self.device));
|
||||
self.actions = Some(AB::zeros_like(action, self.capacity, &self.device));
|
||||
self.rewards = Some(Tensor::zeros([self.capacity, 1], &self.device));
|
||||
self.dones = Some(Tensor::zeros([self.capacity, 1], &self.device));
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a transition, overwriting the oldest if full.
|
||||
pub fn push(&mut self, state: SB, next_state: SB, action: AB, reward: f32, done: bool) {
|
||||
self.ensure_init(&state, &next_state, &action);
|
||||
|
||||
let idx = self.write_head % self.capacity;
|
||||
|
||||
self.states
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.slice_assign_inplace(idx, state);
|
||||
self.next_states
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.slice_assign_inplace(idx, next_state);
|
||||
self.actions
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.slice_assign_inplace(idx, action);
|
||||
|
||||
let reward = Tensor::from_data([[reward]], &self.device);
|
||||
self.rewards
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.inplace(|r| r.slice_assign(idx..idx + 1, reward));
|
||||
|
||||
let done_val = if done { 1.0f32 } else { 0.0 };
|
||||
let done = Tensor::from_data([[done_val]], &self.device);
|
||||
self.dones
|
||||
.as_mut()
|
||||
.unwrap()
|
||||
.inplace(|d| d.slice_assign(idx..idx + 1, done));
|
||||
|
||||
self.write_head += 1;
|
||||
if self.len < self.capacity {
|
||||
self.len += 1;
|
||||
}
|
||||
}
|
||||
|
||||
/// Sample a random batch of transitions.
|
||||
pub fn sample(&self, batch_size: usize) -> TransitionBatch<B, SB, AB> {
|
||||
assert!(batch_size <= self.len, "batch_size exceeds buffer length");
|
||||
|
||||
let indices = Tensor::<B, 1>::random(
|
||||
[batch_size],
|
||||
Distribution::Uniform(0.0, self.len as f64),
|
||||
&self.device,
|
||||
)
|
||||
.int();
|
||||
|
||||
TransitionBatch {
|
||||
states: self
|
||||
.states
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.clone()
|
||||
.select(0, indices.clone()),
|
||||
next_states: self
|
||||
.next_states
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.clone()
|
||||
.select(0, indices.clone()),
|
||||
actions: self
|
||||
.actions
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.clone()
|
||||
.select(0, indices.clone()),
|
||||
rewards: self
|
||||
.rewards
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.clone()
|
||||
.select(0, indices.clone()),
|
||||
dones: self.dones.as_ref().unwrap().clone().select(0, indices),
|
||||
}
|
||||
}
|
||||
|
||||
/// Current number of stored transitions.
|
||||
pub fn len(&self) -> usize {
|
||||
self.len
|
||||
}
|
||||
|
||||
/// Whether the buffer is empty.
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.len == 0
|
||||
}
|
||||
|
||||
/// Buffer capacity.
|
||||
pub fn capacity(&self) -> usize {
|
||||
self.capacity
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
|
||||
type TB = Tensor<TestBackend, 2>;
|
||||
|
||||
fn push_transition(
|
||||
buffer: &mut TransitionBuffer<TestBackend, TB, TB>,
|
||||
device: &<TestBackend as Backend>::Device,
|
||||
val: f32,
|
||||
) {
|
||||
let state = Tensor::<TestBackend, 2>::from_data([[val, val]], device);
|
||||
let next_state = Tensor::<TestBackend, 2>::from_data([[val + 1.0, val + 1.0]], device);
|
||||
let action = Tensor::<TestBackend, 2>::from_data([[val]], device);
|
||||
buffer.push(state, next_state, action, val, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn push_increment_len() {
|
||||
let device = Default::default();
|
||||
let mut buffer = TransitionBuffer::<TestBackend, TB, TB>::new(5, &device);
|
||||
|
||||
assert_eq!(buffer.len(), 0);
|
||||
assert!(buffer.is_empty());
|
||||
|
||||
push_transition(&mut buffer, &device, 1.0);
|
||||
assert_eq!(buffer.len(), 1);
|
||||
|
||||
push_transition(&mut buffer, &device, 2.0);
|
||||
assert_eq!(buffer.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn push_overwrites_when_full() {
|
||||
let device = Default::default();
|
||||
let mut buffer = TransitionBuffer::<TestBackend, TB, TB>::new(3, &device);
|
||||
|
||||
for i in 0..5 {
|
||||
push_transition(&mut buffer, &device, i as f32);
|
||||
}
|
||||
|
||||
assert_eq!(buffer.len(), 3);
|
||||
assert_eq!(buffer.capacity(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sample_returns_correct_shapes() {
|
||||
let device = Default::default();
|
||||
let mut buffer = TransitionBuffer::<TestBackend, TB, TB>::new(10, &device);
|
||||
|
||||
for i in 0..5 {
|
||||
push_transition(&mut buffer, &device, i as f32);
|
||||
}
|
||||
|
||||
let batch = buffer.sample(3);
|
||||
assert_eq!(batch.states.dims(), [3, 2]);
|
||||
assert_eq!(batch.next_states.dims(), [3, 2]);
|
||||
assert_eq!(batch.actions.dims(), [3, 1]);
|
||||
assert_eq!(batch.rewards.dims(), [3, 1]);
|
||||
assert_eq!(batch.dones.dims(), [3, 1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "batch_size exceeds buffer length")]
|
||||
fn sample_panics_when_batch_too_large() {
|
||||
let device = Default::default();
|
||||
let mut buffer = TransitionBuffer::<TestBackend, TB, TB>::new(5, &device);
|
||||
|
||||
push_transition(&mut buffer, &device, 1.0);
|
||||
buffer.sample(5);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
mod base;
|
||||
mod slice_access;
|
||||
|
||||
pub use base::*;
|
||||
pub use slice_access::*;
|
||||
@@ -0,0 +1,36 @@
|
||||
use burn_core::prelude::*;
|
||||
|
||||
/// Trait for types that support tensor-like slice operations,
|
||||
/// enabling storage in a [`TransitionBuffer`](super::TransitionBuffer).
|
||||
///
|
||||
/// Implement this trait for any type that wraps tensors and can be stored
|
||||
/// in a replay buffer. The buffer uses these operations for:
|
||||
/// - Pre-allocating storage (`zeros_like`)
|
||||
/// - Writing transitions (`slice_assign_inplace`)
|
||||
/// - Sampling batches (`select`)
|
||||
pub trait SliceAccess<B: Backend>: Clone + Sized {
|
||||
/// Create zeroed storage matching the shape of `sample` but with `capacity` rows
|
||||
/// along the first dimension.
|
||||
fn zeros_like(sample: &Self, capacity: usize, device: &B::Device) -> Self;
|
||||
|
||||
/// Select rows at the given indices along the specified dimension.
|
||||
fn select(self, dim: usize, indices: Tensor<B, 1, Int>) -> Self;
|
||||
|
||||
/// Assign `value` at row `index` along the first dimension, in place.
|
||||
fn slice_assign_inplace(&mut self, index: usize, value: Self);
|
||||
}
|
||||
|
||||
impl<B: Backend> SliceAccess<B> for Tensor<B, 2> {
|
||||
fn zeros_like(sample: &Self, capacity: usize, device: &B::Device) -> Self {
|
||||
let feature_dim = sample.dims()[1];
|
||||
Tensor::zeros([capacity, feature_dim], device)
|
||||
}
|
||||
|
||||
fn select(self, dim: usize, indices: Tensor<B, 1, Int>) -> Self {
|
||||
Tensor::select(self, dim, indices)
|
||||
}
|
||||
|
||||
fn slice_assign_inplace(&mut self, index: usize, value: Self) {
|
||||
self.inplace(|t| t.slice_assign(index..index + 1, value));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user