-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathmodel.py
More file actions
43 lines (31 loc) · 2.03 KB
/
Copy pathmodel.py
File metadata and controls
43 lines (31 loc) · 2.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import io
class Plant_Disease_Model(nn.Module):
def __init__(self):
super().__init__()
self.network = models.resnet34(pretrained=True)
num_ftrs = self.network.fc.in_features
self.network.fc = nn.Linear(num_ftrs, 38)
def forward(self, xb):
out = self.network(xb)
return out
transform = transforms.Compose(
[transforms.Resize(size=128),
transforms.ToTensor()])
num_classes = ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', 'Blueberry___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy', 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', 'Corn_(maize)___Common_rust_', 'Corn_(maize)___Northern_Leaf_Blight', 'Corn_(maize)___healthy', 'Grape___Black_rot', 'Grape___Esca_(Black_Measles)', 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy', 'Orange___Haunglongbing_(Citrus_greening)', 'Peach___Bacterial_spot', 'Peach___healthy',
'Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Raspberry___healthy', 'Soybean___healthy', 'Squash___Powdery_mildew', 'Strawberry___Leaf_scorch', 'Strawberry___healthy', 'Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight', 'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot', 'Tomato___Spider_mites Two-spotted_spider_mite', 'Tomato___Target_Spot', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', 'Tomato___Tomato_mosaic_virus', 'Tomato___healthy']
model = Plant_Disease_Model()
model.load_state_dict(torch.load(
'./Models/plantDisease-resnet34.pth', map_location=torch.device('cpu')))
model.eval()
def predict_image(img):
img_pil = Image.open(io.BytesIO(img))
tensor = transform(img_pil)
xb = tensor.unsqueeze(0)
yb = model(xb)
_, preds = torch.max(yb, dim=1)
return num_classes[preds[0].item()]