N-Gram Language modeling

“old school” language modeling based on counting tokens in data

https://github.com/karpathy/makemore

Unigram


source

CharUnigram

 CharUnigram (data:List[str])

Initialize self. See help(type(self)) for accurate signature.

Usage

# without pandas
with open('../data/text/names.txt', 'r') as f:
    list_of_words = f.read().splitlines()
# with pandas
df = pd.read_csv('../data/text/names.txt', names=['name'], header=None)
list_of_words = list(df.head().name)

unigram = CharUnigram(list_of_words)
print("sorted counts: ", unigram.counts)
print("sorted probs: ", unigram.probs)
print(len(unigram))
print(unigram.chars)
print(unigram._stoi)
print(unigram.stoi('a'))
print(unigram.itos(0))
sorted counts:  {'a': 7, 'i': 4, 'l': 3, 'e': 2, 'm': 2, 'o': 2, 'v': 2, 's': 2, 'b': 1, 'p': 1, 'h': 1}
sorted probs:  {'a': 0.25925925925925924, 'i': 0.14814814814814814, 'l': 0.1111111111111111, 'e': 0.07407407407407407, 'm': 0.07407407407407407, 'o': 0.07407407407407407, 'v': 0.07407407407407407, 's': 0.07407407407407407, 'b': 0.037037037037037035, 'p': 0.037037037037037035, 'h': 0.037037037037037035}
11
['a', 'b', 'e', 'h', 'i', 'l', 'm', 'o', 'p', 's', 'v']
{'a': 0, 'b': 1, 'e': 2, 'h': 3, 'i': 4, 'l': 5, 'm': 6, 'o': 7, 'p': 8, 's': 9, 'v': 10}
0
a
df = pd.DataFrame.from_dict(unigram.counts, orient='index')
df.plot(kind='bar')
<Axes: >

samples = []
for i in range(10000):
    s = unigram.sample()
    samples.append(s)

# sampled
count = Counter([c for w in samples for c in w])
df = pd.DataFrame.from_dict(count, orient='index')
df[0].sort_values(ascending=False).plot(kind='bar')
<Axes: >

Bigram

class CharBigram():
    def __init__(self):
        pass

Usage

# data
with open('../data/text/names.txt', 'r') as f:
    data = f.read().splitlines()
print("first lines of text: ", data[:10])

# data = ["this is a text"]
first lines of text:  ['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia', 'harper', 'evelyn']
# bigram counts
bigrams = {}
unique_tokens = set()
for name in data:
    line = list(name)
    unique_tokens.update(line)
    line.append('<stop>')
    line.insert(0, '<stop>')
    for i,v in enumerate(range(len(line)-1)):
        bigram = (line[i], line[i+1])
        if bigram in bigrams:
            bigrams[bigram] += 1
        else:
            bigrams[bigram] = 1

# print("unsorted: ", list(bigrams)[:10])
# print("sorted: ", sort_dict_by_value(bigrams))

Numericalization

