mirror of
https://gitea.hainer-ernst.de/rasmus/burn-stablediffusion-vibecode.git
synced 2026-06-11 02:09:21 +00:00
Update to burn v0.14.0 and switch to .mpk model file
This commit is contained in:
@@ -15,7 +15,7 @@ pub struct GroupNormConfig {
|
||||
}
|
||||
|
||||
impl GroupNormConfig {
|
||||
pub fn init<B: Backend>(&self) -> GroupNorm<B> {
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> GroupNorm<B> {
|
||||
assert!(
|
||||
self.n_channel % self.n_group == 0,
|
||||
"The number of channels {} must be divisible by the number of groups {}",
|
||||
@@ -25,8 +25,8 @@ impl GroupNormConfig {
|
||||
|
||||
let n_per_group = self.n_channel / self.n_group;
|
||||
|
||||
let gamma = Tensor::ones([self.n_channel]).into();
|
||||
let beta = Tensor::zeros([self.n_channel]).into();
|
||||
let gamma = Param::from_tensor(Tensor::ones([self.n_channel], device));
|
||||
let beta = Param::from_tensor(Tensor::zeros([self.n_channel], device));
|
||||
|
||||
let eps = self.eps;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user