Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions .github/workflows/python_tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
name: Python Application Test

on: [push, pull_request]

jobs:
test:
runs-on: ubuntu-latest

steps:
- name: Check out repository code
uses: actions/checkout@v2

- name: Set up Python 3.9
uses: actions/setup-python@v2
with:
python-version: '3.9'

- name: Execute unittest script
run: |
chmod +x unittest.sh
./unittest.sh
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
unittest_env/
__pycache__/
*egg-info
.vscode/
trace.html
Empty file added test/__init__.py
Empty file.
Empty file.
38 changes: 38 additions & 0 deletions test/test_node_converters/test_math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import unittest
import torch
import nobuco

class TestPyTorchToKerasConversion(unittest.TestCase):
def test_log_converter(self):
# Define the Sign model inside the test
class Log(torch.nn.Module):
def __init__(self):
super(Log, self).__init__()

def forward(self, input_tensor):
return torch.log(input_tensor)

# Initialize the model and input tensor
torch_model = Log()
torch_model.eval()
input_tensor = torch.randn(1, 10, 20)

# Convert the model and ensure the HTML trace is saved
keras_model = nobuco.pytorch_to_keras(
torch_model,
args=[input_tensor], kwargs=None,
inputs_channel_order=nobuco.ChannelOrder.TENSORFLOW,
outputs_channel_order=nobuco.ChannelOrder.TENSORFLOW,
save_trace_html=True
)

# Read the contents of the trace.html file
with open('trace.html', 'r', encoding='utf-8') as file:
trace_html = file.read()

# Assertions for the content of trace_html
self.assertNotIn('Max diff', trace_html, "The trace HTML should not contain 'Max diff'")


if __name__ == '__main__':
unittest.main()
34 changes: 34 additions & 0 deletions unittest.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#!/bin/bash

# Run `chmod +x unittest.sh` first

# Define the environment name
ENV_NAME="unittest_env"

# Remove the previous virtual environment to ensure a fresh setup
echo "Removing any existing virtual environment..."
rm -rf $ENV_NAME

# Create a new virtual environment
echo "Creating a new virtual environment..."
python3 -m venv $ENV_NAME

# Activate the virtual environment
echo "Activating the virtual environment..."
source $ENV_NAME/bin/activate

# Upgrade pip to its latest version
echo "Upgrading pip..."
pip install --upgrade pip

# Install dependencies from requirements.txt
echo "Installing dependencies from requirements.txt..."
pip install -r requirements.txt

# Run unit tests
echo "Running unit tests..."
python -m unittest

# Deactivate the virtual environment
echo "Deactivating the virtual environment..."
deactivate