tokens = sorted(unique_tokens)
# use same for start & stop in this case (separate lines of names)
# tokens.append('<start>')
tokens.append('<stop>')
print(tokens)
stoi = {v:i for i,v in enumerate(tokens)}
itos = {i:v for i, v in enumerate(tokens)}
print(stoi, itos)
['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '<stop>']
{'a': 0, 'b': 1, 'c': 2, 'd': 3, 'e': 4, 'f': 5, 'g': 6, 'h': 7, 'i': 8, 'j': 9, 'k': 10, 'l': 11, 'm': 12, 'n': 13, 'o': 14, 'p': 15, 'q': 16, 'r': 17, 's': 18, 't': 19, 'u': 20, 'v': 21, 'w': 22, 'x': 23, 'y': 24, 'z': 25, '<stop>': 26} {0: 'a', 1: 'b', 2: 'c', 3: 'd', 4: 'e', 5: 'f', 6: 'g', 7: 'h', 8: 'i', 9: 'j', 10: 'k', 11: 'l', 12: 'm', 13: 'n', 14: 'o', 15: 'p', 16: 'q', 17: 'r', 18: 's', 19: 't', 20: 'u', 21: 'v', 22: 'w', 23: 'x', 24: 'y', 25: 'z', 26: '<stop>'}

Matrix representation

n_toks = len(tokens)
print(n_toks)
N = torch.zeros((n_toks, n_toks)).long()
print(N.shape)
27
torch.Size([27, 27])
for bigram, value in bigrams.items():
    idx1, idx2 = stoi[bigram[0]], stoi[bigram[1]]
    N[idx1, idx2] = value

plt.xlabel('char_t+1')
plt.ylabel('char_t')
i = [i for i, v in itos.items()]
v = [v for i,v in itos.items()]
plt.xticks(i, v)
plt.yticks(i, v)
plt.imshow(N, origin='lower')
<matplotlib.image.AxesImage>

From counts to probabilities

print(N)
tensor([[ 556,  541,  470, 1042,  692,  134,  168, 2332, 1650,  175,  568, 2528,
         1634, 5438,   63,   82,   60, 3264, 1118,  687,  381,  834,  161,  182,
         2050,  435, 6640],
        [ 321,   38,    1,   65,  655,    0,    0,   41,  217,    1,    0,  103,
            0,    4,  105,    0,    0,  842,    8,    2,   45,    0,    0,    0,
           83,    0,  114],
        [ 815,    0,   42,    1,  551,    0,    2,  664,  271,    3,  316,  116,
            0,    0,  380,    1,   11,   76,    5,   35,   35,    0,    0,    3,
          104,    4,   97],
        [1303,    1,    3,  149, 1283,    5,   25,  118,  674,    9,    3,   60,
           30,   31,  378,    0,    1,  424,   29,    4,   92,   17,   23,    0,
          317,    1,  516],
        [ 679,  121,  153,  384, 1271,   82,  125,  152,  818,   55,  178, 3248,
          769, 2675,  269,   83,   14, 1958,  861,  580,   69,  463,   50,  132,
         1070,  181, 3983],
        [ 242,    0,    0,    0,  123,   44,    1,    1,  160,    0,    2,   20,
            0,    4,   60,    0,    0,  114,    6,   18,   10,    0,    4,    0,
           14,    2,   80],
        [ 330,    3,    0,   19,  334,    1,   25,  360,  190,    3,    0,   32,
            6,   27,   83,    0,    0,  201,   30,   31,   85,    1,   26,    0,
           31,    1,  108],
        [2244,    8,    2,   24,  674,    2,    2,    1,  729,    9,   29,  185,
          117,  138,  287,    1,    1,  204,   31,   71,  166,   39,   10,    0,
          213,   20, 2409],
        [2445,  110,  509,  440, 1653,  101,  428,   95,   82,   76,  445, 1345,
          427, 2126,  588,   53,   52,  849, 1316,  541,  109,  269,    8,   89,
          779,  277, 2489],
        [1473,    1,    4,    4,  440,    0,    0,   45,  119,    2,    2,    9,
            5,    2,  479,    1,    0,   11,    7,    2,  202,    5,    6,    0,
           10,    0,   71],
        [1731,    2,    2,    2,  895,    1,    0,  307,  509,    2,   20,  139,
            9,   26,  344,    0,    0,  109,   95,   17,   50,    2,   34,    0,
          379,    2,  363],
        [2623,   52,   25,  138, 2921,   22,    6,   19, 2480,    6,   24, 1345,
           60,   14,  692,   15,    3,   18,   94,   77,  324,   72,   16,    0,
         1588,   10, 1314],
        [2590,  112,   51,   24,  818,    1,    0,    5, 1256,    7,    1,    5,
          168,   20,  452,   38,    0,   97,   35,    4,  139,    3,    2,    0,
          287,   11,  516],
        [2977,    8,  213,  704, 1359,   11,  273,   26, 1725,   44,   58,  195,
           19, 1906,  496,    5,    2,   44,  278,  443,   96,   55,   11,    6,
          465,  145, 6763],
        [ 149,  140,  114,  190,  132,   34,   44,  171,   69,   16,   68,  619,
          261, 2411,  115,   95,    3, 1059,  504,  118,  275,  176,  114,   45,
          103,   54,  855],
        [ 209,    2,    1,    0,  197,    1,    0,  204,   61,    1,    1,   16,
            1,    1,   59,   39,    0,  151,   16,   17,    4,    0,    0,    0,
           12,    0,   33],
        [  13,    0,    0,    0,    1,    0,    0,    0,   13,    0,    0,    1,
            2,    0,    2,    0,    0,    1,    2,    0,  206,    0,    3,    0,
            0,    0,   28],
        [2356,   41,   99,  187, 1697,    9,   76,  121, 3033,   25,   90,  413,
          162,  140,  869,   14,   16,  425,  190,  208,  252,   80,   21,    3,
          773,   23, 1377],
        [1201,   21,   60,    9,  884,    2,    2, 1285,  684,    2,   82,  279,
           90,   24,  531,   51,    1,   55,  461,  765,  185,   14,   24,    0,
          215,   10, 1169],
        [1027,    1,   17,    0,  716,    2,    2,  647,  532,    3,    0,  134,
            4,   22,  667,    0,    0,  352,   35,  374,   78,   15,   11,    2,
          341,  105,  483],
        [ 163,  103,  103,  136,  169,   19,   47,   58,  121,   14,   93,  301,
          154,  275,   10,   16,   10,  414,  474,   82,    3,   37,   86,   34,
           13,   45,  155],
        [ 642,    1,    0,    1,  568,    0,    0,    1,  911,    0,    3,   14,
            0,    8,  153,    0,    0,   48,    0,    0,    7,    7,    0,    0,
          121,    0,   88],
        [ 280,    1,    0,    8,  149,    2,    1,   23,  148,    0,    6,   13,
            2,   58,   36,    0,    0,   22,   20,    8,   25,    0,    2,    0,
           73,    1,   51],
        [ 103,    1,    4,    5,   36,    3,    0,    1,  102,    0,    0,   39,
            1,    1,   41,    0,    0,    0,   31,   70,    5,    0,    3,   38,
           30,   19,  164],
        [2143,   27,  115,  272,  301,   12,   30,   22,  192,   23,   86, 1104,
          148, 1826,  271,   15,    6,  291,  401,  104,  141,  106,    4,   28,
           23,   78, 2007],
        [ 860,    4,    2,    2,  373,    0,    1,   43,  364,    2,    2,  123,
           35,    4,  110,    2,    0,   32,    4,    4,   73,    2,    3,    1,
          147,   45,  160],
        [4410, 1306, 1542, 1690, 1531,  417,  669,  874,  591, 2422, 2963, 1572,
         2538, 1146,  394,  515,   92, 1639, 2055, 1308,   78,  376,  307,  134,
          535,  929,    0]])
# smoothing avoids having log(0) = inf when computing NLL loss
smoothing = 1
P = (N.float()+smoothing) / N.sum(1,keepdim=True)
plt.imshow(P, origin='lower')
<matplotlib.image.AxesImage>

row_6 = (N[6,:]/N[6,:].sum())
print(row_6)
print(row_6.sum())
tensor([0.1713, 0.0016, 0.0000, 0.0099, 0.1733, 0.0005, 0.0130, 0.1868, 0.0986,
        0.0016, 0.0000, 0.0166, 0.0031, 0.0140, 0.0431, 0.0000, 0.0000, 0.1043,
        0.0156, 0.0161, 0.0441, 0.0005, 0.0135, 0.0000, 0.0161, 0.0005, 0.0560])
tensor(1.0000)
p = P[6, :]
print(p.sum(), p.max(), torch.argmax(p))
tensor(1.0140) tensor(0.1873) tensor(7)

Sampling

for i in range(10):
    res = []
    prev = stoi['<stop>']
    while True:
        # max prob sampling
        next = int(torch.argmax(P[prev, :]))
        # multinomial sampling
        next = int(torch.multinomial(P[prev,:],num_samples=1,replacement=True))
        if next == stoi['<stop>']:
            print(''.join(res))
            break
        else:
            res.append(itos[next])
            prev = next
rinienah
n
dir
le
dlensllitavilideaeyieshi
jallyon
kukwi
ola
riylikahit
s

Log likelihood loss function

bigram_p = {}
for bigram, value in bigrams.items():
    idx1, idx2 = stoi[bigram[0]], stoi[bigram[1]]
    bigram_p[bigram] = P[idx1,idx2]

print(bigram_p)
{('<stop>', 'e'): tensor(0.0478), ('e', 'm'): tensor(0.0377), ('m', 'm'): tensor(0.0254), ('m', 'a'): tensor(0.3901), ('a', '<stop>'): tensor(0.1960), ('<stop>', 'o'): tensor(0.0123), ('o', 'l'): tensor(0.0781), ('l', 'i'): tensor(0.1777), ('i', 'v'): tensor(0.0153), ('v', 'i'): tensor(0.3545), ('i', 'a'): tensor(0.1382), ('<stop>', 'a'): tensor(0.1377), ('a', 'v'): tensor(0.0246), ('v', 'a'): tensor(0.2499), ('<stop>', 'i'): tensor(0.0185), ('i', 's'): tensor(0.0744), ('s', 'a'): tensor(0.1483), ('a', 'b'): tensor(0.0160), ('b', 'e'): tensor(0.2480), ('e', 'l'): tensor(0.1591), ('l', 'l'): tensor(0.0964), ('l', 'a'): tensor(0.1880), ('<stop>', 's'): tensor(0.0642), ('s', 'o'): tensor(0.0656), ('o', 'p'): tensor(0.0121), ('p', 'h'): tensor(0.1998), ('h', 'i'): tensor(0.0959), ('<stop>', 'c'): tensor(0.0482), ('c', 'h'): tensor(0.1883), ('h', 'a'): tensor(0.2948), ('a', 'r'): tensor(0.0964), ('r', 'l'): tensor(0.0326), ('l', 'o'): tensor(0.0496), ('o', 't'): tensor(0.0150), ('t', 't'): tensor(0.0673), ('t', 'e'): tensor(0.1287), ('e', '<stop>'): tensor(0.1951), ('<stop>', 'm'): tensor(0.0793), ('m', 'i'): tensor(0.1893), ('a', 'm'): tensor(0.0483), ('m', 'e'): tensor(0.1233), ('<stop>', 'h'): tensor(0.0273), ('r', 'p'): tensor(0.0012), ('p', 'e'): tensor(0.1930), ('e', 'r'): tensor(0.0959), ('r', '<stop>'): tensor(0.1085), ('e', 'v'): tensor(0.0227), ('v', 'e'): tensor(0.2211), ('l', 'y'): tensor(0.1138), ('y', 'n'): tensor(0.1869), ('n', '<stop>'): tensor(0.3691), ('b', 'i'): tensor(0.0824), ('i', 'g'): tensor(0.0242), ('g', 'a'): tensor(0.1718), ('a', 'i'): tensor(0.0487), ('i', 'l'): tensor(0.0760), ('l', '<stop>'): tensor(0.0942), ('y', '<stop>'): tensor(0.2054), ('i', 'z'): tensor(0.0157), ('z', 'a'): tensor(0.3590), ('e', 't'): tensor(0.0284), ('t', 'h'): tensor(0.1163), ('h', '<stop>'): tensor(0.3164), ('r', 'y'): tensor(0.0609), ('o', 'f'): tensor(0.0044), ('f', 'i'): tensor(0.1779), ('c', 'a'): tensor(0.2310), ('r', 'i'): tensor(0.2389), ('s', 'c'): tensor(0.0075), ('l', 'e'): tensor(0.2093), ('t', '<stop>'): tensor(0.0869), ('<stop>', 'v'): tensor(0.0118), ('i', 'c'): tensor(0.0288), ('c', 't'): tensor(0.0102), ('t', 'o'): tensor(0.1199), ('o', 'r'): tensor(0.1336), ('a', 'd'): tensor(0.0308), ('d', 'i'): tensor(0.1228), ('o', 'n'): tensor(0.3040), ('<stop>', 'l'): tensor(0.0491), ('l', 'u'): tensor(0.0233), ('u', 'n'): tensor(0.0880), ('n', 'a'): tensor(0.1625), ('<stop>', 'g'): tensor(0.0209), ('g', 'r'): tensor(0.1048), ('r', 'a'): tensor(0.1856), ('a', 'c'): tensor(0.0139), ('c', 'e'): tensor(0.1563), ('h', 'l'): tensor(0.0244), ('o', 'e'): tensor(0.0168), ('<stop>', 'p'): tensor(0.0161), ('e', 'n'): tensor(0.1310), ('n', 'e'): tensor(0.0742), ('a', 'y'): tensor(0.0605), ('y', 'l'): tensor(0.1130), ('<stop>', 'r'): tensor(0.0512), ('e', 'y'): tensor(0.0524), ('<stop>', 'z'): tensor(0.0290), ('z', 'o'): tensor(0.0463), ('<stop>', 'n'): tensor(0.0358), ('n', 'o'): tensor(0.0271), ('e', 'a'): tensor(0.0333), ('a', 'n'): tensor(0.1605), ('n', 'n'): tensor(0.1041), ('a', 'h'): tensor(0.0689), ('d', 'd'): tensor(0.0273), ('a', 'u'): tensor(0.0113), ('u', 'b'): tensor(0.0332), ('b', 'r'): tensor(0.3187), ('r', 'e'): tensor(0.1337), ('i', 'e'): tensor(0.0934), ('s', 't'): tensor(0.0945), ('a', 't'): tensor(0.0203), ('t', 'a'): tensor(0.1846), ('a', 'l'): tensor(0.0746), ('a', 'z'): tensor(0.0129), ('z', 'e'): tensor(0.1560), ('i', 'o'): tensor(0.0333), ('u', 'r'): tensor(0.1324), ('r', 'o'): tensor(0.0685), ('u', 'd'): tensor(0.0437), ('d', 'r'): tensor(0.0773), ('<stop>', 'b'): tensor(0.0408), ('o', 'o'): tensor(0.0146), ('o', 'k'): tensor(0.0087), ('k', 'l'): tensor(0.0278), ('c', 'l'): tensor(0.0331), ('i', 'r'): tensor(0.0480), ('s', 'k'): tensor(0.0102), ('k', 'y'): tensor(0.0754), ('u', 'c'): tensor(0.0332), ('c', 'y'): tensor(0.0297), ('p', 'a'): tensor(0.2047), ('s', 'l'): tensor(0.0345), ('i', 'n'): tensor(0.1202), ('o', 'v'): tensor(0.0223), ('g', 'e'): tensor(0.1738), ('e', 's'): tensor(0.0422), ('s', 'i'): tensor(0.0845), ('s', '<stop>'): tensor(0.1443), ('<stop>', 'k'): tensor(0.0925), ('k', 'e'): tensor(0.1778), ('e', 'd'): tensor(0.0189), ('d', 'y'): tensor(0.0579), ('n', 't'): tensor(0.0242), ('y', 'a'): tensor(0.2193), ('<stop>', 'w'): tensor(0.0096), ('w', 'i'): tensor(0.1604), ('o', 'w'): tensor(0.0145), ('w', '<stop>'): tensor(0.0560), ('k', 'i'): tensor(0.1012), ('n', 's'): tensor(0.0152), ('a', 'o'): tensor(0.0019), ('o', 'm'): tensor(0.0330), ('i', '<stop>'): tensor(0.1407), ('a', 'a'): tensor(0.0164), ('i', 'y'): tensor(0.0441), ('d', 'e'): tensor(0.2336), ('c', 'o'): tensor(0.1079), ('r', 'u'): tensor(0.0199), ('b', 'y'): tensor(0.0318), ('s', 'e'): tensor(0.1092), ('n', 'i'): tensor(0.0942), ('i', 't'): tensor(0.0306), ('t', 'y'): tensor(0.0614), ('u', 't'): tensor(0.0265), ('t', 'u'): tensor(0.0142), ('u', 'm'): tensor(0.0494), ('m', 'n'): tensor(0.0032), ('g', 'i'): tensor(0.0991), ('t', 'i'): tensor(0.0957), ('<stop>', 'q'): tensor(0.0029), ('q', 'u'): tensor(0.7610), ('u', 'i'): tensor(0.0389), ('a', 'e'): tensor(0.0205), ('e', 'h'): tensor(0.0075), ('v', 'y'): tensor(0.0474), ('p', 'i'): tensor(0.0604), ('i', 'p'): tensor(0.0031), ('y', 'd'): tensor(0.0279), ('e', 'x'): tensor(0.0065), ('x', 'a'): tensor(0.1492), ('<stop>', 'j'): tensor(0.0756), ('j', 'o'): tensor(0.1655), ('o', 's'): tensor(0.0637), ('e', 'p'): tensor(0.0041), ('j', 'u'): tensor(0.0700), ('u', 'l'): tensor(0.0963), ('<stop>', 'd'): tensor(0.0528), ('k', 'a'): tensor(0.3437), ('e', 'e'): tensor(0.0623), ('y', 't'): tensor(0.0107), ('d', 'l'): tensor(0.0111), ('c', 'k'): tensor(0.0898), ('n', 'z'): tensor(0.0080), ('z', 'i'): tensor(0.1522), ('a', 'g'): tensor(0.0050), ('d', 'a'): tensor(0.2373), ('j', 'a'): tensor(0.5083), ('h', 'e'): tensor(0.0886), ('<stop>', 'x'): tensor(0.0042), ('x', 'i'): tensor(0.1478), ('i', 'm'): tensor(0.0242), ('e', 'i'): tensor(0.0401), ('<stop>', 't'): tensor(0.0409), ('<stop>', 'f'): tensor(0.0130), ('f', 'a'): tensor(0.2685), ('n', 'd'): tensor(0.0385), ('r', 'g'): tensor(0.0061), ('a', 's'): tensor(0.0330), ('s', 'h'): tensor(0.1586), ('b', 'a'): tensor(0.1217), ('k', 'h'): tensor(0.0611), ('s', 'm'): tensor(0.0112), ('o', 'd'): tensor(0.0241), ('r', 's'): tensor(0.0150), ('g', 'h'): tensor(0.1873), ('s', 'y'): tensor(0.0266), ('y', 's'): tensor(0.0411), ('s', 's'): tensor(0.0570), ('e', 'c'): tensor(0.0075), ('c', 'i'): tensor(0.0770), ('m', 'o'): tensor(0.0682), ('r', 'k'): tensor(0.0072), ('n', 'l'): tensor(0.0107), ('d', 'n'): tensor(0.0058), ('r', 'd'): tensor(0.0148), ('o', 'i'): tensor(0.0088), ('t', 'r'): tensor(0.0634), ('m', 'b'): tensor(0.0170), ('r', 'm'): tensor(0.0128), ('n', 'y'): tensor(0.0254), ('d', 'o'): tensor(0.0690), ('o', 'a'): tensor(0.0189), ('o', 'c'): tensor(0.0145), ('m', 'y'): tensor(0.0434), ('s', 'u'): tensor(0.0229), ('m', 'c'): tensor(0.0078), ('p', 'r'): tensor(0.1481), ('o', 'u'): tensor(0.0348), ('r', 'n'): tensor(0.0111), ('w', 'a'): tensor(0.3025), ('e', 'b'): tensor(0.0060), ('c', 'c'): tensor(0.0122), ('a', 'w'): tensor(0.0048), ('w', 'y'): tensor(0.0797), ('y', 'e'): tensor(0.0309), ('e', 'o'): tensor(0.0132), ('a', 'k'): tensor(0.0168), ('n', 'g'): tensor(0.0150), ('k', 'o'): tensor(0.0685), ('b', 'l'): tensor(0.0393), ('h', 'o'): tensor(0.0378), ('e', 'g'): tensor(0.0062), ('f', 'r'): tensor(0.1271), ('s', 'p'): tensor(0.0064), ('l', 's'): tensor(0.0068), ('y', 'z'): tensor(0.0081), ('g', 'g'): tensor(0.0135), ('z', 'u'): tensor(0.0309), ('i', 'd'): tensor(0.0249), ('m', '<stop>'): tensor(0.0778), ('o', 'g'): tensor(0.0057), ('j', 'e'): tensor(0.1521), ('g', 'n'): tensor(0.0145), ('y', 'r'): tensor(0.0299), ('c', '<stop>'): tensor(0.0277), ('c', 'q'): tensor(0.0034), ('u', 'e'): tensor(0.0542), ('i', 'f'): tensor(0.0058), ('f', 'e'): tensor(0.1370), ('i', 'x'): tensor(0.0051), ('x', '<stop>'): tensor(0.2367), ('o', 'y'): tensor(0.0131), ('g', 'o'): tensor(0.0436), ('g', 't'): tensor(0.0166), ('l', 't'): tensor(0.0056), ('g', 'w'): tensor(0.0140), ('w', 'e'): tensor(0.1615), ('l', 'd'): tensor(0.0100), ('a', 'p'): tensor(0.0024), ('h', 'n'): tensor(0.0183), ('t', 'l'): tensor(0.0242), ('m', 'r'): tensor(0.0148), ('n', 'c'): tensor(0.0117), ('l', 'b'): tensor(0.0038), ('i', 'k'): tensor(0.0252), ('<stop>', 'y'): tensor(0.0167), ('t', 'z'): tensor(0.0190), ('h', 'r'): tensor(0.0269), ('j', 'i'): tensor(0.0414), ('h', 't'): tensor(0.0095), ('r', 'r'): tensor(0.0335), ('z', 'l'): tensor(0.0517), ('w', 'r'): tensor(0.0248), ('b', 'b'): tensor(0.0147), ('r', 't'): tensor(0.0165), ('l', 'v'): tensor(0.0052), ('e', 'j'): tensor(0.0027), ('o', 'h'): tensor(0.0217), ('u', 's'): tensor(0.1515), ('i', 'b'): tensor(0.0063), ('g', 'l'): tensor(0.0171), ('h', 'y'): tensor(0.0281), ('p', 'o'): tensor(0.0585), ('p', 'p'): tensor(0.0390), ('p', 'y'): tensor(0.0127), ('n', 'r'): tensor(0.0025), ('z', 'm'): tensor(0.0150), ('v', 'o'): tensor(0.0599), ('l', 'm'): tensor(0.0044), ('o', 'x'): tensor(0.0058), ('d', '<stop>'): tensor(0.0941), ('i', 'u'): tensor(0.0062), ('v', '<stop>'): tensor(0.0346), ('f', 'f'): tensor(0.0497), ('b', 'o'): tensor(0.0401), ('e', 'k'): tensor(0.0088), ('c', 'r'): tensor(0.0218), ('d', 'g'): tensor(0.0047), ('r', 'c'): tensor(0.0079), ('r', 'h'): tensor(0.0096), ('n', 'k'): tensor(0.0032), ('h', 'u'): tensor(0.0219), ('d', 's'): tensor(0.0055), ('a', 'x'): tensor(0.0054), ('y', 'c'): tensor(0.0119), ('e', 'w'): tensor(0.0025), ('v', 'k'): tensor(0.0016), ('z', 'h'): tensor(0.0183), ('w', 'h'): tensor(0.0258), ('t', 'n'): tensor(0.0041), ('x', 'l'): tensor(0.0574), ('g', 'u'): tensor(0.0446), ('u', 'a'): tensor(0.0523), ('u', 'p'): tensor(0.0054), ('u', 'g'): tensor(0.0153), ('d', 'u'): tensor(0.0169), ('l', 'c'): tensor(0.0019), ('r', 'b'): tensor(0.0033), ('a', 'q'): tensor(0.0018), ('b', '<stop>'): tensor(0.0435), ('g', 'y'): tensor(0.0166), ('y', 'p'): tensor(0.0016), ('p', 't'): tensor(0.0175), ('e', 'z'): tensor(0.0089), ('z', 'r'): tensor(0.0138), ('f', 'l'): tensor(0.0232), ('o', '<stop>'): tensor(0.1079), ('o', 'b'): tensor(0.0178), ('u', 'z'): tensor(0.0147), ('z', '<stop>'): tensor(0.0671), ('i', 'q'): tensor(0.0030), ('y', 'v'): tensor(0.0109), ('n', 'v'): tensor(0.0031), ('d', 'h'): tensor(0.0217), ('g', 'd'): tensor(0.0104), ('t', 's'): tensor(0.0065), ('n', 'h'): tensor(0.0015), ('y', 'j'): tensor(0.0025), ('k', 'r'): tensor(0.0218), ('z', 'b'): tensor(0.0021), ('g', '<stop>'): tensor(0.0566), ('a', 'j'): tensor(0.0052), ('r', 'j'): tensor(0.0020), ('m', 'p'): tensor(0.0059), ('p', 'b'): tensor(0.0029), ('y', 'o'): tensor(0.0278), ('z', 'y'): tensor(0.0617), ('p', 'l'): tensor(0.0166), ('l', 'k'): tensor(0.0018), ('i', 'j'): tensor(0.0044), ('x', 'e'): tensor(0.0531), ('y', 'u'): tensor(0.0145), ('l', 'n'): tensor(0.0011), ('u', 'x'): tensor(0.0112), ('i', 'h'): tensor(0.0054), ('w', 's'): tensor(0.0226), ('k', 's'): tensor(0.0190), ('m', 'u'): tensor(0.0211), ('y', 'k'): tensor(0.0089), ('e', 'f'): tensor(0.0041), ('k', '<stop>'): tensor(0.0722), ('y', 'm'): tensor(0.0152), ('z', 'z'): tensor(0.0192), ('m', 'd'): tensor(0.0038), ('s', 'r'): tensor(0.0069), ('e', 'u'): tensor(0.0034), ('l', 'h'): tensor(0.0014), ('a', 'f'): tensor(0.0040), ('r', 'w'): tensor(0.0017), ('n', 'u'): tensor(0.0053), ('v', 'r'): tensor(0.0190), ('m', 's'): tensor(0.0054), ('<stop>', 'u'): tensor(0.0025), ('f', 's'): tensor(0.0077), ('y', 'b'): tensor(0.0029), ('x', 'o'): tensor(0.0603), ('g', 's'): tensor(0.0161), ('x', 'y'): tensor(0.0445), ('w', 'n'): tensor(0.0635), ('j', 'h'): tensor(0.0159), ('f', 'n'): tensor(0.0055), ('n', 'j'): tensor(0.0025), ('r', 'v'): tensor(0.0064), ('n', 'm'): tensor(0.0011), ('t', 'c'): tensor(0.0032), ('s', 'w'): tensor(0.0031), ('k', 't'): tensor(0.0036), ('f', 't'): tensor(0.0210), ('x', 't'): tensor(0.1019), ('u', 'v'): tensor(0.0121), ('k', 'k'): tensor(0.0042), ('s', 'n'): tensor(0.0031), ('u', '<stop>'): tensor(0.0498), ('j', 'r'): tensor(0.0041), ('y', 'x'): tensor(0.0030), ('h', 'm'): tensor(0.0155), ('e', 'q'): tensor(0.0007), ('u', 'o'): tensor(0.0035), ('f', '<stop>'): tensor(0.0895), ('h', 'z'): tensor(0.0028), ('h', 'k'): tensor(0.0039), ('y', 'g'): tensor(0.0032), ('q', 'r'): tensor(0.0074), ('v', 'n'): tensor(0.0035), ('s', 'd'): tensor(0.0012), ('y', 'i'): tensor(0.0197), ('n', 'w'): tensor(0.0007), ('d', 'v'): tensor(0.0033), ('h', 'v'): tensor(0.0053), ('x', 'w'): tensor(0.0057), ('o', 'z'): tensor(0.0069), ('k', 'u'): tensor(0.0101), ('u', 'h'): tensor(0.0188), ('k', 'n'): tensor(0.0054), ('s', 'b'): tensor(0.0027), ('i', 'i'): tensor(0.0047), ('y', 'y'): tensor(0.0025), ('r', 'z'): tensor(0.0019), ('l', 'g'): tensor(0.0005), ('l', 'p'): tensor(0.0011), ('p', '<stop>'): tensor(0.0331), ('b', 'u'): tensor(0.0174), ('f', 'u'): tensor(0.0122), ('b', 'h'): tensor(0.0159), ('f', 'y'): tensor(0.0166), ('u', 'w'): tensor(0.0278), ('x', 'u'): tensor(0.0086), ('q', '<stop>'): tensor(0.1066), ('l', 'r'): tensor(0.0014), ('m', 'h'): tensor(0.0009), ('l', 'w'): tensor(0.0012), ('j', '<stop>'): tensor(0.0248), ('s', 'v'): tensor(0.0019), ('m', 'l'): tensor(0.0009), ('n', 'f'): tensor(0.0007), ('u', 'j'): tensor(0.0048), ('f', 'o'): tensor(0.0674), ('j', 'l'): tensor(0.0034), ('t', 'g'): tensor(0.0005), ('j', 'm'): tensor(0.0021), ('v', 'v'): tensor(0.0031), ('p', 's'): tensor(0.0166), ('t', 'w'): tensor(0.0022), ('x', 'c'): tensor(0.0072), ('u', 'k'): tensor(0.0300), ('v', 'l'): tensor(0.0058), ('h', 'd'): tensor(0.0033), ('l', 'z'): tensor(0.0008), ('k', 'w'): tensor(0.0069), ('n', 'b'): tensor(0.0005), ('q', 's'): tensor(0.0110), ('i', 'w'): tensor(0.0005), ('c', 's'): tensor(0.0017), ('h', 's'): tensor(0.0042), ('m', 't'): tensor(0.0008), ('h', 'w'): tensor(0.0014), ('x', 'x'): tensor(0.0560), ('t', 'x'): tensor(0.0005), ('d', 'z'): tensor(0.0004), ('x', 'z'): tensor(0.0287), ('t', 'm'): tensor(0.0009), ('t', 'j'): tensor(0.0007), ('u', 'q'): tensor(0.0035), ('q', 'a'): tensor(0.0515), ('f', 'k'): tensor(0.0033), ('z', 'n'): tensor(0.0021), ('l', 'j'): tensor(0.0005), ('j', 'w'): tensor(0.0024), ('v', 'u'): tensor(0.0031), ('c', 'j'): tensor(0.0011), ('h', 'b'): tensor(0.0012), ('z', 't'): tensor(0.0021), ('p', 'u'): tensor(0.0049), ('m', 'z'): tensor(0.0018), ('x', 's'): tensor(0.0459), ('b', 't'): tensor(0.0011), ('u', 'y'): tensor(0.0045), ('d', 'j'): tensor(0.0018), ('j', 's'): tensor(0.0028), ('w', 'u'): tensor(0.0280), ('o', 'j'): tensor(0.0021), ('b', 's'): tensor(0.0034), ('d', 'w'): tensor(0.0044), ('w', 'o'): tensor(0.0398), ('j', 'n'): tensor(0.0010), ('w', 't'): tensor(0.0097), ('l', 'f'): tensor(0.0016), ('d', 'm'): tensor(0.0056), ('p', 'j'): tensor(0.0019), ('j', 'y'): tensor(0.0038), ('y', 'f'): tensor(0.0013), ('q', 'i'): tensor(0.0515), ('j', 'v'): tensor(0.0021), ('q', 'l'): tensor(0.0074), ('s', 'z'): tensor(0.0014), ('k', 'm'): tensor(0.0020), ('w', 'l'): tensor(0.0151), ('p', 'f'): tensor(0.0019), ('q', 'w'): tensor(0.0147), ('n', 'x'): tensor(0.0004), ('k', 'c'): tensor(0.0006), ('t', 'v'): tensor(0.0029), ('c', 'u'): tensor(0.0102), ('z', 'k'): tensor(0.0013), ('c', 'z'): tensor(0.0014), ('y', 'q'): tensor(0.0007), ('y', 'h'): tensor(0.0024), ('r', 'f'): tensor(0.0008), ('s', 'j'): tensor(0.0004), ('h', 'j'): tensor(0.0013), ('g', 'b'): tensor(0.0021), ('u', 'f'): tensor(0.0064), ('s', 'f'): tensor(0.0004), ('q', 'e'): tensor(0.0074), ('b', 'c'): tensor(0.0008), ('c', 'd'): tensor(0.0006), ('z', 'j'): tensor(0.0013), ('n', 'q'): tensor(0.0002), ('m', 'f'): tensor(0.0003), ('p', 'n'): tensor(0.0019), ('f', 'z'): tensor(0.0033), ('b', 'n'): tensor(0.0019), ('w', 'd'): tensor(0.0097), ('w', 'b'): tensor(0.0022), ('b', 'd'): tensor(0.0250), ('z', 's'): tensor(0.0021), ('p', 'c'): tensor(0.0019), ('h', 'g'): tensor(0.0004), ('m', 'j'): tensor(0.0012), ('w', 'w'): tensor(0.0032), ('k', 'j'): tensor(0.0006), ('h', 'p'): tensor(0.0003), ('j', 'k'): tensor(0.0010), ('o', 'q'): tensor(0.0005), ('f', 'w'): tensor(0.0055), ('f', 'h'): tensor(0.0022), ('w', 'm'): tensor(0.0032), ('b', 'j'): tensor(0.0008), ('r', 'q'): tensor(0.0013), ('z', 'c'): tensor(0.0013), ('z', 'v'): tensor(0.0013), ('f', 'g'): tensor(0.0022), ('n', 'p'): tensor(0.0003), ('z', 'g'): tensor(0.0008), ('d', 't'): tensor(0.0009), ('w', 'f'): tensor(0.0032), ('d', 'f'): tensor(0.0011), ('w', 'k'): tensor(0.0075), ('q', 'm'): tensor(0.0110), ('k', 'z'): tensor(0.0006), ('j', 'j'): tensor(0.0010), ('c', 'p'): tensor(0.0006), ('p', 'k'): tensor(0.0019), ('p', 'm'): tensor(0.0019), ('j', 'd'): tensor(0.0017), ('r', 'x'): tensor(0.0003), ('x', 'n'): tensor(0.0029), ('d', 'c'): tensor(0.0007), ('g', 'j'): tensor(0.0021), ('x', 'f'): tensor(0.0057), ('j', 'c'): tensor(0.0017), ('s', 'q'): tensor(0.0002), ('k', 'f'): tensor(0.0004), ('z', 'p'): tensor(0.0013), ('j', 't'): tensor(0.0010), ('k', 'b'): tensor(0.0006), ('m', 'k'): tensor(0.0003), ('m', 'w'): tensor(0.0005), ('x', 'h'): tensor(0.0029), ('h', 'f'): tensor(0.0004), ('x', 'd'): tensor(0.0086), ('y', 'w'): tensor(0.0005), ('z', 'w'): tensor(0.0017), ('d', 'k'): tensor(0.0007), ('c', 'g'): tensor(0.0008), ('u', 'u'): tensor(0.0013), ('t', 'f'): tensor(0.0005), ('g', 'm'): tensor(0.0036), ('m', 'v'): tensor(0.0006), ('c', 'x'): tensor(0.0011), ('h', 'c'): tensor(0.0004), ('g', 'f'): tensor(0.0010), ('q', 'o'): tensor(0.0110), ('l', 'q'): tensor(0.0003), ('v', 'b'): tensor(0.0008), ('j', 'p'): tensor(0.0007), ('k', 'd'): tensor(0.0006), ('g', 'z'): tensor(0.0010), ('v', 'd'): tensor(0.0008), ('d', 'b'): tensor(0.0004), ('v', 'h'): tensor(0.0008), ('k', 'v'): tensor(0.0006), ('h', 'h'): tensor(0.0003), ('s', 'g'): tensor(0.0004), ('g', 'v'): tensor(0.0010), ('d', 'q'): tensor(0.0004), ('x', 'b'): tensor(0.0029), ('w', 'z'): tensor(0.0022), ('h', 'q'): tensor(0.0003), ('j', 'b'): tensor(0.0007), ('z', 'd'): tensor(0.0013), ('x', 'm'): tensor(0.0029), ('w', 'g'): tensor(0.0022), ('t', 'b'): tensor(0.0004), ('z', 'x'): tensor(0.0008)}
bigram_p_sorted = {k: v.float() for k, v in sorted(bigram_p.items(), reverse=True, key=lambda x: x[1])}
print(bigram_p_sorted)
{('q', 'u'): tensor(0.7610), ('j', 'a'): tensor(0.5083), ('m', 'a'): tensor(0.3901), ('n', '<stop>'): tensor(0.3691), ('z', 'a'): tensor(0.3590), ('v', 'i'): tensor(0.3545), ('k', 'a'): tensor(0.3437), ('b', 'r'): tensor(0.3187), ('h', '<stop>'): tensor(0.3164), ('o', 'n'): tensor(0.3040), ('w', 'a'): tensor(0.3025), ('h', 'a'): tensor(0.2948), ('f', 'a'): tensor(0.2685), ('v', 'a'): tensor(0.2499), ('b', 'e'): tensor(0.2480), ('r', 'i'): tensor(0.2389), ('d', 'a'): tensor(0.2373), ('x', '<stop>'): tensor(0.2367), ('d', 'e'): tensor(0.2336), ('c', 'a'): tensor(0.2310), ('v', 'e'): tensor(0.2211), ('y', 'a'): tensor(0.2193), ('l', 'e'): tensor(0.2093), ('y', '<stop>'): tensor(0.2054), ('p', 'a'): tensor(0.2047), ('p', 'h'): tensor(0.1998), ('a', '<stop>'): tensor(0.1960), ('e', '<stop>'): tensor(0.1951), ('p', 'e'): tensor(0.1930), ('m', 'i'): tensor(0.1893), ('c', 'h'): tensor(0.1883), ('l', 'a'): tensor(0.1880), ('g', 'h'): tensor(0.1873), ('y', 'n'): tensor(0.1869), ('r', 'a'): tensor(0.1856), ('t', 'a'): tensor(0.1846), ('f', 'i'): tensor(0.1779), ('k', 'e'): tensor(0.1778), ('l', 'i'): tensor(0.1777), ('g', 'e'): tensor(0.1738), ('g', 'a'): tensor(0.1718), ('j', 'o'): tensor(0.1655), ('n', 'a'): tensor(0.1625), ('w', 'e'): tensor(0.1615), ('a', 'n'): tensor(0.1605), ('w', 'i'): tensor(0.1604), ('e', 'l'): tensor(0.1591), ('s', 'h'): tensor(0.1586), ('c', 'e'): tensor(0.1563), ('z', 'e'): tensor(0.1560), ('z', 'i'): tensor(0.1522), ('j', 'e'): tensor(0.1521), ('u', 's'): tensor(0.1515), ('x', 'a'): tensor(0.1492), ('s', 'a'): tensor(0.1483), ('p', 'r'): tensor(0.1481), ('x', 'i'): tensor(0.1478), ('s', '<stop>'): tensor(0.1443), ('i', '<stop>'): tensor(0.1407), ('i', 'a'): tensor(0.1382), ('<stop>', 'a'): tensor(0.1377), ('f', 'e'): tensor(0.1370), ('r', 'e'): tensor(0.1337), ('o', 'r'): tensor(0.1336), ('u', 'r'): tensor(0.1324), ('e', 'n'): tensor(0.1310), ('t', 'e'): tensor(0.1287), ('f', 'r'): tensor(0.1271), ('m', 'e'): tensor(0.1233), ('d', 'i'): tensor(0.1228), ('b', 'a'): tensor(0.1217), ('i', 'n'): tensor(0.1202), ('t', 'o'): tensor(0.1199), ('t', 'h'): tensor(0.1163), ('l', 'y'): tensor(0.1138), ('y', 'l'): tensor(0.1130), ('s', 'e'): tensor(0.1092), ('r', '<stop>'): tensor(0.1085), ('o', '<stop>'): tensor(0.1079), ('c', 'o'): tensor(0.1079), ('q', '<stop>'): tensor(0.1066), ('g', 'r'): tensor(0.1048), ('n', 'n'): tensor(0.1041), ('x', 't'): tensor(0.1019), ('k', 'i'): tensor(0.1012), ('g', 'i'): tensor(0.0991), ('l', 'l'): tensor(0.0964), ('a', 'r'): tensor(0.0964), ('u', 'l'): tensor(0.0963), ('e', 'r'): tensor(0.0959), ('h', 'i'): tensor(0.0959), ('t', 'i'): tensor(0.0957), ('s', 't'): tensor(0.0945), ('l', '<stop>'): tensor(0.0942), ('n', 'i'): tensor(0.0942), ('d', '<stop>'): tensor(0.0941), ('i', 'e'): tensor(0.0934), ('<stop>', 'k'): tensor(0.0925), ('c', 'k'): tensor(0.0898), ('f', '<stop>'): tensor(0.0895), ('h', 'e'): tensor(0.0886), ('u', 'n'): tensor(0.0880), ('t', '<stop>'): tensor(0.0869), ('s', 'i'): tensor(0.0845), ('b', 'i'): tensor(0.0824), ('w', 'y'): tensor(0.0797), ('<stop>', 'm'): tensor(0.0793), ('o', 'l'): tensor(0.0781), ('m', '<stop>'): tensor(0.0778), ('d', 'r'): tensor(0.0773), ('c', 'i'): tensor(0.0770), ('i', 'l'): tensor(0.0760), ('<stop>', 'j'): tensor(0.0756), ('k', 'y'): tensor(0.0754), ('a', 'l'): tensor(0.0746), ('i', 's'): tensor(0.0744), ('n', 'e'): tensor(0.0742), ('k', '<stop>'): tensor(0.0722), ('j', 'u'): tensor(0.0700), ('d', 'o'): tensor(0.0690), ('a', 'h'): tensor(0.0689), ('r', 'o'): tensor(0.0685), ('k', 'o'): tensor(0.0685), ('m', 'o'): tensor(0.0682), ('f', 'o'): tensor(0.0674), ('t', 't'): tensor(0.0673), ('z', '<stop>'): tensor(0.0671), ('s', 'o'): tensor(0.0656), ('<stop>', 's'): tensor(0.0642), ('o', 's'): tensor(0.0637), ('w', 'n'): tensor(0.0635), ('t', 'r'): tensor(0.0634), ('e', 'e'): tensor(0.0623), ('z', 'y'): tensor(0.0617), ('t', 'y'): tensor(0.0614), ('k', 'h'): tensor(0.0611), ('r', 'y'): tensor(0.0609), ('a', 'y'): tensor(0.0605), ('p', 'i'): tensor(0.0604), ('x', 'o'): tensor(0.0603), ('v', 'o'): tensor(0.0599), ('p', 'o'): tensor(0.0585), ('d', 'y'): tensor(0.0579), ('x', 'l'): tensor(0.0574), ('s', 's'): tensor(0.0570), ('g', '<stop>'): tensor(0.0566), ('w', '<stop>'): tensor(0.0560), ('x', 'x'): tensor(0.0560), ('u', 'e'): tensor(0.0542), ('x', 'e'): tensor(0.0531), ('<stop>', 'd'): tensor(0.0528), ('e', 'y'): tensor(0.0524), ('u', 'a'): tensor(0.0523), ('z', 'l'): tensor(0.0517), ('q', 'a'): tensor(0.0515), ('q', 'i'): tensor(0.0515), ('<stop>', 'r'): tensor(0.0512), ('u', '<stop>'): tensor(0.0498), ('f', 'f'): tensor(0.0497), ('l', 'o'): tensor(0.0496), ('u', 'm'): tensor(0.0494), ('<stop>', 'l'): tensor(0.0491), ('a', 'i'): tensor(0.0487), ('a', 'm'): tensor(0.0483), ('<stop>', 'c'): tensor(0.0482), ('i', 'r'): tensor(0.0480), ('<stop>', 'e'): tensor(0.0478), ('v', 'y'): tensor(0.0474), ('z', 'o'): tensor(0.0463), ('x', 's'): tensor(0.0459), ('g', 'u'): tensor(0.0446), ('x', 'y'): tensor(0.0445), ('i', 'y'): tensor(0.0441), ('u', 'd'): tensor(0.0437), ('g', 'o'): tensor(0.0436), ('b', '<stop>'): tensor(0.0435), ('m', 'y'): tensor(0.0434), ('e', 's'): tensor(0.0422), ('j', 'i'): tensor(0.0414), ('y', 's'): tensor(0.0411), ('<stop>', 't'): tensor(0.0409), ('<stop>', 'b'): tensor(0.0408), ('e', 'i'): tensor(0.0401), ('b', 'o'): tensor(0.0401), ('w', 'o'): tensor(0.0398), ('b', 'l'): tensor(0.0393), ('p', 'p'): tensor(0.0390), ('u', 'i'): tensor(0.0389), ('n', 'd'): tensor(0.0385), ('h', 'o'): tensor(0.0378), ('e', 'm'): tensor(0.0377), ('<stop>', 'n'): tensor(0.0358), ('o', 'u'): tensor(0.0348), ('v', '<stop>'): tensor(0.0346), ('s', 'l'): tensor(0.0345), ('r', 'r'): tensor(0.0335), ('e', 'a'): tensor(0.0333), ('i', 'o'): tensor(0.0333), ('u', 'b'): tensor(0.0332), ('u', 'c'): tensor(0.0332), ('p', '<stop>'): tensor(0.0331), ('c', 'l'): tensor(0.0331), ('a', 's'): tensor(0.0330), ('o', 'm'): tensor(0.0330), ('r', 'l'): tensor(0.0326), ('b', 'y'): tensor(0.0318), ('y', 'e'): tensor(0.0309), ('z', 'u'): tensor(0.0309), ('a', 'd'): tensor(0.0308), ('i', 't'): tensor(0.0306), ('u', 'k'): tensor(0.0300), ('y', 'r'): tensor(0.0299), ('c', 'y'): tensor(0.0297), ('<stop>', 'z'): tensor(0.0290), ('i', 'c'): tensor(0.0288), ('x', 'z'): tensor(0.0287), ('e', 't'): tensor(0.0284), ('h', 'y'): tensor(0.0281), ('w', 'u'): tensor(0.0280), ('y', 'd'): tensor(0.0279), ('y', 'o'): tensor(0.0278), ('k', 'l'): tensor(0.0278), ('u', 'w'): tensor(0.0278), ('c', '<stop>'): tensor(0.0277), ('<stop>', 'h'): tensor(0.0273), ('d', 'd'): tensor(0.0273), ('n', 'o'): tensor(0.0271), ('h', 'r'): tensor(0.0269), ('s', 'y'): tensor(0.0266), ('u', 't'): tensor(0.0265), ('w', 'h'): tensor(0.0258), ('m', 'm'): tensor(0.0254), ('n', 'y'): tensor(0.0254), ('i', 'k'): tensor(0.0252), ('b', 'd'): tensor(0.0250), ('i', 'd'): tensor(0.0249), ('j', '<stop>'): tensor(0.0248), ('w', 'r'): tensor(0.0248), ('a', 'v'): tensor(0.0246), ('h', 'l'): tensor(0.0244), ('t', 'l'): tensor(0.0242), ('i', 'g'): tensor(0.0242), ('n', 't'): tensor(0.0242), ('i', 'm'): tensor(0.0242), ('o', 'd'): tensor(0.0241), ('l', 'u'): tensor(0.0233), ('f', 'l'): tensor(0.0232), ('s', 'u'): tensor(0.0229), ('e', 'v'): tensor(0.0227), ('w', 's'): tensor(0.0226), ('o', 'v'): tensor(0.0223), ('h', 'u'): tensor(0.0219), ('k', 'r'): tensor(0.0218), ('c', 'r'): tensor(0.0218), ('o', 'h'): tensor(0.0217), ('d', 'h'): tensor(0.0217), ('m', 'u'): tensor(0.0211), ('f', 't'): tensor(0.0210), ('<stop>', 'g'): tensor(0.0209), ('a', 'e'): tensor(0.0205), ('a', 't'): tensor(0.0203), ('r', 'u'): tensor(0.0199), ('y', 'i'): tensor(0.0197), ('z', 'z'): tensor(0.0192), ('k', 's'): tensor(0.0190), ('v', 'r'): tensor(0.0190), ('t', 'z'): tensor(0.0190), ('o', 'a'): tensor(0.0189), ('e', 'd'): tensor(0.0189), ('u', 'h'): tensor(0.0188), ('<stop>', 'i'): tensor(0.0185), ('z', 'h'): tensor(0.0183), ('h', 'n'): tensor(0.0183), ('o', 'b'): tensor(0.0178), ('p', 't'): tensor(0.0175), ('b', 'u'): tensor(0.0174), ('g', 'l'): tensor(0.0171), ('m', 'b'): tensor(0.0170), ('d', 'u'): tensor(0.0169), ('a', 'k'): tensor(0.0168), ('o', 'e'): tensor(0.0168), ('<stop>', 'y'): tensor(0.0167), ('g', 't'): tensor(0.0166), ('g', 'y'): tensor(0.0166), ('f', 'y'): tensor(0.0166), ('p', 'l'): tensor(0.0166), ('p', 's'): tensor(0.0166), ('r', 't'): tensor(0.0165), ('a', 'a'): tensor(0.0164), ('<stop>', 'p'): tensor(0.0161), ('g', 's'): tensor(0.0161), ('a', 'b'): tensor(0.0160), ('b', 'h'): tensor(0.0159), ('j', 'h'): tensor(0.0159), ('i', 'z'): tensor(0.0157), ('h', 'm'): tensor(0.0155), ('u', 'g'): tensor(0.0153), ('i', 'v'): tensor(0.0153), ('y', 'm'): tensor(0.0152), ('n', 's'): tensor(0.0152), ('w', 'l'): tensor(0.0151), ('r', 's'): tensor(0.0150), ('z', 'm'): tensor(0.0150), ('o', 't'): tensor(0.0150), ('n', 'g'): tensor(0.0150), ('r', 'd'): tensor(0.0148), ('m', 'r'): tensor(0.0148), ('b', 'b'): tensor(0.0147), ('q', 'w'): tensor(0.0147), ('u', 'z'): tensor(0.0147), ('o', 'o'): tensor(0.0146), ('g', 'n'): tensor(0.0145), ('y', 'u'): tensor(0.0145), ('o', 'w'): tensor(0.0145), ('o', 'c'): tensor(0.0145), ('t', 'u'): tensor(0.0142), ('g', 'w'): tensor(0.0140), ('a', 'c'): tensor(0.0139), ('z', 'r'): tensor(0.0138), ('g', 'g'): tensor(0.0135), ('e', 'o'): tensor(0.0132), ('o', 'y'): tensor(0.0131), ('<stop>', 'f'): tensor(0.0130), ('a', 'z'): tensor(0.0129), ('r', 'm'): tensor(0.0128), ('p', 'y'): tensor(0.0127), ('<stop>', 'o'): tensor(0.0123), ('c', 'c'): tensor(0.0122), ('f', 'u'): tensor(0.0122), ('u', 'v'): tensor(0.0121), ('o', 'p'): tensor(0.0121), ('y', 'c'): tensor(0.0119), ('<stop>', 'v'): tensor(0.0118), ('n', 'c'): tensor(0.0117), ('a', 'u'): tensor(0.0113), ('s', 'm'): tensor(0.0112), ('u', 'x'): tensor(0.0112), ('r', 'n'): tensor(0.0111), ('d', 'l'): tensor(0.0111), ('q', 's'): tensor(0.0110), ('q', 'm'): tensor(0.0110), ('q', 'o'): tensor(0.0110), ('y', 'v'): tensor(0.0109), ('y', 't'): tensor(0.0107), ('n', 'l'): tensor(0.0107), ('g', 'd'): tensor(0.0104), ('s', 'k'): tensor(0.0102), ('c', 't'): tensor(0.0102), ('c', 'u'): tensor(0.0102), ('k', 'u'): tensor(0.0101), ('l', 'd'): tensor(0.0100), ('w', 't'): tensor(0.0097), ('w', 'd'): tensor(0.0097), ('<stop>', 'w'): tensor(0.0096), ('r', 'h'): tensor(0.0096), ('h', 't'): tensor(0.0095), ('e', 'z'): tensor(0.0089), ('y', 'k'): tensor(0.0089), ('o', 'i'): tensor(0.0088), ('e', 'k'): tensor(0.0088), ('o', 'k'): tensor(0.0087), ('x', 'u'): tensor(0.0086), ('x', 'd'): tensor(0.0086), ('y', 'z'): tensor(0.0081), ('n', 'z'): tensor(0.0080), ('r', 'c'): tensor(0.0079), ('m', 'c'): tensor(0.0078), ('f', 's'): tensor(0.0077), ('e', 'c'): tensor(0.0075), ('w', 'k'): tensor(0.0075), ('s', 'c'): tensor(0.0075), ('e', 'h'): tensor(0.0075), ('q', 'r'): tensor(0.0074), ('q', 'l'): tensor(0.0074), ('q', 'e'): tensor(0.0074), ('x', 'c'): tensor(0.0072), ('r', 'k'): tensor(0.0072), ('k', 'w'): tensor(0.0069), ('o', 'z'): tensor(0.0069), ('s', 'r'): tensor(0.0069), ('l', 's'): tensor(0.0068), ('e', 'x'): tensor(0.0065), ('t', 's'): tensor(0.0065), ('s', 'p'): tensor(0.0064), ('u', 'f'): tensor(0.0064), ('r', 'v'): tensor(0.0064), ('i', 'b'): tensor(0.0063), ('i', 'u'): tensor(0.0062), ('e', 'g'): tensor(0.0062), ('r', 'g'): tensor(0.0061), ('e', 'b'): tensor(0.0060), ('m', 'p'): tensor(0.0059), ('v', 'l'): tensor(0.0058), ('d', 'n'): tensor(0.0058), ('o', 'x'): tensor(0.0058), ('i', 'f'): tensor(0.0058), ('x', 'w'): tensor(0.0057), ('x', 'f'): tensor(0.0057), ('o', 'g'): tensor(0.0057), ('d', 'm'): tensor(0.0056), ('l', 't'): tensor(0.0056), ('f', 'n'): tensor(0.0055), ('f', 'w'): tensor(0.0055), ('d', 's'): tensor(0.0055), ('i', 'h'): tensor(0.0054), ('u', 'p'): tensor(0.0054), ('m', 's'): tensor(0.0054), ('a', 'x'): tensor(0.0054), ('k', 'n'): tensor(0.0054), ('n', 'u'): tensor(0.0053), ('h', 'v'): tensor(0.0053), ('l', 'v'): tensor(0.0052), ('a', 'j'): tensor(0.0052), ('i', 'x'): tensor(0.0051), ('a', 'g'): tensor(0.0050), ('p', 'u'): tensor(0.0049), ('u', 'j'): tensor(0.0048), ('a', 'w'): tensor(0.0048), ('d', 'g'): tensor(0.0047), ('i', 'i'): tensor(0.0047), ('u', 'y'): tensor(0.0045), ('o', 'f'): tensor(0.0044), ('l', 'm'): tensor(0.0044), ('d', 'w'): tensor(0.0044), ('i', 'j'): tensor(0.0044), ('<stop>', 'x'): tensor(0.0042), ('h', 's'): tensor(0.0042), ('k', 'k'): tensor(0.0042), ('j', 'r'): tensor(0.0041), ('t', 'n'): tensor(0.0041), ('e', 'p'): tensor(0.0041), ('e', 'f'): tensor(0.0041), ('a', 'f'): tensor(0.0040), ('h', 'k'): tensor(0.0039), ('l', 'b'): tensor(0.0038), ('j', 'y'): tensor(0.0038), ('m', 'd'): tensor(0.0038), ('g', 'm'): tensor(0.0036), ('k', 't'): tensor(0.0036), ('u', 'o'): tensor(0.0035), ('u', 'q'): tensor(0.0035), ('v', 'n'): tensor(0.0035), ('j', 'l'): tensor(0.0034), ('e', 'u'): tensor(0.0034), ('b', 's'): tensor(0.0034), ('c', 'q'): tensor(0.0034), ('f', 'k'): tensor(0.0033), ('f', 'z'): tensor(0.0033), ('r', 'b'): tensor(0.0033), ('h', 'd'): tensor(0.0033), ('d', 'v'): tensor(0.0033), ('t', 'c'): tensor(0.0032), ('w', 'w'): tensor(0.0032), ('w', 'm'): tensor(0.0032), ('w', 'f'): tensor(0.0032), ('n', 'k'): tensor(0.0032), ('y', 'g'): tensor(0.0032), ('m', 'n'): tensor(0.0032), ('v', 'v'): tensor(0.0031), ('v', 'u'): tensor(0.0031), ('s', 'w'): tensor(0.0031), ('s', 'n'): tensor(0.0031), ('n', 'v'): tensor(0.0031), ('i', 'p'): tensor(0.0031), ('i', 'q'): tensor(0.0030), ('y', 'x'): tensor(0.0030), ('p', 'b'): tensor(0.0029), ('<stop>', 'q'): tensor(0.0029), ('t', 'v'): tensor(0.0029), ('x', 'n'): tensor(0.0029), ('x', 'h'): tensor(0.0029), ('x', 'b'): tensor(0.0029), ('x', 'm'): tensor(0.0029), ('y', 'b'): tensor(0.0029), ('j', 's'): tensor(0.0028), ('h', 'z'): tensor(0.0028), ('e', 'j'): tensor(0.0027), ('s', 'b'): tensor(0.0027), ('e', 'w'): tensor(0.0025), ('<stop>', 'u'): tensor(0.0025), ('n', 'r'): tensor(0.0025), ('n', 'j'): tensor(0.0025), ('y', 'j'): tensor(0.0025), ('y', 'y'): tensor(0.0025), ('a', 'p'): tensor(0.0024), ('j', 'w'): tensor(0.0024), ('y', 'h'): tensor(0.0024), ('f', 'h'): tensor(0.0022), ('f', 'g'): tensor(0.0022), ('t', 'w'): tensor(0.0022), ('w', 'b'): tensor(0.0022), ('w', 'z'): tensor(0.0022), ('w', 'g'): tensor(0.0022), ('o', 'j'): tensor(0.0021), ('z', 'b'): tensor(0.0021), ('z', 'n'): tensor(0.0021), ('z', 't'): tensor(0.0021), ('z', 's'): tensor(0.0021), ('g', 'b'): tensor(0.0021), ('g', 'j'): tensor(0.0021), ('j', 'm'): tensor(0.0021), ('j', 'v'): tensor(0.0021), ('r', 'j'): tensor(0.0020), ('k', 'm'): tensor(0.0020), ('p', 'j'): tensor(0.0019), ('p', 'f'): tensor(0.0019), ('p', 'n'): tensor(0.0019), ('p', 'c'): tensor(0.0019), ('p', 'k'): tensor(0.0019), ('p', 'm'): tensor(0.0019), ('b', 'n'): tensor(0.0019), ('r', 'z'): tensor(0.0019), ('a', 'o'): tensor(0.0019), ('l', 'c'): tensor(0.0019), ('s', 'v'): tensor(0.0019), ('d', 'j'): tensor(0.0018), ('m', 'z'): tensor(0.0018), ('a', 'q'): tensor(0.0018), ('l', 'k'): tensor(0.0018), ('r', 'w'): tensor(0.0017), ('j', 'd'): tensor(0.0017), ('j', 'c'): tensor(0.0017), ('c', 's'): tensor(0.0017), ('z', 'w'): tensor(0.0017), ('l', 'f'): tensor(0.0016), ('y', 'p'): tensor(0.0016), ('v', 'k'): tensor(0.0016), ('n', 'h'): tensor(0.0015), ('h', 'w'): tensor(0.0014), ('l', 'h'): tensor(0.0014), ('c', 'z'): tensor(0.0014), ('l', 'r'): tensor(0.0014), ('s', 'z'): tensor(0.0014), ('r', 'q'): tensor(0.0013), ('y', 'f'): tensor(0.0013), ('h', 'j'): tensor(0.0013), ('u', 'u'): tensor(0.0013), ('z', 'k'): tensor(0.0013), ('z', 'j'): tensor(0.0013), ('z', 'c'): tensor(0.0013), ('z', 'v'): tensor(0.0013), ('z', 'p'): tensor(0.0013), ('z', 'd'): tensor(0.0013), ('s', 'd'): tensor(0.0012), ('l', 'w'): tensor(0.0012), ('m', 'j'): tensor(0.0012), ('h', 'b'): tensor(0.0012), ('r', 'p'): tensor(0.0012), ('l', 'p'): tensor(0.0011), ('b', 't'): tensor(0.0011), ('c', 'j'): tensor(0.0011), ('c', 'x'): tensor(0.0011), ('d', 'f'): tensor(0.0011), ('n', 'm'): tensor(0.0011), ('l', 'n'): tensor(0.0011), ('g', 'f'): tensor(0.0010), ('g', 'z'): tensor(0.0010), ('g', 'v'): tensor(0.0010), ('j', 'n'): tensor(0.0010), ('j', 'k'): tensor(0.0010), ('j', 'j'): tensor(0.0010), ('j', 't'): tensor(0.0010), ('d', 't'): tensor(0.0009), ('m', 'h'): tensor(0.0009), ('m', 'l'): tensor(0.0009), ('t', 'm'): tensor(0.0009), ('c', 'g'): tensor(0.0008), ('z', 'g'): tensor(0.0008), ('z', 'x'): tensor(0.0008), ('l', 'z'): tensor(0.0008), ('r', 'f'): tensor(0.0008), ('v', 'b'): tensor(0.0008), ('v', 'd'): tensor(0.0008), ('v', 'h'): tensor(0.0008), ('b', 'c'): tensor(0.0008), ('b', 'j'): tensor(0.0008), ('m', 't'): tensor(0.0008), ('e', 'q'): tensor(0.0007), ('d', 'c'): tensor(0.0007), ('d', 'k'): tensor(0.0007), ('t', 'j'): tensor(0.0007), ('y', 'q'): tensor(0.0007), ('j', 'p'): tensor(0.0007), ('j', 'b'): tensor(0.0007), ('n', 'w'): tensor(0.0007), ('n', 'f'): tensor(0.0007), ('m', 'v'): tensor(0.0006), ('k', 'c'): tensor(0.0006), ('k', 'j'): tensor(0.0006), ('k', 'z'): tensor(0.0006), ('k', 'b'): tensor(0.0006), ('k', 'd'): tensor(0.0006), ('k', 'v'): tensor(0.0006), ('c', 'd'): tensor(0.0006), ('c', 'p'): tensor(0.0006), ('t', 'g'): tensor(0.0005), ('t', 'x'): tensor(0.0005), ('t', 'f'): tensor(0.0005), ('y', 'w'): tensor(0.0005), ('i', 'w'): tensor(0.0005), ('o', 'q'): tensor(0.0005), ('l', 'g'): tensor(0.0005), ('l', 'j'): tensor(0.0005), ('n', 'b'): tensor(0.0005), ('m', 'w'): tensor(0.0005), ('k', 'f'): tensor(0.0004), ('h', 'g'): tensor(0.0004), ('h', 'f'): tensor(0.0004), ('h', 'c'): tensor(0.0004), ('n', 'x'): tensor(0.0004), ('s', 'j'): tensor(0.0004), ('s', 'f'): tensor(0.0004), ('s', 'g'): tensor(0.0004), ('d', 'z'): tensor(0.0004), ('d', 'b'): tensor(0.0004), ('d', 'q'): tensor(0.0004), ('t', 'b'): tensor(0.0004), ('n', 'p'): tensor(0.0003), ('r', 'x'): tensor(0.0003), ('m', 'f'): tensor(0.0003), ('m', 'k'): tensor(0.0003), ('l', 'q'): tensor(0.0003), ('h', 'p'): tensor(0.0003), ('h', 'h'): tensor(0.0003), ('h', 'q'): tensor(0.0003), ('s', 'q'): tensor(0.0002), ('n', 'q'): tensor(0.0002)}
# likelihood of full corpus = product of all bigram prods
l = 0
for bigram, prob in bigram_p_sorted.items():
    l += torch.log(prob)

# negative log likelihood loss nll
nll = -l /len(bigram_p_sorted)
print(nll)
tensor(4.4447)

Generate training data

word = "this"
sample = [(word[i], word[i+1]) for i,c in enumerate(word) if i < len(word)-1]
print(list(zip(*sample)))
[('t', 'h', 'i'), ('h', 'i', 's')]
xs, ys = [], []
for word in data:
    sample = [(stoi[word[i]], stoi[word[i+1]]) for i,c in enumerate(word) if i < len(word)-1]
    x, y = list(zip(*sample)) # inverse of zip
    xs.append(torch.tensor(x))
    ys.append(torch.tensor(y))
print('x:', xs[:3])
print('y', ys[:3])
x: [tensor([ 4, 12, 12]), tensor([14, 11,  8, 21,  8]), tensor([ 0, 21])]
y [tensor([12, 12,  0]), tensor([11,  8, 21,  8,  0]), tensor([21,  0])]

1-hot encoded input

enc = [F.one_hot(x, num_classes=len(tokens)).float() for x in xs]
print(enc[:3])
[tensor([[0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0.]]), tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0.]]), tensor([[1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 0., 0.]])]
plt.imshow(enc[0])
<matplotlib.image.AxesImage>

