hy clear Blog

【物体検出/OBB】YOLOで転移学習したOBBモデルをブラウザで実行する(フロント編)

2024/10/17

2024/10/17

📰 アフィリエイト広告を利用しています

はじめに

YOLOで回転を含めたバウンディングボックス(Oriented Bounding Boxes, OBB)をTensorflow.jsで実行する手順のメモです。
この記事ではフロントエンドのReactで実行する手順です。

間違っている可能性がありますので、自己責任で

モデルの出力

Tensorflow.jsで使える形式でモデルを出力する。
metadata.yaml、model.jsonなどの複数のファイルが出力されるのですべて使用する。

from ultralytics import YOLO
model = YOLO("yolo11n-obb.pt")
model.export(format="saved_model")

転移学習したモデルを出力する場合は、以下のリンク

【物体検出/OBB】YOLOで転移学習したOBBモデルをブラウザで実行する(転移学習編)

モデルの読み込み

tfjs形式で出力したモデルをTensorflow.jsで読み込みます。
まずTensorflow.jsをインストールします。metadataのyamlファイルをパースするライブラリもついでにインストール。

npm install @tensorflow/tfjs js-yaml @types/js-yaml

メタデータのinterfaceを定義して情報をロードします。imgszとnamesを使用します。
imgszは推論する画像の縦横のサイズで、namesはラベルです。

interface YOLOMetadata {
  description: string
  author: string
  date: string,
  version: string
  license: string
  docs: string
  stride: number
  task: string
  batch: number
  imgsz: [number, number]
  names: string[]
}

async function loadMetadata(): Promise<YOLOMetadata | null> {
  let metadata: YOLOMetadata | null = null
  await fetch("/model/default/metadata.yaml")
    .then(response =>
      response.text())
    .then(text =>
      load(text))
    .then(yamlData =>
      metadata = (yamlData as YOLOMetadata)
    )
    .catch(error => console.error('YAML読み込みエラー:', error));
  return metadata;
}

次にYOLOのモデルを読み込んでウォームアップを実行します。最初の実行に時間がかかるためです。imgszはmetadataから取得した画像のサイズを渡します。
imgszは基本的に[640, 640]になると思います。ほかの値でテストしてないので、エラーが起きたときはここをチェックする。

async function loadYOLOModel(imgsz: [number, number]): Promise<tf.GraphModel<string | tf.io.IOHandler>> {
  const model = await tf.loadGraphModel("/model/default/model.json")
  // warm up
  tf.tidy(() => {
    const zeroTensor = tf.zeros([1, imgsz[0], imgsz[1], 3], "float32")
    model.execute(zeroTensor)
  })
  return model
}

返されたモデルを使って画像内の物体検出を行います。

推論を実行する

画像をモデルに渡して推論を実行します。
回転を含む場合の処理は次の段で行います。

推論画像の前処理

YOLOではmetadataのimgszで取得できる画像サイズに合わせて渡さないとエラーになるため、推論する画像のサイズを変更したtensorを取得します。また0~255の値を0~1の正規化をしてYOLOに渡すために次元を1つ追加しています。

  const imageTensor = tf.tidy(() => {
    let imageTensor = tf.browser.fromPixels(img).toFloat().div(tf.scalar(255.0))
    imageTensor = imageTensor.resizeBilinear(imgsz)
    imageTensor = imageTensor.expandDims(0)
    return imageTensor
  })

結果を取得

modelにimageTensorを渡して推論を実行します。

  const results = model.predict(imageTensor) as tf.Tensor<tf.Rank>

resultsは [1, 4+ラベル数, 8400]の形式で出力されます。
これは以下のような形式になっています。

[
  [x, x, ...],
  [y, y, ...],
  [w, w, ...],
  [h, h, ...],
  [class1_score, class1_score, ...],
  [class2_score, class1_score, ...],
  ...
]

(x, y)座標が中心となる幅がw、高さがhのバウンティングボックス(bbox)です。そのbboxのclass_scoreの値がラベルの数だけ続きます。デフォルトのモデルだと80個のnamesがあるので[1, 84]の形式になります。推論した結果のbboxが8400あるので[1, 84, 8400]となります。

結果を変換

8400のbboxの中にはもちろん不要なデータや重なっているbboxが大量に含まれているのでNon-Maximum Suppressionアルゴリズムで使えるデータにします。これはTensorflow.jsのtf.image.nonMaxSuppression()で行えるのでこの関数に必要なデータの形式に変換していきます。

tf.image.nonMaxSuppression(boxes, scores, maxOutputSize, iouThreshold, scoreThreshold)

