-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathPruning_Sparse_training.py
112 lines (87 loc) · 4.61 KB
/
Pruning_Sparse_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
1.初始化带有随机mask的网络:首先我们定义了一个包含两个线性层的神经网络,同时使用create_mask方法为每个线性层创建一个与权重相同形状的mask,通过top-k方法选择一部分元素变成0,实现了一定的稀疏性,其中sparsity_rate为稀疏率
2.训练一个epoch的pruned network:使用随机mask训练网络,然后更新mask
3.剪枝权重:将权重较小的一部分权重剪枝,对应的mask中的元素变成0
4.重新regrow同样数量的random weights:在mask中元素为0的位置随机选择与剪枝的元素数量相同,将其对应的元素重新生成
"""
# raw net
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# Define the network architecture
class SparseNet(nn.Module):
def __init__(self, sparsity_rate, mutation_rate = 0.5):
super(SparseNet, self).__init__()
self.fc1 = nn.Linear(784, 128)
self.fc2 = nn.Linear(128, 10)
self.sparsity_rate = sparsity_rate
self.mutation_rate = mutation_rate
self.initialize_masks() # <== 1.initialize a network with random mask
def forward(self, x):
x = x.view(-1, 784)
x = x @ (self.fc1.weight * self.mask1.to(x.device)).T + self.fc1.bias
x = torch.relu(x)
x = x @ (self.fc2.weight * self.mask2.to(x.device)).T + self.fc2.bias
return x
def initialize_masks(self):
self.mask1 = self.create_mask(self.fc1.weight, self.sparsity_rate)
self.mask2 = self.create_mask(self.fc2.weight, self.sparsity_rate)
def create_mask(self, weight, sparsity_rate):
k = int(sparsity_rate * weight.numel())
_, indices = torch.topk(weight.abs().view(-1), k, largest=False) # take the minimum k elements
mask = torch.ones_like(weight, dtype=bool)
mask.view(-1)[indices] = False
return mask # <== 1.initialize a network with random mask
def update_masks(self):
self.mask1 = self.mutate_mask(self.fc1.weight, self.mask1, self.mutation_rate)
self.mask2 = self.mutate_mask(self.fc2.weight, self.mask2, self.mutation_rate)
def mutate_mask(self, weight, mask, mutation_rate=0.5): # weight and mask: 2d shape
# Find the number of elements in the mask that are true
num_true = torch.count_nonzero(mask)
# Compute the number of elements to mutate
mutate_num = int(mutation_rate * num_true)
# 3) pruning a certain amount of weights of lower magnitude
true_indices_2d = torch.where(mask == True) # index the 2d mask where is true
true_element_1d_idx_prune = torch.topk(weight[true_indices_2d], mutate_num, largest=False)[1]
for i in true_element_1d_idx_prune:
mask[true_indices_2d[0][i], true_indices_2d[1][i]] = False
# 4) regrowing the same amount of random weights
# Get the indices of the False elements in the mask
false_indices = torch.nonzero(~mask)
# Randomly select n indices from the false_indices tensor
random_indices = torch.randperm(false_indices.shape[0])[:mutate_num]
# the element to be regrow
regrow_indices = false_indices[random_indices]
for regrow_idx in regrow_indices:
mask[tuple(regrow_idx)] = True
return mask
# Set the device to CUDA if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
# Initialize the network, loss function, and optimizer
sparsity_rate = 0.5
model = SparseNet(sparsity_rate).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# Training loop
n_epochs = 10
for epoch in range(n_epochs):
running_loss = 0.0
for batch_idx, (inputs, targets) in enumerate(train_loader):
# Move the data to the device
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
running_loss += loss.item()
# print(f"Loss: {running_loss / (batch_idx+1)}")
# Update masks
model.update_masks() # generate a new mask based on the update weights
print(f"Epoch {epoch+1}/{n_epochs}, Loss: {running_loss / (batch_idx+1)}")