X = enc[0]
print(X.shape)
torch.Size([3, 27])

‘Neural net’ modeling

we model the transition probability matrix by neural net activations

W = torch.randn(27, 27)
logits = X @ W
counts = logits.exp()
probs = counts / counts.sum(1, keepdims=True)
print(probs)
tensor([[0.0337, 0.0199, 0.0826, 0.0366, 0.0263, 0.0144, 0.0057, 0.0167, 0.0242,
         0.0250, 0.0264, 0.2481, 0.0328, 0.0063, 0.0064, 0.0177, 0.0164, 0.0161,
         0.0125, 0.0134, 0.0285, 0.0894, 0.0408, 0.0117, 0.0030, 0.0229, 0.1224],
        [0.0113, 0.0214, 0.0160, 0.1308, 0.0192, 0.0053, 0.0200, 0.0808, 0.0523,
         0.0327, 0.0161, 0.0189, 0.0030, 0.0153, 0.0244, 0.0132, 0.0052, 0.0366,
         0.0078, 0.0281, 0.0100, 0.0147, 0.0354, 0.0152, 0.1898, 0.0428, 0.1339],
        [0.0113, 0.0214, 0.0160, 0.1308, 0.0192, 0.0053, 0.0200, 0.0808, 0.0523,
         0.0327, 0.0161, 0.0189, 0.0030, 0.0153, 0.0244, 0.0132, 0.0052, 0.0366,
         0.0078, 0.0281, 0.0100, 0.0147, 0.0354, 0.0152, 0.1898, 0.0428, 0.1339]])

