王荣胜

用Google计算机视觉API进行猫狗图片识别

导入Google Colab目录

from google.colab import drive
drive.flush_and_unmount()
drive.mount('/content/gdrive', force_remount=True)
root_dir = "/content/gdrive/My Drive/"
Mounted at /content/gdrive

创建文件夹

创建两个文件夹,一个cat,一个dog。

from fastai.vision import *
path = Path(root_dir + 'DeepLearning/Datasets2/')
 
dest1 = path/'cat'
dest1.mkdir(parents=True, exist_ok=True)
dest2 = path/'dog'
dest2.mkdir(parents=True, exist_ok=True)

构建数据集

Google图片上搜cat,浏览一会儿,按F12,再点console,或者Ctrl+Shift+J打开控制台,直接粘贴以下五段代码:

function simulateRightClick( element ) {
    var event1 = new MouseEvent( 'mousedown', {
        bubbles: true,
        cancelable: false,
        view: window,
        button: 2,
        buttons: 2,
        clientX: element.getBoundingClientRect().x,
        clientY: element.getBoundingClientRect().y
    } );
    element.dispatchEvent( event1 );
    var event2 = new MouseEvent( 'mouseup', {
        bubbles: true,
        cancelable: false,
        view: window,
        button: 2,
        buttons: 0,
        clientX: element.getBoundingClientRect().x,
        clientY: element.getBoundingClientRect().y
    } );
    element.dispatchEvent( event2 );
    var event3 = new MouseEvent( 'contextmenu', {
        bubbles: true,
        cancelable: false,
        view: window,
        button: 2,
        buttons: 0,
        clientX: element.getBoundingClientRect().x,
        clientY: element.getBoundingClientRect().y
    } );
    element.dispatchEvent( event3 );
}
 
 
 
 
 
 
function getURLParam( queryString, key ) {
    var vars = queryString.replace( /^\?/, '' ).split( '&' );
    for ( let i = 0; i < vars.length; i++ ) {
        let pair = vars[ i ].split( '=' );
        if ( pair[0] == key ) {
            return pair[1];
        }
    }
    return false;
}
 
 
 
 
 
 
 
function createDownload( contents ) {
    var hiddenElement = document.createElement( 'a' );
    hiddenElement.href = 'data:attachment/text,' + encodeURI( contents );
    hiddenElement.target = '_blank';
    hiddenElement.download = 'urls.txt';
    hiddenElement.click();
}
 
 
 
 
 
 
 
 
 
 
function grabUrls() {
    var urls = [];
    return new Promise( function( resolve, reject ) {
        var count = document.querySelectorAll(
        	'.isv-r a:first-of-type' ).length,
            index = 0;
        Array.prototype.forEach.call( document.querySelectorAll(
        	'.isv-r a:first-of-type' ), function( element ) {
            // using the right click menu Google will generate the
            // full-size URL; won't work in Internet Explorer
            // (http://pyimg.co/byukr)
            simulateRightClick( element.querySelector( ':scope img' ) );
            // Wait for it to appear on the <a> element
            var interval = setInterval( function() {
                if ( element.href.trim() !== '' ) {
                    clearInterval( interval );
                    // extract the full-size version of the image
                    let googleUrl = element.href.replace( /.*(\?)/, '$1' ),
                        fullImageUrl = decodeURIComponent(
                        	getURLParam( googleUrl, 'imgurl' ) );
                    if ( fullImageUrl !== 'false' ) {
                        urls.push( fullImageUrl );
                    }
                    // sometimes the URL returns a "false" string and
                    // we still want to count those so our Promise
                    // resolves
                    index++;
                    if ( index == ( count - 1 ) ) {
                        resolve( urls );
                    }
                }
            }, 10 );
        } );
    } );
}
 
 
 
 
 
 
 
 
grabUrls().then( function( urls ) {
    urls = urls.join( '\n' );
    createDownload( urls );
} );

同样的操作搜索dog

把下载的txt改名分别为:cat.txt、dog.txt,为后续工作做准备

下载数据集

cat.txtdog.txt文件上传于path目录下,执行如下命令:

download_images(path/'cat.txt', dest1, max_pics=120)
download_images(path/'dog.txt', dest2, max_pics=120)

其中使用了fast.ai.vision.datadownload_images,指定txt文件的位置,存放图片的目的文件夹,最大的图片数目,最大处理线程数目

数据清理

下载下来的图片格式(jpegpnggift等)大小都不一致,需要进行清理,fast.ai提供了一个verify_images的工具,可以对图像进行基本的清理操作:

classes = ['cat','dog']
for c in classes:
    print(c)
    verify_images(path/c, delete=True, max_size=500)
output_16_1.png

训练模型

使用预训练模型resnet34进行迁移学习:

learn = cnn_learner(data, models.resnet34, metrics=error_rate)
learn.fit_one_cycle(4)
epoch train_loss valid_loss error_rate time
0 1.172639 0.687638 0.276596 00:46
1 0.715001 0.246741 0.106383 00:46
2 0.524886 0.033620 0.000000 00:46
3 0.408488 0.008482 0.000000 00:46

解冻模型,修改学习率,再训练几轮:

learn = cnn_learner(data, models.resnet34, metrics=error_rate)
learn.fit_one_cycle(4)
epoch train_loss valid_loss error_rate time
0 1.137841 0.362603 0.148936 00:48
1 0.717203 0.497480 0.127660 00:46
2 0.532314 0.338967 0.085106 00:46
3 0.407410 0.189146 0.063830 00:46

调整学习率继续学习

learn.fit_one_cycle(10,max_lr=slice(1e-4,3e-4))
learn.save('stage-2')
epoch train_loss valid_loss error_rate time
0 0.115359 0.114586 0.042553 00:45
1 0.062113 0.068654 0.042553 00:46
2 0.061450 0.041762 0.021277 00:46
3 0.051107 0.029947 0.000000 00:46
4 0.046757 0.019579 0.000000 00:46
5 0.066720 0.015556 0.000000 00:46
6 0.063555 0.011138 0.000000 00:46
7 0.056562 0.009309 0.000000 00:47
8 0.053216 0.008640 0.000000 00:46
9 0.053072 0.009092 0.000000 00:47

fast.ai提供了ClassificationInterpretation能够从结果中高效的进行分析

interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix(figsize=(12,12), dpi=60)
output_24_2.png
interp.plot_top_losses(9, figsize=(15,11))
output_25_0.png

备用学习地址

  1. Github
  2. 查看jupyter notebook
小白
本文作者:王荣胜 |「邮箱 」| 「QQ」 | 「QQ群
文章出处:https://sqdxwz.top/
版权声明:本文章采用「知识共享署名-相同方式共享 4.0 国际许可协议」许可。
Powered by 王荣胜 | | 载入天数...载入时分秒...
Google     Server provider     jsDelivr    
正在加载今日诗词....