Re: Re: project

From Rude Frog, 3 Months ago, written in Plain Text, viewed 77 times. This paste is a reply to Re: project from Ksenia - view diff
URL http://codebin.org/view/34c5a8ee Embed
Download Paste or View Raw
  1. from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
  2. from tensorflow.keras.models import Sequential
  3. from tensorflow.keras.optimizers import Adam
  4. from tensorflow.keras.preprocessing.image import ImageDataGenerator
  5. from tensorflow.keras.applications.resnet import ResNet50
  6.  
  7.  
  8. def load_train(path):
  9.     labels_path = os.path.join(path, 'labels.csv')
  10.     data_df = pd.read_csv(labels_path)
  11.     train_datagen = ImageDataGenerator(horizontal_flip=True,
  12.                                        vertical_flip=True,
  13.                                        rescale=1. / 255)
  14.  
  15.     train_datagen_flow = train_datagen.flow_from_dataframe(
  16.         data_df,
  17.         labels_path,
  18.         target_size=(150, 150),
  19.         batch_size=16,
  20.         class_mode='sparse',
  21.         subset='training',
  22.         seed=12345)
  23.  
  24.     return train_datagen_flow
  25.  
  26. def load_test(path):
  27.     labels_path = os.path.join(path, 'labels.csv')
  28.     data_df = pd.read_csv(labels_path)
  29.     train_datagen = ImageDataGenerator(horizontal_flip=True,
  30.                                        vertical_flip=True,
  31.                                        rescale=1. / 255)
  32.  
  33.     test_datagen_flow = test_datagen.flow_from_dataframe(
  34.         data_df,
  35.         labels_path,
  36.         target_size=(150, 150),
  37.         batch_size=16,
  38.         class_mode='sparse',
  39.         subset='validation',
  40.         seed=12345)
  41.  
  42.     return test_datagen_flow
  43.  
  44. def create_model(input_shape):
  45.     backbone = ResNet50(input_shape=input_shape,
  46.                     weights='imagenet',
  47.                     include_top=False)
  48.     model = Sequential()
  49.     model.add(backbone)
  50.     model.add(GlobalAveragePooling2D())
  51.     model.add(Dense(1, activation='relu'))
  52.  
  53.     optimizer = Adam(lr=0.00001)
  54.     model.compile(optimizer=optimizer, loss='mse', metrics=['mae'])
  55.     return model
  56.  
  57.  
  58. def train_model(model, train_data, test_data, batch_size=None, epochs=20, steps_per_epoch=None, validation_steps=None):
  59.     test_gen_flow = test_data
  60.     if steps_per_epoch is None:
  61.         steps_per_epoch = len(train_data)
  62.     if validation_steps is None:
  63.         validation_steps = len(test_gen_flow)
  64.     model.fit(train_data,
  65.               validation_data=test_data,
  66.               batch_size=batch_size, epochs=epochs,
  67.               steps_per_epoch=steps_per_epoch,
  68.               validation_steps=validation_steps,
  69.               verbose=2, shuffle=True)
  70.  
  71.     return model

Replies to Re: Re: project rss

Title Name Language When
Re: Re: Re: project Crippled Anoa text 3 Months ago.

Reply to "Re: Re: project"

Here you can reply to the paste above