KenLM

We refer to efficient kenlm implementation for larger n-gram models usable for production


source

KenLM

 KenLM (arpa_path:str, vocab:List)

Initialize self. See help(type(self)) for accurate signature.

Preprocess data into kenlm format

tokens separated by space with new sentence at each line

df = pd.read_csv('../data/text/names.txt', header=None, names=['name']) 
df = df.name.apply(lambda x: list(x)) # str into list of char
# df.apply(lambda x: x.append('<eos>')) # if eos needed
print(df.head())
df_toks = df.str.join(' ') # for kenlm input format tokens are separated by space
print(df_toks.head())
0                [e, m, m, a]
1          [o, l, i, v, i, a]
2                   [a, v, a]
3    [i, s, a, b, e, l, l, a]
4          [s, o, p, h, i, a]
Name: name, dtype: object
0            e m m a
1        o l i v i a
2              a v a
3    i s a b e l l a
4        s o p h i a
Name: name, dtype: object

Unique tokens

df.head()
# for row in df.iterrows():
#     print(row)
tokens = set()
for k,v in df.items():
    tokens.update(list(v))

print(tokens)
len(tokens)
{'f', 'e', 'q', 't', 'a', 'c', 'x', 'p', 'm', 'z', 'u', 'l', 'j', 'd', 'h', 'w', 'v', 'o', 'i', 'r', 'b', 'g', 'n', 'y', 's', 'k'}
26

