风格迁移应用_PyTorch实战图形风格迁移
前言
什么是圖像風格的遷移?其實現在很多的APP應用中已經普遍存在了,比如讓我們選擇一張自己的大頭照,然后選擇一種風格的圖片,確認后我們的大頭照變成了所選圖片類似的風格。
圖像風格遷移重點就是找出一張圖片的特征,然后將其融合到需要改變的圖片中去,如下圖所展示的就是一種典型的風格遷移。
所以圖像風格遷移實現的難點就在于如何提取一張圖片的特征,這里說的特征也就是圖像的風格。論文《A Neural Algorithm of Artistic Style》使用了CNN(卷積神經網絡)來對圖像的風格進行提取。因為我們都知道CNN本來就可以對特征圖像進行提取,然后通過特征來實現圖像的分類。當我們有了圖像風格的提取方法后,只需要將新提取到的風格融入到新的圖片中去,就實現了圖像風格的遷移。
1、PyTorch核心代碼實現
其實代碼的核心思想并不復雜,就是利用CNN提取內容圖片的內容和風格圖片的風格,然后輸入一張新的圖像。對輸入的圖像提取出內容和風格與CNN提取的內容和風格進行Loss計算,Loss的度量可以使用MSE,然后逐步對Loss進行優化,使Loss值達到最理想,將被優化的參數進行輸出,這樣輸出的圖片就達到了風格遷移的目的。
(一)、計算內容損失
為什么使用卷積提取內容?
下圖是我通過一個卷積提取到的其中一個特征映射,說明使用卷積作為內容提取的方法是完全可行的。
計算內容損失的代碼如下:
class Content_loss(torch.nn.Module):def __init__(self, weight, target):super(Content_loss, self).__init__()self.weight = weightself.target = target.detach()*weightself.loss_fn = torch.nn.MSELoss()def forward(self, input):self.loss = self.loss_fn(input*self.weight, self.target)self.output = inputreturn self.outputdef backward(self):self.loss.backward(retain_graph = True)return self.loss這里的target就是CNN對內容圖像提取得到的內容,weight是用來控制內容和風格對input圖像的影響程度,這里的input就是我們輸入圖像,還有定義的backward主要目的其實是為了調用方向傳播方法和返回我們計算得到的Loss。Loss計算使用的是MSE來度量。
(二)、計算風格損失
計算風格損失的代碼如下:
class Style_loss(torch.nn.Module):def __init__(self, weight, target):super(Style_loss, self).__init__()self.weight = weightself.target = target.detach()*weightself.loss_fn = torch.nn.MSELoss()self.gram = gram_matrix()def forward(self, input):self.output = input.clone()self.G = self.gram(input)self.G.mul_(self.weight)self.loss = self.loss_fn(self.G, self.target)return self.outputdef backward(self):self.loss.backward(retain_graph = True)return self.loss這里的target、weight、input、backward、Loss使用的意義和之前的內容計算類似,唯一不同的地方是引入了Gram矩陣,通過對CNN提取后的內容進行Gram矩陣運算來定義圖像的風格。
為什么Gram矩陣能夠定義圖像的風格了?
因為CNN卷積過后提取了圖像的特征圖,每個數字就是原圖像的特性大小,而Gram矩陣是矩陣的內積運算,運算過后特征圖中越大的數字會變得更大,這就相當于對圖像的特性進行了縮放,使得特征突出了,也就相當于提取到了圖片的風格。
Gram矩陣的代碼如下:
class gram_matrix(torch.nn.Module):def forward(self, input):a,b,c,d = input.size()feature = input.view(a*b, c*d)gram = torch.mm(feature, feature.t())return gram.div(a*b*c*d)(三)、構建訓練CNN
構建新的訓練模型代碼:
content_layer = ["Conv_5","Conv_6"]style_layer = ["Conv_1", "Conv_2", "Conv_3", "Conv_4", "Conv_5"]content_losses = [] style_losses = []conten_weight = 1 style_weight = 1000new_model = torch.nn.Sequential()model = copy.deepcopy(cnn)gram = gram_matrix()if use_gpu:new_model = new_model.cuda()gram = gram.cuda()index = 1 for layer in list(model):if isinstance(layer, torch.nn.Conv2d):name = "Conv_"+str(index)new_model.add_module(name, layer)if name in content_layer:target = new_model(content_img).clone()content_loss = Content_loss(conten_weight, target)new_model.add_module("content_loss_"+str(index), content_loss)content_losses.append(content_loss)if name in style_layer:target = new_model(style_img).clone()target = gram(target)style_loss = Style_loss(style_weight, target)new_model.add_module("style_loss_"+str(index), style_loss)style_losses.append(style_loss)if isinstance(layer, torch.nn.ReLU):name = "Relu_"+str(index)new_model.add_module(name, layer)index = index+1if isinstance(layer, torch.nn.MaxPool2d):name = "MaxPool_"+str(index)new_model.add_module(name, layer)要完成風格遷移,我們還需要構建自己的CNN網絡。首先遷移了vgg16的模型,剔除了全連接部分,之后就是根據vgg16模型架構重構訓練模型,加入了內容和風格Loss的計算部分。這里內容的提取只是選擇了5、6層卷積,風格的提取只選擇了1、2、3、4、5層卷積。
(四)、定義優化
優化定義代碼:
input_img = content_img.clone()parameter = torch.nn.Parameter(input_img.data) optimizer = torch.optim.LBFGS([parameter])這里為什么使用LBFGS來進行優化?
原因是我們要優化的Loss其實是多個,而不是像處理分類問題中只是需要優化一個Loss值,LBFGS能夠獲得更好的效果。
(五)、訓練新定義的CNN
訓練代碼如下:
n_epoch = 1000run = [0] while run[0] <= n_epoch:def closure():optimizer.zero_grad()style_score = 0content_score = 0parameter.data.clamp_(0,1)new_model(parameter)for sl in style_losses:style_score += sl.backward()for cl in content_losses:content_score += cl.backward()run[0] += 1 if run[0] % 50 == 0:print('{} Style Loss : {:4f} Content Loss: {:4f}'.format(run[0],style_score.data[0], content_score.data[0])) return style_score+content_scoreoptimizer.step(closure)n_epoch定義了訓練次數為1000次,使用sl.backward()和cl.backward()實現了反向傳播,對參數進行優化。
2、改進
本文的圖像風格遷移的方法沒次實現都要進行一輪訓練,而且風格調節的方式需要通過weight權重來控制,在實際應用中并不理想,現實中我們需要更加高效智能的實現方式。改進方法已經出現,先放出兩篇論文
Fast Patch-based Style Transfer of Arbitrary Style
Visual Attribute Transfer through Deep Image Analogy
代碼還在實現中......
參考資料:1、Welcome to PyTorch Tutorials
2、圖像風格遷移(Neural Style)簡史
完整代碼:JaimeTang/PyTorch-and-Neural-style-transfer
如果覺得還行,請點個贊哦......
微信公眾號:PyMachine
總結
以上是生活随笔為你收集整理的风格迁移应用_PyTorch实战图形风格迁移的全部內容,希望文章能夠幫你解決所遇到的問題。
- 上一篇: 电工必备实用口诀④
- 下一篇: ssl 2520 小球