Source code for afe.ir.torch_utils

#########################################################
# Copyright (C) 2023 SiMa Technologies, Inc.
#
# This material is SiMa proprietary and confidential.
#
# This material may not be copied or distributed without
# the express prior written permission of SiMa.
#
# All rights reserved.
#########################################################
# Code owner: Ljubomir Papuga
#########################################################
"""
Helper functions using PyTorch API
"""

import numpy as np
import torch
from typing import Union

from afe.ir.utils import transpose_tensor_according_to_layout_strings

_torch_data_layout: str = "NCHW"


[docs] def convert_numpy_to_torch(x: np.ndarray, layout: str = "NHWC") -> torch.Tensor: if x.ndim == 4 and layout == "NHWC": x = transpose_tensor_according_to_layout_strings(x, layout, _torch_data_layout) return torch.from_numpy(x)
[docs] def torch_tensor_to_scalar(x: torch.Tensor) -> Union[int, float]: assert x.ndim == 1 and x.shape[0] == 1, \ f"Cannot convert torch.Tensor with shape {x.shape} to scalar." return x.item()
[docs] def numpy_tensor_to_scalar(x: np.ndarray) -> Union[int, float]: assert x.ndim == 1 and x.shape[0] == 1, \ f"Cannot convert torch.Tensor with shape {x.shape} to scalar." # Returning the only element in the array. if x.dtype == np.float32 or x.dtype == np.float64: return float(x[0]) elif x.dtype == np.int32 or x.dtype == np.int64: return int(x[0]) else: raise RuntimeError(f"Conversion from numpy to scalar dtype {x.dtype}")