Save data to kenlm format for training

data_file = df.to_csv('../data/text/names.kenlm.txt', header=None, index=None)
! bzip2 -kz ../data/text/names.kenlm.txt
! bzcat ../data/text/names.kenlm.txt.bz2 | head
bzip2: Output file ../data/names.kenlm.txt.bz2 already exists.
"['e', 'm', 'm', 'a']"
"['o', 'l', 'i', 'v', 'i', 'a']"
"['a', 'v', 'a']"
"['i', 's', 'a', 'b', 'e', 'l', 'l', 'a']"
"['s', 'o', 'p', 'h', 'i', 'a']"
"['c', 'h', 'a', 'r', 'l', 'o', 't', 't', 'e']"
"['m', 'i', 'a']"
"['a', 'm', 'e', 'l', 'i', 'a']"
"['h', 'a', 'r', 'p', 'e', 'r']"
"['e', 'v', 'e', 'l', 'y', 'n']"

bzcat: I/O or other error, bailing out.  Possible reason follows.
bzcat: Broken pipe
    Input file = ../data/names.kenlm.txt.bz2, output file = (stdout)

Train KenLM n-gram model

https://lukesalamone.github.io/posts/running-simple-language-model/

KenLM requires data to be one sentence per line lowercase

! if [ ! -f "../data/text/names.2gram.arpa" ]; then lmplz --discount_fallback -o 2 < ../data/text/names.kenlm.txt.bz2>../data/text/names.2gram.arpa; fi
! if [ ! -f "../data/text/names.2gram.kenlm" ]; then build_binary ../data/text/names.2gram.arpa ../data/text/names.2gram.kenlm; fi
Reading ../data/names.2gram.arpa
----5---10---15---20---25---30---35---40---45---50---55---60---65---70---75---80---85---90---95--100
****************************************************************************************************
SUCCESS

