/ #IT学習 #AWS 

Amazon SageMakerを使ってみる【学習データ準備編】

はじめに

社内勉強会の討論ネタの確保のため、前回の記事ではAWSが提供している機械学習以外のAI関連サービスについてご紹介しました。
今回はAWSが提供する機械学習サービス「Amazon SageMaker」を使ってみようと思います。
なお、AIは完全初心者で勉強中の身なので、内容に間違いがある可能性があります。
認識違いなどございましたら、ご指摘よろしくお願いします。
(なお今回の作業にあたり、こちらの記事を参考にさせていただきました)

機械学習の流れ

超大雑把に言えば、学習用のデータを用意して、そのデータをAIのエンジンに読み込ませて学習させて、学習用のデータ以外のデータを渡して精度をチェックして、学習用データを変えてみたり、学習のパラメータを変えてみたりを繰り返していくイメージ。
今回の記事では最初の最初、学習用のデータを用意するところを解説したいと思います。

そもそもどんな機械学習を行うのか

一口に機械学習といっても、できることは色々ありますが、今回は「物体検出」を行いたいと思います。
物体検出とは、画像の中から定められた物体の位置とその分類を検出することを指します。
そして今回検出するのは、私もペットとして飼育している「デグーマウス」という動物です。

つまり、対象の画像からこのデグーマウスの位置を検出する機械学習を行おうと思います。

学習用データを用意する

物体検出の学習用データは検出する物体が含まれた画像と、その画像内のどこに物体があるのかという正解情報です。
つまり今回行う作業は、デグーマウスの画像を大量に集め、その画像内のどこにデグーマウスが写っているのかをマーキングするというものになります(この作業をタグ付けと呼びます)。

さてそれでは早速、デグーマウスの画像を手に入れるためにGoogle画像検索を見てみましょう。

たくさんの可愛いデグーマウスたちが表示されました。
これらの画像をひとつひとつ保存してもいいのですが、それは大変なので今回はツールを使おうと思います。
今回使用するのは「google-images-download」というコマンドラインツールです。
このツールはPythonで作られており、Pythonのパッケージ管理ツールである「pip」を使ってインストールすることができます。

1pip install google_images_download

インストールが完了したら、以下のコマンドでデグーマウスの画像を取得することができます。

1googleimagesdownload --keywords "デグーマウス"

これを実行すると…
やりました!

このコマンドで取得できた画像は98枚ですが、これではちょっと少ないので水増しします。
以下のPythonコード(お借りしました)を使って、1つの画像を回転させた画像を複数生成して水増ししました。

 1import os
 2from PIL import Image, ImageFilter
 3 
 4def main():
 5    data_dir_path = './out/'
 6    data_dir_path_in = './in/'
 7    file_list = os.listdir('./in/')
 8 
 9    count = 1
10    for file_name in file_list:
11        root, ext = os.path.splitext(file_name)
12        if ext == '.png' or '.jpeg' or '.jpg':
13            img = Image.open(data_dir_path_in + '/' + file_name) 
14            tmp = img.transpose(Image.FLIP_LEFT_RIGHT)
15            tmp.save(data_dir_path + '/' + '{0:04d}'.format(count) +'.jpg')
16            count+=1
17            tmp = img.transpose(Image.FLIP_TOP_BOTTOM)
18            tmp.save(data_dir_path + '/' + '{0:04d}'.format(count) +'.jpg')
19            count+=1
20            tmp = img.transpose(Image.ROTATE_90)
21            tmp.save(data_dir_path + '/' + '{0:04d}'.format(count) +'.jpg')
22            count+=1
23            tmp = img.transpose(Image.ROTATE_180)
24            tmp.save(data_dir_path + '/' + '{0:04d}'.format(count) +'.jpg')
25            count+=1
26            tmp = img.transpose(Image.ROTATE_270)
27            tmp.save(data_dir_path + '/' + '{0:04d}'.format(count) +'.jpg')
28            count+=1
29            tmp = img.rotate(15)
30            tmp.save(data_dir_path + '/' + '{0:04d}'.format(count) +'.jpg')
31            count+=1
32            tmp = img.rotate(75)
33            tmp.save(data_dir_path + '/' + '{0:04d}'.format(count) +'.jpg')
34            count+=1
35            tmp = img.rotate(135)
36            tmp.save(data_dir_path + '/' + '{0:04d}'.format(count) +'.jpg')
37            count+=1         
38            tmp = img.rotate(195)
39            tmp.save(data_dir_path + '/' + '{0:04d}'.format(count) +'.jpg')
40            count+=1
41            tmp = img.rotate(255)
42            tmp.save(data_dir_path + '/' + '{0:04d}'.format(count) +'.jpg')
43            count+=1
44 
45if __name__ == '__main__':
46    main()

