camel
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
12 lines
404 B
12 lines
404 B
import torch
|
|
from torch import Tensor
|
|
|
|
|
|
def one_hot_to_index(one_hot: Tensor) -> Tensor:
|
|
"""
|
|
Converts a one-hot tensor into a tensor with corresponding indexes
|
|
"""
|
|
device, dtype = one_hot.device, one_hot.dtype
|
|
vocab_size = one_hot.shape[-1]
|
|
oh2idx = torch.tensor(range(vocab_size), dtype=dtype, device=device)
|
|
return (one_hot @ oh2idx.unsqueeze(dim=1)).long().squeeze(dim=-1)
|