Test original Kenlm python api probs

model = kenlm.LanguageModel('../data/text/names.2gram.kenlm')
sentence = "emma"
tokenized = "e m m a"
# model.score("emma", bos = False, eos = False)
words = ['<s>'] + list(sentence) + ['</s>']
print(words)
final = 0
for i, (prob, length, oov) in enumerate(model.full_scores(tokenized)):
    print(f'words: {words[i:i+length]} index:{i}, prob:{prob}, length:{length}, oov:{oov}')
    final += prob

print(final)
print(model.score("e m m a"))
print(f'prob <s> e: {model.score("e", bos=True, eos=False)}')
print(f'prob e: {model.score("e", bos=False, eos=False)}')
print(f'prob <s> e m: {model.score("e m", bos=True, eos=False)}')
print(f'prob e m: {model.score("e m", bos=False, eos=False)}')
state = kenlm.State()
state2 = kenlm.State()
model.BeginSentenceWrite(state)
accum = 0
accum += model.BaseScore(state, "e", state2)
print(f'prob <s> e: {accum}')
state, state2 = state2, state
accum += model.BaseScore(state, "m", state2)
print(f'prob <s> e m: {accum}')
['<s>', 'e', 'm', 'm', 'a', '</s>']
words: ['<s>'] index:0, prob:-6.039403915405273, length:1, oov:True
words: ['e'] index:1, prob:-3.074934959411621, length:1, oov:True
words: ['m'] index:2, prob:-3.074934959411621, length:1, oov:True
words: ['m'] index:3, prob:-3.074934959411621, length:1, oov:True
words: ['a'] index:4, prob:-1.7287936210632324, length:1, oov:False
-16.99300241470337
-16.99300193786621
prob <s> e: -6.039403915405273
prob e: -3.074934959411621
prob <s> e m: -9.114338874816895
prob e m: -6.149869918823242
prob <s> e: -6.039403915405273
prob <s> e m: -9.114338874816895

