Update to burn v0.14.0 and switch to .mpk model file

This commit is contained in:
Hermes
2024-10-05 14:19:49 -04:00
parent 9e4d7bd310
commit 893fb0950d
19 changed files with 366 additions and 311 deletions

View File

@@ -15,7 +15,7 @@ pub struct GroupNormConfig {
}
impl GroupNormConfig {
pub fn init<B: Backend>(&self) -> GroupNorm<B> {
pub fn init<B: Backend>(&self, device: &B::Device) -> GroupNorm<B> {
assert!(
self.n_channel % self.n_group == 0,
"The number of channels {} must be divisible by the number of groups {}",
@@ -25,8 +25,8 @@ impl GroupNormConfig {
let n_per_group = self.n_channel / self.n_group;
let gamma = Tensor::ones([self.n_channel]).into();
let beta = Tensor::zeros([self.n_channel]).into();
let gamma = Param::from_tensor(Tensor::ones([self.n_channel], device));
let beta = Param::from_tensor(Tensor::zeros([self.n_channel], device));
let eps = self.eps;