Update to burn v0.14.0 and switch to .mpk model file
This commit is contained in:
@@ -18,14 +18,14 @@ pub fn load_group_norm<B: Backend>(
|
||||
let n_channel = load_usize::<B>("n_channel", path, device)?.into();
|
||||
let eps = load_f32::<B>("eps", path, device)?.into();
|
||||
|
||||
let gamma = load_tensor::<B, 1>("weight", path, device)
|
||||
let gamma = Param::from_tensor(load_tensor::<B, 1>("weight", path, device)
|
||||
.ok()
|
||||
.unwrap_or_else(|| Tensor::ones_device([n_channel], device))
|
||||
.into();
|
||||
let beta = load_tensor::<B, 1>("bias", path, device)
|
||||
.unwrap_or_else(|| Tensor::ones([n_channel], device))
|
||||
);
|
||||
let beta = Param::from_tensor(load_tensor::<B, 1>("bias", path, device)
|
||||
.ok()
|
||||
.unwrap_or_else(|| Tensor::zeros_device([n_channel], device))
|
||||
.into();
|
||||
.unwrap_or_else(|| Tensor::zeros([n_channel], device))
|
||||
);
|
||||
|
||||
Ok(GroupNorm {
|
||||
n_group,
|
||||
|
||||
Reference in New Issue
Block a user