- from tensorflow.keras.preprocessing.image import ImageDataGenerator
- from tensorflow.keras.models import Sequential
- from tensorflow.keras.layers import Dense, Conv2D, AveragePooling2D, Flatten
- from tensorflow.keras.optimizers import Adam
- from tensorflow.keras.applications.resnet import ResNet50
- def load_train(path):
- datagen = ImageDataGenerator(rescale=1./255, horizontal_flip=True, vertical_flip = True)
- train_datagen_flow = datagen.flow_from_directory(
- path,
- target_size=(150, 150),
- batch_size=16,
- class_mode='sparse',
- seed=12345
- )
- return train_datagen_flow
- def create_model(input_shape):
- backbone = ResNet50(input_shape=input_shape,
- weights=None,
- include_top=False)
- model = Sequential()
- model.add(backbone)
- model.add(GlobalAveragePooling2D())
- model.add(Dense(12, activation='softmax'))
- model.compile(loss='sparse_categorical_crossentropy',
- optimizer= Adam(lr = 0.01), metrics=['acc'])
- return model
- def train_model(model, train_data, test_data, batch_size=None, epochs=3,
- steps_per_epoch=None, validation_steps=None):
- model.fit(train_data,
- validation_data=test_data,
- batch_size=batch_size, epochs=epochs,
- steps_per_epoch=steps_per_epoch,
- validation_steps=validation_steps,
- verbose=2)
- return model
fruit
From Kate, 5 Months ago, written in Plain Text, viewed 109 times.
This paste will die in 1 Second.
URL http://codebin.org/view/22cad52c
Embed
Download Paste or View Raw
— Expand Paste to full width of browser