iouThresholdは二つのbboxがどの程度重なっている場合に同じものを検出しているとするかの閾値です。
まずboxesを作成します。boxesの形式は[y1, x1, y2, x2]の形式が必要になるのでそれに合うようにデータを作成します。

以下のコードはメモリリークを防ぐためすべてtf.tidyの中で処理する。

const temp = result.squeeze()
// x, y, w, hを取り出す
const x = temp.slice([0, 0], [1, -1]); // x座標
const y = temp.slice([1, 0], [1, -1]); // y座標
const w = temp.slice([2, 0], [1, -1]); // 幅
const h = temp.slice([3, 0], [1, -1]); // 高さ


// (x1, y1)がbboxの左上の座標、(x2, y2)がbboxの右下の座標
const x1 = tf.sub(x , tf.div(w, 2))
const y1 = tf.sub(y , tf.div(h, 2))
const x2 = tf.add(x1, w)
const y2 = tf.add(y1, h)

const boxes = tf.stack([y1, x1, y2, x2], 2).squeeze();

最大のスコアとそのラベルのインデックスを取得する。
その後、Non-Maximum Suppressionでbboxを取得する。

const maxScores = temp.slice([4, 0], [labelCount, -1]).max(0)
const labelIndexes = temp.slice([4, 0], [labelCount, -1]).argMax(0)

const bboxIndexs = tf.image.nonMaxSuppression(boxes.as2D(boxes.shape[0], boxes.shape[1]!), maxScores.as1D(), 100, 0.5, 0.3);

データをまとめてBbox型にして返す。

Bbox.ts
interface Bbox {
  x: number,
  y: number,
  w: number
  h: number
  label: number
  score: number
}
const resultBboxes = boxes.gather(bboxIndexs, 0).arraySync() as []
const resultScores = maxScores.gather(bboxIndexs, 0).arraySync() as []
const resultLables = labelIndexes.gather(bboxIndexs, 0).arraySync() as []

return resultBboxes.map((bbox, index) => {
  return {
    x: bbox[1],
    y: bbox[0],
    w: bbox[3] - bbox[1], // 右下の座標になっているので、幅に戻す
    h: bbox[2] - bbox[0], // 上記同様
    score: resultScores[index],
    label: resultLables[index]
  }
})

結果を表示する

Reactを使って推論したbboxをcanvas上に表示します。

コードのまとめ

今までのコードをまとめます。

モデルをロードする

load_model.ts
import { load } from "js-yaml";

import '@tensorflow/tfjs-backend-cpu';
import '@tensorflow/tfjs-backend-webgl';

import * as tf from '@tensorflow/tfjs';


export interface YOLOMetadata {
    description: string
    author: string
    date: string,
    version: string
    license: string
    docs: string
    stride: number
    task: string
    batch: number
    imgsz: [number, number]
    names: [number, string][]
}

export async function loadMetadata(): Promise<YOLOMetadata | null> {
    let metadata: YOLOMetadata | null = null
    await fetch("/model/default/metadata.yaml")
        .then(response =>
            response.text())
        .then(text =>
            load(text))
        .then(yamlData =>
            metadata = (yamlData as YOLOMetadata)
        )
        .catch(error => console.error('YAML読み込みエラー:', error));
    return metadata;
}

export async function loadYOLOModel(imgsz: [number, number]): Promise<tf.GraphModel<string | tf.io.IOHandler>> {
    const model = await tf.loadGraphModel("/model/default/model.json")
    // warm up
    tf.tidy(() => {
        const zeroTensor = tf.zeros([1, imgsz[0], imgsz[1], 3], "float32")
        model.execute(zeroTensor)
    })
    return model
}

推論

predict.ts
import '@tensorflow/tfjs-backend-cpu';
import '@tensorflow/tfjs-backend-webgl';

import * as tf from '@tensorflow/tfjs';

export async function predict(model: tf.GraphModel, img: HTMLImageElement, imgsz: [number, number]) {
    const imageTensor = tf.tidy(() => {
      let imageTensor = tf.browser.fromPixels(img).toFloat().div(tf.scalar(255.0))
      imageTensor = imageTensor.resizeBilinear(imgsz)
      imageTensor = imageTensor.expandDims(0)
      return imageTensor
    })
  
    const results = model.predict(imageTensor) as tf.Tensor<tf.Rank>
  
    return results
  }

結果を変換

result_to_bbox.ts

import * as tf from '@tensorflow/tfjs';

export interface Bbox {
  x: number,
  y: number,
  w: number
  h: number
  label: number
  score: number
}