こんな感じで生成されます!560枚に増やせました

次にVoTT(Visual Object Tagging Tool)というツールを使って、この画像ひとつひとつのどこにデグーマウスが写っているのかをタグ付けしていきます。
本来なら場所だけでなく分類付け(白っぽいデグーマウスと茶色のデグーマウスを区別するとか)も可能なのですが、今回は分類はデグーマウスひとつだけで、あくまで位置を検出するだけとします。

こんな画面でタグ付けします。

私はこの矩形選択で位置をタグ付けするという作業を手作業で画像560枚に対して行いました(1時間かかりました)。
もっと効率のいい方法がありそうな気もするので、ご存知の方がおりましたら、教えてください…

そうしてすべての画像のタグ付けが終わると「out.json」というファイルが出力されます。
これが画像のどこにデグーマウスが写っているのか、タグ付けした結果になります。
ただこのままではAmazon SageMakerに読み込ませることはできないみたいですので、以下のPythonコードを使って変換します。
(こちらのサイトのものをお借りしましたが、現時点のVoTTのout.jsonを読ませるとエラーが発生したので一部修正しています)

 1import json
 2 
 3file_name = './out.json'
 4class_list = {'Degu':0}
 5 
 6with open(file_name) as f:
 7    js = json.load(f)
 8 
 9    for k, v in js['frames'].items():
10 
11        k = int(k.replace('.jpg',''))
12        if len(v) == 0:
13            continue
14        line = {}
15        line['file'] = '{0:04d}'.format(k+1) + '.jpg'
16        line['image_size'] = [{
17            'width':int(v[0]['width']),
18            'height':int(v[0]['height']),
19            'depth':3
20        }]
21 
22        line['annotations'] = []
23 
24        for annotation in v:
25 
26            line['annotations'].append(
27                {
28                    'class_id':class_list[annotation['tags'][0]],
29                    'top':int(annotation['y1']),
30                    'left':int(annotation['x1']),
31                    'width':int(annotation['x2'])-int(annotation['x1']),
32                    'height':int(annotation['y2']-int(annotation['y1']))
33                }
34            )
35 
36        line['categories'] = []
37         
38        for name, class_id in class_list.items():
39 
40            line['categories'].append(
41                {
42                    'class_id':class_id,
43                    'name':name
44                }
45            )
46 
47        f = open('./json/'+'{0:04d}'.format(k+1) + '.json', 'w')
48        json.dump(line, f)

完了するとこのようなjsonファイルが出力されます。

これで学習用データの用意ができました!

おわりに

今回はAmazon SageMakerで使用する、学習用データの用意を行いました。
次回以降の記事で実際に学習をさせてみようと思います。
本当にタグ付けは途中で発狂しそうな作業だったので、誰かもっといい方法をご存知なら、私のTwitterで教えてください…

それでは今回はこの辺で。ここまで読んでいただき、ありがとうございました!

私のおすすめの書籍です。ぜひ読んでみてください!

Author

りんごく

2019年8月よりフリーランスとして活動するフロントエンドエンジニア。Reactが友達。