logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

12 lines
404 B

import torch
from torch import Tensor
def one_hot_to_index(one_hot: Tensor) -> Tensor:
"""
Converts a one-hot tensor into a tensor with corresponding indexes
"""
device, dtype = one_hot.device, one_hot.dtype
vocab_size = one_hot.shape[-1]
oh2idx = torch.tensor(range(vocab_size), dtype=dtype, device=device)
return (one_hot @ oh2idx.unsqueeze(dim=1)).long().squeeze(dim=-1)