mirror of
https://gitea.hainer-ernst.de/rasmus/burn-stablediffusion-vibecode.git
synced 2026-06-11 02:09:21 +00:00
Replace helper functions with native burn functions
This commit is contained in:
@@ -1,35 +1,33 @@
|
||||
pub mod load;
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
nn,
|
||||
tensor::{
|
||||
activation::{sigmoid, softmax},
|
||||
backend::Backend,
|
||||
activation::{softmax, sigmoid},
|
||||
module::embedding,
|
||||
Tensor,
|
||||
Distribution,
|
||||
Int,
|
||||
module::embedding,
|
||||
Distribution, Int, Tensor,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::model::attention::{qkv_attention, attn_decoder_mask};
|
||||
|
||||
use crate::model::attention::{attn_decoder_mask, qkv_attention};
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct CLIPConfig {
|
||||
n_vocab: usize,
|
||||
n_state: usize,
|
||||
n_head: usize,
|
||||
n_ctx: usize,
|
||||
n_layer: usize,
|
||||
n_vocab: usize,
|
||||
n_state: usize,
|
||||
n_head: usize,
|
||||
n_ctx: usize,
|
||||
n_layer: usize,
|
||||
}
|
||||
|
||||
impl CLIPConfig {
|
||||
pub fn init<B: Backend>(&self) -> CLIP<B> {
|
||||
let token_embedding = nn::EmbeddingConfig::new(self.n_vocab, self.n_state).init();
|
||||
let position_embedding = Tensor::random([self.n_ctx, self.n_state], Distribution::Normal(0.0, 1.0)).into();
|
||||
let position_embedding =
|
||||
Tensor::random([self.n_ctx, self.n_state], Distribution::Normal(0.0, 1.0)).into();
|
||||
let blocks = (0..self.n_layer)
|
||||
.into_iter()
|
||||
.map(|_| ResidualDecoderAttentionBlockConfig::new(self.n_state, self.n_head).init())
|
||||
@@ -37,33 +35,35 @@ impl CLIPConfig {
|
||||
let layer_norm = nn::LayerNormConfig::new(self.n_state).init();
|
||||
|
||||
CLIP {
|
||||
token_embedding,
|
||||
position_embedding,
|
||||
blocks,
|
||||
layer_norm,
|
||||
token_embedding,
|
||||
position_embedding,
|
||||
blocks,
|
||||
layer_norm,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct CLIP<B: Backend> {
|
||||
token_embedding: nn::Embedding<B>,
|
||||
position_embedding: Param<Tensor<B, 2>>,
|
||||
blocks: Vec<ResidualDecoderAttentionBlock<B>>,
|
||||
layer_norm: nn::LayerNorm<B>,
|
||||
token_embedding: nn::Embedding<B>,
|
||||
position_embedding: Param<Tensor<B, 2>>,
|
||||
blocks: Vec<ResidualDecoderAttentionBlock<B>>,
|
||||
layer_norm: nn::LayerNorm<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> CLIP<B> {
|
||||
pub fn forward(&self, x: Tensor<B, 2, Int>) -> Tensor<B, 3> {
|
||||
let [n_batch, seq_len] = x.dims();
|
||||
|
||||
|
||||
let mask = attn_decoder_mask(seq_len, &x.device());
|
||||
|
||||
let embedded = self.token_embedding.forward(x)
|
||||
+ self.position_embedding.val().slice([0..seq_len]).unsqueeze();
|
||||
|
||||
let embedded = self.token_embedding.forward(x)
|
||||
+ self
|
||||
.position_embedding
|
||||
.val()
|
||||
.slice([0..seq_len])
|
||||
.unsqueeze();
|
||||
|
||||
let mut x = embedded;
|
||||
for block in &self.blocks {
|
||||
x = block.forward(x, mask.clone());
|
||||
@@ -73,37 +73,35 @@ impl<B: Backend> CLIP<B> {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct ResidualDecoderAttentionBlockConfig {
|
||||
n_state: usize,
|
||||
n_head: usize,
|
||||
n_state: usize,
|
||||
n_head: usize,
|
||||
}
|
||||
|
||||
impl ResidualDecoderAttentionBlockConfig {
|
||||
pub fn init<B: Backend>(&self) -> ResidualDecoderAttentionBlock<B> {
|
||||
let attn = MultiHeadSelfAttentionConfig::new(self.n_state, self.n_head).init();
|
||||
let attn_ln = nn::LayerNormConfig::new(self.n_state).init();
|
||||
|
||||
|
||||
let mlp = MLPConfig::new(self.n_state, 4 * self.n_state).init();
|
||||
let mlp_ln = nn::LayerNormConfig::new(self.n_state).init();
|
||||
|
||||
ResidualDecoderAttentionBlock {
|
||||
attn,
|
||||
attn_ln,
|
||||
mlp,
|
||||
mlp_ln,
|
||||
attn,
|
||||
attn_ln,
|
||||
mlp,
|
||||
mlp_ln,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct ResidualDecoderAttentionBlock<B: Backend> {
|
||||
attn: MultiHeadSelfAttention<B>,
|
||||
attn_ln: nn::LayerNorm<B>,
|
||||
mlp: MLP<B>,
|
||||
mlp_ln: nn::LayerNorm<B>,
|
||||
attn: MultiHeadSelfAttention<B>,
|
||||
attn_ln: nn::LayerNorm<B>,
|
||||
mlp: MLP<B>,
|
||||
mlp_ln: nn::LayerNorm<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ResidualDecoderAttentionBlock<B> {
|
||||
@@ -117,12 +115,17 @@ impl<B: Backend> ResidualDecoderAttentionBlock<B> {
|
||||
#[derive(Config)]
|
||||
pub struct MultiHeadSelfAttentionConfig {
|
||||
n_state: usize,
|
||||
n_head: usize,
|
||||
n_head: usize,
|
||||
}
|
||||
|
||||
impl MultiHeadSelfAttentionConfig {
|
||||
fn init<B: Backend>(&self) -> MultiHeadSelfAttention<B> {
|
||||
assert!(self.n_state % self.n_head == 0, "State size {} must be a multiple of head size {}", self.n_state, self.n_head);
|
||||
assert!(
|
||||
self.n_state % self.n_head == 0,
|
||||
"State size {} must be a multiple of head size {}",
|
||||
self.n_state,
|
||||
self.n_head
|
||||
);
|
||||
|
||||
let n_head = self.n_head;
|
||||
let query = nn::LinearConfig::new(self.n_state, self.n_state).init();
|
||||
@@ -130,23 +133,23 @@ impl MultiHeadSelfAttentionConfig {
|
||||
let value = nn::LinearConfig::new(self.n_state, self.n_state).init();
|
||||
let out = nn::LinearConfig::new(self.n_state, self.n_state).init();
|
||||
|
||||
MultiHeadSelfAttention {
|
||||
n_head,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out
|
||||
MultiHeadSelfAttention {
|
||||
n_head,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
out,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct MultiHeadSelfAttention<B: Backend> {
|
||||
n_head: usize,
|
||||
query: nn::Linear<B>,
|
||||
key: nn::Linear<B>,
|
||||
value: nn::Linear<B>,
|
||||
out: nn::Linear<B>,
|
||||
n_head: usize,
|
||||
query: nn::Linear<B>,
|
||||
key: nn::Linear<B>,
|
||||
value: nn::Linear<B>,
|
||||
out: nn::Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> MultiHeadSelfAttention<B> {
|
||||
@@ -161,17 +164,10 @@ impl<B: Backend> MultiHeadSelfAttention<B> {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
#[derive(Config, Debug)]
|
||||
pub struct MLPConfig {
|
||||
input_size: usize,
|
||||
hidden_size: usize,
|
||||
input_size: usize,
|
||||
hidden_size: usize,
|
||||
}
|
||||
|
||||
impl MLPConfig {
|
||||
@@ -180,19 +176,15 @@ impl MLPConfig {
|
||||
let gelu = QuickGELU::new();
|
||||
let fc2 = nn::LinearConfig::new(self.hidden_size, self.input_size).init();
|
||||
|
||||
MLP {
|
||||
fc1,
|
||||
gelu,
|
||||
fc2,
|
||||
}
|
||||
MLP { fc1, gelu, fc2 }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct MLP<B: Backend> {
|
||||
fc1: nn::Linear<B>,
|
||||
gelu: QuickGELU,
|
||||
fc2: nn::Linear<B>,
|
||||
fc1: nn::Linear<B>,
|
||||
gelu: QuickGELU,
|
||||
fc2: nn::Linear<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> MLP<B> {
|
||||
@@ -217,4 +209,3 @@ impl QuickGELU {
|
||||
x.clone() * sigmoid(x * 1.702)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user