1. Introduction
Trong những năm gần đây, Artificial Intelligence (AI) hay trí tuệ nhân tạo đang có những sự phát triển mạnh mẽ trên nhiều lĩnh vực. Một trong những lĩnh vực có được sự phát triển thần sầu nhất phải kể đến là Computer Vision. Sự ra đời của Convolution Neural Network (mạng neuron tich chập) cùng với phát triển mạnh mẽ của Deep Learning đã giúp cho Computer Vision đạt được nhiều bước đột phá đáng kế trong các bài toán: image classification, object detection, video tracking, image restoration, etc. Trong bài viết này mình sẽ giới thiệu với các bạn các CNN pre-trained model nổi tiếng. Bài viết này hướng tới những bạn đã có lý thuyết tốt về Deep Learning nói chung và Convolution Neural Network nói riêng. Trong phần demo mình sẽ sử dụng Python và thư viện Keras.
2. Pre-trained CNN model in Keras VGG-16
Đầu tiên phải kể tới mạng VGG. VGG ra đời năm 2015 và được giới thiệu tại hội thảo ICLR 2015. Kiến trúc của mô hình này có nhiều biến thể khác nhau: 11 layers, 13 layers, 16 layers, và 19 layers, các bạn có thể xem chi tiết trong hình sau: Trong bài viết này mình sẽ đề cập tới VGG-16 – kiến trúc mạng có 16 layers. Nguyên tắc thiết kế của các mạng VGG nói chung rất đơn giản: 2 hoặc 3 layers Convolution (Conv) và tiếp nối sau đó là 1 layer Max Pooling 2D. Ngay sau Conv cuối cùng là 1 Flatten layer để chuyển ma trận 4 chiều của Conv layer về ma trận 2 chiều. Tiếp nối sau đó là các Fully-connected layers và 1 Softmax layer. Do VGG được training trên tập dữ liệu của ImageNet có 1000 class nên ở Fully-connected layer cuối cùng sẽ có 1000 units.
Trong Keras hiện tại có hỗ trợ 2 pre-trained model của VGG: VGG-16 và VGG 19. Có 2 params chính các bạn cần lưu ý là include_top (True / False): có sử dụng các Fully-connected layer hay không và weights (‘imagenet’ / None): có sử dụng pre-trained weights của ImageNet hay không.
1 2 3 4 5 6 7 8 9 10 11 12 | <span class="token comment"># VGG 16 <span class="token keyword">from</span> keras<span class="token punctuation">.</span>applications<span class="token punctuation">.</span>vgg16 <span class="token keyword">import</span> VGG16 <span class="token comment"># Sử dụng pre-trained weight từ ImageNet và không sử dụng các Fully-connected layer ở cuối pretrained_model <span class="token operator">=</span> VGG16<span class="token punctuation">(</span>include_top<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span> weights<span class="token operator">=</span><span class="token string">'imagenet'</span><span class="token punctuation">)</span> <span class="token comment"># VGG 19 <span class="token keyword">from</span> keras<span class="token punctuation">.</span>applications<span class="token punctuation">.</span>vgg19 <span class="token keyword">import</span> VGG19 <span class="token comment"># Không sử dụng pre-trained weight từ ImageNet và không sử dụng các Fully-connected layer ở cuối pretrained_model <span class="token operator">=</span> VGG19<span class="token punctuation">(</span>include_top<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span> weights<span class="token operator">=</span><span class="token boolean">None</span><span class="token punctuation">)</span> |
InceptionNet
Với các mạng CNN thông thường, khi thiết kế ta bắt buộc phải xác định trước các tham số của 1 Conv layer như: kernel_size, padding, strides, etc. Và thường thì rất khó để ta có thể nói trước được tham số nào là phù hợp, kernel_size = (1, 1) hay (3, 3) hay (5, 5) sẽ tốt hơn. Mạng Inception ra đời để giải quyết vấn đề đó, yếu tố quan trọng nhất trong mạng Inception là Inception module, một mạng Inception hoàn chỉnh bao gồm nhiều module Inception nhỏ ghép lại với nhau. Các bạn có thể xem hình minh họa dưới đây để hiểu rõ hơn. Ý tưởng của Inception module rất đơn giản, thay vì sử dụng 1 Conv layer với tham số kernel_size cố định, ta hoàn toàn có thể sử dụng cùng lúc nhiều Conv layer với các tham số kernel_size khác nhau (1, 3, 5, 7, etc) và sau đó concatenate các output lại với nhau. Để không bị lỗi về chiều của ma trận khi concatenate, tất cả các Conv layer đều có chung strides=(1, 1) và padding=’same’. Ở thời điểm hiện tại, có 3 phiên bản của mạng Inception, các version sau thường có 1 vài điểm cải tiến so với phiên bản trước để cải thiện độ chính xác. Keras hiện tại hỗ trợ pre-trained model Inception version 3.
1 2 3 4 5 6 7 | <span class="token keyword">from</span> keras<span class="token punctuation">.</span>applications<span class="token punctuation">.</span>inception_v3 <span class="token keyword">import</span> InceptionV3 <span class="token comment"># các tham số include_top và weights các bạn có thể tùy chỉnh <span class="token comment"># theo ý muốn để phù hợp với bài toán của riêng mình. pretrained_model <span class="token operator">=</span> InceptionV3<span class="token punctuation">(</span>include_top<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span> weights<span class="token operator">=</span><span class="token string">'imagenet'</span><span class="token punctuation">)</span> |
ResNet
Mô hình CNN tiếp theo mình muốn giới thiệu với các bạn là ResNet hay Residual Network. Khi train các mô hình Deep CNN (số lượng layers lớn, số lượng param lớn, etc) ta thường gặp phải vấn đề về vanishing gradient hoặc exploding gradient. Thực tế cho thấy khi số lượng layer trong CNN model tăng, độ chính xác của mô hình cũng tăng theo, tuy nhiên khi tăng số layers quá lớn (>50 layers) thì độ chính xác lại bị giảm đi.
Residual block ra đời nhằm giải quyết vấn đề trên, với Residual block, ta hoàn toàn có thể train các mô hình CNN có kích thước và độ phức tạp “khủng” hơn mà không lo bị exploding/vanishing gradient. Mấu chốt của Residual block là cứ sau 2 layer, ta cộng input với output: F(x) + x. Resnet là một mạng CNN bao gồm nhiều Residual block nhỏ tạo thành. Hiện tại trong Keras có pre-trained model của ResNet50 với weight được train trên tập ImageNet với 1000 clas.
1 2 3 4 5 | <span class="token keyword">from</span> keras<span class="token punctuation">.</span>applications<span class="token punctuation">.</span>resnet50 <span class="token keyword">import</span> ResNet50 pretrained_model <span class="token operator">=</span> ResNet50<span class="token punctuation">(</span>include_top<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span> weights<span class="token operator">=</span><span class="token string">'imagenet'</span><span class="token punctuation">)</span> |
InceptionResNet
Nghe tên mô hình chắc các bạn cũng đoán ra được cấu hình của mạng rồi. InceptionResNet là mô hình được xây dựng dựa trên nhưng ưu điểm của Inception và Residual block. Với sự kết hợp này InceptionResNet đạt được đô chính xác rất đáng kinh ngạc. Trên tập dữ liệu ImageNet, InceptionResNet đạt 80.3% top 1 accuray trong khi con số này của Inception V3 và ResNet50 lần lượt là: 77.9% và 74.9%. Để sử dụng pre-trained InceptionResNet trong Keras chúng tac cũng làm tương tự như các pre-trained model khác.
1 2 3 4 5 | <span class="token keyword">from</span> keras<span class="token punctuation">.</span>applications<span class="token punctuation">.</span>inception_resnet_v2 <span class="token keyword">import</span> InceptionResNetV2 pretrained_model <span class="token operator">=</span> InceptionResNetV2<span class="token punctuation">(</span>include_top<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> weights<span class="token operator">=</span><span class="token string">'imgaenet'</span><span class="token punctuation">)</span> |
MobileNet
Các mô hình CNN vừa được giới thiệu, tuy có độ chính xác cao, nhưng chúng đều có một điểm hạn chế chung đó là không phù hợp với các ứng dụng trên mobile hay các hệ thống nhúng có khả năng tính toán thấp. Nếu muốn deploy các mô hình trên cho các ứng dụng real time, ta cần phải có cấu hình cực kì mạnh mẽ (GPU / TPU) còn đối với các hệ thống nhúng (Raspberry Pi, Nano pc, etc) hay các ứng dụng chạy trên smart phone, ta cần có một mô hình “nhẹ” hơn. Dưới đây benchmark các mô hình trên cùng tập dữ liệu ImageNet, ta có thể thấy MobileNetV2 có độ chính xác không hề thua kém các mô hình khác như VGG16, VGG19 trong khi lượng parameters chỉ vỏn vẹn 3.5M (khoảng 1/40 số tham số của VGG16).
Yếu tố chính giúp MobileNet có được độ chính xác cao trong khi thời gian tính toán thấp nằm ở sự cải tiến Conv layer bình thường. Trong MobileNet có 2 Covn layer được sử dụng là: SeparableConv và DepthwiseConv. Thay vì thực hiện phép tích chập như thông thường, SeparableConv sẽ tiến hành phép tích chập depthwise spatial (mình cũng không biết dịch ntn luôn ????????????) sau đó là phép tích chập pointwise (cũng không biết dịch nốt ????). Còn DepthwiseConv sẽ chỉ thực hiện phép tích chập depthwise spatial (không tính pointwise convolution). Việc chia phép tích chập ra như vậy giúp giảm đáng kể khối lượng tinh toán và số lượng tham số của mạng. Với sự thay đổi này, MobileNet có thể hoạt động một cách mượt mà ngay cả trên phần cứng cấu hình thấp. Và vẫn như ccas pre-trained model trước, Keras cũng có hộ trợ tận răng cho các bạn luôn:
1 2 3 4 5 | <span class="token keyword">from</span> keras<span class="token punctuation">.</span>applications<span class="token punctuation">.</span>mobilenet_v2 <span class="token keyword">import</span> MobileNetV2 pretrained_model <span class="token operator">=</span> MobileNetV2<span class="token punctuation">(</span>include_top<span class="token operator">=</span><span class="token boolean">True</span><span class="token punctuation">,</span> weights<span class="token operator">=</span><span class="token string">'imagenet'</span><span class="token punctuation">)</span> |
3. Build your own CNN model with a pre-trained model
Ở phần trước mình đã giới thiệu với các bạn pre-trained model nổi tiếng, tuy nhiên có thể các bạn vẫn còn khá mông lung về cách sử dụng các pre-trained model này vào trong bài toán thực tế của riêng mình. Vì vậy, ở phần này mình sẽ có 1 demo nho nhỏ để mọi người có thể có cái nhìn cụ thể hơn về transfer learning.
Giả sử bài toán đặt ra là image classificatoin, nhận biết các kí tự chữ từ a-z, A-Z và các số từ 0-9 (tổng cộng 62 labels). Ảnh đầu vào có kích thước 96x96x3 (rgb image). Pre-trained mình sử dụng là MobileNetV2.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 | <span class="token comment"># import các module cần thiết <span class="token keyword">from</span> keras<span class="token punctuation">.</span>applications<span class="token punctuation">.</span>mobilenet_v2 <span class="token keyword">import</span> MobileNetV2 <span class="token keyword">from</span> keras<span class="token punctuation">.</span>layers <span class="token keyword">import</span> Input<span class="token punctuation">,</span> BatchNormalization <span class="token keyword">from</span> keras<span class="token punctuation">.</span>layers <span class="token keyword">import</span> GlobalAveragePooling2D <span class="token keyword">from</span> keras<span class="token punctuation">.</span>layers <span class="token keyword">import</span> Activation <span class="token keyword">from</span> keras<span class="token punctuation">.</span>models <span class="token keyword">import</span> Model <span class="token comment"># 62 class: a-z, A-Z, 0-9 NUMBER_CLASSES <span class="token operator">=</span> <span class="token number">62</span> <span class="token comment"># build model <span class="token keyword">def</span> <span class="token function">create_model</span><span class="token punctuation">(</span>input_shape<span class="token operator">=</span><span class="token punctuation">(</span><span class="token number">96</span><span class="token punctuation">,</span> <span class="token number">96</span><span class="token punctuation">,</span> <span class="token number">3</span><span class="token punctuation">)</span><span class="token punctuation">)</span><span class="token punctuation">:</span> <span class="token comment"># khai bao input layer input_layer <span class="token operator">=</span> Input<span class="token punctuation">(</span>shape<span class="token operator">=</span>input_shape<span class="token punctuation">,</span> name<span class="token operator">=</span><span class="token string">'input'</span><span class="token punctuation">)</span> <span class="token comment"># su dung pre-trained model pretrained_model <span class="token operator">=</span> MobileNetV2<span class="token punctuation">(</span>include_top<span class="token operator">=</span><span class="token boolean">False</span><span class="token punctuation">,</span> weights<span class="token operator">=</span><span class="token string">'imagenet'</span><span class="token punctuation">)</span> pretrained_model_output <span class="token operator">=</span> pretrained_model<span class="token punctuation">(</span>input_layer<span class="token punctuation">)</span> global_avg <span class="token operator">=</span> GlobalAveragePooling2D<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">(</span>pretrained_model_output<span class="token punctuation">)</span> <span class="token comment"># fully-connected layer 1 dense <span class="token operator">=</span> Dense<span class="token punctuation">(</span>units<span class="token operator">=</span><span class="token number">512</span><span class="token punctuation">)</span><span class="token punctuation">(</span>global_avg<span class="token punctuation">)</span> dense <span class="token operator">=</span> BatchNormalization<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">(</span>dense<span class="token punctuation">)</span> dense <span class="token operator">=</span> ReLU<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">(</span>dense<span class="token punctuation">)</span> output <span class="token operator">=</span> Dense<span class="token punctuation">(</span>units<span class="token operator">=</span>NUMBER_CLASSES<span class="token punctuation">)</span><span class="token punctuation">(</span>dense<span class="token punctuation">)</span> output <span class="token operator">=</span> Activation<span class="token punctuation">(</span><span class="token string">'softmax'</span><span class="token punctuation">)</span><span class="token punctuation">(</span>output<span class="token punctuation">)</span> model <span class="token operator">=</span> Model<span class="token punctuation">(</span>input_layer<span class="token punctuation">,</span> output<span class="token punctuation">)</span> <span class="token keyword">print</span> <span class="token punctuation">(</span>model<span class="token punctuation">.</span>summary<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">)</span> <span class="token keyword">return</span> model model <span class="token operator">=</span> create_model<span class="token punctuation">(</span><span class="token punctuation">)</span> |
Mô hình chúng ta vừa định nghĩa có kiến trúc như sau:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 | Layer (type) Output Shape Param # ================================================================= input (InputLayer) (None, 96, 96, 3) 0 _________________________________________________________________ mobilenetv2_1.00_224 (Model) multiple 2257984 _________________________________________________________________ global_average_pooling2d_1 ( (None, 1280) 0 _________________________________________________________________ dense_1 (Dense) (None, 512) 655872 _________________________________________________________________ batch_normalization_1 (Batch (None, 512) 2048 _________________________________________________________________ re_lu_1 (ReLU) (None, 512) 0 _________________________________________________________________ dense_2 (Dense) (None, 62) 31806 _________________________________________________________________ activation_1 (Activation) (None, 62) 0 ================================================================= Total params: 2,947,710 Trainable params: 2,912,574 Non-trainable params: 35,136 |
Như vậy là anh em đã có 1 model CNN với pre-trained là MobileNetV2. Chỉ cần compile và fit dữ liệu vào để training là xong, so easy.
1 2 3 4 5 | model.compile(...) modle.fit(...) |
4. Conclusions
Như vậy lfa trogn bài viết này mình đã giới thiệu với các bạn 1 số pre-trained model nổi tiếng trong lĩnh vực Computer Vision. Khi xây dựng một mô hình CNN cho một bài toán nào đó các bạn hoàn toàn có thể sử dụng transfer learning thay vì xây dựng mô hình từ đầu. Để có thể hiểu rõ hơn về ý tưởng cũng như nguyên tắc thiết kế của các mô hình trên các bạn hoàn toàn có thể vào google scholar tìm kiếm và đọc thêm để hiểu sâu hơn. Rất mong bài viết này có thể giúp các bạn xây dựng mô hình CNN nhanh hơn, tiết kiệm thời gian training hơn so với cách xây dựng mô hình từ đầu. Hẹn gặp lại các bạn trong những bài viết sau.
TechTalk via Viblo