export const resultToBbox = (result: tf.Tensor<tf.Rank>, labelCount: number) => {
  const bboxes = tf.tidy(() => {
    const temp = result.squeeze()
    // x, y, w, hを取り出す
    const x = temp.slice([0, 0], [1, -1]); // x座標
    const y = temp.slice([1, 0], [1, -1]); // y座標
    const w = temp.slice([2, 0], [1, -1]); // 幅
    const h = temp.slice([3, 0], [1, -1]); // 高さ

    const x1 = tf.sub(x, tf.div(w, 2))
    const y1 = tf.sub(y, tf.div(h, 2))
    const x2 = tf.add(x1, w)
    const y2 = tf.add(y1, h)
    const boxes = tf.stack([y1, x1, y2, x2], 2).squeeze();

    const maxScores = temp.slice([4, 0], [labelCount, -1]).max(0)
    const labelIndexes = temp.slice([4, 0], [labelCount, -1]).argMax(0)

    const bboxIndexs = tf.image.nonMaxSuppression(
      boxes.as2D(boxes.shape[0], boxes.shape[1]!),
      maxScores.as1D(), 100, 0.5, 0.3);

    const resultBboxes = boxes.gather(bboxIndexs, 0).arraySync() as []
    const resultScores = maxScores.gather(bboxIndexs, 0).arraySync() as []
    const resultLables = labelIndexes.gather(bboxIndexs, 0).arraySync() as []

    return resultBboxes.map((bbox, index) => {
      return {
        x: bbox[1],
        y: bbox[0],
        w: bbox[3] - bbox[1],
        h: bbox[2] - bbox[0],
        score: resultScores[index],
        label: resultLables[index]
      }
    })
  });
  return bboxes;
}

結果の表示

app.tsx
import { useEffect, useRef, useState } from 'react';


import '@tensorflow/tfjs-backend-cpu';
import '@tensorflow/tfjs-backend-webgl';

import * as tf from '@tensorflow/tfjs';
import { YOLOMetadata, loadMetadata, loadYOLOModel } from './load_model';
import { predict } from './predict';
import { Bbox, resultToBbox } from './result_to_bbox';

function App() {
  const imgRef = useRef<HTMLImageElement>(null)
  const modelRef = useRef<tf.GraphModel | null>(null)
  const metadataRef = useRef<YOLOMetadata | null>(null)

  const [labelCount, setLabelCount] = useState(0)

  const startPredict = async () => {
    if (!imgRef.current || !metadataRef.current || !modelRef.current) {
      return;
    }
    const predicts = await predict(modelRef.current, imgRef.current, metadataRef.current.imgsz)
    const bboxes = await resultToBbox(predicts, labelCount)

    drawImage(bboxes)
  }
  const drawImage = (bboxes: Bbox[]) => {
    if (!imgRef.current) {
      return;
    }
    const img = imgRef.current
    const canvas = document.getElementById("canvas") as HTMLCanvasElement;
    canvas.width = img.width
    canvas.height = img.height

    const context = canvas.getContext("2d") as CanvasRenderingContext2D;
    context.drawImage(img, 0, 0, img.width, img.height)
    context.font = '10px Arial'

    const imgsz = metadataRef.current?.imgsz ?? [640, 640]

    const scaleFactorWidth = canvas.width / imgsz[0]
    const scaleFactorHeight = canvas.height / imgsz[1]

    bboxes.forEach(bbox => {
      context.beginPath();
      context.rect(
        bbox.x * scaleFactorWidth,
        bbox.y * scaleFactorHeight,
        bbox.w * scaleFactorWidth,
        bbox.h * scaleFactorHeight)
      context.strokeStyle = "red"
      context.stroke()
      context.fillText(
        `${metadataRef.current?.names[bbox.label] ?? ""} ${bbox.score.toFixed(3)}`, bbox.x * scaleFactorWidth, bbox.y * scaleFactorHeight + -10)
    })
  }

  const setup = async () => {
    metadataRef.current = await loadMetadata()
    modelRef.current = await loadYOLOModel(metadataRef.current!.imgsz)
    setLabelCount(Object.entries(metadataRef.current?.names ?? []).length)
  }

  useEffect(() => {
    setup()
  }, [])
  return (
      <div className=' h-screen w-screen bg-gray-700 p-8 text-center'>
        <h1 className=' text-red-600'>YOLO</h1>
        <div>
          <img ref={imgRef} src={"/ob_test_image.jpg"} className="rounded-md w-[500px]" alt="process image" />
        </div>
        <button onClick={() => startPredict()}>predict</button>
        <canvas id='canvas' className="rounded-md w-[500px]"></canvas>
      </div>
  )
}

