-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_HTM.py
112 lines (88 loc) · 4.36 KB
/
train_HTM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#import libraries------------------------------------
from datasets import load_dataset
from tokenizers import decoders, models, normalizers, pre_tokenizers, processors, trainers, Tokenizer
from transformers import BertTokenizer
import numpy as np
from htm.bindings.sdr import SDR
from htm.algorithms import TemporalMemory as TM
from tqdm import tqdm
#settings--------------------------------------------
vocab_size = 10000
batch_size = 1000
arraySize = vocab_size
inputSDR = SDR( arraySize )
tm = TM(columnDimensions = (inputSDR.size,),
cellsPerColumn = 32, # default: 32
minThreshold = 1, # default: 10
activationThreshold = 1, # default: 13
initialPermanence = 0.4, # default: 0.21
connectedPermanence = 0.5, # default: 0.5
permanenceIncrement = 0.1, # default: 0.1
permanenceDecrement = 0.1, # default: 0.1
predictedSegmentDecrement = 0.0, # default: 0.0 --> #set to 0.05?
maxSegmentsPerCell = 1, # default: 255
maxSynapsesPerSegment = 1 # default: 255
)
#functions-------------------------------------------
# def formatSdr(sdr):
# result = ''
# for i in range(sdr.size):
# if i > 0 and i % 8 == 0:
# result += ' '
# result += str(sdr.dense.flatten()[i])
# return result
#acquire data----------------------------------------
dataset = load_dataset("wikitext", name="wikitext-2-raw-v1", split="train")
#slice small portion
#dataset = dataset.select(range(10))
#acquire tokenizer-----------------------------------
custom_tokenizer = Tokenizer.from_file("my-new-tokenizer.json") #self trained
#training tokenizers is quick (<30s)
for cycle in range(3):
#print('CURRENTLY IN CYCLE = ', cycle+1, "==================================")
description = f'Processing sentences, cycle = {str(cycle+1)}'
for sentence in tqdm(dataset, desc=description):
#tokenize sentences----------------------------------
#sequence = "Using a Transformer network is simple" #dummy
#should use wikitext
#print(sentence['text'])
sequence = str(sentence)
encodings = custom_tokenizer.encode(sequence) #--> sentence
#print(encodings.tokens) #display
id_seq = (encodings.ids)
#print(id_seq)#display
for id in id_seq:
#encode to SDR---------------------------------------
sensorValueBits = inputSDR.dense
sensorValueBits = np.zeros(arraySize)
sensorValueBits[id] = 1 #this has no semantic meaning
#ideally words with close relationships should have some overlap or such
inputSDR.dense = sensorValueBits
#inputSDR.sparse = id #shorter code
#pass into TM----------------------------------------
tm.compute(inputSDR, learn = True)
#print the active cell ids --------------------------
# active_cell_ids = tm.cellsToColumns(tm.getActiveCells()).sparse
# print('active cells = ', active_cell_ids)
# decoded_string = custom_tokenizer.decode(active_cell_ids)
# print('current token: ', decoded_string) #print the current processing token
# tm.activateDendrites(True) #necessary, call before getPredictiveCells
# #print/acquire the predicted cell ids
# predicted_cell_ids = tm.cellsToColumns(tm.getPredictiveCells()).sparse
# decoded_string = custom_tokenizer.decode(predicted_cell_ids)
# print('predicted next token: ', decoded_string) #print the current processing token
#save trained model
# File to save the TemporalMemory state
filename = 'trained_HTM'
format = 'BINARY' # Can be 'BINARY', 'PORTABLE', 'JSON', or 'XML'
#Only use BINARY to avoid load error
# Save the TemporalMemory state to a file
try:
tm.saveToFile(filename, format)
print(f"TemporalMemory state saved to {filename} in {format} format.")
except Exception as e:
print("Error during save:", e)
#get user input--------------------------------------
#pass into TM and get prediction---------------------
#pass prediction into TM-----------------------------
#decode and print prediction-------------------------