mirror of
https://gitea.hainer-ernst.de/rasmus/burn-stablediffusion-vibecode.git
synced 2026-06-11 02:09:21 +00:00
Replace helper functions with native burn functions
This commit is contained in:
@@ -1,25 +1,27 @@
|
||||
pub mod load;
|
||||
|
||||
use burn::{
|
||||
config::Config,
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
tensor::{
|
||||
backend::Backend,
|
||||
Tensor,
|
||||
},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct GroupNormConfig {
|
||||
n_group: usize,
|
||||
n_channel: usize,
|
||||
n_group: usize,
|
||||
n_channel: usize,
|
||||
#[config(default = 1e-5)]
|
||||
eps: f64,
|
||||
eps: f64,
|
||||
}
|
||||
|
||||
impl GroupNormConfig {
|
||||
pub fn init<B: Backend>(&self) -> GroupNorm<B> {
|
||||
assert!(self.n_channel % self.n_group == 0, "The number of channels {} must be divisible by the number of groups {}", self.n_channel, self.n_group);
|
||||
assert!(
|
||||
self.n_channel % self.n_group == 0,
|
||||
"The number of channels {} must be divisible by the number of groups {}",
|
||||
self.n_channel,
|
||||
self.n_group
|
||||
);
|
||||
|
||||
let n_per_group = self.n_channel / self.n_group;
|
||||
|
||||
@@ -29,22 +31,22 @@ impl GroupNormConfig {
|
||||
let eps = self.eps;
|
||||
|
||||
GroupNorm {
|
||||
n_group: self.n_group,
|
||||
n_channel: self.n_channel,
|
||||
gamma,
|
||||
beta,
|
||||
eps,
|
||||
n_group: self.n_group,
|
||||
n_channel: self.n_channel,
|
||||
gamma,
|
||||
beta,
|
||||
eps,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct GroupNorm<B: Backend> {
|
||||
n_group: usize,
|
||||
n_channel: usize,
|
||||
gamma: Param<Tensor<B, 1>>,
|
||||
beta: Param<Tensor<B, 1>>,
|
||||
eps: f64,
|
||||
n_group: usize,
|
||||
n_channel: usize,
|
||||
gamma: Param<Tensor<B, 1>>,
|
||||
beta: Param<Tensor<B, 1>>,
|
||||
eps: f64,
|
||||
}
|
||||
|
||||
impl<B: Backend> GroupNorm<B> {
|
||||
@@ -56,10 +58,17 @@ impl<B: Backend> GroupNorm<B> {
|
||||
let mut affine_shape = [1; D];
|
||||
affine_shape[1] = self.n_channel;
|
||||
|
||||
layernorm( x.reshape([n_batch, self.n_group, num_elements / (n_batch * self.n_group) ]), self.eps )
|
||||
.reshape(shape)
|
||||
.mul(self.gamma.val().reshape(affine_shape))
|
||||
.add(self.beta.val().reshape(affine_shape))
|
||||
layernorm(
|
||||
x.reshape([
|
||||
n_batch,
|
||||
self.n_group,
|
||||
num_elements / (n_batch * self.n_group),
|
||||
]),
|
||||
self.eps,
|
||||
)
|
||||
.reshape(shape)
|
||||
.mul(self.gamma.val().reshape(affine_shape))
|
||||
.add(self.beta.val().reshape(affine_shape))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -68,5 +77,6 @@ pub fn layernorm<B: Backend, const D: usize>(x: Tensor<B, D>, eps: f64) -> Tenso
|
||||
//x.sub(mean).div(var.sqrt().add_scalar(eps))
|
||||
|
||||
let u = x.clone() - x.mean_dim(D - 1);
|
||||
u.clone().div( (u.clone() * u).mean_dim(D - 1).add_scalar(eps).sqrt() )
|
||||
}
|
||||
u.clone()
|
||||
.div((u.clone() * u).mean_dim(D - 1).add_scalar(eps).sqrt())
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user