Update to burn v0.14.0 and switch to .mpk model file

This commit is contained in:
Hermes
2024-10-05 14:19:49 -04:00
parent 9e4d7bd310
commit 893fb0950d
19 changed files with 366 additions and 311 deletions

View File

@@ -18,14 +18,14 @@ pub fn load_group_norm<B: Backend>(
let n_channel = load_usize::<B>("n_channel", path, device)?.into();
let eps = load_f32::<B>("eps", path, device)?.into();
let gamma = load_tensor::<B, 1>("weight", path, device)
let gamma = Param::from_tensor(load_tensor::<B, 1>("weight", path, device)
.ok()
.unwrap_or_else(|| Tensor::ones_device([n_channel], device))
.into();
let beta = load_tensor::<B, 1>("bias", path, device)
.unwrap_or_else(|| Tensor::ones([n_channel], device))
);
let beta = Param::from_tensor(load_tensor::<B, 1>("bias", path, device)
.ok()
.unwrap_or_else(|| Tensor::zeros_device([n_channel], device))
.into();
.unwrap_or_else(|| Tensor::zeros([n_channel], device))
);
Ok(GroupNorm {
n_group,