當前位置:編程學習大全網 - 圖片素材 - 樹莓派使用PaddleX做物體分類

樹莓派使用PaddleX做物體分類

1.先使用百度AI運行代碼。參考/aistudio/projectdetail/2160041鏈接網址,從而得到模型。但是paddlex運行得到的模型不能直接在樹莓派上跑。所以進行第二步。

2.把模型轉換成paddle-lite支持的模型。在百度studio,上壹步的代碼裏運行

paddle_lite_opt --model_fie=妳的模型途徑

--param_file=妳的權值途徑

--valid_targets=arm

--optimize_out_type=naive_buffer

--optimize_out=妳要的輸出nb模型的途徑和名稱

3.執行以下分類代碼,修改屬於妳的參數

from paddlelite.lite import *

import cv2

import numpy as np

import sys

import time

from PIL import Image

from PIL import ImageFont

from PIL import ImageDraw

# 加載模型

def create_predictor(model_dir):

config = MobileConfig()

config.set_model_from_file(model_dir)

predictor = create_paddle_predictor(config)

return predictor

#圖像歸壹化處理

def process_img(image, input_image_size):

origin = image

img = origin.resize(input_image_size, Image.BILINEAR)

resized_img = img.copy()

if img.mode != 'RGB':

img = img.convert('RGB')

img = np.array(img).astype('float32').transpose((2, 0, 1)) # HWC to CHW

img -= 127.5

img *= 0.007843

img = img[np.newaxis, :]

return origin,img

# 預測

def predict(image, predictor, input_image_size):

#輸入數據處理

input_tensor = predictor.get_input(0)

input_tensor.resize([1, 3, input_image_size[0], input_image_size[1]])

image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA))

origin, img = process_img(image, input_image_size)

image_data = np.array(img).flatten().tolist()

input_tensor.set_float_data(image_data)

#執行預測

predictor.run()

#獲取輸出

output_tensor = predictor.get_output(0)

print("output_tensor.float_data()[:] : ", output_tensor.float_data()[:])

res = output_tensor.float_data()[:]

return res

# 展示結果

def post_res(label_dict, res):

print(max(res))

target_index = res.index(max(res))

print("結果是:" + " " + label_dict[target_index])

if __name__ == '__main__':

# 初始定義

label_dict = {0:"metal", 1:"paper", 2:"plastic", 3:"glass"}

image = "./test_pic/images_orginal/glass/glass300.jpg"

model_dir = "./trained_model/ResNet50_trash_x86_model.nb"

image_size = (224, 224)

# 初始化

predictor = create_predictor(model_dir)

# 讀入圖片

image = cv2.imread(image)

# 預測

res = predict(image, predictor, image_size)

# 顯示結果

post_res(label_dict, res)

cv2.imshow("image", image)

cv2.waitKey()

  • 上一篇:i世界“足跡”正在壹點壹滴遍布全國,描繪屬於我們的故事~~
  • 下一篇:30歲大齡剩女:壹個真正優秀的女人,年齡不會成為劣勢,妳怎麽看?
  • copyright 2024編程學習大全網