Skip to content

version and compatibility fix #53

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,14 @@
git clone https://github.com/peterliht/knowledge-distillation-pytorch.git
```

* Install python==3.10.15 and create virtualenv
```
sudo apt install python3.10 python3.10-venv python3.10-dev
python3.10 -m venv venv
source venv/bin/activate
```


* Install the dependencies (including Pytorch)
```
pip install -r requirements.txt
Expand Down
4 changes: 2 additions & 2 deletions evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def evaluate(model, loss_fn, dataloader, metrics, params):

# move to GPU if available
if params.cuda:
data_batch, labels_batch = data_batch.cuda(async=True), labels_batch.cuda(async=True)
data_batch, labels_batch = data_batch.cuda(), labels_batch.cuda()
# fetch the next evaluation batch
data_batch, labels_batch = Variable(data_batch), Variable(labels_batch)

Expand Down Expand Up @@ -94,7 +94,7 @@ def evaluate_kd(model, dataloader, metrics, params):

# move to GPU if available
if params.cuda:
data_batch, labels_batch = data_batch.cuda(async=True), labels_batch.cuda(async=True)
data_batch, labels_batch = data_batch.cuda(), labels_batch.cuda()
# fetch the next evaluation batch
data_batch, labels_batch = Variable(data_batch), Variable(labels_batch)

Expand Down
15 changes: 8 additions & 7 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
scipy==1.0.0
numpy==1.14.0
Pillow==8.1.1
tabulate==0.8.2
tensorflow==1.7.0rc0
torch==0.3.0.post4
torchvision==0.2.0
scipy==1.14.1
numpy==1.25.0
Pillow==9.0.0
tabulate==0.5
tensorflow==2.8.0rc0
torch==1.13.0
torchvision==0.14.0
tqdm==4.19.8
torchnet
protobuf == 3.20
18 changes: 9 additions & 9 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ def train(model, optimizer, loss_fn, dataloader, metrics, params):
for i, (train_batch, labels_batch) in enumerate(dataloader):
# move to GPU if available
if params.cuda:
train_batch, labels_batch = train_batch.cuda(async=True), \
labels_batch.cuda(async=True)
train_batch, labels_batch = train_batch.cuda(), \
labels_batch.cuda()
# convert to torch Variables
train_batch, labels_batch = Variable(train_batch), Variable(labels_batch)

Expand Down Expand Up @@ -186,8 +186,8 @@ def train_kd(model, teacher_model, optimizer, loss_fn_kd, dataloader, metrics, p
for i, (train_batch, labels_batch) in enumerate(dataloader):
# move to GPU if available
if params.cuda:
train_batch, labels_batch = train_batch.cuda(async=True), \
labels_batch.cuda(async=True)
train_batch, labels_batch = train_batch.cuda(), \
labels_batch.cuda()
# convert to torch Variables
train_batch, labels_batch = Variable(train_batch), Variable(labels_batch)

Expand All @@ -199,7 +199,7 @@ def train_kd(model, teacher_model, optimizer, loss_fn_kd, dataloader, metrics, p
with torch.no_grad():
output_teacher_batch = teacher_model(train_batch)
if params.cuda:
output_teacher_batch = output_teacher_batch.cuda(async=True)
output_teacher_batch = output_teacher_batch.cuda()

loss = loss_fn_kd(output_batch, labels_batch, output_teacher_batch, params)

Expand All @@ -213,17 +213,17 @@ def train_kd(model, teacher_model, optimizer, loss_fn_kd, dataloader, metrics, p
# Evaluate summaries only once in a while
if i % params.save_summary_steps == 0:
# extract data from torch Variable, move to cpu, convert to numpy arrays
output_batch = output_batch.data.cpu().numpy()
labels_batch = labels_batch.data.cpu().numpy()
output_batch = output_batch.detach().cpu().numpy()
labels_batch = labels_batch.detach().cpu().numpy()

# compute all metrics on this batch
summary_batch = {metric:metrics[metric](output_batch, labels_batch)
for metric in metrics}
summary_batch['loss'] = loss.data[0]
summary_batch['loss'] = loss.item()
summ.append(summary_batch)

# update the average loss
loss_avg.update(loss.data[0])
loss_avg.update(loss.item())

t.set_postfix(loss='{:05.3f}'.format(loss_avg()))
t.update()
Expand Down
76 changes: 23 additions & 53 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def load_checkpoint(checkpoint, model, optimizer=None):
optimizer: (torch.optim) optional: resume optimizer from checkpoint
"""
if not os.path.exists(checkpoint):
raise("File doesn't exist {}".format(checkpoint))
raise(FileNotFoundError("File doesn't exist {}".format(checkpoint)))
if torch.cuda.is_available():
checkpoint = torch.load(checkpoint)
else:
Expand All @@ -163,65 +163,35 @@ def load_checkpoint(checkpoint, model, optimizer=None):
return checkpoint


class Board_Logger(object):
"""Tensorboard log utility"""

class BoardLogger:
"""TensorBoard log utility for TensorFlow 2.x"""
def __init__(self, log_dir):
"""Create a summary writer logging to log_dir."""
self.writer = tf.summary.FileWriter(log_dir)
self.writer = tf.summary.create_file_writer(log_dir)

def scalar_summary(self, tag, value, step):
"""Log a scalar variable."""
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)])
self.writer.add_summary(summary, step)
with self.writer.as_default():
tf.summary.scalar(tag, value, step=step)
self.writer.flush()

def image_summary(self, tag, images, step):
"""Log a list of images."""
with self.writer.as_default():
for i, img in enumerate(images):
# Convert image to a TensorFlow-compatible format
if isinstance(img, np.ndarray):
img = tf.convert_to_tensor(img, dtype=tf.uint8)
if img.ndim == 2: # Add channel dimension for grayscale images
img = tf.expand_dims(img, axis=-1)
tf.summary.image(f"{tag}/{i}", tf.expand_dims(img, 0), step=step) # Add batch dimension
self.writer.flush()

img_summaries = []
for i, img in enumerate(images):
# Write the image to a string
try:
s = StringIO()
except:
s = BytesIO()
scipy.misc.toimage(img).save(s, format="png")

# Create an Image object
img_sum = tf.Summary.Image(encoded_image_string=s.getvalue(),
height=img.shape[0],
width=img.shape[1])
# Create a Summary value
img_summaries.append(tf.Summary.Value(tag='%s/%d' % (tag, i), image=img_sum))

# Create and write Summary
summary = tf.Summary(value=img_summaries)
self.writer.add_summary(summary, step)

def histo_summary(self, tag, values, step, bins=1000):
"""Log a histogram of the tensor of values."""

# Create a histogram using numpy
counts, bin_edges = np.histogram(values, bins=bins)

# Fill the fields of the histogram proto
hist = tf.HistogramProto()
hist.min = float(np.min(values))
hist.max = float(np.max(values))
hist.num = int(np.prod(values.shape))
hist.sum = float(np.sum(values))
hist.sum_squares = float(np.sum(values**2))

# Drop the start of the first bin
bin_edges = bin_edges[1:]

# Add bin edges and counts
for edge in bin_edges:
hist.bucket_limit.append(edge)
for c in counts:
hist.bucket.append(c)

# Create and write Summary
summary = tf.Summary(value=[tf.Summary.Value(tag=tag, histo=hist)])
self.writer.add_summary(summary, step)
self.writer.flush()
with self.writer.as_default():
# Create histogram data using numpy
counts, bin_edges = np.histogram(values, bins=bins)

# Create a histogram summary
tf.summary.histogram(tag, values, step=step)
self.writer.flush()