import torch
from torch.distributions import multinomial
fairProbs = torch.ones([6], dtype=torch.float32)/6.0
print(multinomial.Multinomial(1, fairProbs).sample()) # tensor([0., 0., 1., 0., 0., 0.])
print(multinomial.Multinomial(10, fairProbs).sample()) # tensor([1., 2., 1., 1., 2., 3.])
print(multinomial.Multinomial(1000, fairProbs).sample()) # tensor([192., 167., 148., 167., 152., 174.])
count= multinomial.Multinomial(10000, fairProbs).sample()
print(count/10000.0) # tensor([0.1711, 0.1716, 0.1684, 0.1632, 0.1649, 0.1608])
counts = multinomial.Multinomial(10, fairProbs).sample((500,))
print(counts)
cum_counts = counts.cumsum(dim=0)
print(cum_counts)
estimates = cum_counts / cum_counts.sum(dim=1, keepdims=True)
print(estimates)