- from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
- from tensorflow.keras.models import Sequential
- from tensorflow.keras.optimizers import Adam
- from tensorflow.keras.preprocessing.image import ImageDataGenerator
- from tensorflow.keras.applications.resnet import ResNet50
- def load_train(path):
- labels_path = os.path.join(path, 'labels.csv')
- data_df = pd.read_csv(labels_path)
- train_datagen = ImageDataGenerator(horizontal_flip=True,
- vertical_flip=True,
- rescale=1. / 255)
- train_datagen_flow = train_datagen.flow_from_dataframe(
- data_df,
- labels_path,
- target_size=(150, 150),
- batch_size=16,
- class_mode='sparse',
- subset='training',
- seed=12345)
- return train_datagen_flow
- def load_test(path):
- labels_path = os.path.join(path, 'labels.csv')
- data_df = pd.read_csv(labels_path)
- train_datagen = ImageDataGenerator(horizontal_flip=True,
- vertical_flip=True,
- rescale=1. / 255)
- test_datagen_flow = test_datagen.flow_from_dataframe(
- data_df,
- labels_path,
- target_size=(150, 150),
- batch_size=16,
- class_mode='sparse',
- subset='validation',
- seed=12345)
- return test_datagen_flow
- def create_model(input_shape):
- backbone = ResNet50(input_shape=input_shape,
- weights='imagenet',
- include_top=False)
- model = Sequential()
- model.add(backbone)
- model.add(GlobalAveragePooling2D())
- model.add(Dense(1, activation='relu'))
- optimizer = Adam(lr=0.00001)
- model.compile(optimizer=optimizer, loss='mse', metrics=['mae'])
- return model
- def train_model(model, train_data, test_data, batch_size=None, epochs=20, steps_per_epoch=None, validation_steps=None):
- test_gen_flow = test_data
- if steps_per_epoch is None:
- steps_per_epoch = len(train_data)
- if validation_steps is None:
- validation_steps = len(test_gen_flow)
- model.fit(train_data,
- validation_data=test_data,
- batch_size=batch_size, epochs=epochs,
- steps_per_epoch=steps_per_epoch,
- validation_steps=validation_steps,
- verbose=2, shuffle=True)
- return model
Re: Re: project
From Rude Frog, 3 Months ago, written in Plain Text, viewed 77 times.
This paste is a reply to Re: project from Ksenia
- view diff
URL http://codebin.org/view/34c5a8ee
Embed
Download Paste or View Raw
— Expand Paste to full width of browser