본문 바로가기
Python 파이썬/기초 & 내장모듈

functools ) partial() 인수가 이미 채워진 새로운 함수 만들기

by 하이방가루 2022. 5. 12.
728x90
반응형

functools.partial(func, /, *args, **keywords)

  말 그대로 인수가 이미 채워진 새로운 함수를 만들 때 사용한다.

 

partial()를 활용하여 U-net 모델 만들기

from functools import partial
from tensorflow import keras

# factory
conv_filters = [16, 32, 64, 128, 256, 512]
cont_factory = partial(
    keras.layers.Conv2D, kernel_size=(3,3), strides=1, padding="same"
)
cont_activation = keras.layers.ELU()

expan_factory = partial(
    keras.layers.Conv2DTranspose,
    kernel_size=(3,3), strides=1, padding="same"
)
expan_activation = keras.layers.LeakyReLU(0.2)


## Input
inputs = keras.Input((None,None,1))
## Contracting path
## 1
conv1 = cont_factory(conv_filters[0])(inputs)
batch1 = keras.layers.BatchNormalization(axis=-1)(conv1)
rel1 = cont_activation(batch1)
## 2
conv2 = cont_factory(conv_filters[1])(rel1)
batch2 = keras.layers.BatchNormalization(axis=-1)(conv2)
rel2 = cont_activation(batch2)
## 3
conv3 = cont_factory(conv_filters[2])(rel2)
batch3 = keras.layers.BatchNormalization(axis=-1)(conv3)
rel3 = cont_activation(batch3)
## 4
conv4 = cont_factory(conv_filters[3])(rel3)
batch4 = keras.layers.BatchNormalization(axis=-1)(conv4)
rel4 = cont_activation(batch4)
## 5
conv5 = cont_factory(conv_filters[4])(rel4)
batch5 = keras.layers.BatchNormalization(axis=-1)(conv5)
rel5 = cont_activation(batch5)
## 6
conv6 = cont_factory(conv_filters[5])(rel5)

## Expansive path
## 6
up1 = expan_factory(conv_filters[4])(conv6)
up1 = expan_activation(up1)
up_batch1 = keras.layers.BatchNormalization(axis=-1)(up1)
drop1 = keras.layers.Dropout(0.5)(up_batch1)
merge1 = keras.layers.Concatenate(axis=-1)([conv5,drop1])
## 5
up2 = expan_factory(conv_filters[3])(merge1)
up2 = expan_activation(up2)
up_batch2 = keras.layers.BatchNormalization(axis=-1)(up2)
drop2 = keras.layers.Dropout(0.5)(up_batch2)
merge2 = keras.layers.Concatenate(axis=-1)([conv4,drop2])
## 4
up3 = expan_factory(conv_filters[2])(merge2)
up3 = expan_activation(up3)
up_batch3 = keras.layers.BatchNormalization(axis=-1)(up3)
drop3 = keras.layers.Dropout(0.5)(up_batch3)
merge3 = keras.layers.Concatenate(axis=-1)([conv3,drop3])
## 3
up4 = expan_factory(conv_filters[1])(merge3)
up4 = expan_activation(up4)
up_batch4 = keras.layers.BatchNormalization(axis=-1)(up4)
drop4 = keras.layers.Dropout(0.5)(up_batch4)
merge4 = keras.layers.Concatenate(axis=-1)([conv2,drop4])
## 2
up5 = expan_factory(conv_filters[0])(merge4)
up5 = expan_activation(up5)
up_batch5 = keras.layers.BatchNormalization(axis=-1)(up5)
drop5 = keras.layers.Dropout(0.5)(up_batch5)
merge5 = keras.layers.Concatenate(axis=-1)([conv1,drop5])
## 1
up6 = expan_factory(1)(merge5)
up6 = expan_activation(up6)
up_batch6 = keras.layers.BatchNormalization(axis=-1)(up6)
## output
outputs = keras.layers.Conv2D(1, (4,4), dilation_rate=(2,2), activation="sigmoid", padding="same")(up_batch6)

model = keras.Model(inputs, outputs)
model.summary()

결과

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_1 (InputLayer)           [(None, None, None,  0           []                               
                                 1)]                                                              
                                                                                                  
 conv2d (Conv2D)                (None, None, None,   160         ['input_1[0][0]']                
                                16)                                                               
                                                                                                  
 batch_normalization (BatchNorm  (None, None, None,   64         ['conv2d[0][0]']                 
 alization)                     16)                                                               
                                                                                                  
 elu (ELU)                      multiple             0           ['batch_normalization[0][0]',    
                                                                  'batch_normalization_1[0][0]',  
                                                                  'batch_normalization_2[0][0]',  
                                                                  'batch_normalization_3[0][0]',  
                                                                  'batch_normalization_4[0][0]']  
                                                                                                  
 conv2d_1 (Conv2D)              (None, None, None,   4640        ['elu[0][0]']                    
                                32)                                                               
                                                                                                  
 batch_normalization_1 (BatchNo  (None, None, None,   128        ['conv2d_1[0][0]']               
 rmalization)                   32)                                                               
                                                                                                  
...
Total params: 3,540,262
Trainable params: 3,538,276
Non-trainable params: 1,986
728x90
반응형

'Python 파이썬 > 기초 & 내장모듈' 카테고리의 다른 글

소프트웨어 인스펙션 체크 리스트  (0) 2022.06.09
python ) 클래스 class  (0) 2022.03.11
python ) 모듈과 패키지  (0) 2022.03.11
python ) 예외 오류 처리 try  (0) 2022.03.11
python ) 함수 function  (0) 2022.03.08

댓글