Pytorch Functions (3)
Pytorch 함수들에 관한 내용을 정리합니다. (3)
- dataclass
from dataclass import dataclass @dataclass class GPTConfig: # able to make a class simple block_size : int = 1024 vocab_size : int = 50257 n_layer : int = 12 n_head : int = 12 n_embd : int = 768 dropout : float = 0.1
- nn.ModuleDict(), nn.ModuleList()
nn.ModuleDict(dict( wte = nn.Embedding(config.vocab_size, config.n_embd), wpe = nn.Embedding(config.vocab_size, config.n_embd), drop = nn.Dropout(config.dropout), h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), ln_f = nn.layerNorm(config.n_embd) ))
- nn.Embedding
nn.Embedding(num_embeddings, embedding_dim) #(vocab_size, n_embd) # a simple lookup table that stores embeddings of a fixed dictionary and size. # input : a list of indices / output : word embeddings