export default App

推論を実行する(Oriented Bounding Box)

次に回転を含む場合のBBoxの推論を実行する場合の手順
回転しないBBoxからの変更点は以下です。

  • スコアの後ろにラジアンが追加される
  • nonMaxSuppressionがないので自作する
  • canvasで表示するときに回転させる

追加変更されている部分のコードだけ

結果をbboxに変換

resultToBbox を以下のように修正する。nonMaxSuppressionWithRotateは次で作成する。

result_to_bbox.ts
export const resultToBbox = (result: tf.Tensor<tf.Rank>, labelCount: number) => {
  const bboxes = tf.tidy(() => {
    const temp = result.squeeze()
    // x, y, w, hを取り出す
    const x = temp.slice([0, 0], [1, -1]); // x座標
    const y = temp.slice([1, 0], [1, -1]); // y座標
    const w = temp.slice([2, 0], [1, -1]); // 幅
    const h = temp.slice([3, 0], [1, -1]); // 高さ
    const r = temp.slice([(result.shape[1] ?? 0) - 1, 0], [1, -1])// R

    const x1 = tf.sub(x, tf.div(w, 2))
    const y1 = tf.sub(y, tf.div(h, 2))
    const x2 = tf.add(x1, w)
    const y2 = tf.add(y1, h)
    const boxes = tf.stack([y1, x1, y2, x2], 2).squeeze();
    const boxesWithR = tf.stack([x, y, w, h, r], 2).squeeze();

    const maxScores = temp.slice([4, 0], [labelCount, -1]).max(0)
    const labelIndexes = temp.slice([4, 0], [labelCount, -1]).argMax(0)
    const bboxIndexs = nonMaxSuppressionWithRotate(
      boxesWithR.as2D(boxesWithR.shape[0], boxesWithR.shape[1]!),
      maxScores.as1D(), 200, 0.45, 0.4)
    
    const resultBboxes = boxes.gather(bboxIndexs, 0).arraySync() as []
    const resultScores = maxScores.gather(bboxIndexs, 0).arraySync() as []
    const resultLables = labelIndexes.gather(bboxIndexs, 0).arraySync() as []

    const rs = r.squeeze().gather(bboxIndexs, 0).arraySync() as []

    return resultBboxes.map((bbox, index) => {
      return {
        x: bbox[1],
        y: bbox[0],
        w: bbox[3] - bbox[1],
        h: bbox[2] - bbox[0],
        score: resultScores[index],
        label: resultLables[index],
        r: rs[index]
      }
    })
  });
  return bboxes;
}

nonMaxSuppressionWithRotateを作成

テストしてないので間違っているかも。

tensorflow.jsのnonMaxSuppressionはOBBに対応してないので新しく作成する。
作ったはいいがnonMaxSuppressionをそのまま使ってもそれっぽい出力にはなったし、
テストもほぼしてないので、以下のコードは自信ないです。
rotationMatrixで回転した点を演算して、IoUの計算にはturf.jsを使用している。

non_max_suppression_with_rotate.ts
export default function nonMaxSuppressionWithRotate(
  boxes: tf.Tensor2D,
  score: tf.Tensor1D,
  maxOutputSize: number,
  iouThreshold: number = 0.5,
  scoreThreshold: number = 0.3
) {
  let candidates: Candidate[] = []
  // Scoreの閾値以下を切り捨て
  const scoreArray = score.arraySync()
  for (let i = 0; i < scoreArray.length; i++) {
    if (scoreArray[i] > scoreThreshold) {
      candidates.push({ score: scoreArray[i], boxIndex: i, box: null } as Candidate)
    }
  }

  // scoreをいい順番に並び変え
  candidates.sort((a, b) => (b.score - a.score))
  // 回転した座標に変換してからturfで処理するためのポリゴンに変更
  const candidatesTensor = tf.tensor1d(candidates.map(e => e.boxIndex), "int32")
  const rotatedMatrix = rotationMatrix(boxes.gather(candidatesTensor, 0))
  const polygons = matrix2Polygons(rotatedMatrix)
  polygons.forEach((polygon, index) => {
    candidates[index].box = polygon
  })

  // 選択されたボックスのインデックスを格納する配列を初期化します。
  const selectedIndices: Candidate[] = [];

  while (candidates.length > 0) {
    const currentCandidate = candidates[0]
    // 残っている候補で一番いいスコアのboxは残す
    selectedIndices.push(candidates[0])
    // maxを超えていたら終わり
    if (selectedIndices.length >= maxOutputSize) {
      break;
    }
    // 一番いいスコアのboxと残っているboxを比較し、IoUの値が閾値より小さいもののみ候補に残す
    currentCandidate.box
    candidates.filter(box => box.boxIndex !== currentCandidate.boxIndex)
    candidates = candidates.filter((candidate) => {
      if (candidate.boxIndex === currentCandidate.boxIndex) return false;
      const iou = calculateRotatedIOU(currentCandidate.box!, candidate.box!)
      return iou < iouThreshold
    })
  }

  return tf.tensor1d(selectedIndices.map(e => e.boxIndex), "int32")
}

