#########################################################
# Copyright (C) 2023 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Ljubomir Papuga
#########################################################
"""
Helper functions using PyTorch API
"""
import numpy as np
import torch
from typing import Union
from afe.ir.utils import transpose_tensor_according_to_layout_strings
_torch_data_layout: str = "NCHW"
[docs]
def convert_numpy_to_torch(x: np.ndarray, layout: str = "NHWC") -> torch.Tensor:
if x.ndim == 4 and layout == "NHWC":
x = transpose_tensor_according_to_layout_strings(x, layout, _torch_data_layout)
return torch.from_numpy(x)
[docs]
def torch_tensor_to_scalar(x: torch.Tensor) -> Union[int, float]:
assert x.ndim == 1 and x.shape[0] == 1, \
f"Cannot convert torch.Tensor with shape {x.shape} to scalar."
return x.item()
[docs]
def numpy_tensor_to_scalar(x: np.ndarray) -> Union[int, float]:
assert x.ndim == 1 and x.shape[0] == 1, \
f"Cannot convert torch.Tensor with shape {x.shape} to scalar."
# Returning the only element in the array.
if x.dtype == np.float32 or x.dtype == np.float64:
return float(x[0])
elif x.dtype == np.int32 or x.dtype == np.int64:
return int(x[0])
else:
raise RuntimeError(f"Conversion from numpy to scalar dtype {x.dtype}")