自定義Autograd函數
對于淺層的網絡,我們可以手動的書寫前向傳播和反向傳播過程。但是當網絡變得很大時,特別是在做深度學習時,網絡結構變得復雜。前向傳播和反向傳播也隨之變得復雜,手動書寫這兩個過程就會存在很大的困難。幸運地是在pytorch中存在了自動微分的包,可以用來解決該問題。在使用自動求導的時候,網絡的前向傳播會定義一個計算圖(computational graph),圖中的節點是張量(tensor),兩個節點之間的邊對應了兩個張量之間變換關系的函數。有了計算圖的存在,張量的梯度計算也變得容易了些。例如, x是一個張量,其屬性 x.requires_grad = True,那么 x.grad就是一個保存這個張量x的梯度的一些標量值。
最基礎的自動求導操作在底層就是作用在兩個張量上。前向傳播函數是從輸入張量到輸出張量的計算過程;反向傳播是輸入輸出張量的梯度(一些標量)并輸出輸入張量的梯度(一些標量)。在pytorch中我們可以很容易地定義自己的自動求導操作,通過繼承torch.autograd.Function并定義forward和backward函數。
forward(): 前向傳播操作。可以輸入任意多的參數,任意的python對象都可以。
backward():反向傳播(梯度公式)。輸出的梯度個數需要與所使用的張量個數保持一致,且返回的順序也要對應起來。
# Inherit from Function class LinearFunction(Function): # Note that both forward and backward are @staticmethods @staticmethod # bias is an optional argument def forward(ctx, input, weight, bias=None): # ctx在這里類似self,ctx的屬性可以在backward中調用 ctx.save_for_backward(input, weight, bias) output = input.mm(weight.t()) if bias is not None: output += bias.unsqueeze(0).expand_as(output) return output # This function has only a single output, so it gets only one gradient @staticmethod def backward(ctx, grad_output): # This is a pattern that is very convenient - at the top of backward # unpack saved_tensors and initialize all gradients w.r.t. inputs to # None. Thanks to the fact that additional trailing Nones are # ignored, the return statement is simple even when the function has # optional inputs. input, weight, bias = ctx.saved_tensors grad_input = grad_weight = grad_bias = None # These needs_input_grad checks are optional and there only to # improve efficiency. If you want to make your code simpler, you can # skip them. Returning gradients for inputs that don't require it is # not an error. if ctx.needs_input_grad[0]: grad_input = grad_output.mm(weight) if ctx.needs_input_grad[1]: grad_weight = grad_output.t().mm(input) if bias is not None and ctx.needs_input_grad[2]: grad_bias = grad_output.sum(0).squeeze(0) return grad_input, grad_weight, grad_bias #調用自定義的自動求導函數 linear = LinearFunction.apply(*args) #前向傳播 linear.backward()#反向傳播 linear.grad_fn.apply(*args)#反向傳播
另外有需要云服務器可以了解下創新互聯scvps.cn,海內外云服務器15元起步,三天無理由+7*72小時售后在線,公司持有idc許可證,提供“云服務器、裸金屬服務器、高防服務器、香港服務器、美國服務器、虛擬主機、免備案服務器”等云主機租用服務以及企業上云的綜合解決方案,具有“安全穩定、簡單易用、服務可用性高、性價比高”等特點與優勢,專為企業上云打造定制,能夠滿足用戶豐富、多元化的應用場景需求。
網頁題目:Pytorch:自定義網絡層實例-創新互聯
網頁URL:http://m.kartarina.com/article22/cdsocc.html
成都網站建設公司_創新互聯,為您提供微信公眾號、Google、微信小程序、服務器托管、軟件開發、外貿網站建設
聲明:本網站發布的內容(圖片、視頻和文字)以用戶投稿、用戶轉載內容為主,如果涉及侵權請盡快告知,我們將會在第一時間刪除。文章觀點不代表本網站立場,如需處理請聯系客服。電話:028-86922220;郵箱:631063699@qq.com。內容未經允許不得轉載,或轉載時需注明來源: 創新互聯