Define LM vocabulary

# add special tokens to vocabulary
tokens.add('<s>')
tokens.add('</s>')
tokens.add('<unk>')
print(tokens, len(tokens))
vocab = list(tokens)
{'f', 'e', 'q', 't', 'a', 'c', 'x', 'p', 'm', 'z', '<s>', 'u', 'l', 'j', 'd', 'h', 'w', 'v', '<unk>', 'o', 'i', 'r', '</s>', 'b', 'g', 'n', 'y', 's', 'k'} 29

Inference / Sampling from prob distributions

lm = KenLM('../data/text/names.2gram.kenlm', vocab)
init_char = '<s> e m m'
# probs = lm.nbest(len(vocab), log_prob=False)
# print(np.sum([p for char, p in probs]))
# res = [init_char]
# next = int(torch.multinomial(P[prev,:],num_samples=1,replacement=True))
for i in range(50):
    lm.new_sentence_init()
    lm.append(init_char)
    while True:
        # nbest probs at current state
        probs = lm.nbest(len(vocab), log_prob=False)
        # print(probs)
        # print(np.sum(probs))
        # sample from prob distribution
        try:
            index_next = int(torch.multinomial(torch.tensor([prob for char, prob in probs]),num_samples=1,replacement=True))
        except:
            print("probs too small")
            break
        char_next = probs[index_next][0]
        lm.append(char_next)
        # print(init_char + '<s>')
        if char_next == '</s>' or char_next == '<s>' and lm.text != init_char and (lm.text != init_char+' <s>'):
            print(lm.text.replace(' ', ''))
            break
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s></s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s>l<s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s>h<s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>
<s>emm<s><s>