// IOUの計算関数
function calculateRotatedIOU(polygon_a: Feature<Polygon, GeoJsonProperties>, polygon_b: Feature<Polygon, GeoJsonProperties>): number {
  const intersectPolygon = turf.intersect(turf.featureCollection([polygon_a, polygon_b]))
  if (!intersectPolygon) {
    return 0
  }
  const unionPolygon = turf.union(turf.featureCollection([polygon_a, polygon_b]))

  if (!unionPolygon) {
    return 0
  }

  const iou = turf.area(intersectPolygon) / turf.area(unionPolygon)
  return iou

}

// 回転した4つの点を求める
function rotationMatrix(boxes: tf.Tensor<tf.Rank>) {
  const results = tf.tidy(() => {
    const [x, y, w, h, rad] = tf.split(boxes, [1, 1, 1, 1, 1], 1)
    // cosAとsinAを計算 (角度のラジアン部分)
    const cos = tf.cos(rad).squeeze();
    const sin = tf.sin(rad).squeeze();

    // x,yを中心としてx1~4, y1~4を求める
    const x1 = w.div(-2).squeeze()
    const x2 = w.div(2).squeeze()
    const y1 = h.div(-2).squeeze()
    const y2 = h.div(2).squeeze()
    // 回転した点を求めるp1~時計回りに進める
    const p1x = x1.mul(cos).sub(y1.mul(sin)).add(x.squeeze())
    const p1y = x1.mul(sin).add(y1.mul(cos)).add(y.squeeze())
    const p2x = x2.mul(cos).sub(y1.mul(sin)).add(x.squeeze())
    const p2y = x2.mul(sin).add(y1.mul(cos)).add(y.squeeze())
    const p3x = x2.mul(cos).sub(y2.mul(sin)).add(x.squeeze())
    const p3y = x2.mul(sin).add(y2.mul(cos)).add(y.squeeze())
    const p4x = x1.mul(cos).sub(y2.mul(sin)).add(x.squeeze())
    const p4y = x1.mul(sin).add(y2.mul(cos)).add(y.squeeze())

    return tf.stack([p1x, p1y, p2x, p2y, p3x, p3y, p4x, p4y])
  })
  return results
}

function matrix2Polygons(matrix: tf.Tensor<tf.Rank>): Feature<Polygon, GeoJsonProperties>[] {

  const transposedMatrix = tf.tidy(() => {
    return matrix.transpose([1, 0])

  })
  const matrixArray = transposedMatrix.arraySync() as []
  return matrixArray.map(e => {
    return turf.polygon([[
      [e[0], e[1]],
      [e[2], e[3]],
      [e[4], e[5]],
      [e[6], e[7]],
      [e[0], e[1]],
    ]])
  })
}

Canvasの表示を修正

表示するときの手順は

  • Canvasの中心点をBBoxの中心に移動する。
  • Canvasを回転させる。
  • 線を引いて、restore()する。
app.tsx
    bboxes.forEach(bbox => {
      context.save()
      context.translate(
        bbox.x * scaleFactorWidth + ((bbox.w * scaleFactorWidth) / 2),
        bbox.y * scaleFactorHeight + ((bbox.h * scaleFactorHeight) / 2)
      )
      context.rotate(bbox.r)
      context.beginPath();
      context.rect(
        -bbox.w * scaleFactorWidth / 2,
        -bbox.h * scaleFactorHeight / 2,
        bbox.w * scaleFactorWidth,
        bbox.h * scaleFactorHeight)
      context.strokeStyle = "cyan"
      context.lineWidth = 2
      context.stroke()
      context.font = '20px Roboto medium';
      context.fillStyle = 'cyan';
      context.fillText(
        metadataRef.current?.names[bbox.label] ?? "none",
        -bbox.w * scaleFactorWidth / 2,
        -bbox.h * scaleFactorHeight / 2)

      context.restore()
    })

おわり